├── .gitignore
├── FlyingChairs_train_val.txt
├── LICENSE
├── README.md
├── dataloader
├── MiddleburyList.py
├── MiddleburySubmit.py
├── __init__.py
├── chairslist.py
├── flow_transforms.py
├── hd1klist.py
├── kitti12list.py
├── kitti15list.py
├── kitti15list_train.py
├── kitti15list_val.py
├── kitticliplist.py
├── listflowfile.py
├── mblist.py
├── robloader.py
├── sintellist.py
├── sintellist_clean.py
├── sintellist_final.py
├── sintellist_train.py
├── sintellist_val.py
├── stereo_kittilist12.py
├── stereo_kittilist15.py
└── thingslist.py
├── dataset
├── IIW
│ ├── hawk_000299.png
│ └── hawk_000300.png
└── kitti_scene
│ └── testing
│ └── image_2
│ ├── 000042_10.png
│ └── 000042_11.png
├── eval_tmp.py
├── figs
├── architecture.png
├── hawk-vec.png
├── hawk.png
├── kitti-test-42-vec.png
├── kitti-test-42.png
├── output-onlinepngtools.png
└── time-breakdown.png
├── flops.py
├── main.py
├── models
├── PWCNet.py
├── VCN.py
├── VCN_small.py
├── __init__.py
├── conv4d.py
└── submodule.py
├── order.txt
├── run.sh
├── run_self.sh
├── submission.py
├── thop
├── __init__.py
├── count_hooks.py
├── profile.py
└── utils.py
├── utils
├── __init__.py
├── flowlib.py
├── io.py
├── logger.py
├── multiscaleloss.py
├── pfm.py
├── readpfm.py
├── sintel_io.py
└── util_flow.py
└── visualize.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | pwc-time
2 | vcn-time
3 | vcn-time2
4 | profile.py
5 | results/
6 | weights
7 | weights/
8 | weights_small/
9 | __pycache__/
10 | */__pycache__/
11 | iter_counts*.txt
12 | tmp/
13 | *.lprof
14 | *.pyc
15 | models/*.pyc
16 | *.npy
17 | .ipynb_checkpoints/*
18 | .ipynb_checkpoints
19 | results/*
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Carnegie Mellon University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VCN: Volumetric correspondence networks for optical flow
2 | #### [[project website]](http://www.contrib.andrew.cmu.edu/~gengshay/neurips19flow)
3 |
4 |
5 | **Requirements**
6 | - python 3.6
7 | - pytorch 1.1.0-1.3.0
8 | - [pytorch correlation module](https://github.com/gengshan-y/Pytorch-Correlation-extension) (optional) This gives around 20ms speed-up on KITTI-sized images. However, only forward pass is implemented. Please replace self.corrf() with self.corr() in models/VCN.py and models/VCN_small.py if you plan to use the faster version.
9 | - [weights files (VCN)](https://drive.google.com/drive/folders/1mgadg50ti1QdwfAf6aR2v1pCx-ITsYfE?usp=sharing) Note: You can load them without untarring the files, see [pytorch saving and loading models](https://pytorch.org/tutorials/beginner/saving_loading_models.html).
10 | - [weights files (VCN-small)](https://drive.google.com/drive/folders/16WvCL1Y5IkCoNmEEEbF_qmZPToMhw9VZ?usp=sharing)
11 |
12 | ## Pre-trained models
13 | #### To test on any two images
14 | Running [visualize.ipynb](./visualize.ipynb) gives you the following flow visualizations with color and vectors. Note: the sintel model "./weights/sintel-ft-trainval/finetune_67999.tar" is trained on multiple datasets and generalizes better than the KITTI model.
15 |
16 |
17 |
18 |
19 |
20 |
21 | #### KITTI
22 | **This correspondens to the entry on the [leaderboard](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) (Fl-all=6.30%).**
23 | ##### Evaluate on KITTI-15 benchmark
24 |
25 | To run + visualize on KITTI-15 test set,
26 | ```
27 | modelname=kitti-ft-trainval
28 | i=149999
29 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015test --datapath dataset/kitti_scene/testing/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 512 --fac 2
30 | python eval_tmp.py --path ./weights/$modelname/ --vis yes --dataset 2015test
31 | ```
32 |
33 | ##### Evaluate on KITTI-val
34 | *To see the details of the train-val split, please scroll down to "note on train-val" and run dataloader/kitti15list_val.py, dataloader/kitti15list_train.py, dataloader/sitnellist_train.py, and dataloader/sintellist_val.py.*
35 |
36 | To evaluate on the 40 validation images of KITTI-15 (0,5,...195), (also assuming the data is at /ssd/kitti_scene)
37 | ```
38 | modelname=kitti-ft-trainval
39 | i=149999
40 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015 --datapath /ssd/kitti_scene/training/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 512 --fac 2
41 | python eval_tmp.py --path ./weights/$modelname/ --vis no --dataset 2015
42 | ```
43 |
44 | To evaluate + visualize on KITTI-15 validation set,
45 | ```
46 | python eval_tmp.py --path ./weights/$modelname/ --vis yes --dataset 2015
47 | ```
48 | Evaluation error on 40 validation images : Fl-err = 3.9, EPE = 1.144
49 |
50 | #### Sintel
51 | **This correspondens to the entry on the [leaderboard](http://sintel.is.tue.mpg.de/quant?metric_id=0&selected_pass=0) (EPE-all-final = 4.404, EPE-all-clean = 2.808).**
52 | ##### Evaluate on Sintel-val
53 |
54 | To evaluate on Sintel validation set,
55 | ```
56 | modelname=sintel-ft-trainval
57 | i=67999
58 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset sintel --datapath /ssd/rob_flow/training/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 448 --fac 1.4
59 | python eval_tmp.py --path ./weights/$modelname/ --vis no --dataset sintel
60 | ```
61 | Evaluation error on sintel validation images: Fl-err = 7.9, EPE = 2.351
62 |
63 |
64 | ## Train the model
65 | We follow the same stage-wise training procedure as prior work: Chairs->Things->KITTI or Chairs->Things->Sintel, but uses much lesser iterations.
66 | If you plan to train the model and reproduce the numbers, please check out our [supplementary material](https://papers.nips.cc/paper/8367-volumetric-correspondence-networks-for-optical-flow) for the differences in hyper-parameters with FlowNet2 and PWCNet.
67 |
68 | #### Pretrain on flying chairs and flying things
69 | Make sure you have downloaded [flying chairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html)
70 | and [flying things **subset**](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html),
71 | and placed them under the same folder, say /ssd/.
72 |
73 | To first train on flying chairs for 140k iterations with a batchsize of 8, run (assuming you have two gpus)
74 | ```
75 | CUDA_VISIBLE_DEVICES=0,1 python main.py --maxdisp 256 --fac 1 --database /ssd/ --logname chairs-0 --savemodel /data/ptmodel/ --epochs 1000 --stage chairs --ngpus 2
76 | ```
77 | Then we want to fine-tune on flying things for 80k iterations with a batchsize of 8, resume from your pre-trained model or use our pretrained model
78 | ```
79 | CUDA_VISIBLE_DEVICES=0,1 python main.py --maxdisp 256 --fac 1 --database /ssd/ --logname things-0 --savemodel /data/ptmodel/ --epochs 1000 --stage things --ngpus 2 --loadmodel ./weights/charis/finetune_141999.tar --retrain false
80 | ```
81 | Note that to resume the number of iterations, put the iteration to start from in iter_counts-(your suffix).txt. In this example, I'll put 141999 in iter_counts-0.txt.
82 | Be aware that the program reads/writes to iter_counts-(suffix).txt at training time, so you may want to use different suffix when multiple training programs are running at the same time.
83 |
84 | #### Finetune on KITTI / Sintel
85 | Please first download the kitti 2012/2015 flow dataset if you want to fine-tune on kitti.
86 | Download [rob_devkit](http://www.cvlibs.net:3000/ageiger/rob_devkit/src/flow/flow) if you want to fine-tune on sintel.
87 |
88 | To fine-tune on KITTI with a batchsize of 16, run
89 | ```
90 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --maxdisp 512 --fac 2 --database /ssd/ --logname kitti-trainval-0 --savemodel /data/ptmodel/ --epochs 1000 --stage 2015trainval --ngpus 4 --loadmodel ./weights/things/finetune_211999.tar --retrain true
91 | ```
92 | To fine-tune on Sintel with a batchsize of 16, run
93 | ```
94 | CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --maxdisp 448 --fac 1.4 --database /ssd/ --logname sintel-trainval-0 --savemodel /data/ptmodel/ --epochs 1000 --stage sinteltrainval --ngpus 4 --loadmodel ./weights/things/finetune_239999.tar --retrain true
95 | ```
96 |
97 | #### Note on train-val
98 | - To tune hyper-parameters, we use a train-val split for kitti and sintel, which is not covered by the
99 | above procedure.
100 | - For kitti we use every 5th image in the training set (0,5,10,...195) for validation, and the rest for training; while for Sintel, we manually select several sequences for validation.
101 | - If you plan to use our split, put "--stage 2015train" or "--stage sinteltrain" for training.
102 | - The numbers in Tab.3 of the paper is on the whole train-val set (all the data with ground-truth).
103 | - You might find run.sh helpful to run evaluation on KITTI/Sintel.
104 |
105 | ## Measure FLOPS
106 | ```
107 | python flops.py
108 | ```
109 | gives
110 |
111 | PWCNet: flops(G)/params(M):90.8/9.37
112 |
113 | VCN: flops(G)/params(M):96.5/6.23
114 |
115 | #### Note on inference time
116 | The current implementation runs at 180ms/pair on KITTI-sized images at inference time.
117 | A rough breakdown of running time is: feature extraction - 4.9%, feature correlation - 8.7%, separable 4D convolutions - 56%, trun. soft-argmin (soft winner-take-all) - 20% and hypotheses fusion - 9.5%.
118 | A detailed breakdown is shown below in the form "name-level percentage".
119 |
120 |
121 |
122 | Note that separable 4D convolutions use less FLOPS than 2D convolutions (i.e., feature extraction module + hypotheses fusion module, 47.8 v.s. 53.3 Gflops)
123 | but take 4X more time (56% v.s. 14.4%). One reason might be that pytorch (also other packages) is more friendly to networks with more feature channels than those with large spatial size given the same Flops. This might be fixed at the conv kernel / hardware level.
124 |
125 | Besides, the truncated soft-argmin is implemented with 3D max pooling, which is inefficient and takes more time than expected.
126 |
127 | ## Acknowledgement
128 | Thanks [ClementPinard](https://github.com/ClementPinard), [Lyken17](https://github.com/Lyken17), [NVlabs](https://github.com/NVlabs) and many others for open-sourcing their code.
129 | - Pytorch op counter thop is modified from [pytorch-OpCounter](https://github.com/Lyken17/pytorch-OpCounter)
130 | - Correlation module is modified from [Pytorch-Correlation-extension](https://github.com/ClementPinard/Pytorch-Correlation-extension)
131 | - Full 4D convolution is taken from [NCNet](https://github.com/ignacio-rocco/ncnet), but is not used for our model (only used in Ablation study).
132 |
133 | ## Citation
134 | ```
135 | @inproceedings{yang2019vcn,
136 | title={Volumetric Correspondence Networks for Optical Flow},
137 | author={Yang, Gengshan and Ramanan, Deva},
138 | booktitle={NeurIPS},
139 | year={2019}
140 | }
141 | ```
142 |
--------------------------------------------------------------------------------
/dataloader/MiddleburyList.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import glob
3 | import pdb
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import numpy as np
8 |
9 |
10 | def dataloader(filepath,res='Q'):
11 | filepath = '%s/training%s'%(filepath,res)
12 | img_list = [i.split('/')[-1] for i in glob.glob('%s/*'%(filepath)) if os.path.isdir(i)]
13 |
14 | left_train = ['%s/%s/im0.png'% (filepath,img) for img in img_list]
15 | right_train = ['%s/%s/im1.png'% (filepath,img) for img in img_list]
16 | disp_train_L = ['%s/%s/disp0GT.pfm' % (filepath,img) for img in img_list]
17 | disp_train_R = ['%s/%s/disp1GT.pfm' % (filepath,img) for img in img_list]
18 |
19 | return left_train, right_train, disp_train_L, left_train, right_train, disp_train_R
20 |
21 |
--------------------------------------------------------------------------------
/dataloader/MiddleburySubmit.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | import pdb
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import numpy as np
8 | import glob
9 |
10 |
11 | def dataloader(filepath):
12 | img_list = ['Adirondack', 'Motorcycle', 'PianoL',
13 | 'Playtable', 'Shelves', 'ArtL',
14 | 'MotorcycleE', 'Pipes', 'PlaytableP',
15 | 'Teddy', 'Jadeplant', 'Piano',
16 | 'Playroom', 'Recycle', 'Vintage']
17 |
18 | img_list = ['Australia', 'Bicycle2', 'Classroom2E',
19 | 'Crusade', 'Djembe', 'Hoops', 'Newkuba',
20 | 'Staircase', 'AustraliaP', 'Classroom2', 'Computer',
21 | 'CrusadeP', 'DjembeL', 'Livingroom', 'Plants']
22 |
23 | img_list = [i.split('/')[-1] for i in glob.glob('%s/*'%filepath) if os.path.isdir(i)]
24 | #img_list *= 10
25 |
26 | left_train = ['%s/%s/im0.png'% (filepath,img) for img in img_list]
27 | right_train = ['%s/%s/im1.png'% (filepath,img) for img in img_list]
28 |
29 |
30 | return left_train, right_train, left_train
31 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataloader/__init__.py
--------------------------------------------------------------------------------
/dataloader/chairslist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import glob
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 | l0_train = []
20 | l1_train = []
21 | flow_train = []
22 | for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))):
23 | root_filename = flow_map[:-9]
24 | img1 = root_filename+'_img1.ppm'
25 | img2 = root_filename+'_img2.ppm'
26 | if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))):
27 | continue
28 |
29 | l0_train.append(img1)
30 | l1_train.append(img2)
31 | flow_train.append(flow_map)
32 |
33 | return l0_train, l1_train, flow_train
34 |
--------------------------------------------------------------------------------
/dataloader/flow_transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import random
4 | import numpy as np
5 | import numbers
6 | import types
7 | import scipy.ndimage as ndimage
8 | import pdb
9 | import torchvision
10 | import PIL.Image as Image
11 | import cv2
12 | from torch.nn import functional as F
13 |
14 |
15 | class Compose(object):
16 | """ Composes several co_transforms together.
17 | For example:
18 | >>> co_transforms.Compose([
19 | >>> co_transforms.CenterCrop(10),
20 | >>> co_transforms.ToTensor(),
21 | >>> ])
22 | """
23 |
24 | def __init__(self, co_transforms):
25 | self.co_transforms = co_transforms
26 |
27 | def __call__(self, input, target):
28 | for t in self.co_transforms:
29 | input,target = t(input,target)
30 | return input,target
31 |
32 |
33 | class Scale(object):
34 | """ Rescales the inputs and target arrays to the given 'size'.
35 | 'size' will be the size of the smaller edge.
36 | For example, if height > width, then image will be
37 | rescaled to (size * height / width, size)
38 | size: size of the smaller edge
39 | interpolation order: Default: 2 (bilinear)
40 | """
41 |
42 | def __init__(self, size, order=1):
43 | self.ratio = size
44 | self.order = order
45 | if order==0:
46 | self.code=cv2.INTER_NEAREST
47 | elif order==1:
48 | self.code=cv2.INTER_LINEAR
49 | elif order==2:
50 | self.code=cv2.INTER_CUBIC
51 |
52 | def __call__(self, inputs, target):
53 | if self.ratio==1:
54 | return inputs, target
55 | h, w, _ = inputs[0].shape
56 | ratio = self.ratio
57 |
58 | inputs[0] = cv2.resize(inputs[0], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
59 | inputs[1] = cv2.resize(inputs[1], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_LINEAR)
60 | # keep the mask same
61 | tmp = cv2.resize(target[:,:,2], None, fx=ratio,fy=ratio,interpolation=cv2.INTER_NEAREST)
62 | target = cv2.resize(target, None, fx=ratio,fy=ratio,interpolation=self.code) * ratio
63 | target[:,:,2] = tmp
64 |
65 |
66 | return inputs, target
67 |
68 |
69 |
70 |
71 | class SpatialAug(object):
72 | def __init__(self, crop, scale=None, rot=None, trans=None, squeeze=None, schedule_coeff=1, order=1, black=False):
73 | self.crop = crop
74 | self.scale = scale
75 | self.rot = rot
76 | self.trans = trans
77 | self.squeeze = squeeze
78 | self.t = np.zeros(6)
79 | self.schedule_coeff = schedule_coeff
80 | self.order = order
81 | self.black = black
82 |
83 | def to_identity(self):
84 | self.t[0] = 1; self.t[2] = 0; self.t[4] = 0; self.t[1] = 0; self.t[3] = 1; self.t[5] = 0;
85 |
86 | def left_multiply(self, u0, u1, u2, u3, u4, u5):
87 | result = np.zeros(6)
88 | result[0] = self.t[0]*u0 + self.t[1]*u2;
89 | result[1] = self.t[0]*u1 + self.t[1]*u3;
90 |
91 | result[2] = self.t[2]*u0 + self.t[3]*u2;
92 | result[3] = self.t[2]*u1 + self.t[3]*u3;
93 |
94 | result[4] = self.t[4]*u0 + self.t[5]*u2 + u4;
95 | result[5] = self.t[4]*u1 + self.t[5]*u3 + u5;
96 | self.t = result
97 |
98 | def inverse(self):
99 | result = np.zeros(6)
100 | a = self.t[0]; c = self.t[2]; e = self.t[4];
101 | b = self.t[1]; d = self.t[3]; f = self.t[5];
102 |
103 | denom = a*d - b*c;
104 |
105 | result[0] = d / denom;
106 | result[1] = -b / denom;
107 | result[2] = -c / denom;
108 | result[3] = a / denom;
109 | result[4] = (c*f-d*e) / denom;
110 | result[5] = (b*e-a*f) / denom;
111 |
112 | return result
113 |
114 | def grid_transform(self, meshgrid, t, normalize=True, gridsize=None):
115 | if gridsize is None:
116 | h, w = meshgrid[0].shape
117 | else:
118 | h, w = gridsize
119 | vgrid = torch.cat([(meshgrid[0] * t[0] + meshgrid[1] * t[2] + t[4])[:,:,np.newaxis],
120 | (meshgrid[0] * t[1] + meshgrid[1] * t[3] + t[5])[:,:,np.newaxis]],-1)
121 | if normalize:
122 | vgrid[:,:,0] = 2.0*vgrid[:,:,0]/max(w-1,1)-1.0
123 | vgrid[:,:,1] = 2.0*vgrid[:,:,1]/max(h-1,1)-1.0
124 | return vgrid
125 |
126 |
127 | def __call__(self, inputs, target):
128 | h, w, _ = inputs[0].shape
129 | th, tw = self.crop
130 | meshgrid = torch.meshgrid([torch.Tensor(range(th)), torch.Tensor(range(tw))])[::-1]
131 | cornergrid = torch.meshgrid([torch.Tensor([0,th-1]), torch.Tensor([0,tw-1])])[::-1]
132 |
133 | for i in range(50):
134 | # im0
135 | self.to_identity()
136 | #TODO add mirror
137 | if np.random.binomial(1,0.5):
138 | mirror = True
139 | else:
140 | mirror = False
141 | ##TODO
142 | #mirror = False
143 | if mirror:
144 | self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
145 | else:
146 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
147 | scale0 = 1; scale1 = 1; squeeze0 = 1; squeeze1 = 1;
148 | if not self.rot is None:
149 | rot0 = np.random.uniform(-self.rot[0],+self.rot[0])
150 | rot1 = np.random.uniform(-self.rot[1]*self.schedule_coeff, self.rot[1]*self.schedule_coeff) + rot0
151 | self.left_multiply(np.cos(rot0), np.sin(rot0), -np.sin(rot0), np.cos(rot0), 0, 0)
152 | if not self.trans is None:
153 | trans0 = np.random.uniform(-self.trans[0],+self.trans[0], 2)
154 | trans1 = np.random.uniform(-self.trans[1]*self.schedule_coeff,+self.trans[1]*self.schedule_coeff, 2) + trans0
155 | self.left_multiply(1, 0, 0, 1, trans0[0] * tw, trans0[1] * th)
156 | if not self.squeeze is None:
157 | squeeze0 = np.exp(np.random.uniform(-self.squeeze[0], self.squeeze[0]))
158 | squeeze1 = np.exp(np.random.uniform(-self.squeeze[1]*self.schedule_coeff, self.squeeze[1]*self.schedule_coeff)) * squeeze0
159 | if not self.scale is None:
160 | scale0 = np.exp(np.random.uniform(self.scale[2]-self.scale[0], self.scale[2]+self.scale[0]))
161 | scale1 = np.exp(np.random.uniform(-self.scale[1]*self.schedule_coeff, self.scale[1]*self.schedule_coeff)) * scale0
162 | self.left_multiply(1.0/(scale0*squeeze0), 0, 0, 1.0/(scale0/squeeze0), 0, 0)
163 |
164 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
165 | transmat0 = self.t.copy()
166 |
167 | # im1
168 | self.to_identity()
169 | if mirror:
170 | self.left_multiply(-1, 0, 0, 1, .5 * tw, -.5 * th);
171 | else:
172 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
173 | if not self.rot is None:
174 | self.left_multiply(np.cos(rot1), np.sin(rot1), -np.sin(rot1), np.cos(rot1), 0, 0)
175 | if not self.trans is None:
176 | self.left_multiply(1, 0, 0, 1, trans1[0] * tw, trans1[1] * th)
177 | self.left_multiply(1.0/(scale1*squeeze1), 0, 0, 1.0/(scale1/squeeze1), 0, 0)
178 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
179 | transmat1 = self.t.copy()
180 | transmat1_inv = self.inverse()
181 |
182 | if self.black:
183 | # black augmentation, allowing 0 values in the input images
184 | # https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/black_augmentation_layer.cu
185 | break
186 | else:
187 | if ((self.grid_transform(cornergrid, transmat0, gridsize=[float(h),float(w)]).abs()>1).sum() +\
188 | (self.grid_transform(cornergrid, transmat1, gridsize=[float(h),float(w)]).abs()>1).sum()) == 0:
189 | break
190 | if i==49:
191 | print('max_iter in augmentation')
192 | self.to_identity()
193 | self.left_multiply(1, 0, 0, 1, -.5 * tw, -.5 * th);
194 | self.left_multiply(1, 0, 0, 1, .5 * w, .5 * h);
195 | transmat0 = self.t.copy()
196 | transmat1 = self.t.copy()
197 |
198 | # do the real work
199 | vgrid = self.grid_transform(meshgrid, transmat0,gridsize=[float(h),float(w)])
200 | inputs_0 = F.grid_sample(torch.Tensor(inputs[0]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
201 | if self.order == 0:
202 | target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
203 | else:
204 | target_0 = F.grid_sample(torch.Tensor(target).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
205 |
206 | mask_0 = target[:,:,2:3].copy(); mask_0[mask_0==0]=np.nan
207 | if self.order == 0:
208 | mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis], mode='nearest')[0].permute(1,2,0)
209 | else:
210 | mask_0 = F.grid_sample(torch.Tensor(mask_0).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
211 | mask_0[torch.isnan(mask_0)] = 0
212 |
213 |
214 | vgrid = self.grid_transform(meshgrid, transmat1,gridsize=[float(h),float(w)])
215 | inputs_1 = F.grid_sample(torch.Tensor(inputs[1]).permute(2,0,1)[np.newaxis], vgrid[np.newaxis])[0].permute(1,2,0)
216 |
217 | # flow
218 | pos = target_0[:,:,:2] + self.grid_transform(meshgrid, transmat0,normalize=False)
219 | pos = self.grid_transform(pos.permute(2,0,1),transmat1_inv,normalize=False)
220 | if target_0.shape[2]>=4:
221 | # scale
222 | exp = target_0[:,:,3:] * scale1 / scale0
223 | target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
224 | (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
225 | mask_0,
226 | exp], -1)
227 | else:
228 | target = torch.cat([ (pos[:,:,0] - meshgrid[0]).unsqueeze(-1),
229 | (pos[:,:,1] - meshgrid[1]).unsqueeze(-1),
230 | mask_0], -1)
231 | # target_0[:,:,2].unsqueeze(-1) ], -1)
232 | inputs = [np.asarray(inputs_0), np.asarray(inputs_1)]
233 | target = np.asarray(target)
234 |
235 | return inputs,target
236 |
237 |
238 | class pseudoPCAAug(object):
239 | """
240 | Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
241 | This version is faster.
242 | """
243 | def __init__(self, schedule_coeff=1):
244 | self.augcolor = torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.5, hue=0.5/3.14)
245 |
246 | def __call__(self, inputs, target):
247 | inputs[0] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[0]*255))))/255.
248 | inputs[1] = np.asarray(self.augcolor(Image.fromarray(np.uint8(inputs[1]*255))))/255.
249 | return inputs,target
250 |
251 |
252 | class PCAAug(object):
253 | """
254 | Chromatic Eigen Augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
255 | """
256 | def __init__(self, lmult_pow =[0.4, 0,-0.2],
257 | lmult_mult =[0.4, 0,0, ],
258 | lmult_add =[0.03,0,0, ],
259 | sat_pow =[0.4, 0,0, ],
260 | sat_mult =[0.5, 0,-0.3],
261 | sat_add =[0.03,0,0, ],
262 | col_pow =[0.4, 0,0, ],
263 | col_mult =[0.2, 0,0, ],
264 | col_add =[0.02,0,0, ],
265 | ladd_pow =[0.4, 0,0, ],
266 | ladd_mult =[0.4, 0,0, ],
267 | ladd_add =[0.04,0,0, ],
268 | col_rotate =[1., 0,0, ],
269 | schedule_coeff=1):
270 | # no mean
271 | self.pow_nomean = [1,1,1]
272 | self.add_nomean = [0,0,0]
273 | self.mult_nomean = [1,1,1]
274 | self.pow_withmean = [1,1,1]
275 | self.add_withmean = [0,0,0]
276 | self.mult_withmean = [1,1,1]
277 | self.lmult_pow = 1
278 | self.lmult_mult = 1
279 | self.lmult_add = 0
280 | self.col_angle = 0
281 | if not ladd_pow is None:
282 | self.pow_nomean[0] =np.exp(np.random.normal(ladd_pow[2], ladd_pow[0]))
283 | if not col_pow is None:
284 | self.pow_nomean[1] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
285 | self.pow_nomean[2] =np.exp(np.random.normal(col_pow[2], col_pow[0]))
286 |
287 | if not ladd_add is None:
288 | self.add_nomean[0] =np.random.normal(ladd_add[2], ladd_add[0])
289 | if not col_add is None:
290 | self.add_nomean[1] =np.random.normal(col_add[2], col_add[0])
291 | self.add_nomean[2] =np.random.normal(col_add[2], col_add[0])
292 |
293 | if not ladd_mult is None:
294 | self.mult_nomean[0] =np.exp(np.random.normal(ladd_mult[2], ladd_mult[0]))
295 | if not col_mult is None:
296 | self.mult_nomean[1] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
297 | self.mult_nomean[2] =np.exp(np.random.normal(col_mult[2], col_mult[0]))
298 |
299 | # with mean
300 | if not sat_pow is None:
301 | self.pow_withmean[1] =np.exp(np.random.uniform(sat_pow[2]-sat_pow[0], sat_pow[2]+sat_pow[0]))
302 | self.pow_withmean[2] =self.pow_withmean[1]
303 | if not sat_add is None:
304 | self.add_withmean[1] =np.random.uniform(sat_add[2]-sat_add[0], sat_add[2]+sat_add[0])
305 | self.add_withmean[2] =self.add_withmean[1]
306 | if not sat_mult is None:
307 | self.mult_withmean[1] = np.exp(np.random.uniform(sat_mult[2]-sat_mult[0], sat_mult[2]+sat_mult[0]))
308 | self.mult_withmean[2] = self.mult_withmean[1]
309 |
310 | if not lmult_pow is None:
311 | self.lmult_pow = np.exp(np.random.uniform(lmult_pow[2]-lmult_pow[0], lmult_pow[2]+lmult_pow[0]))
312 | if not lmult_mult is None:
313 | self.lmult_mult= np.exp(np.random.uniform(lmult_mult[2]-lmult_mult[0], lmult_mult[2]+lmult_mult[0]))
314 | if not lmult_add is None:
315 | self.lmult_add = np.random.uniform(lmult_add[2]-lmult_add[0], lmult_add[2]+lmult_add[0])
316 | if not col_rotate is None:
317 | self.col_angle= np.random.uniform(col_rotate[2]-col_rotate[0], col_rotate[2]+col_rotate[0])
318 |
319 | # eigen vectors
320 | self.eigvec = np.reshape([0.51,0.56,0.65,0.79,0.01,-0.62,0.35,-0.83,0.44],[3,3]).transpose()
321 |
322 |
323 | def __call__(self, inputs, target):
324 | inputs[0] = self.pca_image(inputs[0])
325 | inputs[1] = self.pca_image(inputs[1])
326 | return inputs,target
327 |
328 | def pca_image(self, rgb):
329 | eig = np.dot(rgb, self.eigvec)
330 | max_rgb = np.clip(rgb,0,np.inf).max((0,1))
331 | min_rgb = rgb.min((0,1))
332 | mean_rgb = rgb.mean((0,1))
333 | max_abs_eig = np.abs(eig).max((0,1))
334 | max_l = np.sqrt(np.sum(max_abs_eig*max_abs_eig))
335 | mean_eig = np.dot(mean_rgb, self.eigvec)
336 |
337 | # no-mean stuff
338 | eig -= mean_eig[np.newaxis, np.newaxis]
339 |
340 | for c in range(3):
341 | if max_abs_eig[c] > 1e-2:
342 | mean_eig[c] /= max_abs_eig[c]
343 | eig[:,:,c] = eig[:,:,c] / max_abs_eig[c];
344 | eig[:,:,c] = np.power(np.abs(eig[:,:,c]),self.pow_nomean[c]) *\
345 | ((eig[:,:,c] > 0) -0.5)*2
346 | eig[:,:,c] = eig[:,:,c] + self.add_nomean[c]
347 | eig[:,:,c] = eig[:,:,c] * self.mult_nomean[c]
348 | eig += mean_eig[np.newaxis,np.newaxis]
349 |
350 | # withmean stuff
351 | if max_abs_eig[0] > 1e-2:
352 | eig[:,:,0] = np.power(np.abs(eig[:,:,0]),self.pow_withmean[0]) * \
353 | ((eig[:,:,0]>0)-0.5)*2;
354 | eig[:,:,0] = eig[:,:,0] + self.add_withmean[0];
355 | eig[:,:,0] = eig[:,:,0] * self.mult_withmean[0];
356 |
357 | s = np.sqrt(eig[:,:,1]*eig[:,:,1] + eig[:,:,2] * eig[:,:,2])
358 | smask = s > 1e-2
359 | s1 = np.power(s, self.pow_withmean[1]);
360 | s1 = np.clip(s1 + self.add_withmean[1], 0,np.inf)
361 | s1 = s1 * self.mult_withmean[1]
362 | s1 = s1 * smask + s*(1-smask)
363 |
364 | # color angle
365 | if self.col_angle!=0:
366 | temp1 = np.cos(self.col_angle) * eig[:,:,1] - np.sin(self.col_angle) * eig[:,:,2]
367 | temp2 = np.sin(self.col_angle) * eig[:,:,1] + np.cos(self.col_angle) * eig[:,:,2]
368 | eig[:,:,1] = temp1
369 | eig[:,:,2] = temp2
370 |
371 | # to origin magnitude
372 | for c in range(3):
373 | if max_abs_eig[c] > 1e-2:
374 | eig[:,:,c] = eig[:,:,c] * max_abs_eig[c]
375 |
376 | if max_l > 1e-2:
377 | l1 = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
378 | l1 = l1 / max_l
379 |
380 | eig[:,:,1][smask] = (eig[:,:,1] / s * s1)[smask]
381 | eig[:,:,2][smask] = (eig[:,:,2] / s * s1)[smask]
382 | #eig[:,:,1] = (eig[:,:,1] / s * s1) * smask + eig[:,:,1] * (1-smask)
383 | #eig[:,:,2] = (eig[:,:,2] / s * s1) * smask + eig[:,:,2] * (1-smask)
384 |
385 | if max_l > 1e-2:
386 | l = np.sqrt(eig[:,:,0]*eig[:,:,0] + eig[:,:,1]*eig[:,:,1] + eig[:,:,2]*eig[:,:,2])
387 | l1 = np.power(l1, self.lmult_pow)
388 | l1 = np.clip(l1 + self.lmult_add, 0, np.inf)
389 | l1 = l1 * self.lmult_mult
390 | l1 = l1 * max_l
391 | lmask = l > 1e-2
392 | eig[lmask] = (eig / l[:,:,np.newaxis] * l1[:,:,np.newaxis])[lmask]
393 | for c in range(3):
394 | eig[:,:,c][lmask] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c]))[lmask]
395 | # for c in range(3):
396 | # # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask] * lmask + eig[:,:,c] * (1-lmask)
397 | # eig[:,:,c][lmask] = (eig[:,:,c] / l * l1)[lmask]
398 | # eig[:,:,c] = (np.clip(eig[:,:,c], -np.inf, max_abs_eig[c])) * lmask + eig[:,:,c] * (1-lmask)
399 |
400 | return np.clip(np.dot(eig, self.eigvec.transpose()), 0, 1)
401 |
402 |
403 | class ChromaticAug(object):
404 | """
405 | Chromatic augmentation: https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
406 | """
407 | def __init__(self, noise = 0.06,
408 | gamma = 0.02,
409 | brightness = 0.02,
410 | contrast = 0.02,
411 | color = 0.02,
412 | schedule_coeff=1):
413 |
414 | self.noise = np.random.uniform(0,noise)
415 | self.gamma = np.exp(np.random.normal(0, gamma*schedule_coeff))
416 | self.brightness = np.random.normal(0, brightness*schedule_coeff)
417 | self.contrast = np.exp(np.random.normal(0, contrast*schedule_coeff))
418 | self.color = np.exp(np.random.normal(0, color*schedule_coeff,3))
419 |
420 | def __call__(self, inputs, target):
421 | inputs[1] = self.chrom_aug(inputs[1])
422 | # noise
423 | inputs[0]+=np.random.normal(0, self.noise, inputs[0].shape)
424 | inputs[1]+=np.random.normal(0, self.noise, inputs[0].shape)
425 | return inputs,target
426 |
427 | def chrom_aug(self, rgb):
428 | # color change
429 | mean_in = rgb.sum(-1)
430 | rgb = rgb*self.color[np.newaxis,np.newaxis]
431 | brightness_coeff = mean_in / (rgb.sum(-1)+0.01)
432 | rgb = np.clip(rgb*brightness_coeff[:,:,np.newaxis],0,1)
433 | # gamma
434 | rgb = np.power(rgb,self.gamma)
435 | # brightness
436 | rgb += self.brightness
437 | # contrast
438 | rgb = 0.5 + ( rgb-0.5)*self.contrast
439 | rgb = np.clip(rgb, 0, 1)
440 | return rgb
441 |
--------------------------------------------------------------------------------
/dataloader/hd1klist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1]
22 | train = sorted(train)
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
26 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
27 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
28 |
29 | return l0_train, l1_train, flow_train
30 |
--------------------------------------------------------------------------------
/dataloader/kitti12list.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'colored_0/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26 | flow_train = [filepath+flow_noc+img for img in train]
27 |
28 |
29 | return l0_train, l1_train, flow_train
30 |
--------------------------------------------------------------------------------
/dataloader/kitti15list.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26 | flow_train = [filepath+flow_noc+img for img in train]
27 |
28 |
29 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
30 |
--------------------------------------------------------------------------------
/dataloader/kitti15list_train.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | train = [i for i in train if int(i.split('_')[0])%5!=0]
25 |
26 | l0_train = [filepath+left_fold+img for img in train]
27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28 | flow_train = [filepath+flow_noc+img for img in train]
29 |
30 |
31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
32 |
--------------------------------------------------------------------------------
/dataloader/kitti15list_val.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | train = [i for i in train if int(i.split('_')[0])%5==0]
25 |
26 | l0_train = [filepath+left_fold+img for img in train]
27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28 | flow_train = [filepath+flow_noc+img for img in train]
29 |
30 |
31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
32 |
--------------------------------------------------------------------------------
/dataloader/kitticliplist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import glob
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | train = [img for img in sorted(glob.glob('%s*'%filepath))]
21 |
22 | l0_train = train[:-1]
23 | l1_train = train[1:]
24 |
25 |
26 | return sorted(l0_train), sorted(l1_train), sorted(l0_train)
27 |
--------------------------------------------------------------------------------
/dataloader/listflowfile.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | ]
11 |
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 | def dataloader(filepath):
17 |
18 | classes = [d for d in os.listdir(filepath) if os.path.isdir(os.path.join(filepath, d))]
19 | image = [img for img in classes if img.find('frames_cleanpass') > -1]
20 | disp = [dsp for dsp in classes if dsp.find('disparity') > -1]
21 |
22 | monkaa_path = filepath + [x for x in image if 'monkaa' in x][0]
23 | monkaa_disp = filepath + [x for x in disp if 'monkaa' in x][0]
24 |
25 |
26 | monkaa_dir = os.listdir(monkaa_path)
27 |
28 | all_left_img=[]
29 | all_right_img=[]
30 | all_left_disp = []
31 | all_right_disp = []
32 | test_left_img=[]
33 | test_right_img=[]
34 | test_left_disp = []
35 | test_right_disp = []
36 |
37 |
38 | for dd in monkaa_dir:
39 | for im in os.listdir(monkaa_path+'/'+dd+'/left/'):
40 | if is_image_file(monkaa_path+'/'+dd+'/left/'+im):
41 | all_left_img.append(monkaa_path+'/'+dd+'/left/'+im)
42 | all_left_disp.append(monkaa_disp+'/'+dd+'/left/'+im.split(".")[0]+'.pfm')
43 | all_right_disp.append(monkaa_disp+'/'+dd+'/right/'+im.split(".")[0]+'.pfm')
44 |
45 | for im in os.listdir(monkaa_path+'/'+dd+'/right/'):
46 | if is_image_file(monkaa_path+'/'+dd+'/right/'+im):
47 | all_right_img.append(monkaa_path+'/'+dd+'/right/'+im)
48 |
49 | flying_path = filepath + [x for x in image if x == 'frames_cleanpass'][0]
50 | flying_disp = filepath + [x for x in disp if x == 'disparity'][0]
51 | flying_dir = flying_path+'/TRAIN/'
52 | subdir = ['A','B','C']
53 |
54 | for ss in subdir:
55 | flying = os.listdir(flying_dir+ss)
56 |
57 | for ff in flying:
58 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
59 | for im in imm_l:
60 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
61 | all_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
62 |
63 | all_left_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
64 | all_right_disp.append(flying_disp+'/TRAIN/'+ss+'/'+ff+'/right/'+im.split(".")[0]+'.pfm')
65 |
66 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
67 | all_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
68 |
69 | flying_dir = flying_path+'/TEST/'
70 |
71 | subdir = ['A','B','C']
72 |
73 | for ss in subdir:
74 | flying = os.listdir(flying_dir+ss)
75 |
76 | for ff in flying:
77 | imm_l = os.listdir(flying_dir+ss+'/'+ff+'/left/')
78 | for im in imm_l:
79 | if is_image_file(flying_dir+ss+'/'+ff+'/left/'+im):
80 | test_left_img.append(flying_dir+ss+'/'+ff+'/left/'+im)
81 |
82 | test_left_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/left/'+im.split(".")[0]+'.pfm')
83 | test_right_disp.append(flying_disp+'/TEST/'+ss+'/'+ff+'/right/'+im.split(".")[0]+'.pfm')
84 |
85 | if is_image_file(flying_dir+ss+'/'+ff+'/right/'+im):
86 | test_right_img.append(flying_dir+ss+'/'+ff+'/right/'+im)
87 |
88 |
89 | driving_dir = filepath + [x for x in image if 'driving' in x][0] + '/'
90 | driving_disp = filepath + [x for x in disp if 'driving' in x][0]
91 |
92 | #subdir1 = ['35mm_focallength']
93 | #subdir2 = ['scene_backwards']
94 | #subdir3 = ['fast']
95 | ##TODO was using 15 only
96 | subdir1 = ['35mm_focallength','15mm_focallength']
97 | subdir2 = ['scene_backwards','scene_forwards']
98 | subdir3 = ['fast','slow']
99 |
100 | #subdir1 = ['35mm_focallength']
101 | #subdir2 = ['scene_backwards']
102 | #subdir3 = ['fast']
103 |
104 | for i in subdir1:
105 | for j in subdir2:
106 | for k in subdir3:
107 | imm_l = os.listdir(driving_dir+i+'/'+j+'/'+k+'/left/')
108 | for im in imm_l:
109 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/left/'+im):
110 | all_left_img.append(driving_dir+i+'/'+j+'/'+k+'/left/'+im)
111 | all_left_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/left/'+im.split(".")[0]+'.pfm')
112 | all_right_disp.append(driving_disp+'/'+i+'/'+j+'/'+k+'/right/'+im.split(".")[0]+'.pfm')
113 |
114 | if is_image_file(driving_dir+i+'/'+j+'/'+k+'/right/'+im):
115 | all_right_img.append(driving_dir+i+'/'+j+'/'+k+'/right/'+im)
116 |
117 |
118 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp, all_right_disp, test_right_disp
119 |
120 |
121 |
--------------------------------------------------------------------------------
/dataloader/mblist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if 'Middlebury' in img and img.find('_10') > -1]
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26 | flow_train = [filepath+flow_noc+img for img in train]
27 |
28 |
29 | return l0_train, l1_train, flow_train
30 |
--------------------------------------------------------------------------------
/dataloader/robloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numbers
3 | import torch
4 | import torch.utils.data as data
5 | import torch
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image, ImageOps
9 | import numpy as np
10 | import torchvision
11 | from . import flow_transforms
12 | import pdb
13 | import cv2
14 | from utils.flowlib import read_flow
15 | from utils.util_flow import readPFM
16 |
17 |
18 | def default_loader(path):
19 | return Image.open(path).convert('RGB')
20 |
21 | def flow_loader(path):
22 | if '.pfm' in path:
23 | data = readPFM(path)[0]
24 | data[:,:,2] = 1
25 | return data
26 | else:
27 | return read_flow(path)
28 |
29 |
30 | def disparity_loader(path):
31 | if '.png' in path:
32 | data = Image.open(path)
33 | data = np.ascontiguousarray(data,dtype=np.float32)/256
34 | return data
35 | else:
36 | return readPFM(path)[0]
37 |
38 | class myImageFloder(data.Dataset):
39 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False):
40 | self.iml0 = iml0
41 | self.iml1 = iml1
42 | self.flowl0 = flowl0
43 | self.loader = loader
44 | self.dploader = dploader
45 | self.scale=scale
46 | self.shape=shape
47 | self.order=order
48 | self.noise = noise
49 | self.pca_augmentor = pca_augmentor
50 | self.prob = prob
51 | self.cover = cover
52 | self.black = black
53 |
54 | def __getitem__(self, index):
55 | iml0 = self.iml0[index]
56 | iml1 = self.iml1[index]
57 | flowl0= self.flowl0[index]
58 | th, tw = self.shape
59 |
60 | iml0 = self.loader(iml0)
61 | iml1 = self.loader(iml1)
62 | iml1 = np.asarray(iml1)/255.
63 | iml0 = np.asarray(iml0)/255.
64 | iml0 = iml0[:,:,::-1].copy()
65 | iml1 = iml1[:,:,::-1].copy()
66 | flowl0 = self.dploader(flowl0)
67 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
68 | flowl0[np.isnan(flowl0)] = 1e6 # set to max
69 |
70 | ## following data augmentation procedure in PWCNet
71 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
72 | import __main__ # a workaround for "discount_coeff"
73 | try:
74 | with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
75 | iter_counts = int(f.readline())
76 | except:
77 | iter_counts = 0
78 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
79 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
80 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
81 |
82 | if self.pca_augmentor:
83 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
84 | else:
85 | pca_augmentor = flow_transforms.Scale(1., order=0)
86 |
87 | if np.random.binomial(1,self.prob):
88 | co_transform = flow_transforms.Compose([
89 | flow_transforms.Scale(self.scale, order=self.order),
90 | flow_transforms.SpatialAug([th,tw],scale=[0.4,0.03,0.2],
91 | rot=[0.4,0.03],
92 | trans=[0.4,0.03],
93 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black),
94 | flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
95 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
96 | ])
97 | else:
98 | co_transform = flow_transforms.Compose([
99 | flow_transforms.Scale(self.scale, order=self.order),
100 | flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black)
101 | ])
102 |
103 | augmented,flowl0 = co_transform([iml0, iml1], flowl0)
104 | iml0 = augmented[0]
105 | iml1 = augmented[1]
106 |
107 | if self.cover:
108 | ## randomly cover a region
109 | # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html
110 | if np.random.binomial(1,0.5):
111 | #sx = int(np.random.uniform(25,100))
112 | #sy = int(np.random.uniform(25,100))
113 | sx = int(np.random.uniform(50,125))
114 | sy = int(np.random.uniform(50,125))
115 | #sx = int(np.random.uniform(50,150))
116 | #sy = int(np.random.uniform(50,150))
117 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
118 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
119 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
120 |
121 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
122 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
123 |
124 | return iml0, iml1, flowl0
125 |
126 | def __len__(self):
127 | return len(self.iml0)
128 |
--------------------------------------------------------------------------------
/dataloader/sintellist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 |
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/dataloader/sintellist_clean.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 | return l0_train, l1_train, flow_train
32 |
--------------------------------------------------------------------------------
/dataloader/sintellist_final.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 | pdb.set_trace()
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/dataloader/sintellist_train.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 |
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/dataloader/sintellist_val.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val
27 | #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
28 |
29 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
30 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
31 |
32 |
33 | return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3]
34 | # return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10]
35 |
--------------------------------------------------------------------------------
/dataloader/stereo_kittilist12.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'colored_0/'
20 | right_fold = 'colored_1/'
21 | disp_noc = 'disp_occ/'
22 |
23 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
24 |
25 | train = image[:]
26 | val = image[160:]
27 |
28 | left_train = [filepath+left_fold+img for img in train]
29 | right_train = [filepath+right_fold+img for img in train]
30 | disp_train = [filepath+disp_noc+img for img in train]
31 |
32 |
33 | left_val = [filepath+left_fold+img for img in val]
34 | right_val = [filepath+right_fold+img for img in val]
35 | disp_val = [filepath+disp_noc+img for img in val]
36 |
37 | return left_train, right_train, disp_train, left_val, right_val, disp_val
38 |
--------------------------------------------------------------------------------
/dataloader/stereo_kittilist15.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | import pdb
4 | from PIL import Image
5 | import os
6 | import os.path
7 | import numpy as np
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath, typ = 'train'):
19 |
20 | left_fold = 'image_2/'
21 | right_fold = 'image_3/'
22 | disp_L = 'disp_occ_0/'
23 | disp_R = 'disp_occ_1/'
24 |
25 | image = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
26 | image = sorted(image)
27 | imglist = [1,3,6,20,26,35,38,41,43,44,49,60,67,70,81,84,89,97,109,119,122,123,129,130,132,134,141,144,152,158,159,165,171,174,179,182, 184,186,187,196]
28 | if typ == 'train':
29 | train = [image[i] for i in range(200) if i not in imglist]
30 | elif typ == 'trainval':
31 | train = [image[i] for i in range(200)]
32 | val = [image[i] for i in imglist]
33 |
34 | left_train = [filepath+left_fold+img for img in train]
35 | right_train = [filepath+right_fold+img for img in train]
36 | disp_train_L = [filepath+disp_L+img for img in train]
37 | #disp_train_R = [filepath+disp_R+img for img in train]
38 |
39 | left_val = [filepath+left_fold+img for img in val]
40 | right_val = [filepath+right_fold+img for img in val]
41 | disp_val_L = [filepath+disp_L+img for img in val]
42 | #disp_val_R = [filepath+disp_R+img for img in val]
43 |
44 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L
45 |
--------------------------------------------------------------------------------
/dataloader/thingslist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 | exc_list = [
19 | '0004117.flo',
20 | '0003149.flo',
21 | '0001203.flo',
22 | '0003147.flo',
23 | '0003666.flo',
24 | '0006337.flo',
25 | '0006336.flo',
26 | '0007126.flo',
27 | '0004118.flo',
28 | ]
29 |
30 | left_fold = 'image_clean/left/'
31 | flow_noc = 'flow/left/into_future/'
32 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
33 |
34 | l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train]
35 | l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf]
36 | flow_trainlf = [filepath+flow_noc+img for img in train]
37 |
38 |
39 | exc_list = [
40 | '0003148.flo',
41 | '0004117.flo',
42 | '0002890.flo',
43 | '0003149.flo',
44 | '0001203.flo',
45 | '0003666.flo',
46 | '0006337.flo',
47 | '0006336.flo',
48 | '0004118.flo',
49 | ]
50 |
51 | left_fold = 'image_clean/right/'
52 | flow_noc = 'flow/right/into_future/'
53 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
54 |
55 | l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train]
56 | l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf]
57 | flow_trainrf = [filepath+flow_noc+img for img in train]
58 |
59 |
60 | exc_list = [
61 | '0004237.flo',
62 | '0004705.flo',
63 | '0004045.flo',
64 | '0004346.flo',
65 | '0000161.flo',
66 | '0000931.flo',
67 | '0000121.flo',
68 | '0010822.flo',
69 | '0004117.flo',
70 | '0006023.flo',
71 | '0005034.flo',
72 | '0005054.flo',
73 | '0000162.flo',
74 | '0000053.flo',
75 | '0005055.flo',
76 | '0003147.flo',
77 | '0004876.flo',
78 | '0000163.flo',
79 | '0006878.flo',
80 | ]
81 |
82 | left_fold = 'image_clean/left/'
83 | flow_noc = 'flow/left/into_past/'
84 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
85 |
86 | l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train]
87 | l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp]
88 | flow_trainlp = [filepath+flow_noc+img for img in train]
89 |
90 | exc_list = [
91 | '0003148.flo',
92 | '0004705.flo',
93 | '0000161.flo',
94 | '0000121.flo',
95 | '0004117.flo',
96 | '0000160.flo',
97 | '0005034.flo',
98 | '0005054.flo',
99 | '0000162.flo',
100 | '0000053.flo',
101 | '0005055.flo',
102 | '0003147.flo',
103 | '0001549.flo',
104 | '0000163.flo',
105 | '0006336.flo',
106 | '0001648.flo',
107 | '0006878.flo',
108 | ]
109 |
110 | left_fold = 'image_clean/right/'
111 | flow_noc = 'flow/right/into_past/'
112 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
113 |
114 | l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train]
115 | l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp]
116 | flow_trainrp = [filepath+flow_noc+img for img in train]
117 |
118 |
119 | l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp
120 | l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp
121 | flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp
122 | return l0_train, l1_train, flow_train
123 |
--------------------------------------------------------------------------------
/dataset/IIW/hawk_000299.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/IIW/hawk_000299.png
--------------------------------------------------------------------------------
/dataset/IIW/hawk_000300.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/IIW/hawk_000300.png
--------------------------------------------------------------------------------
/dataset/kitti_scene/testing/image_2/000042_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/kitti_scene/testing/image_2/000042_10.png
--------------------------------------------------------------------------------
/dataset/kitti_scene/testing/image_2/000042_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/dataset/kitti_scene/testing/image_2/000042_11.png
--------------------------------------------------------------------------------
/eval_tmp.py:
--------------------------------------------------------------------------------
1 | import os
2 | from matplotlib import pyplot as plt
3 | import numpy as np
4 | import sys
5 | sys.path.insert(0,'utils')
6 | from utils.flowlib import flow_to_image, read_flow, compute_color, visualize_flow
7 | from utils.io import mkdir_p
8 | import pdb
9 | import glob
10 |
11 | import argparse
12 |
13 | parser = argparse.ArgumentParser(description='')
14 | parser.add_argument('--path', default='/data/ptmodel/',
15 | help='database')
16 | parser.add_argument('--vis', default='no',
17 | help='database')
18 | parser.add_argument('--dataset', default='2015',
19 | help='database')
20 | args = parser.parse_args()
21 |
22 | aepe_s = []
23 | fall_s64 = []
24 | fall_s32 = []
25 | fall_s16 = []
26 | fall_s8 = []
27 | fall_s = []
28 | oor_tp = []
29 | oor_fp = []
30 |
31 | # dataloader
32 | if args.dataset == '2015':
33 | #from dataloader import kitti15list as DA
34 | #from dataloader import kitti15list_val_lidar as DA
35 | from dataloader import kitti15list_val as DA
36 | datapath = '/ssd/kitti_scene/training/'
37 | elif args.dataset == '2015test':
38 | from dataloader import kitti15list as DA
39 | datapath = '/ssd/kitti_scene/testing/'
40 | elif args.dataset == 'kitticlip':
41 | from dataloader import kitticliplist as DA
42 | #datapath = '/ssd/rob_flow/test/image_2/Kitti2015_000140_'
43 | datapath = '/data/gengshay/KITTI_png/2011_09_30/2011_09_30_drive_0028_sync/image_02/data/'
44 | elif args.dataset == 'tumclip':
45 | from dataloader import kitticliplist as DA
46 | datapath = '/data/gengshay/TUM/rgbd_dataset_freiburg1_plant/rgb/'
47 | elif args.dataset == '2012':
48 | from dataloader import kitti12list as DA
49 | datapath = '/ssd/data_stereo_flow/training/'
50 | elif args.dataset == '2012test':
51 | from dataloader import kitti12list as DA
52 | datapath = '/ssd/data_stereo_flow/testing/'
53 | elif args.dataset == 'mb':
54 | from dataloader import mblist as DA
55 | datapath = '/ssd/rob_flow/training/'
56 | elif args.dataset == 'sintel':
57 | #from dataloader import sintellist as DA
58 | from dataloader import sintellist_val as DA
59 | #from dataloader import sintellist_clean as DA
60 | datapath = '/ssd/rob_flow/training/'
61 | elif args.dataset == 'hd1k':
62 | from dataloader import hd1klist as DA
63 | datapath = '/ssd/rob_flow/training/'
64 | elif args.dataset == 'mbtest':
65 | from dataloader import mblist as DA
66 | datapath = '/ssd/rob_flow/test/'
67 | elif args.dataset == 'sinteltest':
68 | from dataloader import sintellist as DA
69 | datapath = '/ssd/rob_flow/test/'
70 | elif args.dataset == 'chairs':
71 | from dataloader import chairslist as DA
72 | datapath = '/ssd/FlyingChairs_release/data/'
73 | test_left_img, test_right_img ,flow_paths= DA.dataloader(datapath)
74 |
75 | #pdb.set_trace()
76 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_test1.txt','w') as f:
77 | # for i in test_left_img:
78 | # f.write(i+'\n')
79 | #
80 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_test2.txt','w') as f:
81 | # for i in test_right_img:
82 | # f.write(i+'\n')
83 | #
84 | #with open('/data/gengshay/PWC-Net/Caffe/sintel_testout.txt','w') as f:
85 | # for i in test_left_img:
86 | # f.write('/data/ptmodel/pwcnet-1/sintel/%s.flo'%(i.split('/')[-1].split('.')[0])+'\n')
87 | #exit()
88 | if args.dataset == 'chairs':
89 | with open('FlyingChairs_train_val.txt', 'r') as f:
90 | split = [int(i) for i in f.readlines()]
91 | test_left_img = [test_left_img[i] for i,flag in enumerate(split) if flag==2]
92 | test_right_img = [test_right_img[i] for i,flag in enumerate(split) if flag==2]
93 | flow_paths = [flow_paths[i] for i,flag in enumerate(split) if flag==2]
94 |
95 | #pdb.set_trace()
96 | #test_left_img = [i for i in test_left_img if 'clean' in i]
97 | #test_right_img = [i for i in test_right_img if 'clean' in i]
98 | #flow_paths = [i for i in flow_paths if 'clean' in i]
99 |
100 | #for i,gtflow_path in enumerate(sorted(flow_paths)):
101 | for i,gtflow_path in enumerate(flow_paths):
102 | #if not 'Sintel_clean_cave_4_10' in gtflow_path:
103 | # continue
104 | #if i%10!=1:
105 | # continue
106 | num = gtflow_path.split('/')[-1].strip().replace('flow.flo','img1.png')
107 | if not 'test' in args.dataset and not 'clip' in args.dataset:
108 | gtflow = read_flow(gtflow_path)
109 | num = num.replace('jpg','png')
110 | flow = read_flow('%s/%s/%s'%(args.path,args.dataset,num))
111 | if args.vis == 'yes':
112 | #flowimg = flow_to_image(flow)
113 | flowimg = flow_to_image(flow)*np.linalg.norm(flow[:,:,:2],2,2)[:,:,np.newaxis]/100./255.
114 | #gtflowimg = compute_color(gtflow[:,:,0]/20, gtflow[:,:,1]/20)/255.
115 | #flowimg = compute_color(flow[:,:,0]/20, flow[:,:,1]/20)/255.
116 | mkdir_p('%s/%s/flowimg'%(args.path,args.dataset))
117 | plt.imsave('%s/%s/flowimg/%s'%(args.path,args.dataset,num), flowimg)
118 | if 'test' in args.dataset or 'clip' in args.dataset:
119 | continue
120 | gtflowimg = flow_to_image(gtflow)
121 | mkdir_p('%s/%s/gtimg'%(args.path,args.dataset))
122 | plt.imsave('%s/%s/gtimg/%s'%(args.path,args.dataset,num), gtflowimg)
123 |
124 | mask = gtflow[:,:,2]==1
125 |
126 | ## occlusion
127 | #H,W,_ = gtflow.shape
128 | #xx = np.tile(np.asarray(range(0, W))[np.newaxis:],(H,1))
129 | #yy = np.tile(np.asarray(range(0, H))[:,np.newaxis],(1,W))
130 | #occmask = np.logical_or( np.logical_or(xx + gtflow[:,:,0] <0, xx + gtflow[:,:,0]>W-1),
131 | # np.logical_or(yy + gtflow[:,:,1] <0, yy + gtflow[:,:,1]>H-1))
132 | #mask = np.logical_and(mask,~occmask)
133 |
134 |
135 |
136 | # if args.dataset == 'mb':
137 | # ##TODO
138 | # mask = np.logical_and(np.logical_and(np.abs(gtflow[:,:,0]) < 16,np.abs(gtflow[:,:,1]) < 16), mask)
139 | gtflow = gtflow[:,:,:2]
140 | flow = flow[:,:,:2]
141 |
142 | epe = np.sqrt(np.power(gtflow - flow,2).sum(-1))[mask]
143 | gt_mag = np.sqrt(np.power(gtflow,2).sum(-1))[mask]
144 |
145 |
146 | #aepe_s.append( epe.mean() )
147 | #fall_s.append( np.sum(np.logical_and(epe > 3, epe/gt_mag > 0.05)) / float(epe.size) )
148 |
149 | clippx = [0,1000]
150 | inrangepx = np.logical_and((np.abs(gtflow)>=clippx[0]).sum(-1), (np.abs(gtflow)clippx).sum(-1)>0)
154 | gtoorfp = mask*((np.abs(gtflow)>clippx).sum(-1)==0)
155 | oor_tp.append(isoor[gtoortp])
156 | oor_fp.append(isoor[gtoorfp])
157 | if args.vis == 'yes' and 'test' not in args.dataset:
158 | epeimg = np.sqrt(np.power(gtflow - flow,2).sum(-1))*(mask*(np.logical_and((np.abs(gtflow)>=clippx[0]).sum(-1), (np.abs(gtflow) 64)[inrangepx])
164 | fall_s32.append( (epe > 32)[inrangepx])
165 | fall_s16.append( (epe > 16)[inrangepx])
166 | fall_s8.append( (epe > 8)[inrangepx])
167 | fall_s.append( np.logical_and(epe > 3, epe/gt_mag > 0.05)[inrangepx])
168 | # aepe_s.append( epe )
169 | #fall_s.append( epe[gt_mag<32] > 8)
170 | # fall_s.append( np.logical_and(epe > 3, epe/gt_mag > 0.05))
171 | # print(gtflow_path)
172 | #for i in [np.mean(i) for i in aepe_s]:
173 | # print('%f'%i)
174 | #for i in [np.mean(i) for i in fall_s]:
175 | # print('%f'%i)
176 | #print('\t%.1f/%.1f/%.1f/%.1f/%.1f/%.3f'%(
177 | # np.mean( 100*np.concatenate(fall_s64,0)),
178 | # np.mean( 100*np.concatenate(fall_s32,0)),
179 | # np.mean( 100*np.concatenate(fall_s16,0)),
180 | # np.mean( 100*np.concatenate(fall_s8,0)),
181 | # np.mean( 100*np.concatenate(fall_s,0)),
182 | # np.mean( np.concatenate(aepe_s,0))) )
183 | print('\t%.1f/%.3f'%(
184 | np.mean( 100*np.concatenate(fall_s,0)),
185 | np.mean( np.concatenate(aepe_s,0))) )
186 | #print('\t%.1f/%.1f'%(100*np.mean( np.concatenate(oor_tp,0) ), 100*np.mean( np.concatenate(oor_fp,0) )) )
187 |
--------------------------------------------------------------------------------
/figs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/architecture.png
--------------------------------------------------------------------------------
/figs/hawk-vec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/hawk-vec.png
--------------------------------------------------------------------------------
/figs/hawk.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/hawk.png
--------------------------------------------------------------------------------
/figs/kitti-test-42-vec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/kitti-test-42-vec.png
--------------------------------------------------------------------------------
/figs/kitti-test-42.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/kitti-test-42.png
--------------------------------------------------------------------------------
/figs/output-onlinepngtools.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/output-onlinepngtools.png
--------------------------------------------------------------------------------
/figs/time-breakdown.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/figs/time-breakdown.png
--------------------------------------------------------------------------------
/flops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from thop import profile
3 | import pdb
4 | from models.PWCNet import PWCDCNet
5 | from models.VCN import VCN
6 | #from models.VCN_small import VCN
7 |
8 | vcn = VCN([1, 1280,384])
9 | pwcnet = PWCDCNet([1, 1280,384])
10 |
11 | f_pwc, p_pwc = profile(pwcnet, input_size=(2, 3, 384,1280), device='cuda')
12 | f_vcn, p_vcn = profile(vcn, input_size=(2, 3, 384,1280), device='cuda')
13 | print('PWCNet: \tflops(G)/params(M):%.1f/%.2f'%(f_pwc/1e9,p_pwc/1e6))
14 | print('VCN: \t\tflops(G)/params(M):%.1f/%.2f'%(f_vcn/1e9,p_vcn/1e6))
15 |
16 |
17 | #from models.conv4d import butterfly4D, sepConv4dBlock, sepConv4d
18 | #
19 | #model = torch.nn.Conv2d(30,30, (3,3), stride=1, padding=(1, 1), bias=False)
20 | #f, p = profile(model, input_size=(9*9,30,64,128), device='cuda')
21 | #
22 | ##model = torch.nn.Conv3d(30,30, (1,3,3), stride=1, padding=(0, 1, 1), bias=False)
23 | ##f, p = profile(model, input_size=(1,30,9*9,64,128), device='cuda')
24 | #
25 | #
26 | ###model = butterfly4D(64, 12,withbn=True,full=True)
27 | ###model = sepConv4dBlock(16,16,with_bn=True, stride=(1,1,1),full=True)
28 | ##model = sepConv4d(20, 20, (1,1,1), with_bn=True, full=True)
29 | ##f, p = profile(model, input_size=(1,20,9,9,64,128), device='cuda')
30 | #
31 | ##model = torch.nn.Conv2d(256,256,(3,3),padding=(1,1))
32 | ##f, p = profile(model, input_size=(1,256,64,128), device='cuda')
33 | #
34 | #print('flops(G)/params(M):%.1f/%.2f'%(f/1e9,p/1e6))
35 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import cv2
3 | cv2.setNumThreads(0)
4 | import sys
5 | import pdb
6 | import argparse
7 | import collections
8 | import os
9 | import random
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.parallel
13 | import torch.backends.cudnn as cudnn
14 | import torch.optim as optim
15 | import torch.utils.data
16 | from torch.autograd import Variable
17 | import torch.nn.functional as F
18 | import numpy as np
19 | import time
20 | from utils.flowlib import flow_to_image
21 | from models import *
22 | from utils import logger
23 | torch.backends.cudnn.benchmark=True
24 | from utils.multiscaleloss import realEPE
25 |
26 |
27 | parser = argparse.ArgumentParser(description='PSMNet')
28 | parser.add_argument('--maxdisp', type=int ,default=256,
29 | help='maxium disparity, out of range pixels will be masked out. Only affect the coarsest cost volume size')
30 | parser.add_argument('--fac', type=float ,default=1,
31 | help='controls the shape of search grid. Only affect the coarsest cost volume size')
32 | parser.add_argument('--logname', default='logname',
33 | help='name of the log file')
34 | parser.add_argument('--database', default='/',
35 | help='path to the database')
36 | parser.add_argument('--epochs', type=int, default=300,
37 | help='number of epochs to train')
38 | parser.add_argument('--loadmodel', default=None,
39 | help='path of the pre-trained model')
40 | parser.add_argument('--model', default='VCN',
41 | help='VCN or VCN_small')
42 | parser.add_argument('--savemodel', default='./',
43 | help='path to save the model')
44 | parser.add_argument('--retrain', default='false',
45 | help='whether to reset moving mean / other hyperparameters')
46 | parser.add_argument('--stage', default='chairs',
47 | help='one of {chairs, things, 2015train, 2015trainval, sinteltrain, sinteltrainval}')
48 | parser.add_argument('--ngpus', type=int, default=2,
49 | help='number of gpus to use.')
50 | args = parser.parse_args()
51 |
52 | if args.model == 'VCN':
53 | from models.VCN import VCN
54 | elif args.model == 'VCN_small':
55 | from models.VCN_small import VCN
56 |
57 | # fix random seed
58 | torch.manual_seed(1)
59 | def _init_fn(worker_id):
60 | np.random.seed()
61 | random.seed()
62 | torch.manual_seed(8) # do it again
63 | torch.cuda.manual_seed(1)
64 |
65 | ## set hyperparameters for training
66 | ngpus = args.ngpus
67 | batch_size = 4*ngpus
68 | if args.stage == 'chairs' or args.stage == 'things':
69 | lr_schedule = 'slong_ours'
70 | else:
71 | lr_schedule = 'rob_ours'
72 | #baselr = 1e-4
73 | baselr = 1e-3
74 | worker_mul = int(2)
75 | #worker_mul = int(0)
76 | if args.stage == 'chairs' or args.stage == 'things':
77 | datashape = [320,448]
78 | elif '2015' in args.stage:
79 | datashape = [256,768]
80 | elif 'sintel' in args.stage:
81 | datashape = [320,576]
82 | else:
83 | print('error')
84 | exit(0)
85 |
86 | ## dataloader
87 | from dataloader import robloader as dr
88 |
89 | if args.stage == 'chairs' or 'sintel' in args.stage:
90 | # flying chairs
91 | from dataloader import chairslist as lc
92 | iml0, iml1, flowl0 = lc.dataloader('%s/FlyingChairs_release/data/'%args.database)
93 | with open('order.txt','r') as f:
94 | order = [int(i) for i in f.readline().split(' ')]
95 | with open('FlyingChairs_train_val.txt', 'r') as f:
96 | split = [int(i) for i in f.readlines()]
97 | iml0 = [iml0[i] for i in order if split[i]==1]
98 | iml1 = [iml1[i] for i in order if split[i]==1]
99 | flowl0 = [flowl0[i] for i in order if split[i]==1]
100 | loader_chairs = dr.myImageFloder(iml0,iml1,flowl0, shape = datashape)
101 |
102 | if args.stage == 'things' or 'sintel' in args.stage:
103 | # flything things
104 | from dataloader import thingslist as lt
105 | iml0, iml1, flowl0 = lt.dataloader('%s/FlyingThings3D_subset/train/'%args.database)
106 | loader_things = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=1)
107 |
108 | # fine-tuning datasets
109 | if args.stage == '2015train':
110 | from dataloader import kitti15list_train as lk15
111 | else:
112 | from dataloader import kitti15list as lk15
113 | if args.stage == 'sinteltrain':
114 | from dataloader import sintellist_train as ls
115 | else:
116 | from dataloader import sintellist as ls
117 | from dataloader import kitti12list as lk12
118 | from dataloader import hd1klist as lh
119 |
120 | if 'sintel' in args.stage:
121 | iml0, iml1, flowl0 = lk15.dataloader('%s/kitti_scene/training/'%args.database)
122 | loader_kitti15 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, noise=0) # SINTEL
123 | iml0, iml1, flowl0 = lh.dataloader('%s/rob_flow/training/'%args.database)
124 | loader_hd1k = dr.myImageFloder(iml0,iml1,flowl0,shape=datashape, scale=0.5,order=0, noise=0)
125 | iml0, iml1, flowl0 = ls.dataloader('%s/rob_flow/training/'%args.database)
126 | loader_sintel = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=1, noise=0)
127 | if '2015' in args.stage:
128 | iml0, iml1, flowl0 = lk12.dataloader('%s/data_stereo_flow/training/'%args.database)
129 | loader_kitti12 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, prob=0.5)
130 | iml0, iml1, flowl0 = lk15.dataloader('%s/kitti_scene/training/'%args.database)
131 | loader_kitti15 = dr.myImageFloder(iml0,iml1,flowl0, shape=datashape, scale=1, order=0, prob=0.5) # KITTI
132 |
133 | if args.stage=='chairs':
134 | data_inuse = torch.utils.data.ConcatDataset([loader_chairs]*100)
135 | elif args.stage=='things':
136 | data_inuse = torch.utils.data.ConcatDataset([loader_things]*100)
137 | elif '2015' in args.stage:
138 | data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*50+[loader_kitti12]*50)
139 | for i in data_inuse.datasets:
140 | i.black = True
141 | i.cover = True
142 | elif 'sintel' in args.stage:
143 | data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*200+[loader_hd1k]*40 + [loader_sintel]*150 + [loader_chairs]*2 + [loader_things])
144 | for i in data_inuse.datasets:
145 | i.black = True
146 | i.cover = True
147 | else:
148 | print('error')
149 | exit(0)
150 |
151 |
152 |
153 | ## Stereo data
154 | #from dataloader import stereo_kittilist15 as lks15
155 | #from dataloader import stereo_kittilist12 as lks12
156 | #from dataloader import MiddleburyList as lmbs
157 | #from dataloader import listflowfile as lsfs
158 | #
159 | #def disparity_loader_sf(path):
160 | # from utils.util_flow import readPFM as rp
161 | # out = rp(path)[0]
162 | # shape = (out.shape[0], out.shape[1], 1)
163 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1)
164 | # return out
165 | #
166 | #def disparity_loader_mb(path):
167 | # from utils.util_flow import readPFM as rp
168 | # out = rp(path)[0]
169 | # mask = np.asarray(out!=np.inf,float)[:,:,np.newaxis]
170 | # out[out==np.inf]=0
171 | #
172 | # shape = (out.shape[0], out.shape[1], 1)
173 | ## out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1)
174 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),mask),-1)
175 | # return out
176 | #
177 | #def disparity_loader(path):
178 | ## from utils.util_flow import readPFM as rp
179 | ## out = rp(path)[0]
180 | # from PIL import Image
181 | # out = Image.open(path)
182 | # out = np.ascontiguousarray(out,dtype=np.float32)/256
183 | # mask = np.asarray(out>0,float)[:,:,np.newaxis]
184 | #
185 | # shape = (out.shape[0], out.shape[1], 1)
186 | ## out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),np.ones(shape)),-1)
187 | # out = np.concatenate((-out[:,:,np.newaxis],np.zeros(shape),mask),-1)
188 | # return out
189 | ##iml0, iml1, flowl0, _, _, _ = lks15.dataloader('%s/kitti_scene/training/'%args.database, typ='trainval')
190 | ##loader_stereo_15 = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=0, prob=0.5,dploader=disparity_loader)
191 | ##iml0, iml1, flowl0, _, _, _ = lks12.dataloader('%s/data_stereo_flow/training/'%args.database)
192 | ##loader_stereo_12 = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=0, prob=0.5,dploader=disparity_loader)
193 | ##iml0, iml1, flowl0, _, _, _ = lmbs.dataloader('%s/mb-ex-training/'%args.database, res='F')
194 | ##loader_stereo_mb = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=0.5, order=1, prob=0.5,dploader=disparity_loader_mb)
195 | ##iml0, iml1, flowl0, _, _, _,_,_ = lsfs.dataloader('%s/sceneflow/'%args.database)
196 | ##loader_stereo_sf = dr.myImageFloder(iml0,iml1,flowl0,shape = datashape,scale=1, order=1, dploader=disparity_loader_sf)
197 |
198 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*75+[loader_stereo_12]*75+[loader_stereo_mb]*600+[loader_stereo_sf])
199 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*50+[loader_stereo_12]*50+[loader_stereo_mb]*600+[loader_chairs])
200 | #data_inuse = torch.utils.data.ConcatDataset([loader_chairs]*2 + [loader_things] +[loader_stereo_15]*300+[loader_stereo_12]*300) # stereo transfer
201 | #data_inuse = torch.utils.data.ConcatDataset([loader_stereo_15]*20+[loader_stereo_12]*20) # stereo transfer
202 | #data_inuse = torch.utils.data.ConcatDataset([loader_kitti15]*20+[loader_kitti12]*20+[loader_stereo_15]*20+[loader_stereo_12]*20)
203 | print('%d batches per epoch'%(len(data_inuse)//batch_size))
204 |
205 | #TODO
206 | model = VCN([batch_size//ngpus]+data_inuse.datasets[0].shape[::-1], md=[int(4*(args.maxdisp/256)), 4,4,4,4], fac=args.fac)
207 | model = nn.DataParallel(model)
208 | model.cuda()
209 |
210 | total_iters = 0
211 | mean_L=[[0.33,0.33,0.33]]
212 | mean_R=[[0.33,0.33,0.33]]
213 | if args.loadmodel is not None:
214 | pretrained_dict = torch.load(args.loadmodel)
215 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()}
216 |
217 | model.load_state_dict(pretrained_dict['state_dict'],strict=False)
218 | if args.retrain == 'true':
219 | print('re-training')
220 | else:
221 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'r') as f:
222 | total_iters = int(f.readline())
223 | print('resuming from %d'%total_iters)
224 | mean_L=pretrained_dict['mean_L']
225 | mean_R=pretrained_dict['mean_R']
226 |
227 |
228 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
229 | optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), amsgrad=False)
230 |
231 | def train(imgL,imgR,flowl0):
232 | model.train()
233 | imgL = Variable(torch.FloatTensor(imgL))
234 | imgR = Variable(torch.FloatTensor(imgR))
235 | flowl0 = Variable(torch.FloatTensor(flowl0))
236 |
237 | imgL, imgR, flowl0 = imgL.cuda(), imgR.cuda(), flowl0.cuda()
238 | mask = (flowl0[:,:,:,2] == 1) & (flowl0[:,:,:,0].abs() < args.maxdisp) & (flowl0[:,:,:,1].abs() < (args.maxdisp//args.fac))
239 | mask.detach_();
240 |
241 | # rearrange inputs
242 | groups = []
243 | for i in range(ngpus):
244 | groups.append(imgL[i*batch_size//ngpus:(i+1)*batch_size//ngpus])
245 | groups.append(imgR[i*batch_size//ngpus:(i+1)*batch_size//ngpus])
246 |
247 | # forward-backward
248 | optimizer.zero_grad()
249 | output = model(torch.cat(groups,0), [flowl0,mask])
250 | loss = output[-2].mean()
251 | loss.backward()
252 | optimizer.step()
253 |
254 | vis = {}
255 | vis['output2'] = output[0].detach().cpu().numpy()
256 | vis['output3'] = output[1].detach().cpu().numpy()
257 | vis['output4'] = output[2].detach().cpu().numpy()
258 | vis['output5'] = output[3].detach().cpu().numpy()
259 | vis['output6'] = output[4].detach().cpu().numpy()
260 | vis['oor'] = output[6][0].detach().cpu().numpy()
261 | vis['gt'] = flowl0[:,:,:,:].detach().cpu().numpy()
262 | if mask.sum():
263 | vis['AEPE'] = realEPE(output[0].detach(), flowl0.permute(0,3,1,2).detach(),mask,sparse=False)
264 | vis['mask'] = mask
265 | return loss.data,vis
266 |
267 | def adjust_learning_rate(optimizer, total_iters):
268 | if lr_schedule == 'slong':
269 | if total_iters < 200000:
270 | lr = baselr
271 | elif total_iters < 300000:
272 | lr = baselr/2.
273 | elif total_iters < 400000:
274 | lr = baselr/4.
275 | elif total_iters < 500000:
276 | lr = baselr/8.
277 | elif total_iters < 600000:
278 | lr = baselr/16.
279 | if lr_schedule == 'slong_ours':
280 | if total_iters < 70000:
281 | lr = baselr
282 | elif total_iters < 130000:
283 | lr = baselr/2.
284 | elif total_iters < 190000:
285 | lr = baselr/4.
286 | elif total_iters < 240000:
287 | lr = baselr/8.
288 | elif total_iters < 290000:
289 | lr = baselr/16.
290 | if lr_schedule == 'slong_pwc':
291 | if total_iters < 400000:
292 | lr = baselr
293 | elif total_iters < 600000:
294 | lr = baselr/2.
295 | elif total_iters < 800000:
296 | lr = baselr/4.
297 | elif total_iters < 1000000:
298 | lr = baselr/8.
299 | elif total_iters < 1200000:
300 | lr = baselr/16.
301 | if lr_schedule == 'sfine_pwc':
302 | if total_iters < 1400000:
303 | lr = baselr/10.
304 | elif total_iters < 1500000:
305 | lr = baselr/20.
306 | elif total_iters < 1600000:
307 | lr = baselr/40.
308 | elif total_iters < 1700000:
309 | lr = baselr/80.
310 | if lr_schedule == 'sfine':
311 | if total_iters < 700000:
312 | lr = baselr/10.
313 | elif total_iters < 750000:
314 | lr = baselr/20.
315 | elif total_iters < 800000:
316 | lr = baselr/40.
317 | elif total_iters < 850000:
318 | lr = baselr/80.
319 | if lr_schedule == 'rob_ours':
320 | if total_iters < 30000:
321 | lr = baselr
322 | elif total_iters < 40000:
323 | lr = baselr / 2.
324 | elif total_iters < 50000:
325 | lr = baselr / 4.
326 | elif total_iters < 60000:
327 | lr = baselr / 8.
328 | elif total_iters < 70000:
329 | lr = baselr / 16.
330 | elif total_iters < 100000:
331 | lr = baselr
332 | elif total_iters < 110000:
333 | lr = baselr / 2.
334 | elif total_iters < 120000:
335 | lr = baselr / 4.
336 | elif total_iters < 130000:
337 | lr = baselr / 8.
338 | elif total_iters < 140000:
339 | lr = baselr / 16.
340 | print(lr)
341 | for param_group in optimizer.param_groups:
342 | param_group['lr'] = lr
343 |
344 | # get global counts
345 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'w') as f:
346 | f.write('%d'%total_iters)
347 |
348 | def main():
349 | TrainImgLoader = torch.utils.data.DataLoader(
350 | data_inuse,
351 | batch_size= batch_size, shuffle= True, num_workers=worker_mul*batch_size, drop_last=True, worker_init_fn=_init_fn, pin_memory=True)
352 | log = logger.Logger(args.savemodel, name=args.logname)
353 | start_full_time = time.time()
354 | global total_iters
355 |
356 | for epoch in range(1, args.epochs+1):
357 | total_train_loss = 0
358 | total_train_aepe = 0
359 |
360 | # training loop
361 | for batch_idx, (imgL_crop, imgR_crop, flowl0) in enumerate(TrainImgLoader):
362 | if batch_idx % 100 == 0:
363 | adjust_learning_rate(optimizer,total_iters)
364 |
365 | if total_iters < 1000:
366 | # subtract mean
367 | mean_L.append( np.asarray(imgL_crop.mean(0).mean(1).mean(1)) )
368 | mean_R.append( np.asarray(imgR_crop.mean(0).mean(1).mean(1)) )
369 | imgL_crop -= torch.from_numpy(np.asarray(mean_L).mean(0)[np.newaxis,:,np.newaxis, np.newaxis]).float()
370 | imgR_crop -= torch.from_numpy(np.asarray(mean_R).mean(0)[np.newaxis,:,np.newaxis, np.newaxis]).float()
371 |
372 | start_time = time.time()
373 | loss,vis = train(imgL_crop,imgR_crop, flowl0)
374 | print('Iter %d training loss = %.3f , time = %.2f' %(batch_idx, loss, time.time() - start_time))
375 | total_train_loss += loss
376 | total_train_aepe += vis['AEPE']
377 |
378 | if total_iters %10 == 0:
379 | log.scalar_summary('train/loss_batch',loss, total_iters)
380 | log.scalar_summary('train/aepe_batch',vis['AEPE'], total_iters)
381 | if total_iters %100 == 0:
382 | log.image_summary('train/left',imgL_crop[0:1],total_iters)
383 | log.image_summary('train/right',imgR_crop[0:1],total_iters)
384 | log.histo_summary('train/pred_hist',vis['output2'], total_iters)
385 | if len(np.asarray(vis['gt']))>0:
386 | log.histo_summary('train/gt_hist',np.asarray(vis['gt']), total_iters)
387 | gu = vis['gt'][0,:,:,0]; gv = vis['gt'][0,:,:,1]
388 | gu = gu*np.asarray(vis['mask'][0].float().cpu()); gv = gv*np.asarray(vis['mask'][0].float().cpu())
389 | mask = vis['mask'][0].float().cpu()
390 | log.image_summary('train/gt0', flow_to_image(np.concatenate((gu[:,:,np.newaxis],gv[:,:,np.newaxis],mask[:,:,np.newaxis]),-1))[np.newaxis],total_iters)
391 | log.image_summary('train/output2',flow_to_image(vis['output2'][0].transpose((1,2,0)))[np.newaxis],total_iters)
392 | log.image_summary('train/output3',flow_to_image(vis['output3'][0].transpose((1,2,0)))[np.newaxis],total_iters)
393 | log.image_summary('train/output4',flow_to_image(vis['output4'][0].transpose((1,2,0)))[np.newaxis],total_iters)
394 | log.image_summary('train/output5',flow_to_image(vis['output5'][0].transpose((1,2,0)))[np.newaxis],total_iters)
395 | log.image_summary('train/output6',flow_to_image(vis['output6'][0].transpose((1,2,0)))[np.newaxis],total_iters)
396 | log.image_summary('train/oor',vis['oor'][np.newaxis],total_iters)
397 | torch.cuda.empty_cache()
398 | total_iters += 1
399 | # get global counts
400 | with open('./iter_counts-%d.txt'%int(args.logname.split('-')[-1]), 'w') as f:
401 | f.write('%d'%total_iters)
402 |
403 | if (total_iters + 1)%2000==0:
404 | #SAVE
405 | savefilename = args.savemodel+'/'+args.logname+'/finetune_'+str(total_iters)+'.tar'
406 | save_dict = model.state_dict()
407 | save_dict = collections.OrderedDict({k:v for k,v in save_dict.items() if ('flow_reg' not in k or 'conv1' in k) and ('grid' not in k)})
408 | torch.save({
409 | 'iters': total_iters,
410 | 'state_dict': save_dict,
411 | 'train_loss': total_train_loss/len(TrainImgLoader),
412 | 'mean_L': mean_L,
413 | 'mean_R': mean_R,
414 | }, savefilename)
415 |
416 | log.scalar_summary('train/loss',total_train_loss/len(TrainImgLoader), epoch)
417 | log.scalar_summary('train/aepe',total_train_aepe/len(TrainImgLoader), epoch)
418 |
419 |
420 |
421 | print('full finetune time = %.2f HR' %((time.time() - start_full_time)/3600))
422 | print(max_epo)
423 |
424 |
425 | if __name__ == '__main__':
426 | main()
427 |
--------------------------------------------------------------------------------
/models/PWCNet.py:
--------------------------------------------------------------------------------
1 | """
2 | implementation of the PWC-DC network for optical flow estimation by Sun et al., 2018
3 |
4 | Jinwei Gu and Zhile Ren
5 |
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | import os
13 | os.environ['PYTHON_EGG_CACHE'] = 'tmp/' # a writable directory
14 | import numpy as np
15 | import pdb
16 | import time
17 |
18 |
19 |
20 |
21 | __all__ = [
22 | 'pwc_dc_net', 'pwc_dc_net_old', 'pwcnet'
23 | ]
24 |
25 | def pwcnet(data=None):
26 | """FlowNetS model architecture from the
27 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
28 |
29 | Args:
30 | data : pretrained weights of the network. will create a new one if not set
31 | """
32 | model = PWCDCNet()
33 | if data is not None:
34 | model.load_state_dict(data['state_dict'])
35 | return model
36 |
37 |
38 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
39 | return nn.Sequential(
40 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
41 | padding=padding, dilation=dilation, bias=True),
42 | nn.LeakyReLU(0.1))
43 |
44 | def predict_flow(in_planes):
45 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
46 |
47 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
48 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
49 |
50 |
51 |
52 | class WarpModule(nn.Module):
53 | def __init__(self, size):
54 | super(WarpModule, self).__init__()
55 | B,W,H = size
56 | # mesh grid
57 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
58 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
59 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
60 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
61 | self.register_buffer('grid',torch.cat((xx,yy),1).float())
62 |
63 | def forward(self, x, flo):
64 | """
65 | warp an image/tensor (im2) back to im1, according to the optical flow
66 |
67 | x: [B, C, H, W] (im2)
68 | flo: [B, 2, H, W] flow
69 |
70 | """
71 | B, C, H, W = x.size()
72 | vgrid = self.grid + flo
73 |
74 | # scale grid to [-1,1]
75 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/max(W-1,1)-1.0
76 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/max(H-1,1)-1.0
77 |
78 | vgrid = vgrid.permute(0,2,3,1)
79 | output = nn.functional.grid_sample(x, vgrid)
80 | mask = ((vgrid[:,:,:,0].abs()<1) * (vgrid[:,:,:,1].abs()<1)) >0
81 | return output*mask.unsqueeze(1).float()
82 |
83 |
84 | class PWCDCNet(nn.Module):
85 | """
86 | PWC-DC net. add dilation convolution and densenet connections
87 |
88 | """
89 | def __init__(self, size, md=4):
90 | """
91 | input: md --- maximum displacement (for correlation. default: 4), after warpping
92 |
93 | """
94 | super(PWCDCNet,self).__init__()
95 | self.conv1a = conv(3, 16, kernel_size=3, stride=2)
96 | self.conv1aa = conv(16, 16, kernel_size=3, stride=1)
97 | self.conv1b = conv(16, 16, kernel_size=3, stride=1)
98 | self.conv2a = conv(16, 32, kernel_size=3, stride=2)
99 | self.conv2aa = conv(32, 32, kernel_size=3, stride=1)
100 | self.conv2b = conv(32, 32, kernel_size=3, stride=1)
101 | self.conv3a = conv(32, 64, kernel_size=3, stride=2)
102 | self.conv3aa = conv(64, 64, kernel_size=3, stride=1)
103 | self.conv3b = conv(64, 64, kernel_size=3, stride=1)
104 | self.conv4a = conv(64, 96, kernel_size=3, stride=2)
105 | self.conv4aa = conv(96, 96, kernel_size=3, stride=1)
106 | self.conv4b = conv(96, 96, kernel_size=3, stride=1)
107 | self.conv5a = conv(96, 128, kernel_size=3, stride=2)
108 | self.conv5aa = conv(128,128, kernel_size=3, stride=1)
109 | self.conv5b = conv(128,128, kernel_size=3, stride=1)
110 | self.conv6aa = conv(128,196, kernel_size=3, stride=2)
111 | self.conv6a = conv(196,196, kernel_size=3, stride=1)
112 | self.conv6b = conv(196,196, kernel_size=3, stride=1)
113 |
114 | self.warp5 = WarpModule([size[0],size[1]//32,size[2]//32])
115 | self.warp4 = WarpModule([size[0],size[1]//16,size[2]//16])
116 | self.warp3 = WarpModule([size[0],size[1]//8,size[2]//8])
117 | self.warp2 = WarpModule([size[0],size[1]//4,size[2]//4])
118 | self.leakyRELU = nn.LeakyReLU(0.1)
119 |
120 | nd = (2*md+1)**2
121 | dd = np.cumsum([128,128,96,64,32])
122 |
123 | od = nd
124 | self.conv6_0 = conv(od, 128, kernel_size=3, stride=1)
125 | self.conv6_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
126 | self.conv6_2 = conv(od+dd[1],96, kernel_size=3, stride=1)
127 | self.conv6_3 = conv(od+dd[2],64, kernel_size=3, stride=1)
128 | self.conv6_4 = conv(od+dd[3],32, kernel_size=3, stride=1)
129 | self.predict_flow6 = predict_flow(od+dd[4])
130 | self.deconv6 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
131 | self.upfeat6 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
132 |
133 | od = nd+128+4
134 | self.conv5_0 = conv(od, 128, kernel_size=3, stride=1)
135 | self.conv5_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
136 | self.conv5_2 = conv(od+dd[1],96, kernel_size=3, stride=1)
137 | self.conv5_3 = conv(od+dd[2],64, kernel_size=3, stride=1)
138 | self.conv5_4 = conv(od+dd[3],32, kernel_size=3, stride=1)
139 | self.predict_flow5 = predict_flow(od+dd[4])
140 | self.deconv5 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
141 | self.upfeat5 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
142 |
143 | od = nd+96+4
144 | self.conv4_0 = conv(od, 128, kernel_size=3, stride=1)
145 | self.conv4_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
146 | self.conv4_2 = conv(od+dd[1],96, kernel_size=3, stride=1)
147 | self.conv4_3 = conv(od+dd[2],64, kernel_size=3, stride=1)
148 | self.conv4_4 = conv(od+dd[3],32, kernel_size=3, stride=1)
149 | self.predict_flow4 = predict_flow(od+dd[4])
150 | self.deconv4 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
151 | self.upfeat4 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
152 |
153 | od = nd+64+4
154 | self.conv3_0 = conv(od, 128, kernel_size=3, stride=1)
155 | self.conv3_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
156 | self.conv3_2 = conv(od+dd[1],96, kernel_size=3, stride=1)
157 | self.conv3_3 = conv(od+dd[2],64, kernel_size=3, stride=1)
158 | self.conv3_4 = conv(od+dd[3],32, kernel_size=3, stride=1)
159 | self.predict_flow3 = predict_flow(od+dd[4])
160 | self.deconv3 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
161 | self.upfeat3 = deconv(od+dd[4], 2, kernel_size=4, stride=2, padding=1)
162 |
163 | od = nd+32+4
164 | self.conv2_0 = conv(od, 128, kernel_size=3, stride=1)
165 | self.conv2_1 = conv(od+dd[0],128, kernel_size=3, stride=1)
166 | self.conv2_2 = conv(od+dd[1],96, kernel_size=3, stride=1)
167 | self.conv2_3 = conv(od+dd[2],64, kernel_size=3, stride=1)
168 | self.conv2_4 = conv(od+dd[3],32, kernel_size=3, stride=1)
169 | self.predict_flow2 = predict_flow(od+dd[4])
170 | self.deconv2 = deconv(2, 2, kernel_size=4, stride=2, padding=1)
171 |
172 | self.dc_conv1 = conv(od+dd[4], 128, kernel_size=3, stride=1, padding=1, dilation=1)
173 | self.dc_conv2 = conv(128, 128, kernel_size=3, stride=1, padding=2, dilation=2)
174 | self.dc_conv3 = conv(128, 128, kernel_size=3, stride=1, padding=4, dilation=4)
175 | self.dc_conv4 = conv(128, 96, kernel_size=3, stride=1, padding=8, dilation=8)
176 | self.dc_conv5 = conv(96, 64, kernel_size=3, stride=1, padding=16, dilation=16)
177 | self.dc_conv6 = conv(64, 32, kernel_size=3, stride=1, padding=1, dilation=1)
178 | self.dc_conv7 = predict_flow(32)
179 |
180 | for m in self.modules():
181 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
182 | nn.init.kaiming_normal(m.weight.data, mode='fan_in')
183 | if m.bias is not None:
184 | m.bias.data.zero_()
185 |
186 | # load weights
187 | #pretrained_dict = torch.load('/data/gengshay/PWC-Net//PyTorch/pwc_net_chairs.pth.tar')
188 | #pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict.items() if 'grid' not in k and 'flow_reg' not in k}
189 | #self.load_state_dict(pretrained_dict['state_dict'],strict=False)
190 |
191 | def corr(self, refimg_fea, targetimg_fea):
192 | maxdisp=4
193 | b,c,h,w = refimg_fea.shape
194 | targetimg_fea = F.unfold(targetimg_fea, (2*maxdisp+1,2*maxdisp+1), padding=maxdisp).view(b,c,2*maxdisp+1, 2*maxdisp+1**2,h,w)
195 | cost = refimg_fea.view(b,c,h,w)[:,:,np.newaxis, np.newaxis]*targetimg_fea.view(b,c,2*maxdisp+1, 2*maxdisp+1**2,h,w)
196 | cost = cost.sum(1)
197 |
198 | b, ph, pw, h, w = cost.size()
199 | cost = cost.view(b, ph * pw, h, w)/refimg_fea.size(1)
200 | return cost
201 |
202 | def weight_parameters(self):
203 | return [param for name, param in self.named_parameters() if 'weight' in name]
204 |
205 | def bias_parameters(self):
206 | return [param for name, param in self.named_parameters() if 'bias' in name]
207 |
208 | #@profile
209 | def forward(self,im):
210 | bs = im.shape[0]//2
211 | c01 = self.conv1b(self.conv1aa(self.conv1a(im)))
212 | c02 = self.conv2b(self.conv2aa(self.conv2a(c01)))
213 | c03 = self.conv3b(self.conv3aa(self.conv3a(c02)))
214 | c04 = self.conv4b(self.conv4aa(self.conv4a(c03)))
215 | c05 = self.conv5b(self.conv5aa(self.conv5a(c04)))
216 | c06 = self.conv6b(self.conv6a(self.conv6aa(c05)))
217 | c11 = c01[:bs]; c21 = c01[bs:]
218 | c12 = c02[:bs]; c22 = c02[bs:]
219 | c13 = c03[:bs]; c23 = c03[bs:]
220 | c14 = c04[:bs]; c24 = c04[bs:]
221 | c15 = c05[:bs]; c25 = c05[bs:]
222 | c16 = c06[:bs]; c26 = c06[bs:]
223 |
224 | corr6 = self.corr(c16, c26)
225 | corr6 = self.leakyRELU(corr6)
226 | x = torch.cat((self.conv6_0(corr6), corr6),1)
227 | x = torch.cat((self.conv6_1(x), x),1)
228 | x = torch.cat((self.conv6_2(x), x),1)
229 | x = torch.cat((self.conv6_3(x), x),1)
230 | x = torch.cat((self.conv6_4(x), x),1)
231 | flow6 = self.predict_flow6(x)
232 | up_flow6 = self.deconv6(flow6)
233 | up_feat6 = self.upfeat6(x)
234 |
235 |
236 | warp5 = self.warp5(c25, up_flow6*0.625)
237 | corr5 = self.corr(c15, warp5)
238 | corr5 = self.leakyRELU(corr5)
239 | x = torch.cat((corr5, c15, up_flow6, up_feat6), 1)
240 | x = torch.cat((self.conv5_0(x), x),1)
241 | x = torch.cat((self.conv5_1(x), x),1)
242 | x = torch.cat((self.conv5_2(x), x),1)
243 | x = torch.cat((self.conv5_3(x), x),1)
244 | x = torch.cat((self.conv5_4(x), x),1)
245 | flow5 = self.predict_flow5(x)
246 | up_flow5 = self.deconv5(flow5)
247 | up_feat5 = self.upfeat5(x)
248 |
249 |
250 | warp4 = self.warp4(c24, up_flow5*1.25)
251 | corr4 = self.corr(c14, warp4)
252 | corr4 = self.leakyRELU(corr4)
253 | x = torch.cat((corr4, c14, up_flow5, up_feat5), 1)
254 | x = torch.cat((self.conv4_0(x), x),1)
255 | x = torch.cat((self.conv4_1(x), x),1)
256 | x = torch.cat((self.conv4_2(x), x),1)
257 | x = torch.cat((self.conv4_3(x), x),1)
258 | x = torch.cat((self.conv4_4(x), x),1)
259 | flow4 = self.predict_flow4(x)
260 | up_flow4 = self.deconv4(flow4)
261 | up_feat4 = self.upfeat4(x)
262 |
263 |
264 | warp3 = self.warp3(c23, up_flow4*2.5)
265 | corr3 = self.corr(c13, warp3)
266 | corr3 = self.leakyRELU(corr3)
267 | x = torch.cat((corr3, c13, up_flow4, up_feat4), 1)
268 | x = torch.cat((self.conv3_0(x), x),1)
269 | x = torch.cat((self.conv3_1(x), x),1)
270 | x = torch.cat((self.conv3_2(x), x),1)
271 | x = torch.cat((self.conv3_3(x), x),1)
272 | x = torch.cat((self.conv3_4(x), x),1)
273 | flow3 = self.predict_flow3(x)
274 | up_flow3 = self.deconv3(flow3)
275 | up_feat3 = self.upfeat3(x)
276 |
277 |
278 | warp2 = self.warp2(c22, up_flow3*5.0)
279 | corr2 = self.corr(c12, warp2)
280 | corr2 = self.leakyRELU(corr2)
281 | x = torch.cat((corr2, c12, up_flow3, up_feat3), 1)
282 | x = torch.cat((self.conv2_0(x), x),1)
283 | x = torch.cat((self.conv2_1(x), x),1)
284 | x = torch.cat((self.conv2_2(x), x),1)
285 | x = torch.cat((self.conv2_3(x), x),1)
286 | x = torch.cat((self.conv2_4(x), x),1)
287 | flow2 = self.predict_flow2(x)
288 |
289 |
290 | x = self.dc_conv4(self.dc_conv3(self.dc_conv2(self.dc_conv1(x))))
291 | flow2 += self.dc_conv7(self.dc_conv6(self.dc_conv5(x)))
292 |
293 | flow2 = F.upsample(flow2, [im.size()[2],im.size()[3]], mode='bilinear')
294 |
295 | if self.training:
296 | flow3 = F.upsample(flow3, [im.size()[2],im.size()[3]], mode='bilinear')
297 | flow4 = F.upsample(flow4, [im.size()[2],im.size()[3]], mode='bilinear')
298 | flow5 = F.upsample(flow5, [im.size()[2],im.size()[3]], mode='bilinear')
299 | flow6 = F.upsample(flow6, [im.size()[2],im.size()[3]], mode='bilinear')
300 | return flow2*20,flow3*20,flow4*20,flow5*20,flow6*20,flow2, flow2[:,0]
301 | else:
302 | return flow2*20, flow2[0,0]
303 |
304 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/models/__init__.py
--------------------------------------------------------------------------------
/models/conv4d.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch.nn as nn
3 | import math
4 | import torch
5 | from torch.nn.parameter import Parameter
6 | import torch.nn.functional as F
7 | from torch.nn import Module
8 | from torch.nn.modules.conv import _ConvNd
9 | from torch.nn.modules.utils import _quadruple
10 | from torch.autograd import Variable
11 | from torch.nn import Conv2d
12 |
13 | def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
14 | """
15 | This is done by stacking results of multiple 3D convolutions, and is very slow.
16 | Taken from https://github.com/ignacio-rocco/ncnet
17 | """
18 | b,c,h,w,d,t=data.size()
19 |
20 | data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
21 |
22 | # Same permutation is done with filters, unless already provided with permutation
23 | if permute_filters:
24 | filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
25 |
26 | c_out=filters.size(1)
27 | if use_half:
28 | output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
29 | else:
30 | output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
31 |
32 | padding=filters.size(0)//2
33 | if use_half:
34 | Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
35 | else:
36 | Z=Variable(torch.zeros(padding,b,c,w,d,t))
37 |
38 | if data.is_cuda:
39 | Z=Z.cuda(data.get_device())
40 | output=output.cuda(data.get_device())
41 |
42 | data_padded = torch.cat((Z,data,Z),0)
43 |
44 |
45 | for i in range(output.size(0)): # loop on first feature dimension
46 | # convolve with center channel of filter (at position=padding)
47 | output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
48 | filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
49 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
50 | for p in range(1,padding+1):
51 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
52 | filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
53 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
54 | filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
55 |
56 | output=output.permute(1,2,0,3,4,5).contiguous()
57 | return output
58 |
59 | class Conv4d(_ConvNd):
60 | """Applies a 4D convolution over an input signal composed of several input
61 | planes.
62 | """
63 |
64 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
65 | # stride, dilation and groups !=1 functionality not tested
66 | stride=1
67 | dilation=1
68 | groups=1
69 | # zero padding is added automatically in conv4d function to preserve tensor size
70 | padding = 0
71 | kernel_size = _quadruple(kernel_size)
72 | stride = _quadruple(stride)
73 | padding = _quadruple(padding)
74 | dilation = _quadruple(dilation)
75 | super(Conv4d, self).__init__(
76 | in_channels, out_channels, kernel_size, stride, padding, dilation,
77 | False, _quadruple(0), groups, bias)
78 | # weights will be sliced along one dimension during convolution loop
79 | # make the looping dimension to be the first one in the tensor,
80 | # so that we don't need to call contiguous() inside the loop
81 | self.pre_permuted_filters=pre_permuted_filters
82 | if self.pre_permuted_filters:
83 | self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
84 | self.use_half=False
85 | # self.isbias = bias
86 | # if not self.isbias:
87 | # self.bn = torch.nn.BatchNorm1d(out_channels)
88 |
89 |
90 | def forward(self, input):
91 | out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
92 | # if not self.isbias:
93 | # b,c,u,v,h,w = out.shape
94 | # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
95 | return out
96 |
97 | class fullConv4d(torch.nn.Module):
98 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
99 | super(fullConv4d, self).__init__()
100 | self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
101 | self.isbias = bias
102 | if not self.isbias:
103 | self.bn = torch.nn.BatchNorm1d(out_channels)
104 |
105 | def forward(self, input):
106 | out = self.conv(input)
107 | if not self.isbias:
108 | b,c,u,v,h,w = out.shape
109 | out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
110 | return out
111 |
112 | class butterfly4D(torch.nn.Module):
113 | '''
114 | butterfly 4d
115 | '''
116 | def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
117 | super(butterfly4D, self).__init__()
118 | self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
119 | nn.ReLU(inplace=True),)
120 | self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
121 | self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
122 | self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
123 | self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
124 | self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
125 |
126 | #@profile
127 | def forward(self,x):
128 | out = self.proj(x)
129 | b,c,u,v,h,w = out.shape # 9x9
130 |
131 | out1 = self.conva1(out) # 5x5, 3
132 | _,c1,u1,v1,h1,w1 = out1.shape
133 |
134 | out2 = self.conva2(out1) # 3x3, 9
135 | _,c2,u2,v2,h2,w2 = out2.shape
136 |
137 | out2 = self.convb3(out2) # 3x3, 9
138 |
139 | tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
140 | tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
141 | out1 = tout1 + out1
142 | out1 = self.convb2(out1)
143 |
144 | tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
145 | tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
146 | out = tout + out
147 | out = self.convb1(out)
148 |
149 | return out
150 |
151 |
152 |
153 | class projfeat4d(torch.nn.Module):
154 | '''
155 | Turn 3d projection into 2d projection
156 | '''
157 | def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
158 | super(projfeat4d, self).__init__()
159 | self.with_bn = with_bn
160 | self.stride = stride
161 | self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
162 | self.bn = nn.BatchNorm3d(out_planes)
163 |
164 | def forward(self,x):
165 | b,c,u,v,h,w = x.size()
166 | x = self.conv1(x.view(b,c,u,v,h*w))
167 | if self.with_bn:
168 | x = self.bn(x)
169 | _,c,u,v,_ = x.shape
170 | x = x.view(b,c,u,v,h,w)
171 | return x
172 |
173 | class sepConv4d(torch.nn.Module):
174 | '''
175 | Separable 4d convolution block as 2 3D convolutions
176 | '''
177 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
178 | super(sepConv4d, self).__init__()
179 | bias = not with_bn
180 | self.isproj = False
181 | self.stride = stride[0]
182 | expand = 1
183 |
184 | if with_bn:
185 | if in_planes != out_planes:
186 | self.isproj = True
187 | self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
188 | nn.BatchNorm2d(out_planes))
189 | if full:
190 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
191 | nn.BatchNorm3d(in_planes))
192 | else:
193 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
194 | nn.BatchNorm3d(in_planes))
195 | self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
196 | nn.BatchNorm3d(in_planes*expand))
197 | else:
198 | if in_planes != out_planes:
199 | self.isproj = True
200 | self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
201 | if full:
202 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
203 | else:
204 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
205 | self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
206 | self.relu = nn.ReLU(inplace=True)
207 |
208 | #@profile
209 | def forward(self,x):
210 | b,c,u,v,h,w = x.shape
211 | x = self.conv2(x.view(b,c,u,v,-1)) # WTA convolution over (u,v)
212 | b,c,u,v,_ = x.shape
213 | x = self.relu(x)
214 | x = self.conv1(x.view(b,c,-1,h,w)) # spatial convolution over (x,y)
215 | b,c,_,h,w = x.shape
216 |
217 | if self.isproj:
218 | x = self.proj(x.view(b,c,-1,w))
219 | x = x.view(b,-1,u,v,h,w)
220 | return x
221 |
222 |
223 | class sepConv4dBlock(torch.nn.Module):
224 | '''
225 | Separable 4d convolution block as 2 2D convolutions and a projection
226 | layer
227 | '''
228 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
229 | super(sepConv4dBlock, self).__init__()
230 | if in_planes == out_planes and stride==(1,1,1):
231 | self.downsample = None
232 | else:
233 | if full:
234 | self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
235 | else:
236 | self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
237 | self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
238 | self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
239 | self.relu1 = nn.ReLU(inplace=True)
240 | self.relu2 = nn.ReLU(inplace=True)
241 |
242 | #@profile
243 | def forward(self,x):
244 | out = self.relu1(self.conv1(x))
245 | if self.downsample:
246 | x = self.downsample(x)
247 | out = self.relu2(x + self.conv2(out))
248 | return out
249 |
250 |
251 | ##import torch.backends.cudnn as cudnn
252 | ##cudnn.benchmark = True
253 | #import time
254 | ##im = torch.randn(9,64,9,160,224).cuda()
255 | ##net = torch.nn.Conv3d(64, 64, 3).cuda()
256 | ##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
257 | ##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
258 | #
259 | ##im = torch.randn(1,16,9,9,96,320).cuda()
260 | ##net = sepConv4d(16,16,with_bn=False).cuda()
261 | #
262 | ##im = torch.randn(1,16,81,96,320).cuda()
263 | ##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
264 | #
265 | ##im = torch.randn(1,16,9,9,96*320).cuda()
266 | ##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
267 | #
268 | ##im = torch.randn(10000,10,9,9).cuda()
269 | ##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
270 | #
271 | ##im = torch.randn(81,16,96,320).cuda()
272 | ##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
273 | #c= int(16 *1)
274 | #cp = int(16 *1)
275 | #h=int(96 *4)
276 | #w=int(320 *4)
277 | #k=3
278 | #im = torch.randn(1,c,h,w).cuda()
279 | #net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
280 | #
281 | #im2 = torch.randn(cp,k*k*c).cuda()
282 | #im1 = F.unfold(im, (k,k), padding=k//2)[0]
283 | #
284 | #
285 | #net(im)
286 | #net(im)
287 | #torch.mm(im2,im1)
288 | #torch.mm(im2,im1)
289 | #torch.cuda.synchronize()
290 | #beg = time.time()
291 | #for i in range(100):
292 | # net(im)
293 | # #im1 = F.unfold(im, (k,k), padding=k//2)[0]
294 | # torch.mm(im2,im1)
295 | #torch.cuda.synchronize()
296 | #print('%f'%((time.time()-beg)*10.))
297 |
--------------------------------------------------------------------------------
/models/submodule.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import math
8 | import numpy as np
9 | import pdb
10 |
11 | class residualBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, in_channels, n_filters, stride=1, downsample=None,dilation=1,with_bn=True):
15 | super(residualBlock, self).__init__()
16 | if dilation > 1:
17 | padding = dilation
18 | else:
19 | padding = 1
20 |
21 | if with_bn:
22 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation)
23 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1)
24 | else:
25 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, padding, dilation=dilation,with_bn=False)
26 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, with_bn=False)
27 | self.downsample = downsample
28 | self.relu = nn.LeakyReLU(0.1, inplace=True)
29 |
30 | def forward(self, x):
31 | residual = x
32 |
33 | out = self.convbnrelu1(x)
34 | out = self.convbn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | return self.relu(out)
41 |
42 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
43 | return nn.Sequential(
44 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
45 | padding=padding, dilation=dilation, bias=True),
46 | nn.BatchNorm2d(out_planes),
47 | nn.LeakyReLU(0.1,inplace=True))
48 |
49 |
50 | class conv2DBatchNorm(nn.Module):
51 | def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
52 | super(conv2DBatchNorm, self).__init__()
53 | bias = not with_bn
54 |
55 | if dilation > 1:
56 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
57 | padding=padding, stride=stride, bias=bias, dilation=dilation)
58 |
59 | else:
60 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
61 | padding=padding, stride=stride, bias=bias, dilation=1)
62 |
63 |
64 | if with_bn:
65 | self.cb_unit = nn.Sequential(conv_mod,
66 | nn.BatchNorm2d(int(n_filters)),)
67 | else:
68 | self.cb_unit = nn.Sequential(conv_mod,)
69 |
70 | def forward(self, inputs):
71 | outputs = self.cb_unit(inputs)
72 | return outputs
73 |
74 | class conv2DBatchNormRelu(nn.Module):
75 | def __init__(self, in_channels, n_filters, k_size, stride, padding, dilation=1, with_bn=True):
76 | super(conv2DBatchNormRelu, self).__init__()
77 | bias = not with_bn
78 | if dilation > 1:
79 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
80 | padding=padding, stride=stride, bias=bias, dilation=dilation)
81 |
82 | else:
83 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size,
84 | padding=padding, stride=stride, bias=bias, dilation=1)
85 |
86 | if with_bn:
87 | self.cbr_unit = nn.Sequential(conv_mod,
88 | nn.BatchNorm2d(int(n_filters)),
89 | nn.LeakyReLU(0.1, inplace=True),)
90 | else:
91 | self.cbr_unit = nn.Sequential(conv_mod,
92 | nn.LeakyReLU(0.1, inplace=True),)
93 |
94 | def forward(self, inputs):
95 | outputs = self.cbr_unit(inputs)
96 | return outputs
97 |
98 | class pyramidPooling(nn.Module):
99 |
100 | def __init__(self, in_channels, with_bn=True, levels=4):
101 | super(pyramidPooling, self).__init__()
102 | self.levels = levels
103 |
104 | self.paths = []
105 | for i in range(levels):
106 | self.paths.append(conv2DBatchNormRelu(in_channels, in_channels, 1, 1, 0, with_bn=with_bn))
107 | self.path_module_list = nn.ModuleList(self.paths)
108 | self.relu = nn.LeakyReLU(0.1, inplace=True)
109 |
110 | def forward(self, x):
111 | h, w = x.shape[2:]
112 |
113 | k_sizes = []
114 | strides = []
115 | for pool_size in np.linspace(1,min(h,w)//2,self.levels,dtype=int):
116 | k_sizes.append((int(h/pool_size), int(w/pool_size)))
117 | strides.append((int(h/pool_size), int(w/pool_size)))
118 | k_sizes = k_sizes[::-1]
119 | strides = strides[::-1]
120 |
121 | pp_sum = x
122 |
123 | for i, module in enumerate(self.path_module_list):
124 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
125 | out = module(out)
126 | out = F.upsample(out, size=(h,w), mode='bilinear')
127 | pp_sum = pp_sum + 1./self.levels*out
128 | pp_sum = self.relu(pp_sum/2.)
129 |
130 | return pp_sum
131 |
132 | class pspnet(nn.Module):
133 | """
134 | Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
135 | """
136 | def __init__(self, is_proj=True,groups=1):
137 | super(pspnet, self).__init__()
138 | self.inplanes = 32
139 | self.is_proj = is_proj
140 |
141 | # Encoder
142 | self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
143 | padding=1, stride=2)
144 | self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
145 | padding=1, stride=1)
146 | self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
147 | padding=1, stride=1)
148 | # Vanilla Residual Blocks
149 | self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
150 | self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
151 | self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
152 | self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
153 | self.pyramid_pooling = pyramidPooling(128, levels=3)
154 |
155 | # Iconvs
156 | self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
157 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
158 | padding=1, stride=1))
159 | self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
160 | padding=1, stride=1)
161 | self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
162 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
163 | padding=1, stride=1))
164 | self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
165 | padding=1, stride=1)
166 | self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
167 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
168 | padding=1, stride=1))
169 | self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
170 | padding=1, stride=1)
171 | self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
172 | conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
173 | padding=1, stride=1))
174 | self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
175 | padding=1, stride=1)
176 |
177 | if self.is_proj:
178 | self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
179 | self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
180 | self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
181 | self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
182 | self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
183 |
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
187 | m.weight.data.normal_(0, math.sqrt(2. / n))
188 | if hasattr(m.bias,'data'):
189 | m.bias.data.zero_()
190 |
191 |
192 | def _make_layer(self, block, planes, blocks, stride=1):
193 | downsample = None
194 | if stride != 1 or self.inplanes != planes * block.expansion:
195 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
196 | kernel_size=1, stride=stride, bias=False),
197 | nn.BatchNorm2d(planes * block.expansion),)
198 | layers = []
199 | layers.append(block(self.inplanes, planes, stride, downsample))
200 | self.inplanes = planes * block.expansion
201 | for i in range(1, blocks):
202 | layers.append(block(self.inplanes, planes))
203 | return nn.Sequential(*layers)
204 |
205 | def forward(self, x):
206 | # H, W -> H/2, W/2
207 | conv1 = self.convbnrelu1_1(x)
208 | conv1 = self.convbnrelu1_2(conv1)
209 | conv1 = self.convbnrelu1_3(conv1)
210 |
211 | ## H/2, W/2 -> H/4, W/4
212 | pool1 = F.max_pool2d(conv1, 3, 2, 1)
213 |
214 | # H/4, W/4 -> H/16, W/16
215 | rconv3 = self.res_block3(pool1)
216 | conv4 = self.res_block5(rconv3)
217 | conv5 = self.res_block6(conv4)
218 | conv6 = self.res_block7(conv5)
219 | conv6 = self.pyramid_pooling(conv6)
220 |
221 | conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
222 | concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
223 | conv5 = self.iconv5(concat5)
224 |
225 | conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
226 | concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
227 | conv4 = self.iconv4(concat4)
228 |
229 | conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
230 | concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
231 | conv3 = self.iconv3(concat3)
232 |
233 | conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
234 | concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
235 | conv2 = self.iconv2(concat2)
236 |
237 | if self.is_proj:
238 | proj6 = self.proj6(conv6)
239 | proj5 = self.proj5(conv5)
240 | proj4 = self.proj4(conv4)
241 | proj3 = self.proj3(conv3)
242 | proj2 = self.proj2(conv2)
243 | return proj6,proj5,proj4,proj3,proj2
244 | else:
245 | return conv6, conv5, conv4, conv3, conv2
246 |
247 |
248 | class pspnet_s(nn.Module):
249 | """
250 | Modified PSPNet. https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/models/pspnet.py
251 | """
252 | def __init__(self, is_proj=True,groups=1):
253 | super(pspnet_s, self).__init__()
254 | self.inplanes = 32
255 | self.is_proj = is_proj
256 |
257 | # Encoder
258 | self.convbnrelu1_1 = conv2DBatchNormRelu(in_channels=3, k_size=3, n_filters=16,
259 | padding=1, stride=2)
260 | self.convbnrelu1_2 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=16,
261 | padding=1, stride=1)
262 | self.convbnrelu1_3 = conv2DBatchNormRelu(in_channels=16, k_size=3, n_filters=32,
263 | padding=1, stride=1)
264 | # Vanilla Residual Blocks
265 | self.res_block3 = self._make_layer(residualBlock,64,1,stride=2)
266 | self.res_block5 = self._make_layer(residualBlock,128,1,stride=2)
267 | self.res_block6 = self._make_layer(residualBlock,128,1,stride=2)
268 | self.res_block7 = self._make_layer(residualBlock,128,1,stride=2)
269 | self.pyramid_pooling = pyramidPooling(128, levels=3)
270 |
271 | # Iconvs
272 | self.upconv6 = nn.Sequential(nn.Upsample(scale_factor=2),
273 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
274 | padding=1, stride=1))
275 | self.iconv5 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
276 | padding=1, stride=1)
277 | self.upconv5 = nn.Sequential(nn.Upsample(scale_factor=2),
278 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
279 | padding=1, stride=1))
280 | self.iconv4 = conv2DBatchNormRelu(in_channels=192, k_size=3, n_filters=128,
281 | padding=1, stride=1)
282 | self.upconv4 = nn.Sequential(nn.Upsample(scale_factor=2),
283 | conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
284 | padding=1, stride=1))
285 | self.iconv3 = conv2DBatchNormRelu(in_channels=128, k_size=3, n_filters=64,
286 | padding=1, stride=1)
287 | #self.upconv3 = nn.Sequential(nn.Upsample(scale_factor=2),
288 | # conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=32,
289 | # padding=1, stride=1))
290 | #self.iconv2 = conv2DBatchNormRelu(in_channels=64, k_size=3, n_filters=64,
291 | # padding=1, stride=1)
292 |
293 | if self.is_proj:
294 | self.proj6 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
295 | self.proj5 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
296 | self.proj4 = conv2DBatchNormRelu(in_channels=128,k_size=1,n_filters=128//groups, padding=0,stride=1)
297 | self.proj3 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
298 | #self.proj2 = conv2DBatchNormRelu(in_channels=64, k_size=1,n_filters=64//groups, padding=0,stride=1)
299 |
300 | for m in self.modules():
301 | if isinstance(m, nn.Conv2d):
302 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
303 | m.weight.data.normal_(0, math.sqrt(2. / n))
304 | if hasattr(m.bias,'data'):
305 | m.bias.data.zero_()
306 |
307 |
308 | def _make_layer(self, block, planes, blocks, stride=1):
309 | downsample = None
310 | if stride != 1 or self.inplanes != planes * block.expansion:
311 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,
312 | kernel_size=1, stride=stride, bias=False),
313 | nn.BatchNorm2d(planes * block.expansion),)
314 | layers = []
315 | layers.append(block(self.inplanes, planes, stride, downsample))
316 | self.inplanes = planes * block.expansion
317 | for i in range(1, blocks):
318 | layers.append(block(self.inplanes, planes))
319 | return nn.Sequential(*layers)
320 |
321 | def forward(self, x):
322 | # H, W -> H/2, W/2
323 | conv1 = self.convbnrelu1_1(x)
324 | conv1 = self.convbnrelu1_2(conv1)
325 | conv1 = self.convbnrelu1_3(conv1)
326 |
327 | ## H/2, W/2 -> H/4, W/4
328 | pool1 = F.max_pool2d(conv1, 3, 2, 1)
329 |
330 | # H/4, W/4 -> H/16, W/16
331 | rconv3 = self.res_block3(pool1)
332 | conv4 = self.res_block5(rconv3)
333 | conv5 = self.res_block6(conv4)
334 | conv6 = self.res_block7(conv5)
335 | conv6 = self.pyramid_pooling(conv6)
336 |
337 | conv6x = F.upsample(conv6, [conv5.size()[2],conv5.size()[3]],mode='bilinear')
338 | concat5 = torch.cat((conv5,self.upconv6[1](conv6x)),dim=1)
339 | conv5 = self.iconv5(concat5)
340 |
341 | conv5x = F.upsample(conv5, [conv4.size()[2],conv4.size()[3]],mode='bilinear')
342 | concat4 = torch.cat((conv4,self.upconv5[1](conv5x)),dim=1)
343 | conv4 = self.iconv4(concat4)
344 |
345 | conv4x = F.upsample(conv4, [rconv3.size()[2],rconv3.size()[3]],mode='bilinear')
346 | concat3 = torch.cat((rconv3,self.upconv4[1](conv4x)),dim=1)
347 | conv3 = self.iconv3(concat3)
348 |
349 | #conv3x = F.upsample(conv3, [pool1.size()[2],pool1.size()[3]],mode='bilinear')
350 | #concat2 = torch.cat((pool1,self.upconv3[1](conv3x)),dim=1)
351 | #conv2 = self.iconv2(concat2)
352 |
353 | if self.is_proj:
354 | proj6 = self.proj6(conv6)
355 | proj5 = self.proj5(conv5)
356 | proj4 = self.proj4(conv4)
357 | proj3 = self.proj3(conv3)
358 | # proj2 = self.proj2(conv2)
359 | # return proj6,proj5,proj4,proj3,proj2
360 | return proj6,proj5,proj4,proj3
361 | else:
362 | # return conv6, conv5, conv4, conv3, conv2
363 | return conv6, conv5, conv4, conv3
364 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | modelname=$1
2 | filename=$modelname-$2
3 | array=(
4 | 1999 3999 5999 7999 9999 11999 13999 15999 17999 19999 21999 23999 25999 27999 29999 31999 33999 35999 37999 39999 41999 43999 45999 47999 49999 51999 53999 55999 57999 59999 61999 63999 65999 67999 69999 71999 73999 75999 75999 77999 79999 81999 83999 85999 87999 89999 91999 93999 95999 97999 99999 101999 103999 105999 107999 109999 111999 113999 115999 117999 119999 121999 123999 125999 127999 129999 131999 133999 135999 137999 139999
5 | # 141999 143999 145999 147999
6 | #149999
7 | # 151999 153999 155999 157999 159999
8 | # 161999 163999 165999 167999 169999 171999 173999 175999 177999 179999 181999 183999 185999 187999 189999 191999 193999 195999 197999
9 | #199999 201999 203999 205999 207999 209999 211999 213999 215999 217999 219999
10 | #221999 223999 225999 227999 229999 231999 233999
11 | )
12 |
13 | #echo $modelname > results/$filename
14 | echo $modelname > results/s-$filename
15 | mkdir /data/ptmodel/$modelname/kitti15
16 | mkdir /data/ptmodel/$modelname/kitti15/flow
17 | for i in "${array[@]}"
18 | do
19 | # echo $i >> results/$filename
20 | # CUDA_VISIBLE_DEVICES=0 python submission.py --dataset 2015 --datapath /ssd/kitti_scene/training/ --outdir /data/ptmodel/$modelname/ --loadmodel /data/ptmodel/$modelname/finetune_$i.tar --testres 1
21 | # python eval_tmp.py --path /data/ptmodel/$modelname/ --vis no --dataset 2015 >> results/$filename
22 |
23 | echo $i >> results/s-$filename
24 | CUDA_VISIBLE_DEVICES=0 python submission.py --dataset sintel --datapath /ssd/rob_flow/training/ --outdir /data/ptmodel/$modelname/ --loadmodel /data/ptmodel/$modelname/finetune_$i.tar --testres 1 --maxdisp 448 --fac 1.4
25 | python eval_tmp.py --path /data/ptmodel/$modelname/ --vis no --dataset sintel >> results/s-$filename
26 |
27 | done
28 |
--------------------------------------------------------------------------------
/run_self.sh:
--------------------------------------------------------------------------------
1 | ## point $datapath to the folder of your images
2 | datapath=./dataset/IIW/
3 | modelname=things
4 | i=239999
5 | CUDA_VISIBLE_DEVICES=1 python submission.py --dataset kitticlip --datapath $datapath/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 256 --fac 1
6 |
--------------------------------------------------------------------------------
/submission.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | sys.path.insert(0,'utils/')
4 | #sys.path.insert(0,'dataloader/')
5 | sys.path.insert(0,'models/')
6 | import cv2
7 | import pdb
8 | import argparse
9 | import numpy as np
10 | import skimage.io
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.backends.cudnn as cudnn
15 | import torch.utils.data
16 | from torch.autograd import Variable
17 | import time
18 | from utils.io import mkdir_p
19 | from utils.util_flow import write_flow, save_pfm
20 | cudnn.benchmark = False
21 |
22 | parser = argparse.ArgumentParser(description='VCN')
23 | parser.add_argument('--dataset', default='2015',
24 | help='{2015: KITTI-15, sintel}')
25 | parser.add_argument('--datapath', default='/ssd/kitti_scene/training/',
26 | help='data path')
27 | parser.add_argument('--loadmodel', default=None,
28 | help='model path')
29 | parser.add_argument('--outdir', default='output',
30 | help='output path')
31 | parser.add_argument('--model', default='VCN',
32 | help='VCN or VCN_small')
33 | parser.add_argument('--testres', type=float, default=1,
34 | help='resolution, {1: original resolution, 2: 2X resolution}')
35 | parser.add_argument('--maxdisp', type=int ,default=256,
36 | help='maxium disparity. Only affect the coarsest cost volume size')
37 | parser.add_argument('--fac', type=float ,default=1,
38 | help='controls the shape of search grid. Only affect the coarse cost volume size')
39 | args = parser.parse_args()
40 |
41 |
42 |
43 | # dataloader
44 | if args.dataset == '2015':
45 | #from dataloader import kitti15list as DA
46 | from dataloader import kitti15list_val as DA
47 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
48 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
49 | elif args.dataset == '2015test':
50 | from dataloader import kitti15list as DA
51 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
52 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
53 | elif args.dataset == 'tumclip':
54 | from dataloader import kitticliplist as DA
55 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
56 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
57 | elif args.dataset == 'kitticlip':
58 | from dataloader import kitticliplist as DA
59 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
60 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
61 | elif args.dataset == '2012':
62 | from dataloader import kitti12list as DA
63 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
64 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
65 | elif args.dataset == '2012test':
66 | from dataloader import kitti12list as DA
67 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
68 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
69 | elif args.dataset == 'mb':
70 | from dataloader import mblist as DA
71 | maxw,maxh = [int(args.testres*640), int(args.testres*512)]
72 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
73 | elif args.dataset == 'chairs':
74 | from dataloader import chairslist as DA
75 | maxw,maxh = [int(args.testres*512), int(args.testres*384)]
76 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
77 | elif args.dataset == 'sinteltest':
78 | from dataloader import sintellist as DA
79 | maxw,maxh = [int(args.testres*1024), int(args.testres*448)]
80 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
81 | elif args.dataset == 'sintel':
82 | #from dataloader import sintellist_clean as DA
83 | from dataloader import sintellist_val as DA
84 | #from dataloader import sintellist_val_2s as DA
85 | maxw,maxh = [int(args.testres*1024), int(args.testres*448)]
86 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
87 | elif args.dataset == 'hd1k':
88 | from dataloader import hd1klist as DA
89 | maxw,maxh = [int(args.testres*2560), int(args.testres*1088)]
90 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
91 | elif args.dataset == 'mbstereo':
92 | from dataloader import MiddleburySubmit as DA
93 | maxw,maxh = [int(args.testres*900), int(args.testres*750)]
94 | test_left_img, test_right_img ,_= DA.dataloader(args.datapath)
95 | elif args.dataset == 'k15stereo':
96 | from dataloader import stereo_kittilist15 as DA
97 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
98 | test_left_img, test_right_img ,_,_,_,_= DA.dataloader(args.datapath, typ='trainval')
99 | elif args.dataset == 'k12stereo':
100 | from dataloader import stereo_kittilist12 as DA
101 | maxw,maxh = [int(args.testres*1280), int(args.testres*384)]
102 | test_left_img, test_right_img ,_,_,_,_= DA.dataloader(args.datapath)
103 | if args.dataset == 'chairs':
104 | with open('FlyingChairs_train_val.txt', 'r') as f:
105 | split = [int(i) for i in f.readlines()]
106 | test_left_img = [test_left_img[i] for i,flag in enumerate(split) if flag==2]
107 | test_right_img = [test_right_img[i] for i,flag in enumerate(split) if flag==2]
108 |
109 | if args.model == 'VCN':
110 | from models.VCN import VCN
111 | elif args.model == 'VCN_small':
112 | from models.VCN_small import VCN
113 | #if '2015' in args.dataset:
114 | # model = VCN([1, maxw, maxh], md=[8,4,4,4,4], fac=2)
115 | #elif 'sintel' in args.dataset:
116 | # model = VCN([1, maxw, maxh], md=[7,4,4,4,4], fac=1.4)
117 | #else:
118 | # model = VCN([1, maxw, maxh], md=[4,4,4,4,4], fac=1)
119 | model = VCN([1, maxw, maxh], md=[int(4*(args.maxdisp/256)),4,4,4,4], fac=args.fac)
120 |
121 | model = nn.DataParallel(model, device_ids=[0])
122 | model.cuda()
123 | if args.loadmodel is not None:
124 |
125 | pretrained_dict = torch.load(args.loadmodel)
126 | mean_L=pretrained_dict['mean_L']
127 | mean_R=pretrained_dict['mean_R']
128 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items() if 'grid' not in k and (('flow_reg' not in k) or ('conv1' in k))}
129 |
130 | model.load_state_dict(pretrained_dict['state_dict'],strict=False)
131 | else:
132 | mean_L = [[0.33,0.33,0.33]]
133 | mean_R = [[0.33,0.33,0.33]]
134 | print('dry run')
135 |
136 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
137 |
138 |
139 | mkdir_p('%s/%s/'% (args.outdir, args.dataset))
140 | def main():
141 | model.eval()
142 | ttime_all = []
143 | for inx in range(len(test_left_img)):
144 | print(test_left_img[inx])
145 | imgL_o = skimage.io.imread(test_left_img[inx])
146 | imgR_o = skimage.io.imread(test_right_img[inx])
147 |
148 | # for gray input images
149 | if len(imgL_o.shape) == 2:
150 | imgL_o = np.tile(imgL_o[:,:,np.newaxis],(1,1,3))
151 | imgR_o = np.tile(imgR_o[:,:,np.newaxis],(1,1,3))
152 |
153 | # resize
154 | maxh = imgL_o.shape[0]*args.testres
155 | maxw = imgL_o.shape[1]*args.testres
156 | max_h = int(maxh // 64 * 64)
157 | max_w = int(maxw // 64 * 64)
158 | if max_h < maxh: max_h += 64
159 | if max_w < maxw: max_w += 64
160 |
161 | input_size = imgL_o.shape
162 | imgL = cv2.resize(imgL_o,(max_w, max_h))
163 | imgR = cv2.resize(imgR_o,(max_w, max_h))
164 |
165 | # flip channel, subtract mean
166 | imgL = imgL[:,:,::-1].copy() / 255. - np.asarray(mean_L).mean(0)[np.newaxis,np.newaxis,:]
167 | imgR = imgR[:,:,::-1].copy() / 255. - np.asarray(mean_R).mean(0)[np.newaxis,np.newaxis,:]
168 | imgL = np.transpose(imgL, [2,0,1])[np.newaxis]
169 | imgR = np.transpose(imgR, [2,0,1])[np.newaxis]
170 |
171 | # support for any resolution inputs
172 | from models.VCN import WarpModule, flow_reg
173 | if hasattr(model.module, 'flow_reg64'):
174 | model.module.flow_reg64 = flow_reg([1,max_w//64,max_h//64], ent=model.module.flow_reg64.ent, maxdisp=model.module.flow_reg64.md, fac=model.module.flow_reg64.fac).cuda()
175 | if hasattr(model.module, 'flow_reg32'):
176 | model.module.flow_reg32 = flow_reg([1,max_w//64*2,max_h//64*2], ent=model.module.flow_reg32.ent, maxdisp=model.module.flow_reg32.md, fac=model.module.flow_reg32.fac).cuda()
177 | if hasattr(model.module, 'flow_reg16'):
178 | model.module.flow_reg16 = flow_reg([1,max_w//64*4,max_h//64*4], ent=model.module.flow_reg16.ent, maxdisp=model.module.flow_reg16.md, fac=model.module.flow_reg16.fac).cuda()
179 | if hasattr(model.module, 'flow_reg8'):
180 | model.module.flow_reg8 = flow_reg([1,max_w//64*8, max_h//64*8], ent=model.module.flow_reg8.ent, maxdisp=model.module.flow_reg8.md , fac = model.module.flow_reg8.fac).cuda()
181 | if hasattr(model.module, 'flow_reg4'):
182 | model.module.flow_reg4 = flow_reg([1,max_w//64*16, max_h//64*16 ], ent=model.module.flow_reg4.ent, maxdisp=model.module.flow_reg4.md , fac = model.module.flow_reg4.fac).cuda()
183 | model.module.warp5 = WarpModule([1,max_w//32,max_h//32]).cuda()
184 | model.module.warp4 = WarpModule([1,max_w//16,max_h//16]).cuda()
185 | model.module.warp3 = WarpModule([1,max_w//8, max_h//8]).cuda()
186 | model.module.warp2 = WarpModule([1,max_w//4, max_h//4]).cuda()
187 | model.module.warpx = WarpModule([1,max_w, max_h]).cuda()
188 |
189 | # forward
190 | imgL = Variable(torch.FloatTensor(imgL).cuda())
191 | imgR = Variable(torch.FloatTensor(imgR).cuda())
192 | with torch.no_grad():
193 | imgLR = torch.cat([imgL,imgR],0)
194 | model.eval()
195 | torch.cuda.synchronize()
196 | start_time = time.time()
197 | rts = model(imgLR)
198 | torch.cuda.synchronize()
199 | ttime = (time.time() - start_time); print('time = %.2f' % (ttime*1000) )
200 | ttime_all.append(ttime)
201 | pred_disp, entropy = rts
202 |
203 | # upsampling
204 | pred_disp = torch.squeeze(pred_disp).data.cpu().numpy()
205 | pred_disp = cv2.resize(np.transpose(pred_disp,(1,2,0)), (input_size[1], input_size[0]))
206 | pred_disp[:,:,0] *= input_size[1] / max_w
207 | pred_disp[:,:,1] *= input_size[0] / max_h
208 | flow = np.ones([pred_disp.shape[0],pred_disp.shape[1],3])
209 | flow[:,:,:2] = pred_disp
210 | entropy = torch.squeeze(entropy).data.cpu().numpy()
211 | entropy = cv2.resize(entropy, (input_size[1], input_size[0]))
212 |
213 | # save predictions
214 | if args.dataset == 'mbstereo':
215 | dirname = '%s/%s/%s'%(args.outdir, args.dataset, test_left_img[inx].split('/')[-2])
216 | mkdir_p(dirname)
217 | idxname = ('%s/%s')%(dirname.rsplit('/',1)[-1],test_left_img[inx].split('/')[-1])
218 | else:
219 | idxname = test_left_img[inx].split('/')[-1]
220 |
221 | if args.dataset == 'mbstereo':
222 | with open(test_left_img[inx].replace('im0.png','calib.txt')) as f:
223 | lines = f.readlines()
224 | #max_disp = int(int(lines[9].split('=')[-1]))
225 | max_disp = int(int(lines[6].split('=')[-1]))
226 | with open('%s/%s/%s'% (args.outdir, args.dataset,idxname.replace('im0.png','disp0IO.pfm')),'w') as f:
227 | save_pfm(f,np.clip(-flow[::-1,:,0].astype(np.float32),0,max_disp) )
228 | with open('%s/%s/%s/timeIO.txt'%(args.outdir, args.dataset,idxname.split('/')[0]),'w') as f:
229 | f.write(str(ttime))
230 | elif args.dataset == 'k15stereo' or args.dataset == 'k12stereo':
231 | skimage.io.imsave('%s/%s/%s.png'% (args.outdir, args.dataset,idxname.split('.')[0]),(-flow[:,:,0].astype(np.float32)*256).astype('uint16'))
232 | else:
233 | write_flow('%s/%s/%s.png'% (args.outdir, args.dataset,idxname.rsplit('.',1)[0]), flow.copy())
234 | cv2.imwrite('%s/%s/ent-%s.png'% (args.outdir, args.dataset,idxname.rsplit('.',1)[0]), entropy*200)
235 |
236 | torch.cuda.empty_cache()
237 | print(np.mean(ttime_all))
238 |
239 |
240 |
241 | if __name__ == '__main__':
242 | main()
243 |
244 |
--------------------------------------------------------------------------------
/thop/__init__.py:
--------------------------------------------------------------------------------
1 | from .profile import profile
--------------------------------------------------------------------------------
/thop/count_hooks.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | multiply_adds = 1
7 |
8 |
9 | def count_convNd(m, x, y):
10 | x = x[0]
11 |
12 | kernel_ops = m.weight.size()[2:].numel()
13 | bias_ops = 1 if m.bias is not None else 0
14 |
15 | # cout x oW x oH
16 | total_ops = y.nelement() * (m.in_channels//m.groups * kernel_ops + bias_ops)
17 | m.total_ops = torch.Tensor([int(total_ops)])
18 |
19 |
20 | def count_conv2d(m, x, y):
21 | x = x[0]
22 |
23 | cin = m.in_channels
24 | cout = m.out_channels
25 | kh, kw = m.kernel_size
26 | batch_size = x.size()[0]
27 |
28 | out_h = y.size(2)
29 | out_w = y.size(3)
30 |
31 | # ops per output element
32 | # kernel_mul = kh * kw * cin
33 | # kernel_add = kh * kw * cin - 1
34 | kernel_ops = multiply_adds * kh * kw
35 | bias_ops = 1 if m.bias is not None else 0
36 | ops_per_element = kernel_ops + bias_ops
37 |
38 | # total ops
39 | # num_out_elements = y.numel()
40 | output_elements = batch_size * out_w * out_h * cout
41 | total_ops = output_elements * ops_per_element * cin // m.groups
42 |
43 | m.total_ops = torch.Tensor([int(total_ops)])
44 |
45 |
46 | def count_convtranspose2d(m, x, y):
47 | x = x[0]
48 |
49 | cin = m.in_channels
50 | cout = m.out_channels
51 | kh, kw = m.kernel_size
52 | # batch_size = x.size()[0]
53 |
54 | out_h = y.size(2)
55 | out_w = y.size(3)
56 |
57 | # ops per output element
58 | # kernel_mul = kh * kw * cin
59 | # kernel_add = kh * kw * cin - 1
60 | kernel_ops = multiply_adds * kh * kw * cin // m.groups
61 | bias_ops = 1 if m.bias is not None else 0
62 | ops_per_element = kernel_ops + bias_ops
63 |
64 | # total ops
65 | # num_out_elements = y.numel()
66 | # output_elements = batch_size * out_w * out_h * cout
67 | ops_per_element = m.weight.nelement()
68 | output_elements = y.nelement()
69 | total_ops = output_elements * ops_per_element
70 |
71 | m.total_ops = torch.Tensor([int(total_ops)])
72 | import pdb; pdb.set_trace()
73 | print(m.total_ops)
74 |
75 |
76 | def count_bn(m, x, y):
77 | x = x[0]
78 |
79 | nelements = x.numel()
80 | # subtract, divide, gamma, beta
81 | total_ops = 4 * nelements
82 |
83 | m.total_ops = torch.Tensor([int(total_ops)])
84 |
85 |
86 | def count_relu(m, x, y):
87 | x = x[0]
88 |
89 | nelements = x.numel()
90 | total_ops = nelements
91 |
92 | m.total_ops = torch.Tensor([int(total_ops)])
93 |
94 |
95 | def count_softmax(m, x, y):
96 | x = x[0]
97 |
98 | batch_size, nfeatures = x.size()
99 |
100 | total_exp = nfeatures
101 | total_add = nfeatures - 1
102 | total_div = nfeatures
103 | total_ops = batch_size * (total_exp + total_add + total_div)
104 |
105 | m.total_ops = torch.Tensor([int(total_ops)])
106 |
107 |
108 | def count_maxpool(m, x, y):
109 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
110 | num_elements = y.numel()
111 | total_ops = kernel_ops * num_elements
112 |
113 | m.total_ops = torch.Tensor([int(total_ops)])
114 |
115 |
116 | def count_adap_maxpool(m, x, y):
117 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
118 | kernel_ops = torch.prod(kernel)
119 | num_elements = y.numel()
120 | total_ops = kernel_ops * num_elements
121 |
122 | m.total_ops = torch.Tensor([int(total_ops)])
123 |
124 |
125 | def count_avgpool(m, x, y):
126 | total_add = torch.prod(torch.Tensor([m.kernel_size]))
127 | total_div = 1
128 | kernel_ops = total_add + total_div
129 | num_elements = y.numel()
130 | total_ops = kernel_ops * num_elements
131 |
132 | m.total_ops = torch.Tensor([int(total_ops)])
133 |
134 |
135 | def count_adap_avgpool(m, x, y):
136 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
137 | total_add = torch.prod(kernel)
138 | total_div = 1
139 | kernel_ops = total_add + total_div
140 | num_elements = y.numel()
141 | total_ops = kernel_ops * num_elements
142 |
143 | m.total_ops = torch.Tensor([int(total_ops)])
144 |
145 |
146 | def count_linear(m, x, y):
147 | # per output element
148 | total_mul = m.in_features
149 | total_add = m.in_features - 1
150 | num_elements = y.numel()
151 | total_ops = (total_mul + total_add) * num_elements
152 |
153 | m.total_ops = torch.Tensor([int(total_ops)])
154 |
--------------------------------------------------------------------------------
/thop/profile.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules.conv import _ConvNd
6 |
7 | from .count_hooks import *
8 |
9 | register_hooks = {
10 | nn.Conv1d: count_convNd,
11 | nn.Conv2d: count_convNd,
12 | nn.Conv3d: count_convNd,
13 | nn.ConvTranspose2d: count_convNd,
14 |
15 | nn.BatchNorm1d: count_bn,
16 | nn.BatchNorm2d: count_bn,
17 | nn.BatchNorm3d: count_bn,
18 |
19 | nn.ReLU: count_relu,
20 | nn.ReLU6: count_relu,
21 | nn.LeakyReLU: count_relu,
22 |
23 | nn.MaxPool1d: count_maxpool,
24 | nn.MaxPool2d: count_maxpool,
25 | nn.MaxPool3d: count_maxpool,
26 | nn.AdaptiveMaxPool1d: count_adap_maxpool,
27 | nn.AdaptiveMaxPool2d: count_adap_maxpool,
28 | nn.AdaptiveMaxPool3d: count_adap_maxpool,
29 |
30 | nn.AvgPool1d: count_avgpool,
31 | nn.AvgPool2d: count_avgpool,
32 | nn.AvgPool3d: count_avgpool,
33 |
34 | nn.AdaptiveAvgPool1d: count_adap_avgpool,
35 | nn.AdaptiveAvgPool2d: count_adap_avgpool,
36 | nn.AdaptiveAvgPool3d: count_adap_avgpool,
37 | nn.Linear: count_linear,
38 | nn.Dropout: None,
39 | }
40 |
41 |
42 | def profile(model, input_size, custom_ops={}, device="cpu"):
43 | handler_collection = []
44 |
45 | def add_hooks(m):
46 | if len(list(m.children())) > 0:
47 | return
48 |
49 | m.register_buffer('total_ops', torch.zeros(1))
50 | m.register_buffer('total_params', torch.zeros(1))
51 |
52 | for p in m.parameters():
53 | m.total_params += torch.Tensor([p.numel()])
54 |
55 | m_type = type(m)
56 | fn = None
57 |
58 | if m_type in custom_ops:
59 | fn = custom_ops[m_type]
60 | elif m_type in register_hooks:
61 | fn = register_hooks[m_type]
62 | else:
63 | print("Not implemented for ", m)
64 |
65 | if fn is not None:
66 | handler = m.register_forward_hook(fn)
67 | handler_collection.append(handler)
68 |
69 | original_device = model.parameters().__next__().device
70 | training = model.training
71 |
72 | model.eval().to(device)
73 | model.apply(add_hooks)
74 |
75 | x = torch.zeros(input_size).to(device)
76 | with torch.no_grad():
77 | model(x)
78 |
79 | total_ops = 0
80 | total_params = 0
81 | for m in model.modules():
82 | if len(list(m.children())) > 0: # skip for non-leaf module
83 | #import pdb; pdb.set_trace()
84 | #if 'butterfly' in str(m._get_name()): break
85 | print('-> %s'%(str(m._get_name())))
86 | continue
87 | #if not '2d' in str(m._get_name()): continue
88 | #if not '3d' in str(m._get_name()): continue
89 | print("Registered FLOP counter (%.1f M/%.1f) for module %s" % (m.total_ops/1e6, m.total_params, str(m)))
90 | total_ops += m.total_ops
91 | total_params += m.total_params
92 |
93 | total_ops = total_ops.item()
94 | total_params = total_params.item()
95 |
96 | model.train(training).to(original_device)
97 | for handler in handler_collection:
98 | handler.remove()
99 |
100 | return total_ops, total_params
101 |
--------------------------------------------------------------------------------
/thop/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def clever_format(num, format="%.2f"):
3 | if num > 1e12:
4 | return format % (num / 1e12) + "T"
5 | if num > 1e9:
6 | return format % (num / 1e9) + "G"
7 | if num > 1e6:
8 | return format % (num / 1e6) + "M"
9 | if num > 1e3:
10 | return format % (num / 1e3) + "K"
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gengshan-y/VCN/00c4befdbdf4e42050867996a6f686f52086e01a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/flowlib.py:
--------------------------------------------------------------------------------
1 | """
2 | # ==============================
3 | # flowlib.py
4 | # library for optical flow processing
5 | # Author: Ruoteng Li
6 | # Date: 6th Aug 2016
7 | # ==============================
8 | """
9 | import png
10 | from . import pfm
11 | import numpy as np
12 | import matplotlib.colors as cl
13 | import matplotlib.pyplot as plt
14 | from PIL import Image
15 | import cv2
16 | import pdb
17 |
18 |
19 | UNKNOWN_FLOW_THRESH = 1e7
20 | SMALLFLOW = 0.0
21 | LARGEFLOW = 1e8
22 |
23 | """
24 | =============
25 | Flow Section
26 | =============
27 | """
28 |
29 |
30 | def show_flow(filename):
31 | """
32 | visualize optical flow map using matplotlib
33 | :param filename: optical flow file
34 | :return: None
35 | """
36 | flow = read_flow(filename)
37 | img = flow_to_image(flow)
38 | plt.imshow(img)
39 | plt.show()
40 |
41 |
42 | def visualize_flow(flow, mode='Y'):
43 | """
44 | this function visualize the input flow
45 | :param flow: input flow in array
46 | :param mode: choose which color mode to visualize the flow (Y: Ccbcr, RGB: RGB color)
47 | :return: None
48 | """
49 | if mode == 'Y':
50 | # Ccbcr color wheel
51 | img = flow_to_image(flow)
52 | plt.imshow(img)
53 | plt.show()
54 | elif mode == 'RGB':
55 | (h, w) = flow.shape[0:2]
56 | du = flow[:, :, 0]
57 | dv = flow[:, :, 1]
58 | valid = flow[:, :, 2]
59 | max_flow = max(np.max(du), np.max(dv))
60 | img = np.zeros((h, w, 3), dtype=np.float64)
61 | # angle layer
62 | img[:, :, 0] = np.arctan2(dv, du) / (2 * np.pi)
63 | # magnitude layer, normalized to 1
64 | img[:, :, 1] = np.sqrt(du * du + dv * dv) * 8 / max_flow
65 | # phase layer
66 | img[:, :, 2] = 8 - img[:, :, 1]
67 | # clip to [0,1]
68 | small_idx = img[:, :, 0:3] < 0
69 | large_idx = img[:, :, 0:3] > 1
70 | img[small_idx] = 0
71 | img[large_idx] = 1
72 | # convert to rgb
73 | img = cl.hsv_to_rgb(img)
74 | # remove invalid point
75 | import pdb; pdb.set_trace()
76 | img[:, :, 0] = img[:, :, 0] * valid
77 | img[:, :, 1] = img[:, :, 1] * valid
78 | img[:, :, 2] = img[:, :, 2] * valid
79 | # show
80 | plt.imshow(img)
81 | plt.show()
82 |
83 | return None
84 |
85 |
86 | def read_flow(filename):
87 | """
88 | read optical flow data from flow file
89 | :param filename: name of the flow file
90 | :return: optical flow data in numpy array
91 | """
92 | if filename.endswith('.flo'):
93 | flow = read_flo_file(filename)
94 | elif filename.endswith('.png'):
95 | flow = read_png_file(filename)
96 | elif filename.endswith('.pfm'):
97 | flow = read_pfm_file(filename)
98 | else:
99 | raise Exception('Invalid flow file format!')
100 |
101 | return flow
102 |
103 |
104 | def write_flow(flow, filename):
105 | """
106 | write optical flow in Middlebury .flo format
107 | :param flow: optical flow map
108 | :param filename: optical flow file path to be saved
109 | :return: None
110 | """
111 | f = open(filename, 'wb')
112 | magic = np.array([202021.25], dtype=np.float32)
113 | (height, width) = flow.shape[0:2]
114 | w = np.array([width], dtype=np.int32)
115 | h = np.array([height], dtype=np.int32)
116 | magic.tofile(f)
117 | w.tofile(f)
118 | h.tofile(f)
119 | flow.tofile(f)
120 | f.close()
121 |
122 |
123 | def save_flow_image(flow, image_file):
124 | """
125 | save flow visualization into image file
126 | :param flow: optical flow data
127 | :param flow_fil
128 | :return: None
129 | """
130 | flow_img = flow_to_image(flow)
131 | img_out = Image.fromarray(flow_img)
132 | img_out.save(image_file)
133 |
134 |
135 | def flowfile_to_imagefile(flow_file, image_file):
136 | """
137 | convert flowfile into image file
138 | :param flow: optical flow data
139 | :param flow_fil
140 | :return: None
141 | """
142 | flow = read_flow(flow_file)
143 | save_flow_image(flow, image_file)
144 |
145 |
146 | def segment_flow(flow):
147 | h = flow.shape[0]
148 | w = flow.shape[1]
149 | u = flow[:, :, 0]
150 | v = flow[:, :, 1]
151 |
152 | idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))
153 | idx2 = (abs(u) == SMALLFLOW)
154 | class0 = (v == 0) & (u == 0)
155 | u[idx2] = 0.00001
156 | tan_value = v / u
157 |
158 | class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)
159 | class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)
160 | class3 = (tan_value < -1) & (u <= 0) & (v >= 0)
161 | class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)
162 | class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)
163 | class7 = (tan_value < -1) & (u >= 0) & (v <= 0)
164 | class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)
165 | class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)
166 |
167 | seg = np.zeros((h, w))
168 |
169 | seg[class1] = 1
170 | seg[class2] = 2
171 | seg[class3] = 3
172 | seg[class4] = 4
173 | seg[class5] = 5
174 | seg[class6] = 6
175 | seg[class7] = 7
176 | seg[class8] = 8
177 | seg[class0] = 0
178 | seg[idx] = 0
179 |
180 | return seg
181 |
182 |
183 | def flow_error(tu, tv, u, v):
184 | """
185 | Calculate average end point error
186 | :param tu: ground-truth horizontal flow map
187 | :param tv: ground-truth vertical flow map
188 | :param u: estimated horizontal flow map
189 | :param v: estimated vertical flow map
190 | :return: End point error of the estimated flow
191 | """
192 | smallflow = 0.0
193 | '''
194 | stu = tu[bord+1:end-bord,bord+1:end-bord]
195 | stv = tv[bord+1:end-bord,bord+1:end-bord]
196 | su = u[bord+1:end-bord,bord+1:end-bord]
197 | sv = v[bord+1:end-bord,bord+1:end-bord]
198 | '''
199 | stu = tu[:]
200 | stv = tv[:]
201 | su = u[:]
202 | sv = v[:]
203 |
204 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH)
205 | stu[idxUnknow] = 0
206 | stv[idxUnknow] = 0
207 | su[idxUnknow] = 0
208 | sv[idxUnknow] = 0
209 |
210 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)]
211 | index_su = su[ind2]
212 | index_sv = sv[ind2]
213 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1)
214 | un = index_su * an
215 | vn = index_sv * an
216 |
217 | index_stu = stu[ind2]
218 | index_stv = stv[ind2]
219 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1)
220 | tun = index_stu * tn
221 | tvn = index_stv * tn
222 |
223 | '''
224 | angle = un * tun + vn * tvn + (an * tn)
225 | index = [angle == 1.0]
226 | angle[index] = 0.999
227 | ang = np.arccos(angle)
228 | mang = np.mean(ang)
229 | mang = mang * 180 / np.pi
230 | '''
231 |
232 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2)
233 | epe = epe[ind2]
234 | mepe = np.mean(epe)
235 | return mepe
236 |
237 |
238 | def flow_to_image(flow):
239 | """
240 | Convert flow into middlebury color code image
241 | :param flow: optical flow map
242 | :return: optical flow image in middlebury color
243 | """
244 | u = flow[:, :, 0]
245 | v = flow[:, :, 1]
246 |
247 | maxu = -999.
248 | maxv = -999.
249 | minu = 999.
250 | minv = 999.
251 |
252 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
253 | u[idxUnknow] = 0
254 | v[idxUnknow] = 0
255 |
256 | maxu = max(maxu, np.max(u))
257 | minu = min(minu, np.min(u))
258 |
259 | maxv = max(maxv, np.max(v))
260 | minv = min(minv, np.min(v))
261 |
262 | rad = np.sqrt(u ** 2 + v ** 2)
263 | maxrad = max(-1, np.max(rad))
264 |
265 | u = u/(maxrad + np.finfo(float).eps)
266 | v = v/(maxrad + np.finfo(float).eps)
267 |
268 | img = compute_color(u, v)
269 |
270 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
271 | img[idx] = 0
272 |
273 | return np.uint8(img)
274 |
275 |
276 | def evaluate_flow_file(gt_file, pred_file):
277 | """
278 | evaluate the estimated optical flow end point error according to ground truth provided
279 | :param gt_file: ground truth file path
280 | :param pred_file: estimated optical flow file path
281 | :return: end point error, float32
282 | """
283 | # Read flow files and calculate the errors
284 | gt_flow = read_flow(gt_file) # ground truth flow
285 | eva_flow = read_flow(pred_file) # predicted flow
286 | # Calculate errors
287 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1])
288 | return average_pe
289 |
290 |
291 | def evaluate_flow(gt_flow, pred_flow):
292 | """
293 | gt: ground-truth flow
294 | pred: estimated flow
295 | """
296 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1])
297 | return average_pe
298 |
299 |
300 | """
301 | ==============
302 | Disparity Section
303 | ==============
304 | """
305 |
306 |
307 | def read_disp_png(file_name):
308 | """
309 | Read optical flow from KITTI .png file
310 | :param file_name: name of the flow file
311 | :return: optical flow data in matrix
312 | """
313 | image_object = png.Reader(filename=file_name)
314 | image_direct = image_object.asDirect()
315 | image_data = list(image_direct[2])
316 | (w, h) = image_direct[3]['size']
317 | channel = len(image_data[0]) / w
318 | flow = np.zeros((h, w, channel), dtype=np.uint16)
319 | for i in range(len(image_data)):
320 | for j in range(channel):
321 | flow[i, :, j] = image_data[i][j::channel]
322 | return flow[:, :, 0] / 256
323 |
324 |
325 | def disp_to_flowfile(disp, filename):
326 | """
327 | Read KITTI disparity file in png format
328 | :param disp: disparity matrix
329 | :param filename: the flow file name to save
330 | :return: None
331 | """
332 | f = open(filename, 'wb')
333 | magic = np.array([202021.25], dtype=np.float32)
334 | (height, width) = disp.shape[0:2]
335 | w = np.array([width], dtype=np.int32)
336 | h = np.array([height], dtype=np.int32)
337 | empty_map = np.zeros((height, width), dtype=np.float32)
338 | data = np.dstack((disp, empty_map))
339 | magic.tofile(f)
340 | w.tofile(f)
341 | h.tofile(f)
342 | data.tofile(f)
343 | f.close()
344 |
345 |
346 | """
347 | ==============
348 | Image Section
349 | ==============
350 | """
351 |
352 |
353 | def read_image(filename):
354 | """
355 | Read normal image of any format
356 | :param filename: name of the image file
357 | :return: image data in matrix uint8 type
358 | """
359 | img = Image.open(filename)
360 | im = np.array(img)
361 | return im
362 |
363 |
364 | def warp_image(im, flow):
365 | """
366 | Use optical flow to warp image to the next
367 | :param im: image to warp
368 | :param flow: optical flow
369 | :return: warped image
370 | """
371 | from scipy import interpolate
372 | image_height = im.shape[0]
373 | image_width = im.shape[1]
374 | flow_height = flow.shape[0]
375 | flow_width = flow.shape[1]
376 | n = image_height * image_width
377 | (iy, ix) = np.mgrid[0:image_height, 0:image_width]
378 | (fy, fx) = np.mgrid[0:flow_height, 0:flow_width]
379 | fx = fx.astype(np.float64)
380 | fy = fy.astype(np.float64)
381 | fx += flow[:,:,0]
382 | fy += flow[:,:,1]
383 | mask = np.logical_or(fx <0 , fx > flow_width)
384 | mask = np.logical_or(mask, fy < 0)
385 | mask = np.logical_or(mask, fy > flow_height)
386 | fx = np.minimum(np.maximum(fx, 0), flow_width)
387 | fy = np.minimum(np.maximum(fy, 0), flow_height)
388 | points = np.concatenate((ix.reshape(n,1), iy.reshape(n,1)), axis=1)
389 | xi = np.concatenate((fx.reshape(n, 1), fy.reshape(n,1)), axis=1)
390 | warp = np.zeros((image_height, image_width, im.shape[2]))
391 | for i in range(im.shape[2]):
392 | channel = im[:, :, i]
393 | plt.imshow(channel, cmap='gray')
394 | values = channel.reshape(n, 1)
395 | new_channel = interpolate.griddata(points, values, xi, method='cubic')
396 | new_channel = np.reshape(new_channel, [flow_height, flow_width])
397 | new_channel[mask] = 1
398 | warp[:, :, i] = new_channel.astype(np.uint8)
399 |
400 | return warp.astype(np.uint8)
401 |
402 |
403 | """
404 | ==============
405 | Others
406 | ==============
407 | """
408 |
409 | def pfm_to_flo(pfm_file):
410 | flow_filename = pfm_file[0:pfm_file.find('.pfm')] + '.flo'
411 | (data, scale) = pfm.readPFM(pfm_file)
412 | flow = data[:, :, 0:2]
413 | write_flow(flow, flow_filename)
414 |
415 |
416 | def scale_image(image, new_range):
417 | """
418 | Linearly scale the image into desired range
419 | :param image: input image
420 | :param new_range: the new range to be aligned
421 | :return: image normalized in new range
422 | """
423 | min_val = np.min(image).astype(np.float32)
424 | max_val = np.max(image).astype(np.float32)
425 | min_val_new = np.array(min(new_range), dtype=np.float32)
426 | max_val_new = np.array(max(new_range), dtype=np.float32)
427 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new
428 | return scaled_image.astype(np.uint8)
429 |
430 |
431 | def compute_color(u, v):
432 | """
433 | compute optical flow color map
434 | :param u: optical flow horizontal map
435 | :param v: optical flow vertical map
436 | :return: optical flow in color code
437 | """
438 | [h, w] = u.shape
439 | img = np.zeros([h, w, 3])
440 | nanIdx = np.isnan(u) | np.isnan(v)
441 | u[nanIdx] = 0
442 | v[nanIdx] = 0
443 |
444 | colorwheel = make_color_wheel()
445 | ncols = np.size(colorwheel, 0)
446 |
447 | rad = np.sqrt(u**2+v**2)
448 |
449 | a = np.arctan2(-v, -u) / np.pi
450 |
451 | fk = (a+1) / 2 * (ncols - 1) + 1
452 |
453 | k0 = np.floor(fk).astype(int)
454 |
455 | k1 = k0 + 1
456 | k1[k1 == ncols+1] = 1
457 | f = fk - k0
458 |
459 | for i in range(0, np.size(colorwheel,1)):
460 | tmp = colorwheel[:, i]
461 | col0 = tmp[k0-1] / 255
462 | col1 = tmp[k1-1] / 255
463 | col = (1-f) * col0 + f * col1
464 |
465 | idx = rad <= 1
466 | col[idx] = 1-rad[idx]*(1-col[idx])
467 | notidx = np.logical_not(idx)
468 |
469 | col[notidx] *= 0.75
470 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
471 |
472 | return img
473 |
474 |
475 | def make_color_wheel():
476 | """
477 | Generate color wheel according Middlebury color code
478 | :return: Color wheel
479 | """
480 | RY = 15
481 | YG = 6
482 | GC = 4
483 | CB = 11
484 | BM = 13
485 | MR = 6
486 |
487 | ncols = RY + YG + GC + CB + BM + MR
488 |
489 | colorwheel = np.zeros([ncols, 3])
490 |
491 | col = 0
492 |
493 | # RY
494 | colorwheel[0:RY, 0] = 255
495 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
496 | col += RY
497 |
498 | # YG
499 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
500 | colorwheel[col:col+YG, 1] = 255
501 | col += YG
502 |
503 | # GC
504 | colorwheel[col:col+GC, 1] = 255
505 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
506 | col += GC
507 |
508 | # CB
509 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
510 | colorwheel[col:col+CB, 2] = 255
511 | col += CB
512 |
513 | # BM
514 | colorwheel[col:col+BM, 2] = 255
515 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
516 | col += + BM
517 |
518 | # MR
519 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
520 | colorwheel[col:col+MR, 0] = 255
521 |
522 | return colorwheel
523 |
524 |
525 | def read_flo_file(filename):
526 | """
527 | Read from Middlebury .flo file
528 | :param flow_file: name of the flow file
529 | :return: optical flow data in matrix
530 | """
531 | f = open(filename, 'rb')
532 | magic = np.fromfile(f, np.float32, count=1)
533 | data2d = None
534 |
535 | if 202021.25 != magic:
536 | print('Magic number incorrect. Invalid .flo file')
537 | else:
538 | w = np.fromfile(f, np.int32, count=1)
539 | h = np.fromfile(f, np.int32, count=1)
540 | #print("Reading %d x %d flow file in .flo format" % (h, w))
541 | flow = np.ones((h[0],w[0],3))
542 | data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0])
543 | # reshape data into 3D array (columns, rows, channels)
544 | data2d = np.resize(data2d, (h[0], w[0], 2))
545 | flow[:,:,:2] = data2d
546 | f.close()
547 | return flow
548 |
549 |
550 | def read_png_file(flow_file):
551 | """
552 | Read from KITTI .png file
553 | :param flow_file: name of the flow file
554 | :return: optical flow data in matrix
555 | """
556 | flow = cv2.imread(flow_file,-1)[:,:,::-1].astype(np.float64)
557 | # flow_object = png.Reader(filename=flow_file)
558 | # flow_direct = flow_object.asDirect()
559 | # flow_data = list(flow_direct[2])
560 | # (w, h) = flow_direct[3]['size']
561 | # #print("Reading %d x %d flow file in .png format" % (h, w))
562 | # flow = np.zeros((h, w, 3), dtype=np.float64)
563 | # for i in range(len(flow_data)):
564 | # flow[i, :, 0] = flow_data[i][0::3]
565 | # flow[i, :, 1] = flow_data[i][1::3]
566 | # flow[i, :, 2] = flow_data[i][2::3]
567 |
568 | invalid_idx = (flow[:, :, 2] == 0)
569 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
570 | flow[invalid_idx, 0] = 0
571 | flow[invalid_idx, 1] = 0
572 | return flow
573 |
574 |
575 | def read_pfm_file(flow_file):
576 | """
577 | Read from .pfm file
578 | :param flow_file: name of the flow file
579 | :return: optical flow data in matrix
580 | """
581 | (data, scale) = pfm.readPFM(flow_file)
582 | return data
583 |
584 |
585 | # fast resample layer
586 | def resample(img, sz):
587 | """
588 | img: flow map to be resampled
589 | sz: new flow map size. Must be [height,weight]
590 | """
591 | original_image_size = img.shape
592 | in_height = img.shape[0]
593 | in_width = img.shape[1]
594 | out_height = sz[0]
595 | out_width = sz[1]
596 | out_flow = np.zeros((out_height, out_width, 2))
597 | # find scale
598 | height_scale = float(in_height) / float(out_height)
599 | width_scale = float(in_width) / float(out_width)
600 |
601 | [x,y] = np.meshgrid(range(out_width), range(out_height))
602 | xx = x * width_scale
603 | yy = y * height_scale
604 | x0 = np.floor(xx).astype(np.int32)
605 | x1 = x0 + 1
606 | y0 = np.floor(yy).astype(np.int32)
607 | y1 = y0 + 1
608 |
609 | x0 = np.clip(x0,0,in_width-1)
610 | x1 = np.clip(x1,0,in_width-1)
611 | y0 = np.clip(y0,0,in_height-1)
612 | y1 = np.clip(y1,0,in_height-1)
613 |
614 | Ia = img[y0,x0,:]
615 | Ib = img[y1,x0,:]
616 | Ic = img[y0,x1,:]
617 | Id = img[y1,x1,:]
618 |
619 | wa = (y1-yy) * (x1-xx)
620 | wb = (yy-y0) * (x1-xx)
621 | wc = (y1-yy) * (xx-x0)
622 | wd = (yy-y0) * (xx-x0)
623 | out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width
624 | out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height
625 |
626 | return out_flow
627 |
628 |
--------------------------------------------------------------------------------
/utils/io.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import os
3 | import shutil
4 | import sys
5 | import traceback
6 | import zipfile
7 |
8 | if sys.version_info[0] == 2:
9 | import urllib2
10 | else:
11 | import urllib.request
12 |
13 |
14 | # Converts a string to bytes (for writing the string into a file). Provided for
15 | # compatibility with Python 2 and 3.
16 | def StrToBytes(text):
17 | if sys.version_info[0] == 2:
18 | return text
19 | else:
20 | return bytes(text, 'UTF-8')
21 |
22 |
23 | # Outputs the given text and lets the user input a response (submitted by
24 | # pressing the return key). Provided for compatibility with Python 2 and 3.
25 | def GetUserInput(text):
26 | if sys.version_info[0] == 2:
27 | return raw_input(text)
28 | else:
29 | return input(text)
30 |
31 |
32 | # Creates the given directory (hierarchy), which may already exist. Provided for
33 | # compatibility with Python 2 and 3.
34 | def MakeDirsExistOk(directory_path):
35 | try:
36 | os.makedirs(directory_path)
37 | except OSError as exception:
38 | if exception.errno != errno.EEXIST:
39 | raise
40 |
41 |
42 | # Deletes all files and folders within the given folder.
43 | def DeleteFolderContents(folder_path):
44 | for file_name in os.listdir(folder_path):
45 | file_path = os.path.join(folder_path, file_name)
46 | try:
47 | if os.path.isfile(file_path):
48 | os.unlink(file_path)
49 | else: #if os.path.isdir(file_path):
50 | shutil.rmtree(file_path)
51 | except Exception as e:
52 | print('Exception in DeleteFolderContents():')
53 | print(e)
54 | print('Stack trace:')
55 | print(traceback.format_exc())
56 |
57 |
58 | # Creates the given directory, respectively deletes all content of the directory
59 | # in case it already exists.
60 | def MakeCleanDirectory(folder_path):
61 | if os.path.isdir(folder_path):
62 | DeleteFolderContents(folder_path)
63 | else:
64 | MakeDirsExistOk(folder_path)
65 |
66 |
67 | # Downloads the given URL to a file in the given directory. Returns the
68 | # path to the downloaded file.
69 | # In part adapted from: https://stackoverflow.com/questions/22676
70 | def DownloadFile(url, dest_dir_path):
71 | file_name = url.split('/')[-1]
72 | dest_file_path = os.path.join(dest_dir_path, file_name)
73 |
74 | if os.path.isfile(dest_file_path):
75 | print('The following file already exists:')
76 | print(dest_file_path)
77 | print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.')
78 | while True:
79 | response = GetUserInput("> ")
80 | if response == 's':
81 | return dest_file_path
82 | elif response == 'o':
83 | break
84 | else:
85 | print('Please enter o or s.')
86 |
87 | url_object = None
88 | if sys.version_info[0] == 2:
89 | url_object = urllib2.urlopen(url)
90 | else:
91 | url_object = urllib.request.urlopen(url)
92 |
93 | with open(dest_file_path, 'wb') as outfile:
94 | meta = url_object.info()
95 | file_size = 0
96 | if sys.version_info[0] == 2:
97 | file_size = int(meta.getheaders("Content-Length")[0])
98 | else:
99 | file_size = int(meta["Content-Length"])
100 | print("Downloading: %s (size [bytes]: %s)" % (url, file_size))
101 |
102 | file_size_downloaded = 0
103 | block_size = 8192
104 | while True:
105 | buffer = url_object.read(block_size)
106 | if not buffer:
107 | break
108 |
109 | file_size_downloaded += len(buffer)
110 | outfile.write(buffer)
111 |
112 | sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size))
113 | sys.stdout.flush()
114 |
115 | return dest_file_path
116 |
117 |
118 | # Unzips the given zip file into the given directory.
119 | def UnzipFile(file_path, unzip_dir_path, overwrite=True):
120 | zip_ref = zipfile.ZipFile(open(file_path, 'rb'))
121 |
122 | if not overwrite:
123 | for f in zip_ref.namelist():
124 | if not os.path.isfile(os.path.join(unzip_dir_path, f)):
125 | zip_ref.extract(f, path=unzip_dir_path)
126 | else:
127 | print('Not overwriting {}'.format(f))
128 | else:
129 | zip_ref.extractall(unzip_dir_path)
130 | zip_ref.close()
131 |
132 |
133 | # Creates a zip file with the contents of the given directory.
134 | # The archive_base_path must not include the extension .zip. The full, final
135 | # path of the archive is returned by the function.
136 | def ZipDirectory(archive_base_path, root_dir_path):
137 | # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED
138 | with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf:
139 | base_path = os.path.normpath(root_dir_path)
140 | for dirpath, dirnames, filenames in os.walk(root_dir_path):
141 | for name in sorted(dirnames):
142 | path = os.path.normpath(os.path.join(dirpath, name))
143 | zf.write(path, os.path.relpath(path, base_path))
144 | for name in filenames:
145 | path = os.path.normpath(os.path.join(dirpath, name))
146 | if os.path.isfile(path):
147 | zf.write(path, os.path.relpath(path, base_path))
148 |
149 | return archive_base_path+'.zip'
150 |
151 |
152 | # Downloads a zip file and directly unzips it.
153 | def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True):
154 | archive_path = DownloadFile(url, archive_dir_path)
155 | UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite)
156 |
157 | def mkdir_p(path):
158 | try:
159 | os.makedirs(path)
160 | except OSError as exc: # Python >2.5
161 | if exc.errno == errno.EEXIST and os.path.isdir(path):
162 | pass
163 | else:
164 | raise
165 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | File: logger.py
3 | Modified by: Senthil Purushwalkam
4 | Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
5 | Email: spurushwandrewcmuedu
6 | Github: https://github.com/senthilps8
7 | Description:
8 | """
9 | import pdb
10 | import tensorflow as tf
11 | from torch.autograd import Variable
12 | import numpy as np
13 | import scipy.misc
14 | import os
15 | try:
16 | from StringIO import StringIO # Python 2.7
17 | except ImportError:
18 | from io import BytesIO # Python 3.x
19 |
20 |
21 | class Logger(object):
22 |
23 | def __init__(self, log_dir, name=None):
24 | """Create a summary writer logging to log_dir."""
25 | if name is None:
26 | name = 'temp'
27 | self.name = name
28 | if name is not None:
29 | try:
30 | os.makedirs(os.path.join(log_dir, name))
31 | except:
32 | pass
33 | self.writer = tf.summary.FileWriter(os.path.join(log_dir, name),
34 | filename_suffix=name)
35 | else:
36 | self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name)
37 |
38 | def scalar_summary(self, tag, value, step):
39 | """Log a scalar variable."""
40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
41 | self.writer.add_summary(summary, step)
42 |
43 | def image_summary(self, tag, images, step):
44 | """Log a list of images."""
45 |
46 | img_summaries = []
47 | for i, img in enumerate(images):
48 | # Write the image to a string
49 | try:
50 | s = StringIO()
51 | except:
52 | s = BytesIO()
53 | scipy.misc.toimage(img).save(s, format="png")
54 |
55 | # Create an Image object
56 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
57 | height=img.shape[0],
58 | width=img.shape[1])
59 | # Create a Summary value
60 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
61 |
62 | # Create and write Summary
63 | summary = tf.Summary(value=img_summaries)
64 | self.writer.add_summary(summary, step)
65 |
66 | def histo_summary(self, tag, values, step, bins=1000):
67 | """Log a histogram of the tensor of values."""
68 |
69 | # Create a histogram using numpy
70 | counts, bin_edges = np.histogram(values, bins=bins)
71 |
72 | # Fill the fields of the histogram proto
73 | hist = tf.HistogramProto()
74 | hist.min = float(np.min(values))
75 | hist.max = float(np.max(values))
76 | hist.num = int(np.prod(values.shape))
77 | hist.sum = float(np.sum(values))
78 | hist.sum_squares = float(np.sum(values**2))
79 |
80 | # Drop the start of the first bin
81 | bin_edges = bin_edges[1:]
82 |
83 | # Add bin edges and counts
84 | for edge in bin_edges:
85 | hist.bucket_limit.append(edge)
86 | for c in counts:
87 | hist.bucket.append(c)
88 |
89 | # Create and write Summary
90 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
91 | self.writer.add_summary(summary, step)
92 | self.writer.flush()
93 |
94 | def to_np(self, x):
95 | return x.data.cpu().numpy()
96 |
97 | def to_var(self, x):
98 | if torch.cuda.is_available():
99 | x = x.cuda()
100 | return Variable(x)
101 |
102 | def model_param_histo_summary(self, model, step):
103 | """log histogram summary of model's parameters
104 | and parameter gradients
105 | """
106 | for tag, value in model.named_parameters():
107 | if value.grad is None:
108 | continue
109 | tag = tag.replace('.', '/')
110 | tag = self.name+'/'+tag
111 | self.histo_summary(tag, self.to_np(value), step)
112 | self.histo_summary(tag+'/grad', self.to_np(value.grad), step)
113 |
114 |
--------------------------------------------------------------------------------
/utils/multiscaleloss.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/ClementPinard/FlowNetPytorch
3 | """
4 | import pdb
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def EPE(input_flow, target_flow, mask, sparse=False, mean=True):
10 | #mask = target_flow[:,2]>0
11 | target_flow = target_flow[:,:2]
12 | EPE_map = torch.norm(target_flow-input_flow,2,1)
13 | batch_size = EPE_map.size(0)
14 | if sparse:
15 | # invalid flow is defined with both flow coordinates to be exactly 0
16 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
17 |
18 | EPE_map = EPE_map[~mask]
19 | if mean:
20 | return EPE_map[mask].mean()
21 | else:
22 | return EPE_map[mask].sum()/batch_size
23 |
24 | def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True):
25 | #mask = target_flow[:,2]>0
26 | target_flow = target_flow[:,:2]
27 | #TODO
28 | # EPE_map = torch.norm(target_flow-input_flow,2,1)
29 | EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4)
30 | batch_size = EPE_map.size(0)
31 | if sparse:
32 | # invalid flow is defined with both flow coordinates to be exactly 0
33 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
34 |
35 | EPE_map = EPE_map[~mask]
36 | if mean:
37 | return EPE_map[mask].mean()
38 | else:
39 | return EPE_map[mask].sum()/batch_size
40 |
41 | def sparse_max_pool(input, size):
42 | '''Downsample the input by considering 0 values as invalid.
43 |
44 | Unfortunately, no generic interpolation mode can resize a sparse map correctly,
45 | the strategy here is to use max pooling for positive values and "min pooling"
46 | for negative values, the two results are then summed.
47 | This technique allows sparsity to be minized, contrary to nearest interpolation,
48 | which could potentially lose information for isolated data points.'''
49 |
50 | positive = (input > 0).float()
51 | negative = (input < 0).float()
52 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
53 | return output
54 |
55 |
56 | def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False):
57 | def one_scale(output, target, mask, sparse):
58 |
59 | b, _, h, w = output.size()
60 |
61 | if sparse:
62 | target_scaled = sparse_max_pool(target, (h, w))
63 | else:
64 | target_scaled = F.interpolate(target, (h, w), mode='area')
65 | mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1
66 | if rob_loss:
67 | return rob_EPE(output, target_scaled, mask, sparse, mean=False)
68 | else:
69 | return EPE(output, target_scaled, mask, sparse, mean=False)
70 |
71 | if type(network_output) not in [tuple, list]:
72 | network_output = [network_output]
73 | if weights is None:
74 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
75 | assert(len(weights) == len(network_output))
76 |
77 | loss = 0
78 | for output, weight in zip(network_output, weights):
79 | loss += weight * one_scale(output, target_flow, mask, sparse)
80 | return loss
81 |
82 |
83 | def realEPE(output, target, mask, sparse=False):
84 | b, _, h, w = target.size()
85 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
86 | return EPE(upsampled_output, target,mask, sparse, mean=True)
87 |
--------------------------------------------------------------------------------
/utils/pfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 | def readPFM(file):
6 | file = open(file, 'rb')
7 |
8 | color = None
9 | width = None
10 | height = None
11 | scale = None
12 | endian = None
13 |
14 | header = file.readline().rstrip()
15 | if (sys.version[0]) == '3':
16 | header = header.decode('utf-8')
17 | if header == 'PF':
18 | color = True
19 | elif header == 'Pf':
20 | color = False
21 | else:
22 | raise Exception('Not a PFM file.')
23 |
24 | if (sys.version[0]) == '3':
25 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
26 | else:
27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
28 | if dim_match:
29 | width, height = map(int, dim_match.groups())
30 | else:
31 | raise Exception('Malformed PFM header.')
32 |
33 | if (sys.version[0]) == '3':
34 | scale = float(file.readline().rstrip().decode('utf-8'))
35 | else:
36 | scale = float(file.readline().rstrip())
37 |
38 | if scale < 0: # little-endian
39 | endian = '<'
40 | scale = -scale
41 | else:
42 | endian = '>' # big-endian
43 |
44 | data = np.fromfile(file, endian + 'f')
45 | shape = (height, width, 3) if color else (height, width)
46 |
47 | data = np.reshape(data, shape)
48 | data = np.flipud(data)
49 | return data, scale
50 |
51 |
52 | def writePFM(file, image, scale=1):
53 | file = open(file, 'wb')
54 |
55 | color = None
56 |
57 | if image.dtype.name != 'float32':
58 | raise Exception('Image dtype must be float32.')
59 |
60 | image = np.flipud(image)
61 |
62 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
63 | color = True
64 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
65 | color = False
66 | else:
67 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
68 |
69 | file.write('PF\n' if color else 'Pf\n')
70 | file.write('%d %d\n' % (image.shape[1], image.shape[0]))
71 |
72 | endian = image.dtype.byteorder
73 |
74 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
75 | scale = -scale
76 |
77 | file.write('%f\n' % scale)
78 |
79 | image.tofile(file)
80 |
--------------------------------------------------------------------------------
/utils/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if (sys.version[0]) == '3':
17 | header = header.decode('utf-8')
18 | if header == 'PF':
19 | color = True
20 | elif header == 'Pf':
21 | color = False
22 | else:
23 | raise Exception('Not a PFM file.')
24 |
25 | if (sys.version[0]) == '3':
26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
27 | else:
28 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
29 | if dim_match:
30 | width, height = map(int, dim_match.groups())
31 | else:
32 | raise Exception('Malformed PFM header.')
33 |
34 | if (sys.version[0]) == '3':
35 | scale = float(file.readline().rstrip().decode('utf-8'))
36 | else:
37 | scale = float(file.readline().rstrip())
38 |
39 | if scale < 0: # little-endian
40 | endian = '<'
41 | scale = -scale
42 | else:
43 | endian = '>' # big-endian
44 |
45 | data = np.fromfile(file, endian + 'f')
46 | shape = (height, width, 3) if color else (height, width)
47 |
48 | data = np.reshape(data, shape)
49 | data = np.flipud(data)
50 | return data, scale
51 |
52 |
--------------------------------------------------------------------------------
/utils/sintel_io.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python2
2 |
3 | """
4 | I/O script to save and load the data coming with the MPI-Sintel low-level
5 | computer vision benchmark.
6 |
7 | For more details about the benchmark, please visit www.mpi-sintel.de
8 |
9 | CHANGELOG:
10 | v1.0 (2015/02/03): First release
11 |
12 | Copyright (c) 2015 Jonas Wulff
13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany
14 |
15 | """
16 |
17 | # Requirements: Numpy as PIL/Pillow
18 | import numpy as np
19 | from PIL import Image
20 |
21 | # Check for endianness, based on Daniel Scharstein's optical flow code.
22 | # Using little-endian architecture, these two should be equal.
23 | TAG_FLOAT = 202021.25
24 | TAG_CHAR = 'PIEH'
25 |
26 | def flow_read(filename):
27 | """ Read optical flow from file, return (U,V) tuple.
28 |
29 | Original code by Deqing Sun, adapted from Daniel Scharstein.
30 | """
31 | f = open(filename,'rb')
32 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
33 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
34 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
35 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
36 | size = width*height
37 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
38 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2))
39 | u = tmp[:,np.arange(width)*2]
40 | v = tmp[:,np.arange(width)*2 + 1]
41 | return u,v
42 |
43 | def flow_write(filename,uv,v=None):
44 | """ Write optical flow to file.
45 |
46 | If v is None, uv is assumed to contain both u and v channels,
47 | stacked in depth.
48 |
49 | Original code by Deqing Sun, adapted from Daniel Scharstein.
50 | """
51 | nBands = 2
52 |
53 | if v is None:
54 | assert(uv.ndim == 3)
55 | assert(uv.shape[2] == 2)
56 | u = uv[:,:,0]
57 | v = uv[:,:,1]
58 | else:
59 | u = uv
60 |
61 | assert(u.shape == v.shape)
62 | height,width = u.shape
63 | f = open(filename,'wb')
64 | # write the header
65 | f.write(TAG_CHAR)
66 | np.array(width).astype(np.int32).tofile(f)
67 | np.array(height).astype(np.int32).tofile(f)
68 | # arrange into matrix form
69 | tmp = np.zeros((height, width*nBands))
70 | tmp[:,np.arange(width)*2] = u
71 | tmp[:,np.arange(width)*2 + 1] = v
72 | tmp.astype(np.float32).tofile(f)
73 | f.close()
74 |
75 |
76 | def depth_read(filename):
77 | """ Read depth data from file, return as numpy array. """
78 | f = open(filename,'rb')
79 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
80 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
81 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
82 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
83 | size = width*height
84 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
85 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
86 | return depth
87 |
88 | def depth_write(filename, depth):
89 | """ Write depth to file. """
90 | height,width = depth.shape[:2]
91 | f = open(filename,'wb')
92 | # write the header
93 | f.write(TAG_CHAR)
94 | np.array(width).astype(np.int32).tofile(f)
95 | np.array(height).astype(np.int32).tofile(f)
96 |
97 | depth.astype(np.float32).tofile(f)
98 | f.close()
99 |
100 |
101 | def disparity_write(filename,disparity,bitdepth=16):
102 | """ Write disparity to file.
103 |
104 | bitdepth can be either 16 (default) or 32.
105 |
106 | The maximum disparity is 1024, since the image width in Sintel
107 | is 1024.
108 | """
109 | d = disparity.copy()
110 |
111 | # Clip disparity.
112 | d[d>1024] = 1024
113 | d[d<0] = 0
114 |
115 | d_r = (d / 4.0).astype('uint8')
116 | d_g = ((d * (2.0**6)) % 256).astype('uint8')
117 |
118 | out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8')
119 | out[:,:,0] = d_r
120 | out[:,:,1] = d_g
121 |
122 | if bitdepth > 16:
123 | d_b = (d * (2**14) % 256).astype('uint8')
124 | out[:,:,2] = d_b
125 |
126 | Image.fromarray(out,'RGB').save(filename,'PNG')
127 |
128 |
129 | def disparity_read(filename):
130 | """ Return disparity read from filename. """
131 | f_in = np.array(Image.open(filename))
132 | d_r = f_in[:,:,0].astype('float64')
133 | d_g = f_in[:,:,1].astype('float64')
134 | d_b = f_in[:,:,2].astype('float64')
135 |
136 | depth = d_r * 4 + d_g / (2**6) + d_b / (2**14)
137 | return depth
138 |
139 |
140 | #def cam_read(filename):
141 | # """ Read camera data, return (M,N) tuple.
142 | #
143 | # M is the intrinsic matrix, N is the extrinsic matrix, so that
144 | #
145 | # x = M*N*X,
146 | # with x being a point in homogeneous image pixel coordinates, X being a
147 | # point in homogeneous world coordinates.
148 | # """
149 | # txtdata = np.loadtxt(filename)
150 | # intrinsic = txtdata[0,:9].reshape((3,3))
151 | # extrinsic = textdata[1,:12].reshape((3,4))
152 | # return intrinsic,extrinsic
153 | #
154 | #
155 | #def cam_write(filename,M,N):
156 | # """ Write intrinsic matrix M and extrinsic matrix N to file. """
157 | # Z = np.zeros((2,12))
158 | # Z[0,:9] = M.ravel()
159 | # Z[1,:12] = N.ravel()
160 | # np.savetxt(filename,Z)
161 |
162 | def cam_read(filename):
163 | """ Read camera data, return (M,N) tuple.
164 |
165 | M is the intrinsic matrix, N is the extrinsic matrix, so that
166 |
167 | x = M*N*X,
168 | with x being a point in homogeneous image pixel coordinates, X being a
169 | point in homogeneous world coordinates.
170 | """
171 | f = open(filename,'rb')
172 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
173 | assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
174 | M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
175 | N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
176 | return M,N
177 |
178 | def cam_write(filename, M, N):
179 | """ Write intrinsic matrix M and extrinsic matrix N to file. """
180 | f = open(filename,'wb')
181 | # write the header
182 | f.write(TAG_CHAR)
183 | M.astype('float64').tofile(f)
184 | N.astype('float64').tofile(f)
185 | f.close()
186 |
187 |
188 | def segmentation_write(filename,segmentation):
189 | """ Write segmentation to file. """
190 |
191 | segmentation_ = segmentation.astype('int32')
192 | seg_r = np.floor(segmentation_ / (256**2)).astype('uint8')
193 | seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8')
194 | seg_b = np.floor(segmentation_ % 256).astype('uint8')
195 |
196 | out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8')
197 | out[:,:,0] = seg_r
198 | out[:,:,1] = seg_g
199 | out[:,:,2] = seg_b
200 |
201 | Image.fromarray(out,'RGB').save(filename,'PNG')
202 |
203 |
204 | def segmentation_read(filename):
205 | """ Return disparity read from filename. """
206 | f_in = np.array(Image.open(filename))
207 | seg_r = f_in[:,:,0].astype('int32')
208 | seg_g = f_in[:,:,1].astype('int32')
209 | seg_b = f_in[:,:,2].astype('int32')
210 |
211 | segmentation = (seg_r * 256 + seg_g) * 256 + seg_b
212 | return segmentation
213 |
214 |
215 |
--------------------------------------------------------------------------------
/utils/util_flow.py:
--------------------------------------------------------------------------------
1 | import math
2 | import png
3 | import struct
4 | import array
5 | import numpy as np
6 | import cv2
7 | import pdb
8 |
9 | from io import *
10 |
11 | UNKNOWN_FLOW_THRESH = 1e9;
12 | UNKNOWN_FLOW = 1e10;
13 |
14 | # Middlebury checks
15 | TAG_STRING = 'PIEH' # use this when WRITING the file
16 | TAG_FLOAT = 202021.25 # check for this when READING the file
17 |
18 | def readPFM(file):
19 | import re
20 | file = open(file, 'rb')
21 |
22 | color = None
23 | width = None
24 | height = None
25 | scale = None
26 | endian = None
27 |
28 | header = file.readline().rstrip()
29 | if header == b'PF':
30 | color = True
31 | elif header == b'Pf':
32 | color = False
33 | else:
34 | raise Exception('Not a PFM file.')
35 |
36 | dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline())
37 | if dim_match:
38 | width, height = map(int, dim_match.groups())
39 | else:
40 | raise Exception('Malformed PFM header.')
41 |
42 | scale = float(file.readline().rstrip())
43 | if scale < 0: # little-endian
44 | endian = '<'
45 | scale = -scale
46 | else:
47 | endian = '>' # big-endian
48 |
49 | data = np.fromfile(file, endian + 'f')
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 | return data, scale
55 |
56 |
57 | def save_pfm(file, image, scale = 1):
58 | import sys
59 | color = None
60 |
61 | if image.dtype.name != 'float32':
62 | raise Exception('Image dtype must be float32.')
63 |
64 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
65 | color = True
66 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
67 | color = False
68 | else:
69 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
70 |
71 | file.write('PF\n' if color else 'Pf\n')
72 | file.write('%d %d\n' % (image.shape[1], image.shape[0]))
73 |
74 | endian = image.dtype.byteorder
75 |
76 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
77 | scale = -scale
78 |
79 | file.write('%f\n' % scale)
80 |
81 | image.tofile(file)
82 |
83 |
84 | def ReadMiddleburyFloFile(path):
85 | """ Read .FLO file as specified by Middlebury.
86 |
87 | Returns tuple (width, height, u, v, mask), where u, v, mask are flat
88 | arrays of values.
89 | """
90 |
91 | with open(path, 'rb') as fil:
92 | tag = struct.unpack('f', fil.read(4))[0]
93 | width = struct.unpack('i', fil.read(4))[0]
94 | height = struct.unpack('i', fil.read(4))[0]
95 |
96 | assert tag == TAG_FLOAT
97 |
98 | #data = np.fromfile(path, dtype=np.float, count=-1)
99 | #data = data[3:]
100 |
101 | fmt = 'f' * width*height*2
102 | data = struct.unpack(fmt, fil.read(4*width*height*2))
103 |
104 | u = data[::2]
105 | v = data[1::2]
106 |
107 | mask = map(lambda x,y: abs(x) 0:
144 | # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2])
145 |
146 | #png_reader.close()
147 |
148 | return (width, height, u, v, mask)
149 |
150 |
151 | def WriteMiddleburyFloFile(path, width, height, u, v, mask=None):
152 | """ Write .FLO file as specified by Middlebury.
153 | """
154 |
155 | if mask is not None:
156 | u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask)
157 | v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask)
158 | else:
159 | u_masked = u
160 | v_masked = v
161 |
162 | fmt = 'f' * width*height*2
163 | # Interleave lists
164 | data = [x for t in zip(u_masked,v_masked) for x in t]
165 |
166 | with open(path, 'wb') as fil:
167 | fil.write(str.encode(TAG_STRING))
168 | fil.write(struct.pack('i', width))
169 | fil.write(struct.pack('i', height))
170 | fil.write(struct.pack(fmt, *data))
171 |
172 |
173 | def write_flow(path,flow):
174 |
175 | invalid_idx = (flow[:, :, 2] == 0)
176 | flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15
177 | flow[invalid_idx, 0] = 0
178 | flow[invalid_idx, 1] = 0
179 |
180 | flow = flow.astype(np.uint16)
181 | flow = cv2.imwrite(path, flow[:,:,::-1])
182 |
183 | #WriteKittiPngFile(path,
184 | # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(),
185 | # flow[:,:,1].flatten(), flow[:,:,2].flatten())
186 |
187 |
188 |
189 | def WriteKittiPngFile(path, width, height, u, v, mask=None):
190 | """ Write 16-bit .PNG file as specified by KITTI-2015 (flow).
191 |
192 | u, v are lists of float values
193 | mask is a list of floats, denoting the *valid* pixels.
194 | """
195 |
196 | data = array.array('H',[0])*width*height*3
197 |
198 | for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)):
199 | data[3*i] = int(u_*64.0+2**15)
200 | data[3*i+1] = int(v_*64.0+2**15)
201 | data[3*i+2] = int(mask_)
202 |
203 | # if mask_ > 0:
204 | # print(data[3*i], data[3*i+1],data[3*i+2])
205 |
206 | with open(path, 'wb') as png_file:
207 | png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False)
208 | png_writer.write_array(png_file, data)
209 |
210 |
211 | def ConvertMiddleburyFloToKittiPng(src_path, dest_path):
212 | width, height, u, v, mask = ReadMiddleburyFloFile(src_path)
213 | WriteKittiPngFile(dest_path, width, height, u, v, mask=mask)
214 |
215 | def ConvertKittiPngToMiddleburyFlo(src_path, dest_path):
216 | width, height, u, v, mask = ReadKittiPngFile(src_path)
217 | WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask)
218 |
219 |
220 | def ParseFilenameKitti(filename):
221 | # Parse kitti filename (seq_frameno.xx),
222 | # return seq, frameno, ext.
223 | # Be aware that seq might contain the dataset name (if contained as prefix)
224 | ext = filename[filename.rfind('.'):]
225 | frameno = filename[filename.rfind('_')+1:filename.rfind('.')]
226 | frameno = int(frameno)
227 | seq = filename[:filename.rfind('_')]
228 | return seq, frameno, ext
229 |
230 |
231 | def read_calib_file(filepath):
232 | """Read in a calibration file and parse into a dictionary."""
233 | data = {}
234 |
235 | with open(filepath, 'r') as f:
236 | for line in f.readlines():
237 | key, value = line.split(':', 1)
238 | # The only non-float values in these files are dates, which
239 | # we don't care about anyway
240 | try:
241 | data[key] = np.array([float(x) for x in value.split()])
242 | except ValueError:
243 | pass
244 |
245 | return data
246 |
247 | def load_calib_cam_to_cam(cam_to_cam_file):
248 | # We'll return the camera calibration as a dictionary
249 | data = {}
250 |
251 | # Load and parse the cam-to-cam calibration data
252 | filedata = read_calib_file(cam_to_cam_file)
253 |
254 | # Create 3x4 projection matrices
255 | P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4))
256 | P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4))
257 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4))
258 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4))
259 |
260 | # Compute the camera intrinsics
261 | data['K_cam0'] = P_rect_00[0:3, 0:3]
262 | data['K_cam1'] = P_rect_10[0:3, 0:3]
263 | data['K_cam2'] = P_rect_20[0:3, 0:3]
264 | data['K_cam3'] = P_rect_30[0:3, 0:3]
265 |
266 | data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0]
267 | data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0]
268 | data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0]
269 | data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0]
270 |
271 | return data
272 |
273 |
--------------------------------------------------------------------------------