├── LICENSE ├── README.md ├── demo ├── sti.gif ├── ti.gif └── vi.gif ├── nsff_exp ├── Q_Slerp.py ├── configs │ ├── config_balloon1-2.txt │ ├── config_balloon2-2.txt │ ├── config_broom.txt │ ├── config_curls.txt │ ├── config_dynamicFace-2.txt │ ├── config_jumping.txt │ ├── config_kid-running.txt │ ├── config_playground.txt │ ├── config_skating-2.txt │ ├── config_truck2.txt │ └── config_umbrella.txt ├── evaluation.py ├── load_llff.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── poseInterpolator.py ├── render_utils.py ├── run_nerf.py ├── run_nerf_helpers.py └── softsplat.py └── nsff_scripts ├── alt_cuda_corr ├── correlation.cpp ├── correlation_kernel.cu └── setup.py ├── colmap_read_model.py ├── core ├── __init__.py ├── corr.py ├── datasets.py ├── extractor.py ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── flow_viz.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── download_models.sh ├── flow_utils.py ├── models ├── base_model.py ├── blocks.py ├── midas_net.py └── transforms.py ├── run_flows_video.py ├── run_midas.py └── save_poses_nerf.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhengqi Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Scene Flow Fields 2 | PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021 3 | 4 | [[Project Website]](https://www.cs.cornell.edu/~zl548/NSFF/) [[Paper]](https://arxiv.org/abs/2011.13084) [[Video]](https://www.youtube.com/watch?v=qsMIH7gYRCc&feature=emb_title) 5 | 6 | 7 | ## Dependency 8 | The code is tested with Python3, Pytorch >= 1.6 and CUDA >= 10.2, the dependencies includes 9 | * configargparse 10 | * matplotlib 11 | * opencv 12 | * scikit-image 13 | * scipy 14 | * cupy 15 | * imageio. 16 | * tqdm 17 | * kornia 18 | 19 | The current version in this github include some improvement for monocular videos in the wild. For reference code matched paper's description, please check out [this branch](https://github.com/zhengqili/Neural-Scene-Flow-Fields/tree/5bfedc477bab845d539e7b70d114ba39c1644b0e) 20 | 21 | ## Video preprocessing 22 | 1. Download nerf_data.zip from [link](https://drive.google.com/drive/folders/1G-NFZKEA8KSWojUKecpJPVoq5XCjBLOV?usp=sharing), an example input video with SfM camera poses and intrinsics estimated from [COLMAP](https://colmap.github.io/) (Note you need to use COLMAP "colmap image_undistorter" command to undistort input images to get "dense" folder as shown in the example, this dense folder should include "images" and "sparse" folders). 23 | 24 | 2. Download single view depth prediction model "model.pt" from [link](https://drive.google.com/drive/folders/1G-NFZKEA8KSWojUKecpJPVoq5XCjBLOV?usp=sharing), and put it on the folder "nsff_scripts". 25 | 26 | 3. Run the following commands to generate required inputs for training/inference: 27 | ```bash 28 | # Usage 29 | cd nsff_scripts 30 | # create camera intrinsics/extrinsic format for NSFF, same as original NeRF where it uses imgs2poses.py script from the LLFF code: https://github.com/Fyusion/LLFF/blob/master/imgs2poses.py 31 | python save_poses_nerf.py --data_path "/home/xxx/Neural-Scene-Flow-Fields/kid-running/dense/" 32 | # Resize input images and run single view model, 33 | # argument resize_height: resized image height for model training, width will be resized based on original aspect ratio 34 | python run_midas.py --data_path "/home/xxx/Neural-Scene-Flow-Fields/kid-running/dense/" --resize_height 288 35 | # Run optical flow model 36 | ./download_models.sh 37 | python run_flows_video.py --model models/raft-things.pth --data_path /home/xxx/Neural-Scene-Flow-Fields/kid-running/dense/ 38 | ``` 39 | 40 | ## Rendering from an example pretrained model 41 | 1. Download pretraind model "kid-running_ndc_5f_sv_of_sm_unify3_F00-30.zip" from [link](https://drive.google.com/drive/folders/1G-NFZKEA8KSWojUKecpJPVoq5XCjBLOV?usp=sharing). Unzipping and putting it in the folder "nsff_exp/logs/kid-running_ndc_5f_sv_of_sm_unify3_F00-30/360000.tar". 42 | 43 | Set datadir in config/config_kid-running.txt to the root directory of input video. Then go to directory "nsff_exp": 44 | ```bash 45 | cd nsff_exp 46 | mkdir logs 47 | ``` 48 | 49 | 2. Rendering of fixed time, viewpoint interpolation 50 | ```bash 51 | python run_nerf.py --config configs/config_kid-running.txt --render_bt --target_idx 10 52 | ``` 53 | 54 | By running the example command, you should get the following result: 55 | ![Alt Text](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/demo/vi.gif) 56 | 57 | 3. Rendering of fixed viewpoint, time interpolation 58 | ```bash 59 | python run_nerf.py --config configs/config_kid-running.txt --render_lockcam_slowmo --target_idx 8 60 | ``` 61 | 62 | By running the example command, you should get the following result: 63 | ![Alt Text](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/demo/ti.gif) 64 | 65 | 4. Rendering of space-time interpolation 66 | ```bash 67 | python run_nerf.py --config configs/config_kid-running.txt --render_slowmo_bt --target_idx 10 68 | ``` 69 | 70 | By running the example command, you should get the following result: 71 | ![Alt Text](https://github.com/zhengqili/Neural-Scene-Flow-Fields/blob/main/demo/sti.gif) 72 | 73 | ## Training 74 | 1. In configs/config_kid-running.txt, modifying expname to any name you like (different from the original one), and running the following command to train the model: 75 | ```bash 76 | python run_nerf.py --config configs/config_kid-running.txt 77 | ``` 78 | The per-scene training takes ~2 days using 4 Nvidia GTX2080TI GPUs. 79 | 80 | 2. Several parameters in config files you might need to know for training a good model on in-the-wild video 81 | * final_height: this must be same as --resize_height argument in run_midas.py, in kid-running case, it should be 288. 82 | * N_samples: in order to render images with higher resolution, you have to increase number sampled points such as 256 or 512 83 | * chain_sf: model will perform local 5 frame consistency if set True, and perform 3 frame consistency if set False. For faster training, setting to False. 84 | * start_frame, end_frame: indicate training frame range. The default model usually works for video of 1~2s and 30-60 frames work the best for default hyperparameters. Training on longer frames can cause oversmooth rendering. To mitigate the effect, you can increase the capacity of the network by increasing netwidth to 512. 85 | * decay_iteration: number of iteartion in initialization stage. Data-driven losses will decay every 1000 * decay_iteration steps. We have updated code to automatically calculate number of decay iterations. 86 | * no_ndc: our current implementation only supports reconstruction in NDC space, meaning it only works for forward-facing scene, same as original NeRF. 87 | * use_motion_mask, num_extra_sample: whether to use estimated coarse motion segmentation mask to perform hard-mining sampling during initialization stage, and how many extra samples during initialization stage. 88 | * w_depth, w_optical_flow: weight of losses for single-view depth and geometry consistency priors described in the paper. Weights of (0.4, 0.2) or (0.2, 0.1) usually work the best for most of the videos. 89 | * If you see signifacnt ghosting result in the final rendering, you might try the suggestion from [link](https://github.com/zhengqili/Neural-Scene-Flow-Fields/issues/18) 90 | 91 | ## Evaluation on the Dynamic Scene Dataset 92 | 1. Download Dynamic Scene dataset "dynamic_scene_data_full.zip" from [link](https://drive.google.com/drive/folders/1G-NFZKEA8KSWojUKecpJPVoq5XCjBLOV?usp=sharing) 93 | 94 | 2. Download pretrained model "dynamic_scene_pretrained_models.zip" from [link](https://drive.google.com/drive/folders/1G-NFZKEA8KSWojUKecpJPVoq5XCjBLOV?usp=sharing), unzip and put them in the folder "nsff_exp/logs/" 95 | 96 | 3. Run the following command for each scene to get quantitative results reported in the paper: 97 | ```bash 98 | # Usage: configs/config_xxx.txt indicates each scene name such as config_balloon1-2.txt in nsff/configs 99 | python evaluation.py --config configs/config_xxx.txt 100 | ``` 101 | 102 | * Note: you have to use modified LPIPS implementation included in this branch in order to measure LIPIS error for dynamic region only as described in the paper. 103 | 104 | ## Acknowledgment 105 | The code is based on implementation of several prior work: 106 | 107 | * https://github.com/sniklaus/softmax-splatting 108 | * https://github.com/yenchenlin/nerf-pytorch 109 | * https://github.com/JKOK005/dVRK-Linear-Interpolator- 110 | * https://github.com/richzhang/PerceptualSimilarity 111 | * https://github.com/intel-isl/MiDaS 112 | * https://github.com/princeton-vl/RAFT 113 | * https://github.com/NVIDIA/flownet2-pytorch 114 | 115 | ## License 116 | This repository is released under the [MIT license](hhttps://opensource.org/licenses/MIT). 117 | 118 | ## Citation 119 | If you find our code/models useful, please consider citing our paper: 120 | ```bash 121 | @InProceedings{li2020neural, 122 | title={Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes}, 123 | author={Li, Zhengqi and Niklaus, Simon and Snavely, Noah and Wang, Oliver}, 124 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 125 | year={2021} 126 | } 127 | -------------------------------------------------------------------------------- /demo/sti.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/demo/sti.gif -------------------------------------------------------------------------------- /demo/ti.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/demo/ti.gif -------------------------------------------------------------------------------- /demo/vi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/demo/vi.gif -------------------------------------------------------------------------------- /nsff_exp/Q_Slerp.py: -------------------------------------------------------------------------------- 1 | # from utilities import * 2 | # import transform as T 3 | # import copy 4 | 5 | import numpy as np 6 | from numpy import * 7 | from math import sqrt, sin, cos, acos, asin 8 | 9 | class quaternion: 10 | '''A quaternion is a compact method of representing a 3D rotation that has 11 | computational advantages including speed and numerical robustness. 12 | 13 | A quaternion has 2 parts, a scalar s, and a vector v and is typically written:: 14 | 15 | q = s 16 | 17 | A unit quaternion is one for which M{s^2+vx^2+vy^2+vz^2 = 1}. 18 | 19 | A quaternion can be considered as a rotation about a vector in space where 20 | q = cos (theta/2) sin(theta/2) 21 | where is a unit vector. 22 | 23 | Various functions such as INV, NORM, UNIT and PLOT are overloaded for 24 | quaternion objects. 25 | 26 | Arithmetic operators are also overloaded to allow quaternion multiplication, 27 | division, exponentiaton, and quaternion-vector multiplication (rotation). 28 | ''' 29 | 30 | def __init__(self, *args): 31 | ''' 32 | Constructor for quaternion objects: 33 | - q = quaternion() object initialization 34 | - q = quaternion(s, v1, v2, v3) from 4 elements 35 | ''' 36 | 37 | self.vec = []; 38 | 39 | if len(args) == 0: 40 | # default is a null rotation 41 | self.s = 1.0 42 | self.v = matrix([0.0, 0.0, 0.0]) 43 | 44 | elif len(args) == 4: 45 | self.s = args[0]; 46 | self.v = mat(args[1:4]) 47 | 48 | else: 49 | print("error") 50 | return None 51 | 52 | def __repr__(self): 53 | return "%f <%f, %f, %f>" % (self.s, self.v[0,0], self.v[0,1], self.v[0,2]) 54 | 55 | 56 | def tr2q(self, t): 57 | #TR2Q Convert homogeneous transform to a unit-quaternion 58 | # 59 | # Q = tr2q(T) 60 | # 61 | # Return a unit quaternion corresponding to the rotational part of the 62 | # homogeneous transform T. 63 | 64 | qs = sqrt(trace(t)+1)/2.0 65 | kx = t[2,1] - t[1,2] # Oz - Ay 66 | ky = t[0,2] - t[2,0] # Ax - Nz 67 | kz = t[1,0] - t[0,1] # Ny - Ox 68 | 69 | if (t[0,0] >= t[1,1]) and (t[0,0] >= t[2,2]): 70 | kx1 = t[0,0] - t[1,1] - t[2,2] + 1 # Nx - Oy - Az + 1 71 | ky1 = t[1,0] + t[0,1] # Ny + Ox 72 | kz1 = t[2,0] + t[0,2] # Nz + Ax 73 | add = (kx >= 0) 74 | elif (t[1,1] >= t[2,2]): 75 | kx1 = t[1,0] + t[0,1] # Ny + Ox 76 | ky1 = t[1,1] - t[0,0] - t[2,2] + 1 # Oy - Nx - Az + 1 77 | kz1 = t[2,1] + t[1,2] # Oz + Ay 78 | add = (ky >= 0) 79 | else: 80 | kx1 = t[2,0] + t[0,2] # Nz + Ax 81 | ky1 = t[2,1] + t[1,2] # Oz + Ay 82 | kz1 = t[2,2] - t[0,0] - t[1,1] + 1 # Az - Nx - Oy + 1 83 | add = (kz >= 0) 84 | 85 | if add: 86 | kx = kx + kx1 87 | ky = ky + ky1 88 | kz = kz + kz1 89 | else: 90 | kx = kx - kx1 91 | ky = ky - ky1 92 | kz = kz - kz1 93 | 94 | kv = matrix([kx, ky, kz]) 95 | nm = linalg.norm( kv ) 96 | if nm == 0: 97 | self.s = 1.0 98 | self.v = matrix([0.0, 0.0, 0.0]) 99 | 100 | else: 101 | self.s = qs 102 | self.v = (sqrt(1 - qs**2) / nm) * kv 103 | 104 | ############### OPERATORS ######################################### 105 | #PLUS Add two quaternion objects 106 | # 107 | # Invoked by the + operator 108 | # 109 | # q1+q2 standard quaternion addition 110 | def __add__(self, q): 111 | ''' 112 | Return a new quaternion that is the element-wise sum of the operands. 113 | ''' 114 | if isinstance(q, quaternion): 115 | qr = quaternion() 116 | qr.s = 0 117 | 118 | qr.s = self.s + q.s 119 | qr.v = self.v + q.v 120 | 121 | return qr 122 | else: 123 | raise ValueError 124 | 125 | #MINUS Subtract two quaternion objects 126 | # 127 | # Invoked by the - operator 128 | # 129 | # q1-q2 standard quaternion subtraction 130 | 131 | def __sub__(self, q): 132 | ''' 133 | Return a new quaternion that is the element-wise difference of the operands. 134 | ''' 135 | if isinstance(q, quaternion): 136 | qr = quaternion() 137 | qr.s = 0 138 | 139 | qr.s = self.s - q.s 140 | qr.v = self.v - q.v 141 | 142 | return qr 143 | else: 144 | raise ValueError 145 | 146 | # q * q or q * const 147 | def __mul__(self, q2): 148 | ''' 149 | Quaternion product. Several cases are handled 150 | 151 | - q * q quaternion multiplication 152 | - q * c element-wise multiplication by constant 153 | - q * v quaternion-vector multiplication q * v * q.inv(); 154 | ''' 155 | qr = quaternion(); 156 | 157 | if isinstance(q2, quaternion): 158 | 159 | #Multiply unit-quaternion by unit-quaternion 160 | # 161 | # QQ = qqmul(Q1, Q2) 162 | 163 | # decompose into scalar and vector components 164 | s1 = self.s; v1 = self.v 165 | s2 = q2.s; v2 = q2.v 166 | 167 | # form the product 168 | qr.s = s1*s2 - v1*v2.T 169 | qr.v = s1*v2 + s2*v1 + cross(v1,v2) 170 | 171 | elif type(q2) is matrix: 172 | 173 | # Multiply vector by unit-quaternion 174 | # 175 | # Rotate the vector V by the unit-quaternion Q. 176 | 177 | if q2.shape == (1,3) or q2.shape == (3,1): 178 | qr = self * quaternion(q2) * self.inv() 179 | return qr.v; 180 | else: 181 | raise ValueError; 182 | 183 | else: 184 | qr.s = self.s * q2 185 | qr.v = self.v * q2 186 | 187 | return qr 188 | 189 | def __rmul__(self, c): 190 | ''' 191 | Quaternion product. Several cases are handled 192 | 193 | - c * q element-wise multiplication by constant 194 | ''' 195 | qr = quaternion() 196 | qr.s = self.s * c 197 | qr.v = self.v * c 198 | 199 | return qr 200 | 201 | def __imul__(self, x): 202 | ''' 203 | Quaternion in-place multiplication 204 | 205 | - q *= q2 206 | 207 | ''' 208 | 209 | if isinstance(x, quaternion): 210 | s1 = self.s; 211 | v1 = self.v 212 | s2 = x.s 213 | v2 = x.v 214 | 215 | # form the product 216 | self.s = s1*s2 - v1*v2.T 217 | self.v = s1*v2 + s2*v1 + cross(v1,v2) 218 | 219 | elif isscalar(x): 220 | self.s *= x; 221 | self.v *= x; 222 | 223 | return self; 224 | 225 | 226 | # def __div__(self, q): 227 | # '''Return quaternion quotient. Several cases handled: 228 | # - q1 / q2 quaternion division implemented as q1 * q2.inv() 229 | # - q1 / c element-wise division 230 | # ''' 231 | # if isinstance(q, quaternion): 232 | # qr = quaternion() 233 | # qr = self * q.inv() 234 | # elif isscalar(q): 235 | # qr.s = self.s / q 236 | # qr.v = self.v / q 237 | # 238 | # return qr 239 | 240 | 241 | def __pow__(self, p): 242 | ''' 243 | Quaternion exponentiation. Only integer exponents are handled. Negative 244 | integer exponents are supported. 245 | ''' 246 | 247 | # check that exponent is an integer 248 | if not isinstance(p, int): 249 | raise ValueError 250 | 251 | qr = quaternion() 252 | q = quaternion(self); 253 | 254 | # multiply by itself so many times 255 | for i in range(0, abs(p)): 256 | qr *= q 257 | 258 | # if exponent was negative, invert it 259 | if p < 0: 260 | qr = qr.inv() 261 | 262 | return qr 263 | 264 | # def copy(self): 265 | # """ 266 | # Return a copy of the quaternion. 267 | # """ 268 | # return copy.copy(self); 269 | # 270 | # def inv(self): 271 | # """Return the inverse. 272 | # 273 | # @rtype: quaternion 274 | # @return: the inverse 275 | # """ 276 | # 277 | # qi = quaternion(self); 278 | # qi.v = -qi.v; 279 | # 280 | # return qi; 281 | # 282 | # 283 | # 284 | # def norm(self): 285 | # """Return the norm of this quaternion. 286 | # 287 | # @rtype: number 288 | # @return: the norm 289 | # """ 290 | # 291 | # return linalg.norm(self.double()) 292 | # 293 | def double(self): 294 | """Return the quaternion as 4-element vector. 295 | 296 | @rtype: 4-vector 297 | @return: the quaternion elements 298 | """ 299 | return concatenate((mat(self.s), self.v), 1 ) # Debug present 300 | 301 | 302 | # def unit(self): 303 | # """Return an equivalent unit quaternion 304 | # 305 | # @rtype: quaternion 306 | # @return: equivalent unit quaternion 307 | # """ 308 | # 309 | # qr = quaternion() 310 | # nm = self.norm() 311 | # 312 | # qr.s = self.s / nm 313 | # qr.v = self.v / nm 314 | # 315 | # return qr 316 | 317 | def unit_Q(self): 318 | ''' 319 | Function asserts a quaternion Q to be a unit quaternion 320 | s is unchanged as the angle of rotation is assumed to be controlled by the user 321 | v is cast into a unit vector based on |v|^2 = 1 - s^2 and dividing by |v| 322 | ''' 323 | 324 | # Still have some errors with linalg.norm(self.v) 325 | qr = quaternion() 326 | 327 | try: 328 | nm = linalg.norm(self.v) / sqrt(1 - pow(self.s, 2)) 329 | qr.s = self.s 330 | qr.v = self.v / nm 331 | 332 | except: 333 | qr.s = self.s 334 | qr.v = self.v 335 | 336 | return qr 337 | # 338 | # 339 | # def tr(self): 340 | # """Return an equivalent rotation matrix. 341 | # 342 | # @rtype: 4x4 homogeneous transform 343 | # @return: equivalent rotation matrix 344 | # """ 345 | # 346 | # return T.r2t( self.r() ) 347 | # 348 | # def r(self): 349 | # """Return an equivalent rotation matrix. 350 | # 351 | # @rtype: 3x3 orthonormal rotation matrix 352 | # @return: equivalent rotation matrix 353 | # """ 354 | # 355 | # s = self.s; 356 | # x = self.v[0,0] 357 | # y = self.v[0,1] 358 | # z = self.v[0,2] 359 | # 360 | # return matrix([[ 1-2*(y**2+z**2), 2*(x*y-s*z), 2*(x*z+s*y)], 361 | # [2*(x*y+s*z), 1-2*(x**2+z**2), 2*(y*z-s*x)], 362 | # [2*(x*z-s*y), 2*(y*z+s*x), 1-2*(x**2+y**2)]]) 363 | 364 | 365 | 366 | #QINTERP Interpolate rotations expressed by quaternion objects 367 | # 368 | # QI = qinterp(Q1, Q2, R) 369 | # 370 | # Return a unit-quaternion that interpolates between Q1 and Q2 as R moves 371 | # from 0 to 1. This is a spherical linear interpolation (slerp) that can 372 | # be interpretted as interpolation along a great circle arc on a sphere. 373 | # 374 | # If r is a vector, QI, is a cell array of quaternions, each element 375 | # corresponding to sequential elements of R. 376 | # 377 | # See also: CTRAJ, QUATERNION. 378 | 379 | # MOD HISTORY 380 | # 2/99 convert to use of objects 381 | # $Log: qinterp.m,v $ 382 | # Revision 1.3 2002/04/14 11:02:54 pic 383 | # Changed see also line. 384 | # 385 | # Revision 1.2 2002/04/01 12:06:48 pic 386 | # General tidyup, help comments, copyright, see also, RCS keys. 387 | # 388 | # $Revision: 1.3 $ 389 | # 390 | # Copyright (C) 1999-2002, by Peter I. Corke 391 | 392 | def interpolate(Q1, Q2, r): 393 | q1 = Q1.double() 394 | q2 = Q2.double() 395 | 396 | theta = acos(q1*q2.T) 397 | q = [] 398 | count = 0 399 | 400 | if isscalar(r): 401 | if r<0 or r>1: 402 | raise Exception('R out of range') 403 | if theta == 0: 404 | q = quaternion(Q1) 405 | else: 406 | Qq = np.copy((sin((1-r)*theta) * q1 + sin(r*theta) * q2) / sin(theta)) 407 | q = quaternion(Qq[0,0], Qq[0,1], Qq[0,2], Qq[0,3]) 408 | else: 409 | for R in r: 410 | if theta == 0: 411 | qq = Q1 412 | else: 413 | qq = quaternion( (sin((1-R)*theta) * q1 + sin(R*theta) * q2) / sin(theta)) 414 | q.append(qq) 415 | return q 416 | 417 | # if __name__ == "__main__": 418 | # Q1 = quaternion(0,1,2,3).unit_Q() # Cast to unit quaternion by scaling v 419 | # Q2 = quaternion(0.5,-1,4,0).unit_Q() 420 | # 421 | # print("Q1 =", Q1) 422 | # print("Q2 =", Q2) 423 | # print(interpolate(Q1, Q2, 0.6)) -------------------------------------------------------------------------------- /nsff_exp/configs/config_balloon1-2.txt: -------------------------------------------------------------------------------- 1 | expname = balloon1-2_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Balloon1-2/dense/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.02 29 | w_optical_flow = 0.01 30 | w_sm = 0.1 31 | w_sf_reg = 0.1 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_balloon2-2.txt: -------------------------------------------------------------------------------- 1 | expname = balloon2-2_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Balloon2-2/dense/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.02 29 | w_optical_flow = 0.01 30 | w_sm = 0.1 31 | w_sf_reg = 0.1 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_broom.txt: -------------------------------------------------------------------------------- 1 | expname = broom_ndc_5f_unify 2 | 3 | basedir = ./logs 4 | datadir = /phoenix/S7/zl548/nerfie/broom/dense 5 | 6 | dataset_type = llff 7 | 8 | factor = 2 9 | llffhold = 10 10 | 11 | N_rand = 1024 12 | N_samples = 128 13 | N_importance = 0 14 | netwidth = 256 15 | 16 | use_viewdirs = True 17 | raw_noise_std = 1e0 18 | no_ndc = False 19 | lindisp = False 20 | no_batching = True 21 | spherify = False 22 | decay_depth_w = True 23 | decay_optical_flow_w = True 24 | use_motion_mask = True 25 | num_extra_sample = 512 26 | 27 | lrate_decay = 500 28 | 29 | w_depth = 0.04 30 | w_optical_flow = 0.02 31 | w_sm = 0.1 32 | w_sf_reg = 0.1 33 | w_cycle = 1.0 34 | w_prob_reg = 0.1 35 | w_entropy = 1e-3 36 | 37 | start_frame = 0 38 | end_frame = 196 39 | decay_iteration = 150 40 | 41 | final_height = 480 42 | chain_sf = True 43 | 44 | -------------------------------------------------------------------------------- /nsff_exp/configs/config_curls.txt: -------------------------------------------------------------------------------- 1 | expname = curls_ndc_5f_unify 2 | 3 | basedir = ./logs 4 | datadir = /phoenix/S7/zl548/nerfie/curls/dense 5 | 6 | dataset_type = llff 7 | 8 | factor = 2 9 | llffhold = 10 10 | 11 | N_rand = 1024 12 | N_samples = 128 13 | N_importance = 0 14 | netwidth = 256 15 | 16 | use_viewdirs = True 17 | raw_noise_std = 1e0 18 | no_ndc = False 19 | lindisp = False 20 | no_batching = True 21 | spherify = False 22 | decay_depth_w = True 23 | decay_optical_flow_w = True 24 | use_motion_mask = True 25 | num_extra_sample = 512 26 | 27 | lrate_decay = 500 28 | 29 | w_depth = 0.04 30 | w_optical_flow = 0.02 31 | w_sm = 0.1 32 | w_sf_reg = 0.1 33 | w_cycle = 1.0 34 | w_prob_reg = 0.1 35 | w_entropy = 1e-3 36 | 37 | start_frame = 0 38 | end_frame = 56 39 | decay_iteration = 60 40 | 41 | final_height = 480 42 | chain_sf = True 43 | 44 | -------------------------------------------------------------------------------- /nsff_exp/configs/config_dynamicFace-2.txt: -------------------------------------------------------------------------------- 1 | expname = dynamicFace-2_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | 4 | datadir = /home/zl548/nvidia_data_full/DynamicFace-2/dense/ 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.04 29 | w_optical_flow = 0.02 30 | w_sm = 0.1 31 | w_sf_reg = 0.01 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_jumping.txt: -------------------------------------------------------------------------------- 1 | expname = jumping_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Jumping/dense/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 48 26 | chain_sf = True 27 | 28 | w_depth = 0.04 29 | w_optical_flow = 0.02 30 | w_sm = 0.1 31 | w_sf_reg = 0.01 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_kid-running.txt: -------------------------------------------------------------------------------- 1 | expname = kid-running_ndc_5f_sv_of_sm_unify3 2 | 3 | basedir = ./logs 4 | datadir = /phoenix/S7/zl548/nerf_data/kid-running/dense 5 | 6 | dataset_type = llff 7 | 8 | factor = 2 9 | llffhold = 10 10 | 11 | N_rand = 1024 12 | N_samples = 128 13 | N_importance = 0 14 | netwidth = 256 15 | 16 | use_viewdirs = True 17 | raw_noise_std = 1e0 18 | no_ndc = False 19 | lindisp = False 20 | 21 | no_batching = True 22 | spherify = False 23 | decay_depth_w = True 24 | decay_optical_flow_w = True 25 | use_motion_mask = True 26 | num_extra_sample = 512 27 | chain_sf = True 28 | 29 | w_depth = 0.04 30 | w_optical_flow = 0.02 31 | w_sm = 0.1 32 | w_sf_reg = 0.1 33 | w_cycle = 1.0 34 | 35 | start_frame = 0 36 | end_frame = 30 37 | decay_iteration = 30 38 | -------------------------------------------------------------------------------- /nsff_exp/configs/config_playground.txt: -------------------------------------------------------------------------------- 1 | expname = playground_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Playground/dense/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.02 29 | w_optical_flow = 0.01 30 | w_sm = 0.1 31 | w_sf_reg = 0.1 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_skating-2.txt: -------------------------------------------------------------------------------- 1 | expname = Skating2_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | 4 | datadir = //home/zl548/nvidia_data_full/Skating-2/dense/ 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.04 29 | w_optical_flow = 0.02 30 | w_sm = 0.1 31 | w_sf_reg = 0.1 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_truck2.txt: -------------------------------------------------------------------------------- 1 | expname = truck-2_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Truck-2/dense/ 4 | dataset_type = llff 5 | 6 | factor = 2 7 | llffhold = 10 8 | 9 | N_rand = 1024 10 | N_samples = 128 11 | N_importance = 0 12 | netwidth = 256 13 | 14 | use_viewdirs = True 15 | raw_noise_std = 1e0 16 | no_ndc = False 17 | lindisp = False 18 | no_batching = True 19 | spherify = False 20 | decay_depth_w = True 21 | decay_optical_flow_w = True 22 | use_motion_mask = True 23 | num_extra_sample = 512 24 | decay_iteration = 25 25 | chain_sf = True 26 | 27 | w_depth = 0.04 28 | w_optical_flow = 0.02 29 | w_sm = 0.1 30 | w_sf_reg = 0.1 31 | w_cycle = 1.0 32 | w_prob_reg = 0.1 33 | 34 | start_frame = 0 35 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/configs/config_umbrella.txt: -------------------------------------------------------------------------------- 1 | expname = umbrella_ndc_5f_sv_of_unify3 2 | basedir = ./logs 3 | datadir = /home/zl548/nvidia_data_full/Umbrella/dense/ 4 | 5 | dataset_type = llff 6 | 7 | factor = 2 8 | llffhold = 10 9 | 10 | N_rand = 1024 11 | N_samples = 128 12 | N_importance = 0 13 | netwidth = 256 14 | 15 | use_viewdirs = True 16 | raw_noise_std = 1e0 17 | no_ndc = False 18 | lindisp = False 19 | no_batching = True 20 | spherify = False 21 | decay_depth_w = True 22 | decay_optical_flow_w = True 23 | use_motion_mask = True 24 | num_extra_sample = 512 25 | decay_iteration = 25 26 | chain_sf = True 27 | 28 | w_depth = 0.04 29 | w_optical_flow = 0.02 30 | w_sm = 0.1 31 | w_sf_reg = 0.1 32 | w_cycle = 1.0 33 | w_prob_reg = 0.1 34 | 35 | start_frame = 0 36 | end_frame = 24 -------------------------------------------------------------------------------- /nsff_exp/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from models import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, mask=None, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred, mask=mask) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /nsff_exp/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /nsff_exp/models/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | 20 | from . import networks_basic as networks 21 | import models as util 22 | 23 | class DistModel(BaseModel): 24 | def name(self): 25 | return self.model_name 26 | 27 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 28 | use_gpu=True, printNet=False, spatial=False, 29 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 30 | ''' 31 | INPUTS 32 | model - ['net-lin'] for linearly calibrated network 33 | ['net'] for off-the-shelf network 34 | ['L2'] for L2 distance in Lab colorspace 35 | ['SSIM'] for ssim in RGB colorspace 36 | net - ['squeeze','alex','vgg'] 37 | model_path - if None, will look in weights/[NET_NAME].pth 38 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 39 | use_gpu - bool - whether or not to use a GPU 40 | printNet - bool - whether or not to print network architecture out 41 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 42 | is_train - bool - [True] for training mode 43 | lr - float - initial learning rate 44 | beta1 - float - initial momentum term for adam 45 | version - 0.1 for latest, 0.0 was original (with a bug) 46 | gpu_ids - int array - [0] by default, gpus to use 47 | ''' 48 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 49 | 50 | self.model = model 51 | self.net = net 52 | self.is_train = is_train 53 | self.spatial = spatial 54 | self.gpu_ids = gpu_ids 55 | self.model_name = '%s [%s]'%(model,net) 56 | 57 | # print('model_path ', model_path) 58 | # sys.exit() 59 | 60 | if(self.model == 'net-lin'): # pretrained net + linear layer 61 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 62 | use_dropout=True, spatial=spatial, version=version, lpips=True) 63 | kw = {} 64 | if not use_gpu: 65 | kw['map_location'] = 'cpu' 66 | if(model_path is None): 67 | import inspect 68 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 69 | 70 | if(not is_train): 71 | print('Loading model from: %s'%model_path) 72 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 73 | 74 | # sys.exit() 75 | 76 | elif(self.model=='net'): # pretrained network 77 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 78 | elif(self.model in ['L2','l2']): 79 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 80 | self.model_name = 'L2' 81 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 82 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 83 | self.model_name = 'SSIM' 84 | else: 85 | raise ValueError("Model [%s] not recognized." % self.model) 86 | 87 | self.parameters = list(self.net.parameters()) 88 | 89 | if self.is_train: # training mode 90 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 91 | self.rankLoss = networks.BCERankingLoss() 92 | self.parameters += list(self.rankLoss.net.parameters()) 93 | self.lr = lr 94 | self.old_lr = lr 95 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 96 | else: # test mode 97 | self.net.eval() 98 | 99 | if(use_gpu): 100 | self.net.to(gpu_ids[0]) 101 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 102 | if(self.is_train): 103 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 104 | 105 | if(printNet): 106 | print('---------- Networks initialized -------------') 107 | networks.print_network(self.net) 108 | print('-----------------------------------------------') 109 | 110 | def forward(self, in0, in1, mask=None, retPerLayer=False): 111 | ''' Function computes the distance between image patches in0 and in1 112 | INPUTS 113 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 114 | OUTPUT 115 | computed distances between in0 and in1 116 | ''' 117 | 118 | return self.net.forward(in0, in1, mask, retPerLayer=retPerLayer) 119 | 120 | # ***** TRAINING FUNCTIONS ***** 121 | def optimize_parameters(self): 122 | self.forward_train() 123 | self.optimizer_net.zero_grad() 124 | self.backward_train() 125 | self.optimizer_net.step() 126 | self.clamp_weights() 127 | 128 | def clamp_weights(self): 129 | for module in self.net.modules(): 130 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 131 | module.weight.data = torch.clamp(module.weight.data,min=0) 132 | 133 | def set_input(self, data): 134 | self.input_ref = data['ref'] 135 | self.input_p0 = data['p0'] 136 | self.input_p1 = data['p1'] 137 | self.input_judge = data['judge'] 138 | 139 | if(self.use_gpu): 140 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 141 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 142 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 143 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 144 | 145 | self.var_ref = Variable(self.input_ref,requires_grad=True) 146 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 147 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 148 | 149 | def forward_train(self): # run forward pass 150 | # print(self.net.module.scaling_layer.shift) 151 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 152 | 153 | self.d0 = self.forward(self.var_ref, self.var_p0) 154 | self.d1 = self.forward(self.var_ref, self.var_p1) 155 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 156 | 157 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 158 | 159 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 160 | 161 | return self.loss_total 162 | 163 | def backward_train(self): 164 | torch.mean(self.loss_total).backward() 165 | 166 | def compute_accuracy(self,d0,d1,judge): 167 | ''' d0, d1 are Variables, judge is a Tensor ''' 168 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 211 | self.old_lr = lr 212 | 213 | def score_2afc_dataset(data_loader, func, name=''): 214 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 215 | distance function 'func' in dataset 'data_loader' 216 | INPUTS 217 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 218 | func - callable distance function - calling d=func(in0,in1) should take 2 219 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 220 | OUTPUTS 221 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 222 | [1] - dictionary with following elements 223 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 224 | gts - N array in [0,1], preferred patch selected by human evaluators 225 | (closer to "0" for left patch p0, "1" for right patch p1, 226 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 227 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 228 | CONSTS 229 | N - number of test triplets in data_loader 230 | ''' 231 | 232 | d0s = [] 233 | d1s = [] 234 | gts = [] 235 | 236 | for data in tqdm(data_loader.load_data(), desc=name): 237 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 238 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 239 | gts+=data['judge'].cpu().numpy().flatten().tolist() 240 | 241 | d0s = np.array(d0s) 242 | d1s = np.array(d1s) 243 | gts = np.array(gts) 244 | scores = (d0s cy_thresh: # cos(y) not close to zero, standard form 80 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 81 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 82 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 83 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 84 | # so r21 -> sin(z), r22 -> cos(z) and 85 | z = math.atan2(r21, r22) 86 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 87 | x = 0.0 88 | return z, y, x 89 | 90 | 91 | def quat2euler(q): 92 | ''' Return Euler angles corresponding to quaternion `q` 93 | 94 | Parameters 95 | ---------- 96 | q : 4 element sequence 97 | w, x, y, z of quaternion 98 | 99 | Returns 100 | ------- 101 | z : scalar 102 | Rotation angle in radians around z-axis (performed first) 103 | y : scalar 104 | Rotation angle in radians around y-axis 105 | x : scalar 106 | Rotation angle in radians around x-axis (performed last) 107 | 108 | Notes 109 | ----- 110 | It's possible to reduce the amount of calculation a little, by 111 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 112 | the reduction in computation is small, and the code repetition is 113 | large. 114 | ''' 115 | # delayed import to avoid cyclic dependencies 116 | import nibabel.quaternions as nq 117 | return mat2euler(nq.quat2mat(q)) 118 | 119 | def linear_translation(A, B, T): 120 | ''' 121 | Interpolates between 2 start points A an dB linearly 122 | 123 | Input: 124 | A - Start point 125 | B - Ending point 126 | T - intermediate points within a range from 0 - 1, 0 representing point A and 1 representing point B 127 | 128 | Output 129 | C - intermediate pose as an array 130 | ''' 131 | V_AB = B -A 132 | C = A + T*(V_AB) 133 | return C 134 | 135 | def qvec2rotmat(qvec): 136 | return np.array([ 137 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 138 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 139 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 140 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 141 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 142 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 143 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 144 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 145 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 146 | 147 | 148 | def rotmat2qvec(R): 149 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 150 | K = np.array([ 151 | [Rxx - Ryy - Rzz, 0, 0, 0], 152 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 153 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 154 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 155 | eigvals, eigvecs = np.linalg.eigh(K) 156 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 157 | if qvec[0] < 0: 158 | qvec *= -1 159 | return qvec 160 | 161 | 162 | def linear_pose_interp(A_trans, A_rot, B_trans, B_rot, T): 163 | ''' 164 | Pose interpolator that calculates intermediate poses of a vector connecting 2 points 165 | Interpolation is done linearly for both translation and rotation 166 | Translation is done assuming that point A is rotationally invariant. 167 | Rotation is done about the point A. Quaternion SLERP rotation is used to give the quickest rotation from A -> B 168 | ** "Roll" or twisting of the arm is not taken into account in this calculation. Separate interpolations have to be done for the roll angle 169 | 170 | Input: 171 | Starting points start_A [X,Y,Z,roll,pitch,yaw] list 172 | Ending points end_B in [X',Y',Z',roll',pitch',yaw'] list 173 | Yaw, Pitch, Roll calculated in radians within bounds [0, 2*pi] 174 | Sequence of rotation: Roll - 1st, Pitch - 2nd, Yaw - 3rd 175 | T = no of intermediate poses 176 | 177 | Output: 178 | list of positions and rotations stored into the variable track 179 | track is a dictionary with keys 'lin' and 'rot' 180 | 181 | track['lin'] - Linear interpolation of interval T of starting positions from A -> B 182 | track['rot'] - Slerp interpolation of quaternion of interval T, arranged as a list in [w x y z] 183 | # track['rot'] - Intermediate Yaw-Pitch-Roll poses of interval T, in sequence YPR 184 | 185 | ''' 186 | 187 | track = {'lin': [], 'rot': []} 188 | 189 | # ra = start_A[3]; pa = start_A[4]; ya = start_A[5] # Yaw/pitch/Roll for A and B 190 | # rb = end_B[3]; pb = end_B[4]; yb = end_B[5] 191 | 192 | # A = array(start_A[:3]); B = array(end_B[:3]) 193 | # [vxa, vya, vza, wa] = Rotation(Vector(A_rot[0, 0], A_rot[0, 1], A_rot[0, 2]), 194 | # Vector(A_rot[1, 0], A_rot[1, 1], A_rot[1, 2]), 195 | # Vector(A_rot[2, 0], A_rot[2, 1], A_rot[2, 2])).GetQuaternion() # Quaternion representation of start and end points 196 | 197 | # [vxb, vyb, vzb, wb] = Rotation(Vector(B_rot[0, 0], B_rot[0, 1], B_rot[0, 2]), 198 | # Vector(B_rot[1, 0], B_rot[1, 1], B_rot[1, 2]), 199 | # Vector(B_rot[2, 0], B_rot[2, 1], B_rot[2, 2])).GetQuaternion() 200 | 201 | q_a = rotmat2qvec(A_rot) 202 | q_b = rotmat2qvec(B_rot) 203 | 204 | # print('q_a ', q_a) 205 | # sys.exit() 206 | 207 | QA = quaternion(q_a[0], q_a[1], q_a[2], q_a[3]) 208 | QB = quaternion(q_b[0], q_b[1], q_b[2], q_b[3]) 209 | 210 | track['lin'] = linear_translation(A_trans, B_trans, T).tolist() 211 | q = interpolate(QA, QB, T) 212 | track['rot'] = [q.s] + (q.v).tolist()[0] # List of quaternion [w x y z] 213 | 214 | # print('track ', track['rot'], track['lin']) 215 | # sys.exit() 216 | return qvec2rotmat(track['rot']), np.array(track['lin']) 217 | # print("Quaternion: ", track['rot'], "Y-P-R: ", quat2euler(track['rot'])) 218 | # return track 219 | 220 | 221 | if __name__ == "__main__": 222 | A = [0, 0, 0, 0, 0, 0] 223 | B = [1, 1, 1, -pi/2, pi/2, -pi/2] 224 | 225 | track = linear_pose_interp(A, B, 1) 226 | # print(track) 227 | -------------------------------------------------------------------------------- /nsff_exp/softsplat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import re 7 | 8 | kernel_Softsplat_updateOutput = ''' 9 | extern "C" __global__ void kernel_Softsplat_updateOutput( 10 | const int n, 11 | const float* input, 12 | const float* flow, 13 | float* output 14 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 15 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); 16 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); 17 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); 18 | const int intX = ( intIndex ) % SIZE_3(output); 19 | 20 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 21 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 22 | 23 | int intNorthwestX = (int) (floor(fltOutputX)); 24 | int intNorthwestY = (int) (floor(fltOutputY)); 25 | int intNortheastX = intNorthwestX + 1; 26 | int intNortheastY = intNorthwestY; 27 | int intSouthwestX = intNorthwestX; 28 | int intSouthwestY = intNorthwestY + 1; 29 | int intSoutheastX = intNorthwestX + 1; 30 | int intSoutheastY = intNorthwestY + 1; 31 | 32 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); 33 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); 34 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); 35 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 36 | 37 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { 38 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); 39 | } 40 | 41 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { 42 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); 43 | } 44 | 45 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { 46 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); 47 | } 48 | 49 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { 50 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); 51 | } 52 | } } 53 | ''' 54 | 55 | kernel_Softsplat_updateGradInput = ''' 56 | extern "C" __global__ void kernel_Softsplat_updateGradInput( 57 | const int n, 58 | const float* input, 59 | const float* flow, 60 | const float* gradOutput, 61 | float* gradInput, 62 | float* gradFlow 63 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 64 | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); 65 | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); 66 | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); 67 | const int intX = ( intIndex ) % SIZE_3(gradInput); 68 | 69 | float fltGradInput = 0.0; 70 | 71 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 72 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 73 | 74 | int intNorthwestX = (int) (floor(fltOutputX)); 75 | int intNorthwestY = (int) (floor(fltOutputY)); 76 | int intNortheastX = intNorthwestX + 1; 77 | int intNortheastY = intNorthwestY; 78 | int intSouthwestX = intNorthwestX; 79 | int intSouthwestY = intNorthwestY + 1; 80 | int intSoutheastX = intNorthwestX + 1; 81 | int intSoutheastY = intNorthwestY + 1; 82 | 83 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); 84 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); 85 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); 86 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 87 | 88 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 89 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; 90 | } 91 | 92 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 93 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; 94 | } 95 | 96 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 97 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; 98 | } 99 | 100 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 101 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; 102 | } 103 | 104 | gradInput[intIndex] = fltGradInput; 105 | } } 106 | ''' 107 | 108 | kernel_Softsplat_updateGradFlow = ''' 109 | extern "C" __global__ void kernel_Softsplat_updateGradFlow( 110 | const int n, 111 | const float* input, 112 | const float* flow, 113 | const float* gradOutput, 114 | float* gradInput, 115 | float* gradFlow 116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 117 | float fltGradFlow = 0.0; 118 | 119 | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); 120 | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); 121 | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); 122 | const int intX = ( intIndex ) % SIZE_3(gradFlow); 123 | 124 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 125 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 126 | 127 | int intNorthwestX = (int) (floor(fltOutputX)); 128 | int intNorthwestY = (int) (floor(fltOutputY)); 129 | int intNortheastX = intNorthwestX + 1; 130 | int intNortheastY = intNorthwestY; 131 | int intSouthwestX = intNorthwestX; 132 | int intSouthwestY = intNorthwestY + 1; 133 | int intSoutheastX = intNorthwestX + 1; 134 | int intSoutheastY = intNorthwestY + 1; 135 | 136 | float fltNorthwest = 0.0; 137 | float fltNortheast = 0.0; 138 | float fltSouthwest = 0.0; 139 | float fltSoutheast = 0.0; 140 | 141 | if (intC == 0) { 142 | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); 143 | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); 144 | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); 145 | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); 146 | 147 | } else if (intC == 1) { 148 | fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); 149 | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); 150 | fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); 151 | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); 152 | 153 | } 154 | 155 | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { 156 | float fltInput = VALUE_4(input, intN, intChannel, intY, intX); 157 | 158 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 159 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; 160 | } 161 | 162 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 163 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; 164 | } 165 | 166 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 167 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; 168 | } 169 | 170 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 171 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; 172 | } 173 | } 174 | 175 | gradFlow[intIndex] = fltGradFlow; 176 | } } 177 | ''' 178 | 179 | def cupy_kernel(strFunction, objVariables): 180 | strKernel = globals()[strFunction] 181 | 182 | while True: 183 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 184 | 185 | if objMatch is None: 186 | break 187 | # end 188 | 189 | intArg = int(objMatch.group(2)) 190 | 191 | strTensor = objMatch.group(4) 192 | intSizes = objVariables[strTensor].size() 193 | 194 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 195 | # end 196 | 197 | while True: 198 | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) 199 | 200 | if objMatch is None: 201 | break 202 | # end 203 | 204 | intArgs = int(objMatch.group(2)) 205 | strArgs = objMatch.group(4).split(',') 206 | 207 | strTensor = strArgs[0] 208 | intStrides = objVariables[strTensor].stride() 209 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 210 | 211 | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') 212 | # end 213 | 214 | while True: 215 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 216 | 217 | if objMatch is None: 218 | break 219 | # end 220 | 221 | intArgs = int(objMatch.group(2)) 222 | strArgs = objMatch.group(4).split(',') 223 | 224 | strTensor = strArgs[0] 225 | intStrides = objVariables[strTensor].stride() 226 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 227 | 228 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 229 | # end 230 | 231 | return strKernel 232 | # end 233 | 234 | @cupy.memoize(for_each_device=True) 235 | def cupy_launch(strFunction, strKernel): 236 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 237 | # end 238 | 239 | class _FunctionSoftsplat(torch.autograd.Function): 240 | @staticmethod 241 | def forward(self, input, flow): 242 | self.save_for_backward(input, flow) 243 | 244 | intSamples = input.shape[0] 245 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 246 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 247 | 248 | assert(intFlowDepth == 2) 249 | assert(intInputHeight == intFlowHeight) 250 | assert(intInputWidth == intFlowWidth) 251 | 252 | assert(input.is_contiguous() == True) 253 | assert(flow.is_contiguous() == True) 254 | 255 | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) 256 | 257 | if input.is_cuda == True: 258 | n = output.nelement() 259 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { 260 | 'input': input, 261 | 'flow': flow, 262 | 'output': output 263 | }))( 264 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 265 | block=tuple([ 512, 1, 1 ]), 266 | args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] 267 | ) 268 | 269 | elif input.is_cuda == False: 270 | raise NotImplementedError() 271 | 272 | # end 273 | 274 | return output 275 | # end 276 | 277 | @staticmethod 278 | def backward(self, gradOutput): 279 | input, flow = self.saved_tensors 280 | 281 | intSamples = input.shape[0] 282 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 283 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 284 | 285 | assert(intFlowDepth == 2) 286 | assert(intInputHeight == intFlowHeight) 287 | assert(intInputWidth == intFlowWidth) 288 | 289 | assert(gradOutput.is_contiguous() == True) 290 | 291 | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None 292 | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None 293 | 294 | if input.is_cuda == True: 295 | if gradInput is not None: 296 | n = gradInput.nelement() 297 | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { 298 | 'input': input, 299 | 'flow': flow, 300 | 'gradOutput': gradOutput, 301 | 'gradInput': gradInput, 302 | 'gradFlow': gradFlow 303 | }))( 304 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 305 | block=tuple([ 512, 1, 1 ]), 306 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] 307 | ) 308 | # end 309 | 310 | if gradFlow is not None: 311 | n = gradFlow.nelement() 312 | cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { 313 | 'input': input, 314 | 'flow': flow, 315 | 'gradOutput': gradOutput, 316 | 'gradInput': gradInput, 317 | 'gradFlow': gradFlow 318 | }))( 319 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 320 | block=tuple([ 512, 1, 1 ]), 321 | args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] 322 | ) 323 | # end 324 | 325 | elif input.is_cuda == False: 326 | raise NotImplementedError() 327 | 328 | # end 329 | 330 | return gradInput, gradFlow 331 | # end 332 | # end 333 | 334 | def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): 335 | assert(tenMetric is None or tenMetric.shape[1] == 1) 336 | assert(strType in ['summation', 'average', 'linear', 'softmax']) 337 | 338 | if strType == 'average': 339 | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) 340 | 341 | elif strType == 'linear': 342 | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) 343 | 344 | elif strType == 'softmax': 345 | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) 346 | 347 | # end 348 | 349 | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) 350 | 351 | if strType != 'summation': 352 | tenOutput = tenOutput[:, :-1, :, :] / (tenOutput[:, -1:, :, :] + 0.0000001) 353 | # end 354 | 355 | return tenOutput 356 | # end 357 | 358 | class ModuleSoftsplat(torch.nn.Module): 359 | def __init__(self, strType): 360 | super(ModuleSoftsplat, self).__init__() 361 | 362 | self.strType = strType 363 | # end 364 | 365 | def forward(self, tenInput, tenFlow, tenMetric): 366 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 367 | # end 368 | # end -------------------------------------------------------------------------------- /nsff_scripts/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /nsff_scripts/alt_cuda_corr/correlation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #define BLOCK_H 4 8 | #define BLOCK_W 8 9 | #define BLOCK_HW BLOCK_H * BLOCK_W 10 | #define CHANNEL_STRIDE 32 11 | 12 | 13 | __forceinline__ __device__ 14 | bool within_bounds(int h, int w, int H, int W) { 15 | return h >= 0 && h < H && w >= 0 && w < W; 16 | } 17 | 18 | template 19 | __global__ void corr_forward_kernel( 20 | const torch::PackedTensorAccessor32 fmap1, 21 | const torch::PackedTensorAccessor32 fmap2, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | const int b = blockIdx.x; 27 | const int h0 = blockIdx.y * blockDim.x; 28 | const int w0 = blockIdx.z * blockDim.y; 29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 30 | 31 | const int H1 = fmap1.size(1); 32 | const int W1 = fmap1.size(2); 33 | const int H2 = fmap2.size(1); 34 | const int W2 = fmap2.size(2); 35 | const int N = coords.size(1); 36 | const int C = fmap1.size(3); 37 | 38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 40 | __shared__ scalar_t x2s[BLOCK_HW]; 41 | __shared__ scalar_t y2s[BLOCK_HW]; 42 | 43 | for (int c=0; c(floor(y2s[k1]))-r+iy; 76 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 77 | int c2 = tid % CHANNEL_STRIDE; 78 | 79 | auto fptr = fmap2[b][h2][w2]; 80 | if (within_bounds(h2, w2, H2, W2)) 81 | f2[c2][k1] = fptr[c+c2]; 82 | else 83 | f2[c2][k1] = 0.0; 84 | } 85 | 86 | __syncthreads(); 87 | 88 | scalar_t s = 0.0; 89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 105 | *(corr_ptr + ix_nw) += nw; 106 | 107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 108 | *(corr_ptr + ix_ne) += ne; 109 | 110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 111 | *(corr_ptr + ix_sw) += sw; 112 | 113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 114 | *(corr_ptr + ix_se) += se; 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | template 123 | __global__ void corr_backward_kernel( 124 | const torch::PackedTensorAccessor32 fmap1, 125 | const torch::PackedTensorAccessor32 fmap2, 126 | const torch::PackedTensorAccessor32 coords, 127 | const torch::PackedTensorAccessor32 corr_grad, 128 | torch::PackedTensorAccessor32 fmap1_grad, 129 | torch::PackedTensorAccessor32 fmap2_grad, 130 | torch::PackedTensorAccessor32 coords_grad, 131 | int r) 132 | { 133 | 134 | const int b = blockIdx.x; 135 | const int h0 = blockIdx.y * blockDim.x; 136 | const int w0 = blockIdx.z * blockDim.y; 137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 138 | 139 | const int H1 = fmap1.size(1); 140 | const int W1 = fmap1.size(2); 141 | const int H2 = fmap2.size(1); 142 | const int W2 = fmap2.size(2); 143 | const int N = coords.size(1); 144 | const int C = fmap1.size(3); 145 | 146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 148 | 149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 151 | 152 | __shared__ scalar_t x2s[BLOCK_HW]; 153 | __shared__ scalar_t y2s[BLOCK_HW]; 154 | 155 | for (int c=0; c(floor(y2s[k1]))-r+iy; 190 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 191 | int c2 = tid % CHANNEL_STRIDE; 192 | 193 | auto fptr = fmap2[b][h2][w2]; 194 | if (within_bounds(h2, w2, H2, W2)) 195 | f2[c2][k1] = fptr[c+c2]; 196 | else 197 | f2[c2][k1] = 0.0; 198 | 199 | f2_grad[c2][k1] = 0.0; 200 | } 201 | 202 | __syncthreads(); 203 | 204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; 205 | scalar_t g = 0.0; 206 | 207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); 208 | int ix_ne = H1*W1*((iy-1) + rd*ix); 209 | int ix_sw = H1*W1*(iy + rd*(ix-1)); 210 | int ix_se = H1*W1*(iy + rd*ix); 211 | 212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 213 | g += *(grad_ptr + ix_nw) * dy * dx; 214 | 215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 216 | g += *(grad_ptr + ix_ne) * dy * (1-dx); 217 | 218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx; 220 | 221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); 223 | 224 | for (int k=0; k(floor(y2s[k1]))-r+iy; 232 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 233 | int c2 = tid % CHANNEL_STRIDE; 234 | 235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; 236 | if (within_bounds(h2, w2, H2, W2)) 237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]); 238 | } 239 | } 240 | } 241 | } 242 | __syncthreads(); 243 | 244 | 245 | for (int k=0; k corr_cuda_forward( 261 | torch::Tensor fmap1, 262 | torch::Tensor fmap2, 263 | torch::Tensor coords, 264 | int radius) 265 | { 266 | const auto B = coords.size(0); 267 | const auto N = coords.size(1); 268 | const auto H = coords.size(2); 269 | const auto W = coords.size(3); 270 | 271 | const auto rd = 2 * radius + 1; 272 | auto opts = fmap1.options(); 273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); 274 | 275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); 276 | const dim3 threads(BLOCK_H, BLOCK_W); 277 | 278 | corr_forward_kernel<<>>( 279 | fmap1.packed_accessor32(), 280 | fmap2.packed_accessor32(), 281 | coords.packed_accessor32(), 282 | corr.packed_accessor32(), 283 | radius); 284 | 285 | return {corr}; 286 | } 287 | 288 | std::vector corr_cuda_backward( 289 | torch::Tensor fmap1, 290 | torch::Tensor fmap2, 291 | torch::Tensor coords, 292 | torch::Tensor corr_grad, 293 | int radius) 294 | { 295 | const auto B = coords.size(0); 296 | const auto N = coords.size(1); 297 | 298 | const auto H1 = fmap1.size(1); 299 | const auto W1 = fmap1.size(2); 300 | const auto H2 = fmap2.size(1); 301 | const auto W2 = fmap2.size(2); 302 | const auto C = fmap1.size(3); 303 | 304 | auto opts = fmap1.options(); 305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); 306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); 307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); 308 | 309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); 310 | const dim3 threads(BLOCK_H, BLOCK_W); 311 | 312 | 313 | corr_backward_kernel<<>>( 314 | fmap1.packed_accessor32(), 315 | fmap2.packed_accessor32(), 316 | coords.packed_accessor32(), 317 | corr_grad.packed_accessor32(), 318 | fmap1_grad.packed_accessor32(), 319 | fmap2_grad.packed_accessor32(), 320 | coords_grad.packed_accessor32(), 321 | radius); 322 | 323 | return {fmap1_grad, fmap2_grad, coords_grad}; 324 | } -------------------------------------------------------------------------------- /nsff_scripts/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /nsff_scripts/colmap_read_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | class Image(BaseImage): 49 | def qvec2rotmat(self): 50 | return qvec2rotmat(self.qvec) 51 | 52 | 53 | CAMERA_MODELS = { 54 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 55 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 56 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 57 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 58 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 59 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 60 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 61 | CameraModel(model_id=7, model_name="FOV", num_params=5), 62 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 63 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 64 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 65 | } 66 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ 67 | for camera_model in CAMERA_MODELS]) 68 | 69 | 70 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 71 | """Read and unpack the next bytes from a binary file. 72 | :param fid: 73 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 74 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 75 | :param endian_character: Any of {@, =, <, >, !} 76 | :return: Tuple of read and unpacked values. 77 | """ 78 | data = fid.read(num_bytes) 79 | return struct.unpack(endian_character + format_char_sequence, data) 80 | 81 | 82 | def read_cameras_text(path): 83 | """ 84 | see: src/base/reconstruction.cc 85 | void Reconstruction::WriteCamerasText(const std::string& path) 86 | void Reconstruction::ReadCamerasText(const std::string& path) 87 | """ 88 | cameras = {} 89 | with open(path, "r") as fid: 90 | while True: 91 | line = fid.readline() 92 | if not line: 93 | break 94 | line = line.strip() 95 | if len(line) > 0 and line[0] != "#": 96 | elems = line.split() 97 | camera_id = int(elems[0]) 98 | model = elems[1] 99 | width = int(elems[2]) 100 | height = int(elems[3]) 101 | params = np.array(tuple(map(float, elems[4:]))) 102 | cameras[camera_id] = Camera(id=camera_id, model=model, 103 | width=width, height=height, 104 | params=params) 105 | return cameras 106 | 107 | 108 | def read_cameras_binary(path_to_model_file): 109 | """ 110 | see: src/base/reconstruction.cc 111 | void Reconstruction::WriteCamerasBinary(const std::string& path) 112 | void Reconstruction::ReadCamerasBinary(const std::string& path) 113 | """ 114 | cameras = {} 115 | with open(path_to_model_file, "rb") as fid: 116 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 117 | for camera_line_index in range(num_cameras): 118 | camera_properties = read_next_bytes( 119 | fid, num_bytes=24, format_char_sequence="iiQQ") 120 | camera_id = camera_properties[0] 121 | model_id = camera_properties[1] 122 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 123 | width = camera_properties[2] 124 | height = camera_properties[3] 125 | num_params = CAMERA_MODEL_IDS[model_id].num_params 126 | params = read_next_bytes(fid, num_bytes=8*num_params, 127 | format_char_sequence="d"*num_params) 128 | cameras[camera_id] = Camera(id=camera_id, 129 | model=model_name, 130 | width=width, 131 | height=height, 132 | params=np.array(params)) 133 | assert len(cameras) == num_cameras 134 | return cameras 135 | 136 | 137 | def read_images_text(path): 138 | """ 139 | see: src/base/reconstruction.cc 140 | void Reconstruction::ReadImagesText(const std::string& path) 141 | void Reconstruction::WriteImagesText(const std::string& path) 142 | """ 143 | images = {} 144 | with open(path, "r") as fid: 145 | while True: 146 | line = fid.readline() 147 | if not line: 148 | break 149 | line = line.strip() 150 | if len(line) > 0 and line[0] != "#": 151 | elems = line.split() 152 | image_id = int(elems[0]) 153 | qvec = np.array(tuple(map(float, elems[1:5]))) 154 | tvec = np.array(tuple(map(float, elems[5:8]))) 155 | camera_id = int(elems[8]) 156 | image_name = elems[9] 157 | elems = fid.readline().split() 158 | xys = np.column_stack([tuple(map(float, elems[0::3])), 159 | tuple(map(float, elems[1::3]))]) 160 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 161 | images[image_id] = Image( 162 | id=image_id, qvec=qvec, tvec=tvec, 163 | camera_id=camera_id, name=image_name, 164 | xys=xys, point3D_ids=point3D_ids) 165 | return images 166 | 167 | 168 | def read_images_binary(path_to_model_file): 169 | """ 170 | see: src/base/reconstruction.cc 171 | void Reconstruction::ReadImagesBinary(const std::string& path) 172 | void Reconstruction::WriteImagesBinary(const std::string& path) 173 | """ 174 | images = {} 175 | with open(path_to_model_file, "rb") as fid: 176 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 177 | for image_index in range(num_reg_images): 178 | binary_image_properties = read_next_bytes( 179 | fid, num_bytes=64, format_char_sequence="idddddddi") 180 | image_id = binary_image_properties[0] 181 | qvec = np.array(binary_image_properties[1:5]) 182 | tvec = np.array(binary_image_properties[5:8]) 183 | camera_id = binary_image_properties[8] 184 | image_name = "" 185 | current_char = read_next_bytes(fid, 1, "c")[0] 186 | while current_char != b"\x00": # look for the ASCII 0 entry 187 | image_name += current_char.decode("utf-8") 188 | current_char = read_next_bytes(fid, 1, "c")[0] 189 | num_points2D = read_next_bytes(fid, num_bytes=8, 190 | format_char_sequence="Q")[0] 191 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 192 | format_char_sequence="ddq"*num_points2D) 193 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 194 | tuple(map(float, x_y_id_s[1::3]))]) 195 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 196 | images[image_id] = Image( 197 | id=image_id, qvec=qvec, tvec=tvec, 198 | camera_id=camera_id, name=image_name, 199 | xys=xys, point3D_ids=point3D_ids) 200 | return images 201 | 202 | 203 | def read_points3D_text(path): 204 | """ 205 | see: src/base/reconstruction.cc 206 | void Reconstruction::ReadPoints3DText(const std::string& path) 207 | void Reconstruction::WritePoints3DText(const std::string& path) 208 | """ 209 | points3D = {} 210 | with open(path, "r") as fid: 211 | while True: 212 | line = fid.readline() 213 | if not line: 214 | break 215 | line = line.strip() 216 | if len(line) > 0 and line[0] != "#": 217 | elems = line.split() 218 | point3D_id = int(elems[0]) 219 | xyz = np.array(tuple(map(float, elems[1:4]))) 220 | rgb = np.array(tuple(map(int, elems[4:7]))) 221 | error = float(elems[7]) 222 | image_ids = np.array(tuple(map(int, elems[8::2]))) 223 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 224 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 225 | error=error, image_ids=image_ids, 226 | point2D_idxs=point2D_idxs) 227 | return points3D 228 | 229 | 230 | def read_points3d_binary(path_to_model_file): 231 | """ 232 | see: src/base/reconstruction.cc 233 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 234 | void Reconstruction::WritePoints3DBinary(const std::string& path) 235 | """ 236 | points3D = {} 237 | with open(path_to_model_file, "rb") as fid: 238 | num_points = read_next_bytes(fid, 8, "Q")[0] 239 | for point_line_index in range(num_points): 240 | binary_point_line_properties = read_next_bytes( 241 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 242 | point3D_id = binary_point_line_properties[0] 243 | xyz = np.array(binary_point_line_properties[1:4]) 244 | rgb = np.array(binary_point_line_properties[4:7]) 245 | error = np.array(binary_point_line_properties[7]) 246 | track_length = read_next_bytes( 247 | fid, num_bytes=8, format_char_sequence="Q")[0] 248 | track_elems = read_next_bytes( 249 | fid, num_bytes=8*track_length, 250 | format_char_sequence="ii"*track_length) 251 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 252 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 253 | points3D[point3D_id] = Point3D( 254 | id=point3D_id, xyz=xyz, rgb=rgb, 255 | error=error, image_ids=image_ids, 256 | point2D_idxs=point2D_idxs) 257 | return points3D 258 | 259 | 260 | def read_model(path, ext): 261 | if ext == ".txt": 262 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 263 | images = read_images_text(os.path.join(path, "images" + ext)) 264 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 265 | else: 266 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 267 | images = read_images_binary(os.path.join(path, "images" + ext)) 268 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 269 | return cameras, images, points3D 270 | 271 | 272 | def qvec2rotmat(qvec): 273 | return np.array([ 274 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 275 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 276 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 277 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 278 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 279 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 280 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 281 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 282 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 283 | 284 | 285 | def rotmat2qvec(R): 286 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 287 | K = np.array([ 288 | [Rxx - Ryy - Rzz, 0, 0, 0], 289 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 290 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 291 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 292 | eigvals, eigvecs = np.linalg.eigh(K) 293 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 294 | if qvec[0] < 0: 295 | qvec *= -1 296 | return qvec 297 | 298 | 299 | def main(): 300 | if len(sys.argv) != 3: 301 | print("Usage: python read_model.py path/to/model/folder [.txt,.bin]") 302 | return 303 | 304 | cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2]) 305 | 306 | print("num_cameras:", len(cameras)) 307 | print("num_images:", len(images)) 308 | print("num_points3D:", len(points3D)) 309 | 310 | 311 | if __name__ == "__main__": 312 | main() 313 | -------------------------------------------------------------------------------- /nsff_scripts/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/core/__init__.py -------------------------------------------------------------------------------- /nsff_scripts/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /nsff_scripts/core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /nsff_scripts/core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /nsff_scripts/core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from update import BasicUpdateBlock, SmallUpdateBlock 7 | from extractor import BasicEncoder, SmallEncoder 8 | from corr import CorrBlock, AlternateCorrBlock 9 | from utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in self.args: 42 | self.args.dropout = 0 43 | 44 | if 'alternate_corr' not in self.args: 45 | self.args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | def freeze_bn(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.eval() 62 | 63 | def initialize_flow(self, img): 64 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 65 | N, C, H, W = img.shape 66 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 67 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 68 | 69 | # optical flow computed as difference: flow = coords1 - coords0 70 | return coords0, coords1 71 | 72 | def upsample_flow(self, flow, mask): 73 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 74 | N, _, H, W = flow.shape 75 | mask = mask.view(N, 1, 9, 8, 8, H, W) 76 | mask = torch.softmax(mask, dim=2) 77 | 78 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 79 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 80 | 81 | up_flow = torch.sum(mask * up_flow, dim=2) 82 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 83 | return up_flow.reshape(N, 2, 8*H, 8*W) 84 | 85 | 86 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 87 | """ Estimate optical flow between pair of frames """ 88 | 89 | image1 = 2 * (image1 / 255.0) - 1.0 90 | image2 = 2 * (image2 / 255.0) - 1.0 91 | 92 | image1 = image1.contiguous() 93 | image2 = image2.contiguous() 94 | 95 | hdim = self.hidden_dim 96 | cdim = self.context_dim 97 | 98 | # run the feature network 99 | with autocast(enabled=self.args.mixed_precision): 100 | fmap1, fmap2 = self.fnet([image1, image2]) 101 | 102 | fmap1 = fmap1.float() 103 | fmap2 = fmap2.float() 104 | if self.args.alternate_corr: 105 | corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 106 | else: 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | cnet = self.cnet(image1) 112 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 113 | net = torch.tanh(net) 114 | inp = torch.relu(inp) 115 | 116 | coords0, coords1 = self.initialize_flow(image1) 117 | 118 | if flow_init is not None: 119 | coords1 = coords1 + flow_init 120 | 121 | flow_predictions = [] 122 | for itr in range(iters): 123 | coords1 = coords1.detach() 124 | corr = corr_fn(coords1) # index correlation volume 125 | 126 | flow = coords1 - coords0 127 | with autocast(enabled=self.args.mixed_precision): 128 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 129 | 130 | # F(t+1) = F(t) + \Delta(t) 131 | coords1 = coords1 + delta_flow 132 | 133 | # upsample predictions 134 | if up_mask is None: 135 | flow_up = upflow8(coords1 - coords0) 136 | else: 137 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 138 | 139 | flow_predictions.append(flow_up) 140 | 141 | if test_mode: 142 | return coords1 - coords0, flow_up 143 | 144 | return flow_predictions 145 | -------------------------------------------------------------------------------- /nsff_scripts/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /nsff_scripts/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/core/utils/__init__.py -------------------------------------------------------------------------------- /nsff_scripts/core/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/core/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /nsff_scripts/core/utils/__pycache__/flow_viz.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/core/utils/__pycache__/flow_viz.cpython-36.pyc -------------------------------------------------------------------------------- /nsff_scripts/core/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengqili/Neural-Scene-Flow-Fields/d4001759a39b056c95d8bc22da34b10b4fb85afb/nsff_scripts/core/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /nsff_scripts/core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /nsff_scripts/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /nsff_scripts/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /nsff_scripts/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /nsff_scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /nsff_scripts/flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import glob 5 | import cv2 6 | import scipy.io 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | def read_img(img_dir, img1_name, img2_name): 13 | # print(os.path.join(img_dir, img1_name + '.png')) 14 | return cv2.imread(os.path.join(img_dir, img1_name + '.png')), cv2.imread(os.path.join(img_dir, img2_name + '.png')) 15 | 16 | def refinement_flow(fwd_flow, img1, img2): 17 | flow_refine = cv2.VariationalRefinement.create() 18 | 19 | refine_flow = flow_refine.calc(cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY), 20 | cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY), 21 | fwd_flow) 22 | 23 | return refine_flow 24 | 25 | def make_color_wheel(): 26 | """ 27 | Generate color wheel according Middlebury color code 28 | :return: Color wheel 29 | """ 30 | RY = 15 31 | YG = 6 32 | GC = 4 33 | CB = 11 34 | BM = 13 35 | MR = 6 36 | 37 | ncols = RY + YG + GC + CB + BM + MR 38 | 39 | colorwheel = np.zeros([ncols, 3]) 40 | 41 | col = 0 42 | 43 | # RY 44 | colorwheel[0:RY, 0] = 255 45 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 46 | col += RY 47 | 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 50 | colorwheel[col:col+YG, 1] = 255 51 | col += YG 52 | 53 | # GC 54 | colorwheel[col:col+GC, 1] = 255 55 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 56 | col += GC 57 | 58 | # CB 59 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 60 | colorwheel[col:col+CB, 2] = 255 61 | col += CB 62 | 63 | # BM 64 | colorwheel[col:col+BM, 2] = 255 65 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 66 | col += + BM 67 | 68 | # MR 69 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 70 | colorwheel[col:col+MR, 0] = 255 71 | 72 | return colorwheel 73 | 74 | 75 | def compute_color(u, v): 76 | """ 77 | compute optical flow color map 78 | :param u: optical flow horizontal map 79 | :param v: optical flow vertical map 80 | :return: optical flow in color code 81 | """ 82 | [h, w] = u.shape 83 | img = np.zeros([h, w, 3]) 84 | nanIdx = np.isnan(u) | np.isnan(v) 85 | u[nanIdx] = 0 86 | v[nanIdx] = 0 87 | 88 | colorwheel = make_color_wheel() 89 | ncols = np.size(colorwheel, 0) 90 | 91 | rad = np.sqrt(u**2+v**2) 92 | 93 | a = np.arctan2(-v, -u) / np.pi 94 | 95 | fk = (a+1) / 2 * (ncols - 1) + 1 96 | 97 | k0 = np.floor(fk).astype(int) 98 | 99 | k1 = k0 + 1 100 | k1[k1 == ncols+1] = 1 101 | f = fk - k0 102 | 103 | for i in range(0, np.size(colorwheel,1)): 104 | tmp = colorwheel[:, i] 105 | col0 = tmp[k0-1] / 255 106 | col1 = tmp[k1-1] / 255 107 | col = (1-f) * col0 + f * col1 108 | 109 | idx = rad <= 1 110 | col[idx] = 1-rad[idx]*(1-col[idx]) 111 | notidx = np.logical_not(idx) 112 | 113 | col[notidx] *= 0.75 114 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 115 | 116 | return img 117 | 118 | 119 | def flow_to_image(flow, display=False): 120 | """ 121 | Convert flow into middlebury color code image 122 | :param flow: optical flow map 123 | :return: optical flow image in middlebury color 124 | """ 125 | UNKNOWN_FLOW_THRESH = 100 126 | u = flow[:, :, 0] 127 | v = flow[:, :, 1] 128 | 129 | maxu = -999. 130 | maxv = -999. 131 | minu = 999. 132 | minv = 999. 133 | 134 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 135 | u[idxUnknow] = 0 136 | v[idxUnknow] = 0 137 | 138 | maxu = max(maxu, np.max(u)) 139 | minu = min(minu, np.min(u)) 140 | 141 | maxv = max(maxv, np.max(v)) 142 | minv = min(minv, np.min(v)) 143 | 144 | # sqrt_rad = u**2 + v**2 145 | rad = np.sqrt(u**2 + v**2) 146 | 147 | maxrad = max(-1, np.max(rad)) 148 | 149 | if display: 150 | print("max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv)) 151 | 152 | u = u/(maxrad + np.finfo(float).eps) 153 | v = v/(maxrad + np.finfo(float).eps) 154 | 155 | img = compute_color(u, v) 156 | 157 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 158 | img[idx] = 0 159 | 160 | return np.uint8(img) 161 | 162 | 163 | def warp_flow(img, flow): 164 | h, w = flow.shape[:2] 165 | flow_new = flow.copy() 166 | flow_new[:,:,0] += np.arange(w) 167 | flow_new[:,:,1] += np.arange(h)[:,np.newaxis] 168 | 169 | res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT) 170 | return res 171 | 172 | def resize_flow(flow, img_h, img_w): 173 | # flow = np.load(flow_path) 174 | # flow_h, flow_w = flow.shape[0], flow.shape[1] 175 | flow[:, :, 0] *= float(img_w)/float(flow_w) 176 | flow[:, :, 1] *= float(img_h)/float(flow_h) 177 | flow = cv2.resize(flow, (img_w, img_h), cv2.INTER_LINEAR) 178 | 179 | return flow 180 | 181 | def extract_poses(im): 182 | R = im.qvec2rotmat() 183 | t = im.tvec.reshape([3,1]) 184 | bottom = np.array([0,0,0,1.]).reshape([1,4]) 185 | 186 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 187 | 188 | return m 189 | 190 | def load_colmap_data(realdir): 191 | import colmap_read_model as read_model 192 | 193 | camerasfile = os.path.join(realdir, 'sparse/cameras.bin') 194 | camdata = read_model.read_cameras_binary(camerasfile) 195 | 196 | list_of_keys = list(camdata.keys()) 197 | cam = camdata[list_of_keys[0]] 198 | print( 'Cameras', len(cam)) 199 | 200 | h, w, f = cam.height, cam.width, cam.params[0] 201 | # w, h, f = factor * w, factor * h, factor * f 202 | hwf = np.array([h,w,f]).reshape([3,1]) 203 | 204 | imagesfile = os.path.join(realdir, 'sparse/images.bin') 205 | imdata = read_model.read_images_binary(imagesfile) 206 | 207 | w2c_mats = [] 208 | # bottom = np.array([0,0,0,1.]).reshape([1,4]) 209 | 210 | names = [imdata[k].name for k in imdata] 211 | img_keys = [k for k in imdata] 212 | 213 | print( 'Images #', len(names)) 214 | perm = np.argsort(names) 215 | 216 | return imdata, perm, img_keys, hwf 217 | 218 | def skew(x): 219 | return np.array([[0, -x[2], x[1]], 220 | [x[2], 0, -x[0]], 221 | [-x[1], x[0], 0]]) 222 | 223 | 224 | def compute_epipolar_distance(T_21, K, p_1, p_2): 225 | R_21 = T_21[:3, :3] 226 | t_21 = T_21[:3, 3] 227 | 228 | E_mat = np.dot(skew(t_21), R_21) 229 | # compute bearing vector 230 | inv_K = np.linalg.inv(K) 231 | 232 | F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K) 233 | 234 | l_2 = np.dot(F_mat, p_1) 235 | algebric_e_distance = np.sum(p_2 * l_2, axis=0) 236 | n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8 237 | geometric_e_distance = algebric_e_distance/n_term 238 | geometric_e_distance = np.abs(geometric_e_distance) 239 | 240 | return geometric_e_distance 241 | 242 | def read_optical_flow(basedir, img_i_name, read_fwd): 243 | flow_dir = os.path.join(basedir, 'flow_i1') 244 | 245 | fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4]) 246 | bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4]) 247 | 248 | if read_fwd: 249 | fwd_data = np.load(fwd_flow_path)#, (w, h)) 250 | fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask'] 251 | # fwd_mask = np.float32(fwd_mask) 252 | 253 | # bwd_flow = np.zeros_like(fwd_flow) 254 | return fwd_flow, fwd_mask 255 | else: 256 | bwd_data = np.load(bwd_flow_path)#, (w, h)) 257 | bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask'] 258 | # bwd_mask = np.float32(bwd_mask) 259 | # fwd_flow = np.zeros_like(bwd_flow) 260 | return bwd_flow, bwd_mask 261 | # return fwd_flow, bwd_flow#, fwd_mask, bwd_mask 262 | -------------------------------------------------------------------------------- /nsff_scripts/models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(torch.nn.Module): 6 | def load(self, path): 7 | """Load model from file. 8 | 9 | Args: 10 | path (str): file path 11 | """ 12 | parameters = torch.load(path) 13 | 14 | if "optimizer" in parameters: 15 | parameters = parameters["model"] 16 | 17 | self.load_state_dict(parameters) 18 | -------------------------------------------------------------------------------- /nsff_scripts/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def _make_encoder(features, use_pretrained): 6 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 7 | scratch = _make_scratch([256, 512, 1024, 2048], features) 8 | 9 | return pretrained, scratch 10 | 11 | 12 | def _make_resnet_backbone(resnet): 13 | pretrained = nn.Module() 14 | pretrained.layer1 = nn.Sequential( 15 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 16 | ) 17 | 18 | pretrained.layer2 = resnet.layer2 19 | pretrained.layer3 = resnet.layer3 20 | pretrained.layer4 = resnet.layer4 21 | 22 | return pretrained 23 | 24 | 25 | def _make_pretrained_resnext101_wsl(use_pretrained): 26 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 27 | return _make_resnet_backbone(resnet) 28 | 29 | 30 | def _make_scratch(in_shape, out_shape): 31 | scratch = nn.Module() 32 | 33 | scratch.layer1_rn = nn.Conv2d( 34 | in_shape[0], out_shape, kernel_size=3, stride=1, padding=1, bias=False 35 | ) 36 | scratch.layer2_rn = nn.Conv2d( 37 | in_shape[1], out_shape, kernel_size=3, stride=1, padding=1, bias=False 38 | ) 39 | scratch.layer3_rn = nn.Conv2d( 40 | in_shape[2], out_shape, kernel_size=3, stride=1, padding=1, bias=False 41 | ) 42 | scratch.layer4_rn = nn.Conv2d( 43 | in_shape[3], out_shape, kernel_size=3, stride=1, padding=1, bias=False 44 | ) 45 | return scratch 46 | 47 | 48 | class Interpolate(nn.Module): 49 | """Interpolation module. 50 | """ 51 | 52 | def __init__(self, scale_factor, mode): 53 | """Init. 54 | 55 | Args: 56 | scale_factor (float): scaling 57 | mode (str): interpolation mode 58 | """ 59 | super(Interpolate, self).__init__() 60 | 61 | self.interp = nn.functional.interpolate 62 | self.scale_factor = scale_factor 63 | self.mode = mode 64 | 65 | def forward(self, x): 66 | """Forward pass. 67 | 68 | Args: 69 | x (tensor): input 70 | 71 | Returns: 72 | tensor: interpolated data 73 | """ 74 | 75 | x = self.interp( 76 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False 77 | ) 78 | 79 | return x 80 | 81 | 82 | class ResidualConvUnit(nn.Module): 83 | """Residual convolution module. 84 | """ 85 | 86 | def __init__(self, features): 87 | """Init. 88 | 89 | Args: 90 | features (int): number of features 91 | """ 92 | super().__init__() 93 | 94 | self.conv1 = nn.Conv2d( 95 | features, features, kernel_size=3, stride=1, padding=1, bias=True 96 | ) 97 | 98 | self.conv2 = nn.Conv2d( 99 | features, features, kernel_size=3, stride=1, padding=1, bias=True 100 | ) 101 | 102 | self.relu = nn.ReLU(inplace=True) 103 | 104 | def forward(self, x): 105 | """Forward pass. 106 | 107 | Args: 108 | x (tensor): input 109 | 110 | Returns: 111 | tensor: output 112 | """ 113 | out = self.relu(x) 114 | out = self.conv1(out) 115 | out = self.relu(out) 116 | out = self.conv2(out) 117 | 118 | return out + x 119 | 120 | 121 | class FeatureFusionBlock(nn.Module): 122 | """Feature fusion block. 123 | """ 124 | 125 | def __init__(self, features): 126 | """Init. 127 | 128 | Args: 129 | features (int): number of features 130 | """ 131 | super(FeatureFusionBlock, self).__init__() 132 | 133 | self.resConfUnit1 = ResidualConvUnit(features) 134 | self.resConfUnit2 = ResidualConvUnit(features) 135 | 136 | def forward(self, *xs): 137 | """Forward pass. 138 | 139 | Returns: 140 | tensor: output 141 | """ 142 | output = xs[0] 143 | 144 | if len(xs) == 2: 145 | output += self.resConfUnit1(xs[1]) 146 | 147 | output = self.resConfUnit2(output) 148 | 149 | output = nn.functional.interpolate( 150 | output, scale_factor=2, mode="bilinear", align_corners=True 151 | ) 152 | 153 | return output 154 | -------------------------------------------------------------------------------- /nsff_scripts/models/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.base_model import BaseModel 9 | from models.blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(features, use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /nsff_scripts/models/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | class Resize(object): 7 | """Resize sample to given size (width, height). 8 | """ 9 | 10 | def __init__( 11 | self, 12 | width, 13 | height, 14 | resize_target=True, 15 | keep_aspect_ratio=False, 16 | ensure_multiple_of=1, 17 | resize_method="lower_bound", 18 | image_interpolation_method=cv2.INTER_AREA, 19 | ): 20 | """Init. 21 | 22 | Args: 23 | width (int): desired output width 24 | height (int): desired output height 25 | resize_target (bool, optional): 26 | True: Resize the full sample (image, mask, target). 27 | False: Resize image only. 28 | Defaults to True. 29 | keep_aspect_ratio (bool, optional): 30 | True: Keep the aspect ratio of the input sample. 31 | Output sample might not have the given width and height, and 32 | resize behaviour depends on the parameter 'resize_method'. 33 | Defaults to False. 34 | ensure_multiple_of (int, optional): 35 | Output width and height is constrained to be multiple of this parameter. 36 | Defaults to 1. 37 | resize_method (str, optional): 38 | "lower_bound": Output will be at least as large as the given size. 39 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 40 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 41 | Defaults to "lower_bound". 42 | """ 43 | self.__width = width 44 | self.__height = height 45 | 46 | self.__resize_target = resize_target 47 | self.__keep_aspect_ratio = keep_aspect_ratio 48 | self.__multiple_of = ensure_multiple_of 49 | self.__resize_method = resize_method 50 | self.__image_interpolation_method = image_interpolation_method 51 | 52 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 53 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 54 | 55 | if max_val is not None and y > max_val: 56 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 57 | 58 | if y < min_val: 59 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 60 | 61 | return y 62 | 63 | def get_size(self, width, height): 64 | # determine new height and width 65 | scale_height = self.__height / height 66 | scale_width = self.__width / width 67 | 68 | if self.__keep_aspect_ratio: 69 | if self.__resize_method == "lower_bound": 70 | # scale such that output size is lower bound 71 | if scale_width > scale_height: 72 | # fit width 73 | scale_height = scale_width 74 | else: 75 | # fit height 76 | scale_width = scale_height 77 | elif self.__resize_method == "upper_bound": 78 | # scale such that output size is upper bound 79 | if scale_width < scale_height: 80 | # fit width 81 | scale_height = scale_width 82 | else: 83 | # fit height 84 | scale_width = scale_height 85 | elif self.__resize_method == "minimal": 86 | # scale as least as possbile 87 | if abs(1 - scale_width) < abs(1 - scale_height): 88 | # fit width 89 | scale_height = scale_width 90 | else: 91 | # fit height 92 | scale_width = scale_height 93 | else: 94 | raise ValueError( 95 | f"resize_method {self.__resize_method} not implemented" 96 | ) 97 | 98 | if self.__resize_method == "lower_bound": 99 | new_height = self.constrain_to_multiple_of( 100 | scale_height * height, min_val=self.__height 101 | ) 102 | new_width = self.constrain_to_multiple_of( 103 | scale_width * width, min_val=self.__width 104 | ) 105 | elif self.__resize_method == "upper_bound": 106 | new_height = self.constrain_to_multiple_of( 107 | scale_height * height, max_val=self.__height 108 | ) 109 | new_width = self.constrain_to_multiple_of( 110 | scale_width * width, max_val=self.__width 111 | ) 112 | elif self.__resize_method == "minimal": 113 | new_height = self.constrain_to_multiple_of(scale_height * height) 114 | new_width = self.constrain_to_multiple_of(scale_width * width) 115 | else: 116 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 117 | 118 | return (new_width, new_height) 119 | 120 | def __call__(self, sample): 121 | width, height = self.get_size( 122 | sample["image"].shape[1], sample["image"].shape[0] 123 | ) 124 | 125 | # resize sample 126 | sample["image"] = cv2.resize( 127 | sample["image"], 128 | (width, height), 129 | interpolation=self.__image_interpolation_method, 130 | ) 131 | 132 | if self.__resize_target: 133 | if "disparity" in sample: 134 | sample["disparity"] = cv2.resize( 135 | sample["disparity"], 136 | (width, height), 137 | interpolation=cv2.INTER_NEAREST, 138 | ) 139 | 140 | if "depth" in sample: 141 | sample["depth"] = cv2.resize( 142 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 143 | ) 144 | 145 | sample["mask"] = cv2.resize( 146 | sample["mask"].astype(np.float32), 147 | (width, height), 148 | interpolation=cv2.INTER_NEAREST, 149 | ) 150 | sample["mask"] = sample["mask"].astype(bool) 151 | 152 | return sample 153 | 154 | 155 | class NormalizeImage(object): 156 | """Normlize image by given mean and std. 157 | """ 158 | 159 | def __init__(self, mean, std): 160 | self.__mean = mean 161 | self.__std = std 162 | 163 | def __call__(self, sample): 164 | sample["image"] = (sample["image"] - self.__mean) / self.__std 165 | 166 | return sample 167 | 168 | 169 | class PrepareForNet(object): 170 | """Prepare sample for usage as network input. 171 | """ 172 | 173 | def __init__(self): 174 | pass 175 | 176 | def __call__(self, sample): 177 | image = np.transpose(sample["image"], (2, 0, 1)) 178 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 179 | 180 | if "mask" in sample: 181 | sample["mask"] = sample["mask"].astype(np.float32) 182 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 183 | 184 | if "disparity" in sample: 185 | disparity = sample["disparity"].astype(np.float32) 186 | sample["disparity"] = np.ascontiguousarray(disparity) 187 | 188 | if "depth" in sample: 189 | depth = sample["depth"].astype(np.float32) 190 | sample["depth"] = np.ascontiguousarray(depth) 191 | 192 | return sample 193 | -------------------------------------------------------------------------------- /nsff_scripts/run_midas.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute depth maps for images in the input folder. 3 | """ 4 | 5 | import os 6 | import glob 7 | import torch 8 | import cv2 9 | import numpy as np 10 | from torchvision.transforms import Compose 11 | from models.midas_net import MidasNet 12 | from models.transforms import Resize, NormalizeImage, PrepareForNet 13 | 14 | import sys 15 | 16 | import matplotlib 17 | matplotlib.use('Agg') 18 | import matplotlib.pyplot as plt 19 | 20 | VIZ = True 21 | 22 | def read_image(path): 23 | """Read image and output RGB image (0-1). 24 | 25 | Args: 26 | path (str): path to file 27 | 28 | Returns: 29 | array: RGB image (0-1) 30 | """ 31 | img = cv2.imread(path) 32 | 33 | if img.ndim == 2: 34 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 35 | 36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 37 | 38 | return img 39 | 40 | def _minify(basedir, factors=[], resolutions=[]): 41 | ''' 42 | Minify the images to small resolution for training 43 | ''' 44 | 45 | needtoload = False 46 | for r in factors: 47 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 48 | if not os.path.exists(imgdir): 49 | needtoload = True 50 | for r in resolutions: 51 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 52 | if not os.path.exists(imgdir): 53 | needtoload = True 54 | if not needtoload: 55 | return 56 | 57 | from shutil import copy 58 | from subprocess import check_output 59 | import glob 60 | 61 | imgdir = os.path.join(basedir, 'images') 62 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 63 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 64 | imgdir_orig = imgdir 65 | 66 | wd = os.getcwd() 67 | 68 | for r in factors + resolutions: 69 | if isinstance(r, int): 70 | name = 'images_{}'.format(r) 71 | resizearg = '{}%'.format(100./r) 72 | else: 73 | name = 'images_{}x{}'.format(r[1], r[0]) 74 | resizearg = '{}x{}'.format(r[1], r[0]) 75 | 76 | imgdir = os.path.join(basedir, name) 77 | if os.path.exists(imgdir): 78 | continue 79 | 80 | print('Minifying', r, basedir) 81 | 82 | os.makedirs(imgdir) 83 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 84 | 85 | ext = imgs[0].split('.')[-1] 86 | print(ext) 87 | # sys.exit() 88 | img_path_list = glob.glob(os.path.join(imgdir, '*.%s'%ext)) 89 | 90 | for img_path in img_path_list: 91 | save_path = img_path.replace('.jpg', '.png') 92 | img = cv2.imread(img_path) 93 | 94 | print(img.shape, r) 95 | 96 | cv2.imwrite(save_path, 97 | cv2.resize(img, 98 | (r[1], r[0]), 99 | interpolation=cv2.INTER_AREA)) 100 | 101 | if ext != 'png': 102 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 103 | print('Removed duplicates') 104 | print('Done') 105 | 106 | 107 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 108 | import imageio 109 | 110 | def run(basedir, 111 | input_path, 112 | output_path, 113 | model_path, 114 | resize_height=288): 115 | """Run MonoDepthNN to compute depth maps. 116 | 117 | Args: 118 | input_path (str): path to input folder 119 | output_path (str): path to output folder 120 | model_path (str): path to saved model 121 | """ 122 | print("initialize") 123 | 124 | img0 = [os.path.join(basedir, 'images', f) \ 125 | for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 126 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 127 | sh = cv2.imread(img0).shape 128 | height = resize_height 129 | factor = sh[0] / float(height) 130 | width = int(round(sh[1] / factor)) 131 | _minify(basedir, resolutions=[[height, width]]) 132 | 133 | # select device 134 | device = torch.device("cuda") 135 | print("device: %s" % device) 136 | 137 | small_img_dir = input_path + '_*x' + str(resize_height) + '/' 138 | print(small_img_dir) 139 | 140 | small_img_path = sorted(glob.glob(glob.glob(small_img_dir)[0] + '/*.png'))[0] 141 | 142 | small_img = cv2.imread(small_img_path) 143 | 144 | print('small_img', small_img.shape) 145 | 146 | # Portrait Orientation 147 | if small_img.shape[0] > small_img.shape[1]: 148 | input_h = 640 149 | input_w = int(round( float(input_h) / small_img.shape[0] * small_img.shape[1])) 150 | # Landscape Orientation 151 | else: 152 | input_w = 640 153 | input_h = int(round( float(input_w) / small_img.shape[1] * small_img.shape[0])) 154 | 155 | print('Monocular depth input_w %d input_h %d '%(input_w, input_h)) 156 | 157 | # load network 158 | model = MidasNet(model_path, non_negative=True) 159 | 160 | transform_1 = Compose( 161 | [ 162 | Resize( 163 | input_w, 164 | input_h, 165 | resize_target=None, 166 | keep_aspect_ratio=True, 167 | ensure_multiple_of=32, 168 | resize_method="upper_bound", 169 | image_interpolation_method=cv2.INTER_AREA, 170 | ), 171 | NormalizeImage(mean=[0.485, 0.456, 0.406], 172 | std=[0.229, 0.224, 0.225]), 173 | PrepareForNet(), 174 | ] 175 | ) 176 | 177 | model.to(device) 178 | model.eval() 179 | 180 | # get input 181 | img_names = sorted(glob.glob(os.path.join(input_path, "*"))) 182 | num_images = len(img_names) 183 | 184 | # create output folder 185 | os.makedirs(output_path, exist_ok=True) 186 | 187 | print("start processing") 188 | 189 | for ind in range(len(img_names)): 190 | 191 | img_name = img_names[ind] 192 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 193 | # input 194 | img = read_image(img_name) 195 | img_input_1 = transform_1({"image": img})["image"] 196 | 197 | # compute 198 | with torch.no_grad(): 199 | sample_1 = torch.from_numpy(img_input_1).to(device).unsqueeze(0) 200 | prediction = model.forward(sample_1) 201 | prediction = ( 202 | torch.nn.functional.interpolate( 203 | prediction.unsqueeze(1), 204 | size=[small_img.shape[0], 205 | small_img.shape[1]], 206 | mode="nearest", 207 | ) 208 | .squeeze() 209 | .cpu() 210 | .numpy() 211 | ) 212 | 213 | # output 214 | filename = os.path.join( 215 | output_path, os.path.splitext(os.path.basename(img_name))[0] 216 | ) 217 | 218 | 219 | if VIZ: 220 | if not os.path.exists('./midas_otuputs'): 221 | os.makedirs('./midas_otuputs') 222 | 223 | plt.figure(figsize=(12, 6)) 224 | plt.subplot(1,2,1) 225 | plt.imshow(img) 226 | plt.subplot(1,2,2) 227 | plt.imshow(prediction, cmap='jet') 228 | plt.savefig('./midas_otuputs/%s'%(img_name.split('/')[-1])) 229 | plt.close() 230 | 231 | print(filename + '.npy') 232 | np.save(filename + '.npy', prediction.astype(np.float32)) 233 | 234 | print("finished") 235 | 236 | import argparse 237 | 238 | if __name__ == "__main__": 239 | parser = argparse.ArgumentParser() 240 | parser.add_argument("--data_path", type=str, 241 | help='COLMAP Directory') 242 | # parser.add_argument("--input_w", type=int, default=640, 243 | # help='input image width for monocular depth network') 244 | # parser.add_argument("--input_h", type=int, default=360, 245 | # help='input image height for monocular depth network') 246 | parser.add_argument("--resize_height", type=int, default=288, 247 | help='resized image height for training \ 248 | (width will be resized based on original aspect ratio)') 249 | 250 | args = parser.parse_args() 251 | BASE_DIR = args.data_path 252 | 253 | INPUT_PATH = BASE_DIR + "/images" 254 | OUTPUT_PATH = BASE_DIR + "/disp" 255 | 256 | MODEL_PATH = "model.pt" 257 | if not os.path.exists(OUTPUT_PATH): 258 | os.makedirs(OUTPUT_PATH) 259 | 260 | # set torch options 261 | torch.backends.cudnn.enabled = True 262 | torch.backends.cudnn.benchmark = True 263 | 264 | # compute depth maps 265 | run(BASE_DIR, INPUT_PATH, 266 | OUTPUT_PATH, MODEL_PATH, 267 | args.resize_height) 268 | 269 | 270 | -------------------------------------------------------------------------------- /nsff_scripts/save_poses_nerf.py: -------------------------------------------------------------------------------- 1 | import colmap_read_model as read_model 2 | import numpy as np 3 | import os 4 | import sys 5 | import json 6 | 7 | def get_bbox_corners(points): 8 | lower = points.min(axis=0) 9 | upper = points.max(axis=0) 10 | return np.stack([lower, upper]) 11 | 12 | def filter_outlier_points(points, inner_percentile): 13 | """Filters outlier points.""" 14 | outer = 1.0 - inner_percentile 15 | lower = outer / 2.0 16 | upper = 1.0 - lower 17 | centers_min = np.quantile(points, lower, axis=0) 18 | centers_max = np.quantile(points, upper, axis=0) 19 | result = points.copy() 20 | 21 | too_near = np.any(result < centers_min[None, :], axis=1) 22 | too_far = np.any(result > centers_max[None, :], axis=1) 23 | 24 | return result[~(too_near | too_far)] 25 | 26 | def load_colmap_data(realdir): 27 | camerasfile = os.path.join(realdir, 'sparse/cameras.bin') 28 | camdata = read_model.read_cameras_binary(camerasfile) 29 | 30 | list_of_keys = list(camdata.keys()) 31 | cam = camdata[list_of_keys[0]] 32 | print( 'Cameras', len(cam)) 33 | 34 | h, w, f = cam.height, cam.width, cam.params[0] 35 | # w, h, f = factor * w, factor * h, factor * f 36 | hwf = np.array([h, w, f]).reshape([3,1]) 37 | 38 | imagesfile = os.path.join(realdir, 'sparse/images.bin') 39 | imdata = read_model.read_images_binary(imagesfile) 40 | 41 | w2c_mats = [] 42 | bottom = np.array([0,0,0,1.]).reshape([1,4]) 43 | 44 | names = [imdata[k].name for k in imdata] 45 | img_keys = [k for k in imdata] 46 | 47 | print( 'Images #', len(names)) 48 | perm = np.argsort(names) 49 | 50 | points3dfile = os.path.join(realdir, 'sparse/points3D.bin') 51 | pts3d = read_model.read_points3d_binary(points3dfile) 52 | 53 | # extract point 3D xyz 54 | point_cloud = [] 55 | for key in pts3d: 56 | point_cloud.append(pts3d[key].xyz) 57 | 58 | point_cloud = np.stack(point_cloud, 0) 59 | point_cloud = filter_outlier_points(point_cloud, 0.95) 60 | 61 | bounds_mats = [] 62 | 63 | upper_bound = 1000 64 | 65 | if upper_bound < len(img_keys): 66 | print("Only keeping " + str(upper_bound) + " images!") 67 | 68 | for i in perm[0:min(upper_bound, len(img_keys))]: 69 | im = imdata[img_keys[i]] 70 | print(im.name) 71 | R = im.qvec2rotmat() 72 | t = im.tvec.reshape([3,1]) 73 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 74 | w2c_mats.append(m) 75 | 76 | pts_3d_idx = im.point3D_ids 77 | pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0] 78 | 79 | # 80 | depth_list = [] 81 | for k in range(len(pts_3d_vis_idx)): 82 | point_info = pts3d[pts_3d_vis_idx[k]] 83 | P_g = point_info.xyz 84 | P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1) 85 | depth_list.append(P_c[2]) 86 | 87 | zs = np.array(depth_list) 88 | close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95) 89 | bounds = np.array([close_depth, inf_depth]) 90 | bounds_mats.append(bounds) 91 | 92 | w2c_mats = np.stack(w2c_mats, 0) 93 | # bounds_mats = np.stack(bounds_mats, 0) 94 | c2w_mats = np.linalg.inv(w2c_mats) 95 | 96 | # bbox_corners = get_bbox_corners(point_cloud) 97 | # also add camera 98 | bbox_corners = get_bbox_corners( 99 | np.concatenate([point_cloud, c2w_mats[:, :3, 3]], axis=0)) 100 | 101 | scene_center = np.mean(bbox_corners, axis=0) 102 | scene_scale = 1.0 / np.sqrt(np.sum((bbox_corners[1] - bbox_corners[0]) ** 2)) 103 | 104 | print('bbox_corners ', bbox_corners) 105 | print('scene_center ', scene_center, scene_scale) 106 | 107 | poses = c2w_mats[:, :3, :4].transpose([1,2,0]) 108 | poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis], 109 | [1,1,poses.shape[-1]])], 1) 110 | 111 | # must switch to [-y, x, z] from [x, -y, -z], NOT [r, u, -t] 112 | poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], 113 | -poses[:, 2:3, :], 114 | poses[:, 3:4, :], 115 | poses[:, 4:5, :]], 1) 116 | 117 | save_arr = [] 118 | 119 | for i in range((poses.shape[2])): 120 | save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0)) 121 | 122 | save_arr = np.array(save_arr) 123 | print(save_arr.shape) 124 | np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr) 125 | with open(os.path.join(realdir, 'scene.json'), 'w') as f: 126 | json.dump({ 127 | 'scale': scene_scale, 128 | 'center': scene_center.tolist(), 129 | 'bbox': bbox_corners.tolist(), 130 | }, f, indent=2) 131 | 132 | import argparse 133 | 134 | if __name__=='__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("--data_path", type=str, 137 | help='COLMAP Directory') 138 | 139 | args = parser.parse_args() 140 | 141 | basedir = args.data_path #"/phoenix/S7/zl548/nerf_data/%s/dense"%scene_name 142 | load_colmap_data(basedir) 143 | print( 'Done with imgs2poses' ) 144 | --------------------------------------------------------------------------------