├── .gitignore ├── DTU-MATLAB ├── BaseEval2Obj_web.m ├── BaseEvalMain_web.m ├── ComputeStat_web.m ├── MaxDistCP.m ├── PointCompareMain.m ├── README.md ├── plyread.m └── reducePts_haa.m ├── LICENSE ├── README.md ├── assets └── overview.png ├── datasets ├── __init__.py ├── bld_train.py ├── data_io.py ├── dtu_yao.py ├── general_eval.py ├── preprocess.py └── tnt_eval.py ├── dynamic_fusion.py ├── finetune.py ├── gipuma.py ├── lists ├── bld │ ├── training_list.txt │ └── validation_list.txt ├── dtu │ ├── test.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt └── tnt │ ├── adv.txt │ └── inter.txt ├── models ├── FMT.py ├── TransMVSNet.py ├── __init__.py ├── dcn.py ├── module.py └── position_encoding.py ├── requirements.txt ├── scripts ├── test_dtu.sh ├── test_tnt.sh ├── train.sh └── train_bld_fintune.sh ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | datasets/__pycache__/ 3 | models/__pycache__/ -------------------------------------------------------------------------------- /DTU-MATLAB/BaseEval2Obj_web.m: -------------------------------------------------------------------------------- 1 | function BaseEval2Obj_web(BaseEval,method_string,outputPath) 2 | 3 | if(nargin<3) 4 | outputPath='./'; 5 | end 6 | 7 | % tresshold for coloring alpha channel in the range of 0-10 mm 8 | dist_tresshold=10; 9 | 10 | cSet=BaseEval.cSet; 11 | 12 | Qdata=BaseEval.Qdata; 13 | alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; 14 | 15 | fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); 16 | 17 | for cP=1:size(Qdata,2) 18 | if(BaseEval.DataInMask(cP)) 19 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 20 | else 21 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) 22 | end 23 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); 24 | end 25 | fclose(fid); 26 | 27 | disp('Data2Stl saved as obj') 28 | 29 | Qstl=BaseEval.Qstl; 30 | fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); 31 | 32 | alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; 33 | 34 | for cP=1:size(Qstl,2) 35 | if(BaseEval.StlAbovePlane(cP)) 36 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 37 | else 38 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) 39 | end 40 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); 41 | end 42 | fclose(fid); 43 | 44 | disp('Stl2Data saved as obj') -------------------------------------------------------------------------------- /DTU-MATLAB/BaseEvalMain_web.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | format compact 4 | clc 5 | 6 | % script to calculate distances have been measured for all included scans (UsedSets) 7 | % DTU dataset pcd path 8 | dataPath='/0_workspace/MVS/DTU_MATLAB_eval'; 9 | % my result path 10 | plyPath='/0_workspace/MVS/DTU_MATLAB_eval/test_train_syncbn_4src_1021'; 11 | % eval result path 12 | resultsPath='/0_workspace/MVS/DTU_MATLAB_eval/test_train_syncbn_4src_1021_result'; 13 | 14 | method_string='mvsnet'; 15 | light_string=''; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6) 16 | representation_string='Points'; %mvs representation 'Points' or 'Surfaces' 17 | 18 | switch representation_string 19 | case 'Points' 20 | eval_string='_Eval_'; %results naming 21 | settings_string=''; 22 | end 23 | 24 | % get sets used in evaluation 25 | UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118]; 26 | 27 | dst=0.2; %Min dist between points when reducing 28 | 29 | for cIdx=1:length(UsedSets) 30 | %Data set number 31 | cSet = UsedSets(cIdx) 32 | %input data name 33 | %DataInName=[plyPath sprintf('/%sscan%d%s%s.ply',lower(method_string),cSet,light_string,settings_string)] 34 | DataInName=[plyPath sprintf('/mvsnet%03d_l3.ply',cSet)] 35 | 36 | %results name 37 | EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] 38 | 39 | %check if file is already computed 40 | if(~exist(EvalName,'file')) 41 | disp(DataInName); 42 | 43 | time=clock;time(4:5), drawnow 44 | 45 | tic 46 | Mesh = plyread(DataInName); 47 | Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; 48 | toc 49 | 50 | BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); 51 | 52 | disp('Saving results'), drawnow 53 | toc 54 | save(EvalName,'BaseEval'); 55 | toc 56 | 57 | % write obj-file of evaluation 58 | % BaseEval2Obj_web(BaseEval,method_string, resultsPath) 59 | % toc 60 | time=clock;time(4:5), drawnow 61 | 62 | BaseEval.MaxDist=20; %outlier threshold of 20 mm 63 | 64 | BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane 65 | BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... 18 | Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... 25 | Qto(1,:)3)] 49 | end 50 | 51 | -------------------------------------------------------------------------------- /DTU-MATLAB/PointCompareMain.m: -------------------------------------------------------------------------------- 1 | function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath) 2 | % evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the 3 | % distances from the evaluation points to the reference 4 | 5 | tic 6 | % reduce points 0.2 mm neighbourhood density 7 | Qdata=reducePts_haa(Qdata,dst); 8 | toc 9 | 10 | StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply']; 11 | 12 | StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density 13 | Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]'; 14 | 15 | %Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res) 16 | Margin=10; 17 | MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat']; 18 | load(MaskName) 19 | 20 | MaxDist=60; 21 | disp('Computing Data 2 Stl distances') 22 | Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist); 23 | toc 24 | 25 | disp('Computing Stl 2 Data distances') 26 | Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist); 27 | disp('Distances computed') 28 | toc 29 | 30 | %use mask 31 | %From Get mask - inverted & modified. 32 | One=ones(1,size(Qdata,2)); 33 | Qv=(Qdata-BB(1,:)'*One)/Res+1; 34 | Qv=round(Qv); 35 | 36 | Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); 37 | MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); 38 | Midx2=find(ObsMask(MidxA)); 39 | 40 | BaseEval.DataInMask(1:size(Qv,2))=false; 41 | BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask 42 | 43 | BaseEval.cSet=cSet; 44 | BaseEval.Margin=Margin; %Margin of masks 45 | BaseEval.dst=dst; %Min dist between points when reducing 46 | BaseEval.Qdata=Qdata; %Input data points 47 | BaseEval.Ddata=Ddata; %distance from data to stl 48 | BaseEval.Qstl=Qstl; %Input stl points 49 | BaseEval.Dstl=Dstl; %Distance from the stl to data 50 | 51 | load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') 52 | BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' 53 | BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' 54 | BaseEval.Time=clock; %Time when computation is finished 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /DTU-MATLAB/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the minimum necessary scripts for running DTU evaluation. You can download the full code [here](http://roboimagedata2.compute.dtu.dk/data/MVS/SampleSet.zip). 2 | 3 | We assume that you have run the [depth fusion](../), and have the final `.ply` files for each scan. 4 | 5 | # Installation 6 | 7 | You need to have `Matlab`. It is a **must**, you cannot use workaround free software like `octave`, I tried but some libraries are not available (e.g. `KDTreeSearch`) 8 | 9 | # Data download 10 | 11 | 1. Download [Points](http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip) and [SampleSet](http://roboimagedata2.compute.dtu.dk/data/MVS/SampleSet.zip). 12 | 2. Extract the files and arrange in the following structure: 13 | ``` 14 | ├── DTU (can be anywhere) 15 | │ ├── Points 16 | │ └── ObsMask 17 | ``` 18 | Where the `ObsMask` folder is taken from `SampleSet/MVS Data`. 19 | 20 | # Quantitative evaluation 21 | 22 | 1. Change `dataPath`, `plyPath`, `resultsPath` [here](https://github.com/kwea123/CasMVSNet_pl/blob/784cec6635fa819bab0d716c15ba07972c260293/evaluations/dtu/BaseEvalMain_web.m#L8-L10). Be careful that you need to ensure `resultsPath` folder already exists because `Matlab` won't automatically create... 23 | 2. Open `Matlab` and run `BaseEvalMain_web.m`. It will compute the metrics for each scan specified [here](https://github.com/kwea123/CasMVSNet_pl/blob/master/evaluations/dtu/BaseEvalMain_web.m#L23). This step will take **A VERY LONG TIME**: for the point cloud I provide, each of them takes **~20mins** to evaluate... the time depends on the point cloud size. 24 | 3. Set `resultsPath` [here](https://github.com/kwea123/CasMVSNet_pl/blob/master/evaluations/dtu/ComputeStat_web.m#L10) the same as above, then run `ComputeStat_web.m`, it will compute the average metrics of all scans. The final numbers `mean acc` and `mean comp` are the final result, and `overall` is just the average of these two numbers. You can then compare these numbers with papers/other implementations. 25 | 26 | ## Result 27 | Since it takes a long time to evaluate (5 hours using default settings...), I provide the numbers here for comparison with some open source methods: 28 | 29 | | | Acc. | Comp. | Overall | resolution | 30 | | --- | --- | --- | --- | --- | 31 | | [MVSNet](https://github.com/YoYo000/MVSNet) | 0.396 | 0.527 | 0.462 | 1600x1184 | 32 | | [MVSNet_pytorch](https://github.com/xy-guo/MVSNet_pytorch) | 0.4492 | 0.3796 | 0.4144 | 1600x1184 | 33 | | *[MVSNet_pytorch](https://github.com/xy-guo/MVSNet_pytorch) | 0.5229 | 0.4514 | 0.4871 | 1152x864 | 34 | | *[R-MVSNet](https://github.com/YoYo000/MVSNet) | 0.383 | 0.452 | 0.4175 | 1600x1184 | 35 | | CasMVSNet paper (fusibile) | 0.325 | 0.385 | 0.355 | 1152x864 | 36 | | *[CasMVSNet](https://github.com/alibaba/cascade-stereo/tree/master/CasMVSNet) | 0.3779 | 0.3645 | 0.3712 | 1152x864 | 37 | | PointMVSNet paper | 0.361 | 0.421 | 0.391 | 1280x960 | 38 | | *[PointMVSNet](https://github.com/callmeray/PointMVSNet) | 0.6344 | 0.6481 | 0.6412 | 1280x960 | 39 | | [UCSNet](https://github.com/touristCheng/UCSNet) paper | 0.330 | 0.372 | 0.351 | 1600x1184 | 40 | | [PVAMVSNet](https://github.com/yhw-yhw/PVAMVSNet) paper | 0.372 | 0.350 | 0.361 | 1600x1184 | 41 | | *This repo(TransMVSNet) | 0.321 | 0.289 | 0.305 | 1152x864 | 42 | 43 | 1. The number of views used is 5 for all methods. 44 | 2. Generally, larger resolution leads to better result. 45 | 46 | -------------------------------------------------------------------------------- /DTU-MATLAB/plyread.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | function [Elements,varargout] = plyread(Path,Str) 3 | %PLYREAD Read a PLY 3D data file. 4 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 5 | % FILENAME and returns a structure DATA. The fields in this structure 6 | % are defined by the PLY header; each element type is a field and each 7 | % element property is a subfield. If the file contains any comments, 8 | % they are returned in a cell string array COMMENTS. 9 | % 10 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 11 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 12 | % and face data into triangular connectivity and vertex arrays. The 13 | % mesh can then be displayed using the TRISURF command. 14 | % 15 | % Note: This function is slow for large mesh files (+50K faces), 16 | % especially when reading data with list type properties. 17 | % 18 | % Example: 19 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 20 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 21 | % colormap(gray); axis equal; 22 | % 23 | % See also: PLYWRITE 24 | 25 | % Pascal Getreuer 2004 26 | 27 | [fid,Msg] = fopen(Path,'rt'); % open file in read text mode 28 | 29 | if fid == -1, error(Msg); end 30 | 31 | Buf = fscanf(fid,'%s',1); 32 | if ~strcmp(Buf,'ply') 33 | fclose(fid); 34 | error('Not a PLY file.'); 35 | end 36 | 37 | 38 | %%% read header %%% 39 | 40 | Position = ftell(fid); 41 | Format = ''; 42 | NumComments = 0; 43 | Comments = {}; % for storing any file comments 44 | NumElements = 0; 45 | NumProperties = 0; 46 | Elements = []; % structure for holding the element data 47 | ElementCount = []; % number of each type of element in file 48 | PropertyTypes = []; % corresponding structure recording property types 49 | ElementNames = {}; % list of element names in the order they are stored in the file 50 | PropertyNames = []; % structure of lists of property names 51 | 52 | while 1 53 | Buf = fgetl(fid); % read one line from file 54 | BufRem = Buf; 55 | Token = {}; 56 | Count = 0; 57 | 58 | while ~isempty(BufRem) % split line into tokens 59 | [tmp,BufRem] = strtok(BufRem); 60 | 61 | if ~isempty(tmp) 62 | Count = Count + 1; % count tokens 63 | Token{Count} = tmp; 64 | end 65 | end 66 | 67 | if Count % parse line 68 | switch lower(Token{1}) 69 | case 'format' % read data format 70 | if Count >= 2 71 | Format = lower(Token{2}); 72 | 73 | if Count == 3 & ~strcmp(Token{3},'1.0') 74 | fclose(fid); 75 | error('Only PLY format version 1.0 supported.'); 76 | end 77 | end 78 | case 'comment' % read file comment 79 | NumComments = NumComments + 1; 80 | Comments{NumComments} = ''; 81 | for i = 2:Count 82 | Comments{NumComments} = [Comments{NumComments},Token{i},' ']; 83 | end 84 | case 'element' % element name 85 | if Count >= 3 86 | if isfield(Elements,Token{2}) 87 | fclose(fid); 88 | error(['Duplicate element name, ''',Token{2},'''.']); 89 | end 90 | 91 | NumElements = NumElements + 1; 92 | NumProperties = 0; 93 | Elements = setfield(Elements,Token{2},[]); 94 | PropertyTypes = setfield(PropertyTypes,Token{2},[]); 95 | ElementNames{NumElements} = Token{2}; 96 | PropertyNames = setfield(PropertyNames,Token{2},{}); 97 | CurElement = Token{2}; 98 | ElementCount(NumElements) = str2double(Token{3}); 99 | 100 | if isnan(ElementCount(NumElements)) 101 | fclose(fid); 102 | error(['Bad element definition: ',Buf]); 103 | end 104 | else 105 | error(['Bad element definition: ',Buf]); 106 | end 107 | case 'property' % element property 108 | if ~isempty(CurElement) & Count >= 3 109 | NumProperties = NumProperties + 1; 110 | eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 111 | 'fclose(fid);error([''Error reading property: '',Buf])'); 112 | 113 | if tmp 114 | error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); 115 | end 116 | 117 | % add property subfield to Elements 118 | eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 119 | 'fclose(fid);error([''Error reading property: '',Buf])'); 120 | % add property subfield to PropertyTypes and save type 121 | eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 122 | 'fclose(fid);error([''Error reading property: '',Buf])'); 123 | % record property name order 124 | eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 125 | 'fclose(fid);error([''Error reading property: '',Buf])'); 126 | else 127 | fclose(fid); 128 | 129 | if isempty(CurElement) 130 | error(['Property definition without element definition: ',Buf]); 131 | else 132 | error(['Bad property definition: ',Buf]); 133 | end 134 | end 135 | case 'end_header' % end of header, break from while loop 136 | break; 137 | end 138 | end 139 | end 140 | 141 | %%% set reading for specified data format %%% 142 | 143 | if isempty(Format) 144 | warning('Data format unspecified, assuming ASCII.'); 145 | Format = 'ascii'; 146 | end 147 | 148 | switch Format 149 | case 'ascii' 150 | Format = 0; 151 | case 'binary_little_endian' 152 | Format = 1; 153 | case 'binary_big_endian' 154 | Format = 2; 155 | otherwise 156 | fclose(fid); 157 | error(['Data format ''',Format,''' not supported.']); 158 | end 159 | 160 | if ~Format 161 | Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data 162 | BufOff = 1; 163 | else 164 | % reopen the file in read binary mode 165 | fclose(fid); 166 | 167 | if Format == 1 168 | fid = fopen(Path,'r','ieee-le.l64'); % little endian 169 | else 170 | fid = fopen(Path,'r','ieee-be.l64'); % big endian 171 | end 172 | 173 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 174 | BufSize = 8192; 175 | Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; 176 | i = []; 177 | tmp = -11; 178 | 179 | while isempty(i) 180 | i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF 181 | i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF 182 | 183 | if isempty(i) 184 | tmp = tmp + BufSize; 185 | Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; 186 | end 187 | end 188 | 189 | % seek to just after the line feed 190 | fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); 191 | end 192 | 193 | 194 | %%% read element data %%% 195 | 196 | % PLY and MATLAB data types (for fread) 197 | PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 198 | 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; 199 | MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; 200 | SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type 201 | 202 | for i = 1:NumElements 203 | % get current element property information 204 | eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); 205 | eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); 206 | NumProperties = size(CurPropertyNames,2); 207 | 208 | % fprintf('Reading %s...\n',ElementNames{i}); 209 | 210 | if ~Format %%% read ASCII data %%% 211 | for j = 1:NumProperties 212 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 213 | 214 | if strcmpi(Token{1},'list') 215 | Type(j) = 1; 216 | else 217 | Type(j) = 0; 218 | end 219 | end 220 | 221 | % parse buffer 222 | if ~any(Type) 223 | % no list types 224 | Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; 225 | BufOff = BufOff + ElementCount(i)*NumProperties; 226 | else 227 | ListData = cell(NumProperties,1); 228 | 229 | for k = 1:NumProperties 230 | ListData{k} = cell(ElementCount(i),1); 231 | end 232 | 233 | % list type 234 | for j = 1:ElementCount(i) 235 | for k = 1:NumProperties 236 | if ~Type(k) 237 | Data(j,k) = Buf(BufOff); 238 | BufOff = BufOff + 1; 239 | else 240 | tmp = Buf(BufOff); 241 | ListData{k}{j} = Buf(BufOff+(1:tmp))'; 242 | BufOff = BufOff + tmp + 1; 243 | end 244 | end 245 | end 246 | end 247 | else %%% read binary data %%% 248 | % translate PLY data type names to MATLAB data type names 249 | ListFlag = 0; % = 1 if there is a list type 250 | SameFlag = 1; % = 1 if all types are the same 251 | 252 | for j = 1:NumProperties 253 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 254 | 255 | if ~strcmp(Token{1},'list') % non-list type 256 | tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; 257 | 258 | if ~isempty(tmp) 259 | TypeSize(j) = SizeOf(tmp); 260 | Type{j} = MatlabTypeNames{tmp}; 261 | TypeSize2(j) = 0; 262 | Type2{j} = ''; 263 | 264 | SameFlag = SameFlag & strcmp(Type{1},Type{j}); 265 | else 266 | fclose(fid); 267 | error(['Unknown property data type, ''',Token{1},''', in ', ... 268 | ElementNames{i},'.',CurPropertyNames{j},'.']); 269 | end 270 | else % list type 271 | if length(Token) == 3 272 | ListFlag = 1; 273 | SameFlag = 0; 274 | tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; 275 | tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; 276 | 277 | if ~isempty(tmp) & ~isempty(tmp2) 278 | TypeSize(j) = SizeOf(tmp); 279 | Type{j} = MatlabTypeNames{tmp}; 280 | TypeSize2(j) = SizeOf(tmp2); 281 | Type2{j} = MatlabTypeNames{tmp2}; 282 | else 283 | fclose(fid); 284 | error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... 285 | ElementNames{i},'.',CurPropertyNames{j},'.']); 286 | end 287 | else 288 | fclose(fid); 289 | error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); 290 | end 291 | end 292 | end 293 | 294 | % read file 295 | if ~ListFlag 296 | if SameFlag 297 | % no list types, all the same type (fast) 298 | Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; 299 | else 300 | % no list types, mixed type 301 | Data = zeros(ElementCount(i),NumProperties); 302 | 303 | for j = 1:ElementCount(i) 304 | for k = 1:NumProperties 305 | Data(j,k) = fread(fid,1,Type{k}); 306 | end 307 | end 308 | end 309 | else 310 | ListData = cell(NumProperties,1); 311 | 312 | for k = 1:NumProperties 313 | ListData{k} = cell(ElementCount(i),1); 314 | end 315 | 316 | if NumProperties == 1 317 | BufSize = 512; 318 | SkipNum = 4; 319 | j = 0; 320 | 321 | % list type, one property (fast if lists are usually the same length) 322 | while j < ElementCount(i) 323 | Position = ftell(fid); 324 | % read in BufSize count values, assuming all counts = SkipNum 325 | [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); 326 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 327 | fseek(fid,Position + TypeSize(1),-1); % seek back to after first count 328 | 329 | if isempty(Miss) % all counts are SkipNum 330 | Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 331 | fseek(fid,-TypeSize(1),0); % undo last skip 332 | 333 | for k = 1:BufSize 334 | ListData{1}{j+k} = Buf(k,:); 335 | end 336 | 337 | j = j + BufSize; 338 | BufSize = floor(1.5*BufSize); 339 | else 340 | if Miss(1) > 1 % some counts are SkipNum 341 | Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 342 | 343 | for k = 1:Miss(1)-1 344 | ListData{1}{j+k} = Buf2(k,:); 345 | end 346 | 347 | j = j + k; 348 | end 349 | 350 | % read in the list with the missed count 351 | SkipNum = Buf(Miss(1)); 352 | j = j + 1; 353 | ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); 354 | BufSize = ceil(0.6*BufSize); 355 | end 356 | end 357 | else 358 | % list type(s), multiple properties (slow) 359 | Data = zeros(ElementCount(i),NumProperties); 360 | 361 | for j = 1:ElementCount(i) 362 | for k = 1:NumProperties 363 | if isempty(Type2{k}) 364 | Data(j,k) = fread(fid,1,Type{k}); 365 | else 366 | tmp = fread(fid,1,Type{k}); 367 | ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); 368 | end 369 | end 370 | end 371 | end 372 | end 373 | end 374 | 375 | % put data into Elements structure 376 | for k = 1:NumProperties 377 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 378 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); 379 | else 380 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); 381 | end 382 | end 383 | end 384 | 385 | clear Data ListData; 386 | fclose(fid); 387 | 388 | if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 389 | % find vertex element field 390 | Name = {'vertex','Vertex','point','Point','pts','Pts'}; 391 | Names = []; 392 | 393 | for i = 1:length(Name) 394 | if any(strcmp(ElementNames,Name{i})) 395 | Names = getfield(PropertyNames,Name{i}); 396 | Name = Name{i}; 397 | break; 398 | end 399 | end 400 | 401 | if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) 402 | eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); 403 | else 404 | varargout{1} = zeros(1,3); 405 | end 406 | 407 | varargout{2} = Elements; 408 | varargout{3} = Comments; 409 | Elements = []; 410 | 411 | % find face element field 412 | Name = {'face','Face','poly','Poly','tri','Tri'}; 413 | Names = []; 414 | 415 | for i = 1:length(Name) 416 | if any(strcmp(ElementNames,Name{i})) 417 | Names = getfield(PropertyNames,Name{i}); 418 | Name = Name{i}; 419 | break; 420 | end 421 | end 422 | 423 | if ~isempty(Names) 424 | % find vertex indices property subfield 425 | PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; 426 | 427 | for i = 1:length(PropertyName) 428 | if any(strcmp(Names,PropertyName{i})) 429 | PropertyName = PropertyName{i}; 430 | break; 431 | end 432 | end 433 | 434 | if ~iscell(PropertyName) 435 | % convert face index lists to triangular connectivity 436 | eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); 437 | N = length(FaceIndices); 438 | Elements = zeros(N*2,3); 439 | Extra = 0; 440 | 441 | for k = 1:N 442 | Elements(k,:) = FaceIndices{k}(1:3); 443 | 444 | for j = 4:length(FaceIndices{k}) 445 | Extra = Extra + 1; 446 | Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; 447 | end 448 | end 449 | Elements = Elements(1:N+Extra,:) + 1; 450 | end 451 | end 452 | else 453 | varargout{1} = Comments; 454 | end -------------------------------------------------------------------------------- /DTU-MATLAB/reducePts_haa.m: -------------------------------------------------------------------------------- 1 | function [ptsOut,indexSet] = reducePts_haa(pts, dst) 2 | 3 | %Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance 4 | % between points is 'dst'. Writen by abd, edited by haa, then by raje 5 | 6 | nPoints=size(pts,2); 7 | 8 | indexSet=true(nPoints,1); 9 | RandOrd=randperm(nPoints); 10 | 11 | %tic 12 | NS = KDTreeSearcher(pts'); 13 | %toc 14 | 15 | % search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big 16 | Chunks=1:min(4e6,nPoints-1):nPoints; 17 | Chunks(end)=nPoints; 18 | 19 | for cChunk=1:(length(Chunks)-1) 20 | Range=Chunks(cChunk):Chunks(cChunk+1); 21 | 22 | idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); 23 | 24 | for i = 1:size(idx,1) 25 | id =RandOrd(i-1+Chunks(cChunk)); 26 | if (indexSet(id)) 27 | indexSet(idx{i}) = 0; 28 | indexSet(id) = 1; 29 | end 30 | end 31 | end 32 | 33 | ptsOut = pts(:,indexSet); 34 | 35 | disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Megvii Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # (CVPR2022) TransMVSNet: Global Context-aware Multi-view Stereo Network with Transformers 4 | 5 | 6 | ## [Paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Ding_TransMVSNet_Global_Context-Aware_Multi-View_Stereo_Network_With_Transformers_CVPR_2022_paper.pdf) | [Project Page](https://dingyikang.github.io/transmvsnet.github.io/) | [Arxiv](https://arxiv.org/abs/2111.14600/) | [Models](https://drive.google.com/drive/folders/1ZJ9bx9qZENEoXv5i5izKCNszlaCNBMkJ?usp=sharing/) 7 | 8 | **Tips**: If you meet any problems when reproduce our results, please contact Yikang Ding (dyk20@mails.tsinghua.edu.cn). We are happy to help you solve the problems and share our experience. 9 | 10 | 11 | ## ⚠ Change log 12 | * 09.2022: Add more detailed instruction of how to reproduce the reported results (see [testing-on-dtu](#-testing-on-dtu)). 13 | * 09.2022: Fix the bugs in MATLAB evaluation code (remove the debug code). 14 | * 09.2022: Fix the bug of default fuse parameters of gipuma, which could have a great impact on the final results. 15 | * 09.2022: Update the website link and instruction of installing gipuma, which would affect the fusion quality. 16 | 17 | 18 | ## 📔 Introduction 19 | In this paper, we present TransMVSNet, based on our exploration of feature matching in multi-view stereo (MVS). We analogize MVS back to its nature of a feature matching task and therefore propose a powerful Feature Matching Transformer (FMT) to leverage intra- (self-) and inter- (cross-) attention to aggregate long-range context information within and across images. To facilitate a better adaptation of the FMT, we leverage an Adaptive Receptive Field (ARF) module to ensure a smooth transit in scopes of features and bridge different stages with a feature pathway to pass transformed features and gradients across different scales. In addition, we apply pair-wise feature correlation to measure similarity between features, and adopt ambiguity-reducing focal loss to strengthen the supervision. To the best of our knowledge, TransMVSNet is the first attempt to leverage Transformer into the task of MVS. As a result, our method achieves state-of-the-art performance on DTU dataset, Tanks and Temples benchmark, and BlendedMVS dataset. 20 | ![](assets/overview.png) 21 | 22 | 23 | ## 🔧 Installation 24 | Our code is tested with Python==3.6/3.7/3.8, PyTorch==1.6.0/1.7.0/1.9.0, CUDA==10.2 on Ubuntu-18.04 with NVIDIA GeForce RTX 2080Ti. Similar or higher version should work well. 25 | 26 | To use TransMVSNet, clone this repo: 27 | ``` 28 | git clone https://github.com/MegviiRobot/TransMVSNet.git 29 | cd TransMVSNet 30 | ``` 31 | 32 | We highly recommend using [Anaconda](https://www.anaconda.com/) to manage the python environment: 33 | ``` 34 | conda create -n transmvsnet python=3.6 35 | conda activate transmvsnet 36 | pip install -r requirements.txt 37 | ``` 38 | 44 | We also recommend using apex, you can install apex from the [official repo](https://www.github.com/nvidia/apex). 45 | 46 | 47 | ## 📦 Data preparation 48 | In TransMVSNet, we mainly use [DTU](https://roboimagedata.compute.dtu.dk/), [BlendedMVS](https://github.com/YoYo000/BlendedMVS/) and [Tanks and Temples](https://www.tanksandtemples.org/) to train and evaluate our models. You can prepare the corresponding data by following the instructions below. 49 | 50 | ### ✔ DTU 51 | For DTU training set, you can download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) 52 | and [Depths_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) 53 | (both from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and unzip them to construct a dataset folder like: 54 | ``` 55 | dtu_training 56 | ├── Cameras 57 | ├── Depths 58 | ├── Depths_raw 59 | └── Rectified 60 | ``` 61 | For DTU testing set, you can download the preprocessed [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the test data folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file. 62 | 63 | ### ✔ BlendedMVS 64 | We use the [low-res set](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDgxb8MDGgoV74S?e=hJKlvV) of BlendedMVS dataset for both training and testing. You can download the [low-res set](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDgxb8MDGgoV74S?e=hJKlvV) from [orignal BlendedMVS](https://github.com/YoYo000/BlendedMVS) and unzip it to form the dataset folder like below: 65 | 66 | ``` 67 | BlendedMVS 68 | ├── 5a0271884e62597cdee0d0eb 69 | │ ├── blended_images 70 | │ ├── cams 71 | │ └── rendered_depth_maps 72 | ├── 59338e76772c3e6384afbb15 73 | ├── 59f363a8b45be22330016cad 74 | ├── ... 75 | ├── all_list.txt 76 | ├── training_list.txt 77 | └── validation_list.txt 78 | ``` 79 | 80 | ### ✔ Tanks and Temples 81 | Download our preprocessed [Tanks and Temples dataset](https://drive.google.com/file/d/1IHG5GCJK1pDVhDtTHFS3sY-ePaK75Qzg/view?usp=sharing) and unzip it to form the dataset folder like below: 82 | ``` 83 | tankandtemples 84 | ├── advanced 85 | │ ├── Auditorium 86 | │ ├── Ballroom 87 | │ ├── ... 88 | │ └── Temple 89 | └── intermediate 90 | ├── Family 91 | ├── Francis 92 | ├── ... 93 | └── Train 94 | ``` 95 | 96 | ## 📈 Training 97 | 98 | ### ✔ Training on DTU 99 | Set the configuration in ``scripts/train.sh``: 100 | * Set ``MVS_TRAINING`` as the path of DTU training set. 101 | * Set ``LOG_DIR`` to save the checkpoints. 102 | * Change ``NGPUS`` to suit your device. 103 | * We use ``torch.distributed.launch`` by default. 104 | 105 | To train your own model, just run: 106 | ``` 107 | bash scripts/train.sh 108 | ``` 109 | You can conveniently modify more hyper-parameters in ``scripts/train.sh`` according to the argparser in ``train.py``, such as ``summary_freq``, ``save_freq``, and so on. 110 | 111 | ### ✔ Finetune on BlendedMVS 112 | For a fair comparison with other SOTA methods on Tanks and Temples benchmark, we finetune our model on BlendedMVS dataset after training on DTU dataset. 113 | 114 | Set the configuration in ``scripts/train_bld_fintune.sh``: 115 | * Set ``MVS_TRAINING`` as the path of BlendedMVS dataset. 116 | * Set ``LOG_DIR`` to save the checkpoints and training log. 117 | * Set ``CKPT`` as path of the loaded ``.ckpt`` which is trained on DTU dataset. 118 | 119 | To finetune your own model, just run: 120 | ``` 121 | bash scripts/train_bld_fintune.sh 122 | ``` 123 | 124 | ## 📊 Testing 125 | For easy testing, you can download our [pre-trained models](https://drive.google.com/drive/folders/1ZJ9bx9qZENEoXv5i5izKCNszlaCNBMkJ?usp=sharing) and put them in `checkpoints` folder, or use your own models and follow the instruction below. 126 | 127 | ### ✔ Testing on DTU 128 | 129 | **Important Tips:** to reproduce our reported results, you need to: 130 | * compile and install the modified `gipuma` from [Yao Yao](https://github.com/YoYo000/fusibile) as introduced below 131 | * use the latest code as we have fixed tiny bugs and updated the fusion parameters 132 | * make sure you install the right version of python and pytorch, use some old versions would throw warnings of the default action of `align_corner` in several functions, which would affect the final results 133 | * be aware that we only test the code on 2080Ti and Ubuntu 18.04, other devices and systems might get slightly different results 134 | * make sure that you use the `model_dtu.ckpt` for testing 135 | 136 | 137 | To start testing, set the configuration in ``scripts/test_dtu.sh``: 138 | * Set ``TESTPATH`` as the path of DTU testing set. 139 | * Set ``TESTLIST`` as the path of test list (.txt file). 140 | * Set ``CKPT_FILE`` as the path of the model weights. 141 | * Set ``OUTDIR`` as the path to save results. 142 | 143 | Run: 144 | ``` 145 | bash scripts/test_dtu.sh 146 | ``` 147 | **Note:** You can use the `gipuma` fusion method or `normal` fusion method to fuse the point clouds. **In our experiments, we use the `gipuma` fusion method by default**. 148 | With using the uploaded ckpt and latest code, these two fusion methods would get the below results: 149 | | Fuse | Overall | 150 | | --- | --- | 151 | | gipuma | 0.304 | 152 | | normal | 0.314 | 153 | 154 | 155 | To install the `gipuma`, clone the modified version from [Yao Yao](https://github.com/YoYo000/fusibile). 156 | Modify the line-10 in `CMakeLists.txt` to suit your GPUs. Othervise you would meet warnings when compile it, which would lead to failure and get 0 points in fused point cloud. For example, if you use 2080Ti GPU, modify the line-10 to: 157 | ``` 158 | set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-O3 --use_fast_math --ptxas-options=-v -std=c++11 --compiler-options -Wall -gencode arch=compute_70,code=sm_70) 159 | ``` 160 | If you use other kind of GPUs, please modify the arch code to suit your device (`arch=compute_XX,code=sm_XX`). 161 | Then install it by `cmake .` and `make`, which will generate the executable file at `FUSIBILE_EXE_PATH`. 162 | Please note 163 | 164 | 165 | 166 | For quantitative evaluation on DTU dataset, download [SampleSet](http://roboimagedata.compute.dtu.dk/?page_id=36) and [Points](http://roboimagedata.compute.dtu.dk/?page_id=36). Unzip them and place `Points` folder in `SampleSet/MVS Data/`. The structure looks like: 167 | ``` 168 | SampleSet 169 | ├──MVS Data 170 | └──Points 171 | ``` 172 | In ``DTU-MATLAB/BaseEvalMain_web.m``, set `dataPath` as path to `SampleSet/MVS Data/`, `plyPath` as directory that stores the reconstructed point clouds and `resultsPath` as directory to store the evaluation results. Then run ``DTU-MATLAB/BaseEvalMain_web.m`` in matlab. 173 | 174 | We also upload our final point cloud results to [here](https://drive.google.com/drive/folders/1a3b0tDoPj9y7GMhOSjb5TBRq7ahYjp4f?usp=sharing). You can easily download them and evaluate them using the `MATLAB` scripts, the results look like: 175 | 176 | 177 | | Acc. (mm) | Comp. (mm) | Overall (mm) | 178 | |-----------|------------|--------------| 179 | | 0.321 | 0.289 | 0.305 | 180 | 181 | 182 | 183 | 184 | ### ✔ Testing on Tanks and Temples 185 | We recommend using the finetuned models (``model_bld.ckpt``) to test on Tanks and Temples benchmark. 186 | 187 | Similarly, set the configuration in ``scripts/test_tnt.sh``: 188 | * Set ``TESTPATH`` as the path of intermediate set or advanced set. 189 | * Set ``TESTLIST`` as the path of test list (.txt file). 190 | * Set ``CKPT_FILE`` as the path of the model weights. 191 | * Set ``OUTDIR`` as the path to save resutls. 192 | 193 | To generate point cloud results, just run: 194 | ``` 195 | bash scripts/test_tnt.sh 196 | ``` 197 | Note that: 198 | * The parameters of point cloud fusion have not been studied thoroughly and the performance can be better if cherry-picking more appropriate thresholds for each of the scenes. 199 | * The dynamic fusion code is borrowed from [AA-RMVSNet](https://github.com/QT-Zhu/AA-RMVSNet). 200 | 201 | For quantitative evaluation, you can upload your point clouds to [Tanks and Temples benchmark](https://www.tanksandtemples.org/). 202 | 203 | ## 🔗 Citation 204 | 205 | ```bibtex 206 | @inproceedings{ding2022transmvsnet, 207 | title={Transmvsnet: Global context-aware multi-view stereo network with transformers}, 208 | author={Ding, Yikang and Yuan, Wentao and Zhu, Qingtian and Zhang, Haotian and Liu, Xiangyue and Wang, Yuanjiang and Liu, Xiao}, 209 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 210 | pages={8585--8594}, 211 | year={2022} 212 | } 213 | ``` 214 | 215 | ## 📌 Acknowledgments 216 | We borrow some code from [CasMVSNet](https://github.com/alibaba/cascade-stereo/tree/master/CasMVSNet), [LoFTR](https://github.com/zju3dv/LoFTR) and [AA-RMVSNet](https://github.com/QT-Zhu/AA-RMVSNet). We thank the authors for releasing the source code. 217 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megvii-research/TransMVSNet/16100feb97309a846f73af4be9ea4e2e833baf5d/assets/overview.png -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | # find the dataset definition by name, for example dtu_yao (dtu_yao.py) 5 | def find_dataset_def(dataset_name): 6 | module_name = 'datasets.{}'.format(dataset_name) 7 | module = importlib.import_module(module_name) 8 | return getattr(module, "MVSDataset") 9 | -------------------------------------------------------------------------------- /datasets/bld_train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from datasets.data_io import * 6 | from datasets.preprocess import * 7 | 8 | class MVSDataset(Dataset): 9 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.0, 10 | origin_size=False, light_idx=-1, image_scale=1.0, **kwargs): 11 | super(MVSDataset, self).__init__() 12 | self.datapath = datapath 13 | self.listfile = listfile 14 | self.mode = mode 15 | self.nviews = nviews 16 | self.ndepths = ndepths 17 | self.interval_scale = interval_scale 18 | self.origin_size = origin_size 19 | self.light_idx=light_idx 20 | self.image_scale = image_scale # use to resize image 21 | 22 | print('dataset: origin_size {}, light_idx:{}, image_scale:{}'.format( 23 | self.origin_size, self.light_idx, self.image_scale)) 24 | 25 | assert self.mode in ["train", "val", "test"] 26 | self.metas = self.build_list() 27 | 28 | def build_list(self): 29 | metas = [] 30 | with open(self.listfile) as f: 31 | scans = f.readlines() 32 | scans = [line.rstrip() for line in scans] 33 | 34 | # scans 35 | for scan in scans: 36 | pair_file = "{}/cams/pair.txt".format(scan) 37 | # read the pair file 38 | with open(os.path.join(self.datapath, pair_file)) as f: 39 | num_viewpoint = int(f.readline()) 40 | for view_idx in range(num_viewpoint): 41 | ref_view = int(f.readline().rstrip()) 42 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 43 | if len(src_views) < self.nviews -1: 44 | print('less ref_view small {}'.format(self.nviews-1)) 45 | continue 46 | metas.append((scan, ref_view, src_views,)) 47 | print("dataset", self.mode, "metas:", len(metas)) 48 | return metas 49 | 50 | def __len__(self): 51 | return len(self.metas) 52 | 53 | def read_cam_file(self, filename): 54 | with open(filename) as f: 55 | lines = f.readlines() 56 | lines = [line.rstrip() for line in lines] 57 | # extrinsics: line [1,5), 4x4 matrix 58 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 59 | # intrinsics: line [7-10), 3x3 matrix 60 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 61 | intrinsics[:2, :] /= 4.0 62 | 63 | if self.image_scale != 1.0: # origin: 1.0 64 | intrinsics[:2, :] *= self.image_scale 65 | 66 | depth_min = float(lines[11].split()[0]) 67 | # depth_interval = float(lines[11].split()[1]) * self.interval_scale 68 | depth_max = float(lines[11].split()[-1]) 69 | depth_interval = float(depth_max - depth_min) / self.ndepths 70 | return intrinsics, extrinsics, depth_min, depth_interval 71 | 72 | def read_img(self, filename): 73 | img = Image.open(filename) 74 | np_img = np.array(img, dtype=np.float32) / 255. # origin version on 2020/02/20 75 | return np_img 76 | 77 | 78 | def center_img(self, img): # this is very important for batch normalization 79 | img = img.astype(np.float32) 80 | var = np.var(img, axis=(0,1), keepdims=True) 81 | mean = np.mean(img, axis=(0,1), keepdims=True) 82 | return (img - mean) / (np.sqrt(var) + 0.00000001) 83 | 84 | def read_depth(self, filename): 85 | depth_image = np.array(read_pfm(filename)[0], dtype=np.float32) 86 | return depth_image 87 | 88 | 89 | def __getitem__(self, idx): 90 | 91 | #print('idx: {}, flip_falg {}'.format(idx, flip_flag)) 92 | meta = self.metas[idx] 93 | scan, ref_view, src_views = meta 94 | # use only the reference view and first nviews-1 source views 95 | view_ids = [ref_view] + src_views[:self.nviews - 1] 96 | 97 | imgs = [] 98 | mask = None 99 | depth = None 100 | depth_values = None 101 | proj_matrices = [] 102 | for i, vid in enumerate(view_ids): 103 | # NOTE that the id in image file names is from 000000000 104 | img_filename = os.path.join(self.datapath, 105 | '{}/blended_images/{:0>8}.jpg'.format(scan, vid)) 106 | 107 | proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) 108 | depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid)) 109 | 110 | if i == 0: 111 | depth_name = depth_filename 112 | 113 | imgs.append(self.read_img(img_filename)) 114 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 115 | 116 | # multiply intrinsics and extrinsics to get projection matrix 117 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 118 | proj_mat[0, :4, :4] = extrinsics 119 | proj_mat[1, :3, :3] = intrinsics 120 | 121 | proj_matrices.append(proj_mat) 122 | 123 | if i == 0: # reference view 124 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths-0.5) + depth_min, depth_interval, 125 | dtype=np.float32) # the set is [) 126 | 127 | depth_values = np.concatenate((depth_values,depth_values[::-1]),axis=0) 128 | 129 | depth_end = depth_interval * (self.ndepths-1) + depth_min 130 | 131 | depth = self.read_depth(depth_filename) 132 | mask_ = np.array((depth >= depth_min) & (depth <= depth_end), dtype=np.float32) 133 | mask_ms = { 134 | "stage1": cv2.resize(mask_, (mask_.shape[1]//4, mask_.shape[0]//4), interpolation=cv2.INTER_NEAREST), 135 | "stage2": cv2.resize(mask_, (mask_.shape[1]//2, mask_.shape[0]//2), interpolation=cv2.INTER_NEAREST), 136 | "stage3": mask_, 137 | } 138 | h,w = depth.shape 139 | depth_ms = { 140 | "stage1": cv2.resize(depth, (w//4, h//4), interpolation=cv2.INTER_NEAREST), 141 | "stage2": cv2.resize(depth, (w//2, h//2), interpolation=cv2.INTER_NEAREST), 142 | "stage3": depth, 143 | } 144 | depth_max = depth_interval * self.ndepths + depth_min 145 | depth_values = np.arange(depth_min, depth_max, depth_interval, dtype=np.float32) 146 | 147 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 148 | 149 | proj_matrices = np.stack(proj_matrices) 150 | stage2_pjmats = proj_matrices.copy() 151 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 152 | stage3_pjmats = proj_matrices.copy() 153 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 154 | 155 | proj_matrices_ms = { 156 | "stage1": proj_matrices, 157 | "stage2": stage2_pjmats, 158 | "stage3": stage3_pjmats 159 | } 160 | 161 | return {"imgs": imgs, 162 | "proj_matrices": proj_matrices_ms, 163 | "depth": depth_ms, 164 | "depth_values": depth_values, # generate depth index 165 | "mask": mask_ms, 166 | "depth_interval": depth_interval, 167 | 'name':depth_name,} 168 | -------------------------------------------------------------------------------- /datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | 5 | 6 | def read_pfm(filename): 7 | file = open(filename, 'rb') 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().decode('utf-8').rstrip() 15 | if header == 'PF': 16 | color = True 17 | elif header == 'Pf': 18 | color = False 19 | else: 20 | raise Exception('Not a PFM file.') 21 | 22 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 23 | if dim_match: 24 | width, height = map(int, dim_match.groups()) 25 | else: 26 | raise Exception('Malformed PFM header.') 27 | 28 | scale = float(file.readline().rstrip()) 29 | if scale < 0: # little-endian 30 | endian = '<' 31 | scale = -scale 32 | else: 33 | endian = '>' # big-endian 34 | 35 | data = np.fromfile(file, endian + 'f') 36 | shape = (height, width, 3) if color else (height, width) 37 | 38 | data = np.reshape(data, shape) 39 | data = np.flipud(data) 40 | file.close() 41 | return data, scale 42 | 43 | 44 | def save_pfm(filename, image, scale=1): 45 | file = open(filename, "wb") 46 | color = None 47 | 48 | image = np.flipud(image) 49 | 50 | if image.dtype.name != 'float32': 51 | raise Exception('Image dtype must be float32.') 52 | 53 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 54 | color = True 55 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 56 | color = False 57 | else: 58 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 59 | 60 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 61 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 62 | 63 | endian = image.dtype.byteorder 64 | 65 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 66 | scale = -scale 67 | 68 | file.write(('%f\n' % scale).encode('utf-8')) 69 | 70 | image.tofile(file) 71 | file.close() 72 | 73 | import random, cv2 74 | class RandomCrop(object): 75 | def __init__(self, CropSize=0.1): 76 | self.CropSize = CropSize 77 | 78 | def __call__(self, image, normal): 79 | h, w = normal.shape[:2] 80 | img_h, img_w = image.shape[:2] 81 | CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize)) 82 | x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h) 83 | x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h) 84 | 85 | normal_crop = normal[y1:y2, x1:x2] 86 | normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST) 87 | 88 | image_crop = image[4*y1:4*y2, 4*x1:4*x2] 89 | image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR) 90 | 91 | 92 | return image_resize, normal_resize 93 | -------------------------------------------------------------------------------- /datasets/dtu_yao.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os, cv2, time, math 4 | from PIL import Image 5 | from datasets.data_io import * 6 | 7 | # the DTU dataset preprocessed by Yao Yao (only for training) 8 | class MVSDataset(Dataset): 9 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.06, **kwargs): 10 | super(MVSDataset, self).__init__() 11 | self.datapath = datapath 12 | self.listfile = listfile 13 | self.mode = mode 14 | self.nviews = nviews 15 | self.ndepths = ndepths 16 | self.interval_scale = interval_scale 17 | self.kwargs = kwargs 18 | print("mvsdataset kwargs", self.kwargs) 19 | 20 | assert self.mode in ["train", "val", "test"] 21 | self.metas = self.build_list() 22 | 23 | def build_list(self): 24 | metas = [] 25 | with open(self.listfile) as f: 26 | scans = f.readlines() 27 | scans = [line.rstrip() for line in scans] 28 | 29 | # scans 30 | for scan in scans: 31 | pair_file = "Cameras/pair.txt" 32 | # read the pair file 33 | with open(os.path.join(self.datapath, pair_file)) as f: 34 | num_viewpoint = int(f.readline()) 35 | # viewpoints (49) 36 | for view_idx in range(num_viewpoint): 37 | ref_view = int(f.readline().rstrip()) 38 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 39 | # light conditions 0-6 40 | for light_idx in range(7): 41 | metas.append((scan, light_idx, ref_view, src_views)) 42 | print("dataset", self.mode, "metas:", len(metas)) 43 | return metas 44 | 45 | def __len__(self): 46 | return len(self.metas) 47 | 48 | def read_cam_file(self, filename): 49 | with open(filename) as f: 50 | lines = f.readlines() 51 | lines = [line.rstrip() for line in lines] 52 | # extrinsics: line [1,5), 4x4 matrix 53 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 54 | # intrinsics: line [7-10), 3x3 matrix 55 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 56 | # depth_min & depth_interval: line 11 57 | depth_min = float(lines[11].split()[0]) 58 | depth_interval = float(lines[11].split()[1]) * self.interval_scale 59 | return intrinsics, extrinsics, depth_min, depth_interval 60 | 61 | def read_img(self, filename): 62 | img = Image.open(filename) 63 | # scale 0~255 to 0~1 64 | np_img = np.array(img, dtype=np.float32) / 255. 65 | return np_img 66 | 67 | def prepare_img(self, hr_img): 68 | #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128 69 | 70 | #downsample 71 | h, w = hr_img.shape 72 | hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST) 73 | #crop 74 | h, w = hr_img_ds.shape 75 | target_h, target_w = 512, 640 76 | start_h, start_w = (h - target_h)//2, (w - target_w)//2 77 | hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w] 78 | 79 | return hr_img_crop 80 | 81 | def read_mask_hr(self, filename): 82 | img = Image.open(filename) 83 | np_img = np.array(img, dtype=np.float32) 84 | np_img = (np_img > 10).astype(np.float32) 85 | np_img = self.prepare_img(np_img) 86 | 87 | h, w = np_img.shape 88 | np_img_ms = { 89 | "stage1": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST), 90 | "stage2": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST), 91 | "stage3": np_img, 92 | } 93 | return np_img_ms 94 | 95 | def read_depth(self, filename): 96 | # read pfm depth file 97 | return np.array(read_pfm(filename)[0], dtype=np.float32) 98 | 99 | def read_depth_hr(self, filename): 100 | # read pfm depth file 101 | depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) 102 | depth_lr = self.prepare_img(depth_hr) 103 | 104 | h, w = depth_lr.shape 105 | depth_lr_ms = { 106 | "stage1": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST), 107 | "stage2": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST), 108 | "stage3": depth_lr, 109 | } 110 | return depth_lr_ms 111 | 112 | def __getitem__(self, idx): 113 | meta = self.metas[idx] 114 | scan, light_idx, ref_view, src_views = meta 115 | # use only the reference view and first nviews-1 source views 116 | view_ids = [ref_view] + src_views[:self.nviews - 1] 117 | 118 | imgs = [] 119 | mask = None 120 | depth_values = None 121 | proj_matrices = [] 122 | 123 | for i, vid in enumerate(view_ids): 124 | # NOTE that the id in image file names is from 1 to 49 (not 0~48) 125 | img_filename = os.path.join(self.datapath, 126 | 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) 127 | 128 | mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid)) 129 | depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid)) 130 | 131 | proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid) 132 | 133 | 134 | img = self.read_img(img_filename) 135 | 136 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 137 | 138 | 139 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 140 | proj_mat[0, :4, :4] = extrinsics 141 | proj_mat[1, :3, :3] = intrinsics 142 | 143 | proj_matrices.append(proj_mat) 144 | 145 | if i == 0: # reference view 146 | mask_read_ms = self.read_mask_hr(mask_filename_hr) 147 | depth_ms = self.read_depth_hr(depth_filename_hr) 148 | 149 | #get depth values 150 | depth_max = depth_interval * self.ndepths + depth_min 151 | depth_values = np.arange(depth_min, depth_max, depth_interval, dtype=np.float32) 152 | 153 | mask = mask_read_ms 154 | 155 | imgs.append(img) 156 | 157 | #all 158 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 159 | #ms proj_mats 160 | proj_matrices = np.stack(proj_matrices) 161 | stage2_pjmats = proj_matrices.copy() 162 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 163 | stage3_pjmats = proj_matrices.copy() 164 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 165 | 166 | proj_matrices_ms = { 167 | "stage1": proj_matrices, 168 | "stage2": stage2_pjmats, 169 | "stage3": stage3_pjmats 170 | } 171 | 172 | return {"imgs": imgs, 173 | "proj_matrices": proj_matrices_ms, 174 | "depth": depth_ms, 175 | "depth_values": depth_values, 176 | "depth_interval": depth_interval, 177 | "mask": mask } 178 | -------------------------------------------------------------------------------- /datasets/general_eval.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os, cv2, time 4 | from PIL import Image 5 | from datasets.data_io import * 6 | 7 | s_h, s_w = 0, 0 8 | class MVSDataset(Dataset): 9 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.06, **kwargs): 10 | super(MVSDataset, self).__init__() 11 | self.datapath = datapath 12 | self.listfile = listfile 13 | self.mode = mode 14 | self.nviews = nviews 15 | self.ndepths = ndepths 16 | self.interval_scale = interval_scale 17 | self.max_h, self.max_w = kwargs["max_h"], kwargs["max_w"] 18 | self.fix_res = kwargs.get("fix_res", False) #whether to fix the resolution of input image. 19 | self.fix_wh = False 20 | 21 | assert self.mode == "test" 22 | self.metas = self.build_list() 23 | 24 | def build_list(self): 25 | metas = [] 26 | scans = self.listfile 27 | 28 | interval_scale_dict = {} 29 | # scans 30 | for scan in scans: 31 | # determine the interval scale of each scene. default is 1.06 32 | if isinstance(self.interval_scale, float): 33 | interval_scale_dict[scan] = self.interval_scale 34 | else: 35 | interval_scale_dict[scan] = self.interval_scale[scan] 36 | 37 | pair_file = "{}/pair.txt".format(scan) 38 | # read the pair file 39 | with open(os.path.join(self.datapath, pair_file)) as f: 40 | num_viewpoint = int(f.readline()) 41 | # viewpoints 42 | for view_idx in range(num_viewpoint): 43 | ref_view = int(f.readline().rstrip()) 44 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 45 | # filter by no src view and fill to nviews 46 | if len(src_views) > 0: 47 | if len(src_views) < self.nviews: 48 | print("{}< num_views:{}".format(len(src_views), self.nviews)) 49 | src_views += [src_views[0]] * (self.nviews - len(src_views)) 50 | metas.append((scan, ref_view, src_views, scan)) 51 | 52 | self.interval_scale = interval_scale_dict 53 | print("dataset", self.mode, "metas:", len(metas), "interval_scale:{}".format(self.interval_scale)) 54 | return metas 55 | 56 | def __len__(self): 57 | return len(self.metas) 58 | 59 | def read_cam_file(self, filename, interval_scale): 60 | with open(filename) as f: 61 | lines = f.readlines() 62 | lines = [line.rstrip() for line in lines] 63 | # extrinsics: line [1,5), 4x4 matrix 64 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 65 | # intrinsics: line [7-10), 3x3 matrix 66 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 67 | intrinsics[:2, :] /= 4.0 68 | # depth_min & depth_interval: line 11 69 | depth_min = float(lines[11].split()[0]) 70 | depth_interval = float(lines[11].split()[1]) 71 | 72 | if len(lines[11].split()) >= 3: 73 | num_depth = lines[11].split()[2] 74 | depth_max = depth_min + int(float(num_depth)) * depth_interval 75 | depth_interval = (depth_max - depth_min) / self.ndepths 76 | 77 | depth_interval *= interval_scale 78 | 79 | return intrinsics, extrinsics, depth_min, depth_interval 80 | 81 | def read_img(self, filename): 82 | img = Image.open(filename) 83 | # scale 0~255 to 0~1 84 | np_img = np.array(img, dtype=np.float32) / 255. 85 | 86 | return np_img 87 | 88 | def read_depth(self, filename): 89 | # read pfm depth file 90 | return np.array(read_pfm(filename)[0], dtype=np.float32) 91 | 92 | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=32): 93 | h, w = img.shape[:2] 94 | if h > max_h or w > max_w: 95 | scale = 1.0 * max_h / h 96 | if scale * w > max_w: 97 | scale = 1.0 * max_w / w 98 | new_w, new_h = scale * w // base * base, scale * h // base * base 99 | else: 100 | new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base 101 | 102 | scale_w = 1.0 * new_w / w 103 | scale_h = 1.0 * new_h / h 104 | intrinsics[0, :] *= scale_w 105 | intrinsics[1, :] *= scale_h 106 | 107 | img = cv2.resize(img, (int(new_w), int(new_h))) 108 | 109 | return img, intrinsics 110 | 111 | def __getitem__(self, idx): 112 | global s_h, s_w 113 | meta = self.metas[idx] 114 | scan, ref_view, src_views, scene_name = meta 115 | # use only the reference view and first nviews-1 source views 116 | view_ids = [ref_view] + src_views[:self.nviews - 1] 117 | 118 | imgs = [] 119 | depth_values = None 120 | proj_matrices = [] 121 | 122 | for i, vid in enumerate(view_ids): 123 | img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid)) 124 | if not os.path.exists(img_filename): 125 | img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) 126 | 127 | proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) 128 | 129 | img = self.read_img(img_filename) 130 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale= 131 | self.interval_scale[scene_name]) 132 | # scale input 133 | img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h) 134 | 135 | if self.fix_res: 136 | # using the same standard height or width in entire scene. 137 | s_h, s_w = img.shape[:2] 138 | self.fix_res = False 139 | self.fix_wh = True 140 | 141 | if i == 0: 142 | if not self.fix_wh: 143 | # using the same standard height or width in each nviews. 144 | s_h, s_w = img.shape[:2] 145 | 146 | # resize to standard height or width 147 | c_h, c_w = img.shape[:2] 148 | if (c_h != s_h) or (c_w != s_w): 149 | scale_h = 1.0 * s_h / c_h 150 | scale_w = 1.0 * s_w / c_w 151 | img = cv2.resize(img, (s_w, s_h)) 152 | intrinsics[0, :] *= scale_w 153 | intrinsics[1, :] *= scale_h 154 | 155 | 156 | imgs.append(img) 157 | # extrinsics, intrinsics 158 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 159 | proj_mat[0, :4, :4] = extrinsics 160 | proj_mat[1, :3, :3] = intrinsics 161 | proj_matrices.append(proj_mat) 162 | 163 | if i == 0: # reference view 164 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval, 165 | dtype=np.float32) 166 | 167 | #all 168 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 169 | proj_matrices = np.stack(proj_matrices) 170 | 171 | stage2_pjmats = proj_matrices.copy() 172 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 173 | stage3_pjmats = proj_matrices.copy() 174 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 175 | 176 | proj_matrices_ms = { 177 | "stage1": proj_matrices, 178 | "stage2": stage2_pjmats, 179 | "stage3": stage3_pjmats 180 | } 181 | 182 | return {"imgs": imgs, 183 | "proj_matrices": proj_matrices_ms, 184 | "depth_values": depth_values, 185 | "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} 186 | -------------------------------------------------------------------------------- /datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import math 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def scale_camera(cam, scale=1): 8 | """ resize input in order to produce sampled depth map """ 9 | new_cam = np.copy(cam) 10 | 11 | # focal: 12 | new_cam[0][0] = cam[0][0] * scale 13 | new_cam[1][1] = cam[1][1] * scale 14 | # principle point: 15 | new_cam[0][2] = cam[0][2] * scale 16 | new_cam[1][2] = cam[1][2] * scale 17 | return new_cam 18 | 19 | def scale_image(image, scale=1, interpolation='linear'): 20 | """ resize image using cv2 """ 21 | if interpolation == 'linear': 22 | return cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 23 | if interpolation == 'nearest': 24 | return cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) 25 | 26 | def scale_mvs_input(images, cams, depth_image=None, scale=1,view_num=5): 27 | """ resize input to fit into the memory """ 28 | new_images = [] 29 | new_cams=[] 30 | for view in range(view_num): 31 | new_images.append(scale_image(images[view], scale=scale)) 32 | new_cams.append(scale_camera(cams[view], scale=scale)) 33 | new_images = np.array(new_images) 34 | if depth_image is None: 35 | #return images, cams 36 | return new_images, new_cams 37 | else: 38 | depth_image = scale_image(depth_image, scale=scale, interpolation='nearest') 39 | return new_images, cams, depth_image 40 | 41 | def crop_mvs_input(images, cams, depth_image=None,view_num=5,max_h=1200,max_w=1600,base_image_size=8): 42 | """ resize images and cameras to fit the network (can be divided by base image size) """ 43 | 44 | new_images = [] 45 | # crop images and cameras 46 | for view in range(view_num): 47 | h, w = images[view].shape[0:2] 48 | new_h = h 49 | new_w = w 50 | if new_h > max_h: 51 | new_h = max_h 52 | else: 53 | new_h = int(math.ceil(h /base_image_size) * base_image_size) 54 | if new_w > max_w: 55 | new_w = max_w 56 | else: 57 | new_w = int(math.ceil(w /base_image_size) * base_image_size) 58 | start_h = int(math.ceil((h - new_h) / 2)) 59 | start_w = int(math.ceil((w - new_w) / 2)) 60 | finish_h = start_h + new_h 61 | finish_w = start_w + new_w 62 | 63 | new_images.append(images[view][start_h:finish_h, start_w:finish_w]) 64 | cams[view][0][2] = cams[view][0][2] - start_w 65 | cams[view][1][2] = cams[view][1][2] - start_h 66 | 67 | new_images = np.stack(new_images) 68 | # crop depth image 69 | if not depth_image is None: 70 | depth_image = depth_image[start_h:finish_h, start_w:finish_w] 71 | return new_images, cams, depth_image 72 | else: 73 | return new_images, cams 74 | 75 | -------------------------------------------------------------------------------- /datasets/tnt_eval.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os, cv2 5 | from PIL import Image 6 | from datasets.data_io import * 7 | 8 | # Test any dataset with scale and center crop 9 | s_h, s_w = 0, 0 10 | class MVSDataset(Dataset): 11 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.0, 12 | max_h=704,max_w=1280, inverse_depth=False, **kwargs): 13 | super(MVSDataset, self).__init__() 14 | self.datapath = datapath 15 | self.mode = mode 16 | self.nviews = nviews 17 | self.ndepths = ndepths 18 | self.interval_scale = interval_scale 19 | self.fix_res = kwargs.get("fix_res", True) #whether to fix the resolution of input image. 20 | self.fix_wh = False 21 | self.max_h=max_h 22 | self.max_w=max_w 23 | self.inverse_depth=inverse_depth 24 | self.scans = listfile 25 | 26 | self.image_sizes = {'Family': (1920, 1080), 27 | 'Francis': (1920, 1080), 28 | 'Horse': (1920, 1080), 29 | 'Lighthouse': (2048, 1080), 30 | 'M60': (2048, 1080), 31 | 'Panther': (2048, 1080), 32 | 'Playground': (1920, 1080), 33 | 'Train': (1920, 1080), 34 | 'Auditorium': (1920, 1080), 35 | 'Ballroom': (1920, 1080), 36 | 'Courtroom': (1920, 1080), 37 | 'Museum': (1920, 1080), 38 | 'Palace': (1920, 1080), 39 | 'Temple': (1920, 1080)} 40 | 41 | assert self.mode == "test" 42 | self.metas = self.build_list() 43 | print('Data Loader : data_eval_T&T**************' ) 44 | 45 | def build_list(self): 46 | metas = [] 47 | 48 | scans = self.scans 49 | 50 | for scan in scans: 51 | pair_file = "{}/pair.txt".format(scan) 52 | # read the pair file 53 | with open(os.path.join(self.datapath, pair_file)) as f: 54 | num_viewpoint = int(f.readline()) 55 | for view_idx in range(num_viewpoint): 56 | ref_view = int(f.readline().rstrip()) 57 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 58 | if len(src_views) == 0: 59 | continue 60 | metas.append((scan, ref_view, src_views)) 61 | print("dataset", self.mode, "metas:", len(metas)) 62 | return metas 63 | 64 | def __len__(self): 65 | return len(self.metas) 66 | 67 | def read_cam_file(self, filename): 68 | with open(filename) as f: 69 | lines = f.readlines() 70 | lines = [line.rstrip() for line in lines] 71 | # extrinsics: line [1,5), 4x4 matrix 72 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 73 | # intrinsics: line [7-10), 3x3 matrix 74 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 75 | intrinsics[:2, :] /= 4.0 76 | 77 | depth_min = float(lines[11].split()[0]) 78 | depth_max = float(lines[11].split()[1]) 79 | depth_interval = float((depth_max - depth_min) / self.ndepths) 80 | 81 | return intrinsics, extrinsics, depth_min, depth_interval, depth_max 82 | 83 | def read_img(self, filename): 84 | img = Image.open(filename) 85 | 86 | np_img = np.array(img, dtype=np.float32) / 255. 87 | return np_img 88 | 89 | def center_img(self, img): # this is very important for batch normalization 90 | img = img.astype(np.float32) 91 | var = np.var(img, axis=(0,1), keepdims=True) 92 | mean = np.mean(img, axis=(0,1), keepdims=True) 93 | return (img - mean) / (np.sqrt(var) ) 94 | 95 | def read_depth(self, filename): 96 | # read pfm depth file 97 | return np.array(read_pfm(filename)[0], dtype=np.float32) 98 | 99 | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=32): 100 | h, w = img.shape[:2] 101 | if h > max_h or w > max_w: 102 | scale = 1.0 * max_h / h 103 | if scale * w > max_w: 104 | scale = 1.0 * max_w / w 105 | new_w, new_h = scale * w // base * base, scale * h // base * base 106 | else: 107 | new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base 108 | 109 | scale_w = 1.0 * new_w / w 110 | scale_h = 1.0 * new_h / h 111 | intrinsics[0, :] *= scale_w 112 | intrinsics[1, :] *= scale_h 113 | 114 | img = cv2.resize(img, (int(new_w), int(new_h))) 115 | 116 | return img, intrinsics 117 | 118 | def __getitem__(self, idx): 119 | global s_h, s_w 120 | meta = self.metas[idx] 121 | scan, ref_view, src_views = meta 122 | # img_w, img_h = self.image_sizes[scan] 123 | 124 | if self.nviews>len(src_views): 125 | self.nviews=len(src_views)+1 126 | 127 | # use only the reference view and first nviews-1 source views 128 | view_ids = [ref_view] + src_views[:self.nviews - 1] 129 | 130 | imgs = [] 131 | depth_values = None 132 | proj_matrices = [] 133 | 134 | for i, vid in enumerate(view_ids): 135 | img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) 136 | proj_mat_filename = os.path.join(self.datapath, '{}/cams_1/{:0>8}_cam.txt'.format(scan, vid)) 137 | 138 | img = (self.read_img(img_filename)) 139 | # imgs.append(self.read_img(img_filename)) 140 | intrinsics, extrinsics, depth_min, depth_interval, depth_max = self.read_cam_file(proj_mat_filename) 141 | 142 | img, intrinsics = self.scale_mvs_input(img, intrinsics, self.image_sizes[scan][0], self.image_sizes[scan][1]) 143 | # img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h) 144 | 145 | if self.fix_res: 146 | # using the same standard height or width in entire scene. 147 | s_h, s_w = img.shape[:2] 148 | self.fix_res = False 149 | self.fix_wh = True 150 | 151 | if i == 0: 152 | if not self.fix_wh: 153 | # using the same standard height or width in each nviews. 154 | s_h, s_w = img.shape[:2] 155 | 156 | c_h, c_w = img.shape[:2] 157 | if (c_h != s_h) or (c_w != s_w): 158 | scale_h = 1.0 * s_h / c_h 159 | scale_w = 1.0 * s_w / c_w 160 | img = cv2.resize(img, (s_w, s_h)) 161 | intrinsics[0, :] *= scale_w 162 | intrinsics[1, :] *= scale_h 163 | 164 | imgs.append(img) 165 | # extrinsics, intrinsics 166 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 167 | proj_mat[0, :4, :4] = extrinsics 168 | proj_mat[1, :3, :3] = intrinsics 169 | proj_matrices.append(proj_mat) 170 | 171 | if i == 0: # reference view 172 | if self.inverse_depth is False: 173 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths) + depth_min, depth_interval, 174 | dtype=np.float32) 175 | else: 176 | print('********* Here we use inverse depth for all stage! ***********') 177 | depth_end = depth_max - depth_interval / self.interval_scale 178 | depth_values = np.linspace(1.0 / depth_end, 1.0 / depth_min, self.ndepths, endpoint=False) 179 | depth_values = 1.0 / depth_values 180 | depth_values = depth_values.astype(np.float32) 181 | 182 | 183 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) # B,C,H,W 184 | proj_matrices = np.stack(proj_matrices) 185 | 186 | stage2_pjmats = proj_matrices.copy() 187 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 188 | stage3_pjmats = proj_matrices.copy() 189 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 190 | 191 | proj_matrices_ms = { 192 | "stage1": proj_matrices, 193 | "stage2": stage2_pjmats, 194 | "stage3": stage3_pjmats 195 | } 196 | 197 | return {"imgs": imgs, 198 | "proj_matrices": proj_matrices_ms, 199 | "depth_values": depth_values, 200 | "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} -------------------------------------------------------------------------------- /dynamic_fusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Type 4 | import numpy as np 5 | from utils import print_args 6 | import sys 7 | from datasets.data_io import read_pfm 8 | from plyfile import PlyData, PlyElement 9 | from PIL import Image 10 | import cv2 11 | from multiprocessing import Pool 12 | from functools import partial 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Filter and fuse depth maps') 16 | 17 | parser.add_argument('--testpath', help='testing data path') 18 | parser.add_argument('--tntpath', help='tnt data path') 19 | parser.add_argument('--testlist', help='testing scan list') 20 | parser.add_argument('--outdir', help='output dir') 21 | parser.add_argument('--photo_threshold', type=float, default=0.3, help='photo threshold for filter confidence') 22 | parser.add_argument('--display', action='store_true', help='display depth images and masks') 23 | parser.add_argument('--test_dataset', choices=['dtu','tnt'], help='which dataset to evaluate') 24 | parser.add_argument('--thres_view', type=int, default=3, help='threshold of num view') 25 | 26 | # parse arguments and check 27 | args = parser.parse_args() 28 | print("argv:", sys.argv[1:]) 29 | print_args(args) 30 | 31 | 32 | # read intrinsics and extrinsics 33 | def read_camera_parameters(filename,scale,index,flag): 34 | with open(filename) as f: 35 | lines = f.readlines() 36 | lines = [line.rstrip() for line in lines] 37 | # extrinsics: line [1,5), 4x4 matrix 38 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 39 | # intrinsics: line [7-10), 3x3 matrix 40 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 41 | 42 | intrinsics[:2, :] *= scale 43 | 44 | if (flag==0): 45 | intrinsics[0,2]-=index 46 | else: 47 | intrinsics[1,2]-=index 48 | 49 | return intrinsics, extrinsics 50 | 51 | # read an image 52 | def read_img(filename): 53 | img = Image.open(filename) 54 | # scale 0~255 to 0~1 55 | np_img = np.array(img, dtype=np.float32) / 255. 56 | return np_img 57 | 58 | # save a binary mask 59 | def save_mask(filename, mask): 60 | assert mask.dtype == np.bool 61 | mask = mask.astype(np.uint8) * 255 62 | Image.fromarray(mask).save(filename) 63 | 64 | # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] 65 | def read_pair_file(filename): 66 | data = [] 67 | with open(filename) as f: 68 | num_viewpoint = int(f.readline()) 69 | for view_idx in range(num_viewpoint): 70 | ref_view = int(f.readline().rstrip()) 71 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 72 | if len(src_views) == 0: 73 | continue 74 | data.append((ref_view, src_views)) 75 | return data 76 | 77 | # project the reference point cloud into the source view, then project back 78 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 79 | width, height = depth_ref.shape[1], depth_ref.shape[0] 80 | ## step1. project reference pixels to the source view 81 | # reference view x, y 82 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 83 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 84 | # reference 3D space 85 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 86 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 87 | # source 3D space 88 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 89 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 90 | # source view x, y 91 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 92 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 93 | 94 | ## step2. reproject the source view points with source view depth estimation 95 | # find the depth estimation of the source view 96 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 97 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 98 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 99 | # mask = sampled_depth_src > 0 100 | 101 | # source 3D space 102 | # NOTE that we should use sampled source-view depth_here to project back 103 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 104 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 105 | # reference 3D space 106 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 107 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 108 | # source view x, y, depth 109 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 110 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 111 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 112 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 113 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 114 | 115 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 116 | 117 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src 118 | ): 119 | width, height = depth_ref.shape[1], depth_ref.shape[0] 120 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 121 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, 122 | intrinsics_ref, 123 | extrinsics_ref, 124 | depth_src, 125 | intrinsics_src, 126 | extrinsics_src) 127 | # check |p_reproj-p_1| < 1 128 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 129 | 130 | # check |d_reproj-d_1| / d_1 < 0.01 131 | depth_diff = np.abs(depth_reprojected - depth_ref) 132 | relative_depth_diff = depth_diff / depth_ref 133 | masks=[] 134 | for i in range(2,11): 135 | mask = np.logical_and(dist < i/4, relative_depth_diff < i/1300) 136 | masks.append(mask) 137 | vis_mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) 138 | depth_reprojected[~mask] = 0 139 | 140 | return masks, mask, depth_reprojected, x2d_src, y2d_src, vis_mask 141 | 142 | def filter_depth(scan_folder, out_folder, pair_path, plyfilename, photo_threshold): 143 | # the pair file 144 | pair_file = os.path.join(scan_folder, "pair.txt") 145 | pair_file = pair_path 146 | # for the final point cloud 147 | vertexs = [] 148 | vertex_colors = [] 149 | 150 | pair_data = read_pair_file(pair_file) 151 | 152 | for ref_view, src_views in pair_data: 153 | # load the reference image 154 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 155 | # load the estimated depth of the reference view 156 | ref_depth_est = read_pfm(os.path.join(scan_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 157 | 158 | # load the photometric mask of the reference view 159 | confidence = read_pfm(os.path.join(scan_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 160 | 161 | 162 | scale=float(confidence.shape[0])/ref_img.shape[0] 163 | index=int((int(ref_img.shape[1]*scale)-confidence.shape[1])/2) 164 | index_p=(int(ref_img.shape[1]*scale)-confidence.shape[1])-index 165 | flag=0 166 | if (confidence.shape[1]/ref_img.shape[1]>scale): 167 | scale=float(confidence.shape[1])/ref_img.shape[1] 168 | index=int((int(ref_img.shape[0]*scale)-confidence.shape[0])/2) 169 | index_p=(int(ref_img.shape[0]*scale)-confidence.shape[0])-index 170 | flag=1 171 | 172 | ref_img=cv2.resize(ref_img,(int(ref_img.shape[1]*scale),int(ref_img.shape[0]*scale))) 173 | if (flag==0): 174 | ref_img=ref_img[:,index:ref_img.shape[1]-index_p,:] 175 | else: 176 | ref_img=ref_img[index:ref_img.shape[0]-index_p,:,:] 177 | 178 | # load the camera parameters 179 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 180 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)),scale,index,flag) 181 | 182 | photo_mask = confidence > photo_threshold 183 | 184 | 185 | all_srcview_depth_ests = [] 186 | 187 | # compute the geometric mask 188 | geo_mask_sum = 0 189 | geo_mask_sums=[] 190 | vis_masks=[] 191 | n=1 192 | for src_view in src_views: 193 | n+=1 194 | ct = 0 195 | for src_view in src_views: 196 | ct = ct + 1 197 | src_depth_est = read_pfm(os.path.join(scan_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 198 | 199 | src_intrinsics, src_extrinsics = read_camera_parameters( 200 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view)),scale,index,flag) 201 | 202 | masks, geo_mask, depth_reprojected, x2d_src, y2d_src, vis_mask= check_geometric_consistency(ref_depth_est, ref_intrinsics, 203 | ref_extrinsics, 204 | src_depth_est, 205 | src_intrinsics, src_extrinsics) 206 | 207 | vis_masks.append(vis_mask*src_view) 208 | 209 | if (ct==1): 210 | for i in range(2,n): 211 | geo_mask_sums.append(masks[i-2].astype(np.int32)) 212 | else : 213 | for i in range(2,n): 214 | geo_mask_sums[i-2]+=masks[i-2].astype(np.int32) 215 | 216 | geo_mask_sum+=geo_mask.astype(np.int32) 217 | 218 | all_srcview_depth_ests.append(depth_reprojected) 219 | 220 | # Modify 221 | # geo_mask=geo_mask_sum>=n 222 | geo_mask=geo_mask_sum>=args.thres_view 223 | 224 | for i in range (2,n): 225 | geo_mask=np.logical_or(geo_mask,geo_mask_sums[i-2]>=i) 226 | print(geo_mask.mean()) 227 | 228 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 229 | 230 | if (not isinstance(geo_mask, bool)): 231 | 232 | final_mask = np.logical_and(photo_mask, geo_mask) 233 | 234 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 235 | 236 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 237 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 238 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 239 | 240 | print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, 241 | photo_mask.mean(), 242 | geo_mask.mean(), 243 | final_mask.mean())) 244 | 245 | if args.display: 246 | cv2.imshow('ref_img', ref_img[:, :, ::-1]) 247 | cv2.imshow('ref_depth', ref_depth_est / 800) 248 | cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32) / 800) 249 | cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / 800) 250 | cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / 800) 251 | cv2.waitKey(0) 252 | 253 | height, width = depth_est_averaged.shape[:2] 254 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 255 | valid_points = final_mask 256 | print("valid_points", valid_points.mean()) 257 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 258 | color = ref_img[:, :, :][valid_points] 259 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 260 | np.vstack((x, y, np.ones_like(x))) * depth) 261 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 262 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 263 | vertexs.append(xyz_world.transpose((1, 0))) 264 | vertex_colors.append((color * 255).astype(np.uint8)) 265 | 266 | 267 | vertexs = np.concatenate(vertexs, axis=0) 268 | vertex_colors = np.concatenate(vertex_colors, axis=0) 269 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 270 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 271 | 272 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 273 | for prop in vertexs.dtype.names: 274 | vertex_all[prop] = vertexs[prop] 275 | for prop in vertex_colors.dtype.names: 276 | vertex_all[prop] = vertex_colors[prop] 277 | 278 | el = PlyElement.describe(vertex_all, 'vertex') 279 | PlyData([el]).write(plyfilename) 280 | print("saving the final model to", plyfilename) 281 | 282 | def worker(scan): 283 | scan_folder = os.path.join(args.testpath, scan) 284 | out_folder = os.path.join(args.outdir, scan) 285 | pair_path = os.path.join(args.tntpath, f"{scan}/pair.txt") 286 | if (args.test_dataset=='dtu'): 287 | scan_id = int(scan[4:]) 288 | photo_threshold=0.3 289 | filter_depth(scan_folder, out_folder, pair_path, os.path.join(args.outdir, 'mvsnet_{:0>3}_l3.ply'.format(scan_id) ), photo_threshold) 290 | if (args.test_dataset=='tnt'): 291 | photo_threshold=args.photo_threshold 292 | filter_depth(scan_folder, out_folder, pair_path, os.path.join(args.outdir, scan + '.ply'), photo_threshold) 293 | 294 | 295 | if __name__ == '__main__': 296 | 297 | with open(args.testlist) as f: 298 | scans = f.readlines() 299 | scans = [line.rstrip() for line in scans] 300 | partial_func = partial(worker) 301 | 302 | p = Pool(8) 303 | p.map(partial_func, scans) 304 | p.close() 305 | p.join() 306 | 307 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, time, gc, datetime 2 | from models.module import focal_loss_bld 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from tensorboardX import SummaryWriter 10 | from datasets import find_dataset_def 11 | from models import * 12 | from utils import * 13 | import torch.distributed as dist 14 | 15 | cudnn.benchmark = True 16 | 17 | parser = argparse.ArgumentParser(description='A PyTorch Implementation of Cascade Cost Volume MVSNet') 18 | parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile']) 19 | parser.add_argument('--model', default='mvsnet', help='select model') 20 | parser.add_argument('--device', default='cuda', help='select model') 21 | parser.add_argument('--dataset', default='dtu_yao', help='select dataset') 22 | parser.add_argument('--trainpath', help='train datapath') 23 | parser.add_argument('--testpath', help='test datapath') 24 | parser.add_argument('--trainlist', help='train list') 25 | parser.add_argument('--testlist', help='test list') 26 | parser.add_argument('--epochs', type=int, default=16, help='number of epochs to train') 27 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 28 | parser.add_argument('--lrepochs', type=str, default="10,12,14:2", help='epoch ids to downscale lr and the downscale rate') 29 | parser.add_argument('--wd', type=float, default=0.0001, help='weight decay') 30 | parser.add_argument('--nviews', type=int, default=5, help='total number of views') 31 | parser.add_argument('--batch_size', type=int, default=1, help='train batch size') 32 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 33 | parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values') 34 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 35 | parser.add_argument('--logdir', default='./checkpoints', help='the directory to save checkpoints/logs') 36 | parser.add_argument('--resume', action='store_true', help='continue to train the model') 37 | parser.add_argument('--summary_freq', type=int, default=50, help='print and summary frequency') 38 | parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') 39 | parser.add_argument('--eval_freq', type=int, default=1, help='eval freq') 40 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 41 | parser.add_argument('--pin_m', action='store_true', help='data loader pin memory') 42 | parser.add_argument("--local_rank", type=int, default=0) 43 | parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization') 44 | parser.add_argument('--ndepths', type=str, default="48,32,8", help='ndepths') 45 | parser.add_argument('--depth_inter_r', type=str, default="4,1,0.5", help='depth_intervals_ratio') 46 | parser.add_argument('--dlossw', type=str, default="1.0,1.0,1.0", help='depth loss weight for different stage') 47 | parser.add_argument('--cr_base_chs', type=str, default="8,8,8", help='cost regularization base channels') 48 | parser.add_argument('--grad_method', type=str, default="detach", choices=["detach", "undetach"], help='grad method') 49 | parser.add_argument('--using_apex', action='store_true', help='using apex, need to install apex') 50 | parser.add_argument('--sync_bn', action='store_true',help='enabling apex sync BN.') 51 | parser.add_argument('--opt-level', type=str, default="O0") 52 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 53 | parser.add_argument('--loss-scale', type=str, default=None) 54 | 55 | 56 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 57 | is_distributed = num_gpus > 1 58 | 59 | # main function 60 | def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args): 61 | milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')] 62 | lr_gamma = 1 / float(args.lrepochs.split(':')[1]) 63 | lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500, 64 | last_epoch=len(TrainImgLoader) * start_epoch - 1) 65 | 66 | for epoch_idx in range(start_epoch, args.epochs): 67 | global_step = len(TrainImgLoader) * epoch_idx 68 | 69 | # training 70 | if is_distributed: 71 | TrainImgLoader.sampler.set_epoch(epoch_idx) 72 | for batch_idx, sample in enumerate(TrainImgLoader): 73 | start_time = time.time() 74 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 75 | do_summary = global_step % args.summary_freq == 0 76 | loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args) 77 | lr_scheduler.step() 78 | if (not is_distributed) or (dist.get_rank() == 0): 79 | if do_summary: 80 | save_scalars(logger, 'train', scalar_outputs, global_step) 81 | save_images(logger, 'train', image_outputs, global_step) 82 | print( 83 | "Epoch {}/{}, Iter {}/{}, lr {:.6f}, train loss = {:.3f}, depth loss = {:.3f}, epe = {:.3f}, less1 = {:.3f}, less3 = {:.3f}, time = {:.3f}".format( 84 | epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), 85 | optimizer.param_groups[0]["lr"], loss, 86 | scalar_outputs['depth_loss'], 87 | scalar_outputs['epe'], 88 | scalar_outputs['less1'], 89 | scalar_outputs['less3'], 90 | time.time() - start_time)) 91 | del scalar_outputs, image_outputs 92 | 93 | # checkpoint 94 | if (not is_distributed) or (dist.get_rank() == 0): 95 | if (epoch_idx + 1) % args.save_freq == 0: 96 | torch.save({ 97 | 'epoch': epoch_idx, 98 | 'model': model.module.state_dict(), 99 | 'optimizer': optimizer.state_dict()}, 100 | "{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx)) 101 | gc.collect() 102 | 103 | # testing 104 | if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1): 105 | avg_test_scalars = DictAverageMeter() 106 | for batch_idx, sample in enumerate(TestImgLoader): 107 | start_time = time.time() 108 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 109 | do_summary = global_step % args.summary_freq == 0 110 | loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) 111 | if (not is_distributed) or (dist.get_rank() == 0): 112 | if do_summary: 113 | save_scalars(logger, 'test', scalar_outputs, global_step) 114 | save_images(logger, 'test', image_outputs, global_step) 115 | print("Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, depth loss = {:.3f}, epe = {:.3f}, less1 = {:.3f}, less3 = {:.3f}, time = {:3f}".format( 116 | epoch_idx, args.epochs, 117 | batch_idx, 118 | len(TestImgLoader), loss, 119 | scalar_outputs["depth_loss"], 120 | scalar_outputs['epe'], 121 | scalar_outputs['less1'], 122 | scalar_outputs['less3'], 123 | time.time() - start_time)) 124 | avg_test_scalars.update(scalar_outputs) 125 | del scalar_outputs, image_outputs 126 | 127 | if (not is_distributed) or (dist.get_rank() == 0): 128 | save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step) 129 | print("avg_test_scalars:", avg_test_scalars.mean()) 130 | gc.collect() 131 | 132 | 133 | def test(model, model_loss, TestImgLoader, args): 134 | avg_test_scalars = DictAverageMeter() 135 | for batch_idx, sample in enumerate(TestImgLoader): 136 | start_time = time.time() 137 | loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) 138 | avg_test_scalars.update(scalar_outputs) 139 | del scalar_outputs, image_outputs 140 | if (not is_distributed) or (dist.get_rank() == 0): 141 | print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss, 142 | time.time() - start_time)) 143 | if batch_idx % 100 == 0: 144 | print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean())) 145 | if (not is_distributed) or (dist.get_rank() == 0): 146 | print("final", avg_test_scalars.mean()) 147 | 148 | 149 | def train_sample(model, model_loss, optimizer, sample, args): 150 | model.train() 151 | optimizer.zero_grad() 152 | 153 | sample_cuda = tocuda(sample) 154 | depth_gt_ms = sample_cuda["depth"] 155 | mask_ms = sample_cuda["mask"] 156 | 157 | num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) 158 | depth_gt = depth_gt_ms["stage{}".format(num_stage)] 159 | mask = mask_ms["stage{}".format(num_stage)] 160 | try: 161 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 162 | depth_est = outputs["depth"] 163 | 164 | loss, depth_loss, epe, less1, less3 = model_loss(outputs, depth_gt_ms, mask_ms, sample_cuda["depth_interval"], dlossw=[float(e) for e in args.dlossw.split(",") if e]) 165 | 166 | if np.isnan(loss.item()): 167 | raise NanError 168 | 169 | if is_distributed and args.using_apex: 170 | with amp.scale_loss(loss, optimizer) as scaled_loss: 171 | scaled_loss.backward() 172 | else: 173 | loss.backward() 174 | 175 | optimizer.step() 176 | 177 | except NanError: 178 | print(f'nan error accur!!') 179 | gc.collect() 180 | torch.cuda.empty_cache() 181 | 182 | scalar_outputs = {"loss": loss, 183 | "depth_loss": depth_loss, 184 | "epe": epe, 185 | "less1":less1, 186 | "less3":less3 } 187 | 188 | image_outputs = {"depth_est": depth_est * mask, 189 | "depth_est_nomask": depth_est, 190 | "depth_gt": sample["depth"]["stage1"], 191 | "ref_img": sample["imgs"][:, 0], 192 | "mask": sample["mask"]["stage1"], 193 | "errormap": (depth_est - depth_gt).abs() * mask, 194 | } 195 | 196 | 197 | if is_distributed: 198 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 199 | 200 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 201 | 202 | 203 | @make_nograd_func 204 | def test_sample_depth(model, model_loss, sample, args): 205 | if is_distributed: 206 | model_eval = model.module 207 | else: 208 | model_eval = model 209 | model_eval.eval() 210 | 211 | sample_cuda = tocuda(sample) 212 | depth_gt_ms = sample_cuda["depth"] 213 | mask_ms = sample_cuda["mask"] 214 | 215 | num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) 216 | depth_gt = depth_gt_ms["stage{}".format(num_stage)] 217 | mask = mask_ms["stage{}".format(num_stage)] 218 | 219 | outputs = model_eval(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 220 | depth_est = outputs["depth"] 221 | 222 | loss, depth_loss, epe, less1, less3 = model_loss(outputs, depth_gt_ms, mask_ms, sample_cuda["depth_interval"], dlossw=[float(e) for e in args.dlossw.split(",") if e]) 223 | 224 | scalar_outputs = {"loss": loss, 225 | "depth_loss": depth_loss, 226 | "epe": epe, 227 | "less1": less1, 228 | "less3": less3 229 | } 230 | 231 | image_outputs = {"depth_est": depth_est * mask, 232 | "depth_est_nomask": depth_est, 233 | "depth_gt": sample["depth"]["stage1"], 234 | "ref_img": sample["imgs"][:, 0], 235 | "mask": sample["mask"]["stage1"], 236 | "errormap": (depth_est - depth_gt).abs() * mask} 237 | 238 | if is_distributed: 239 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 240 | 241 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 242 | 243 | def profile(): 244 | warmup_iter = 5 245 | iter_dataloader = iter(TestImgLoader) 246 | 247 | @make_nograd_func 248 | def do_iteration(): 249 | torch.cuda.synchronize() 250 | torch.cuda.synchronize() 251 | start_time = time.perf_counter() 252 | test_sample_depth(next(iter_dataloader), detailed_summary=True) 253 | torch.cuda.synchronize() 254 | end_time = time.perf_counter() 255 | return end_time - start_time 256 | 257 | for i in range(warmup_iter): 258 | t = do_iteration() 259 | print('WarpUp Iter {}, time = {:.4f}'.format(i, t)) 260 | 261 | with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as prof: 262 | for i in range(5): 263 | t = do_iteration() 264 | print('Profile Iter {}, time = {:.4f}'.format(i, t)) 265 | time.sleep(0.02) 266 | 267 | if prof is not None: 268 | # print(prof) 269 | trace_fn = 'chrome-trace.bin' 270 | prof.export_chrome_trace(trace_fn) 271 | print("chrome trace file is written to: ", trace_fn) 272 | 273 | 274 | if __name__ == '__main__': 275 | # parse arguments and check 276 | args = parser.parse_args() 277 | 278 | # using sync_bn by using nvidia-apex, need to install apex. 279 | if args.sync_bn: 280 | assert args.using_apex, "must set using apex and install nvidia-apex" 281 | if args.using_apex: 282 | try: 283 | from apex.parallel import DistributedDataParallel as DDP 284 | from apex.fp16_utils import * 285 | from apex import amp, optimizers 286 | from apex.multi_tensor_apply import multi_tensor_applier 287 | except ImportError: 288 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 289 | 290 | if args.resume: 291 | assert args.mode == "train" 292 | assert args.loadckpt is None 293 | if args.testpath is None: 294 | args.testpath = args.trainpath 295 | 296 | if is_distributed: 297 | torch.cuda.set_device(args.local_rank) 298 | torch.distributed.init_process_group( 299 | backend="nccl", init_method="env://" 300 | ) 301 | synchronize() 302 | 303 | set_random_seed(args.seed) 304 | # device = torch.device(args.device) 305 | device = torch.device(args.local_rank) 306 | 307 | if (not is_distributed) or (dist.get_rank() == 0): 308 | # create logger for mode "train" and "testall" 309 | if args.mode == "train": 310 | if not os.path.isdir(args.logdir): 311 | os.makedirs(args.logdir) 312 | current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) 313 | print("current time", current_time_str) 314 | print("creating new summary file") 315 | logger = SummaryWriter(args.logdir) 316 | print("argv:", sys.argv[1:]) 317 | print_args(args) 318 | 319 | # model, optimizer 320 | model = TransMVSNet(refine=False, ndepths=[int(nd) for nd in args.ndepths.split(",") if nd], 321 | depth_interals_ratio=[float(d_i) for d_i in args.depth_inter_r.split(",") if d_i], 322 | share_cr=args.share_cr, 323 | cr_base_chs=[int(ch) for ch in args.cr_base_chs.split(",") if ch], 324 | grad_method=args.grad_method) 325 | model.to(device) 326 | model_loss = focal_loss_bld 327 | 328 | if args.sync_bn: 329 | import apex 330 | print("using apex synced BN") 331 | model = apex.parallel.convert_syncbn_model(model) 332 | 333 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd) 334 | 335 | # load parameters 336 | start_epoch = 0 337 | if args.resume: 338 | saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".link")] 339 | saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) 340 | # use the latest checkpoint file 341 | loadckpt = os.path.join(args.logdir, saved_models[-1]) 342 | print("resuming", loadckpt) 343 | state_dict = torch.load(loadckpt, map_location=torch.device("cpu")) 344 | # state_dict = load_checkpoint(loadckpt) 345 | model.load_state_dict(state_dict['model']) 346 | optimizer.load_state_dict(state_dict['optimizer']) 347 | start_epoch = state_dict['epoch'] + 1 348 | elif args.loadckpt: 349 | # load checkpoint file specified by args.loadckpt 350 | print("loading model {}".format(args.loadckpt)) 351 | state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) 352 | # state_dict = load_checkpoint(args.loadckpt) 353 | model.load_state_dict(state_dict['model']) 354 | 355 | if (not is_distributed) or (dist.get_rank() == 0): 356 | print("start at epoch {}".format(start_epoch)) 357 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 358 | 359 | if args.using_apex: 360 | # Initialize Amp 361 | model, optimizer = amp.initialize(model, optimizer, 362 | opt_level=args.opt_level, 363 | keep_batchnorm_fp32=args.keep_batchnorm_fp32, 364 | loss_scale=args.loss_scale 365 | ) 366 | 367 | if is_distributed: 368 | print("Let's use", torch.cuda.device_count(), "GPUs!") 369 | model = torch.nn.parallel.DistributedDataParallel( 370 | model, device_ids=[args.local_rank], output_device=args.local_rank, 371 | ) 372 | else: 373 | if torch.cuda.is_available(): 374 | print("Let's use", torch.cuda.device_count(), "GPUs!") 375 | model = nn.DataParallel(model) 376 | 377 | # dataset, dataloader 378 | MVSDataset = find_dataset_def(args.dataset) 379 | train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", args.nviews, args.numdepth, args.interval_scale) 380 | test_dataset = MVSDataset(args.testpath, args.testlist, "test", args.nviews, args.numdepth, args.interval_scale) 381 | 382 | if is_distributed: 383 | train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), 384 | rank=dist.get_rank()) 385 | test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), 386 | rank=dist.get_rank()) 387 | 388 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=1, 389 | drop_last=True, 390 | pin_memory=args.pin_m) 391 | TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=1, drop_last=False, 392 | pin_memory=args.pin_m) 393 | else: 394 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=1, drop_last=True, 395 | pin_memory=args.pin_m) 396 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=1, drop_last=False, 397 | pin_memory=args.pin_m) 398 | 399 | 400 | if args.mode == "train": 401 | train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args) 402 | elif args.mode == "test": 403 | test(model, model_loss, TestImgLoader, args) 404 | elif args.mode == "profile": 405 | profile() 406 | else: 407 | raise NotImplementedError 408 | -------------------------------------------------------------------------------- /gipuma.py: -------------------------------------------------------------------------------- 1 | import os, sys, shutil, gc 2 | from utils import * 3 | from datasets.data_io import read_pfm, save_pfm 4 | from struct import * 5 | import numpy as np 6 | 7 | # read intrinsics and extrinsics 8 | def read_camera_parameters(filename): 9 | with open(filename) as f: 10 | lines = f.readlines() 11 | lines = [line.rstrip() for line in lines] 12 | # extrinsics: line [1,5), 4x4 matrix 13 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 14 | # intrinsics: line [7-10), 3x3 matrix 15 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 16 | # TODO: assume the feature is 1/4 of the original image size 17 | # intrinsics[:2, :] /= 4 18 | return intrinsics, extrinsics 19 | 20 | def read_gipuma_dmb(path): 21 | '''read Gipuma .dmb format image''' 22 | 23 | with open(path, "rb") as fid: 24 | image_type = unpack(' 0, 1, 0)) 100 | mask_image = np.reshape(mask_image, (image_shape[0], image_shape[1], 1)) 101 | mask_image = np.tile(mask_image, [1, 1, 3]) 102 | mask_image = np.float32(mask_image) 103 | 104 | normal_image = np.multiply(normal_image, mask_image) 105 | normal_image = np.float32(normal_image) 106 | 107 | write_gipuma_dmb(out_normal_path, normal_image) 108 | return 109 | 110 | 111 | def mvsnet_to_gipuma(dense_folder, gipuma_point_folder): 112 | image_folder = os.path.join(dense_folder, 'images') 113 | cam_folder = os.path.join(dense_folder, 'cams') 114 | 115 | gipuma_cam_folder = os.path.join(gipuma_point_folder, 'cams') 116 | gipuma_image_folder = os.path.join(gipuma_point_folder, 'images') 117 | if not os.path.isdir(gipuma_point_folder): 118 | os.mkdir(gipuma_point_folder) 119 | if not os.path.isdir(gipuma_cam_folder): 120 | os.mkdir(gipuma_cam_folder) 121 | if not os.path.isdir(gipuma_image_folder): 122 | os.mkdir(gipuma_image_folder) 123 | 124 | # convert cameras 125 | image_names = os.listdir(image_folder) 126 | for image_name in image_names: 127 | image_prefix = os.path.splitext(image_name)[0] 128 | in_cam_file = os.path.join(cam_folder, image_prefix + '_cam.txt') 129 | out_cam_file = os.path.join(gipuma_cam_folder, image_name + '.P') 130 | mvsnet_to_gipuma_cam(in_cam_file, out_cam_file) 131 | 132 | # copy images to gipuma image folder 133 | image_names = os.listdir(image_folder) 134 | for image_name in image_names: 135 | in_image_file = os.path.join(image_folder, image_name) 136 | out_image_file = os.path.join(gipuma_image_folder, image_name) 137 | shutil.copy(in_image_file, out_image_file) 138 | 139 | # convert depth maps and fake normal maps 140 | gipuma_prefix = '2333__' 141 | for image_name in image_names: 142 | image_prefix = os.path.splitext(image_name)[0] 143 | sub_depth_folder = os.path.join(gipuma_point_folder, gipuma_prefix + image_prefix) 144 | if not os.path.isdir(sub_depth_folder): 145 | os.mkdir(sub_depth_folder) 146 | in_depth_pfm = os.path.join(dense_folder, "depth_est", image_prefix + '_prob_filtered.pfm') 147 | out_depth_dmb = os.path.join(sub_depth_folder, 'disp.dmb') 148 | fake_normal_dmb = os.path.join(sub_depth_folder, 'normals.dmb') 149 | mvsnet_to_gipuma_dmb(in_depth_pfm, out_depth_dmb) 150 | fake_gipuma_normal(out_depth_dmb, fake_normal_dmb) 151 | 152 | 153 | def probability_filter(dense_folder, prob_threshold): 154 | image_folder = os.path.join(dense_folder, 'images') 155 | 156 | # convert cameras 157 | image_names = os.listdir(image_folder) 158 | for image_name in image_names: 159 | image_prefix = os.path.splitext(image_name)[0] 160 | init_depth_map_path = os.path.join(dense_folder, "depth_est", image_prefix + '.pfm') 161 | prob_map_path = os.path.join(dense_folder, "confidence", image_prefix + '.pfm') 162 | out_depth_map_path = os.path.join(dense_folder, "depth_est", image_prefix + '_prob_filtered.pfm') 163 | 164 | depth_map, _ = read_pfm(init_depth_map_path) 165 | prob_map, _ = read_pfm(prob_map_path) 166 | depth_map[prob_map < prob_threshold] = 0 167 | save_pfm(out_depth_map_path, depth_map) 168 | 169 | 170 | def depth_map_fusion(point_folder, fusibile_exe_path, disp_thresh, num_consistent): 171 | cam_folder = os.path.join(point_folder, 'cams') 172 | image_folder = os.path.join(point_folder, 'images') 173 | depth_min = 0.001 174 | depth_max = 100000 175 | normal_thresh = 360 176 | 177 | cmd = fusibile_exe_path 178 | cmd = cmd + ' -input_folder ' + point_folder + '/' 179 | cmd = cmd + ' -p_folder ' + cam_folder + '/' 180 | cmd = cmd + ' -images_folder ' + image_folder + '/' 181 | cmd = cmd + ' --depth_min=' + str(depth_min) 182 | cmd = cmd + ' --depth_max=' + str(depth_max) 183 | cmd = cmd + ' --normal_thresh=' + str(normal_thresh) 184 | cmd = cmd + ' --disp_thresh=' + str(disp_thresh) 185 | cmd = cmd + ' --num_consistent=' + str(num_consistent) 186 | print(cmd) 187 | os.system(cmd) 188 | 189 | return 190 | 191 | 192 | def gipuma_filter(testlist, outdir, prob_threshold, disp_threshold, num_consistent, fusibile_exe_path): 193 | 194 | for scan in testlist: 195 | 196 | out_folder = os.path.join(outdir, scan) 197 | dense_folder = out_folder 198 | 199 | point_folder = os.path.join(dense_folder, 'points_mvsnet') 200 | if not os.path.isdir(point_folder): 201 | os.mkdir(point_folder) 202 | 203 | # probability filter 204 | print('filter depth map with probability map') 205 | probability_filter(dense_folder, prob_threshold) 206 | 207 | # convert to gipuma format 208 | print('Convert mvsnet output to gipuma input') 209 | mvsnet_to_gipuma(dense_folder, point_folder) 210 | 211 | # depth map fusion with gipuma 212 | print('Run depth map fusion & filter') 213 | depth_map_fusion(point_folder, fusibile_exe_path, disp_threshold, num_consistent) 214 | -------------------------------------------------------------------------------- /lists/bld/training_list.txt: -------------------------------------------------------------------------------- 1 | 5c1f33f1d33e1f2e4aa6dda4 2 | 5bfe5ae0fe0ea555e6a969ca 3 | 5bff3c5cfe0ea555e6bcbf3a 4 | 58eaf1513353456af3a1682a 5 | 5bfc9d5aec61ca1dd69132a2 6 | 5bf18642c50e6f7f8bdbd492 7 | 5bf26cbbd43923194854b270 8 | 5bf17c0fd439231948355385 9 | 5be3ae47f44e235bdbbc9771 10 | 5be3a5fb8cfdd56947f6b67c 11 | 5bbb6eb2ea1cfa39f1af7e0c 12 | 5ba75d79d76ffa2c86cf2f05 13 | 5bb7a08aea1cfa39f1a947ab 14 | 5b864d850d072a699b32f4ae 15 | 5b6eff8b67b396324c5b2672 16 | 5b6e716d67b396324c2d77cb 17 | 5b69cc0cb44b61786eb959bf 18 | 5b62647143840965efc0dbde 19 | 5b60fa0c764f146feef84df0 20 | 5b558a928bbfb62204e77ba2 21 | 5b271079e0878c3816dacca4 22 | 5b08286b2775267d5b0634ba 23 | 5afacb69ab00705d0cefdd5b 24 | 5af28cea59bc705737003253 25 | 5af02e904c8216544b4ab5a2 26 | 5aa515e613d42d091d29d300 27 | 5c34529873a8df509ae57b58 28 | 5c34300a73a8df509add216d 29 | 5c1af2e2bee9a723c963d019 30 | 5c1892f726173c3a09ea9aeb 31 | 5c0d13b795da9479e12e2ee9 32 | 5c062d84a96e33018ff6f0a6 33 | 5bfd0f32ec61ca1dd69dc77b 34 | 5bf21799d43923194842c001 35 | 5bf3a82cd439231948877aed 36 | 5bf03590d4392319481971dc 37 | 5beb6e66abd34c35e18e66b9 38 | 5be883a4f98cee15019d5b83 39 | 5be47bf9b18881428d8fbc1d 40 | 5bcf979a6d5f586b95c258cd 41 | 5bce7ac9ca24970bce4934b6 42 | 5bb8a49aea1cfa39f1aa7f75 43 | 5b78e57afc8fcf6781d0c3ba 44 | 5b21e18c58e2823a67a10dd8 45 | 5b22269758e2823a67a3bd03 46 | 5b192eb2170cf166458ff886 47 | 5ae2e9c5fe405c5076abc6b2 48 | 5adc6bd52430a05ecb2ffb85 49 | 5ab8b8e029f5351f7f2ccf59 50 | 5abc2506b53b042ead637d86 51 | 5ab85f1dac4291329b17cb50 52 | 5a969eea91dfc339a9a3ad2c 53 | 5a8aa0fab18050187cbe060e 54 | 5a7d3db14989e929563eb153 55 | 5a69c47d0d5d0a7f3b2e9752 56 | 5a618c72784780334bc1972d 57 | 5a6464143d809f1d8208c43c 58 | 5a588a8193ac3d233f77fbca 59 | 5a57542f333d180827dfc132 60 | 5a572fd9fc597b0478a81d14 61 | 5a563183425d0f5186314855 62 | 5a4a38dad38c8a075495b5d2 63 | 5a48d4b2c7dab83a7d7b9851 64 | 5a489fb1c7dab83a7d7b1070 65 | 5a48ba95c7dab83a7d7b44ed 66 | 5a3ca9cb270f0e3f14d0eddb 67 | 5a3cb4e4270f0e3f14d12f43 68 | 5a3f4aba5889373fbbc5d3b5 69 | 5a0271884e62597cdee0d0eb 70 | 59e864b2a9e91f2c5529325f 71 | 599aa591d5b41f366fed0d58 72 | 59350ca084b7f26bf5ce6eb8 73 | 59338e76772c3e6384afbb15 74 | 5c20ca3a0843bc542d94e3e2 75 | 5c1dbf200843bc542d8ef8c4 76 | 5c1b1500bee9a723c96c3e78 77 | 5bea87f4abd34c35e1860ab5 78 | 5c2b3ed5e611832e8aed46bf 79 | 57f8d9bbe73f6760f10e916a 80 | 5bf7d63575c26f32dbf7413b 81 | 5be4ab93870d330ff2dce134 82 | 5bd43b4ba6b28b1ee86b92dd 83 | 5bccd6beca24970bce448134 84 | 5bc5f0e896b66a2cd8f9bd36 85 | 5b908d3dc6ab78485f3d24a9 86 | 5b2c67b5e0878c381608b8d8 87 | 5b4933abf2b5f44e95de482a 88 | 5b3b353d8d46a939f93524b9 89 | 5acf8ca0f3d8a750097e4b15 90 | 5ab8713ba3799a1d138bd69a 91 | 5aa235f64a17b335eeaf9609 92 | 5aa0f9d7a9efce63548c69a1 93 | 5a8315f624b8e938486e0bd8 94 | 5a48c4e9c7dab83a7d7b5cc7 95 | 59ecfd02e225f6492d20fcc9 96 | 59f87d0bfa6280566fb38c9a 97 | 59f363a8b45be22330016cad 98 | 59f70ab1e5c5d366af29bf3e 99 | 59e75a2ca9e91f2c5526005d 100 | 5947719bf1b45630bd096665 101 | 5947b62af1b45630bd0c2a02 102 | 59056e6760bb961de55f3501 103 | 58f7f7299f5b5647873cb110 104 | 58cf4771d0f5fb221defe6da 105 | 58d36897f387231e6c929903 106 | 58c4bb4f4a69c55606122be4 107 | -------------------------------------------------------------------------------- /lists/bld/validation_list.txt: -------------------------------------------------------------------------------- 1 | 5b7a3890fc8fcf6781e2593a 2 | 5c189f2326173c3a09ed7ef3 3 | 5b950c71608de421b1e7318f 4 | 5a6400933d809f1d8200af15 5 | 59d2657f82ca7774b1ec081d 6 | 5ba19a8a360c7c30c1c169df 7 | 59817e4a1bd4b175e7038d19 8 | -------------------------------------------------------------------------------- /lists/dtu/test.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan4 3 | scan9 4 | scan10 5 | scan11 6 | scan12 7 | scan13 8 | scan15 9 | scan23 10 | scan24 11 | scan29 12 | scan32 13 | scan33 14 | scan34 15 | scan48 16 | scan49 17 | scan62 18 | scan75 19 | scan77 20 | scan110 21 | scan114 22 | scan118 -------------------------------------------------------------------------------- /lists/dtu/train.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 -------------------------------------------------------------------------------- /lists/dtu/trainval.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 80 | scan3 81 | scan5 82 | scan17 83 | scan21 84 | scan28 85 | scan35 86 | scan37 87 | scan38 88 | scan40 89 | scan43 90 | scan56 91 | scan59 92 | scan66 93 | scan67 94 | scan82 95 | scan86 96 | scan106 97 | scan117 -------------------------------------------------------------------------------- /lists/dtu/val.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan5 3 | scan17 4 | scan21 5 | scan28 6 | scan35 7 | scan37 8 | scan38 9 | scan40 10 | scan43 11 | scan56 12 | scan59 13 | scan66 14 | scan67 15 | scan82 16 | scan86 17 | scan106 18 | scan117 -------------------------------------------------------------------------------- /lists/tnt/adv.txt: -------------------------------------------------------------------------------- 1 | Auditorium 2 | Ballroom 3 | Courtroom 4 | Museum 5 | Palace 6 | Temple -------------------------------------------------------------------------------- /lists/tnt/inter.txt: -------------------------------------------------------------------------------- 1 | Family 2 | Francis 3 | Horse 4 | Lighthouse 5 | M60 6 | Panther 7 | Playground 8 | Train -------------------------------------------------------------------------------- /models/FMT.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import einops 7 | 8 | ''' 9 | - We provide two different positional encoding methods as shown below. 10 | - You can easily switch different pos-enc in the __init__() function of FMT. 11 | - In our experiments, PositionEncodingSuperGule usually cost more GPU memory. 12 | ''' 13 | from .position_encoding import PositionEncodingSuperGule, PositionEncodingSine 14 | 15 | 16 | class LinearAttention(nn.Module): 17 | def __init__(self, eps=1e-6): 18 | super(LinearAttention, self).__init__() 19 | self.feature_map = lambda x: torch.nn.functional.elu(x) + 1 20 | self.eps = eps 21 | 22 | def forward(self, queries, keys, values): 23 | Q = self.feature_map(queries) 24 | K = self.feature_map(keys) 25 | 26 | # Compute the KV matrix, namely the dot product of keys and values so 27 | # that we never explicitly compute the attention matrix and thus 28 | # decrease the complexity 29 | KV = torch.einsum("nshd,nshm->nhmd", K, values) 30 | 31 | # Compute the normalizer 32 | Z = 1/(torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1))+self.eps) 33 | 34 | # Finally compute and return the new values 35 | V = torch.einsum("nlhd,nhmd,nlh->nlhm", Q, KV, Z) 36 | 37 | return V.contiguous() 38 | 39 | 40 | class AttentionLayer(nn.Module): 41 | def __init__(self, attention, d_model, n_heads, d_keys=None, 42 | d_values=None): 43 | super(AttentionLayer, self).__init__() 44 | 45 | # Fill d_keys and d_values 46 | d_keys = d_keys or (d_model//n_heads) 47 | d_values = d_values or (d_model//n_heads) 48 | 49 | self.inner_attention = attention 50 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 51 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 52 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 53 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 54 | self.n_heads = n_heads 55 | 56 | def forward(self, queries, keys, values): 57 | # Extract the dimensions into local variables 58 | N, L, _ = queries.shape 59 | _, S, _ = keys.shape 60 | H = self.n_heads 61 | 62 | # Project the queries/keys/values 63 | queries = self.query_projection(queries).view(N, L, H, -1) 64 | keys = self.key_projection(keys).view(N, S, H, -1) 65 | values = self.value_projection(values).view(N, S, H, -1) 66 | 67 | # Compute the attention 68 | new_values = self.inner_attention( 69 | queries, 70 | keys, 71 | values, 72 | ).view(N, L, -1) 73 | 74 | # Project the output and return 75 | return self.out_projection(new_values) 76 | 77 | 78 | class EncoderLayer(nn.Module): 79 | def __init__(self, d_model, n_heads, d_keys=None, d_values=None, d_ff=None, dropout=0.0, 80 | activation="relu"): 81 | super(EncoderLayer, self).__init__() 82 | 83 | d_keys = d_keys or (d_model//n_heads) 84 | inner_attention = LinearAttention() 85 | attention = AttentionLayer(inner_attention, d_model, n_heads, d_keys, d_values) 86 | 87 | d_ff = d_ff or 2 * d_model 88 | self.attention = attention 89 | self.linear1 = nn.Linear(d_model, d_ff) 90 | self.linear2 = nn.Linear(d_ff, d_model) 91 | self.norm1 = nn.LayerNorm(d_model) 92 | self.norm2 = nn.LayerNorm(d_model) 93 | self.dropout = nn.Dropout(dropout) 94 | self.activation = getattr(F, activation) 95 | 96 | def forward(self, x, source): 97 | # Normalize the masks 98 | N = x.shape[0] 99 | L = x.shape[1] 100 | 101 | # Run self attention and add it to the input 102 | x = x + self.dropout(self.attention( 103 | x, source, source, 104 | )) 105 | 106 | # Run the fully connected part of the layer 107 | y = x = self.norm1(x) 108 | y = self.dropout(self.activation(self.linear1(y))) 109 | y = self.dropout(self.linear2(y)) 110 | 111 | return self.norm2(x+y) 112 | 113 | 114 | class FMT(nn.Module): 115 | def __init__(self, config): 116 | super(FMT, self).__init__() 117 | 118 | self.d_model = config['d_model'] 119 | self.nhead = config['nhead'] 120 | self.layer_names = config['layer_names'] 121 | encoder_layer = EncoderLayer(config['d_model'], config['nhead']) 122 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) 123 | self._reset_parameters() 124 | 125 | # self.pos_encoding = PositionEncodingSuperGule(config['d_model']) 126 | self.pos_encoding = PositionEncodingSine(config['d_model']) 127 | 128 | def _reset_parameters(self): 129 | for p in self.parameters(): 130 | if p.dim() > 1: 131 | nn.init.xavier_uniform_(p) 132 | 133 | def forward(self, ref_feature=None, src_feature=None, feat="ref"): 134 | """ 135 | Args: 136 | ref_feature(torch.Tensor): [N, C, H, W] 137 | src_feature(torch.Tensor): [N, C, H, W] 138 | """ 139 | 140 | assert ref_feature is not None 141 | 142 | if feat == "ref": # only self attention layer 143 | 144 | assert self.d_model == ref_feature.size(1) 145 | _, _, H, _ = ref_feature.shape 146 | 147 | ref_feature = einops.rearrange(self.pos_encoding(ref_feature), 'n c h w -> n (h w) c') 148 | 149 | ref_feature_list = [] 150 | for layer, name in zip(self.layers, self.layer_names): # every self attention layer 151 | if name == 'self': 152 | ref_feature = layer(ref_feature, ref_feature) 153 | ref_feature_list.append(einops.rearrange(ref_feature, 'n (h w) c -> n c h w', h=H)) 154 | return ref_feature_list 155 | 156 | elif feat == "src": 157 | 158 | assert self.d_model == ref_feature[0].size(1) 159 | _, _, H, _ = ref_feature[0].shape 160 | 161 | ref_feature = [einops.rearrange(_, 'n c h w -> n (h w) c') for _ in ref_feature] 162 | 163 | src_feature = einops.rearrange(self.pos_encoding(src_feature), 'n c h w -> n (h w) c') 164 | 165 | for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)): 166 | if name == 'self': 167 | src_feature = layer(src_feature, src_feature) 168 | elif name == 'cross': 169 | src_feature = layer(src_feature, ref_feature[i // 2]) 170 | else: 171 | raise KeyError 172 | return einops.rearrange(src_feature, 'n (h w) c -> n c h w', h=H) 173 | else: 174 | raise ValueError("Wrong feature name") 175 | 176 | 177 | 178 | class FMT_with_pathway(nn.Module): 179 | def __init__(self, 180 | base_channels=8, 181 | FMT_config={ 182 | 'd_model': 32, 183 | 'nhead': 8, 184 | 'layer_names': ['self', 'cross'] * 4}): 185 | 186 | super(FMT_with_pathway, self).__init__() 187 | 188 | self.FMT = FMT(FMT_config) 189 | 190 | self.dim_reduction_1 = nn.Conv2d(base_channels * 4, base_channels * 2, 1, bias=False) 191 | self.dim_reduction_2 = nn.Conv2d(base_channels * 2, base_channels * 1, 1, bias=False) 192 | 193 | self.smooth_1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, padding=1, bias=False) 194 | self.smooth_2 = nn.Conv2d(base_channels * 1, base_channels * 1, 3, padding=1, bias=False) 195 | 196 | def _upsample_add(self, x, y): 197 | """_upsample_add. Upsample and add two feature maps. 198 | 199 | :param x: top feature map to be upsampled. 200 | :param y: lateral feature map. 201 | """ 202 | 203 | _, _, H, W = y.size() 204 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 205 | 206 | 207 | def forward(self, features): 208 | """forward. 209 | 210 | :param features: multiple views and multiple stages features 211 | """ 212 | 213 | for nview_idx, feature_multi_stages in enumerate(features): 214 | if nview_idx == 0: # ref view 215 | ref_fea_t_list = self.FMT(feature_multi_stages["stage1"].clone(), feat="ref") 216 | feature_multi_stages["stage1"] = ref_fea_t_list[-1] 217 | feature_multi_stages["stage2"] = self.smooth_1(self._upsample_add(self.dim_reduction_1(feature_multi_stages["stage1"]), feature_multi_stages["stage2"])) 218 | feature_multi_stages["stage3"] = self.smooth_2(self._upsample_add(self.dim_reduction_2(feature_multi_stages["stage2"]), feature_multi_stages["stage3"])) 219 | 220 | else: # src view 221 | feature_multi_stages["stage1"] = self.FMT([_.clone() for _ in ref_fea_t_list], feature_multi_stages["stage1"].clone(), feat="src") 222 | feature_multi_stages["stage2"] = self.smooth_1(self._upsample_add(self.dim_reduction_1(feature_multi_stages["stage1"]), feature_multi_stages["stage2"])) 223 | feature_multi_stages["stage3"] = self.smooth_2(self._upsample_add(self.dim_reduction_2(feature_multi_stages["stage2"]), feature_multi_stages["stage3"])) 224 | 225 | return features 226 | -------------------------------------------------------------------------------- /models/TransMVSNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .module import * 5 | from .FMT import FMT_with_pathway 6 | 7 | Align_Corners_Range = False 8 | 9 | class PixelwiseNet(nn.Module): 10 | 11 | def __init__(self): 12 | 13 | super(PixelwiseNet, self).__init__() 14 | self.conv0 = ConvBnReLU3D(in_channels=1, out_channels=16, kernel_size=1, stride=1, pad=0) 15 | self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0) 16 | self.conv2 = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0) 17 | self.output = nn.Sigmoid() 18 | 19 | def forward(self, x1): 20 | """forward. 21 | 22 | :param x1: [B, 1, D, H, W] 23 | """ 24 | 25 | x1 = self.conv2(self.conv1(self.conv0(x1))).squeeze(1) # [B, D, H, W] 26 | output = self.output(x1) 27 | output = torch.max(output, dim=1, keepdim=True)[0] # [B, 1, H ,W] 28 | 29 | return output 30 | 31 | 32 | class DepthNet(nn.Module): 33 | def __init__(self): 34 | super(DepthNet, self).__init__() 35 | self.pixel_wise_net = PixelwiseNet() 36 | 37 | def forward(self, features, proj_matrices, depth_values, num_depth, cost_regularization, prob_volume_init=None, view_weights=None): 38 | """forward. 39 | 40 | :param stage_idx: int, index of stage, [1, 2, 3], stage_1 corresponds to lowest image resolution 41 | :param features: torch.Tensor, TODO: [B, C, H, W] 42 | :param proj_matrices: torch.Tensor, 43 | :param depth_values: torch.Tensor, TODO: [B, D, H, W] 44 | :param num_depth: int, Ndepth 45 | :param cost_regularization: nn.Module, regularization network 46 | :param prob_volume_init: 47 | :param view_weights: pixel wise view weights for src views 48 | """ 49 | proj_matrices = torch.unbind(proj_matrices, 1) 50 | assert len(features) == len(proj_matrices), "Different number of images and projection matrices" 51 | assert depth_values.shape[1] == num_depth, "depth_values.shape[1]:{} num_depth:{}".format(depth_values.shapep[1], num_depth) 52 | 53 | # step 1. feature extraction 54 | ref_feature, src_features = features[0], features[1:] # [B, C, H, W] 55 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] # [B, 2, 4, 4] 56 | 57 | # step 2. differentiable homograph, build cost volume 58 | if view_weights == None: 59 | view_weight_list = [] 60 | 61 | similarity_sum = 0 62 | pixel_wise_weight_sum = 1e-5 63 | 64 | for i, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)): # src_fea: [B, C, H, W] 65 | src_proj_new = src_proj[:, 0].clone() # [B, 4, 4] 66 | src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) 67 | ref_proj_new = ref_proj[:, 0].clone() # [B, 4, 4] 68 | ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) 69 | warped_volume = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_values) 70 | similarity = (warped_volume * ref_feature.unsqueeze(2)).mean(1, keepdim=True) 71 | 72 | if view_weights == None: 73 | view_weight = self.pixel_wise_net(similarity) # [B, 1, H, W] 74 | view_weight_list.append(view_weight) 75 | else: 76 | view_weight = view_weights[:, i:i+1] 77 | 78 | if self.training: 79 | similarity_sum = similarity_sum + similarity * view_weight.unsqueeze(1) # [B, 1, D, H, W] 80 | pixel_wise_weight_sum = pixel_wise_weight_sum + view_weight.unsqueeze(1) # [B, 1, 1, H, W] 81 | else: 82 | # TODO: this is only a temporal solution to save memory, better way? 83 | similarity_sum += similarity * view_weight.unsqueeze(1) 84 | pixel_wise_weight_sum += view_weight.unsqueeze(1) 85 | 86 | del warped_volume 87 | # aggregate multiple similarity across all the source views 88 | similarity = similarity_sum.div_(pixel_wise_weight_sum) # [B, 1, D, H, W] 89 | 90 | # step 3. cost volume regularization 91 | cost_reg = cost_regularization(similarity) 92 | prob_volume_pre = cost_reg.squeeze(1) 93 | 94 | if prob_volume_init is not None: 95 | prob_volume_pre += prob_volume_init 96 | 97 | prob_volume = torch.exp(F.log_softmax(prob_volume_pre, dim=1)) 98 | depth = depth_wta(prob_volume, depth_values=depth_values) 99 | 100 | with torch.no_grad(): 101 | photometric_confidence = torch.max(prob_volume, dim=1)[0] 102 | if view_weights == None: 103 | view_weights = torch.cat(view_weight_list, dim=1) # [B, Nview, H, W] 104 | return {"depth": depth, "photometric_confidence": photometric_confidence, "prob_volume": prob_volume, "depth_values": depth_values}, view_weights.detach() 105 | else: 106 | return {"depth": depth, "photometric_confidence": photometric_confidence, "prob_volume": prob_volume, "depth_values": depth_values} 107 | 108 | 109 | class TransMVSNet(nn.Module): 110 | def __init__(self, refine=False, ndepths=[48, 32, 8], depth_interals_ratio=[4, 2, 1], share_cr=False, 111 | grad_method="detach", arch_mode="fpn", cr_base_chs=[8, 8, 8]): 112 | super(TransMVSNet, self).__init__() 113 | self.refine = refine 114 | self.share_cr = share_cr 115 | self.ndepths = ndepths 116 | self.depth_interals_ratio = depth_interals_ratio 117 | self.grad_method = grad_method 118 | self.arch_mode = arch_mode 119 | self.cr_base_chs = cr_base_chs 120 | self.num_stage = len(ndepths) 121 | print("**********netphs:{}, depth_intervals_ratio:{}, grad:{}, chs:{}************".format(ndepths, 122 | depth_interals_ratio, self.grad_method, self.cr_base_chs)) 123 | 124 | assert len(ndepths) == len(depth_interals_ratio) 125 | 126 | self.stage_infos = { 127 | "stage1":{ 128 | "scale": 4.0, 129 | }, 130 | "stage2": { 131 | "scale": 2.0, 132 | }, 133 | "stage3": { 134 | "scale": 1.0, 135 | } 136 | } 137 | 138 | self.feature = FeatureNet(base_channels=8) 139 | 140 | self.FMT_with_pathway = FMT_with_pathway() 141 | 142 | if self.share_cr: 143 | self.cost_regularization = CostRegNet(in_channels=1, base_channels=8) 144 | else: 145 | self.cost_regularization = nn.ModuleList([CostRegNet(in_channels=1, base_channels=self.cr_base_chs[i]) 146 | for i in range(self.num_stage)]) 147 | if self.refine: 148 | self.refine_network = RefineNet() 149 | 150 | self.DepthNet = DepthNet() 151 | 152 | def forward(self, imgs, proj_matrices, depth_values, test_tnt=False): 153 | depth_min = float(depth_values[0, 0].cpu().numpy()) 154 | depth_max = float(depth_values[0, -1].cpu().numpy()) 155 | depth_interval = (depth_max - depth_min) / depth_values.size(1) 156 | 157 | # step 1. feature extraction 158 | features = [] 159 | for nview_idx in range(imgs.size(1)): 160 | img = imgs[:, nview_idx] 161 | features.append(self.feature(img)) 162 | 163 | features = self.FMT_with_pathway(features) 164 | 165 | outputs = {} 166 | depth, cur_depth = None, None 167 | view_weights = None 168 | for stage_idx in range(self.num_stage): 169 | features_stage = [feat["stage{}".format(stage_idx + 1)] for feat in features] 170 | proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)] 171 | stage_scale = self.stage_infos["stage{}".format(stage_idx + 1)]["scale"] 172 | 173 | Using_inverse_d = False 174 | 175 | if depth is not None: 176 | if self.grad_method == "detach": 177 | cur_depth = depth.detach() 178 | else: 179 | cur_depth = depth 180 | cur_depth = F.interpolate(cur_depth.unsqueeze(1), 181 | [img.shape[2], img.shape[3]], mode='bilinear', 182 | align_corners=Align_Corners_Range).squeeze(1) 183 | else: 184 | cur_depth = depth_values 185 | 186 | # [B, D, H, W] 187 | depth_range_samples = get_depth_range_samples(cur_depth=cur_depth, 188 | ndepth=self.ndepths[stage_idx], 189 | depth_inteval_pixel=self.depth_interals_ratio[stage_idx] * depth_interval, 190 | dtype=img[0].dtype, 191 | device=img[0].device, 192 | shape=[img.shape[0], img.shape[2], img.shape[3]], 193 | max_depth=depth_max, 194 | min_depth=depth_min, 195 | use_inverse_depth=Using_inverse_d) 196 | 197 | if stage_idx + 1 > 1: # for stage 2 and 3 198 | view_weights = F.interpolate(view_weights, scale_factor=2, mode="nearest") 199 | 200 | if view_weights == None: # stage 1 201 | outputs_stage, view_weights = self.DepthNet( 202 | features_stage, 203 | proj_matrices_stage, 204 | depth_values=F.interpolate(depth_range_samples.unsqueeze(1), [self.ndepths[stage_idx], img.shape[2]//int(stage_scale), img.shape[3]//int(stage_scale)], mode='trilinear', align_corners=Align_Corners_Range).squeeze(1), 205 | num_depth=self.ndepths[stage_idx], 206 | cost_regularization=self.cost_regularization if self.share_cr else self.cost_regularization[stage_idx], view_weights=view_weights) 207 | else: 208 | outputs_stage = self.DepthNet( 209 | features_stage, 210 | proj_matrices_stage, 211 | depth_values=F.interpolate(depth_range_samples.unsqueeze(1), [self.ndepths[stage_idx], img.shape[2]//int(stage_scale), img.shape[3]//int(stage_scale)], mode='trilinear', align_corners=Align_Corners_Range).squeeze(1), 212 | num_depth=self.ndepths[stage_idx], 213 | cost_regularization=self.cost_regularization if self.share_cr else self.cost_regularization[stage_idx], view_weights=view_weights) 214 | 215 | wta_index_map = torch.argmax(outputs_stage['prob_volume'], dim=1, keepdim=True).type(torch.long) 216 | depth = torch.gather(outputs_stage['depth_values'], 1, wta_index_map).squeeze(1) 217 | outputs_stage['depth'] = depth 218 | 219 | outputs["stage{}".format(stage_idx + 1)] = outputs_stage 220 | outputs.update(outputs_stage) 221 | 222 | if self.refine: 223 | refined_depth = self.refine_network(torch.cat((imgs[:, 0], depth), 1)) 224 | outputs["refined_depth"] = refined_depth 225 | 226 | return outputs 227 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.TransMVSNet import TransMVSNet, trans_mvsnet_loss 2 | -------------------------------------------------------------------------------- /models/dcn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.autograd import Function 10 | from torch.nn.modules.utils import _pair 11 | from torch.autograd.function import once_differentiable 12 | from torchvision.ops import DeformConv2d, deform_conv2d 13 | 14 | 15 | class DCNv2(nn.Module): 16 | 17 | def __init__(self, in_channels, out_channels, 18 | kernel_size, stride, padding, dilation=1, deformable_groups=1): 19 | super(DCNv2, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _pair(kernel_size) 23 | self.stride = _pair(stride) 24 | self.padding = _pair(padding) 25 | self.dilation = _pair(dilation) 26 | self.deformable_groups = deformable_groups 27 | 28 | self.weight = nn.Parameter(torch.Tensor( 29 | out_channels, in_channels, *self.kernel_size)) 30 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | n = self.in_channels 35 | for k in self.kernel_size: 36 | n *= k 37 | stdv = 1. / math.sqrt(n) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | self.bias.data.zero_() 40 | 41 | 42 | 43 | class DCN(DCNv2): 44 | 45 | def __init__(self, in_channels, out_channels, 46 | kernel_size, stride, padding, 47 | dilation=1, deformable_groups=1,bias=True): 48 | super(DCN, self).__init__(in_channels, out_channels, 49 | kernel_size, stride, padding, dilation, deformable_groups) 50 | 51 | channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] 52 | self.conv_offset_mask = nn.Conv2d(self.in_channels, 53 | channels_, 54 | kernel_size=self.kernel_size, 55 | stride=self.stride, 56 | padding=self.padding, 57 | bias=True) 58 | if bias==False: 59 | self.bias = None 60 | self.init_offset() 61 | 62 | def init_offset(self): 63 | self.conv_offset_mask.weight.data.zero_() 64 | self.conv_offset_mask.bias.data.zero_() 65 | 66 | def forward(self, input): 67 | out = self.conv_offset_mask(input) 68 | o1, o2, mask = torch.chunk(out, 3, dim=1) 69 | offset = torch.cat((o1, o2), dim=1) 70 | mask = torch.sigmoid(mask) 71 | return deform_conv2d( 72 | input, 73 | offset, 74 | self.weight, 75 | self.bias, 76 | self.stride, 77 | self.padding, 78 | self.dilation, 79 | mask=mask 80 | ) 81 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | 6 | class PositionEncodingSuperGule(nn.Module): 7 | def __init__(self,d_model): 8 | super().__init__() 9 | self.d_model=d_model 10 | self.kenc = KeypointEncoder(d_model, [32, 64]) 11 | 12 | def forward(self,x): 13 | # x : N,C,H,W 14 | y_position = torch.ones((x.shape[2], x.shape[3])).cumsum(0).float().unsqueeze(0).to(x) 15 | x_position = torch.ones((x.shape[2], x.shape[3])).cumsum(1).float().unsqueeze(0).to(x) 16 | xy_position = torch.cat([x_position, y_position]) - 1 17 | xy_position = xy_position.view(2, -1).permute(1, 0).repeat(x.shape[0], 1, 1) 18 | xy_position_n = normalize_keypoints(xy_position, x.shape) 19 | ret = x + self.kenc(xy_position_n).view(x.shape) 20 | return ret 21 | 22 | 23 | class PositionEncodingSine(nn.Module): 24 | """ 25 | This is a sinusoidal position encoding that generalized to 2-dimensional images 26 | """ 27 | 28 | def __init__(self, d_model, max_shape=(600, 600), temp_bug_fix=True): 29 | """ 30 | Args: 31 | max_shape (tuple): for 1/4 featmap, the max length of 600 corresponds to 2400 pixels 32 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), 33 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact 34 | on the final performance. For now, we keep both impls for backward compatability. 35 | We will remove the buggy impl after re-training all variants of our released models. 36 | """ 37 | super().__init__() 38 | 39 | pe = torch.zeros((d_model, *max_shape)) 40 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 41 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 42 | if temp_bug_fix: 43 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) 44 | else: # a buggy implementation (for backward compatability only) 45 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) 46 | div_term = div_term[:, None, None] # [C//4, 1, 1] 47 | pe[0::4, :, :] = torch.sin(x_position * div_term) 48 | pe[1::4, :, :] = torch.cos(x_position * div_term) 49 | pe[2::4, :, :] = torch.sin(y_position * div_term) 50 | pe[3::4, :, :] = torch.cos(y_position * div_term) 51 | 52 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] 53 | # self.register_buffer('pe11', pe.unsqueeze(0)) # [1, C, H, W] 54 | 55 | def forward(self, x): 56 | """ 57 | Args: 58 | x: [N, C, H, W] 59 | """ 60 | return x + self.pe[:, :, :x.size(2), :x.size(3)] 61 | 62 | 63 | def MLP(channels: list, do_bn=True): 64 | """ Multi-layer perceptron """ 65 | n = len(channels) 66 | layers = [] 67 | for i in range(1, n): 68 | layers.append( 69 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 70 | if i < (n-1): 71 | if do_bn: 72 | layers.append(nn.BatchNorm1d(channels[i])) 73 | layers.append(nn.ReLU()) 74 | return nn.Sequential(*layers) 75 | 76 | 77 | def normalize_keypoints(kpts, image_shape): 78 | """ Normalize keypoints locations based on image image_shape""" 79 | _, _, height, width = image_shape 80 | one = kpts.new_tensor(1) 81 | size = torch.stack([one*width, one*height])[None] 82 | center = size / 2 83 | scaling = size.max(1, keepdim=True).values * 0.7 84 | return (kpts - center[:, None, :]) / scaling[:, None, :] 85 | 86 | 87 | class KeypointEncoder(nn.Module): 88 | """ Joint encoding of visual appearance and location using MLPs""" 89 | def __init__(self, feature_dim, layers): 90 | super().__init__() 91 | 92 | self.encoder = MLP([2] + layers + [feature_dim]) 93 | nn.init.constant_(self.encoder[-1].bias, 0.0) 94 | 95 | def forward(self, kpts): 96 | inputs = kpts.transpose(1, 2) 97 | return self.encoder(inputs) 98 | 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.1 2 | matplotlib==3.3.4 3 | multiprocess==0.70.12.2 4 | open3d==0.13.0 5 | opencv-python==4.5.3.56 6 | Pillow==8.3.2 7 | plyfile==0.7.4 8 | tensorboard==2.6.0 9 | tensorboard-data-server==0.6.1 10 | tensorboard-plugin-wit==1.8.0 11 | tensorboardX==2.4 12 | torch==1.9.1 13 | torchvision==0.10.1 14 | tornado==6.1 15 | tqdm==4.62.3 16 | trimesh==3.9.41 -------------------------------------------------------------------------------- /scripts/test_dtu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | TESTPATH="/data/DTU/dtu-test" # path to dataset dtu_test 3 | TESTLIST="lists/dtu/test.txt" 4 | CKPT_FILE="checkpoints/model_dtu.ckpt" # path to checkpoint file, you need to use the model_dtu.ckpt for testing 5 | FUSIBLE_PATH="" # path to fusible of gipuma 6 | OUTDIR="outputs/dtu_testing" # path to output 7 | if [ ! -d $OUTDIR ]; then 8 | mkdir -p $OUTDIR 9 | fi 10 | 11 | 12 | python test.py \ 13 | --dataset=general_eval \ 14 | --batch_size=1 \ 15 | --testpath=$TESTPATH \ 16 | --testlist=$TESTLIST \ 17 | --loadckpt=$CKPT_FILE \ 18 | --outdir=$OUTDIR \ 19 | --numdepth=192 \ 20 | --ndepths="48,32,8" \ 21 | --depth_inter_r="4.0,1.0,0.5" \ 22 | --interval_scale=1.06 \ 23 | --filter_method="gipuma" \ 24 | --fusibile_exe_path=$FUSIBLE_PATH 25 | #--filter_method="normal" 26 | 27 | -------------------------------------------------------------------------------- /scripts/test_tnt.sh: -------------------------------------------------------------------------------- 1 | # run this script in the root path of TransMVSNet 2 | TESTPATH="/data/TankandTemples/intermediate" # path to dataset 3 | TESTLIST="lists/tnt/inter.txt" # "lists/tnt/adv.txt" 4 | CKPT_FILE="checkpoints/model_bld.ckpt" # path to checkpoint 5 | OUTDIR="outputs/tnt_testing/inter" # path to save the results 6 | if [ ! -d $OUTDIR ]; then 7 | mkdir -p $OUTDIR 8 | fi 9 | 10 | 11 | python test.py \ 12 | --dataset=tnt_eval \ 13 | --num_view=10 \ 14 | --batch_size=1 \ 15 | --interval_scale=1.0 \ 16 | --numdepth=192 \ 17 | --ndepths="48,32,8" \ 18 | --depth_inter_r="4,1,0.5" \ 19 | --testpath=$TESTPATH \ 20 | --testlist=$TESTLIST \ 21 | --outdir=$OUTDIR \ 22 | --filter_method="dynamic" \ 23 | --loadckpt $CKPT_FILE ${@:2} 24 | 25 | 26 | python dynamic_fusion.py \ 27 | --testpath=$OUTDIR \ 28 | --tntpath=$TESTPATH \ 29 | --testlist=$TESTLIST \ 30 | --outdir=$OUTDIR \ 31 | --photo_threshold=0.18 \ 32 | --thres_view=5 \ 33 | --test_dataset=tnt -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # run this script in the root path of TransMVSNet 3 | MVS_TRAINING="/data/DTU/mvs_training/dtu/" # path to dataset mvs_training 4 | LOG_DIR="./outputs/dtu_training" # path to checkpoints 5 | if [ ! -d $LOG_DIR ]; then 6 | mkdir -p $LOG_DIR 7 | fi 8 | 9 | NGPUS=8 10 | BATCH_SIZE=1 11 | python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py \ 12 | --logdir=$LOG_DIR \ 13 | --dataset=dtu_yao \ 14 | --batch_size=$BATCH_SIZE \ 15 | --epochs=16 \ 16 | --trainpath=$MVS_TRAINING \ 17 | --trainlist=lists/dtu/train.txt \ 18 | --testlist=lists/dtu/val.txt \ 19 | --numdepth=192 \ 20 | --ndepths="48,32,8" \ 21 | --nviews=5 \ 22 | --wd=0.0001 \ 23 | --depth_inter_r="4.0,1.0,0.5" \ 24 | --lrepochs="6,8,12:2" \ 25 | --dlossw="1.0,1.0,1.0" | tee -a $LOG_DIR/log.txt 26 | -------------------------------------------------------------------------------- /scripts/train_bld_fintune.sh: -------------------------------------------------------------------------------- 1 | # run this script in the root path of TransMVSNet 2 | MVS_TRAINING="/data/BlendedMVS/" # path to BlendedMVS dataset 3 | CKPT="checkpoints/model_dtu.ckpt" # path to checkpoint 4 | LOG_DIR="outputs/bld_finetune" 5 | if [ ! -d $LOG_DIR ]; then 6 | mkdir -p $LOG_DIR 7 | fi 8 | 9 | 10 | NGPUS=8 11 | BATCH_SIZE=1 12 | python -m torch.distributed.launch --nproc_per_node=$NGPUS finetune.py \ 13 | --logdir=$LOG_DIR \ 14 | --dataset=bld_train \ 15 | --trainpath=$MVS_TRAINING \ 16 | --ndepths="48,32,8" \ 17 | --depth_inter_r="4,1,0.5" \ 18 | --dlossw="1.0,1.0,1.0" \ 19 | --loadckpt=$CKPT \ 20 | --eval_freq=1 \ 21 | --wd=0.0001 \ 22 | --nviews=4 \ 23 | --batch_size=$BATCH_SIZE \ 24 | --lr=0.0002 \ 25 | --lrepochs="6,10,14:2" \ 26 | --epochs=16 \ 27 | --trainlist=lists/bld/training_list.txt \ 28 | --testlist=lists/bld/validation_list.txt \ 29 | --numdepth=192 ${@:3} | tee -a $LOG_DIR/log.txt 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse, os, time, sys, gc, cv2, signal 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | from datasets import find_dataset_def 10 | from models import * 11 | from utils import * 12 | from datasets.data_io import read_pfm, save_pfm 13 | from plyfile import PlyData, PlyElement 14 | from PIL import Image 15 | from gipuma import gipuma_filter 16 | from multiprocessing import Pool 17 | from functools import partial 18 | 19 | cudnn.benchmark = True 20 | 21 | parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse') 22 | parser.add_argument('--model', default='mvsnet', help='select model') 23 | parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset') 24 | parser.add_argument('--testpath', help='testing data dir for some scenes') 25 | parser.add_argument('--testpath_single_scene', help='testing data path for single scene') 26 | parser.add_argument('--testlist', help='testing scene list') 27 | parser.add_argument('--batch_size', type=int, default=1, help='testing batch size') 28 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 29 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 30 | parser.add_argument('--outdir', default='./outputs', help='output dir') 31 | parser.add_argument('--display', action='store_true', help='display depth images and masks') 32 | parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization') 33 | parser.add_argument('--ndepths', type=str, default="48,32,8", help='ndepths') 34 | parser.add_argument('--depth_inter_r', type=str, default="4,2,1", help='depth_intervals_ratio') 35 | parser.add_argument('--cr_base_chs', type=str, default="8,8,8", help='cost regularization base channels') 36 | parser.add_argument('--grad_method', type=str, default="detach", choices=["detach", "undetach"], help='grad method') 37 | parser.add_argument('--interval_scale', type=float, required=True, help='the depth interval scale') 38 | parser.add_argument('--num_view', type=int, default=5, help='num of view') 39 | parser.add_argument('--max_h', type=int, default=864, help='testing max h') 40 | parser.add_argument('--max_w', type=int, default=1152, help='testing max w') 41 | parser.add_argument('--fix_res', action='store_true', help='scene all using same res') 42 | parser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker') 43 | parser.add_argument('--save_freq', type=int, default=20, help='save freq of local pcd') 44 | parser.add_argument('--filter_method', type=str, default='normal', choices=["gipuma", "normal", "dynamic"], help="filter method") 45 | #filter 46 | parser.add_argument('--conf', type=float, default=0.03, help='prob confidence') 47 | parser.add_argument('--thres_view', type=int, default=5, help='threshold of num view') 48 | #filter by gimupa 49 | parser.add_argument('--fusibile_exe_path', type=str, default='../fusibile/fusibile') 50 | parser.add_argument('--prob_threshold', type=float, default='0.01') 51 | parser.add_argument('--disp_threshold', type=float, default='0.25') 52 | parser.add_argument('--num_consistent', type=float, default='3') 53 | # parse arguments and check 54 | args = parser.parse_args() 55 | print("argv:", sys.argv[1:]) 56 | print_args(args) 57 | if args.testpath_single_scene: 58 | args.testpath = os.path.dirname(args.testpath_single_scene) 59 | 60 | num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) 61 | 62 | Interval_Scale = args.interval_scale 63 | print("***********Interval_Scale**********\n", Interval_Scale) 64 | 65 | 66 | # read intrinsics and extrinsics 67 | def read_camera_parameters(filename): 68 | with open(filename) as f: 69 | lines = f.readlines() 70 | lines = [line.rstrip() for line in lines] 71 | # extrinsics: line [1,5), 4x4 matrix 72 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 73 | # intrinsics: line [7-10), 3x3 matrix 74 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 75 | return intrinsics, extrinsics 76 | 77 | 78 | # read an image 79 | def read_img(filename): 80 | img = Image.open(filename) 81 | # scale 0~255 to 0~1 82 | np_img = np.array(img, dtype=np.float32) / 255. 83 | return np_img 84 | 85 | 86 | # read a binary mask 87 | def read_mask(filename): 88 | return read_img(filename) > 0.5 89 | 90 | 91 | # save a binary mask 92 | def save_mask(filename, mask): 93 | assert mask.dtype == np.bool 94 | mask = mask.astype(np.uint8) * 255 95 | Image.fromarray(mask).save(filename) 96 | 97 | 98 | # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] 99 | def read_pair_file(filename): 100 | data = [] 101 | with open(filename) as f: 102 | num_viewpoint = int(f.readline()) 103 | # 49 viewpoints 104 | for view_idx in range(num_viewpoint): 105 | ref_view = int(f.readline().rstrip()) 106 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 107 | if len(src_views) > 0: 108 | data.append((ref_view, src_views)) 109 | return data 110 | 111 | def write_cam(file, cam): 112 | f = open(file, "w") 113 | f.write('extrinsic\n') 114 | for i in range(0, 4): 115 | for j in range(0, 4): 116 | f.write(str(cam[0][i][j]) + ' ') 117 | f.write('\n') 118 | f.write('\n') 119 | 120 | f.write('intrinsic\n') 121 | for i in range(0, 3): 122 | for j in range(0, 3): 123 | f.write(str(cam[1][i][j]) + ' ') 124 | f.write('\n') 125 | 126 | f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') 127 | 128 | f.close() 129 | 130 | def save_depth(testlist): 131 | 132 | for scene in testlist: 133 | save_scene_depth([scene]) 134 | 135 | # run CasMVS model to save depth maps and confidence maps 136 | def save_scene_depth(testlist): 137 | # dataset, dataloader 138 | MVSDataset = find_dataset_def(args.dataset) 139 | test_dataset = MVSDataset(args.testpath, testlist, "test", args.num_view, args.numdepth, Interval_Scale, 140 | max_h=args.max_h, max_w=args.max_w, fix_res=args.fix_res) 141 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) 142 | 143 | # model 144 | model = TransMVSNet(refine=False, ndepths=[int(nd) for nd in args.ndepths.split(",") if nd], 145 | depth_interals_ratio=[float(d_i) for d_i in args.depth_inter_r.split(",") if d_i], 146 | share_cr=args.share_cr, 147 | cr_base_chs=[int(ch) for ch in args.cr_base_chs.split(",") if ch], 148 | grad_method=args.grad_method) 149 | 150 | # load checkpoint file specified by args.loadckpt 151 | print("loading model {}".format(args.loadckpt)) 152 | state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) 153 | model.load_state_dict(state_dict['model'], strict=True) 154 | model = nn.DataParallel(model) 155 | model.cuda() 156 | model.eval() 157 | 158 | with torch.no_grad(): 159 | for batch_idx, sample in enumerate(TestImgLoader): 160 | sample_cuda = tocuda(sample) 161 | start_time = time.time() 162 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 163 | end_time = time.time() 164 | outputs = tensor2numpy(outputs) 165 | del sample_cuda 166 | filenames = sample["filename"] 167 | cams = sample["proj_matrices"]["stage{}".format(num_stage)].numpy() 168 | imgs = sample["imgs"].numpy() 169 | print('Iter {}/{}, Time:{} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape)) 170 | 171 | # save depth maps and confidence maps 172 | for filename, cam, img, depth_est, photometric_confidence, conf_1, conf_2 in zip(filenames, cams, imgs, \ 173 | outputs["depth"], outputs["photometric_confidence"], outputs['stage1']["photometric_confidence"], outputs['stage2']["photometric_confidence"]): 174 | img = img[0] #ref view 175 | cam = cam[0] #ref cam 176 | H,W = photometric_confidence.shape 177 | conf_1 = cv2.resize(conf_1, (W,H)) 178 | conf_2 = cv2.resize(conf_2, (W,H)) 179 | conf_final = photometric_confidence * conf_1 * conf_2 180 | 181 | depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm')) 182 | confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm')) 183 | cam_filename = os.path.join(args.outdir, filename.format('cams', '_cam.txt')) 184 | img_filename = os.path.join(args.outdir, filename.format('images', '.jpg')) 185 | ply_filename = os.path.join(args.outdir, filename.format('ply_local', '.ply')) 186 | os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) 187 | os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True) 188 | os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True) 189 | os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True) 190 | os.makedirs(ply_filename.rsplit('/', 1)[0], exist_ok=True) 191 | #save depth maps 192 | save_pfm(depth_filename, depth_est) 193 | depth_color = visualize_depth(depth_est) 194 | cv2.imwrite(os.path.join(args.outdir, filename.format('depth_est', '.png')), depth_color) 195 | #save confidence maps 196 | save_pfm(confidence_filename, conf_final) 197 | cv2.imwrite(os.path.join(args.outdir, filename.format('confidence', '_3.png')), visualize_depth(photometric_confidence)) 198 | cv2.imwrite(os.path.join(args.outdir, filename.format('confidence', '_1.png')),visualize_depth(conf_1)) 199 | cv2.imwrite(os.path.join(args.outdir, filename.format('confidence', '_2.png')),visualize_depth(conf_2)) 200 | cv2.imwrite(os.path.join(args.outdir, filename.format('confidence', '_final.png')),visualize_depth(conf_final)) 201 | #save cams, img 202 | write_cam(cam_filename, cam) 203 | img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8) 204 | img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 205 | cv2.imwrite(img_filename, img_bgr) 206 | 207 | # if num_stage == 1: 208 | # downsample_img = cv2.resize(img, (int(img.shape[1] * 0.25), int(img.shape[0] * 0.25))) 209 | # elif num_stage == 2: 210 | # downsample_img = cv2.resize(img, (int(img.shape[1] * 0.5), int(img.shape[0] * 0.5))) 211 | # elif num_stage == 3: 212 | # downsample_img = img 213 | 214 | # if batch_idx % args.save_freq == 0: 215 | # generate_pointcloud(downsample_img, depth_est, ply_filename, cam[1, :3, :3]) 216 | 217 | torch.cuda.empty_cache() 218 | gc.collect() 219 | 220 | 221 | # project the reference point cloud into the source view, then project back 222 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 223 | width, height = depth_ref.shape[1], depth_ref.shape[0] 224 | ## step1. project reference pixels to the source view 225 | # reference view x, y 226 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 227 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 228 | # reference 3D space 229 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 230 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 231 | # source 3D space 232 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 233 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 234 | # source view x, y 235 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 236 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 237 | 238 | ## step2. reproject the source view points with source view depth estimation 239 | # find the depth estimation of the source view 240 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 241 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 242 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 243 | # mask = sampled_depth_src > 0 244 | 245 | # source 3D space 246 | # NOTE that we should use sampled source-view depth_here to project back 247 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 248 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 249 | # reference 3D space 250 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 251 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 252 | # source view x, y, depth 253 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 254 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 255 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 256 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 257 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 258 | 259 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 260 | 261 | 262 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 263 | width, height = depth_ref.shape[1], depth_ref.shape[0] 264 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 265 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, 266 | depth_src, intrinsics_src, extrinsics_src) 267 | # check |p_reproj-p_1| < 1 268 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 269 | 270 | # check |d_reproj-d_1| / d_1 < 0.01 271 | depth_diff = np.abs(depth_reprojected - depth_ref) 272 | relative_depth_diff = depth_diff / depth_ref 273 | 274 | mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) 275 | depth_reprojected[~mask] = 0 276 | 277 | return mask, depth_reprojected, x2d_src, y2d_src 278 | 279 | 280 | def filter_depth(pair_folder, scan_folder, out_folder, plyfilename): 281 | # the pair file 282 | pair_file = os.path.join(pair_folder, "pair.txt") 283 | # for the final point cloud 284 | vertexs = [] 285 | vertex_colors = [] 286 | 287 | pair_data = read_pair_file(pair_file) 288 | nviews = len(pair_data) 289 | 290 | # for each reference view and the corresponding source views 291 | for ref_view, src_views in pair_data: 292 | # src_views = src_views[:args.num_view] 293 | # load the camera parameters 294 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 295 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) 296 | # load the reference image 297 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 298 | # load the estimated depth of the reference view 299 | ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 300 | # load the photometric mask of the reference view 301 | confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 302 | photo_mask = confidence > args.conf 303 | 304 | all_srcview_depth_ests = [] 305 | all_srcview_x = [] 306 | all_srcview_y = [] 307 | all_srcview_geomask = [] 308 | 309 | # compute the geometric mask 310 | geo_mask_sum = 0 311 | for src_view in src_views: 312 | # camera parameters of the source view 313 | src_intrinsics, src_extrinsics = read_camera_parameters( 314 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) 315 | # the estimated depth of the source view 316 | src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 317 | 318 | geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, 319 | src_depth_est, 320 | src_intrinsics, src_extrinsics) 321 | geo_mask_sum += geo_mask.astype(np.int32) 322 | all_srcview_depth_ests.append(depth_reprojected) 323 | all_srcview_x.append(x2d_src) 324 | all_srcview_y.append(y2d_src) 325 | all_srcview_geomask.append(geo_mask) 326 | 327 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 328 | # at least 3 source views matched 329 | geo_mask = geo_mask_sum >= args.thres_view 330 | final_mask = np.logical_and(photo_mask, geo_mask) 331 | 332 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 333 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 334 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 335 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 336 | 337 | print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, 338 | photo_mask.mean(), 339 | geo_mask.mean(), final_mask.mean())) 340 | 341 | if args.display: 342 | import cv2 343 | cv2.imshow('ref_img', ref_img[:, :, ::-1]) 344 | cv2.imshow('ref_depth', ref_depth_est / 800) 345 | cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32) / 800) 346 | cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / 800) 347 | cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / 800) 348 | cv2.waitKey(0) 349 | 350 | height, width = depth_est_averaged.shape[:2] 351 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 352 | # valid_points = np.logical_and(final_mask, ~used_mask[ref_view]) 353 | valid_points = final_mask 354 | print("valid_points", valid_points.mean()) 355 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 356 | #color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset 357 | 358 | if num_stage == 1: 359 | color = ref_img[1::4, 1::4, :][valid_points] 360 | elif num_stage == 2: 361 | color = ref_img[1::2, 1::2, :][valid_points] 362 | elif num_stage == 3: 363 | color = ref_img[valid_points] 364 | 365 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 366 | np.vstack((x, y, np.ones_like(x))) * depth) 367 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 368 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 369 | vertexs.append(xyz_world.transpose((1, 0))) 370 | vertex_colors.append((color * 255).astype(np.uint8)) 371 | 372 | 373 | vertexs = np.concatenate(vertexs, axis=0) 374 | vertex_colors = np.concatenate(vertex_colors, axis=0) 375 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 376 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 377 | 378 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 379 | for prop in vertexs.dtype.names: 380 | vertex_all[prop] = vertexs[prop] 381 | for prop in vertex_colors.dtype.names: 382 | vertex_all[prop] = vertex_colors[prop] 383 | 384 | el = PlyElement.describe(vertex_all, 'vertex') 385 | PlyData([el]).write(plyfilename) 386 | print("saving the final model to", plyfilename) 387 | 388 | 389 | def init_worker(): 390 | ''' 391 | Catch Ctrl+C signal to termiante workers 392 | ''' 393 | signal.signal(signal.SIGINT, signal.SIG_IGN) 394 | 395 | 396 | def pcd_filter_worker(scan): 397 | if args.testlist != "all": 398 | scan_id = int(scan[4:]) 399 | save_name = 'mvsnet{:0>3}_l3.ply'.format(scan_id) 400 | else: 401 | save_name = '{}.ply'.format(scan) 402 | pair_folder = os.path.join(args.testpath, scan) 403 | scan_folder = os.path.join(args.outdir, scan) 404 | out_folder = os.path.join(args.outdir, scan) 405 | filter_depth(pair_folder, scan_folder, out_folder, os.path.join(args.outdir, save_name)) 406 | 407 | 408 | def pcd_filter(testlist, number_worker): 409 | 410 | partial_func = partial(pcd_filter_worker) 411 | 412 | p = Pool(number_worker, init_worker) 413 | try: 414 | p.map(partial_func, testlist) 415 | except KeyboardInterrupt: 416 | print("....\nCaught KeyboardInterrupt, terminating workers") 417 | p.terminate() 418 | else: 419 | p.close() 420 | p.join() 421 | 422 | if __name__ == '__main__': 423 | 424 | if args.testlist != "all": 425 | with open(args.testlist) as f: 426 | content = f.readlines() 427 | testlist = [line.rstrip() for line in content] 428 | else: 429 | testlist = [e for e in os.listdir(args.testpath) if os.path.isdir(os.path.join(args.testpath, e))] \ 430 | if not args.testpath_single_scene else [os.path.basename(args.testpath_single_scene)] 431 | 432 | # step1. save all the depth maps and the masks in outputs directory 433 | save_depth(testlist) 434 | 435 | # step2. filter saved depth maps with photometric confidence maps and geometric constraints 436 | if args.filter_method == "normal": 437 | pcd_filter(testlist, args.num_worker) 438 | elif args.filter_method == "gipuma": 439 | gipuma_filter(testlist, args.outdir, args.prob_threshold, args.disp_threshold, args.num_consistent, 440 | args.fusibile_exe_path) 441 | elif args.filter_method == "dynamic": 442 | pass 443 | else: 444 | raise NotImplementedError 445 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, time, gc, datetime 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | from tensorboardX import SummaryWriter 9 | from datasets import find_dataset_def 10 | from models import * 11 | from utils import * 12 | import torch.distributed as dist 13 | 14 | cudnn.benchmark = True 15 | 16 | parser = argparse.ArgumentParser(description='official Implementation of TransMVSNet') 17 | parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile']) 18 | parser.add_argument('--model', default='mvsnet', help='select model') 19 | parser.add_argument('--device', default='cuda', help='select model') 20 | parser.add_argument('--dataset', default='dtu_yao', help='select dataset') 21 | parser.add_argument('--trainpath', help='train datapath') 22 | parser.add_argument('--testpath', help='test datapath') 23 | parser.add_argument('--trainlist', help='train list') 24 | parser.add_argument('--testlist', help='test list') 25 | parser.add_argument('--epochs', type=int, default=16, help='number of epochs to train') 26 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 27 | parser.add_argument('--lrepochs', type=str, default="10,12,14:2", help='epoch ids to downscale lr and the downscale rate') 28 | parser.add_argument('--wd', type=float, default=0.0001, help='weight decay') 29 | parser.add_argument('--nviews', type=int, default=5, help='total number of views') 30 | parser.add_argument('--batch_size', type=int, default=1, help='train batch size') 31 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 32 | parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values') 33 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 34 | parser.add_argument('--logdir', default='./checkpoints', help='the directory to save checkpoints/logs') 35 | parser.add_argument('--resume', action='store_true', help='continue to train the model') 36 | parser.add_argument('--summary_freq', type=int, default=10, help='print and summary frequency') 37 | parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') 38 | parser.add_argument('--eval_freq', type=int, default=1, help='eval freq') 39 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 40 | parser.add_argument('--pin_m', action='store_true', help='data loader pin memory') 41 | parser.add_argument("--local_rank", type=int, default=0) 42 | parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization') 43 | parser.add_argument('--ndepths', type=str, default="48,32,8", help='ndepths') 44 | parser.add_argument('--depth_inter_r', type=str, default="4,2,1", help='depth_intervals_ratio') 45 | parser.add_argument('--dlossw', type=str, default="0.5,1.0,2.0", help='depth loss weight for different stage') 46 | parser.add_argument('--cr_base_chs', type=str, default="8,8,8", help='cost regularization base channels') 47 | parser.add_argument('--grad_method', type=str, default="detach", choices=["detach", "undetach"], help='grad method') 48 | parser.add_argument('--using_apex', action='store_true', help='using apex, need to install apex') 49 | parser.add_argument('--sync_bn', action='store_true',help='enabling apex sync BN.') 50 | parser.add_argument('--opt-level', type=str, default="O0") 51 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 52 | parser.add_argument('--loss-scale', type=str, default=None) 53 | 54 | 55 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 56 | is_distributed = num_gpus > 1 57 | 58 | 59 | # main function 60 | def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args): 61 | milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')] 62 | lr_gamma = 1 / float(args.lrepochs.split(':')[1]) 63 | lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500, 64 | last_epoch=len(TrainImgLoader) * start_epoch - 1) 65 | 66 | for epoch_idx in range(start_epoch, args.epochs): 67 | global_step = len(TrainImgLoader) * epoch_idx 68 | 69 | # training 70 | if is_distributed: 71 | TrainImgLoader.sampler.set_epoch(epoch_idx) 72 | for batch_idx, sample in enumerate(TrainImgLoader): 73 | start_time = time.time() 74 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 75 | do_summary = global_step % args.summary_freq == 0 76 | loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args) 77 | lr_scheduler.step() 78 | if (not is_distributed) or (dist.get_rank() == 0): 79 | if do_summary: 80 | save_scalars(logger, 'train', scalar_outputs, global_step) 81 | # save_images(logger, 'train', image_outputs, global_step) 82 | print( 83 | "Epoch {}/{}, Iter {}/{}, lr {:.6f}, train loss = {:.3f}, depth loss = {:.3f}, entropy loss = {:.3f}, time = {:.3f}".format( 84 | epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), 85 | optimizer.param_groups[0]["lr"], loss, 86 | scalar_outputs['depth_loss'], 87 | scalar_outputs['entropy_loss'], 88 | time.time() - start_time)) 89 | del scalar_outputs, image_outputs 90 | 91 | # checkpoint 92 | if (not is_distributed) or (dist.get_rank() == 0): 93 | if (epoch_idx + 1) % args.save_freq == 0: 94 | torch.save({ 95 | 'epoch': epoch_idx, 96 | 'model': model.module.state_dict(), 97 | 'optimizer': optimizer.state_dict()}, 98 | "{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx)) 99 | gc.collect() 100 | 101 | # testing 102 | if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1): 103 | avg_test_scalars = DictAverageMeter() 104 | for batch_idx, sample in enumerate(TestImgLoader): 105 | start_time = time.time() 106 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 107 | do_summary = global_step % args.summary_freq == 0 108 | loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) 109 | if (not is_distributed) or (dist.get_rank() == 0): 110 | if do_summary: 111 | save_scalars(logger, 'test', scalar_outputs, global_step) 112 | # save_images(logger, 'test', image_outputs, global_step) 113 | print("Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, depth loss = {:.3f}, entropy loss = {:.3f}, time = {:3f}".format( 114 | epoch_idx, args.epochs, 115 | batch_idx, 116 | len(TestImgLoader), loss, 117 | scalar_outputs["depth_loss"], 118 | scalar_outputs['entropy_loss'], 119 | time.time() - start_time)) 120 | avg_test_scalars.update(scalar_outputs) 121 | del scalar_outputs, image_outputs 122 | 123 | if (not is_distributed) or (dist.get_rank() == 0): 124 | save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step) 125 | print("avg_test_scalars:", avg_test_scalars.mean()) 126 | gc.collect() 127 | 128 | 129 | def test(model, model_loss, TestImgLoader, args): 130 | avg_test_scalars = DictAverageMeter() 131 | for batch_idx, sample in enumerate(TestImgLoader): 132 | start_time = time.time() 133 | loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) 134 | avg_test_scalars.update(scalar_outputs) 135 | del scalar_outputs, image_outputs 136 | if (not is_distributed) or (dist.get_rank() == 0): 137 | print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss, 138 | time.time() - start_time)) 139 | if batch_idx % 100 == 0: 140 | print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean())) 141 | if (not is_distributed) or (dist.get_rank() == 0): 142 | print("final", avg_test_scalars.mean()) 143 | 144 | 145 | def train_sample(model, model_loss, optimizer, sample, args): 146 | model.train() 147 | optimizer.zero_grad() 148 | 149 | sample_cuda = tocuda(sample) 150 | depth_gt_ms = sample_cuda["depth"] 151 | mask_ms = sample_cuda["mask"] 152 | 153 | num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) 154 | depth_gt = depth_gt_ms["stage{}".format(num_stage)] 155 | mask = mask_ms["stage{}".format(num_stage)] 156 | try: 157 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 158 | depth_est = outputs["depth"] 159 | 160 | loss, depth_loss, entropy_loss, depth_entropy = model_loss(outputs, depth_gt_ms, mask_ms, dlossw=[float(e) for e in args.dlossw.split(",") if e]) 161 | 162 | if np.isnan(loss.item()): 163 | raise NanError 164 | 165 | if is_distributed and args.using_apex: 166 | with amp.scale_loss(loss, optimizer) as scaled_loss: 167 | scaled_loss.backward() 168 | else: 169 | loss.backward() 170 | 171 | optimizer.step() 172 | 173 | except NanError: 174 | print(f'nan error occur!!') 175 | gc.collect() 176 | torch.cuda.empty_cache() 177 | 178 | scalar_outputs = {"loss": loss, 179 | "depth_loss": depth_loss, 180 | "entropy_loss": entropy_loss, 181 | "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), 182 | "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), 183 | "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), 184 | "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),} 185 | 186 | image_outputs = {"depth_est": depth_est * mask, 187 | "depth_est_nomask": depth_est, 188 | "depth_gt": sample["depth"]["stage1"], 189 | "ref_img": sample["imgs"][:, 0], 190 | "mask": sample["mask"]["stage1"], 191 | "errormap": (depth_est - depth_gt).abs() * mask, 192 | } 193 | 194 | if is_distributed: 195 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 196 | 197 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 198 | 199 | 200 | @make_nograd_func 201 | def test_sample_depth(model, model_loss, sample, args): 202 | if is_distributed: 203 | model_eval = model.module 204 | else: 205 | model_eval = model 206 | model_eval.eval() 207 | 208 | sample_cuda = tocuda(sample) 209 | depth_gt_ms = sample_cuda["depth"] 210 | mask_ms = sample_cuda["mask"] 211 | 212 | num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) 213 | depth_gt = depth_gt_ms["stage{}".format(num_stage)] 214 | mask = mask_ms["stage{}".format(num_stage)] 215 | 216 | outputs = model_eval(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 217 | depth_est = outputs["depth"] 218 | 219 | loss, depth_loss, entropy_loss, depth_entropy = model_loss(outputs, depth_gt_ms, mask_ms, dlossw=[float(e) for e in args.dlossw.split(",") if e]) 220 | 221 | scalar_outputs = {"loss": loss, 222 | "depth_loss": depth_loss, 223 | "entropy_loss": entropy_loss, 224 | "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), 225 | "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), 226 | "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), 227 | "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8), 228 | "thres14mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 14), 229 | "thres20mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 20), 230 | 231 | "thres2mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [0, 2.0]), 232 | "thres4mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [2.0, 4.0]), 233 | "thres8mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [4.0, 8.0]), 234 | "thres14mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [8.0, 14.0]), 235 | "thres20mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [14.0, 20.0]), 236 | "thres>20mm_abserror": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5, [20.0, 1e5]), 237 | } 238 | 239 | image_outputs = {"depth_est": depth_est * mask, 240 | "depth_est_nomask": depth_est, 241 | "depth_gt": sample["depth"]["stage1"], 242 | "ref_img": sample["imgs"][:, 0], 243 | "mask": sample["mask"]["stage1"], 244 | "errormap": (depth_entropy - depth_gt).abs() * mask} 245 | 246 | if is_distributed: 247 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 248 | 249 | return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) 250 | 251 | def profile(): 252 | warmup_iter = 5 253 | iter_dataloader = iter(TestImgLoader) 254 | 255 | @make_nograd_func 256 | def do_iteration(): 257 | torch.cuda.synchronize() 258 | torch.cuda.synchronize() 259 | start_time = time.perf_counter() 260 | test_sample_depth(next(iter_dataloader), detailed_summary=True) 261 | torch.cuda.synchronize() 262 | end_time = time.perf_counter() 263 | return end_time - start_time 264 | 265 | for i in range(warmup_iter): 266 | t = do_iteration() 267 | print('WarpUp Iter {}, time = {:.4f}'.format(i, t)) 268 | 269 | with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as prof: 270 | for i in range(5): 271 | t = do_iteration() 272 | print('Profile Iter {}, time = {:.4f}'.format(i, t)) 273 | time.sleep(0.02) 274 | 275 | if prof is not None: 276 | # print(prof) 277 | trace_fn = 'chrome-trace.bin' 278 | prof.export_chrome_trace(trace_fn) 279 | print("chrome trace file is written to: ", trace_fn) 280 | 281 | 282 | if __name__ == '__main__': 283 | # parse arguments and check 284 | args = parser.parse_args() 285 | 286 | # using sync_bn by using nvidia-apex, need to install apex. 287 | if args.sync_bn: 288 | assert args.using_apex, "must set using apex and install nvidia-apex" 289 | if args.using_apex: 290 | try: 291 | from apex.parallel import DistributedDataParallel as DDP 292 | from apex.fp16_utils import * 293 | from apex import amp, optimizers 294 | from apex.multi_tensor_apply import multi_tensor_applier 295 | except ImportError: 296 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 297 | 298 | if args.resume: 299 | assert args.mode == "train" 300 | assert args.loadckpt is None 301 | if args.testpath is None: 302 | args.testpath = args.trainpath 303 | 304 | if is_distributed: 305 | torch.cuda.set_device(args.local_rank) 306 | torch.distributed.init_process_group( 307 | backend="nccl", init_method="env://" 308 | ) 309 | synchronize() 310 | 311 | set_random_seed(args.seed) 312 | # device = torch.device(args.device) 313 | device = torch.device(args.local_rank) 314 | 315 | if (not is_distributed) or (dist.get_rank() == 0): 316 | # create logger for mode "train" and "testall" 317 | if args.mode == "train": 318 | if not os.path.isdir(args.logdir): 319 | os.makedirs(args.logdir) 320 | current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) 321 | print("current time", current_time_str) 322 | print("creating new summary file") 323 | logger = SummaryWriter(args.logdir) 324 | print("argv:", sys.argv[1:]) 325 | print_args(args) 326 | 327 | # model, optimizer 328 | model = TransMVSNet(refine=False, ndepths=[int(nd) for nd in args.ndepths.split(",") if nd], 329 | depth_interals_ratio=[float(d_i) for d_i in args.depth_inter_r.split(",") if d_i], 330 | share_cr=args.share_cr, 331 | cr_base_chs=[int(ch) for ch in args.cr_base_chs.split(",") if ch], 332 | grad_method=args.grad_method) 333 | model.to(device) 334 | model_loss = trans_mvsnet_loss 335 | 336 | if args.sync_bn: 337 | import apex 338 | print("using apex synced BN") 339 | model = apex.parallel.convert_syncbn_model(model) 340 | 341 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd) 342 | 343 | # load parameters 344 | start_epoch = 0 345 | if args.resume: 346 | saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] 347 | saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) 348 | # use the latest checkpoint file 349 | loadckpt = os.path.join(args.logdir, saved_models[-1]) 350 | print("resuming", loadckpt) 351 | state_dict = torch.load(loadckpt, map_location=torch.device("cpu")) 352 | model.load_state_dict(state_dict['model']) 353 | optimizer.load_state_dict(state_dict['optimizer']) 354 | start_epoch = state_dict['epoch'] + 1 355 | elif args.loadckpt: 356 | # load checkpoint file specified by args.loadckpt 357 | print("loading model {}".format(args.loadckpt)) 358 | state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) 359 | model.load_state_dict(state_dict['model']) 360 | 361 | if (not is_distributed) or (dist.get_rank() == 0): 362 | print("start at epoch {}".format(start_epoch)) 363 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 364 | 365 | if args.using_apex: 366 | # Initialize Amp 367 | model, optimizer = amp.initialize(model, optimizer, 368 | opt_level=args.opt_level, 369 | keep_batchnorm_fp32=args.keep_batchnorm_fp32, 370 | loss_scale=args.loss_scale 371 | ) 372 | 373 | if is_distributed: 374 | print("Let's use", torch.cuda.device_count(), "GPUs!") 375 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 376 | model = torch.nn.parallel.DistributedDataParallel( 377 | model, device_ids=[args.local_rank], output_device=args.local_rank, 378 | ) 379 | else: 380 | if torch.cuda.is_available(): 381 | print("Let's use", torch.cuda.device_count(), "GPUs!") 382 | model = nn.DataParallel(model) 383 | 384 | # dataset, dataloader 385 | MVSDataset = find_dataset_def(args.dataset) 386 | train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", args.nviews, args.numdepth, args.interval_scale) 387 | test_dataset = MVSDataset(args.testpath, args.testlist, "test", args.nviews, args.numdepth, args.interval_scale) 388 | 389 | if is_distributed: 390 | train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 391 | test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 392 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=2,drop_last=True, pin_memory=args.pin_m) 393 | TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=2, drop_last=False, pin_memory=args.pin_m) 394 | else: 395 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=0, drop_last=True, pin_memory=args.pin_m) 396 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=0, drop_last=False, pin_memory=args.pin_m) 397 | 398 | 399 | if args.mode == "train": 400 | train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args) 401 | elif args.mode == "test": 402 | test(model, model_loss, TestImgLoader, args) 403 | elif args.mode == "profile": 404 | profile() 405 | else: 406 | raise NotImplementedError 407 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision.utils as vutils 3 | import torch, random 4 | import torch.nn.functional as F 5 | import cv2 6 | from typing import List, Union, Tuple, Dict 7 | # from refile import * 8 | import io 9 | import os 10 | import json 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import matplotlib.pyplot as plt 15 | 16 | class NanError(Exception): 17 | pass 18 | 19 | def recursive_apply(obj: Union[List, Dict], func): 20 | assert type(obj) == dict or type(obj) == list 21 | idx_iter = obj if type(obj) == dict else range(len(obj)) 22 | for k in idx_iter: 23 | if type(obj[k]) == list or type(obj[k]) == dict: 24 | recursive_apply(obj[k], func) 25 | else: 26 | obj[k] = func(obj[k]) 27 | 28 | def visualize_depth(depth, mask=None, depth_min=None, depth_max=None, direct=False): 29 | """Visualize the depth map with colormap. 30 | Rescales the values so that depth_min and depth_max map to 0 and 1, 31 | respectively. 32 | """ 33 | if not direct: 34 | depth = 1.0 / (depth + 1e-6) 35 | invalid_mask = np.logical_or(np.isnan(depth), np.logical_not(np.isfinite(depth))) 36 | if mask is not None: 37 | invalid_mask += np.logical_not(mask) 38 | if depth_min is None: 39 | depth_min = np.percentile(depth[np.logical_not(invalid_mask)], 5) 40 | if depth_max is None: 41 | depth_max = np.percentile(depth[np.logical_not(invalid_mask)], 95) 42 | depth[depth < depth_min] = depth_min 43 | depth[depth > depth_max] = depth_max 44 | depth[invalid_mask] = depth_max 45 | 46 | depth_scaled = (depth - depth_min) / (depth_max - depth_min) 47 | depth_scaled_uint8 = np.uint8(depth_scaled * 255) 48 | depth_color = cv2.applyColorMap(depth_scaled_uint8, cv2.COLORMAP_MAGMA) 49 | depth_color[invalid_mask, :] = 0 50 | 51 | return depth_color 52 | 53 | 54 | def save_model_vis(obj, save_dir: str, job_name: str, global_step: int, max_keep: int): 55 | os.makedirs(os.path.join(save_dir, job_name), exist_ok=True) 56 | record_file = os.path.join(save_dir, job_name, 'record') 57 | cktp_file = os.path.join(save_dir, job_name, f'{global_step}.tar') 58 | if not os.path.exists(record_file): 59 | with open(record_file, 'w+') as f: 60 | json.dump([], f) 61 | with open(record_file, 'r') as f: 62 | record = json.load(f) 63 | record.append(global_step) 64 | if len(record) > max_keep: 65 | old = record[0] 66 | record = record[1:] 67 | os.remove(os.path.join(save_dir, job_name, f'{old}.tar')) 68 | torch.save(obj, cktp_file) 69 | with open(record_file, 'w') as f: 70 | json.dump(record, f) 71 | 72 | 73 | def load_model_vis(model: nn.Module, load_path: str, load_step: int): 74 | model.load_state_dict(torch.load(load_path)['model']) 75 | return 0 76 | 77 | 78 | # print arguments 79 | def print_args(args): 80 | print("################################ args ################################") 81 | for k, v in args.__dict__.items(): 82 | print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) 83 | print("########################################################################") 84 | 85 | 86 | # torch.no_grad warpper for functions 87 | def make_nograd_func(func): 88 | def wrapper(*f_args, **f_kwargs): 89 | with torch.no_grad(): 90 | ret = func(*f_args, **f_kwargs) 91 | return ret 92 | 93 | return wrapper 94 | 95 | 96 | # convert a function into recursive style to handle nested dict/list/tuple variables 97 | def make_recursive_func(func): 98 | def wrapper(vars): 99 | if isinstance(vars, list): 100 | return [wrapper(x) for x in vars] 101 | elif isinstance(vars, tuple): 102 | return tuple([wrapper(x) for x in vars]) 103 | elif isinstance(vars, dict): 104 | return {k: wrapper(v) for k, v in vars.items()} 105 | else: 106 | return func(vars) 107 | 108 | return wrapper 109 | 110 | 111 | @make_recursive_func 112 | def tensor2float(vars): 113 | if isinstance(vars, float): 114 | return vars 115 | elif isinstance(vars, torch.Tensor): 116 | return vars.data.item() 117 | else: 118 | raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) 119 | 120 | 121 | @make_recursive_func 122 | def tensor2numpy(vars): 123 | if isinstance(vars, np.ndarray): 124 | return vars 125 | elif isinstance(vars, torch.Tensor): 126 | return vars.detach().cpu().numpy().copy() 127 | else: 128 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 129 | 130 | 131 | @make_recursive_func 132 | def tocuda(vars): 133 | if isinstance(vars, torch.Tensor): 134 | return vars.to(torch.device("cuda")) 135 | elif isinstance(vars, str): 136 | return vars 137 | else: 138 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 139 | 140 | 141 | def save_scalars(logger, mode, scalar_dict, global_step): 142 | scalar_dict = tensor2float(scalar_dict) 143 | for key, value in scalar_dict.items(): 144 | if not isinstance(value, (list, tuple)): 145 | name = '{}/{}'.format(mode, key) 146 | logger.add_scalar(name, value, global_step) 147 | else: 148 | for idx in range(len(value)): 149 | name = '{}/{}_{}'.format(mode, key, idx) 150 | logger.add_scalar(name, value[idx], global_step) 151 | 152 | 153 | def save_images(logger, mode, images_dict, global_step): 154 | images_dict = tensor2numpy(images_dict) 155 | 156 | def preprocess(name, img): 157 | if not (len(img.shape) == 3 or len(img.shape) == 4): 158 | raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) 159 | if len(img.shape) == 3: 160 | img = img[:, np.newaxis, :, :] 161 | img = torch.from_numpy(img[:1]) 162 | return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) 163 | 164 | for key, value in images_dict.items(): 165 | if not isinstance(value, (list, tuple)): 166 | name = '{}/{}'.format(mode, key) 167 | logger.add_image(name, preprocess(name, value), global_step) 168 | else: 169 | for idx in range(len(value)): 170 | name = '{}/{}_{}'.format(mode, key, idx) 171 | logger.add_image(name, preprocess(name, value[idx]), global_step) 172 | 173 | 174 | class DictAverageMeter(object): 175 | def __init__(self): 176 | self.data = {} 177 | self.count = 0 178 | 179 | def update(self, new_input): 180 | self.count += 1 181 | if len(self.data) == 0: 182 | for k, v in new_input.items(): 183 | if not isinstance(v, float): 184 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 185 | self.data[k] = v 186 | else: 187 | for k, v in new_input.items(): 188 | if not isinstance(v, float): 189 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 190 | self.data[k] += v 191 | 192 | def mean(self): 193 | return {k: v / self.count for k, v in self.data.items()} 194 | 195 | 196 | # a wrapper to compute metrics for each image individually 197 | def compute_metrics_for_each_image(metric_func): 198 | def wrapper(depth_est, depth_gt, mask, *args): 199 | batch_size = depth_gt.shape[0] 200 | results = [] 201 | # compute result one by one 202 | for idx in range(batch_size): 203 | ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) 204 | results.append(ret) 205 | return torch.stack(results).mean() 206 | 207 | return wrapper 208 | 209 | 210 | @make_nograd_func 211 | @compute_metrics_for_each_image 212 | def Thres_metrics(depth_est, depth_gt, mask, thres): 213 | assert isinstance(thres, (int, float)) 214 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 215 | errors = torch.abs(depth_est - depth_gt) 216 | err_mask = errors > thres 217 | return torch.mean(err_mask.float()) 218 | 219 | 220 | # NOTE: please do not use this to build up training loss 221 | @make_nograd_func 222 | @compute_metrics_for_each_image 223 | def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None): 224 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 225 | error = (depth_est - depth_gt).abs() 226 | if thres is not None: 227 | error = error[(error >= float(thres[0])) & (error <= float(thres[1]))] 228 | if error.shape[0] == 0: 229 | return torch.tensor(0, device=error.device, dtype=error.dtype) 230 | return torch.mean(error) 231 | 232 | import torch.distributed as dist 233 | def synchronize(): 234 | """ 235 | Helper function to synchronize (barrier) among all processes when 236 | using distributed training 237 | """ 238 | if not dist.is_available(): 239 | return 240 | if not dist.is_initialized(): 241 | return 242 | world_size = dist.get_world_size() 243 | if world_size == 1: 244 | return 245 | dist.barrier() 246 | 247 | def get_world_size(): 248 | if not dist.is_available(): 249 | return 1 250 | if not dist.is_initialized(): 251 | return 1 252 | return dist.get_world_size() 253 | 254 | def reduce_scalar_outputs(scalar_outputs): 255 | world_size = get_world_size() 256 | if world_size < 2: 257 | return scalar_outputs 258 | with torch.no_grad(): 259 | names = [] 260 | scalars = [] 261 | for k in sorted(scalar_outputs.keys()): 262 | names.append(k) 263 | scalars.append(scalar_outputs[k]) 264 | scalars = torch.stack(scalars, dim=0) 265 | dist.reduce(scalars, dst=0) 266 | if dist.get_rank() == 0: 267 | # only main process gets accumulated, so only divide by 268 | # world_size in this case 269 | scalars /= world_size 270 | reduced_scalars = {k: v for k, v in zip(names, scalars)} 271 | 272 | return reduced_scalars 273 | 274 | import torch 275 | from bisect import bisect_right 276 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 277 | # separating MultiStepLR with WarmupLR 278 | # but the current LRScheduler design doesn't allow it 279 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 280 | def __init__( 281 | self, 282 | optimizer, 283 | milestones, 284 | gamma=0.1, 285 | warmup_factor=1.0 / 3, 286 | warmup_iters=500, 287 | warmup_method="linear", 288 | last_epoch=-1, 289 | ): 290 | if not list(milestones) == sorted(milestones): 291 | raise ValueError( 292 | "Milestones should be a list of" " increasing integers. Got {}", 293 | milestones, 294 | ) 295 | 296 | if warmup_method not in ("constant", "linear"): 297 | raise ValueError( 298 | "Only 'constant' or 'linear' warmup_method accepted" 299 | "got {}".format(warmup_method) 300 | ) 301 | self.milestones = milestones 302 | self.gamma = gamma 303 | self.warmup_factor = warmup_factor 304 | self.warmup_iters = warmup_iters 305 | self.warmup_method = warmup_method 306 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 307 | 308 | def get_lr(self): 309 | warmup_factor = 1 310 | if self.last_epoch < self.warmup_iters: 311 | if self.warmup_method == "constant": 312 | warmup_factor = self.warmup_factor 313 | elif self.warmup_method == "linear": 314 | alpha = float(self.last_epoch) / self.warmup_iters 315 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 316 | #print("base_lr {}, warmup_factor {}, self.gamma {}, self.milesotnes {}, self.last_epoch{}".format( 317 | # self.base_lrs[0], warmup_factor, self.gamma, self.milestones, self.last_epoch)) 318 | return [ 319 | base_lr 320 | * warmup_factor 321 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 322 | for base_lr in self.base_lrs 323 | ] 324 | 325 | 326 | def set_random_seed(seed): 327 | random.seed(seed) 328 | np.random.seed(seed) 329 | torch.manual_seed(seed) 330 | torch.cuda.manual_seed_all(seed) 331 | 332 | 333 | def local_pcd(depth, intr): 334 | nx = depth.shape[1] # w 335 | ny = depth.shape[0] # h 336 | x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='xy') 337 | x = x.reshape(nx * ny) 338 | y = y.reshape(nx * ny) 339 | p2d = np.array([x, y, np.ones_like(y)]) 340 | p3d = np.matmul(np.linalg.inv(intr), p2d) 341 | depth = depth.reshape(1, nx * ny) 342 | p3d *= depth 343 | p3d = np.transpose(p3d, (1, 0)) 344 | p3d = p3d.reshape(ny, nx, 3).astype(np.float32) 345 | return p3d 346 | 347 | def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0): 348 | """ 349 | Generate a colored point cloud in PLY format from a color and a depth image. 350 | 351 | Input: 352 | rgb_file -- filename of color image 353 | depth_file -- filename of depth image 354 | ply_file -- filename of ply file 355 | 356 | """ 357 | fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] 358 | points = [] 359 | for v in range(rgb.shape[0]): 360 | for u in range(rgb.shape[1]): 361 | color = rgb[v, u] #rgb.getpixel((u, v)) 362 | Z = depth[v, u] / scale 363 | if Z == 0: continue 364 | X = (u - cx) * Z / fx 365 | Y = (v - cy) * Z / fy 366 | points.append("%f %f %f %d %d %d 0\n" % (X, Y, Z, color[0], color[1], color[2])) 367 | file = open(ply_file, "w") 368 | file.write('''ply 369 | format ascii 1.0 370 | element vertex %d 371 | property float x 372 | property float y 373 | property float z 374 | property uchar red 375 | property uchar green 376 | property uchar blue 377 | property uchar alpha 378 | end_header 379 | %s 380 | ''' % (len(points), "".join(points))) 381 | file.close() 382 | print("save ply, fx:{}, fy:{}, cx:{}, cy:{}".format(fx, fy, cx, cy)) 383 | 384 | 385 | --------------------------------------------------------------------------------