├── .gitignore ├── FlyingChairs_train_val.txt ├── LICENSE ├── README.md ├── dataloader ├── MiddleburyList.py ├── MiddleburySubmit.py ├── __init__.py ├── chairslist.py ├── flow_transforms.py ├── hd1klist.py ├── kitti12list.py ├── kitti15list.py ├── kitti15list_train.py ├── kitti15list_val.py ├── kitticliplist.py ├── listflowfile.py ├── mblist.py ├── robloader.py ├── sintellist.py ├── sintellist_clean.py ├── sintellist_final.py ├── sintellist_train.py ├── sintellist_val.py ├── stereo_kittilist12.py ├── stereo_kittilist15.py └── thingslist.py ├── dataset ├── IIW │ ├── hawk_000299.png │ └── hawk_000300.png └── kitti_scene │ └── testing │ └── image_2 │ ├── 000042_10.png │ └── 000042_11.png ├── eval_tmp.py ├── figs ├── architecture.png ├── hawk-vec.png ├── hawk.png ├── kitti-test-42-vec.png ├── kitti-test-42.png ├── output-onlinepngtools.png └── time-breakdown.png ├── flops.py ├── main.py ├── models ├── PWCNet.py ├── VCN.py ├── VCN_small.py ├── __init__.py ├── conv4d.py └── submodule.py ├── order.txt ├── run.sh ├── run_self.sh ├── submission.py ├── thop ├── __init__.py ├── count_hooks.py ├── profile.py └── utils.py ├── utils ├── __init__.py ├── flowlib.py ├── io.py ├── logger.py ├── multiscaleloss.py ├── pfm.py ├── readpfm.py ├── sintel_io.py └── util_flow.py └── visualize.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | pwc-time 2 | vcn-time 3 | vcn-time2 4 | profile.py 5 | results/ 6 | weights 7 | weights/ 8 | weights_small/ 9 | __pycache__/ 10 | */__pycache__/ 11 | iter_counts*.txt 12 | tmp/ 13 | *.lprof 14 | *.pyc 15 | models/*.pyc 16 | *.npy 17 | .ipynb_checkpoints/* 18 | .ipynb_checkpoints 19 | results/* 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Carnegie Mellon University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VCN: Volumetric correspondence networks for optical flow 2 | #### [[project website]](http://www.contrib.andrew.cmu.edu/~gengshay/neurips19flow) 3 | 4 | 5 | **Requirements** 6 | - python 3.6 7 | - pytorch 1.1.0-1.3.0 8 | - [pytorch correlation module](https://github.com/gengshan-y/Pytorch-Correlation-extension) (optional) This gives around 20ms speed-up on KITTI-sized images. However, only forward pass is implemented. Please replace self.corrf() with self.corr() in models/VCN.py and models/VCN_small.py if you plan to use the faster version. 9 | - [weights files (VCN)](https://drive.google.com/drive/folders/1mgadg50ti1QdwfAf6aR2v1pCx-ITsYfE?usp=sharing) Note: You can load them without untarring the files, see [pytorch saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html). 10 | - [weights files (VCN-small)](https://drive.google.com/drive/folders/16WvCL1Y5IkCoNmEEEbF_qmZPToMhw9VZ?usp=sharing) 11 | 12 | ## Pre-trained models 13 | #### To test on any two images 14 | Running [visualize.ipynb](./visualize.ipynb) gives you the following flow visualizations with color and vectors. Note: the sintel model "./weights/sintel-ft-trainval/finetune_67999.tar" is trained on multiple datasets and generalizes better than the KITTI model. 15 | 16 | 17 | 18 | 19 | 20 | 21 | #### KITTI 22 | **This correspondens to the entry on the [leaderboard](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) (Fl-all=6.30%).** 23 | ##### Evaluate on KITTI-15 benchmark 24 | 25 | To run + visualize on KITTI-15 test set, 26 | ``` 27 | modelname=kitti-ft-trainval 28 | i=149999 29 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015test --datapath dataset/kitti_scene/testing/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 512 --fac 2 30 | python eval_tmp.py --path ./weights/$modelname/ --vis yes --dataset 2015test 31 | ``` 32 | 33 | ##### Evaluate on KITTI-val 34 | *To see the details of the train-val split, please scroll down to "note on train-val" and run dataloader/kitti15list_val.py, dataloader/kitti15list_train.py, dataloader/sitnellist_train.py, and dataloader/sintellist_val.py.* 35 | 36 | To evaluate on the 40 validation images of KITTI-15 (0,5,...195), (also assuming the data is at /ssd/kitti_scene) 37 | ``` 38 | modelname=kitti-ft-trainval 39 | i=149999 40 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015 --datapath /ssd/kitti_scene/training/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 512 --fac 2 41 | python eval_tmp.py --path ./weights/$modelname/ --vis no --dataset 2015 42 | ``` 43 | 44 | To evaluate + visualize on KITTI-15 validation set, 45 | ``` 46 | python eval_tmp.py --path ./weights/$modelname/ --vis yes --dataset 2015 47 | ``` 48 | Evaluation error on 40 validation images : Fl-err = 3.9, EPE = 1.144 49 | 50 | #### Sintel 51 | **This correspondens to the entry on the [leaderboard](http://sintel.is.tue.mpg.de/quant?metric_id=0&selected_pass=0) (EPE-all-final = 4.404, EPE-all-clean = 2.808).** 52 | ##### Evaluate on Sintel-val 53 | 54 | To evaluate on Sintel validation set, 55 | ``` 56 | modelname=sintel-ft-trainval 57 | i=67999 58 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset sintel --datapath /ssd/rob_flow/training/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 448 --fac 1.4 59 | python eval_tmp.py --path ./weights/$modelname/ --vis no --dataset sintel 60 | ``` 61 | Evaluation error on sintel validation images: Fl-err = 7.9, EPE = 2.351 62 | 63 | 64 | ## Train the model 65 | We follow the same stage-wise training procedure as prior work: Chairs->Things->KITTI or Chairs->Things->Sintel, but uses much lesser iterations. 66 | If you plan to train the model and reproduce the numbers, please check out our [supplementary material](https://papers.nips.cc/paper/8367-volumetric-correspondence-networks-for-optical-flow) for the differences in hyper-parameters with FlowNet2 and PWCNet. 67 | 68 | #### Pretrain on flying chairs and flying things 69 | Make sure you have downloaded [flying chairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) 70 | and [flying things **subset**](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html), 71 | and placed them under the same folder, say /ssd/. 72 | 73 | To first train on flying chairs for 140k iterations with a batchsize of 8, run (assuming you have two gpus) 74 | ``` 75 | CUDA_VISIBLE_DEVICES=0,1 python main.py --maxdisp 256 --fac 1 --database /ssd/ --logname chairs-0 --savemodel /data/ptmodel/ --epochs 1000 --stage chairs --ngpus 2 76 | ``` 77 | Then we want to fine-tune on flying things for 80k iterations with a batchsize of 8, resume from your pre-trained model or use our pretrained model 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0,1 python main.py --maxdisp 256 --fac 1 --database /ssd/ --logname things-0 --savemodel /data/ptmodel/ --epochs 1000 --stage things --ngpus 2 --loadmodel ./weights/charis/finetune_141999.tar --retrain false 80 | ``` 81 | Note that to resume the number of iterations, put the iteration to start from in iter_counts-(your suffix).txt. In this example, I'll put 141999 in iter_counts-0.txt. 82 | Be aware that the program reads/writes to iter_counts-(suffix).txt at training time, so you may want to use different suffix when multiple training programs are running at the same time. 83 | 84 | #### Finetune on KITTI / Sintel 85 | Please first download the kitti 2012/2015 flow dataset if you want to fine-tune on kitti. 86 | Download [rob_devkit](http://www.cvlibs.net:3000/ageiger/rob_devkit/src/flow/flow) if you want to fine-tune on sintel. 87 | 88 | To fine-tune on KITTI with a batchsize of 16, run 89 | ``` 90 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --maxdisp 512 --fac 2 --database /ssd/ --logname kitti-trainval-0 --savemodel /data/ptmodel/ --epochs 1000 --stage 2015trainval --ngpus 4 --loadmodel ./weights/things/finetune_211999.tar --retrain true 91 | ``` 92 | To fine-tune on Sintel with a batchsize of 16, run 93 | ``` 94 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --maxdisp 448 --fac 1.4 --database /ssd/ --logname sintel-trainval-0 --savemodel /data/ptmodel/ --epochs 1000 --stage sinteltrainval --ngpus 4 --loadmodel ./weights/things/finetune_239999.tar --retrain true 95 | ``` 96 | 97 | #### Note on train-val 98 | - To tune hyper-parameters, we use a train-val split for kitti and sintel, which is not covered by the 99 | above procedure. 100 | - For kitti we use every 5th image in the training set (0,5,10,...195) for validation, and the rest for training; while for Sintel, we manually select several sequences for validation. 101 | - If you plan to use our split, put "--stage 2015train" or "--stage sinteltrain" for training. 102 | - The numbers in Tab.3 of the paper is on the whole train-val set (all the data with ground-truth). 103 | - You might find run.sh helpful to run evaluation on KITTI/Sintel. 104 | 105 | ## Measure FLOPS 106 | ``` 107 | python flops.py 108 | ``` 109 | gives 110 | 111 | PWCNet: flops(G)/params(M):90.8/9.37 112 | 113 | VCN: flops(G)/params(M):96.5/6.23 114 | 115 | #### Note on inference time 116 | The current implementation runs at 180ms/pair on KITTI-sized images at inference time. 117 | A rough breakdown of running time is: feature extraction - 4.9%, feature correlation - 8.7%, separable 4D convolutions - 56%, trun. soft-argmin (soft winner-take-all) - 20% and hypotheses fusion - 9.5%. 118 | A detailed breakdown is shown below in the form "name-level percentage". 119 | 120 | 121 | 122 | Note that separable 4D convolutions use less FLOPS than 2D convolutions (i.e., feature extraction module + hypotheses fusion module, 47.8 v.s. 53.3 Gflops) 123 | but take 4X more time (56% v.s. 14.4%). One reason might be that pytorch (also other packages) is more friendly to networks with more feature channels than those with large spatial size given the same Flops. This might be fixed at the conv kernel / hardware level. 124 | 125 | Besides, the truncated soft-argmin is implemented with 3D max pooling, which is inefficient and takes more time than expected. 126 | 127 | ## Acknowledgement 128 | Thanks [ClementPinard](https://github.com/ClementPinard), [Lyken17](https://github.com/Lyken17), [NVlabs](https://github.com/NVlabs) and many others for open-sourcing their code. 129 | - Pytorch op counter thop is modified from [pytorch-OpCounter](https://github.com/Lyken17/pytorch-OpCounter) 130 | - Correlation module is modified from [Pytorch-Correlation-extension](https://github.com/ClementPinard/Pytorch-Correlation-extension) 131 | - Full 4D convolution is taken from [NCNet](https://github.com/ignacio-rocco/ncnet), but is not used for our model (only used in Ablation study). 132 | 133 | ## Citation 134 | ``` 135 | @inproceedings{yang2019vcn, 136 | title={Volumetric Correspondence Networks for Optical Flow}, 137 | author={Yang, Gengshan and Ramanan, Deva}, 138 | booktitle={NeurIPS}, 139 | year={2019} 140 | } 141 | ``` 142 | -------------------------------------------------------------------------------- /dataloader/MiddleburyList.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import glob 3 | import pdb 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | 9 | 10 | def dataloader(filepath,res='Q'): 11 | filepath = '%s/training%s'%(filepath,res) 12 | img_list = [i.split('/')[-1] for i in glob.glob('%s/*'%(filepath)) if os.path.isdir(i)] 13 | 14 | left_train = ['%s/%s/im0.png'% (filepath,img) for img in img_list] 15 | right_train = ['%s/%s/im1.png'% (filepath,img) for img in img_list] 16 | disp_train_L = ['%s/%s/disp0GT.pfm' % (filepath,img) for img in img_list] 17 | disp_train_R = ['%s/%s/disp1GT.pfm' % (filepath,img) for img in img_list] 18 | 19 | return left_train, right_train, disp_train_L, left_train, right_train, disp_train_R 20 | 21 | -------------------------------------------------------------------------------- /dataloader/MiddleburySubmit.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import pdb 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | import glob 9 | 10 | 11 | def dataloader(filepath): 12 | img_list = ['Adirondack', 'Motorcycle', 'PianoL', 13 | 'Playtable', 'Shelves', 'ArtL', 14 | 'MotorcycleE', 'Pipes', 'PlaytableP', 15 | 'Teddy', 'Jadeplant', 'Piano', 16 | 'Playroom', 'Recycle', 'Vintage'] 17 | 18 | img_list = ['Australia', 'Bicycle2', 'Classroom2E', 19 | 'Crusade', 'Djembe', 'Hoops', 'Newkuba', 20 | 'Staircase', 'AustraliaP', 'Classroom2', 'Computer', 21 | 'CrusadeP', 'DjembeL', 'Livingroom', 'Plants'] 22 | 23 | img_list = [i.split('/')[-1] for i in glob.glob('%s/*'%filepath) if os.path.isdir(i)] 24 | #img_list *= 10 25 | 26 | left_train = ['%s/%s/im0.png'% (filepath,img) for img in img_list] 27 | right_train = ['%s/%s/im1.png'% (filepath,img) for img in img_list] 28 | 29 | 30 | return left_train, right_train, left_train 31 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/chairslist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | l0_train = [] 20 | l1_train = [] 21 | flow_train = [] 22 | for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))): 23 | root_filename = flow_map[:-9] 24 | img1 = root_filename+'_img1.ppm' 25 | img2 = root_filename+'_img2.ppm' 26 | if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))): 27 | continue 28 | 29 | l0_train.append(img1) 30 | l1_train.append(img2) 31 | flow_train.append(flow_map) 32 | 33 | return l0_train, l1_train, flow_train 34 | -------------------------------------------------------------------------------- /dataloader/flow_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | import numbers 6 | import types 7 | import scipy.ndimage as ndimage 8 | import pdb 9 | import torchvision 10 | import PIL.Image as Image 11 | import cv2 12 | from torch.nn import functional as F 13 | 14 | 15 | class Compose(object): 16 | """ Composes several co_transforms together. 17 | For example: 18 | >>> co_transforms.Compose([ 19 | >>> co_transforms.CenterCrop(10), 20 | >>> co_transforms.ToTensor(), 21 | >>> ]) 22 | """ 23 | 24 | def __init__(self, co_transforms): 25 | self.co_transforms = co_transforms 26 | 27 | def __call__(self, input, target): 28 | for t in self.co_transforms: 29 | input,target = t(input,target) 30 | return input,target 31 | 32 | 33 | class Scale(object): 34 | """ Rescales the inputs and target arrays to the given 'size'. 35 | 'size' will be the size of the smaller edge. 36 | For example, if height > width, then image will be 37 | rescaled to (size * height / width, size) 38 | size: size of the smaller edge 39 | interpolation order: Default: 2 (bilinear) 40 | """ 41 | 42 | def __init__(self, size, order=1): 43 | self.ratio = size 44 | self.order = order 45 | if order==0: 46 | self.code=cv2.INTER_NEAREST 47 | elif order==1: 48 | self.code=cv2.INTER_LINEAR 49 | elif order==2: 50 | self.code=cv2.INTER_CUBIC 51 | 52 | def __call__(self, inputs, target): 53 | if self.ratio==1: 54 | return inputs, target 55 | h, w, _ = inputs[0].shape 56 | ratio = self.ratio 57 | 58 | inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) 59 | inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR) 60 | # keep the mask same 61 | tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST) 62 | target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio 63 | target[:,:,2] = tmp 64 | 65 | 66 | return inputs, target 67 | 68 | 69 | 70 | 71 | class SpatialAug(object): 72 | def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False): 73 | self.crop = crop 74 | self.scale = scale 75 | self.rot = rot 76 | self.trans = trans 77 | self.squeeze = squeeze 78 | self.t = np.zeros(6) 79 | self.schedule_coeff = schedule_coeff 80 | self.order = order 81 | self.black = black 82 | 83 | def to_identity(self): 84 | self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0; 85 | 86 | def left_multiply(self, u0, u1, u2, u3, u4, u5): 87 | result = np.zeros(6) 88 | result[0] = self.t[0]*u0 + self.t[1]*u2; 89 | result[1] = self.t[0]*u1 + self.t[1]*u3; 90 | 91 | result[2] = self.t[2]*u0 + self.t[3]*u2; 92 | result[3] = self.t[2]*u1 + self.t[3]*u3; 93 | 94 | result[4] = self.t[4]*u0 + self.t[5]*u2 + u4; 95 | result[5] = self.t[4]*u1 + self.t[5]*u3 + u5; 96 | self.t = result 97 | 98 | def inverse(self): 99 | result = np.zeros(6) 100 | a = self.t[0]; c = self.t[2]; e = self.t[4]; 101 | b = self.t[1]; d = self.t[3]; f = self.t[5]; 102 | 103 | denom = a*d - b*c; 104 | 105 | result[0] = d / denom; 106 | result[1] = -b / denom; 107 | result[2] = -c / denom; 108 | result[3] = a / denom; 109 | result[4] = (c*f-d*e) / denom; 110 | result[5] = (b*e-a*f) / denom; 111 | 112 | return result 113 | 114 | def grid_transform(self, meshgrid, t, normalize=True, gridsize=None): 115 | if gridsize is None: 116 | h, w = meshgrid[0].shape 117 | else: 118 | h, w = gridsize 119 | vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis], 120 | (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1) 121 | if normalize: 122 | vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0 123 | vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0 124 | return vgrid 125 | 126 | 127 | def __call__(self, inputs, target): 128 | h, w, _ = inputs[0].shape 129 | th, tw = self.crop 130 | meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1] 131 | cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1] 132 | 133 | for i in range(50): 134 | # im0 135 | self.to_identity() 136 | #TODO add mirror 137 | if np.random.binomial(1,0.5): 138 | mirror = True 139 | else: 140 | mirror = False 141 | ##TODO 142 | #mirror = False 143 | if mirror: 144 | self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); 145 | else: 146 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); 147 | scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1; 148 | if not self.rot is None: 149 | rot0 = np.random.uniform(-self.rot[0],+self.rot[0]) 150 | rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0 151 | self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0) 152 | if not self.trans is None: 153 | trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2) 154 | trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0 155 | self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th) 156 | if not self.squeeze is None: 157 | squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0])) 158 | squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0 159 | if not self.scale is None: 160 | scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0])) 161 | scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0 162 | self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0) 163 | 164 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); 165 | transmat0 = self.t.copy() 166 | 167 | # im1 168 | self.to_identity() 169 | if mirror: 170 | self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th); 171 | else: 172 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); 173 | if not self.rot is None: 174 | self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0) 175 | if not self.trans is None: 176 | self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th) 177 | self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0) 178 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); 179 | transmat1 = self.t.copy() 180 | transmat1_inv = self.inverse() 181 | 182 | if self.black: 183 | # black augmentation, allowing 0 values in the input images 184 | # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu 185 | break 186 | else: 187 | if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\ 188 | (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0: 189 | break 190 | if i==49: 191 | print('max_iter in augmentation') 192 | self.to_identity() 193 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th); 194 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h); 195 | transmat0 = self.t.copy() 196 | transmat1 = self.t.copy() 197 | 198 | # do the real work 199 | vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)]) 200 | inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) 201 | if self.order == 0: 202 | target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) 203 | else: 204 | target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) 205 | 206 | mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan 207 | if self.order == 0: 208 | mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0) 209 | else: 210 | mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) 211 | mask_0[torch.isnan(mask_0)] = 0 212 | 213 | 214 | vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)]) 215 | inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0) 216 | 217 | # flow 218 | pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False) 219 | pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False) 220 | if target_0.shape[2]>=4: 221 | # scale 222 | exp = target_0[:,:,3:] * scale1 / scale0 223 | target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), 224 | (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), 225 | mask_0, 226 | exp], -1) 227 | else: 228 | target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1), 229 | (pos[:,:,1] - meshgrid[1]).unsqueeze(-1), 230 | mask_0], -1) 231 | # target_0[:,:,2].unsqueeze(-1) ], -1) 232 | inputs = [np.asarray(inputs_0), np.asarray(inputs_1)] 233 | target = np.asarray(target) 234 | 235 | return inputs,target 236 | 237 | 238 | class pseudoPCAAug(object): 239 | """ 240 | Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 241 | This version is faster. 242 | """ 243 | def __init__(self, schedule_coeff=1): 244 | self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14) 245 | 246 | def __call__(self, inputs, target): 247 | inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255. 248 | inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255. 249 | return inputs,target 250 | 251 | 252 | class PCAAug(object): 253 | """ 254 | Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 255 | """ 256 | def __init__(self, lmult_pow =[0.4, 0,-0.2], 257 | lmult_mult =[0.4, 0,0, ], 258 | lmult_add =[0.03,0,0, ], 259 | sat_pow =[0.4, 0,0, ], 260 | sat_mult =[0.5, 0,-0.3], 261 | sat_add =[0.03,0,0, ], 262 | col_pow =[0.4, 0,0, ], 263 | col_mult =[0.2, 0,0, ], 264 | col_add =[0.02,0,0, ], 265 | ladd_pow =[0.4, 0,0, ], 266 | ladd_mult =[0.4, 0,0, ], 267 | ladd_add =[0.04,0,0, ], 268 | col_rotate =[1., 0,0, ], 269 | schedule_coeff=1): 270 | # no mean 271 | self.pow_nomean = [1,1,1] 272 | self.add_nomean = [0,0,0] 273 | self.mult_nomean = [1,1,1] 274 | self.pow_withmean = [1,1,1] 275 | self.add_withmean = [0,0,0] 276 | self.mult_withmean = [1,1,1] 277 | self.lmult_pow = 1 278 | self.lmult_mult = 1 279 | self.lmult_add = 0 280 | self.col_angle = 0 281 | if not ladd_pow is None: 282 | self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0])) 283 | if not col_pow is None: 284 | self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0])) 285 | self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0])) 286 | 287 | if not ladd_add is None: 288 | self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0]) 289 | if not col_add is None: 290 | self.add_nomean[1] =np.random.normal(col_add[2], col_add[0]) 291 | self.add_nomean[2] =np.random.normal(col_add[2], col_add[0]) 292 | 293 | if not ladd_mult is None: 294 | self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0])) 295 | if not col_mult is None: 296 | self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0])) 297 | self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0])) 298 | 299 | # with mean 300 | if not sat_pow is None: 301 | self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0])) 302 | self.pow_withmean[2] =self.pow_withmean[1] 303 | if not sat_add is None: 304 | self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0]) 305 | self.add_withmean[2] =self.add_withmean[1] 306 | if not sat_mult is None: 307 | self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0])) 308 | self.mult_withmean[2] = self.mult_withmean[1] 309 | 310 | if not lmult_pow is None: 311 | self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0])) 312 | if not lmult_mult is None: 313 | self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0])) 314 | if not lmult_add is None: 315 | self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0]) 316 | if not col_rotate is None: 317 | self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0]) 318 | 319 | # eigen vectors 320 | self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose() 321 | 322 | 323 | def __call__(self, inputs, target): 324 | inputs[0] = self.pca_image(inputs[0]) 325 | inputs[1] = self.pca_image(inputs[1]) 326 | return inputs,target 327 | 328 | def pca_image(self, rgb): 329 | eig = np.dot(rgb, self.eigvec) 330 | max_rgb = np.clip(rgb,0,np.inf).max((0,1)) 331 | min_rgb = rgb.min((0,1)) 332 | mean_rgb = rgb.mean((0,1)) 333 | max_abs_eig = np.abs(eig).max((0,1)) 334 | max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig)) 335 | mean_eig = np.dot(mean_rgb, self.eigvec) 336 | 337 | # no-mean stuff 338 | eig -= mean_eig[np.newaxis, np.newaxis] 339 | 340 | for c in range(3): 341 | if max_abs_eig[c] > 1e-2: 342 | mean_eig[c] /= max_abs_eig[c] 343 | eig[:,:,c] = eig[:,:,c] / max_abs_eig[c]; 344 | eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\ 345 | ((eig[:,:,c] > 0) -0.5)*2 346 | eig[:,:,c] = eig[:,:,c] + self.add_nomean[c] 347 | eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c] 348 | eig += mean_eig[np.newaxis,np.newaxis] 349 | 350 | # withmean stuff 351 | if max_abs_eig[0] > 1e-2: 352 | eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \ 353 | ((eig[:,:,0]>0)-0.5)*2; 354 | eig[:,:,0] = eig[:,:,0] + self.add_withmean[0]; 355 | eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0]; 356 | 357 | s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2]) 358 | smask = s > 1e-2 359 | s1 = np.power(s, self.pow_withmean[1]); 360 | s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf) 361 | s1 = s1 * self.mult_withmean[1] 362 | s1 = s1 * smask + s*(1-smask) 363 | 364 | # color angle 365 | if self.col_angle!=0: 366 | temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2] 367 | temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2] 368 | eig[:,:,1] = temp1 369 | eig[:,:,2] = temp2 370 | 371 | # to origin magnitude 372 | for c in range(3): 373 | if max_abs_eig[c] > 1e-2: 374 | eig[:,:,c] = eig[:,:,c] * max_abs_eig[c] 375 | 376 | if max_l > 1e-2: 377 | l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) 378 | l1 = l1 / max_l 379 | 380 | eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask] 381 | eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask] 382 | #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask) 383 | #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask) 384 | 385 | if max_l > 1e-2: 386 | l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2]) 387 | l1 = np.power(l1, self.lmult_pow) 388 | l1 = np.clip(l1 + self.lmult_add, 0, np.inf) 389 | l1 = l1 * self.lmult_mult 390 | l1 = l1 * max_l 391 | lmask = l > 1e-2 392 | eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask] 393 | for c in range(3): 394 | eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask] 395 | # for c in range(3): 396 | # # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask) 397 | # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] 398 | # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask) 399 | 400 | return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1) 401 | 402 | 403 | class ChromaticAug(object): 404 | """ 405 | Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 406 | """ 407 | def __init__(self, noise = 0.06, 408 | gamma = 0.02, 409 | brightness = 0.02, 410 | contrast = 0.02, 411 | color = 0.02, 412 | schedule_coeff=1): 413 | 414 | self.noise = np.random.uniform(0,noise) 415 | self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff)) 416 | self.brightness = np.random.normal(0, brightness*schedule_coeff) 417 | self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff)) 418 | self.color = np.exp(np.random.normal(0, color*schedule_coeff,3)) 419 | 420 | def __call__(self, inputs, target): 421 | inputs[1] = self.chrom_aug(inputs[1]) 422 | # noise 423 | inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape) 424 | inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape) 425 | return inputs,target 426 | 427 | def chrom_aug(self, rgb): 428 | # color change 429 | mean_in = rgb.sum(-1) 430 | rgb = rgb*self.color[np.newaxis,np.newaxis] 431 | brightness_coeff = mean_in / (rgb.sum(-1)+0.01) 432 | rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1) 433 | # gamma 434 | rgb = np.power(rgb,self.gamma) 435 | # brightness 436 | rgb += self.brightness 437 | # contrast 438 | rgb = 0.5 + ( rgb-0.5)*self.contrast 439 | rgb = np.clip(rgb, 0, 1) 440 | return rgb 441 | -------------------------------------------------------------------------------- /dataloader/hd1klist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1] 22 | train = sorted(train) 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 26 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 27 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 28 | 29 | return l0_train, l1_train, flow_train 30 | -------------------------------------------------------------------------------- /dataloader/kitti12list.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'colored_0/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 26 | flow_train = [filepath+flow_noc+img for img in train] 27 | 28 | 29 | return l0_train, l1_train, flow_train 30 | -------------------------------------------------------------------------------- /dataloader/kitti15list.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 26 | flow_train = [filepath+flow_noc+img for img in train] 27 | 28 | 29 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 30 | -------------------------------------------------------------------------------- /dataloader/kitti15list_train.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | train = [i for i in train if int(i.split('_')[0])%5!=0] 25 | 26 | l0_train = [filepath+left_fold+img for img in train] 27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 28 | flow_train = [filepath+flow_noc+img for img in train] 29 | 30 | 31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 32 | -------------------------------------------------------------------------------- /dataloader/kitti15list_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | train = [i for i in train if int(i.split('_')[0])%5==0] 25 | 26 | l0_train = [filepath+left_fold+img for img in train] 27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 28 | flow_train = [filepath+flow_noc+img for img in train] 29 | 30 | 31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 32 | -------------------------------------------------------------------------------- /dataloader/kitticliplist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | train = [img for img in sorted(glob.glob('%s*'%filepath))] 21 | 22 | l0_train = train[:-1] 23 | l1_train = train[1:] 24 | 25 | 26 | return sorted(l0_train), sorted(l1_train), sorted(l0_train) 27 | -------------------------------------------------------------------------------- /dataloader/listflowfile.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | def dataloader(filepath): 17 | 18 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))] 19 | image = [img for img in classes if img.find('frames_cleanpass') > -1] 20 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1] 21 | 22 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0] 23 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0] 24 | 25 | 26 | monkaa_dir = os.listdir(monkaa_path) 27 | 28 | all_left_img=[] 29 | all_right_img=[] 30 | all_left_disp = [] 31 | all_right_disp = [] 32 | test_left_img=[] 33 | test_right_img=[] 34 | test_left_disp = [] 35 | test_right_disp = [] 36 | 37 | 38 | for dd in monkaa_dir: 39 | for im in os.listdir(monkaa_path+'/'+dd+'/left/'): 40 | if is_image_file(monkaa_path+'/'+dd+'/left/'+im): 41 | all_left_img.append(monkaa_path+'/'+dd+'/left/'+im) 42 | all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm') 43 | all_right_disp.append(monkaa_disp+'/'+dd+'/right/'+im.split(".")[0]+'.pfm') 44 | 45 | for im in os.listdir(monkaa_path+'/'+dd+'/right/'): 46 | if is_image_file(monkaa_path+'/'+dd+'/right/'+im): 47 | all_right_img.append(monkaa_path+'/'+dd+'/right/'+im) 48 | 49 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0] 50 | flying_disp = filepath + [x for x in disp if x == 'disparity'][0] 51 | flying_dir = flying_path+'/TRAIN/' 52 | subdir = ['A','B','C'] 53 | 54 | for ss in subdir: 55 | flying = os.listdir(flying_dir+ss) 56 | 57 | for ff in flying: 58 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') 59 | for im in imm_l: 60 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): 61 | all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) 62 | 63 | all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') 64 | all_right_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/right/'+im.split(".")[0]+'.pfm') 65 | 66 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): 67 | all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) 68 | 69 | flying_dir = flying_path+'/TEST/' 70 | 71 | subdir = ['A','B','C'] 72 | 73 | for ss in subdir: 74 | flying = os.listdir(flying_dir+ss) 75 | 76 | for ff in flying: 77 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/') 78 | for im in imm_l: 79 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im): 80 | test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im) 81 | 82 | test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm') 83 | test_right_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/right/'+im.split(".")[0]+'.pfm') 84 | 85 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im): 86 | test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im) 87 | 88 | 89 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/' 90 | driving_disp = filepath + [x for x in disp if 'driving' in x][0] 91 | 92 | #subdir1 = ['35mm_focallength'] 93 | #subdir2 = ['scene_backwards'] 94 | #subdir3 = ['fast'] 95 | ##TODO was using 15 only 96 | subdir1 = ['35mm_focallength','15mm_focallength'] 97 | subdir2 = ['scene_backwards','scene_forwards'] 98 | subdir3 = ['fast','slow'] 99 | 100 | #subdir1 = ['35mm_focallength'] 101 | #subdir2 = ['scene_backwards'] 102 | #subdir3 = ['fast'] 103 | 104 | for i in subdir1: 105 | for j in subdir2: 106 | for k in subdir3: 107 | imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/') 108 | for im in imm_l: 109 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im): 110 | all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im) 111 | all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm') 112 | all_right_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/right/'+im.split(".")[0]+'.pfm') 113 | 114 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im): 115 | all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im) 116 | 117 | 118 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp, all_right_disp, test_right_disp 119 | 120 | 121 | -------------------------------------------------------------------------------- /dataloader/mblist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if 'Middlebury' in img and img.find('_10') > -1] 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 26 | flow_train = [filepath+flow_noc+img for img in train] 27 | 28 | 29 | return l0_train, l1_train, flow_train 30 | -------------------------------------------------------------------------------- /dataloader/robloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numbers 3 | import torch 4 | import torch.utils.data as data 5 | import torch 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image, ImageOps 9 | import numpy as np 10 | import torchvision 11 | from . import flow_transforms 12 | import pdb 13 | import cv2 14 | from utils.flowlib import read_flow 15 | from utils.util_flow import readPFM 16 | 17 | 18 | def default_loader(path): 19 | return Image.open(path).convert('RGB') 20 | 21 | def flow_loader(path): 22 | if '.pfm' in path: 23 | data = readPFM(path)[0] 24 | data[:,:,2] = 1 25 | return data 26 | else: 27 | return read_flow(path) 28 | 29 | 30 | def disparity_loader(path): 31 | if '.png' in path: 32 | data = Image.open(path) 33 | data = np.ascontiguousarray(data,dtype=np.float32)/256 34 | return data 35 | else: 36 | return readPFM(path)[0] 37 | 38 | class myImageFloder(data.Dataset): 39 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False): 40 | self.iml0 = iml0 41 | self.iml1 = iml1 42 | self.flowl0 = flowl0 43 | self.loader = loader 44 | self.dploader = dploader 45 | self.scale=scale 46 | self.shape=shape 47 | self.order=order 48 | self.noise = noise 49 | self.pca_augmentor = pca_augmentor 50 | self.prob = prob 51 | self.cover = cover 52 | self.black = black 53 | 54 | def __getitem__(self, index): 55 | iml0 = self.iml0[index] 56 | iml1 = self.iml1[index] 57 | flowl0= self.flowl0[index] 58 | th, tw = self.shape 59 | 60 | iml0 = self.loader(iml0) 61 | iml1 = self.loader(iml1) 62 | iml1 = np.asarray(iml1)/255. 63 | iml0 = np.asarray(iml0)/255. 64 | iml0 = iml0[:,:,::-1].copy() 65 | iml1 = iml1[:,:,::-1].copy() 66 | flowl0 = self.dploader(flowl0) 67 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32) 68 | flowl0[np.isnan(flowl0)] = 1e6 # set to max 69 | 70 | ## following data augmentation procedure in PWCNet 71 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 72 | import __main__ # a workaround for "discount_coeff" 73 | try: 74 | with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f: 75 | iter_counts = int(f.readline()) 76 | except: 77 | iter_counts = 0 78 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life 79 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \ 80 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1) 81 | 82 | if self.pca_augmentor: 83 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff) 84 | else: 85 | pca_augmentor = flow_transforms.Scale(1., order=0) 86 | 87 | if np.random.binomial(1,self.prob): 88 | co_transform = flow_transforms.Compose([ 89 | flow_transforms.Scale(self.scale, order=self.order), 90 | flow_transforms.SpatialAug([th,tw],scale=[0.4,0.03,0.2], 91 | rot=[0.4,0.03], 92 | trans=[0.4,0.03], 93 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black), 94 | flow_transforms.PCAAug(schedule_coeff=schedule_coeff), 95 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise), 96 | ]) 97 | else: 98 | co_transform = flow_transforms.Compose([ 99 | flow_transforms.Scale(self.scale, order=self.order), 100 | flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black) 101 | ]) 102 | 103 | augmented,flowl0 = co_transform([iml0, iml1], flowl0) 104 | iml0 = augmented[0] 105 | iml1 = augmented[1] 106 | 107 | if self.cover: 108 | ## randomly cover a region 109 | # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html 110 | if np.random.binomial(1,0.5): 111 | #sx = int(np.random.uniform(25,100)) 112 | #sy = int(np.random.uniform(25,100)) 113 | sx = int(np.random.uniform(50,125)) 114 | sy = int(np.random.uniform(50,125)) 115 | #sx = int(np.random.uniform(50,150)) 116 | #sy = int(np.random.uniform(50,150)) 117 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx)) 118 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy)) 119 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis] 120 | 121 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1))) 122 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1))) 123 | 124 | return iml0, iml1, flowl0 125 | 126 | def __len__(self): 127 | return len(self.iml0) 128 | -------------------------------------------------------------------------------- /dataloader/sintellist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /dataloader/sintellist_clean.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | return l0_train, l1_train, flow_train 32 | -------------------------------------------------------------------------------- /dataloader/sintellist_final.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | pdb.set_trace() 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /dataloader/sintellist_train.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /dataloader/sintellist_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val 27 | #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val 28 | 29 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 30 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 31 | 32 | 33 | return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3] 34 | # return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10] 35 | -------------------------------------------------------------------------------- /dataloader/stereo_kittilist12.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'colored_0/' 20 | right_fold = 'colored_1/' 21 | disp_noc = 'disp_occ/' 22 | 23 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 24 | 25 | train = image[:] 26 | val = image[160:] 27 | 28 | left_train = [filepath+left_fold+img for img in train] 29 | right_train = [filepath+right_fold+img for img in train] 30 | disp_train = [filepath+disp_noc+img for img in train] 31 | 32 | 33 | left_val = [filepath+left_fold+img for img in val] 34 | right_val = [filepath+right_fold+img for img in val] 35 | disp_val = [filepath+disp_noc+img for img in val] 36 | 37 | return left_train, right_train, disp_train, left_val, right_val, disp_val 38 | -------------------------------------------------------------------------------- /dataloader/stereo_kittilist15.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | import pdb 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath, typ = 'train'): 19 | 20 | left_fold = 'image_2/' 21 | right_fold = 'image_3/' 22 | disp_L = 'disp_occ_0/' 23 | disp_R = 'disp_occ_1/' 24 | 25 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 26 | image = sorted(image) 27 | imglist = [1,3,6,20,26,35,38,41,43,44,49,60,67,70,81,84,89,97,109,119,122,123,129,130,132,134,141,144,152,158,159,165,171,174,179,182, 184,186,187,196] 28 | if typ == 'train': 29 | train = [image[i] for i in range(200) if i not in imglist] 30 | elif typ == 'trainval': 31 | train = [image[i] for i in range(200)] 32 | val = [image[i] for i in imglist] 33 | 34 | left_train = [filepath+left_fold+img for img in train] 35 | right_train = [filepath+right_fold+img for img in train] 36 | disp_train_L = [filepath+disp_L+img for img in train] 37 | #disp_train_R = [filepath+disp_R+img for img in train] 38 | 39 | left_val = [filepath+left_fold+img for img in val] 40 | right_val = [filepath+right_fold+img for img in val] 41 | disp_val_L = [filepath+disp_L+img for img in val] 42 | #disp_val_R = [filepath+disp_R+img for img in val] 43 | 44 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L 45 | -------------------------------------------------------------------------------- /dataloader/thingslist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | exc_list = [ 19 | '0004117.flo', 20 | '0003149.flo', 21 | '0001203.flo', 22 | '0003147.flo', 23 | '0003666.flo', 24 | '0006337.flo', 25 | '0006336.flo', 26 | '0007126.flo', 27 | '0004118.flo', 28 | ] 29 | 30 | left_fold = 'image_clean/left/' 31 | flow_noc = 'flow/left/into_future/' 32 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 33 | 34 | l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train] 35 | l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf] 36 | flow_trainlf = [filepath+flow_noc+img for img in train] 37 | 38 | 39 | exc_list = [ 40 | '0003148.flo', 41 | '0004117.flo', 42 | '0002890.flo', 43 | '0003149.flo', 44 | '0001203.flo', 45 | '0003666.flo', 46 | '0006337.flo', 47 | '0006336.flo', 48 | '0004118.flo', 49 | ] 50 | 51 | left_fold = 'image_clean/right/' 52 | flow_noc = 'flow/right/into_future/' 53 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 54 | 55 | l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train] 56 | l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf] 57 | flow_trainrf = [filepath+flow_noc+img for img in train] 58 | 59 | 60 | exc_list = [ 61 | '0004237.flo', 62 | '0004705.flo', 63 | '0004045.flo', 64 | '0004346.flo', 65 | '0000161.flo', 66 | '0000931.flo', 67 | '0000121.flo', 68 | '0010822.flo', 69 | '0004117.flo', 70 | '0006023.flo', 71 | '0005034.flo', 72 | '0005054.flo', 73 | '0000162.flo', 74 | '0000053.flo', 75 | '0005055.flo', 76 | '0003147.flo', 77 | '0004876.flo', 78 | '0000163.flo', 79 | '0006878.flo', 80 | ] 81 | 82 | left_fold = 'image_clean/left/' 83 | flow_noc = 'flow/left/into_past/' 84 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 85 | 86 | l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train] 87 | l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp] 88 | flow_trainlp = [filepath+flow_noc+img for img in train] 89 | 90 | exc_list = [ 91 | '0003148.flo', 92 | '0004705.flo', 93 | '0000161.flo', 94 | '0000121.flo', 95 | '0004117.flo', 96 | '0000160.flo', 97 | '0005034.flo', 98 | '0005054.flo', 99 | '0000162.flo', 100 | '0000053.flo', 101 | '0005055.flo', 102 | '0003147.flo', 103 | '0001549.flo', 104 | '0000163.flo', 105 | '0006336.flo', 106 | '0001648.flo', 107 | '0006878.flo', 108 | ] 109 | 110 | left_fold = 'image_clean/right/' 111 | flow_noc = 'flow/right/into_past/' 112 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 113 | 114 | l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train] 115 | l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp] 116 | flow_trainrp = [filepath+flow_noc+img for img in train] 117 | 118 | 119 | l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp 120 | l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp 121 | flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp 122 | return l0_train, l1_train, flow_train 123 | -------------------------------------------------------------------------------- /dataset/IIW/hawk_000299.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/IIW/hawk_000299.png -------------------------------------------------------------------------------- /dataset/IIW/hawk_000300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/IIW/hawk_000300.png -------------------------------------------------------------------------------- /dataset/kitti_scene/testing/image_2/000042_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/kitti_scene/testing/image_2/000042_10.png -------------------------------------------------------------------------------- /dataset/kitti_scene/testing/image_2/000042_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/kitti_scene/testing/image_2/000042_11.png -------------------------------------------------------------------------------- /eval_tmp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0,'utils') 6 | from utils.flowlib import flow_to_image, read_flow, compute_color, visualize_flow 7 | from utils.io import mkdir_p 8 | import pdb 9 | import glob 10 | 11 | import argparse 12 | 13 | parser = argparse.ArgumentParser(description='') 14 | parser.add_argument('--path', default='/data/ptmodel/', 15 | help='database') 16 | parser.add_argument('--vis', default='no', 17 | help='database') 18 | parser.add_argument('--dataset', default='2015', 19 | help='database') 20 | args = parser.parse_args() 21 | 22 | aepe_s = [] 23 | fall_s64 = [] 24 | fall_s32 = [] 25 | fall_s16 = [] 26 | fall_s8 = [] 27 | fall_s = [] 28 | oor_tp = [] 29 | oor_fp = [] 30 | 31 | # dataloader 32 | if args.dataset == '2015': 33 | #from dataloader import kitti15list as DA 34 | #from dataloader import kitti15list_val_lidar as DA 35 | from dataloader import kitti15list_val as DA 36 | datapath = '/ssd/kitti_scene/training/' 37 | elif args.dataset == '2015test': 38 | from dataloader import kitti15list as DA 39 | datapath = '/ssd/kitti_scene/testing/' 40 | elif args.dataset == 'kitticlip': 41 | from dataloader import kitticliplist as DA 42 | #datapath = '/ssd/rob_flow/test/image_2/Kitti2015_000140_' 43 | datapath = '/data/gengshay/KITTI_png/2011_09_30/2011_09_30_drive_0028_sync/image_02/data/' 44 | elif args.dataset == 'tumclip': 45 | from dataloader import kitticliplist as DA 46 | datapath = '/data/gengshay/TUM/rgbd_dataset_freiburg1_plant/rgb/' 47 | elif args.dataset == '2012': 48 | from dataloader import kitti12list as DA 49 | datapath = '/ssd/data_stereo_flow/training/' 50 | elif args.dataset == '2012test': 51 | from dataloader import kitti12list as DA 52 | datapath = '/ssd/data_stereo_flow/testing/' 53 | elif args.dataset == 'mb': 54 | from dataloader import mblist as DA 55 | datapath = '/ssd/rob_flow/training/' 56 | elif args.dataset == 'sintel': 57 | #from dataloader import sintellist as DA 58 | from dataloader import sintellist_val as DA 59 | #from dataloader import sintellist_clean as DA 60 | datapath = '/ssd/rob_flow/training/' 61 | elif args.dataset == 'hd1k': 62 | from dataloader import hd1klist as DA 63 | datapath = '/ssd/rob_flow/training/' 64 | elif args.dataset == 'mbtest': 65 | from dataloader import mblist as DA 66 | datapath = '/ssd/rob_flow/test/' 67 | elif args.dataset == 'sinteltest': 68 | from dataloader import sintellist as DA 69 | datapath = '/ssd/rob_flow/test/' 70 | elif args.dataset == 'chairs': 71 | from dataloader import chairslist as DA 72 | datapath = '/ssd/FlyingChairs_release/data/' 73 | test_left_img, test_right_img ,flow_paths= DA.dataloader(datapath) 74 | 75 | #pdb.set_trace() 76 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_test1.txt','w') as f: 77 | # for i in test_left_img: 78 | # f.write(i+'\n') 79 | # 80 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_test2.txt','w') as f: 81 | # for i in test_right_img: 82 | # f.write(i+'\n') 83 | # 84 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_testout.txt','w') as f: 85 | # for i in test_left_img: 86 | # f.write('/data/ptmodel/pwcnet-1/sintel/%s.flo'%(i.split('/')[-1].split('.')[0])+'\n') 87 | #exit() 88 | if args.dataset == 'chairs': 89 | with open('FlyingChairs_train_val.txt', 'r') as f: 90 | split = [int(i) for i in f.readlines()] 91 | test_left_img = [test_left_img[i] for i,flag in enumerate(split) if flag==2] 92 | test_right_img = [test_right_img[i] for i,flag in enumerate(split) if flag==2] 93 | flow_paths = [flow_paths[i] for i,flag in enumerate(split) if flag==2] 94 | 95 | #pdb.set_trace() 96 | #test_left_img = [i for i in test_left_img if 'clean' in i] 97 | #test_right_img = [i for i in test_right_img if 'clean' in i] 98 | #flow_paths = [i for i in flow_paths if 'clean' in i] 99 | 100 | #for i,gtflow_path in enumerate(sorted(flow_paths)): 101 | for i,gtflow_path in enumerate(flow_paths): 102 | #if not 'Sintel_clean_cave_4_10' in gtflow_path: 103 | # continue 104 | #if i%10!=1: 105 | # continue 106 | num = gtflow_path.split('/')[-1].strip().replace('flow.flo','img1.png') 107 | if not 'test' in args.dataset and not 'clip' in args.dataset: 108 | gtflow = read_flow(gtflow_path) 109 | num = num.replace('jpg','png') 110 | flow = read_flow('%s/%s/%s'%(args.path,args.dataset,num)) 111 | if args.vis == 'yes': 112 | #flowimg = flow_to_image(flow) 113 | flowimg = flow_to_image(flow)*np.linalg.norm(flow[:,:,:2],2,2)[:,:,np.newaxis]/100./255. 114 | #gtflowimg = compute_color(gtflow[:,:,0]/20, gtflow[:,:,1]/20)/255. 115 | #flowimg = compute_color(flow[:,:,0]/20, flow[:,:,1]/20)/255. 116 | mkdir_p('%s/%s/flowimg'%(args.path,args.dataset)) 117 | plt.imsave('%s/%s/flowimg/%s'%(args.path,args.dataset,num), flowimg) 118 | if 'test' in args.dataset or 'clip' in args.dataset: 119 | continue 120 | gtflowimg = flow_to_image(gtflow) 121 | mkdir_p('%s/%s/gtimg'%(args.path,args.dataset)) 122 | plt.imsave('%s/%s/gtimg/%s'%(args.path,args.dataset,num), gtflowimg) 123 | 124 | mask = gtflow[:,:,2]==1 125 | 126 | ## occlusion 127 | #H,W,_ = gtflow.shape 128 | #xx = np.tile(np.asarray(range(0, W))[np.newaxis:],(H,1)) 129 | #yy = np.tile(np.asarray(range(0, H))[:,np.newaxis],(1,W)) 130 | #occmask = np.logical_or( np.logical_or(xx + gtflow[:,:,0] <0, xx + gtflow[:,:,0]>W-1), 131 | # np.logical_or(yy + gtflow[:,:,1] <0, yy + gtflow[:,:,1]>H-1)) 132 | #mask = np.logical_and(mask,~occmask) 133 | 134 | 135 | 136 | # if args.dataset == 'mb': 137 | # ##TODO 138 | # mask = np.logical_and(np.logical_and(np.abs(gtflow[:,:,0]) < 16,np.abs(gtflow[:,:,1]) < 16), mask) 139 | gtflow = gtflow[:,:,:2] 140 | flow = flow[:,:,:2] 141 | 142 | epe = np.sqrt(np.power(gtflow - flow,2).sum(-1))[mask] 143 | gt_mag = np.sqrt(np.power(gtflow,2).sum(-1))[mask] 144 | 145 | 146 | #aepe_s.append( epe.mean() ) 147 | #fall_s.append( np.sum(np.logical_and(epe > 3, epe/gt_mag > 0.05)) / float(epe.size) ) 148 | 149 | clippx = [0,1000] 150 | inrangepx = np.logical_and((np.abs(gtflow)>=clippx[0]).sum(-1), (np.abs(gtflow)clippx).sum(-1)>0) 154 | gtoorfp = mask*((np.abs(gtflow)>clippx).sum(-1)==0) 155 | oor_tp.append(isoor[gtoortp]) 156 | oor_fp.append(isoor[gtoorfp]) 157 | if args.vis == 'yes' and 'test' not in args.dataset: 158 | epeimg = np.sqrt(np.power(gtflow - flow,2).sum(-1))*(mask*(np.logical_and((np.abs(gtflow)>=clippx[0]).sum(-1), (np.abs(gtflow) 64)[inrangepx]) 164 | fall_s32.append( (epe > 32)[inrangepx]) 165 | fall_s16.append( (epe > 16)[inrangepx]) 166 | fall_s8.append( (epe > 8)[inrangepx]) 167 | fall_s.append( np.logical_and(epe > 3, epe/gt_mag > 0.05)[inrangepx]) 168 | # aepe_s.append( epe ) 169 | #fall_s.append( epe[gt_mag<32] > 8) 170 | # fall_s.append( np.logical_and(epe > 3, epe/gt_mag > 0.05)) 171 | # print(gtflow_path) 172 | #for i in [np.mean(i) for i in aepe_s]: 173 | # print('%f'%i) 174 | #for i in [np.mean(i) for i in fall_s]: 175 | # print('%f'%i) 176 | #print('\t%.1f/%.1f/%.1f/%.1f/%.1f/%.3f'%( 177 | # np.mean( 100*np.concatenate(fall_s64,0)), 178 | # np.mean( 100*np.concatenate(fall_s32,0)), 179 | # np.mean( 100*np.concatenate(fall_s16,0)), 180 | # np.mean( 100*np.concatenate(fall_s8,0)), 181 | # np.mean( 100*np.concatenate(fall_s,0)), 182 | # np.mean( np.concatenate(aepe_s,0))) ) 183 | print('\t%.1f/%.3f'%( 184 | np.mean( 100*np.concatenate(fall_s,0)), 185 | np.mean( np.concatenate(aepe_s,0))) ) 186 | #print('\t%.1f/%.1f'%(100*np.mean( np.concatenate(oor_tp,0) ), 100*np.mean( np.concatenate(oor_fp,0) )) ) 187 | -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/architecture.png -------------------------------------------------------------------------------- /figs/hawk-vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/hawk-vec.png -------------------------------------------------------------------------------- /figs/hawk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/hawk.png -------------------------------------------------------------------------------- /figs/kitti-test-42-vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/kitti-test-42-vec.png -------------------------------------------------------------------------------- /figs/kitti-test-42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/kitti-test-42.png -------------------------------------------------------------------------------- /figs/output-onlinepngtools.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/output-onlinepngtools.png -------------------------------------------------------------------------------- /figs/time-breakdown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/time-breakdown.png -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from thop import profile 3 | import pdb 4 | from models.PWCNet import PWCDCNet 5 | from models.VCN import VCN 6 | #from models.VCN_small import VCN 7 | 8 | vcn = VCN([1, 1280,384]) 9 | pwcnet = PWCDCNet([1, 1280,384]) 10 | 11 | f_pwc, p_pwc = profile(pwcnet, input_size=(2, 3, 384,1280), device='cuda') 12 | f_vcn, p_vcn = profile(vcn, input_size=(2, 3, 384,1280), device='cuda') 13 | print('PWCNet: \tflops(G)/params(M):%.1f/%.2f'%(f_pwc/1e9,p_pwc/1e6)) 14 | print('VCN: \t\tflops(G)/params(M):%.1f/%.2f'%(f_vcn/1e9,p_vcn/1e6)) 15 | 16 | 17 | #from models.conv4d import butterfly4D, sepConv4dBlock, sepConv4d 18 | # 19 | #model = torch.nn.Conv2d(30,30, (3,3), stride=1, padding=(1, 1), bias=False) 20 | #f, p = profile(model, input_size=(9*9,30,64,128), device='cuda') 21 | # 22 | ##model = torch.nn.Conv3d(30,30, (1,3,3), stride=1, padding=(0, 1, 1), bias=False) 23 | ##f, p = profile(model, input_size=(1,30,9*9,64,128), device='cuda') 24 | # 25 | # 26 | ###model = butterfly4D(64, 12,withbn=True,full=True) 27 | ###model = sepConv4dBlock(16,16,with_bn=True, stride=(1,1,1),full=True) 28 | ##model = sepConv4d(20, 20, (1,1,1), with_bn=True, full=True) 29 | ##f, p = profile(model, input_size=(1,20,9,9,64,128), device='cuda') 30 | # 31 | ##model = torch.nn.Conv2d(256,256,(3,3),padding=(1,1)) 32 | ##f, p = profile(model, input_size=(1,256,64,128), device='cuda') 33 | # 34 | #print('flops(G)/params(M):%.1f/%.2f'%(f/1e9,p/1e6)) 35 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import cv2 3 | cv2.setNumThreads(0) 4 | import sys 5 | import pdb 6 | import argparse 7 | import collections 8 | import os 9 | import random 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim as optim 15 | import torch.utils.data 16 | from torch.autograd import Variable 17 | import torch.nn.functional as F 18 | import numpy as np 19 | import time 20 | from utils.flowlib import flow_to_image 21 | from models import * 22 | from utils import logger 23 | torch.backends.cudnn.benchmark=True 24 | from utils.multiscaleloss import realEPE 25 | 26 | 27 | parser = argparse.ArgumentParser(description='PSMNet') 28 | parser.add_argument('--maxdisp', type=int ,default=256, 29 | help='maxium disparity, out of range pixels will be masked out. Only affect the coarsest cost volume size') 30 | parser.add_argument('--fac', type=float ,default=1, 31 | help='controls the shape of search grid. Only affect the coarsest cost volume size') 32 | parser.add_argument('--logname', default='logname', 33 | help='name of the log file') 34 | parser.add_argument('--database', default='/', 35 | help='path to the database') 36 | parser.add_argument('--epochs', type=int, default=300, 37 | help='number of epochs to train') 38 | parser.add_argument('--loadmodel', default=None, 39 | help='path of the pre-trained model') 40 | parser.add_argument('--model', default='VCN', 41 | help='VCN or VCN_small') 42 | parser.add_argument('--savemodel', default='./', 43 | help='path to save the model') 44 | parser.add_argument('--retrain', default='false', 45 | help='whether to reset moving mean / other hyperparameters') 46 | parser.add_argument('--stage', default='chairs', 47 | help='one of {chairs, things, 2015train, 2015trainval, sinteltrain, sinteltrainval}') 48 | parser.add_argument('--ngpus', type=int, default=2, 49 | help='number of gpus to use.') 50 | args = parser.parse_args() 51 | 52 | if args.model == 'VCN': 53 | from models.VCN import VCN 54 | elif args.model == 'VCN_small': 55 | from models.VCN_small import VCN 56 | 57 | # fix random seed 58 | torch.manual_seed(1) 59 | def _init_fn(worker_id): 60 | np.random.seed() 61 | random.seed() 62 | torch.manual_seed(8) # do it again 63 | torch.cuda.manual_seed(1) 64 | 65 | ## set hyperparameters for training 66 | ngpus = args.ngpus 67 | batch_size = 4*ngpus 68 | if args.stage == 'chairs' or args.stage == 'things': 69 | lr_schedule = 'slong_ours' 70 | else: 71 | lr_schedule = 'rob_ours' 72 | #baselr = 1e-4 73 | baselr = 1e-3 74 | worker_mul = int(2) 75 | #worker_mul = int(0) 76 | if args.stage == 'chairs' or args.stage == 'things': 77 | datashape = [320,448] 78 | elif '2015' in args.stage: 79 | datashape = [256,768] 80 | elif 'sintel' in args.stage: 81 | datashape = [320,576] 82 | else: 83 | print('error') 84 | exit(0) 85 | 86 | ## dataloader 87 | from dataloader import robloader as dr 88 | 89 | if args.stage == 'chairs' or 'sintel' in args.stage: 90 | # flying chairs 91 | from dataloader import chairslist as lc 92 | iml0, iml1, flowl0 = lc.dataloader('%s/FlyingChairs_release/data/'%args.database) 93 | with open('order.txt','r') as f: 94 | order = [int(i) for i in f.readline().split(' ')] 95 | with open('FlyingChairs_train_val.txt', 'r') as f: 96 | split = [int(i) for i in f.readlines()] 97 | iml0 = [iml0[i] for i in order if split[i]==1] 98 | iml1 = [iml1[i] for i in order if split[i]==1] 99 | flowl0 = [flowl0[i] for i in order if split[i]==1] 100 | loader_chairs = dr.myImageFloder(iml0,iml1,flowl0, shape = datashape) 101 | 102 | if args.stage == 'things' or 'sintel' in args.stage: 103 | # flything things 104 | from dataloader import thingslist as lt 105 | iml0, iml1, flowl0 = lt.dataloader('%s/FlyingThings3D_subset/train/'%args.database) 106 | loader_things = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=1) 107 | 108 | # fine-tuning datasets 109 | if args.stage == '2015train': 110 | from dataloader import kitti15list_train as lk15 111 | else: 112 | from dataloader import kitti15list as lk15 113 | if args.stage == 'sinteltrain': 114 | from dataloader import sintellist_train as ls 115 | else: 116 | from dataloader import sintellist as ls 117 | from dataloader import kitti12list as lk12 118 | from dataloader import hd1klist as lh 119 | 120 | if 'sintel' in args.stage: 121 | iml0, iml1, flowl0 = lk15.dataloader('%s/kitti_scene/training/'%args.database) 122 | loader_kitti15 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, noise=0) # SINTEL 123 | iml0, iml1, flowl0 = lh.dataloader('%s/rob_flow/training/'%args.database) 124 | loader_hd1k = dr.myImageFloder(iml0,iml1,flowl0,shape=datashape, scale=0.5,order=0, noise=0) 125 | iml0, iml1, flowl0 = ls.dataloader('%s/rob_flow/training/'%args.database) 126 | loader_sintel = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=1, noise=0) 127 | if '2015' in args.stage: 128 | iml0, iml1, flowl0 = lk12.dataloader('%s/data_stereo_flow/training/'%args.database) 129 | loader_kitti12 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, prob=0.5) 130 | iml0, iml1, flowl0 = lk15.dataloader('%s/kitti_scene/training/'%args.database) 131 | loader_kitti15 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, prob=0.5) # KITTI 132 | 133 | if args.stage=='chairs': 134 | data_inuse = torch.utils.data.ConcatDataset([loader_chairs]*100) 135 | elif args.stage=='things': 136 | data_inuse = torch.utils.data.ConcatDataset([loader_things]*100) 137 | elif '2015' in args.stage: 138 | data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*50+[loader_kitti12]*50) 139 | for i in data_inuse.datasets: 140 | i.black = True 141 | i.cover = True 142 | elif 'sintel' in args.stage: 143 | data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*200+[loader_hd1k]*40 + [loader_sintel]*150 + [loader_chairs]*2 + [loader_things]) 144 | for i in data_inuse.datasets: 145 | i.black = True 146 | i.cover = True 147 | else: 148 | print('error') 149 | exit(0) 150 | 151 | 152 | 153 | ## Stereo data 154 | #from dataloader import stereo_kittilist15 as lks15 155 | #from dataloader import stereo_kittilist12 as lks12 156 | #from dataloader import MiddleburyList as lmbs 157 | #from dataloader import listflowfile as lsfs 158 | # 159 | #def disparity_loader_sf(path): 160 | # from utils.util_flow import readPFM as rp 161 | # out = rp(path)[0] 162 | # shape = (out.shape[0], out.shape[1], 1) 163 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1) 164 | # return out 165 | # 166 | #def disparity_loader_mb(path): 167 | # from utils.util_flow import readPFM as rp 168 | # out = rp(path)[0] 169 | # mask = np.asarray(out!=np.inf,float)[:,:,np.newaxis] 170 | # out[out==np.inf]=0 171 | # 172 | # shape = (out.shape[0], out.shape[1], 1) 173 | ## out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1) 174 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),mask),-1) 175 | # return out 176 | # 177 | #def disparity_loader(path): 178 | ## from utils.util_flow import readPFM as rp 179 | ## out = rp(path)[0] 180 | # from PIL import Image 181 | # out = Image.open(path) 182 | # out = np.ascontiguousarray(out,dtype=np.float32)/256 183 | # mask = np.asarray(out>0,float)[:,:,np.newaxis] 184 | # 185 | # shape = (out.shape[0], out.shape[1], 1) 186 | ## out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1) 187 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),mask),-1) 188 | # return out 189 | ##iml0, iml1, flowl0, _, _, _ = lks15.dataloader('%s/kitti_scene/training/'%args.database, typ='trainval') 190 | ##loader_stereo_15 = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=0, prob=0.5,dploader=disparity_loader) 191 | ##iml0, iml1, flowl0, _, _, _ = lks12.dataloader('%s/data_stereo_flow/training/'%args.database) 192 | ##loader_stereo_12 = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=0, prob=0.5,dploader=disparity_loader) 193 | ##iml0, iml1, flowl0, _, _, _ = lmbs.dataloader('%s/mb-ex-training/'%args.database, res='F') 194 | ##loader_stereo_mb = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=0.5, order=1, prob=0.5,dploader=disparity_loader_mb) 195 | ##iml0, iml1, flowl0, _, _, _,_,_ = lsfs.dataloader('%s/sceneflow/'%args.database) 196 | ##loader_stereo_sf = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=1, dploader=disparity_loader_sf) 197 | 198 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*75+[loader_stereo_12]*75+[loader_stereo_mb]*600+[loader_stereo_sf]) 199 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*50+[loader_stereo_12]*50+[loader_stereo_mb]*600+[loader_chairs]) 200 | #data_inuse = torch.utils.data.ConcatDataset([loader_chairs]*2 + [loader_things] +[loader_stereo_15]*300+[loader_stereo_12]*300) # stereo transfer 201 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*20+[loader_stereo_12]*20) # stereo transfer 202 | #data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*20+[loader_kitti12]*20+[loader_stereo_15]*20+[loader_stereo_12]*20) 203 | print('%d batches per epoch'%(len(data_inuse)//batch_size)) 204 | 205 | #TODO 206 | model = VCN([batch_size//ngpus]+data_inuse.datasets[0].shape[::-1], md=[int(4*(args.maxdisp/256)), 4,4,4,4], fac=args.fac) 207 | model = nn.DataParallel(model) 208 | model.cuda() 209 | 210 | total_iters = 0 211 | mean_L=[[0.33,0.33,0.33]] 212 | mean_R=[[0.33,0.33,0.33]] 213 | if args.loadmodel is not None: 214 | pretrained_dict = torch.load(args.loadmodel) 215 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()} 216 | 217 | model.load_state_dict(pretrained_dict['state_dict'],strict=False) 218 | if args.retrain == 'true': 219 | print('re-training') 220 | else: 221 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'r') as f: 222 | total_iters = int(f.readline()) 223 | print('resuming from %d'%total_iters) 224 | mean_L=pretrained_dict['mean_L'] 225 | mean_R=pretrained_dict['mean_R'] 226 | 227 | 228 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 229 | optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), amsgrad=False) 230 | 231 | def train(imgL,imgR,flowl0): 232 | model.train() 233 | imgL = Variable(torch.FloatTensor(imgL)) 234 | imgR = Variable(torch.FloatTensor(imgR)) 235 | flowl0 = Variable(torch.FloatTensor(flowl0)) 236 | 237 | imgL, imgR, flowl0 = imgL.cuda(), imgR.cuda(), flowl0.cuda() 238 | mask = (flowl0[:,:,:,2] == 1) & (flowl0[:,:,:,0].abs() < args.maxdisp) & (flowl0[:,:,:,1].abs() < (args.maxdisp//args.fac)) 239 | mask.detach_(); 240 | 241 | # rearrange inputs 242 | groups = [] 243 | for i in range(ngpus): 244 | groups.append(imgL[i*batch_size//ngpus:(i+1)*batch_size//ngpus]) 245 | groups.append(imgR[i*batch_size//ngpus:(i+1)*batch_size//ngpus]) 246 | 247 | # forward-backward 248 | optimizer.zero_grad() 249 | output = model(torch.cat(groups,0), [flowl0,mask]) 250 | loss = output[-2].mean() 251 | loss.backward() 252 | optimizer.step() 253 | 254 | vis = {} 255 | vis['output2'] = output[0].detach().cpu().numpy() 256 | vis['output3'] = output[1].detach().cpu().numpy() 257 | vis['output4'] = output[2].detach().cpu().numpy() 258 | vis['output5'] = output[3].detach().cpu().numpy() 259 | vis['output6'] = output[4].detach().cpu().numpy() 260 | vis['oor'] = output[6][0].detach().cpu().numpy() 261 | vis['gt'] = flowl0[:,:,:,:].detach().cpu().numpy() 262 | if mask.sum(): 263 | vis['AEPE'] = realEPE(output[0].detach(), flowl0.permute(0,3,1,2).detach(),mask,sparse=False) 264 | vis['mask'] = mask 265 | return loss.data,vis 266 | 267 | def adjust_learning_rate(optimizer, total_iters): 268 | if lr_schedule == 'slong': 269 | if total_iters < 200000: 270 | lr = baselr 271 | elif total_iters < 300000: 272 | lr = baselr/2. 273 | elif total_iters < 400000: 274 | lr = baselr/4. 275 | elif total_iters < 500000: 276 | lr = baselr/8. 277 | elif total_iters < 600000: 278 | lr = baselr/16. 279 | if lr_schedule == 'slong_ours': 280 | if total_iters < 70000: 281 | lr = baselr 282 | elif total_iters < 130000: 283 | lr = baselr/2. 284 | elif total_iters < 190000: 285 | lr = baselr/4. 286 | elif total_iters < 240000: 287 | lr = baselr/8. 288 | elif total_iters < 290000: 289 | lr = baselr/16. 290 | if lr_schedule == 'slong_pwc': 291 | if total_iters < 400000: 292 | lr = baselr 293 | elif total_iters < 600000: 294 | lr = baselr/2. 295 | elif total_iters < 800000: 296 | lr = baselr/4. 297 | elif total_iters < 1000000: 298 | lr = baselr/8. 299 | elif total_iters < 1200000: 300 | lr = baselr/16. 301 | if lr_schedule == 'sfine_pwc': 302 | if total_iters < 1400000: 303 | lr = baselr/10. 304 | elif total_iters < 1500000: 305 | lr = baselr/20. 306 | elif total_iters < 1600000: 307 | lr = baselr/40. 308 | elif total_iters < 1700000: 309 | lr = baselr/80. 310 | if lr_schedule == 'sfine': 311 | if total_iters < 700000: 312 | lr = baselr/10. 313 | elif total_iters < 750000: 314 | lr = baselr/20. 315 | elif total_iters < 800000: 316 | lr = baselr/40. 317 | elif total_iters < 850000: 318 | lr = baselr/80. 319 | if lr_schedule == 'rob_ours': 320 | if total_iters < 30000: 321 | lr = baselr 322 | elif total_iters < 40000: 323 | lr = baselr / 2. 324 | elif total_iters < 50000: 325 | lr = baselr / 4. 326 | elif total_iters < 60000: 327 | lr = baselr / 8. 328 | elif total_iters < 70000: 329 | lr = baselr / 16. 330 | elif total_iters < 100000: 331 | lr = baselr 332 | elif total_iters < 110000: 333 | lr = baselr / 2. 334 | elif total_iters < 120000: 335 | lr = baselr / 4. 336 | elif total_iters < 130000: 337 | lr = baselr / 8. 338 | elif total_iters < 140000: 339 | lr = baselr / 16. 340 | print(lr) 341 | for param_group in optimizer.param_groups: 342 | param_group['lr'] = lr 343 | 344 | # get global counts 345 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'w') as f: 346 | f.write('%d'%total_iters) 347 | 348 | def main(): 349 | TrainImgLoader = torch.utils.data.DataLoader( 350 | data_inuse, 351 | batch_size= batch_size, shuffle= True, num_workers=worker_mul*batch_size, drop_last=True, worker_init_fn=_init_fn, pin_memory=True) 352 | log = logger.Logger(args.savemodel, name=args.logname) 353 | start_full_time = time.time() 354 | global total_iters 355 | 356 | for epoch in range(1, args.epochs+1): 357 | total_train_loss = 0 358 | total_train_aepe = 0 359 | 360 | # training loop 361 | for batch_idx, (imgL_crop, imgR_crop, flowl0) in enumerate(TrainImgLoader): 362 | if batch_idx % 100 == 0: 363 | adjust_learning_rate(optimizer,total_iters) 364 | 365 | if total_iters < 1000: 366 | # subtract mean 367 | mean_L.append( np.asarray(imgL_crop.mean(0).mean(1).mean(1)) ) 368 | mean_R.append( np.asarray(imgR_crop.mean(0).mean(1).mean(1)) ) 369 | imgL_crop -= torch.from_numpy(np.asarray(mean_L).mean(0)[np.newaxis,:,np.newaxis, np.newaxis]).float() 370 | imgR_crop -= torch.from_numpy(np.asarray(mean_R).mean(0)[np.newaxis,:,np.newaxis, np.newaxis]).float() 371 | 372 | start_time = time.time() 373 | loss,vis = train(imgL_crop,imgR_crop, flowl0) 374 | print('Iter %d training loss = %.3f , time = %.2f' %(batch_idx, loss, time.time() - start_time)) 375 | total_train_loss += loss 376 | total_train_aepe += vis['AEPE'] 377 | 378 | if total_iters %10 == 0: 379 | log.scalar_summary('train/loss_batch',loss, total_iters) 380 | log.scalar_summary('train/aepe_batch',vis['AEPE'], total_iters) 381 | if total_iters %100 == 0: 382 | log.image_summary('train/left',imgL_crop[0:1],total_iters) 383 | log.image_summary('train/right',imgR_crop[0:1],total_iters) 384 | log.histo_summary('train/pred_hist',vis['output2'], total_iters) 385 | if len(np.asarray(vis['gt']))>0: 386 | log.histo_summary('train/gt_hist',np.asarray(vis['gt']), total_iters) 387 | gu = vis['gt'][0,:,:,0]; gv = vis['gt'][0,:,:,1] 388 | gu = gu*np.asarray(vis['mask'][0].float().cpu()); gv = gv*np.asarray(vis['mask'][0].float().cpu()) 389 | mask = vis['mask'][0].float().cpu() 390 | log.image_summary('train/gt0', flow_to_image(np.concatenate((gu[:,:,np.newaxis],gv[:,:,np.newaxis],mask[:,:,np.newaxis]),-1))[np.newaxis],total_iters) 391 | log.image_summary('train/output2',flow_to_image(vis['output2'][0].transpose((1,2,0)))[np.newaxis],total_iters) 392 | log.image_summary('train/output3',flow_to_image(vis['output3'][0].transpose((1,2,0)))[np.newaxis],total_iters) 393 | log.image_summary('train/output4',flow_to_image(vis['output4'][0].transpose((1,2,0)))[np.newaxis],total_iters) 394 | log.image_summary('train/output5',flow_to_image(vis['output5'][0].transpose((1,2,0)))[np.newaxis],total_iters) 395 | log.image_summary('train/output6',flow_to_image(vis['output6'][0].transpose((1,2,0)))[np.newaxis],total_iters) 396 | log.image_summary('train/oor',vis['oor'][np.newaxis],total_iters) 397 | torch.cuda.empty_cache() 398 | total_iters += 1 399 | # get global counts 400 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'w') as f: 401 | f.write('%d'%total_iters) 402 | 403 | if (total_iters + 1)%2000==0: 404 | #SAVE 405 | savefilename = args.savemodel+'/'+args.logname+'/finetune_'+str(total_iters)+'.tar' 406 | save_dict = model.state_dict() 407 | save_dict = collections.OrderedDict({k:v for k,v in save_dict.items() if ('flow_reg' not in k or 'conv1' in k) and ('grid' not in k)}) 408 | torch.save({ 409 | 'iters': total_iters, 410 | 'state_dict': save_dict, 411 | 'train_loss': total_train_loss/len(TrainImgLoader), 412 | 'mean_L': mean_L, 413 | 'mean_R': mean_R, 414 | }, savefilename) 415 | 416 | log.scalar_summary('train/loss',total_train_loss/len(TrainImgLoader), epoch) 417 | log.scalar_summary('train/aepe',total_train_aepe/len(TrainImgLoader), epoch) 418 | 419 | 420 | 421 | print('full finetune time = %.2f HR' %((time.time() - start_full_time)/3600)) 422 | print(max_epo) 423 | 424 | 425 | if __name__ == '__main__': 426 | main() 427 | -------------------------------------------------------------------------------- /models/PWCNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | implementation of the PWC-DC network for optical flow estimation by Sun et al., 2018 3 | 4 | Jinwei Gu and Zhile Ren 5 | 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import os 13 | os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory 14 | import numpy as np 15 | import pdb 16 | import time 17 | 18 | 19 | 20 | 21 | __all__ = [ 22 | 'pwc_dc_net', 'pwc_dc_net_old', 'pwcnet' 23 | ] 24 | 25 | def pwcnet(data=None): 26 | """FlowNetS model architecture from the 27 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) 28 | 29 | Args: 30 | data : pretrained weights of the network. will create a new one if not set 31 | """ 32 | model = PWCDCNet() 33 | if data is not None: 34 | model.load_state_dict(data['state_dict']) 35 | return model 36 | 37 | 38 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 39 | return nn.Sequential( 40 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 41 | padding=padding, dilation=dilation, bias=True), 42 | nn.LeakyReLU(0.1)) 43 | 44 | def predict_flow(in_planes): 45 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 46 | 47 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 48 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) 49 | 50 | 51 | 52 | class WarpModule(nn.Module): 53 | def __init__(self, size): 54 | super(WarpModule, self).__init__() 55 | B,W,H = size 56 | # mesh grid 57 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 58 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 59 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 60 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 61 | self.register_buffer('grid',torch.cat((xx,yy),1).float()) 62 | 63 | def forward(self, x, flo): 64 | """ 65 | warp an image/tensor (im2) back to im1, according to the optical flow 66 | 67 | x: [B, C, H, W] (im2) 68 | flo: [B, 2, H, W] flow 69 | 70 | """ 71 | B, C, H, W = x.size() 72 | vgrid = self.grid + flo 73 | 74 | # scale grid to [-1,1] 75 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0 76 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0 77 | 78 | vgrid = vgrid.permute(0,2,3,1) 79 | output = nn.functional.grid_sample(x, vgrid) 80 | mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0 81 | return output*mask.unsqueeze(1).float() 82 | 83 | 84 | class PWCDCNet(nn.Module): 85 | """ 86 | PWC-DC net. add dilation convolution and densenet connections 87 | 88 | """ 89 | def __init__(self, size, md=4): 90 | """ 91 | input: md --- maximum displacement (for correlation. default: 4), after warpping 92 | 93 | """ 94 | super(PWCDCNet,self).__init__() 95 | self.conv1a = conv(3, 16, kernel_size=3, stride=2) 96 | self.conv1aa = conv(16, 16, kernel_size=3, stride=1) 97 | self.conv1b = conv(16, 16, kernel_size=3, stride=1) 98 | self.conv2a = conv(16, 32, kernel_size=3, stride=2) 99 | self.conv2aa = conv(32, 32, kernel_size=3, stride=1) 100 | self.conv2b = conv(32, 32, kernel_size=3, stride=1) 101 | self.conv3a = conv(32, 64, kernel_size=3, stride=2) 102 | self.conv3aa = conv(64, 64, kernel_size=3, stride=1) 103 | self.conv3b = conv(64, 64, kernel_size=3, stride=1) 104 | self.conv4a = conv(64, 96, kernel_size=3, stride=2) 105 | self.conv4aa = conv(96, 96, kernel_size=3, stride=1) 106 | self.conv4b = conv(96, 96, kernel_size=3, stride=1) 107 | self.conv5a = conv(96, 128, kernel_size=3, stride=2) 108 | self.conv5aa = conv(128,128, kernel_size=3, stride=1) 109 | self.conv5b = conv(128,128, kernel_size=3, stride=1) 110 | self.conv6aa = conv(128,196, kernel_size=3, stride=2) 111 | self.conv6a = conv(196,196, kernel_size=3, stride=1) 112 | self.conv6b = conv(196,196, kernel_size=3, stride=1) 113 | 114 | self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32]) 115 | self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16]) 116 | self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8]) 117 | self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4]) 118 | self.leakyRELU = nn.LeakyReLU(0.1) 119 | 120 | nd = (2*md+1)**2 121 | dd = np.cumsum([128,128,96,64,32]) 122 | 123 | od = nd 124 | self.conv6_0 = conv(od, 128, kernel_size=3, stride=1) 125 | self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1) 126 | self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1) 127 | self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1) 128 | self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1) 129 | self.predict_flow6 = predict_flow(od+dd[4]) 130 | self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 131 | self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 132 | 133 | od = nd+128+4 134 | self.conv5_0 = conv(od, 128, kernel_size=3, stride=1) 135 | self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1) 136 | self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1) 137 | self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1) 138 | self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1) 139 | self.predict_flow5 = predict_flow(od+dd[4]) 140 | self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 141 | self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 142 | 143 | od = nd+96+4 144 | self.conv4_0 = conv(od, 128, kernel_size=3, stride=1) 145 | self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1) 146 | self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1) 147 | self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1) 148 | self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1) 149 | self.predict_flow4 = predict_flow(od+dd[4]) 150 | self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 151 | self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 152 | 153 | od = nd+64+4 154 | self.conv3_0 = conv(od, 128, kernel_size=3, stride=1) 155 | self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1) 156 | self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1) 157 | self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1) 158 | self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1) 159 | self.predict_flow3 = predict_flow(od+dd[4]) 160 | self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 161 | self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1) 162 | 163 | od = nd+32+4 164 | self.conv2_0 = conv(od, 128, kernel_size=3, stride=1) 165 | self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1) 166 | self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1) 167 | self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1) 168 | self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1) 169 | self.predict_flow2 = predict_flow(od+dd[4]) 170 | self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1) 171 | 172 | self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1) 173 | self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2) 174 | self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4) 175 | self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8) 176 | self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16) 177 | self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1) 178 | self.dc_conv7 = predict_flow(32) 179 | 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 182 | nn.init.kaiming_normal(m.weight.data, mode='fan_in') 183 | if m.bias is not None: 184 | m.bias.data.zero_() 185 | 186 | # load weights 187 | #pretrained_dict = torch.load('/data/gengshay/PWC-Net//PyTorch/pwc_net_chairs.pth.tar') 188 | #pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict.items() if 'grid' not in k and 'flow_reg' not in k} 189 | #self.load_state_dict(pretrained_dict['state_dict'],strict=False) 190 | 191 | def corr(self, refimg_fea, targetimg_fea): 192 | maxdisp=4 193 | b,c,h,w = refimg_fea.shape 194 | targetimg_fea = F.unfold(targetimg_fea, (2*maxdisp+1,2*maxdisp+1), padding=maxdisp).view(b,c,2*maxdisp+1, 2*maxdisp+1**2,h,w) 195 | cost = refimg_fea.view(b,c,h,w)[:,:,np.newaxis, np.newaxis]*targetimg_fea.view(b,c,2*maxdisp+1, 2*maxdisp+1**2,h,w) 196 | cost = cost.sum(1) 197 | 198 | b, ph, pw, h, w = cost.size() 199 | cost = cost.view(b, ph * pw, h, w)/refimg_fea.size(1) 200 | return cost 201 | 202 | def weight_parameters(self): 203 | return [param for name, param in self.named_parameters() if 'weight' in name] 204 | 205 | def bias_parameters(self): 206 | return [param for name, param in self.named_parameters() if 'bias' in name] 207 | 208 | #@profile 209 | def forward(self,im): 210 | bs = im.shape[0]//2 211 | c01 = self.conv1b(self.conv1aa(self.conv1a(im))) 212 | c02 = self.conv2b(self.conv2aa(self.conv2a(c01))) 213 | c03 = self.conv3b(self.conv3aa(self.conv3a(c02))) 214 | c04 = self.conv4b(self.conv4aa(self.conv4a(c03))) 215 | c05 = self.conv5b(self.conv5aa(self.conv5a(c04))) 216 | c06 = self.conv6b(self.conv6a(self.conv6aa(c05))) 217 | c11 = c01[:bs]; c21 = c01[bs:] 218 | c12 = c02[:bs]; c22 = c02[bs:] 219 | c13 = c03[:bs]; c23 = c03[bs:] 220 | c14 = c04[:bs]; c24 = c04[bs:] 221 | c15 = c05[:bs]; c25 = c05[bs:] 222 | c16 = c06[:bs]; c26 = c06[bs:] 223 | 224 | corr6 = self.corr(c16, c26) 225 | corr6 = self.leakyRELU(corr6) 226 | x = torch.cat((self.conv6_0(corr6), corr6),1) 227 | x = torch.cat((self.conv6_1(x), x),1) 228 | x = torch.cat((self.conv6_2(x), x),1) 229 | x = torch.cat((self.conv6_3(x), x),1) 230 | x = torch.cat((self.conv6_4(x), x),1) 231 | flow6 = self.predict_flow6(x) 232 | up_flow6 = self.deconv6(flow6) 233 | up_feat6 = self.upfeat6(x) 234 | 235 | 236 | warp5 = self.warp5(c25, up_flow6*0.625) 237 | corr5 = self.corr(c15, warp5) 238 | corr5 = self.leakyRELU(corr5) 239 | x = torch.cat((corr5, c15, up_flow6, up_feat6), 1) 240 | x = torch.cat((self.conv5_0(x), x),1) 241 | x = torch.cat((self.conv5_1(x), x),1) 242 | x = torch.cat((self.conv5_2(x), x),1) 243 | x = torch.cat((self.conv5_3(x), x),1) 244 | x = torch.cat((self.conv5_4(x), x),1) 245 | flow5 = self.predict_flow5(x) 246 | up_flow5 = self.deconv5(flow5) 247 | up_feat5 = self.upfeat5(x) 248 | 249 | 250 | warp4 = self.warp4(c24, up_flow5*1.25) 251 | corr4 = self.corr(c14, warp4) 252 | corr4 = self.leakyRELU(corr4) 253 | x = torch.cat((corr4, c14, up_flow5, up_feat5), 1) 254 | x = torch.cat((self.conv4_0(x), x),1) 255 | x = torch.cat((self.conv4_1(x), x),1) 256 | x = torch.cat((self.conv4_2(x), x),1) 257 | x = torch.cat((self.conv4_3(x), x),1) 258 | x = torch.cat((self.conv4_4(x), x),1) 259 | flow4 = self.predict_flow4(x) 260 | up_flow4 = self.deconv4(flow4) 261 | up_feat4 = self.upfeat4(x) 262 | 263 | 264 | warp3 = self.warp3(c23, up_flow4*2.5) 265 | corr3 = self.corr(c13, warp3) 266 | corr3 = self.leakyRELU(corr3) 267 | x = torch.cat((corr3, c13, up_flow4, up_feat4), 1) 268 | x = torch.cat((self.conv3_0(x), x),1) 269 | x = torch.cat((self.conv3_1(x), x),1) 270 | x = torch.cat((self.conv3_2(x), x),1) 271 | x = torch.cat((self.conv3_3(x), x),1) 272 | x = torch.cat((self.conv3_4(x), x),1) 273 | flow3 = self.predict_flow3(x) 274 | up_flow3 = self.deconv3(flow3) 275 | up_feat3 = self.upfeat3(x) 276 | 277 | 278 | warp2 = self.warp2(c22, up_flow3*5.0) 279 | corr2 = self.corr(c12, warp2) 280 | corr2 = self.leakyRELU(corr2) 281 | x = torch.cat((corr2, c12, up_flow3, up_feat3), 1) 282 | x = torch.cat((self.conv2_0(x), x),1) 283 | x = torch.cat((self.conv2_1(x), x),1) 284 | x = torch.cat((self.conv2_2(x), x),1) 285 | x = torch.cat((self.conv2_3(x), x),1) 286 | x = torch.cat((self.conv2_4(x), x),1) 287 | flow2 = self.predict_flow2(x) 288 | 289 | 290 | x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x)))) 291 | flow2 += self.dc_conv7(self.dc_conv6(self.dc_conv5(x))) 292 | 293 | flow2 = F.upsample(flow2, [im.size()[2],im.size()[3]], mode='bilinear') 294 | 295 | if self.training: 296 | flow3 = F.upsample(flow3, [im.size()[2],im.size()[3]], mode='bilinear') 297 | flow4 = F.upsample(flow4, [im.size()[2],im.size()[3]], mode='bilinear') 298 | flow5 = F.upsample(flow5, [im.size()[2],im.size()[3]], mode='bilinear') 299 | flow6 = F.upsample(flow6, [im.size()[2],im.size()[3]], mode='bilinear') 300 | return flow2*20,flow3*20,flow4*20,flow5*20,flow6*20,flow2, flow2[:,0] 301 | else: 302 | return flow2*20, flow2[0,0] 303 | 304 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/models/__init__.py -------------------------------------------------------------------------------- /models/conv4d.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch.nn as nn 3 | import math 4 | import torch 5 | from torch.nn.parameter import Parameter 6 | import torch.nn.functional as F 7 | from torch.nn import Module 8 | from torch.nn.modules.conv import _ConvNd 9 | from torch.nn.modules.utils import _quadruple 10 | from torch.autograd import Variable 11 | from torch.nn import Conv2d 12 | 13 | def conv4d(data,filters,bias=None,permute_filters=True,use_half=False): 14 | """ 15 | This is done by stacking results of multiple 3D convolutions, and is very slow. 16 | Taken from https://github.com/ignacio-rocco/ncnet 17 | """ 18 | b,c,h,w,d,t=data.size() 19 | 20 | data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 21 | 22 | # Same permutation is done with filters, unless already provided with permutation 23 | if permute_filters: 24 | filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 25 | 26 | c_out=filters.size(1) 27 | if use_half: 28 | output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 29 | else: 30 | output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 31 | 32 | padding=filters.size(0)//2 33 | if use_half: 34 | Z=Variable(torch.zeros(padding,b,c,w,d,t).half()) 35 | else: 36 | Z=Variable(torch.zeros(padding,b,c,w,d,t)) 37 | 38 | if data.is_cuda: 39 | Z=Z.cuda(data.get_device()) 40 | output=output.cuda(data.get_device()) 41 | 42 | data_padded = torch.cat((Z,data,Z),0) 43 | 44 | 45 | for i in range(output.size(0)): # loop on first feature dimension 46 | # convolve with center channel of filter (at position=padding) 47 | output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:], 48 | filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding) 49 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) 50 | for p in range(1,padding+1): 51 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:], 52 | filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding) 53 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:], 54 | filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding) 55 | 56 | output=output.permute(1,2,0,3,4,5).contiguous() 57 | return output 58 | 59 | class Conv4d(_ConvNd): 60 | """Applies a 4D convolution over an input signal composed of several input 61 | planes. 62 | """ 63 | 64 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): 65 | # stride, dilation and groups !=1 functionality not tested 66 | stride=1 67 | dilation=1 68 | groups=1 69 | # zero padding is added automatically in conv4d function to preserve tensor size 70 | padding = 0 71 | kernel_size = _quadruple(kernel_size) 72 | stride = _quadruple(stride) 73 | padding = _quadruple(padding) 74 | dilation = _quadruple(dilation) 75 | super(Conv4d, self).__init__( 76 | in_channels, out_channels, kernel_size, stride, padding, dilation, 77 | False, _quadruple(0), groups, bias) 78 | # weights will be sliced along one dimension during convolution loop 79 | # make the looping dimension to be the first one in the tensor, 80 | # so that we don't need to call contiguous() inside the loop 81 | self.pre_permuted_filters=pre_permuted_filters 82 | if self.pre_permuted_filters: 83 | self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous() 84 | self.use_half=False 85 | # self.isbias = bias 86 | # if not self.isbias: 87 | # self.bn = torch.nn.BatchNorm1d(out_channels) 88 | 89 | 90 | def forward(self, input): 91 | out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor 92 | # if not self.isbias: 93 | # b,c,u,v,h,w = out.shape 94 | # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) 95 | return out 96 | 97 | class fullConv4d(torch.nn.Module): 98 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): 99 | super(fullConv4d, self).__init__() 100 | self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters) 101 | self.isbias = bias 102 | if not self.isbias: 103 | self.bn = torch.nn.BatchNorm1d(out_channels) 104 | 105 | def forward(self, input): 106 | out = self.conv(input) 107 | if not self.isbias: 108 | b,c,u,v,h,w = out.shape 109 | out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) 110 | return out 111 | 112 | class butterfly4D(torch.nn.Module): 113 | ''' 114 | butterfly 4d 115 | ''' 116 | def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1): 117 | super(butterfly4D, self).__init__() 118 | self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups), 119 | nn.ReLU(inplace=True),) 120 | self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) 121 | self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) 122 | self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 123 | self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 124 | self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 125 | 126 | #@profile 127 | def forward(self,x): 128 | out = self.proj(x) 129 | b,c,u,v,h,w = out.shape # 9x9 130 | 131 | out1 = self.conva1(out) # 5x5, 3 132 | _,c1,u1,v1,h1,w1 = out1.shape 133 | 134 | out2 = self.conva2(out1) # 3x3, 9 135 | _,c2,u2,v2,h2,w2 = out2.shape 136 | 137 | out2 = self.convb3(out2) # 3x3, 9 138 | 139 | tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5 140 | tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5 141 | out1 = tout1 + out1 142 | out1 = self.convb2(out1) 143 | 144 | tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1) 145 | tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w) 146 | out = tout + out 147 | out = self.convb1(out) 148 | 149 | return out 150 | 151 | 152 | 153 | class projfeat4d(torch.nn.Module): 154 | ''' 155 | Turn 3d projection into 2d projection 156 | ''' 157 | def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1): 158 | super(projfeat4d, self).__init__() 159 | self.with_bn = with_bn 160 | self.stride = stride 161 | self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups) 162 | self.bn = nn.BatchNorm3d(out_planes) 163 | 164 | def forward(self,x): 165 | b,c,u,v,h,w = x.size() 166 | x = self.conv1(x.view(b,c,u,v,h*w)) 167 | if self.with_bn: 168 | x = self.bn(x) 169 | _,c,u,v,_ = x.shape 170 | x = x.view(b,c,u,v,h,w) 171 | return x 172 | 173 | class sepConv4d(torch.nn.Module): 174 | ''' 175 | Separable 4d convolution block as 2 3D convolutions 176 | ''' 177 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1): 178 | super(sepConv4d, self).__init__() 179 | bias = not with_bn 180 | self.isproj = False 181 | self.stride = stride[0] 182 | expand = 1 183 | 184 | if with_bn: 185 | if in_planes != out_planes: 186 | self.isproj = True 187 | self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups), 188 | nn.BatchNorm2d(out_planes)) 189 | if full: 190 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), 191 | nn.BatchNorm3d(in_planes)) 192 | else: 193 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), 194 | nn.BatchNorm3d(in_planes)) 195 | self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups), 196 | nn.BatchNorm3d(in_planes*expand)) 197 | else: 198 | if in_planes != out_planes: 199 | self.isproj = True 200 | self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups) 201 | if full: 202 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) 203 | else: 204 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) 205 | self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups) 206 | self.relu = nn.ReLU(inplace=True) 207 | 208 | #@profile 209 | def forward(self,x): 210 | b,c,u,v,h,w = x.shape 211 | x = self.conv2(x.view(b,c,u,v,-1)) # WTA convolution over (u,v) 212 | b,c,u,v,_ = x.shape 213 | x = self.relu(x) 214 | x = self.conv1(x.view(b,c,-1,h,w)) # spatial convolution over (x,y) 215 | b,c,_,h,w = x.shape 216 | 217 | if self.isproj: 218 | x = self.proj(x.view(b,c,-1,w)) 219 | x = x.view(b,-1,u,v,h,w) 220 | return x 221 | 222 | 223 | class sepConv4dBlock(torch.nn.Module): 224 | ''' 225 | Separable 4d convolution block as 2 2D convolutions and a projection 226 | layer 227 | ''' 228 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1): 229 | super(sepConv4dBlock, self).__init__() 230 | if in_planes == out_planes and stride==(1,1,1): 231 | self.downsample = None 232 | else: 233 | if full: 234 | self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups) 235 | else: 236 | self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups) 237 | self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups) 238 | self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups) 239 | self.relu1 = nn.ReLU(inplace=True) 240 | self.relu2 = nn.ReLU(inplace=True) 241 | 242 | #@profile 243 | def forward(self,x): 244 | out = self.relu1(self.conv1(x)) 245 | if self.downsample: 246 | x = self.downsample(x) 247 | out = self.relu2(x + self.conv2(out)) 248 | return out 249 | 250 | 251 | ##import torch.backends.cudnn as cudnn 252 | ##cudnn.benchmark = True 253 | #import time 254 | ##im = torch.randn(9,64,9,160,224).cuda() 255 | ##net = torch.nn.Conv3d(64, 64, 3).cuda() 256 | ##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda() 257 | ##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda() 258 | # 259 | ##im = torch.randn(1,16,9,9,96,320).cuda() 260 | ##net = sepConv4d(16,16,with_bn=False).cuda() 261 | # 262 | ##im = torch.randn(1,16,81,96,320).cuda() 263 | ##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda() 264 | # 265 | ##im = torch.randn(1,16,9,9,96*320).cuda() 266 | ##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda() 267 | # 268 | ##im = torch.randn(10000,10,9,9).cuda() 269 | ##net = torch.nn.Conv2d(10,10,3,padding=1).cuda() 270 | # 271 | ##im = torch.randn(81,16,96,320).cuda() 272 | ##net = torch.nn.Conv2d(16,16,3,padding=1).cuda() 273 | #c= int(16 *1) 274 | #cp = int(16 *1) 275 | #h=int(96 *4) 276 | #w=int(320 *4) 277 | #k=3 278 | #im = torch.randn(1,c,h,w).cuda() 279 | #net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda() 280 | # 281 | #im2 = torch.randn(cp,k*k*c).cuda() 282 | #im1 = F.unfold(im, (k,k), padding=k//2)[0] 283 | # 284 | # 285 | #net(im) 286 | #net(im) 287 | #torch.mm(im2,im1) 288 | #torch.mm(im2,im1) 289 | #torch.cuda.synchronize() 290 | #beg = time.time() 291 | #for i in range(100): 292 | # net(im) 293 | # #im1 = F.unfold(im, (k,k), padding=k//2)[0] 294 | # torch.mm(im2,im1) 295 | #torch.cuda.synchronize() 296 | #print('%f'%((time.time()-beg)*10.)) 297 | -------------------------------------------------------------------------------- /models/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | import pdb 10 | 11 | class residualBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True): 15 | super(residualBlock, self).__init__() 16 | if dilation > 1: 17 | padding = dilation 18 | else: 19 | padding = 1 20 | 21 | if with_bn: 22 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation) 23 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1) 24 | else: 25 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False) 26 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False) 27 | self.downsample = downsample 28 | self.relu = nn.LeakyReLU(0.1, inplace=True) 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.convbnrelu1(x) 34 | out = self.convbn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | return self.relu(out) 41 | 42 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 43 | return nn.Sequential( 44 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 45 | padding=padding, dilation=dilation, bias=True), 46 | nn.BatchNorm2d(out_planes), 47 | nn.LeakyReLU(0.1,inplace=True)) 48 | 49 | 50 | class conv2DBatchNorm(nn.Module): 51 | def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True): 52 | super(conv2DBatchNorm, self).__init__() 53 | bias = not with_bn 54 | 55 | if dilation > 1: 56 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 57 | padding=padding, stride=stride, bias=bias, dilation=dilation) 58 | 59 | else: 60 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 61 | padding=padding, stride=stride, bias=bias, dilation=1) 62 | 63 | 64 | if with_bn: 65 | self.cb_unit = nn.Sequential(conv_mod, 66 | nn.BatchNorm2d(int(n_filters)),) 67 | else: 68 | self.cb_unit = nn.Sequential(conv_mod,) 69 | 70 | def forward(self, inputs): 71 | outputs = self.cb_unit(inputs) 72 | return outputs 73 | 74 | class conv2DBatchNormRelu(nn.Module): 75 | def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True): 76 | super(conv2DBatchNormRelu, self).__init__() 77 | bias = not with_bn 78 | if dilation > 1: 79 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 80 | padding=padding, stride=stride, bias=bias, dilation=dilation) 81 | 82 | else: 83 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 84 | padding=padding, stride=stride, bias=bias, dilation=1) 85 | 86 | if with_bn: 87 | self.cbr_unit = nn.Sequential(conv_mod, 88 | nn.BatchNorm2d(int(n_filters)), 89 | nn.LeakyReLU(0.1, inplace=True),) 90 | else: 91 | self.cbr_unit = nn.Sequential(conv_mod, 92 | nn.LeakyReLU(0.1, inplace=True),) 93 | 94 | def forward(self, inputs): 95 | outputs = self.cbr_unit(inputs) 96 | return outputs 97 | 98 | class pyramidPooling(nn.Module): 99 | 100 | def __init__(self, in_channels, with_bn=True, levels=4): 101 | super(pyramidPooling, self).__init__() 102 | self.levels = levels 103 | 104 | self.paths = [] 105 | for i in range(levels): 106 | self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn)) 107 | self.path_module_list = nn.ModuleList(self.paths) 108 | self.relu = nn.LeakyReLU(0.1, inplace=True) 109 | 110 | def forward(self, x): 111 | h, w = x.shape[2:] 112 | 113 | k_sizes = [] 114 | strides = [] 115 | for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int): 116 | k_sizes.append((int(h/pool_size), int(w/pool_size))) 117 | strides.append((int(h/pool_size), int(w/pool_size))) 118 | k_sizes = k_sizes[::-1] 119 | strides = strides[::-1] 120 | 121 | pp_sum = x 122 | 123 | for i, module in enumerate(self.path_module_list): 124 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 125 | out = module(out) 126 | out = F.upsample(out, size=(h,w), mode='bilinear') 127 | pp_sum = pp_sum + 1./self.levels*out 128 | pp_sum = self.relu(pp_sum/2.) 129 | 130 | return pp_sum 131 | 132 | class pspnet(nn.Module): 133 | """ 134 | Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py 135 | """ 136 | def __init__(self, is_proj=True,groups=1): 137 | super(pspnet, self).__init__() 138 | self.inplanes = 32 139 | self.is_proj = is_proj 140 | 141 | # Encoder 142 | self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16, 143 | padding=1, stride=2) 144 | self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16, 145 | padding=1, stride=1) 146 | self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32, 147 | padding=1, stride=1) 148 | # Vanilla Residual Blocks 149 | self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) 150 | self.res_block5 = self._make_layer(residualBlock,128,1,stride=2) 151 | self.res_block6 = self._make_layer(residualBlock,128,1,stride=2) 152 | self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) 153 | self.pyramid_pooling = pyramidPooling(128, levels=3) 154 | 155 | # Iconvs 156 | self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2), 157 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 158 | padding=1, stride=1)) 159 | self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, 160 | padding=1, stride=1) 161 | self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2), 162 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 163 | padding=1, stride=1)) 164 | self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, 165 | padding=1, stride=1) 166 | self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2), 167 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 168 | padding=1, stride=1)) 169 | self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 170 | padding=1, stride=1) 171 | self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2), 172 | conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, 173 | padding=1, stride=1)) 174 | self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64, 175 | padding=1, stride=1) 176 | 177 | if self.is_proj: 178 | self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 179 | self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 180 | self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 181 | self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) 182 | self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 187 | m.weight.data.normal_(0, math.sqrt(2. / n)) 188 | if hasattr(m.bias,'data'): 189 | m.bias.data.zero_() 190 | 191 | 192 | def _make_layer(self, block, planes, blocks, stride=1): 193 | downsample = None 194 | if stride != 1 or self.inplanes != planes * block.expansion: 195 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, 196 | kernel_size=1, stride=stride, bias=False), 197 | nn.BatchNorm2d(planes * block.expansion),) 198 | layers = [] 199 | layers.append(block(self.inplanes, planes, stride, downsample)) 200 | self.inplanes = planes * block.expansion 201 | for i in range(1, blocks): 202 | layers.append(block(self.inplanes, planes)) 203 | return nn.Sequential(*layers) 204 | 205 | def forward(self, x): 206 | # H, W -> H/2, W/2 207 | conv1 = self.convbnrelu1_1(x) 208 | conv1 = self.convbnrelu1_2(conv1) 209 | conv1 = self.convbnrelu1_3(conv1) 210 | 211 | ## H/2, W/2 -> H/4, W/4 212 | pool1 = F.max_pool2d(conv1, 3, 2, 1) 213 | 214 | # H/4, W/4 -> H/16, W/16 215 | rconv3 = self.res_block3(pool1) 216 | conv4 = self.res_block5(rconv3) 217 | conv5 = self.res_block6(conv4) 218 | conv6 = self.res_block7(conv5) 219 | conv6 = self.pyramid_pooling(conv6) 220 | 221 | conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear') 222 | concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1) 223 | conv5 = self.iconv5(concat5) 224 | 225 | conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear') 226 | concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1) 227 | conv4 = self.iconv4(concat4) 228 | 229 | conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear') 230 | concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1) 231 | conv3 = self.iconv3(concat3) 232 | 233 | conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear') 234 | concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1) 235 | conv2 = self.iconv2(concat2) 236 | 237 | if self.is_proj: 238 | proj6 = self.proj6(conv6) 239 | proj5 = self.proj5(conv5) 240 | proj4 = self.proj4(conv4) 241 | proj3 = self.proj3(conv3) 242 | proj2 = self.proj2(conv2) 243 | return proj6,proj5,proj4,proj3,proj2 244 | else: 245 | return conv6, conv5, conv4, conv3, conv2 246 | 247 | 248 | class pspnet_s(nn.Module): 249 | """ 250 | Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py 251 | """ 252 | def __init__(self, is_proj=True,groups=1): 253 | super(pspnet_s, self).__init__() 254 | self.inplanes = 32 255 | self.is_proj = is_proj 256 | 257 | # Encoder 258 | self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16, 259 | padding=1, stride=2) 260 | self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16, 261 | padding=1, stride=1) 262 | self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32, 263 | padding=1, stride=1) 264 | # Vanilla Residual Blocks 265 | self.res_block3 = self._make_layer(residualBlock,64,1,stride=2) 266 | self.res_block5 = self._make_layer(residualBlock,128,1,stride=2) 267 | self.res_block6 = self._make_layer(residualBlock,128,1,stride=2) 268 | self.res_block7 = self._make_layer(residualBlock,128,1,stride=2) 269 | self.pyramid_pooling = pyramidPooling(128, levels=3) 270 | 271 | # Iconvs 272 | self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2), 273 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 274 | padding=1, stride=1)) 275 | self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, 276 | padding=1, stride=1) 277 | self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2), 278 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 279 | padding=1, stride=1)) 280 | self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128, 281 | padding=1, stride=1) 282 | self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2), 283 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 284 | padding=1, stride=1)) 285 | self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64, 286 | padding=1, stride=1) 287 | #self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2), 288 | # conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32, 289 | # padding=1, stride=1)) 290 | #self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64, 291 | # padding=1, stride=1) 292 | 293 | if self.is_proj: 294 | self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 295 | self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 296 | self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1) 297 | self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) 298 | #self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1) 299 | 300 | for m in self.modules(): 301 | if isinstance(m, nn.Conv2d): 302 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 303 | m.weight.data.normal_(0, math.sqrt(2. / n)) 304 | if hasattr(m.bias,'data'): 305 | m.bias.data.zero_() 306 | 307 | 308 | def _make_layer(self, block, planes, blocks, stride=1): 309 | downsample = None 310 | if stride != 1 or self.inplanes != planes * block.expansion: 311 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, 312 | kernel_size=1, stride=stride, bias=False), 313 | nn.BatchNorm2d(planes * block.expansion),) 314 | layers = [] 315 | layers.append(block(self.inplanes, planes, stride, downsample)) 316 | self.inplanes = planes * block.expansion 317 | for i in range(1, blocks): 318 | layers.append(block(self.inplanes, planes)) 319 | return nn.Sequential(*layers) 320 | 321 | def forward(self, x): 322 | # H, W -> H/2, W/2 323 | conv1 = self.convbnrelu1_1(x) 324 | conv1 = self.convbnrelu1_2(conv1) 325 | conv1 = self.convbnrelu1_3(conv1) 326 | 327 | ## H/2, W/2 -> H/4, W/4 328 | pool1 = F.max_pool2d(conv1, 3, 2, 1) 329 | 330 | # H/4, W/4 -> H/16, W/16 331 | rconv3 = self.res_block3(pool1) 332 | conv4 = self.res_block5(rconv3) 333 | conv5 = self.res_block6(conv4) 334 | conv6 = self.res_block7(conv5) 335 | conv6 = self.pyramid_pooling(conv6) 336 | 337 | conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear') 338 | concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1) 339 | conv5 = self.iconv5(concat5) 340 | 341 | conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear') 342 | concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1) 343 | conv4 = self.iconv4(concat4) 344 | 345 | conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear') 346 | concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1) 347 | conv3 = self.iconv3(concat3) 348 | 349 | #conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear') 350 | #concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1) 351 | #conv2 = self.iconv2(concat2) 352 | 353 | if self.is_proj: 354 | proj6 = self.proj6(conv6) 355 | proj5 = self.proj5(conv5) 356 | proj4 = self.proj4(conv4) 357 | proj3 = self.proj3(conv3) 358 | # proj2 = self.proj2(conv2) 359 | # return proj6,proj5,proj4,proj3,proj2 360 | return proj6,proj5,proj4,proj3 361 | else: 362 | # return conv6, conv5, conv4, conv3, conv2 363 | return conv6, conv5, conv4, conv3 364 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | modelname=$1 2 | filename=$modelname-$2 3 | array=( 4 | 1999 3999 5999 7999 9999 11999 13999 15999 17999 19999 21999 23999 25999 27999 29999 31999 33999 35999 37999 39999 41999 43999 45999 47999 49999 51999 53999 55999 57999 59999 61999 63999 65999 67999 69999 71999 73999 75999 75999 77999 79999 81999 83999 85999 87999 89999 91999 93999 95999 97999 99999 101999 103999 105999 107999 109999 111999 113999 115999 117999 119999 121999 123999 125999 127999 129999 131999 133999 135999 137999 139999 5 | # 141999 143999 145999 147999 6 | #149999 7 | # 151999 153999 155999 157999 159999 8 | # 161999 163999 165999 167999 169999 171999 173999 175999 177999 179999 181999 183999 185999 187999 189999 191999 193999 195999 197999 9 | #199999 201999 203999 205999 207999 209999 211999 213999 215999 217999 219999 10 | #221999 223999 225999 227999 229999 231999 233999 11 | ) 12 | 13 | #echo $modelname > results/$filename 14 | echo $modelname > results/s-$filename 15 | mkdir /data/ptmodel/$modelname/kitti15 16 | mkdir /data/ptmodel/$modelname/kitti15/flow 17 | for i in "${array[@]}" 18 | do 19 | # echo $i >> results/$filename 20 | # CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015 --datapath /ssd/kitti_scene/training/ --outdir /data/ptmodel/$modelname/ --loadmodel /data/ptmodel/$modelname/finetune_$i.tar --testres 1 21 | # python eval_tmp.py --path /data/ptmodel/$modelname/ --vis no --dataset 2015 >> results/$filename 22 | 23 | echo $i >> results/s-$filename 24 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset sintel --datapath /ssd/rob_flow/training/ --outdir /data/ptmodel/$modelname/ --loadmodel /data/ptmodel/$modelname/finetune_$i.tar --testres 1 --maxdisp 448 --fac 1.4 25 | python eval_tmp.py --path /data/ptmodel/$modelname/ --vis no --dataset sintel >> results/s-$filename 26 | 27 | done 28 | -------------------------------------------------------------------------------- /run_self.sh: -------------------------------------------------------------------------------- 1 | ## point $datapath to the folder of your images 2 | datapath=./dataset/IIW/ 3 | modelname=things 4 | i=239999 5 | CUDA_VISIBLE_DEVICES=1 python submission.py --dataset kitticlip --datapath $datapath/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 256 --fac 1 6 | -------------------------------------------------------------------------------- /submission.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | sys.path.insert(0,'utils/') 4 | #sys.path.insert(0,'dataloader/') 5 | sys.path.insert(0,'models/') 6 | import cv2 7 | import pdb 8 | import argparse 9 | import numpy as np 10 | import skimage.io 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.utils.data 16 | from torch.autograd import Variable 17 | import time 18 | from utils.io import mkdir_p 19 | from utils.util_flow import write_flow, save_pfm 20 | cudnn.benchmark = False 21 | 22 | parser = argparse.ArgumentParser(description='VCN') 23 | parser.add_argument('--dataset', default='2015', 24 | help='{2015: KITTI-15, sintel}') 25 | parser.add_argument('--datapath', default='/ssd/kitti_scene/training/', 26 | help='data path') 27 | parser.add_argument('--loadmodel', default=None, 28 | help='model path') 29 | parser.add_argument('--outdir', default='output', 30 | help='output path') 31 | parser.add_argument('--model', default='VCN', 32 | help='VCN or VCN_small') 33 | parser.add_argument('--testres', type=float, default=1, 34 | help='resolution, {1: original resolution, 2: 2X resolution}') 35 | parser.add_argument('--maxdisp', type=int ,default=256, 36 | help='maxium disparity. Only affect the coarsest cost volume size') 37 | parser.add_argument('--fac', type=float ,default=1, 38 | help='controls the shape of search grid. Only affect the coarse cost volume size') 39 | args = parser.parse_args() 40 | 41 | 42 | 43 | # dataloader 44 | if args.dataset == '2015': 45 | #from dataloader import kitti15list as DA 46 | from dataloader import kitti15list_val as DA 47 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 48 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 49 | elif args.dataset == '2015test': 50 | from dataloader import kitti15list as DA 51 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 52 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 53 | elif args.dataset == 'tumclip': 54 | from dataloader import kitticliplist as DA 55 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 56 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 57 | elif args.dataset == 'kitticlip': 58 | from dataloader import kitticliplist as DA 59 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 60 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 61 | elif args.dataset == '2012': 62 | from dataloader import kitti12list as DA 63 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 64 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 65 | elif args.dataset == '2012test': 66 | from dataloader import kitti12list as DA 67 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 68 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 69 | elif args.dataset == 'mb': 70 | from dataloader import mblist as DA 71 | maxw,maxh = [int(args.testres*640), int(args.testres*512)] 72 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 73 | elif args.dataset == 'chairs': 74 | from dataloader import chairslist as DA 75 | maxw,maxh = [int(args.testres*512), int(args.testres*384)] 76 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 77 | elif args.dataset == 'sinteltest': 78 | from dataloader import sintellist as DA 79 | maxw,maxh = [int(args.testres*1024), int(args.testres*448)] 80 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 81 | elif args.dataset == 'sintel': 82 | #from dataloader import sintellist_clean as DA 83 | from dataloader import sintellist_val as DA 84 | #from dataloader import sintellist_val_2s as DA 85 | maxw,maxh = [int(args.testres*1024), int(args.testres*448)] 86 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 87 | elif args.dataset == 'hd1k': 88 | from dataloader import hd1klist as DA 89 | maxw,maxh = [int(args.testres*2560), int(args.testres*1088)] 90 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 91 | elif args.dataset == 'mbstereo': 92 | from dataloader import MiddleburySubmit as DA 93 | maxw,maxh = [int(args.testres*900), int(args.testres*750)] 94 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath) 95 | elif args.dataset == 'k15stereo': 96 | from dataloader import stereo_kittilist15 as DA 97 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 98 | test_left_img, test_right_img ,_,_,_,_= DA.dataloader(args.datapath, typ='trainval') 99 | elif args.dataset == 'k12stereo': 100 | from dataloader import stereo_kittilist12 as DA 101 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)] 102 | test_left_img, test_right_img ,_,_,_,_= DA.dataloader(args.datapath) 103 | if args.dataset == 'chairs': 104 | with open('FlyingChairs_train_val.txt', 'r') as f: 105 | split = [int(i) for i in f.readlines()] 106 | test_left_img = [test_left_img[i] for i,flag in enumerate(split) if flag==2] 107 | test_right_img = [test_right_img[i] for i,flag in enumerate(split) if flag==2] 108 | 109 | if args.model == 'VCN': 110 | from models.VCN import VCN 111 | elif args.model == 'VCN_small': 112 | from models.VCN_small import VCN 113 | #if '2015' in args.dataset: 114 | # model = VCN([1, maxw, maxh], md=[8,4,4,4,4], fac=2) 115 | #elif 'sintel' in args.dataset: 116 | # model = VCN([1, maxw, maxh], md=[7,4,4,4,4], fac=1.4) 117 | #else: 118 | # model = VCN([1, maxw, maxh], md=[4,4,4,4,4], fac=1) 119 | model = VCN([1, maxw, maxh], md=[int(4*(args.maxdisp/256)),4,4,4,4], fac=args.fac) 120 | 121 | model = nn.DataParallel(model, device_ids=[0]) 122 | model.cuda() 123 | if args.loadmodel is not None: 124 | 125 | pretrained_dict = torch.load(args.loadmodel) 126 | mean_L=pretrained_dict['mean_L'] 127 | mean_R=pretrained_dict['mean_R'] 128 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items() if 'grid' not in k and (('flow_reg' not in k) or ('conv1' in k))} 129 | 130 | model.load_state_dict(pretrained_dict['state_dict'],strict=False) 131 | else: 132 | mean_L = [[0.33,0.33,0.33]] 133 | mean_R = [[0.33,0.33,0.33]] 134 | print('dry run') 135 | 136 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 137 | 138 | 139 | mkdir_p('%s/%s/'% (args.outdir, args.dataset)) 140 | def main(): 141 | model.eval() 142 | ttime_all = [] 143 | for inx in range(len(test_left_img)): 144 | print(test_left_img[inx]) 145 | imgL_o = skimage.io.imread(test_left_img[inx]) 146 | imgR_o = skimage.io.imread(test_right_img[inx]) 147 | 148 | # for gray input images 149 | if len(imgL_o.shape) == 2: 150 | imgL_o = np.tile(imgL_o[:,:,np.newaxis],(1,1,3)) 151 | imgR_o = np.tile(imgR_o[:,:,np.newaxis],(1,1,3)) 152 | 153 | # resize 154 | maxh = imgL_o.shape[0]*args.testres 155 | maxw = imgL_o.shape[1]*args.testres 156 | max_h = int(maxh // 64 * 64) 157 | max_w = int(maxw // 64 * 64) 158 | if max_h < maxh: max_h += 64 159 | if max_w < maxw: max_w += 64 160 | 161 | input_size = imgL_o.shape 162 | imgL = cv2.resize(imgL_o,(max_w, max_h)) 163 | imgR = cv2.resize(imgR_o,(max_w, max_h)) 164 | 165 | # flip channel, subtract mean 166 | imgL = imgL[:,:,::-1].copy() / 255. - np.asarray(mean_L).mean(0)[np.newaxis,np.newaxis,:] 167 | imgR = imgR[:,:,::-1].copy() / 255. - np.asarray(mean_R).mean(0)[np.newaxis,np.newaxis,:] 168 | imgL = np.transpose(imgL, [2,0,1])[np.newaxis] 169 | imgR = np.transpose(imgR, [2,0,1])[np.newaxis] 170 | 171 | # support for any resolution inputs 172 | from models.VCN import WarpModule, flow_reg 173 | if hasattr(model.module, 'flow_reg64'): 174 | model.module.flow_reg64 = flow_reg([1,max_w//64,max_h//64], ent=model.module.flow_reg64.ent, maxdisp=model.module.flow_reg64.md, fac=model.module.flow_reg64.fac).cuda() 175 | if hasattr(model.module, 'flow_reg32'): 176 | model.module.flow_reg32 = flow_reg([1,max_w//64*2,max_h//64*2], ent=model.module.flow_reg32.ent, maxdisp=model.module.flow_reg32.md, fac=model.module.flow_reg32.fac).cuda() 177 | if hasattr(model.module, 'flow_reg16'): 178 | model.module.flow_reg16 = flow_reg([1,max_w//64*4,max_h//64*4], ent=model.module.flow_reg16.ent, maxdisp=model.module.flow_reg16.md, fac=model.module.flow_reg16.fac).cuda() 179 | if hasattr(model.module, 'flow_reg8'): 180 | model.module.flow_reg8 = flow_reg([1,max_w//64*8, max_h//64*8], ent=model.module.flow_reg8.ent, maxdisp=model.module.flow_reg8.md , fac = model.module.flow_reg8.fac).cuda() 181 | if hasattr(model.module, 'flow_reg4'): 182 | model.module.flow_reg4 = flow_reg([1,max_w//64*16, max_h//64*16 ], ent=model.module.flow_reg4.ent, maxdisp=model.module.flow_reg4.md , fac = model.module.flow_reg4.fac).cuda() 183 | model.module.warp5 = WarpModule([1,max_w//32,max_h//32]).cuda() 184 | model.module.warp4 = WarpModule([1,max_w//16,max_h//16]).cuda() 185 | model.module.warp3 = WarpModule([1,max_w//8, max_h//8]).cuda() 186 | model.module.warp2 = WarpModule([1,max_w//4, max_h//4]).cuda() 187 | model.module.warpx = WarpModule([1,max_w, max_h]).cuda() 188 | 189 | # forward 190 | imgL = Variable(torch.FloatTensor(imgL).cuda()) 191 | imgR = Variable(torch.FloatTensor(imgR).cuda()) 192 | with torch.no_grad(): 193 | imgLR = torch.cat([imgL,imgR],0) 194 | model.eval() 195 | torch.cuda.synchronize() 196 | start_time = time.time() 197 | rts = model(imgLR) 198 | torch.cuda.synchronize() 199 | ttime = (time.time() - start_time); print('time = %.2f' % (ttime*1000) ) 200 | ttime_all.append(ttime) 201 | pred_disp, entropy = rts 202 | 203 | # upsampling 204 | pred_disp = torch.squeeze(pred_disp).data.cpu().numpy() 205 | pred_disp = cv2.resize(np.transpose(pred_disp,(1,2,0)), (input_size[1], input_size[0])) 206 | pred_disp[:,:,0] *= input_size[1] / max_w 207 | pred_disp[:,:,1] *= input_size[0] / max_h 208 | flow = np.ones([pred_disp.shape[0],pred_disp.shape[1],3]) 209 | flow[:,:,:2] = pred_disp 210 | entropy = torch.squeeze(entropy).data.cpu().numpy() 211 | entropy = cv2.resize(entropy, (input_size[1], input_size[0])) 212 | 213 | # save predictions 214 | if args.dataset == 'mbstereo': 215 | dirname = '%s/%s/%s'%(args.outdir, args.dataset, test_left_img[inx].split('/')[-2]) 216 | mkdir_p(dirname) 217 | idxname = ('%s/%s')%(dirname.rsplit('/',1)[-1],test_left_img[inx].split('/')[-1]) 218 | else: 219 | idxname = test_left_img[inx].split('/')[-1] 220 | 221 | if args.dataset == 'mbstereo': 222 | with open(test_left_img[inx].replace('im0.png','calib.txt')) as f: 223 | lines = f.readlines() 224 | #max_disp = int(int(lines[9].split('=')[-1])) 225 | max_disp = int(int(lines[6].split('=')[-1])) 226 | with open('%s/%s/%s'% (args.outdir, args.dataset,idxname.replace('im0.png','disp0IO.pfm')),'w') as f: 227 | save_pfm(f,np.clip(-flow[::-1,:,0].astype(np.float32),0,max_disp) ) 228 | with open('%s/%s/%s/timeIO.txt'%(args.outdir, args.dataset,idxname.split('/')[0]),'w') as f: 229 | f.write(str(ttime)) 230 | elif args.dataset == 'k15stereo' or args.dataset == 'k12stereo': 231 | skimage.io.imsave('%s/%s/%s.png'% (args.outdir, args.dataset,idxname.split('.')[0]),(-flow[:,:,0].astype(np.float32)*256).astype('uint16')) 232 | else: 233 | write_flow('%s/%s/%s.png'% (args.outdir, args.dataset,idxname.rsplit('.',1)[0]), flow.copy()) 234 | cv2.imwrite('%s/%s/ent-%s.png'% (args.outdir, args.dataset,idxname.rsplit('.',1)[0]), entropy*200) 235 | 236 | torch.cuda.empty_cache() 237 | print(np.mean(ttime_all)) 238 | 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | 244 | -------------------------------------------------------------------------------- /thop/__init__.py: -------------------------------------------------------------------------------- 1 | from .profile import profile -------------------------------------------------------------------------------- /thop/count_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | multiply_adds = 1 7 | 8 | 9 | def count_convNd(m, x, y): 10 | x = x[0] 11 | 12 | kernel_ops = m.weight.size()[2:].numel() 13 | bias_ops = 1 if m.bias is not None else 0 14 | 15 | # cout x oW x oH 16 | total_ops = y.nelement() * (m.in_channels//m.groups * kernel_ops + bias_ops) 17 | m.total_ops = torch.Tensor([int(total_ops)]) 18 | 19 | 20 | def count_conv2d(m, x, y): 21 | x = x[0] 22 | 23 | cin = m.in_channels 24 | cout = m.out_channels 25 | kh, kw = m.kernel_size 26 | batch_size = x.size()[0] 27 | 28 | out_h = y.size(2) 29 | out_w = y.size(3) 30 | 31 | # ops per output element 32 | # kernel_mul = kh * kw * cin 33 | # kernel_add = kh * kw * cin - 1 34 | kernel_ops = multiply_adds * kh * kw 35 | bias_ops = 1 if m.bias is not None else 0 36 | ops_per_element = kernel_ops + bias_ops 37 | 38 | # total ops 39 | # num_out_elements = y.numel() 40 | output_elements = batch_size * out_w * out_h * cout 41 | total_ops = output_elements * ops_per_element * cin // m.groups 42 | 43 | m.total_ops = torch.Tensor([int(total_ops)]) 44 | 45 | 46 | def count_convtranspose2d(m, x, y): 47 | x = x[0] 48 | 49 | cin = m.in_channels 50 | cout = m.out_channels 51 | kh, kw = m.kernel_size 52 | # batch_size = x.size()[0] 53 | 54 | out_h = y.size(2) 55 | out_w = y.size(3) 56 | 57 | # ops per output element 58 | # kernel_mul = kh * kw * cin 59 | # kernel_add = kh * kw * cin - 1 60 | kernel_ops = multiply_adds * kh * kw * cin // m.groups 61 | bias_ops = 1 if m.bias is not None else 0 62 | ops_per_element = kernel_ops + bias_ops 63 | 64 | # total ops 65 | # num_out_elements = y.numel() 66 | # output_elements = batch_size * out_w * out_h * cout 67 | ops_per_element = m.weight.nelement() 68 | output_elements = y.nelement() 69 | total_ops = output_elements * ops_per_element 70 | 71 | m.total_ops = torch.Tensor([int(total_ops)]) 72 | import pdb; pdb.set_trace() 73 | print(m.total_ops) 74 | 75 | 76 | def count_bn(m, x, y): 77 | x = x[0] 78 | 79 | nelements = x.numel() 80 | # subtract, divide, gamma, beta 81 | total_ops = 4 * nelements 82 | 83 | m.total_ops = torch.Tensor([int(total_ops)]) 84 | 85 | 86 | def count_relu(m, x, y): 87 | x = x[0] 88 | 89 | nelements = x.numel() 90 | total_ops = nelements 91 | 92 | m.total_ops = torch.Tensor([int(total_ops)]) 93 | 94 | 95 | def count_softmax(m, x, y): 96 | x = x[0] 97 | 98 | batch_size, nfeatures = x.size() 99 | 100 | total_exp = nfeatures 101 | total_add = nfeatures - 1 102 | total_div = nfeatures 103 | total_ops = batch_size * (total_exp + total_add + total_div) 104 | 105 | m.total_ops = torch.Tensor([int(total_ops)]) 106 | 107 | 108 | def count_maxpool(m, x, y): 109 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) 110 | num_elements = y.numel() 111 | total_ops = kernel_ops * num_elements 112 | 113 | m.total_ops = torch.Tensor([int(total_ops)]) 114 | 115 | 116 | def count_adap_maxpool(m, x, y): 117 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 118 | kernel_ops = torch.prod(kernel) 119 | num_elements = y.numel() 120 | total_ops = kernel_ops * num_elements 121 | 122 | m.total_ops = torch.Tensor([int(total_ops)]) 123 | 124 | 125 | def count_avgpool(m, x, y): 126 | total_add = torch.prod(torch.Tensor([m.kernel_size])) 127 | total_div = 1 128 | kernel_ops = total_add + total_div 129 | num_elements = y.numel() 130 | total_ops = kernel_ops * num_elements 131 | 132 | m.total_ops = torch.Tensor([int(total_ops)]) 133 | 134 | 135 | def count_adap_avgpool(m, x, y): 136 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 137 | total_add = torch.prod(kernel) 138 | total_div = 1 139 | kernel_ops = total_add + total_div 140 | num_elements = y.numel() 141 | total_ops = kernel_ops * num_elements 142 | 143 | m.total_ops = torch.Tensor([int(total_ops)]) 144 | 145 | 146 | def count_linear(m, x, y): 147 | # per output element 148 | total_mul = m.in_features 149 | total_add = m.in_features - 1 150 | num_elements = y.numel() 151 | total_ops = (total_mul + total_add) * num_elements 152 | 153 | m.total_ops = torch.Tensor([int(total_ops)]) 154 | -------------------------------------------------------------------------------- /thop/profile.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.conv import _ConvNd 6 | 7 | from .count_hooks import * 8 | 9 | register_hooks = { 10 | nn.Conv1d: count_convNd, 11 | nn.Conv2d: count_convNd, 12 | nn.Conv3d: count_convNd, 13 | nn.ConvTranspose2d: count_convNd, 14 | 15 | nn.BatchNorm1d: count_bn, 16 | nn.BatchNorm2d: count_bn, 17 | nn.BatchNorm3d: count_bn, 18 | 19 | nn.ReLU: count_relu, 20 | nn.ReLU6: count_relu, 21 | nn.LeakyReLU: count_relu, 22 | 23 | nn.MaxPool1d: count_maxpool, 24 | nn.MaxPool2d: count_maxpool, 25 | nn.MaxPool3d: count_maxpool, 26 | nn.AdaptiveMaxPool1d: count_adap_maxpool, 27 | nn.AdaptiveMaxPool2d: count_adap_maxpool, 28 | nn.AdaptiveMaxPool3d: count_adap_maxpool, 29 | 30 | nn.AvgPool1d: count_avgpool, 31 | nn.AvgPool2d: count_avgpool, 32 | nn.AvgPool3d: count_avgpool, 33 | 34 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 35 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 36 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 37 | nn.Linear: count_linear, 38 | nn.Dropout: None, 39 | } 40 | 41 | 42 | def profile(model, input_size, custom_ops={}, device="cpu"): 43 | handler_collection = [] 44 | 45 | def add_hooks(m): 46 | if len(list(m.children())) > 0: 47 | return 48 | 49 | m.register_buffer('total_ops', torch.zeros(1)) 50 | m.register_buffer('total_params', torch.zeros(1)) 51 | 52 | for p in m.parameters(): 53 | m.total_params += torch.Tensor([p.numel()]) 54 | 55 | m_type = type(m) 56 | fn = None 57 | 58 | if m_type in custom_ops: 59 | fn = custom_ops[m_type] 60 | elif m_type in register_hooks: 61 | fn = register_hooks[m_type] 62 | else: 63 | print("Not implemented for ", m) 64 | 65 | if fn is not None: 66 | handler = m.register_forward_hook(fn) 67 | handler_collection.append(handler) 68 | 69 | original_device = model.parameters().__next__().device 70 | training = model.training 71 | 72 | model.eval().to(device) 73 | model.apply(add_hooks) 74 | 75 | x = torch.zeros(input_size).to(device) 76 | with torch.no_grad(): 77 | model(x) 78 | 79 | total_ops = 0 80 | total_params = 0 81 | for m in model.modules(): 82 | if len(list(m.children())) > 0: # skip for non-leaf module 83 | #import pdb; pdb.set_trace() 84 | #if 'butterfly' in str(m._get_name()): break 85 | print('-> %s'%(str(m._get_name()))) 86 | continue 87 | #if not '2d' in str(m._get_name()): continue 88 | #if not '3d' in str(m._get_name()): continue 89 | print("Registered FLOP counter (%.1f M/%.1f) for module %s" % (m.total_ops/1e6, m.total_params, str(m))) 90 | total_ops += m.total_ops 91 | total_params += m.total_params 92 | 93 | total_ops = total_ops.item() 94 | total_params = total_params.item() 95 | 96 | model.train(training).to(original_device) 97 | for handler in handler_collection: 98 | handler.remove() 99 | 100 | return total_ops, total_params 101 | -------------------------------------------------------------------------------- /thop/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def clever_format(num, format="%.2f"): 3 | if num > 1e12: 4 | return format % (num / 1e12) + "T" 5 | if num > 1e9: 6 | return format % (num / 1e9) + "G" 7 | if num > 1e6: 8 | return format % (num / 1e6) + "M" 9 | if num > 1e3: 10 | return format % (num / 1e3) + "K" -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/utils/__init__.py -------------------------------------------------------------------------------- /utils/flowlib.py: -------------------------------------------------------------------------------- 1 | """ 2 | # ============================== 3 | # flowlib.py 4 | # library for optical flow processing 5 | # Author: Ruoteng Li 6 | # Date: 6th Aug 2016 7 | # ============================== 8 | """ 9 | import png 10 | from . import pfm 11 | import numpy as np 12 | import matplotlib.colors as cl 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | import cv2 16 | import pdb 17 | 18 | 19 | UNKNOWN_FLOW_THRESH = 1e7 20 | SMALLFLOW = 0.0 21 | LARGEFLOW = 1e8 22 | 23 | """ 24 | ============= 25 | Flow Section 26 | ============= 27 | """ 28 | 29 | 30 | def show_flow(filename): 31 | """ 32 | visualize optical flow map using matplotlib 33 | :param filename: optical flow file 34 | :return: None 35 | """ 36 | flow = read_flow(filename) 37 | img = flow_to_image(flow) 38 | plt.imshow(img) 39 | plt.show() 40 | 41 | 42 | def visualize_flow(flow, mode='Y'): 43 | """ 44 | this function visualize the input flow 45 | :param flow: input flow in array 46 | :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color) 47 | :return: None 48 | """ 49 | if mode == 'Y': 50 | # Ccbcr color wheel 51 | img = flow_to_image(flow) 52 | plt.imshow(img) 53 | plt.show() 54 | elif mode == 'RGB': 55 | (h, w) = flow.shape[0:2] 56 | du = flow[:, :, 0] 57 | dv = flow[:, :, 1] 58 | valid = flow[:, :, 2] 59 | max_flow = max(np.max(du), np.max(dv)) 60 | img = np.zeros((h, w, 3), dtype=np.float64) 61 | # angle layer 62 | img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi) 63 | # magnitude layer, normalized to 1 64 | img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow 65 | # phase layer 66 | img[:, :, 2] = 8 - img[:, :, 1] 67 | # clip to [0,1] 68 | small_idx = img[:, :, 0:3] < 0 69 | large_idx = img[:, :, 0:3] > 1 70 | img[small_idx] = 0 71 | img[large_idx] = 1 72 | # convert to rgb 73 | img = cl.hsv_to_rgb(img) 74 | # remove invalid point 75 | import pdb; pdb.set_trace() 76 | img[:, :, 0] = img[:, :, 0] * valid 77 | img[:, :, 1] = img[:, :, 1] * valid 78 | img[:, :, 2] = img[:, :, 2] * valid 79 | # show 80 | plt.imshow(img) 81 | plt.show() 82 | 83 | return None 84 | 85 | 86 | def read_flow(filename): 87 | """ 88 | read optical flow data from flow file 89 | :param filename: name of the flow file 90 | :return: optical flow data in numpy array 91 | """ 92 | if filename.endswith('.flo'): 93 | flow = read_flo_file(filename) 94 | elif filename.endswith('.png'): 95 | flow = read_png_file(filename) 96 | elif filename.endswith('.pfm'): 97 | flow = read_pfm_file(filename) 98 | else: 99 | raise Exception('Invalid flow file format!') 100 | 101 | return flow 102 | 103 | 104 | def write_flow(flow, filename): 105 | """ 106 | write optical flow in Middlebury .flo format 107 | :param flow: optical flow map 108 | :param filename: optical flow file path to be saved 109 | :return: None 110 | """ 111 | f = open(filename, 'wb') 112 | magic = np.array([202021.25], dtype=np.float32) 113 | (height, width) = flow.shape[0:2] 114 | w = np.array([width], dtype=np.int32) 115 | h = np.array([height], dtype=np.int32) 116 | magic.tofile(f) 117 | w.tofile(f) 118 | h.tofile(f) 119 | flow.tofile(f) 120 | f.close() 121 | 122 | 123 | def save_flow_image(flow, image_file): 124 | """ 125 | save flow visualization into image file 126 | :param flow: optical flow data 127 | :param flow_fil 128 | :return: None 129 | """ 130 | flow_img = flow_to_image(flow) 131 | img_out = Image.fromarray(flow_img) 132 | img_out.save(image_file) 133 | 134 | 135 | def flowfile_to_imagefile(flow_file, image_file): 136 | """ 137 | convert flowfile into image file 138 | :param flow: optical flow data 139 | :param flow_fil 140 | :return: None 141 | """ 142 | flow = read_flow(flow_file) 143 | save_flow_image(flow, image_file) 144 | 145 | 146 | def segment_flow(flow): 147 | h = flow.shape[0] 148 | w = flow.shape[1] 149 | u = flow[:, :, 0] 150 | v = flow[:, :, 1] 151 | 152 | idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW)) 153 | idx2 = (abs(u) == SMALLFLOW) 154 | class0 = (v == 0) & (u == 0) 155 | u[idx2] = 0.00001 156 | tan_value = v / u 157 | 158 | class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0) 159 | class2 = (tan_value >= 1) & (u >= 0) & (v >= 0) 160 | class3 = (tan_value < -1) & (u <= 0) & (v >= 0) 161 | class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0) 162 | class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0) 163 | class7 = (tan_value < -1) & (u >= 0) & (v <= 0) 164 | class6 = (tan_value >= 1) & (u <= 0) & (v <= 0) 165 | class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0) 166 | 167 | seg = np.zeros((h, w)) 168 | 169 | seg[class1] = 1 170 | seg[class2] = 2 171 | seg[class3] = 3 172 | seg[class4] = 4 173 | seg[class5] = 5 174 | seg[class6] = 6 175 | seg[class7] = 7 176 | seg[class8] = 8 177 | seg[class0] = 0 178 | seg[idx] = 0 179 | 180 | return seg 181 | 182 | 183 | def flow_error(tu, tv, u, v): 184 | """ 185 | Calculate average end point error 186 | :param tu: ground-truth horizontal flow map 187 | :param tv: ground-truth vertical flow map 188 | :param u: estimated horizontal flow map 189 | :param v: estimated vertical flow map 190 | :return: End point error of the estimated flow 191 | """ 192 | smallflow = 0.0 193 | ''' 194 | stu = tu[bord+1:end-bord,bord+1:end-bord] 195 | stv = tv[bord+1:end-bord,bord+1:end-bord] 196 | su = u[bord+1:end-bord,bord+1:end-bord] 197 | sv = v[bord+1:end-bord,bord+1:end-bord] 198 | ''' 199 | stu = tu[:] 200 | stv = tv[:] 201 | su = u[:] 202 | sv = v[:] 203 | 204 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH) 205 | stu[idxUnknow] = 0 206 | stv[idxUnknow] = 0 207 | su[idxUnknow] = 0 208 | sv[idxUnknow] = 0 209 | 210 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)] 211 | index_su = su[ind2] 212 | index_sv = sv[ind2] 213 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1) 214 | un = index_su * an 215 | vn = index_sv * an 216 | 217 | index_stu = stu[ind2] 218 | index_stv = stv[ind2] 219 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1) 220 | tun = index_stu * tn 221 | tvn = index_stv * tn 222 | 223 | ''' 224 | angle = un * tun + vn * tvn + (an * tn) 225 | index = [angle == 1.0] 226 | angle[index] = 0.999 227 | ang = np.arccos(angle) 228 | mang = np.mean(ang) 229 | mang = mang * 180 / np.pi 230 | ''' 231 | 232 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2) 233 | epe = epe[ind2] 234 | mepe = np.mean(epe) 235 | return mepe 236 | 237 | 238 | def flow_to_image(flow): 239 | """ 240 | Convert flow into middlebury color code image 241 | :param flow: optical flow map 242 | :return: optical flow image in middlebury color 243 | """ 244 | u = flow[:, :, 0] 245 | v = flow[:, :, 1] 246 | 247 | maxu = -999. 248 | maxv = -999. 249 | minu = 999. 250 | minv = 999. 251 | 252 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 253 | u[idxUnknow] = 0 254 | v[idxUnknow] = 0 255 | 256 | maxu = max(maxu, np.max(u)) 257 | minu = min(minu, np.min(u)) 258 | 259 | maxv = max(maxv, np.max(v)) 260 | minv = min(minv, np.min(v)) 261 | 262 | rad = np.sqrt(u ** 2 + v ** 2) 263 | maxrad = max(-1, np.max(rad)) 264 | 265 | u = u/(maxrad + np.finfo(float).eps) 266 | v = v/(maxrad + np.finfo(float).eps) 267 | 268 | img = compute_color(u, v) 269 | 270 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 271 | img[idx] = 0 272 | 273 | return np.uint8(img) 274 | 275 | 276 | def evaluate_flow_file(gt_file, pred_file): 277 | """ 278 | evaluate the estimated optical flow end point error according to ground truth provided 279 | :param gt_file: ground truth file path 280 | :param pred_file: estimated optical flow file path 281 | :return: end point error, float32 282 | """ 283 | # Read flow files and calculate the errors 284 | gt_flow = read_flow(gt_file) # ground truth flow 285 | eva_flow = read_flow(pred_file) # predicted flow 286 | # Calculate errors 287 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1]) 288 | return average_pe 289 | 290 | 291 | def evaluate_flow(gt_flow, pred_flow): 292 | """ 293 | gt: ground-truth flow 294 | pred: estimated flow 295 | """ 296 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1]) 297 | return average_pe 298 | 299 | 300 | """ 301 | ============== 302 | Disparity Section 303 | ============== 304 | """ 305 | 306 | 307 | def read_disp_png(file_name): 308 | """ 309 | Read optical flow from KITTI .png file 310 | :param file_name: name of the flow file 311 | :return: optical flow data in matrix 312 | """ 313 | image_object = png.Reader(filename=file_name) 314 | image_direct = image_object.asDirect() 315 | image_data = list(image_direct[2]) 316 | (w, h) = image_direct[3]['size'] 317 | channel = len(image_data[0]) / w 318 | flow = np.zeros((h, w, channel), dtype=np.uint16) 319 | for i in range(len(image_data)): 320 | for j in range(channel): 321 | flow[i, :, j] = image_data[i][j::channel] 322 | return flow[:, :, 0] / 256 323 | 324 | 325 | def disp_to_flowfile(disp, filename): 326 | """ 327 | Read KITTI disparity file in png format 328 | :param disp: disparity matrix 329 | :param filename: the flow file name to save 330 | :return: None 331 | """ 332 | f = open(filename, 'wb') 333 | magic = np.array([202021.25], dtype=np.float32) 334 | (height, width) = disp.shape[0:2] 335 | w = np.array([width], dtype=np.int32) 336 | h = np.array([height], dtype=np.int32) 337 | empty_map = np.zeros((height, width), dtype=np.float32) 338 | data = np.dstack((disp, empty_map)) 339 | magic.tofile(f) 340 | w.tofile(f) 341 | h.tofile(f) 342 | data.tofile(f) 343 | f.close() 344 | 345 | 346 | """ 347 | ============== 348 | Image Section 349 | ============== 350 | """ 351 | 352 | 353 | def read_image(filename): 354 | """ 355 | Read normal image of any format 356 | :param filename: name of the image file 357 | :return: image data in matrix uint8 type 358 | """ 359 | img = Image.open(filename) 360 | im = np.array(img) 361 | return im 362 | 363 | 364 | def warp_image(im, flow): 365 | """ 366 | Use optical flow to warp image to the next 367 | :param im: image to warp 368 | :param flow: optical flow 369 | :return: warped image 370 | """ 371 | from scipy import interpolate 372 | image_height = im.shape[0] 373 | image_width = im.shape[1] 374 | flow_height = flow.shape[0] 375 | flow_width = flow.shape[1] 376 | n = image_height * image_width 377 | (iy, ix) = np.mgrid[0:image_height, 0:image_width] 378 | (fy, fx) = np.mgrid[0:flow_height, 0:flow_width] 379 | fx = fx.astype(np.float64) 380 | fy = fy.astype(np.float64) 381 | fx += flow[:,:,0] 382 | fy += flow[:,:,1] 383 | mask = np.logical_or(fx <0 , fx > flow_width) 384 | mask = np.logical_or(mask, fy < 0) 385 | mask = np.logical_or(mask, fy > flow_height) 386 | fx = np.minimum(np.maximum(fx, 0), flow_width) 387 | fy = np.minimum(np.maximum(fy, 0), flow_height) 388 | points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1) 389 | xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1) 390 | warp = np.zeros((image_height, image_width, im.shape[2])) 391 | for i in range(im.shape[2]): 392 | channel = im[:, :, i] 393 | plt.imshow(channel, cmap='gray') 394 | values = channel.reshape(n, 1) 395 | new_channel = interpolate.griddata(points, values, xi, method='cubic') 396 | new_channel = np.reshape(new_channel, [flow_height, flow_width]) 397 | new_channel[mask] = 1 398 | warp[:, :, i] = new_channel.astype(np.uint8) 399 | 400 | return warp.astype(np.uint8) 401 | 402 | 403 | """ 404 | ============== 405 | Others 406 | ============== 407 | """ 408 | 409 | def pfm_to_flo(pfm_file): 410 | flow_filename = pfm_file[0:pfm_file.find('.pfm')] + '.flo' 411 | (data, scale) = pfm.readPFM(pfm_file) 412 | flow = data[:, :, 0:2] 413 | write_flow(flow, flow_filename) 414 | 415 | 416 | def scale_image(image, new_range): 417 | """ 418 | Linearly scale the image into desired range 419 | :param image: input image 420 | :param new_range: the new range to be aligned 421 | :return: image normalized in new range 422 | """ 423 | min_val = np.min(image).astype(np.float32) 424 | max_val = np.max(image).astype(np.float32) 425 | min_val_new = np.array(min(new_range), dtype=np.float32) 426 | max_val_new = np.array(max(new_range), dtype=np.float32) 427 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new 428 | return scaled_image.astype(np.uint8) 429 | 430 | 431 | def compute_color(u, v): 432 | """ 433 | compute optical flow color map 434 | :param u: optical flow horizontal map 435 | :param v: optical flow vertical map 436 | :return: optical flow in color code 437 | """ 438 | [h, w] = u.shape 439 | img = np.zeros([h, w, 3]) 440 | nanIdx = np.isnan(u) | np.isnan(v) 441 | u[nanIdx] = 0 442 | v[nanIdx] = 0 443 | 444 | colorwheel = make_color_wheel() 445 | ncols = np.size(colorwheel, 0) 446 | 447 | rad = np.sqrt(u**2+v**2) 448 | 449 | a = np.arctan2(-v, -u) / np.pi 450 | 451 | fk = (a+1) / 2 * (ncols - 1) + 1 452 | 453 | k0 = np.floor(fk).astype(int) 454 | 455 | k1 = k0 + 1 456 | k1[k1 == ncols+1] = 1 457 | f = fk - k0 458 | 459 | for i in range(0, np.size(colorwheel,1)): 460 | tmp = colorwheel[:, i] 461 | col0 = tmp[k0-1] / 255 462 | col1 = tmp[k1-1] / 255 463 | col = (1-f) * col0 + f * col1 464 | 465 | idx = rad <= 1 466 | col[idx] = 1-rad[idx]*(1-col[idx]) 467 | notidx = np.logical_not(idx) 468 | 469 | col[notidx] *= 0.75 470 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 471 | 472 | return img 473 | 474 | 475 | def make_color_wheel(): 476 | """ 477 | Generate color wheel according Middlebury color code 478 | :return: Color wheel 479 | """ 480 | RY = 15 481 | YG = 6 482 | GC = 4 483 | CB = 11 484 | BM = 13 485 | MR = 6 486 | 487 | ncols = RY + YG + GC + CB + BM + MR 488 | 489 | colorwheel = np.zeros([ncols, 3]) 490 | 491 | col = 0 492 | 493 | # RY 494 | colorwheel[0:RY, 0] = 255 495 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 496 | col += RY 497 | 498 | # YG 499 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 500 | colorwheel[col:col+YG, 1] = 255 501 | col += YG 502 | 503 | # GC 504 | colorwheel[col:col+GC, 1] = 255 505 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 506 | col += GC 507 | 508 | # CB 509 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 510 | colorwheel[col:col+CB, 2] = 255 511 | col += CB 512 | 513 | # BM 514 | colorwheel[col:col+BM, 2] = 255 515 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 516 | col += + BM 517 | 518 | # MR 519 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 520 | colorwheel[col:col+MR, 0] = 255 521 | 522 | return colorwheel 523 | 524 | 525 | def read_flo_file(filename): 526 | """ 527 | Read from Middlebury .flo file 528 | :param flow_file: name of the flow file 529 | :return: optical flow data in matrix 530 | """ 531 | f = open(filename, 'rb') 532 | magic = np.fromfile(f, np.float32, count=1) 533 | data2d = None 534 | 535 | if 202021.25 != magic: 536 | print('Magic number incorrect. Invalid .flo file') 537 | else: 538 | w = np.fromfile(f, np.int32, count=1) 539 | h = np.fromfile(f, np.int32, count=1) 540 | #print("Reading %d x %d flow file in .flo format" % (h, w)) 541 | flow = np.ones((h[0],w[0],3)) 542 | data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) 543 | # reshape data into 3D array (columns, rows, channels) 544 | data2d = np.resize(data2d, (h[0], w[0], 2)) 545 | flow[:,:,:2] = data2d 546 | f.close() 547 | return flow 548 | 549 | 550 | def read_png_file(flow_file): 551 | """ 552 | Read from KITTI .png file 553 | :param flow_file: name of the flow file 554 | :return: optical flow data in matrix 555 | """ 556 | flow = cv2.imread(flow_file,-1)[:,:,::-1].astype(np.float64) 557 | # flow_object = png.Reader(filename=flow_file) 558 | # flow_direct = flow_object.asDirect() 559 | # flow_data = list(flow_direct[2]) 560 | # (w, h) = flow_direct[3]['size'] 561 | # #print("Reading %d x %d flow file in .png format" % (h, w)) 562 | # flow = np.zeros((h, w, 3), dtype=np.float64) 563 | # for i in range(len(flow_data)): 564 | # flow[i, :, 0] = flow_data[i][0::3] 565 | # flow[i, :, 1] = flow_data[i][1::3] 566 | # flow[i, :, 2] = flow_data[i][2::3] 567 | 568 | invalid_idx = (flow[:, :, 2] == 0) 569 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 570 | flow[invalid_idx, 0] = 0 571 | flow[invalid_idx, 1] = 0 572 | return flow 573 | 574 | 575 | def read_pfm_file(flow_file): 576 | """ 577 | Read from .pfm file 578 | :param flow_file: name of the flow file 579 | :return: optical flow data in matrix 580 | """ 581 | (data, scale) = pfm.readPFM(flow_file) 582 | return data 583 | 584 | 585 | # fast resample layer 586 | def resample(img, sz): 587 | """ 588 | img: flow map to be resampled 589 | sz: new flow map size. Must be [height,weight] 590 | """ 591 | original_image_size = img.shape 592 | in_height = img.shape[0] 593 | in_width = img.shape[1] 594 | out_height = sz[0] 595 | out_width = sz[1] 596 | out_flow = np.zeros((out_height, out_width, 2)) 597 | # find scale 598 | height_scale = float(in_height) / float(out_height) 599 | width_scale = float(in_width) / float(out_width) 600 | 601 | [x,y] = np.meshgrid(range(out_width), range(out_height)) 602 | xx = x * width_scale 603 | yy = y * height_scale 604 | x0 = np.floor(xx).astype(np.int32) 605 | x1 = x0 + 1 606 | y0 = np.floor(yy).astype(np.int32) 607 | y1 = y0 + 1 608 | 609 | x0 = np.clip(x0,0,in_width-1) 610 | x1 = np.clip(x1,0,in_width-1) 611 | y0 = np.clip(y0,0,in_height-1) 612 | y1 = np.clip(y1,0,in_height-1) 613 | 614 | Ia = img[y0,x0,:] 615 | Ib = img[y1,x0,:] 616 | Ic = img[y0,x1,:] 617 | Id = img[y1,x1,:] 618 | 619 | wa = (y1-yy) * (x1-xx) 620 | wb = (yy-y0) * (x1-xx) 621 | wc = (y1-yy) * (xx-x0) 622 | wd = (yy-y0) * (xx-x0) 623 | out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width 624 | out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height 625 | 626 | return out_flow 627 | 628 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import shutil 4 | import sys 5 | import traceback 6 | import zipfile 7 | 8 | if sys.version_info[0] == 2: 9 | import urllib2 10 | else: 11 | import urllib.request 12 | 13 | 14 | # Converts a string to bytes (for writing the string into a file). Provided for 15 | # compatibility with Python 2 and 3. 16 | def StrToBytes(text): 17 | if sys.version_info[0] == 2: 18 | return text 19 | else: 20 | return bytes(text, 'UTF-8') 21 | 22 | 23 | # Outputs the given text and lets the user input a response (submitted by 24 | # pressing the return key). Provided for compatibility with Python 2 and 3. 25 | def GetUserInput(text): 26 | if sys.version_info[0] == 2: 27 | return raw_input(text) 28 | else: 29 | return input(text) 30 | 31 | 32 | # Creates the given directory (hierarchy), which may already exist. Provided for 33 | # compatibility with Python 2 and 3. 34 | def MakeDirsExistOk(directory_path): 35 | try: 36 | os.makedirs(directory_path) 37 | except OSError as exception: 38 | if exception.errno != errno.EEXIST: 39 | raise 40 | 41 | 42 | # Deletes all files and folders within the given folder. 43 | def DeleteFolderContents(folder_path): 44 | for file_name in os.listdir(folder_path): 45 | file_path = os.path.join(folder_path, file_name) 46 | try: 47 | if os.path.isfile(file_path): 48 | os.unlink(file_path) 49 | else: #if os.path.isdir(file_path): 50 | shutil.rmtree(file_path) 51 | except Exception as e: 52 | print('Exception in DeleteFolderContents():') 53 | print(e) 54 | print('Stack trace:') 55 | print(traceback.format_exc()) 56 | 57 | 58 | # Creates the given directory, respectively deletes all content of the directory 59 | # in case it already exists. 60 | def MakeCleanDirectory(folder_path): 61 | if os.path.isdir(folder_path): 62 | DeleteFolderContents(folder_path) 63 | else: 64 | MakeDirsExistOk(folder_path) 65 | 66 | 67 | # Downloads the given URL to a file in the given directory. Returns the 68 | # path to the downloaded file. 69 | # In part adapted from: https://stackoverflow.com/questions/22676 70 | def DownloadFile(url, dest_dir_path): 71 | file_name = url.split('/')[-1] 72 | dest_file_path = os.path.join(dest_dir_path, file_name) 73 | 74 | if os.path.isfile(dest_file_path): 75 | print('The following file already exists:') 76 | print(dest_file_path) 77 | print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.') 78 | while True: 79 | response = GetUserInput("> ") 80 | if response == 's': 81 | return dest_file_path 82 | elif response == 'o': 83 | break 84 | else: 85 | print('Please enter o or s.') 86 | 87 | url_object = None 88 | if sys.version_info[0] == 2: 89 | url_object = urllib2.urlopen(url) 90 | else: 91 | url_object = urllib.request.urlopen(url) 92 | 93 | with open(dest_file_path, 'wb') as outfile: 94 | meta = url_object.info() 95 | file_size = 0 96 | if sys.version_info[0] == 2: 97 | file_size = int(meta.getheaders("Content-Length")[0]) 98 | else: 99 | file_size = int(meta["Content-Length"]) 100 | print("Downloading: %s (size [bytes]: %s)" % (url, file_size)) 101 | 102 | file_size_downloaded = 0 103 | block_size = 8192 104 | while True: 105 | buffer = url_object.read(block_size) 106 | if not buffer: 107 | break 108 | 109 | file_size_downloaded += len(buffer) 110 | outfile.write(buffer) 111 | 112 | sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size)) 113 | sys.stdout.flush() 114 | 115 | return dest_file_path 116 | 117 | 118 | # Unzips the given zip file into the given directory. 119 | def UnzipFile(file_path, unzip_dir_path, overwrite=True): 120 | zip_ref = zipfile.ZipFile(open(file_path, 'rb')) 121 | 122 | if not overwrite: 123 | for f in zip_ref.namelist(): 124 | if not os.path.isfile(os.path.join(unzip_dir_path, f)): 125 | zip_ref.extract(f, path=unzip_dir_path) 126 | else: 127 | print('Not overwriting {}'.format(f)) 128 | else: 129 | zip_ref.extractall(unzip_dir_path) 130 | zip_ref.close() 131 | 132 | 133 | # Creates a zip file with the contents of the given directory. 134 | # The archive_base_path must not include the extension .zip. The full, final 135 | # path of the archive is returned by the function. 136 | def ZipDirectory(archive_base_path, root_dir_path): 137 | # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED 138 | with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf: 139 | base_path = os.path.normpath(root_dir_path) 140 | for dirpath, dirnames, filenames in os.walk(root_dir_path): 141 | for name in sorted(dirnames): 142 | path = os.path.normpath(os.path.join(dirpath, name)) 143 | zf.write(path, os.path.relpath(path, base_path)) 144 | for name in filenames: 145 | path = os.path.normpath(os.path.join(dirpath, name)) 146 | if os.path.isfile(path): 147 | zf.write(path, os.path.relpath(path, base_path)) 148 | 149 | return archive_base_path+'.zip' 150 | 151 | 152 | # Downloads a zip file and directly unzips it. 153 | def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True): 154 | archive_path = DownloadFile(url, archive_dir_path) 155 | UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite) 156 | 157 | def mkdir_p(path): 158 | try: 159 | os.makedirs(path) 160 | except OSError as exc: # Python >2.5 161 | if exc.errno == errno.EEXIST and os.path.isdir(path): 162 | pass 163 | else: 164 | raise 165 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: logger.py 3 | Modified by: Senthil Purushwalkam 4 | Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 5 | Email: spurushwandrewcmuedu 6 | Github: https://github.com/senthilps8 7 | Description: 8 | """ 9 | import pdb 10 | import tensorflow as tf 11 | from torch.autograd import Variable 12 | import numpy as np 13 | import scipy.misc 14 | import os 15 | try: 16 | from StringIO import StringIO # Python 2.7 17 | except ImportError: 18 | from io import BytesIO # Python 3.x 19 | 20 | 21 | class Logger(object): 22 | 23 | def __init__(self, log_dir, name=None): 24 | """Create a summary writer logging to log_dir.""" 25 | if name is None: 26 | name = 'temp' 27 | self.name = name 28 | if name is not None: 29 | try: 30 | os.makedirs(os.path.join(log_dir, name)) 31 | except: 32 | pass 33 | self.writer = tf.summary.FileWriter(os.path.join(log_dir, name), 34 | filename_suffix=name) 35 | else: 36 | self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name) 37 | 38 | def scalar_summary(self, tag, value, step): 39 | """Log a scalar variable.""" 40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 41 | self.writer.add_summary(summary, step) 42 | 43 | def image_summary(self, tag, images, step): 44 | """Log a list of images.""" 45 | 46 | img_summaries = [] 47 | for i, img in enumerate(images): 48 | # Write the image to a string 49 | try: 50 | s = StringIO() 51 | except: 52 | s = BytesIO() 53 | scipy.misc.toimage(img).save(s, format="png") 54 | 55 | # Create an Image object 56 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 57 | height=img.shape[0], 58 | width=img.shape[1]) 59 | # Create a Summary value 60 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 61 | 62 | # Create and write Summary 63 | summary = tf.Summary(value=img_summaries) 64 | self.writer.add_summary(summary, step) 65 | 66 | def histo_summary(self, tag, values, step, bins=1000): 67 | """Log a histogram of the tensor of values.""" 68 | 69 | # Create a histogram using numpy 70 | counts, bin_edges = np.histogram(values, bins=bins) 71 | 72 | # Fill the fields of the histogram proto 73 | hist = tf.HistogramProto() 74 | hist.min = float(np.min(values)) 75 | hist.max = float(np.max(values)) 76 | hist.num = int(np.prod(values.shape)) 77 | hist.sum = float(np.sum(values)) 78 | hist.sum_squares = float(np.sum(values**2)) 79 | 80 | # Drop the start of the first bin 81 | bin_edges = bin_edges[1:] 82 | 83 | # Add bin edges and counts 84 | for edge in bin_edges: 85 | hist.bucket_limit.append(edge) 86 | for c in counts: 87 | hist.bucket.append(c) 88 | 89 | # Create and write Summary 90 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 91 | self.writer.add_summary(summary, step) 92 | self.writer.flush() 93 | 94 | def to_np(self, x): 95 | return x.data.cpu().numpy() 96 | 97 | def to_var(self, x): 98 | if torch.cuda.is_available(): 99 | x = x.cuda() 100 | return Variable(x) 101 | 102 | def model_param_histo_summary(self, model, step): 103 | """log histogram summary of model's parameters 104 | and parameter gradients 105 | """ 106 | for tag, value in model.named_parameters(): 107 | if value.grad is None: 108 | continue 109 | tag = tag.replace('.', '/') 110 | tag = self.name+'/'+tag 111 | self.histo_summary(tag, self.to_np(value), step) 112 | self.histo_summary(tag+'/grad', self.to_np(value.grad), step) 113 | 114 | -------------------------------------------------------------------------------- /utils/multiscaleloss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/ClementPinard/FlowNetPytorch 3 | """ 4 | import pdb 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def EPE(input_flow, target_flow, mask, sparse=False, mean=True): 10 | #mask = target_flow[:,2]>0 11 | target_flow = target_flow[:,:2] 12 | EPE_map = torch.norm(target_flow-input_flow,2,1) 13 | batch_size = EPE_map.size(0) 14 | if sparse: 15 | # invalid flow is defined with both flow coordinates to be exactly 0 16 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) 17 | 18 | EPE_map = EPE_map[~mask] 19 | if mean: 20 | return EPE_map[mask].mean() 21 | else: 22 | return EPE_map[mask].sum()/batch_size 23 | 24 | def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True): 25 | #mask = target_flow[:,2]>0 26 | target_flow = target_flow[:,:2] 27 | #TODO 28 | # EPE_map = torch.norm(target_flow-input_flow,2,1) 29 | EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4) 30 | batch_size = EPE_map.size(0) 31 | if sparse: 32 | # invalid flow is defined with both flow coordinates to be exactly 0 33 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) 34 | 35 | EPE_map = EPE_map[~mask] 36 | if mean: 37 | return EPE_map[mask].mean() 38 | else: 39 | return EPE_map[mask].sum()/batch_size 40 | 41 | def sparse_max_pool(input, size): 42 | '''Downsample the input by considering 0 values as invalid. 43 | 44 | Unfortunately, no generic interpolation mode can resize a sparse map correctly, 45 | the strategy here is to use max pooling for positive values and "min pooling" 46 | for negative values, the two results are then summed. 47 | This technique allows sparsity to be minized, contrary to nearest interpolation, 48 | which could potentially lose information for isolated data points.''' 49 | 50 | positive = (input > 0).float() 51 | negative = (input < 0).float() 52 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size) 53 | return output 54 | 55 | 56 | def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False): 57 | def one_scale(output, target, mask, sparse): 58 | 59 | b, _, h, w = output.size() 60 | 61 | if sparse: 62 | target_scaled = sparse_max_pool(target, (h, w)) 63 | else: 64 | target_scaled = F.interpolate(target, (h, w), mode='area') 65 | mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1 66 | if rob_loss: 67 | return rob_EPE(output, target_scaled, mask, sparse, mean=False) 68 | else: 69 | return EPE(output, target_scaled, mask, sparse, mean=False) 70 | 71 | if type(network_output) not in [tuple, list]: 72 | network_output = [network_output] 73 | if weights is None: 74 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article 75 | assert(len(weights) == len(network_output)) 76 | 77 | loss = 0 78 | for output, weight in zip(network_output, weights): 79 | loss += weight * one_scale(output, target_flow, mask, sparse) 80 | return loss 81 | 82 | 83 | def realEPE(output, target, mask, sparse=False): 84 | b, _, h, w = target.size() 85 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) 86 | return EPE(upsampled_output, target,mask, sparse, mean=True) 87 | -------------------------------------------------------------------------------- /utils/pfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | def readPFM(file): 6 | file = open(file, 'rb') 7 | 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().rstrip() 15 | if (sys.version[0]) == '3': 16 | header = header.decode('utf-8') 17 | if header == 'PF': 18 | color = True 19 | elif header == 'Pf': 20 | color = False 21 | else: 22 | raise Exception('Not a PFM file.') 23 | 24 | if (sys.version[0]) == '3': 25 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 26 | else: 27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 28 | if dim_match: 29 | width, height = map(int, dim_match.groups()) 30 | else: 31 | raise Exception('Malformed PFM header.') 32 | 33 | if (sys.version[0]) == '3': 34 | scale = float(file.readline().rstrip().decode('utf-8')) 35 | else: 36 | scale = float(file.readline().rstrip()) 37 | 38 | if scale < 0: # little-endian 39 | endian = '<' 40 | scale = -scale 41 | else: 42 | endian = '>' # big-endian 43 | 44 | data = np.fromfile(file, endian + 'f') 45 | shape = (height, width, 3) if color else (height, width) 46 | 47 | data = np.reshape(data, shape) 48 | data = np.flipud(data) 49 | return data, scale 50 | 51 | 52 | def writePFM(file, image, scale=1): 53 | file = open(file, 'wb') 54 | 55 | color = None 56 | 57 | if image.dtype.name != 'float32': 58 | raise Exception('Image dtype must be float32.') 59 | 60 | image = np.flipud(image) 61 | 62 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 63 | color = True 64 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 65 | color = False 66 | else: 67 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 68 | 69 | file.write('PF\n' if color else 'Pf\n') 70 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 71 | 72 | endian = image.dtype.byteorder 73 | 74 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 75 | scale = -scale 76 | 77 | file.write('%f\n' % scale) 78 | 79 | image.tofile(file) 80 | -------------------------------------------------------------------------------- /utils/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if (sys.version[0]) == '3': 17 | header = header.decode('utf-8') 18 | if header == 'PF': 19 | color = True 20 | elif header == 'Pf': 21 | color = False 22 | else: 23 | raise Exception('Not a PFM file.') 24 | 25 | if (sys.version[0]) == '3': 26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 27 | else: 28 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 29 | if dim_match: 30 | width, height = map(int, dim_match.groups()) 31 | else: 32 | raise Exception('Malformed PFM header.') 33 | 34 | if (sys.version[0]) == '3': 35 | scale = float(file.readline().rstrip().decode('utf-8')) 36 | else: 37 | scale = float(file.readline().rstrip()) 38 | 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data = np.fromfile(file, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | return data, scale 51 | 52 | -------------------------------------------------------------------------------- /utils/sintel_io.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python2 2 | 3 | """ 4 | I/O script to save and load the data coming with the MPI-Sintel low-level 5 | computer vision benchmark. 6 | 7 | For more details about the benchmark, please visit www.mpi-sintel.de 8 | 9 | CHANGELOG: 10 | v1.0 (2015/02/03): First release 11 | 12 | Copyright (c) 2015 Jonas Wulff 13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany 14 | 15 | """ 16 | 17 | # Requirements: Numpy as PIL/Pillow 18 | import numpy as np 19 | from PIL import Image 20 | 21 | # Check for endianness, based on Daniel Scharstein's optical flow code. 22 | # Using little-endian architecture, these two should be equal. 23 | TAG_FLOAT = 202021.25 24 | TAG_CHAR = 'PIEH' 25 | 26 | def flow_read(filename): 27 | """ Read optical flow from file, return (U,V) tuple. 28 | 29 | Original code by Deqing Sun, adapted from Daniel Scharstein. 30 | """ 31 | f = open(filename,'rb') 32 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 33 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 34 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 35 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 36 | size = width*height 37 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 38 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) 39 | u = tmp[:,np.arange(width)*2] 40 | v = tmp[:,np.arange(width)*2 + 1] 41 | return u,v 42 | 43 | def flow_write(filename,uv,v=None): 44 | """ Write optical flow to file. 45 | 46 | If v is None, uv is assumed to contain both u and v channels, 47 | stacked in depth. 48 | 49 | Original code by Deqing Sun, adapted from Daniel Scharstein. 50 | """ 51 | nBands = 2 52 | 53 | if v is None: 54 | assert(uv.ndim == 3) 55 | assert(uv.shape[2] == 2) 56 | u = uv[:,:,0] 57 | v = uv[:,:,1] 58 | else: 59 | u = uv 60 | 61 | assert(u.shape == v.shape) 62 | height,width = u.shape 63 | f = open(filename,'wb') 64 | # write the header 65 | f.write(TAG_CHAR) 66 | np.array(width).astype(np.int32).tofile(f) 67 | np.array(height).astype(np.int32).tofile(f) 68 | # arrange into matrix form 69 | tmp = np.zeros((height, width*nBands)) 70 | tmp[:,np.arange(width)*2] = u 71 | tmp[:,np.arange(width)*2 + 1] = v 72 | tmp.astype(np.float32).tofile(f) 73 | f.close() 74 | 75 | 76 | def depth_read(filename): 77 | """ Read depth data from file, return as numpy array. """ 78 | f = open(filename,'rb') 79 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 80 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 81 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 82 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 83 | size = width*height 84 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 85 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) 86 | return depth 87 | 88 | def depth_write(filename, depth): 89 | """ Write depth to file. """ 90 | height,width = depth.shape[:2] 91 | f = open(filename,'wb') 92 | # write the header 93 | f.write(TAG_CHAR) 94 | np.array(width).astype(np.int32).tofile(f) 95 | np.array(height).astype(np.int32).tofile(f) 96 | 97 | depth.astype(np.float32).tofile(f) 98 | f.close() 99 | 100 | 101 | def disparity_write(filename,disparity,bitdepth=16): 102 | """ Write disparity to file. 103 | 104 | bitdepth can be either 16 (default) or 32. 105 | 106 | The maximum disparity is 1024, since the image width in Sintel 107 | is 1024. 108 | """ 109 | d = disparity.copy() 110 | 111 | # Clip disparity. 112 | d[d>1024] = 1024 113 | d[d<0] = 0 114 | 115 | d_r = (d / 4.0).astype('uint8') 116 | d_g = ((d * (2.0**6)) % 256).astype('uint8') 117 | 118 | out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8') 119 | out[:,:,0] = d_r 120 | out[:,:,1] = d_g 121 | 122 | if bitdepth > 16: 123 | d_b = (d * (2**14) % 256).astype('uint8') 124 | out[:,:,2] = d_b 125 | 126 | Image.fromarray(out,'RGB').save(filename,'PNG') 127 | 128 | 129 | def disparity_read(filename): 130 | """ Return disparity read from filename. """ 131 | f_in = np.array(Image.open(filename)) 132 | d_r = f_in[:,:,0].astype('float64') 133 | d_g = f_in[:,:,1].astype('float64') 134 | d_b = f_in[:,:,2].astype('float64') 135 | 136 | depth = d_r * 4 + d_g / (2**6) + d_b / (2**14) 137 | return depth 138 | 139 | 140 | #def cam_read(filename): 141 | # """ Read camera data, return (M,N) tuple. 142 | # 143 | # M is the intrinsic matrix, N is the extrinsic matrix, so that 144 | # 145 | # x = M*N*X, 146 | # with x being a point in homogeneous image pixel coordinates, X being a 147 | # point in homogeneous world coordinates. 148 | # """ 149 | # txtdata = np.loadtxt(filename) 150 | # intrinsic = txtdata[0,:9].reshape((3,3)) 151 | # extrinsic = textdata[1,:12].reshape((3,4)) 152 | # return intrinsic,extrinsic 153 | # 154 | # 155 | #def cam_write(filename,M,N): 156 | # """ Write intrinsic matrix M and extrinsic matrix N to file. """ 157 | # Z = np.zeros((2,12)) 158 | # Z[0,:9] = M.ravel() 159 | # Z[1,:12] = N.ravel() 160 | # np.savetxt(filename,Z) 161 | 162 | def cam_read(filename): 163 | """ Read camera data, return (M,N) tuple. 164 | 165 | M is the intrinsic matrix, N is the extrinsic matrix, so that 166 | 167 | x = M*N*X, 168 | with x being a point in homogeneous image pixel coordinates, X being a 169 | point in homogeneous world coordinates. 170 | """ 171 | f = open(filename,'rb') 172 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 173 | assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 174 | M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) 175 | N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) 176 | return M,N 177 | 178 | def cam_write(filename, M, N): 179 | """ Write intrinsic matrix M and extrinsic matrix N to file. """ 180 | f = open(filename,'wb') 181 | # write the header 182 | f.write(TAG_CHAR) 183 | M.astype('float64').tofile(f) 184 | N.astype('float64').tofile(f) 185 | f.close() 186 | 187 | 188 | def segmentation_write(filename,segmentation): 189 | """ Write segmentation to file. """ 190 | 191 | segmentation_ = segmentation.astype('int32') 192 | seg_r = np.floor(segmentation_ / (256**2)).astype('uint8') 193 | seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8') 194 | seg_b = np.floor(segmentation_ % 256).astype('uint8') 195 | 196 | out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8') 197 | out[:,:,0] = seg_r 198 | out[:,:,1] = seg_g 199 | out[:,:,2] = seg_b 200 | 201 | Image.fromarray(out,'RGB').save(filename,'PNG') 202 | 203 | 204 | def segmentation_read(filename): 205 | """ Return disparity read from filename. """ 206 | f_in = np.array(Image.open(filename)) 207 | seg_r = f_in[:,:,0].astype('int32') 208 | seg_g = f_in[:,:,1].astype('int32') 209 | seg_b = f_in[:,:,2].astype('int32') 210 | 211 | segmentation = (seg_r * 256 + seg_g) * 256 + seg_b 212 | return segmentation 213 | 214 | 215 | -------------------------------------------------------------------------------- /utils/util_flow.py: -------------------------------------------------------------------------------- 1 | import math 2 | import png 3 | import struct 4 | import array 5 | import numpy as np 6 | import cv2 7 | import pdb 8 | 9 | from io import * 10 | 11 | UNKNOWN_FLOW_THRESH = 1e9; 12 | UNKNOWN_FLOW = 1e10; 13 | 14 | # Middlebury checks 15 | TAG_STRING = 'PIEH' # use this when WRITING the file 16 | TAG_FLOAT = 202021.25 # check for this when READING the file 17 | 18 | def readPFM(file): 19 | import re 20 | file = open(file, 'rb') 21 | 22 | color = None 23 | width = None 24 | height = None 25 | scale = None 26 | endian = None 27 | 28 | header = file.readline().rstrip() 29 | if header == b'PF': 30 | color = True 31 | elif header == b'Pf': 32 | color = False 33 | else: 34 | raise Exception('Not a PFM file.') 35 | 36 | dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline()) 37 | if dim_match: 38 | width, height = map(int, dim_match.groups()) 39 | else: 40 | raise Exception('Malformed PFM header.') 41 | 42 | scale = float(file.readline().rstrip()) 43 | if scale < 0: # little-endian 44 | endian = '<' 45 | scale = -scale 46 | else: 47 | endian = '>' # big-endian 48 | 49 | data = np.fromfile(file, endian + 'f') 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | return data, scale 55 | 56 | 57 | def save_pfm(file, image, scale = 1): 58 | import sys 59 | color = None 60 | 61 | if image.dtype.name != 'float32': 62 | raise Exception('Image dtype must be float32.') 63 | 64 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 65 | color = True 66 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 67 | color = False 68 | else: 69 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 70 | 71 | file.write('PF\n' if color else 'Pf\n') 72 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 73 | 74 | endian = image.dtype.byteorder 75 | 76 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 77 | scale = -scale 78 | 79 | file.write('%f\n' % scale) 80 | 81 | image.tofile(file) 82 | 83 | 84 | def ReadMiddleburyFloFile(path): 85 | """ Read .FLO file as specified by Middlebury. 86 | 87 | Returns tuple (width, height, u, v, mask), where u, v, mask are flat 88 | arrays of values. 89 | """ 90 | 91 | with open(path, 'rb') as fil: 92 | tag = struct.unpack('f', fil.read(4))[0] 93 | width = struct.unpack('i', fil.read(4))[0] 94 | height = struct.unpack('i', fil.read(4))[0] 95 | 96 | assert tag == TAG_FLOAT 97 | 98 | #data = np.fromfile(path, dtype=np.float, count=-1) 99 | #data = data[3:] 100 | 101 | fmt = 'f' * width*height*2 102 | data = struct.unpack(fmt, fil.read(4*width*height*2)) 103 | 104 | u = data[::2] 105 | v = data[1::2] 106 | 107 | mask = map(lambda x,y: abs(x) 0: 144 | # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2]) 145 | 146 | #png_reader.close() 147 | 148 | return (width, height, u, v, mask) 149 | 150 | 151 | def WriteMiddleburyFloFile(path, width, height, u, v, mask=None): 152 | """ Write .FLO file as specified by Middlebury. 153 | """ 154 | 155 | if mask is not None: 156 | u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask) 157 | v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask) 158 | else: 159 | u_masked = u 160 | v_masked = v 161 | 162 | fmt = 'f' * width*height*2 163 | # Interleave lists 164 | data = [x for t in zip(u_masked,v_masked) for x in t] 165 | 166 | with open(path, 'wb') as fil: 167 | fil.write(str.encode(TAG_STRING)) 168 | fil.write(struct.pack('i', width)) 169 | fil.write(struct.pack('i', height)) 170 | fil.write(struct.pack(fmt, *data)) 171 | 172 | 173 | def write_flow(path,flow): 174 | 175 | invalid_idx = (flow[:, :, 2] == 0) 176 | flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15 177 | flow[invalid_idx, 0] = 0 178 | flow[invalid_idx, 1] = 0 179 | 180 | flow = flow.astype(np.uint16) 181 | flow = cv2.imwrite(path, flow[:,:,::-1]) 182 | 183 | #WriteKittiPngFile(path, 184 | # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(), 185 | # flow[:,:,1].flatten(), flow[:,:,2].flatten()) 186 | 187 | 188 | 189 | def WriteKittiPngFile(path, width, height, u, v, mask=None): 190 | """ Write 16-bit .PNG file as specified by KITTI-2015 (flow). 191 | 192 | u, v are lists of float values 193 | mask is a list of floats, denoting the *valid* pixels. 194 | """ 195 | 196 | data = array.array('H',[0])*width*height*3 197 | 198 | for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)): 199 | data[3*i] = int(u_*64.0+2**15) 200 | data[3*i+1] = int(v_*64.0+2**15) 201 | data[3*i+2] = int(mask_) 202 | 203 | # if mask_ > 0: 204 | # print(data[3*i], data[3*i+1],data[3*i+2]) 205 | 206 | with open(path, 'wb') as png_file: 207 | png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False) 208 | png_writer.write_array(png_file, data) 209 | 210 | 211 | def ConvertMiddleburyFloToKittiPng(src_path, dest_path): 212 | width, height, u, v, mask = ReadMiddleburyFloFile(src_path) 213 | WriteKittiPngFile(dest_path, width, height, u, v, mask=mask) 214 | 215 | def ConvertKittiPngToMiddleburyFlo(src_path, dest_path): 216 | width, height, u, v, mask = ReadKittiPngFile(src_path) 217 | WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask) 218 | 219 | 220 | def ParseFilenameKitti(filename): 221 | # Parse kitti filename (seq_frameno.xx), 222 | # return seq, frameno, ext. 223 | # Be aware that seq might contain the dataset name (if contained as prefix) 224 | ext = filename[filename.rfind('.'):] 225 | frameno = filename[filename.rfind('_')+1:filename.rfind('.')] 226 | frameno = int(frameno) 227 | seq = filename[:filename.rfind('_')] 228 | return seq, frameno, ext 229 | 230 | 231 | def read_calib_file(filepath): 232 | """Read in a calibration file and parse into a dictionary.""" 233 | data = {} 234 | 235 | with open(filepath, 'r') as f: 236 | for line in f.readlines(): 237 | key, value = line.split(':', 1) 238 | # The only non-float values in these files are dates, which 239 | # we don't care about anyway 240 | try: 241 | data[key] = np.array([float(x) for x in value.split()]) 242 | except ValueError: 243 | pass 244 | 245 | return data 246 | 247 | def load_calib_cam_to_cam(cam_to_cam_file): 248 | # We'll return the camera calibration as a dictionary 249 | data = {} 250 | 251 | # Load and parse the cam-to-cam calibration data 252 | filedata = read_calib_file(cam_to_cam_file) 253 | 254 | # Create 3x4 projection matrices 255 | P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4)) 256 | P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4)) 257 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4)) 258 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4)) 259 | 260 | # Compute the camera intrinsics 261 | data['K_cam0'] = P_rect_00[0:3, 0:3] 262 | data['K_cam1'] = P_rect_10[0:3, 0:3] 263 | data['K_cam2'] = P_rect_20[0:3, 0:3] 264 | data['K_cam3'] = P_rect_30[0:3, 0:3] 265 | 266 | data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0] 267 | data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0] 268 | data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0] 269 | data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0] 270 | 271 | return data 272 | 273 | --------------------------------------------------------------------------------