├── .gitignore
├── LICENSE
├── README.md
├── datasets
├── KITTI.py
├── __init__.py
├── flyingchairs.py
├── listdataset.py
├── mpisintel.py
└── util.py
├── flow_transforms.py
├── images
├── GT_1.png
├── GT_2.png
├── GT_3.png
├── input_1.gif
├── input_2.gif
├── input_3.gif
├── pred_1.png
├── pred_2.png
└── pred_3.png
├── main.py
├── models
├── FlowNetC.py
├── FlowNetS.py
├── __init__.py
└── util.py
├── multiscaleloss.py
├── requirements.txt
├── run_inference.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | */_pycache__
2 | *.pyc
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Clément Pinard
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FlowNetPytorch
2 | Pytorch implementation of FlowNet by Dosovitskiy et al.
3 |
4 | This repository is a torch implementation of [FlowNet](http://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/), by [Alexey Dosovitskiy](http://lmb.informatik.uni-freiburg.de/people/dosovits/) et al. in PyTorch. See Torch implementation [here](https://github.com/ClementPinard/FlowNetTorch)
5 |
6 | This code is mainly inspired from official [imagenet example](https://github.com/pytorch/examples/tree/master/imagenet).
7 | It has not been tested for multiple GPU, but it should work just as in original code.
8 |
9 | The code provides a training example, using [the flying chair dataset](http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) , with data augmentation. An implementation for [Scene Flow Datasets](http://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) may be added in the future.
10 |
11 | Two neural network models are currently provided, along with their batch norm variation (experimental) :
12 |
13 | - **FlowNetS**
14 | - **FlowNetSBN**
15 | - **FlowNetC**
16 | - **FlowNetCBN**
17 |
18 | ## Pretrained Models
19 | Thanks to [Kaixhin](https://github.com/Kaixhin) you can download a pretrained version of FlowNetS (from caffe, not from pytorch) [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharing). This folder also contains trained networks from scratch.
20 |
21 | ### Note on networks loading
22 | Directly feed the downloaded Network to the script, you don't need to uncompress it even if your desktop environment tells you so.
23 |
24 | ### Note on networks from caffe
25 | These networks expect a BGR input (compared to RGB in pytorch). However, BGR order is not very important.
26 |
27 | ## Prerequisite
28 | these modules can be installed with `pip`
29 |
30 | ```
31 | pytorch >= 1.2
32 | tensorboard-pytorch
33 | tensorboardX >= 1.4
34 | spatial-correlation-sampler>=0.2.1
35 | imageio
36 | argparse
37 | path.py
38 | ```
39 |
40 | or
41 | ```bash
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ## Training on Flying Chair Dataset
46 |
47 | First, you need to download the [the flying chair dataset](http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) . It is ~64GB big and we recommend you put it in a SSD Drive.
48 |
49 | Default HyperParameters provided in `main.py` are the same as in the caffe training scripts.
50 |
51 | * Example usage for FlowNetS :
52 |
53 | ```bash
54 | python main.py /path/to/flying_chairs/ -b8 -j8 -a flownets
55 | ```
56 |
57 | We recommend you set j (number of data threads) to high if you use DataAugmentation as to avoid data loading to slow the training.
58 |
59 | For further help you can type
60 |
61 | ```bash
62 | python main.py -h
63 | ```
64 |
65 | ## Visualizing training
66 | [Tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) is used for logging. To visualize result, simply type
67 |
68 | ```bash
69 | tensorboard --logdir=/path/to/checkpoints
70 | ```
71 |
72 | ## Training results
73 |
74 | Models can be downloaded [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharing) in the pytorch folder.
75 |
76 | Models were trained with default options unless specified. Color warping was not used.
77 |
78 | | Arch | learning rate | batch size | epoch size | filename | validation EPE |
79 | | ----------- | ------------- | ---------- | ---------- | ---------------------------- | -------------- |
80 | | FlowNetS | 1e-4 | 8 | 2700 | flownets_EPE1.951.pth.tar | 1.951 |
81 | | FlowNetS BN | 1e-3 | 32 | 695 | flownets_bn_EPE2.459.pth.tar | 2.459 |
82 | | FlowNetC | 1e-4 | 8 | 2700 | flownetc_EPE1.766.pth.tar | 1.766 |
83 |
84 | *Note* : FlowNetS BN took longer to train and got worse results. It is strongly advised not to you use it for Flying Chairs dataset.
85 |
86 | ## Validation samples
87 |
88 | Prediction are made by FlowNetS.
89 |
90 | Exact code for Optical Flow -> Color map can be found [here](main.py#L321)
91 |
92 | | Input | prediction | GroundTruth |
93 | |-------|------------|-------------|
94 | |
|
|
|
95 | |
|
|
|
96 | |
|
|
|
97 |
98 | ## Running inference on a set of image pairs
99 |
100 | If you need to run the network on your images, you can download a pretrained network [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharingM) and launch the inference script on your folder of image pairs.
101 |
102 | Your folder needs to have all the images pairs in the same location, with the name pattern
103 | ```
104 | {image_name}1.{ext}
105 | {image_name}2.{ext}
106 | ```
107 |
108 | ```bash
109 | python3 run_inference.py /path/to/images/folder /path/to/pretrained
110 | ```
111 |
112 | As for the `main.py` script, a help menu is available for additional options.
113 |
114 | ## Note on transform functions
115 |
116 | In order to have coherent transformations between inputs and target, we must define new transformations that take both input and target, as a new random variable is defined each time a random transformation is called.
117 |
118 | ### Flow Transformations
119 |
120 | To allow data augmentation, we have considered rotation and translations for inputs and their result on target flow Map.
121 | Here is a set of things to take care of in order to achieve a proper data augmentation
122 |
123 | #### The Flow Map is directly linked to img1
124 | If you apply a transformation on img1, you have to apply the very same to Flow Map, to get coherent origin points for flow.
125 |
126 | #### Translation between img1 and img2
127 | Given a translation `(tx,ty)` applied on img2, we will have
128 | ```
129 | flow[:,:,0] += tx
130 | flow[:,:,1] += ty
131 | ```
132 |
133 | #### Scale
134 | A scale applied on both img1 and img2 with a zoom parameters `alpha` multiplies the flow by the same amount
135 | ```
136 | flow *= alpha
137 | ```
138 |
139 | #### Rotation applied on both images
140 | A rotation applied on both images by an angle `theta` also rotates flow vectors (`flow[i,j]`) by the same angle
141 | ```
142 | \for_all i,j flow[i,j] = rotate(flow[i,j], theta)
143 |
144 | rotate: x,y,theta -> (x*cos(theta)-x*sin(theta), y*cos(theta), x*sin(theta))
145 | ```
146 |
147 | #### Rotation applied on img2
148 | Let us consider a rotation by the angle `theta` from the image center.
149 |
150 | We must tranform each flow vector based on the coordinates where it lands. On each coordinate `(i, j)`, we have:
151 | ```
152 | flow[i, j, 0] += (cos(theta) - 1) * (j - w/2 + flow[i, j, 0]) + sin(theta) * (i - h/2 + flow[i, j, 1])
153 | flow[i, j, 1] += -sin(theta) * (j - w/2 + flow[i, j, 0]) + (cos(theta) - 1) * (i - h/2 + flow[i, j, 1])
154 | ```
155 |
--------------------------------------------------------------------------------
/datasets/KITTI.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os.path
3 | import glob
4 | from .listdataset import ListDataset
5 | from .util import split2list
6 | import numpy as np
7 | import flow_transforms
8 |
9 | try:
10 | import cv2
11 | except ImportError:
12 | import warnings
13 |
14 | with warnings.catch_warnings():
15 | warnings.filterwarnings("default", category=ImportWarning)
16 | warnings.warn(
17 | "failed to load openCV, which is needed"
18 | "for KITTI which uses 16bit PNG images",
19 | ImportWarning,
20 | )
21 |
22 | """
23 | Dataset routines for KITTI_flow, 2012 and 2015.
24 | http://www.cvlibs.net/datasets/kitti/eval_flow.php
25 | The dataset is not very big, you might want to only finetune on it for flownet
26 | EPE are not representative in this dataset because of the sparsity of the GT.
27 | OpenCV is needed to load 16bit png images
28 | """
29 |
30 |
31 | def load_flow_from_png(png_path):
32 | # The -1 is here to specify not to change the image depth (16bit), and is compatible
33 | # with both OpenCV2 and OpenCV3
34 | flo_file = cv2.imread(png_path, -1)
35 | flo_img = flo_file[:, :, 2:0:-1].astype(np.float32)
36 | invalid = flo_file[:, :, 0] == 0
37 | flo_img = flo_img - 32768
38 | flo_img = flo_img / 64
39 | flo_img[np.abs(flo_img) < 1e-10] = 1e-10
40 | flo_img[invalid, :] = 0
41 | return flo_img
42 |
43 |
44 | def make_dataset(dir, split, split_save_path, occ=True):
45 | """Will search in training folder for folders 'flow_noc' or 'flow_occ'
46 | and 'colored_0' (KITTI 2012) or 'image_2' (KITTI 2015)"""
47 | flow_dir = "flow_occ" if occ else "flow_noc"
48 | assert os.path.isdir(os.path.join(dir, flow_dir))
49 | img_dir = "colored_0"
50 | if not os.path.isdir(os.path.join(dir, img_dir)):
51 | img_dir = "image_2"
52 | assert os.path.isdir(os.path.join(dir, img_dir))
53 |
54 | images = []
55 | for flow_map in glob.iglob(os.path.join(dir, flow_dir, "*.png")):
56 | flow_map = os.path.basename(flow_map)
57 | root_filename = flow_map[:-7]
58 | flow_map = os.path.join(flow_dir, flow_map)
59 | img1 = os.path.join(img_dir, root_filename + "_10.png")
60 | img2 = os.path.join(img_dir, root_filename + "_11.png")
61 | if not (
62 | os.path.isfile(os.path.join(dir, img1))
63 | or os.path.isfile(os.path.join(dir, img2))
64 | ):
65 | continue
66 | images.append([[img1, img2], flow_map])
67 |
68 | return split2list(images, split, split_save_path, default_split=0.9)
69 |
70 |
71 | def KITTI_loader(root, path_imgs, path_flo):
72 | imgs = [os.path.join(root, path) for path in path_imgs]
73 | flo = os.path.join(root, path_flo)
74 | return [
75 | cv2.imread(img)[:, :, ::-1].astype(np.float32) for img in imgs
76 | ], load_flow_from_png(flo)
77 |
78 |
79 | def KITTI_occ(
80 | root,
81 | transform=None,
82 | target_transform=None,
83 | co_transform=None,
84 | split=None,
85 | split_save_path=None,
86 | ):
87 | train_list, test_list = make_dataset(root, split, split_save_path, True)
88 | train_dataset = ListDataset(
89 | root, train_list, transform, target_transform, co_transform, loader=KITTI_loader
90 | )
91 | # All test sample are cropped to lowest possible size of KITTI images
92 | test_dataset = ListDataset(
93 | root,
94 | test_list,
95 | transform,
96 | target_transform,
97 | flow_transforms.CenterCrop((370, 1224)),
98 | loader=KITTI_loader,
99 | )
100 |
101 | return train_dataset, test_dataset
102 |
103 |
104 | def KITTI_noc(
105 | root,
106 | transform=None,
107 | target_transform=None,
108 | co_transform=None,
109 | split=None,
110 | split_save_path=None,
111 | ):
112 | train_list, test_list = make_dataset(root, split, split_save_path, False)
113 | train_dataset = ListDataset(
114 | root, train_list, transform, target_transform, co_transform, loader=KITTI_loader
115 | )
116 | # All test sample are cropped to lowest possible size of KITTI images
117 | test_dataset = ListDataset(
118 | root,
119 | test_list,
120 | transform,
121 | target_transform,
122 | flow_transforms.CenterCrop((370, 1224)),
123 | loader=KITTI_loader,
124 | )
125 |
126 | return train_dataset, test_dataset
127 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .flyingchairs import flying_chairs
2 | from .KITTI import KITTI_occ, KITTI_noc
3 | from .mpisintel import mpi_sintel_clean, mpi_sintel_final, mpi_sintel_both
4 |
5 | __all__ = (
6 | "flying_chairs",
7 | "KITTI_occ",
8 | "KITTI_noc",
9 | "mpi_sintel_clean",
10 | "mpi_sintel_final",
11 | "mpi_sintel_both",
12 | )
13 |
--------------------------------------------------------------------------------
/datasets/flyingchairs.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import glob
3 | from .listdataset import ListDataset
4 | from .util import split2list
5 |
6 |
7 | def make_dataset(dir, split=None, split_save_path=None):
8 | """Will search for triplets that go by the pattern '[name]_img1.ppm [name]_img2.ppm [name]_flow.flo'"""
9 | images = []
10 | for flow_map in sorted(glob.glob(os.path.join(dir, "*_flow.flo"))):
11 | flow_map = os.path.basename(flow_map)
12 | root_filename = flow_map[:-9]
13 | img1 = root_filename + "_img1.ppm"
14 | img2 = root_filename + "_img2.ppm"
15 | if not (
16 | os.path.isfile(os.path.join(dir, img1))
17 | and os.path.isfile(os.path.join(dir, img2))
18 | ):
19 | continue
20 |
21 | images.append([[img1, img2], flow_map])
22 | return split2list(images, split, split_save_path, default_split=0.97)
23 |
24 |
25 | def flying_chairs(
26 | root,
27 | transform=None,
28 | target_transform=None,
29 | co_transform=None,
30 | split=None,
31 | split_save_path=None,
32 | ):
33 | train_list, test_list = make_dataset(root, split, split_save_path)
34 | train_dataset = ListDataset(
35 | root, train_list, transform, target_transform, co_transform
36 | )
37 | test_dataset = ListDataset(root, test_list, transform, target_transform)
38 |
39 | return train_dataset, test_dataset
40 |
--------------------------------------------------------------------------------
/datasets/listdataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os
3 | import os.path
4 | from imageio import imread
5 | import numpy as np
6 |
7 |
8 | def load_flo(path):
9 | with open(path, "rb") as f:
10 | magic = np.fromfile(f, np.float32, count=1)
11 | assert 202021.25 == magic, "Magic number incorrect. Invalid .flo file"
12 | h = np.fromfile(f, np.int32, count=1)[0]
13 | w = np.fromfile(f, np.int32, count=1)[0]
14 | data = np.fromfile(f, np.float32, count=2 * w * h)
15 | # Reshape data into 3D array (columns, rows, bands)
16 | data2D = np.resize(data, (w, h, 2))
17 | return data2D
18 |
19 |
20 | def default_loader(root, path_imgs, path_flo):
21 | imgs = [os.path.join(root, path) for path in path_imgs]
22 | flo = os.path.join(root, path_flo)
23 | return [imread(img).astype(np.float32) for img in imgs], load_flo(flo)
24 |
25 |
26 | class ListDataset(data.Dataset):
27 | def __init__(
28 | self,
29 | root,
30 | path_list,
31 | transform=None,
32 | target_transform=None,
33 | co_transform=None,
34 | loader=default_loader,
35 | ):
36 |
37 | self.root = root
38 | self.path_list = path_list
39 | self.transform = transform
40 | self.target_transform = target_transform
41 | self.co_transform = co_transform
42 | self.loader = loader
43 |
44 | def __getitem__(self, index):
45 | inputs, target = self.path_list[index]
46 |
47 | inputs, target = self.loader(self.root, inputs, target)
48 | if self.co_transform is not None:
49 | inputs, target = self.co_transform(inputs, target)
50 | if self.transform is not None:
51 | inputs[0] = self.transform(inputs[0])
52 | inputs[1] = self.transform(inputs[1])
53 | if self.target_transform is not None:
54 | target = self.target_transform(target)
55 | return inputs, target
56 |
57 | def __len__(self):
58 | return len(self.path_list)
59 |
--------------------------------------------------------------------------------
/datasets/mpisintel.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import glob
3 | from .listdataset import ListDataset
4 | from .util import split2list
5 | import flow_transforms
6 |
7 | """
8 | Dataset routines for MPI Sintel.
9 | http://sintel.is.tue.mpg.de/
10 | clean version imgs are without shaders, final version imgs are fully rendered
11 | The dataset is not very big, you might want to only pretrain on it for flownet
12 | """
13 |
14 |
15 | def make_dataset(dataset_dir, split, split_save_path, dataset_type="clean"):
16 | flow_dir = "flow"
17 | assert os.path.isdir(os.path.join(dataset_dir, flow_dir))
18 | img_dir = dataset_type
19 | assert os.path.isdir(os.path.join(dataset_dir, img_dir))
20 |
21 | images = []
22 | for flow_map in sorted(
23 | glob.glob(os.path.join(dataset_dir, flow_dir, "*", "*.flo"))
24 | ):
25 | flow_map = os.path.relpath(flow_map, os.path.join(dataset_dir, flow_dir))
26 |
27 | scene_dir, filename = os.path.split(flow_map)
28 | no_ext_filename = os.path.splitext(filename)[0]
29 | prefix, frame_nb = no_ext_filename.split("_")
30 | frame_nb = int(frame_nb)
31 | img1 = os.path.join(
32 | img_dir, scene_dir, "{}_{:04d}.png".format(prefix, frame_nb)
33 | )
34 | img2 = os.path.join(
35 | img_dir, scene_dir, "{}_{:04d}.png".format(prefix, frame_nb + 1)
36 | )
37 | flow_map = os.path.join(flow_dir, flow_map)
38 | if not (
39 | os.path.isfile(os.path.join(dataset_dir, img1))
40 | and os.path.isfile(os.path.join(dataset_dir, img2))
41 | ):
42 | continue
43 | images.append([[img1, img2], flow_map])
44 |
45 | return split2list(images, split, split_save_path, default_split=0.87)
46 |
47 |
48 | def mpi_sintel_clean(
49 | root,
50 | transform=None,
51 | target_transform=None,
52 | co_transform=None,
53 | split=None,
54 | split_save_path=None,
55 | ):
56 | train_list, test_list = make_dataset(root, split, split_save_path, "clean")
57 | train_dataset = ListDataset(
58 | root, train_list, transform, target_transform, co_transform
59 | )
60 | test_dataset = ListDataset(
61 | root,
62 | test_list,
63 | transform,
64 | target_transform,
65 | flow_transforms.CenterCrop((384, 1024)),
66 | )
67 |
68 | return train_dataset, test_dataset
69 |
70 |
71 | def mpi_sintel_final(
72 | root,
73 | transform=None,
74 | target_transform=None,
75 | co_transform=None,
76 | split=None,
77 | split_save_path=None,
78 | ):
79 | train_list, test_list = make_dataset(root, split, split_save_path, "final")
80 | train_dataset = ListDataset(
81 | root, train_list, transform, target_transform, co_transform
82 | )
83 | test_dataset = ListDataset(
84 | root,
85 | test_list,
86 | transform,
87 | target_transform,
88 | flow_transforms.CenterCrop((384, 1024)),
89 | )
90 |
91 | return train_dataset, test_dataset
92 |
93 |
94 | def mpi_sintel_both(
95 | root,
96 | transform=None,
97 | target_transform=None,
98 | co_transform=None,
99 | split=None,
100 | split_save_path=None,
101 | ):
102 | """load images from both clean and final folders.
103 | We cannot shuffle input, because it would very likely cause data snooping
104 | for the clean and final frames are not that different"""
105 | assert isinstance(
106 | split, str
107 | ), "To avoid data snooping, you must provide a static list of train/val when dealing with both clean and final."
108 | " Look at Sintel_train_val.txt for an example"
109 | train_list1, test_list1 = make_dataset(root, split, split_save_path, "clean")
110 | train_list2, test_list2 = make_dataset(root, split, split_save_path, "final")
111 | train_dataset = ListDataset(
112 | root, train_list1 + train_list2, transform, target_transform, co_transform
113 | )
114 | test_dataset = ListDataset(
115 | root,
116 | test_list1 + test_list2,
117 | transform,
118 | target_transform,
119 | flow_transforms.CenterCrop((384, 1024)),
120 | )
121 |
122 | return train_dataset, test_dataset
123 |
--------------------------------------------------------------------------------
/datasets/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def split2list(images, split, split_save_path, default_split=0.9):
5 | if isinstance(split, str):
6 | with open(split) as f:
7 | split_values = [x.strip() == "1" for x in f.readlines()]
8 | assert len(images) == len(split_values)
9 | elif split is None:
10 | split_values = np.random.uniform(0, 1, len(images)) < default_split
11 | else:
12 | try:
13 | split = float(split)
14 | except TypeError:
15 | print("Invalid Split value, it must be either a filepath or a float")
16 | raise
17 | split_values = np.random.uniform(0, 1, len(images)) < split
18 | if split_save_path is not None:
19 | with open(split_save_path, "w") as f:
20 | f.write("\n".join(map(lambda x: str(int(x)), split_values)))
21 | train_samples = [sample for sample, split in zip(images, split_values) if split]
22 | test_samples = [sample for sample, split in zip(images, split_values) if not split]
23 | return train_samples, test_samples
24 |
--------------------------------------------------------------------------------
/flow_transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import random
4 | import numpy as np
5 | import numbers
6 | import types
7 | import scipy.ndimage as ndimage
8 |
9 | """Set of tranform random routines that takes both input and target as arguments,
10 | in order to have random but coherent transformations.
11 | inputs are PIL Image pairs and targets are ndarrays"""
12 |
13 |
14 | class Compose(object):
15 | """Composes several co_transforms together.
16 | For example:
17 | >>> co_transforms.Compose([
18 | >>> co_transforms.CenterCrop(10),
19 | >>> co_transforms.ToTensor(),
20 | >>> ])
21 | """
22 |
23 | def __init__(self, co_transforms):
24 | self.co_transforms = co_transforms
25 |
26 | def __call__(self, input, target):
27 | for t in self.co_transforms:
28 | input, target = t(input, target)
29 | return input, target
30 |
31 |
32 | class ArrayToTensor(object):
33 | """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
34 |
35 | def __call__(self, array):
36 | assert isinstance(array, np.ndarray)
37 | array = np.transpose(array, (2, 0, 1))
38 | # handle numpy array
39 | tensor = torch.from_numpy(array)
40 | # put it from HWC to CHW format
41 | return tensor.float()
42 |
43 |
44 | class Lambda(object):
45 | """Applies a lambda as a transform"""
46 |
47 | def __init__(self, lambd):
48 | assert isinstance(lambd, types.LambdaType)
49 | self.lambd = lambd
50 |
51 | def __call__(self, input, target):
52 | return self.lambd(input, target)
53 |
54 |
55 | class CenterCrop(object):
56 | """Crops the given inputs and target arrays at the center to have a region of
57 | the given size. size can be a tuple (target_height, target_width)
58 | or an integer, in which case the target will be of a square shape (size, size)
59 | Careful, img1 and img2 may not be the same size
60 | """
61 |
62 | def __init__(self, size):
63 | if isinstance(size, numbers.Number):
64 | self.size = (int(size), int(size))
65 | else:
66 | self.size = size
67 |
68 | def __call__(self, inputs, target):
69 | h1, w1, _ = inputs[0].shape
70 | h2, w2, _ = inputs[1].shape
71 | th, tw = self.size
72 | x1 = int(round((w1 - tw) / 2.0))
73 | y1 = int(round((h1 - th) / 2.0))
74 | x2 = int(round((w2 - tw) / 2.0))
75 | y2 = int(round((h2 - th) / 2.0))
76 |
77 | inputs[0] = inputs[0][y1 : y1 + th, x1 : x1 + tw]
78 | inputs[1] = inputs[1][y2 : y2 + th, x2 : x2 + tw]
79 | target = target[y1 : y1 + th, x1 : x1 + tw]
80 | return inputs, target
81 |
82 |
83 | class Scale(object):
84 | """Rescales the inputs and target arrays to the given 'size'.
85 | 'size' will be the size of the smaller edge.
86 | For example, if height > width, then image will be
87 | rescaled to (size * height / width, size)
88 | size: size of the smaller edge
89 | interpolation order: Default: 2 (bilinear)
90 | """
91 |
92 | def __init__(self, size, order=2):
93 | self.size = size
94 | self.order = order
95 |
96 | def __call__(self, inputs, target):
97 | h, w, _ = inputs[0].shape
98 | if (w <= h and w == self.size) or (h <= w and h == self.size):
99 | return inputs, target
100 | if w < h:
101 | ratio = self.size / w
102 | else:
103 | ratio = self.size / h
104 |
105 | inputs[0] = ndimage.interpolation.zoom(inputs[0], ratio, order=self.order)
106 | inputs[1] = ndimage.interpolation.zoom(inputs[1], ratio, order=self.order)
107 |
108 | target = ndimage.interpolation.zoom(target, ratio, order=self.order)
109 | target *= ratio
110 | return inputs, target
111 |
112 |
113 | class RandomCrop(object):
114 | """Crops the given PIL.Image at a random location to have a region of
115 | the given size. size can be a tuple (target_height, target_width)
116 | or an integer, in which case the target will be of a square shape (size, size)
117 | """
118 |
119 | def __init__(self, size):
120 | if isinstance(size, numbers.Number):
121 | self.size = (int(size), int(size))
122 | else:
123 | self.size = size
124 |
125 | def __call__(self, inputs, target):
126 | h, w, _ = inputs[0].shape
127 | th, tw = self.size
128 | if w == tw and h == th:
129 | return inputs, target
130 |
131 | x1 = random.randint(0, w - tw)
132 | y1 = random.randint(0, h - th)
133 | inputs[0] = inputs[0][y1 : y1 + th, x1 : x1 + tw]
134 | inputs[1] = inputs[1][y1 : y1 + th, x1 : x1 + tw]
135 | return inputs, target[y1 : y1 + th, x1 : x1 + tw]
136 |
137 |
138 | class RandomHorizontalFlip(object):
139 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5"""
140 |
141 | def __call__(self, inputs, target):
142 | if random.random() < 0.5:
143 | inputs[0] = np.copy(np.fliplr(inputs[0]))
144 | inputs[1] = np.copy(np.fliplr(inputs[1]))
145 | target = np.copy(np.fliplr(target))
146 | target[:, :, 0] *= -1
147 | return inputs, target
148 |
149 |
150 | class RandomVerticalFlip(object):
151 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5"""
152 |
153 | def __call__(self, inputs, target):
154 | if random.random() < 0.5:
155 | inputs[0] = np.copy(np.flipud(inputs[0]))
156 | inputs[1] = np.copy(np.flipud(inputs[1]))
157 | target = np.copy(np.flipud(target))
158 | target[:, :, 1] *= -1
159 | return inputs, target
160 |
161 |
162 | class RandomRotate(object):
163 | """Random rotation of the image from -angle to angle (in degrees)
164 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation
165 | angle: max angle of the rotation
166 | interpolation order: Default: 2 (bilinear)
167 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image.
168 | diff_angle: Default: 0.
169 | """
170 |
171 | def __init__(self, angle, diff_angle=0, order=2, reshape=False):
172 | self.angle = angle
173 | self.reshape = reshape
174 | self.order = order
175 | self.diff_angle = diff_angle
176 |
177 | def __call__(self, inputs, target):
178 | applied_angle = random.uniform(-self.angle, self.angle)
179 | diff = random.uniform(-self.diff_angle, self.diff_angle)
180 | angle1 = applied_angle - diff / 2
181 | angle2 = applied_angle + diff / 2
182 | angle1_rad = angle1 * np.pi / 180
183 | diff_rad = diff * np.pi / 180
184 |
185 | h, w, _ = target.shape
186 |
187 | warped_coords = np.mgrid[:w, :h].T + target
188 | warped_coords -= np.array([w / 2, h / 2])
189 |
190 | warped_coords_rot = np.zeros_like(target)
191 |
192 | warped_coords_rot[..., 0] = (np.cos(diff_rad) - 1) * warped_coords[
193 | ..., 0
194 | ] + np.sin(diff_rad) * warped_coords[..., 1]
195 |
196 | warped_coords_rot[..., 1] = (
197 | -np.sin(diff_rad) * warped_coords[..., 0]
198 | + (np.cos(diff_rad) - 1) * warped_coords[..., 1]
199 | )
200 |
201 | target += warped_coords_rot
202 |
203 | inputs[0] = ndimage.interpolation.rotate(
204 | inputs[0], angle1, reshape=self.reshape, order=self.order
205 | )
206 | inputs[1] = ndimage.interpolation.rotate(
207 | inputs[1], angle2, reshape=self.reshape, order=self.order
208 | )
209 | target = ndimage.interpolation.rotate(
210 | target, angle1, reshape=self.reshape, order=self.order
211 | )
212 | # flow vectors must be rotated too! careful about Y flow which is upside down
213 | target_ = np.copy(target)
214 | target[:, :, 0] = (
215 | np.cos(angle1_rad) * target_[:, :, 0]
216 | + np.sin(angle1_rad) * target_[:, :, 1]
217 | )
218 | target[:, :, 1] = (
219 | -np.sin(angle1_rad) * target_[:, :, 0]
220 | + np.cos(angle1_rad) * target_[:, :, 1]
221 | )
222 | return inputs, target
223 |
224 |
225 | class RandomTranslate(object):
226 | def __init__(self, translation):
227 | if isinstance(translation, numbers.Number):
228 | self.translation = (int(translation), int(translation))
229 | else:
230 | self.translation = translation
231 |
232 | def __call__(self, inputs, target):
233 | h, w, _ = inputs[0].shape
234 | th, tw = self.translation
235 | tw = random.randint(-tw, tw)
236 | th = random.randint(-th, th)
237 | if tw == 0 and th == 0:
238 | return inputs, target
239 | # compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2
240 | x1, x2, x3, x4 = max(0, tw), min(w + tw, w), max(0, -tw), min(w - tw, w)
241 | y1, y2, y3, y4 = max(0, th), min(h + th, h), max(0, -th), min(h - th, h)
242 |
243 | inputs[0] = inputs[0][y1:y2, x1:x2]
244 | inputs[1] = inputs[1][y3:y4, x3:x4]
245 | target = target[y1:y2, x1:x2]
246 | target[:, :, 0] += tw
247 | target[:, :, 1] += th
248 |
249 | return inputs, target
250 |
251 |
252 | class RandomColorWarp(object):
253 | def __init__(self, mean_range=0, std_range=0):
254 | self.mean_range = mean_range
255 | self.std_range = std_range
256 |
257 | def __call__(self, inputs, target):
258 | random_std = np.random.uniform(-self.std_range, self.std_range, 3)
259 | random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3)
260 | random_order = np.random.permutation(3)
261 |
262 | inputs[0] *= 1 + random_std
263 | inputs[0] += random_mean
264 |
265 | inputs[1] *= 1 + random_std
266 | inputs[1] += random_mean
267 |
268 | inputs[0] = inputs[0][:, :, random_order]
269 | inputs[1] = inputs[1][:, :, random_order]
270 |
271 | return inputs, target
272 |
--------------------------------------------------------------------------------
/images/GT_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_1.png
--------------------------------------------------------------------------------
/images/GT_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_2.png
--------------------------------------------------------------------------------
/images/GT_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_3.png
--------------------------------------------------------------------------------
/images/input_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_1.gif
--------------------------------------------------------------------------------
/images/input_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_2.gif
--------------------------------------------------------------------------------
/images/input_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_3.gif
--------------------------------------------------------------------------------
/images/pred_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_1.png
--------------------------------------------------------------------------------
/images/pred_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_2.png
--------------------------------------------------------------------------------
/images/pred_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_3.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim
10 | import torch.utils.data
11 | import torchvision.transforms as transforms
12 | import flow_transforms
13 | import models
14 | import datasets
15 | from multiscaleloss import multiscaleEPE, realEPE
16 | import datetime
17 | from torch.utils.tensorboard import SummaryWriter
18 | from util import flow2rgb, AverageMeter, save_checkpoint
19 | import numpy as np
20 |
21 | model_names = sorted(
22 | name for name in models.__dict__ if name.islower() and not name.startswith("__")
23 | )
24 | dataset_names = sorted(name for name in datasets.__all__)
25 |
26 | parser = argparse.ArgumentParser(
27 | description="PyTorch FlowNet Training on several datasets",
28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
29 | )
30 | parser.add_argument("data", metavar="DIR", help="path to dataset")
31 | parser.add_argument(
32 | "--dataset",
33 | metavar="DATASET",
34 | default="flying_chairs",
35 | choices=dataset_names,
36 | help="dataset type : " + " | ".join(dataset_names),
37 | )
38 | group = parser.add_mutually_exclusive_group()
39 | group.add_argument(
40 | "-s", "--split-file", default=None, type=str, help="test-val split file"
41 | )
42 | group.add_argument(
43 | "--split-value",
44 | default=0.8,
45 | type=float,
46 | help="test-val split proportion between 0 (only test) and 1 (only train), "
47 | "will be overwritten if a split file is set",
48 | )
49 | parser.add_argument(
50 | "--split-seed",
51 | type=int,
52 | default=None,
53 | help="Seed the train-val split to enforce reproducibility (consistent restart too)",
54 | )
55 | parser.add_argument(
56 | "--arch",
57 | "-a",
58 | metavar="ARCH",
59 | default="flownets",
60 | choices=model_names,
61 | help="model architecture, overwritten if pretrained is specified: "
62 | + " | ".join(model_names),
63 | )
64 | parser.add_argument(
65 | "--solver", default="adam", choices=["adam", "sgd"], help="solver algorithms"
66 | )
67 | parser.add_argument(
68 | "-j",
69 | "--workers",
70 | default=8,
71 | type=int,
72 | metavar="N",
73 | help="number of data loading workers",
74 | )
75 | parser.add_argument(
76 | "--epochs", default=300, type=int, metavar="N", help="number of total epochs to run"
77 | )
78 | parser.add_argument(
79 | "--start-epoch",
80 | default=0,
81 | type=int,
82 | metavar="N",
83 | help="manual epoch number (useful on restarts)",
84 | )
85 | parser.add_argument(
86 | "--epoch-size",
87 | default=1000,
88 | type=int,
89 | metavar="N",
90 | help="manual epoch size (will match dataset size if set to 0)",
91 | )
92 | parser.add_argument(
93 | "-b", "--batch-size", default=8, type=int, metavar="N", help="mini-batch size"
94 | )
95 | parser.add_argument(
96 | "--lr",
97 | "--learning-rate",
98 | default=0.0001,
99 | type=float,
100 | metavar="LR",
101 | help="initial learning rate",
102 | )
103 | parser.add_argument(
104 | "--momentum",
105 | default=0.9,
106 | type=float,
107 | metavar="M",
108 | help="momentum for sgd, alpha parameter for adam",
109 | )
110 | parser.add_argument(
111 | "--beta", default=0.999, type=float, metavar="M", help="beta parameter for adam"
112 | )
113 | parser.add_argument(
114 | "--weight-decay", "--wd", default=4e-4, type=float, metavar="W", help="weight decay"
115 | )
116 | parser.add_argument(
117 | "--bias-decay", default=0, type=float, metavar="B", help="bias decay"
118 | )
119 | parser.add_argument(
120 | "--multiscale-weights",
121 | "-w",
122 | default=[0.005, 0.01, 0.02, 0.08, 0.32],
123 | type=float,
124 | nargs=5,
125 | help="training weight for each scale, from highest resolution (flow2) to lowest (flow6)",
126 | metavar=("W2", "W3", "W4", "W5", "W6"),
127 | )
128 | parser.add_argument(
129 | "--sparse",
130 | action="store_true",
131 | help="look for NaNs in target flow when computing EPE, avoid if flow is garantied to be dense,"
132 | "automatically seleted when choosing a KITTIdataset",
133 | )
134 | parser.add_argument(
135 | "--print-freq", "-p", default=10, type=int, metavar="N", help="print frequency"
136 | )
137 | parser.add_argument(
138 | "-e",
139 | "--evaluate",
140 | dest="evaluate",
141 | action="store_true",
142 | help="evaluate model on validation set",
143 | )
144 | parser.add_argument(
145 | "--pretrained", dest="pretrained", default=None, help="path to pre-trained model"
146 | )
147 | parser.add_argument(
148 | "--no-date", action="store_true", help="don't append date timestamp to folder"
149 | )
150 | parser.add_argument(
151 | "--div-flow",
152 | default=20,
153 | help="value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results",
154 | )
155 | parser.add_argument(
156 | "--milestones",
157 | default=[100, 150, 200],
158 | metavar="N",
159 | nargs="*",
160 | help="epochs at which learning rate is divided by 2",
161 | )
162 |
163 |
164 | best_EPE = -1
165 | n_iter = 0
166 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167 |
168 |
169 | def main():
170 | global args, best_EPE, n_iter
171 | args = parser.parse_args()
172 | save_path = "{},{},{}epochs{},b{},lr{}".format(
173 | args.arch,
174 | args.solver,
175 | args.epochs,
176 | ",epochSize" + str(args.epoch_size) if args.epoch_size > 0 else "",
177 | args.batch_size,
178 | args.lr,
179 | )
180 | if not args.no_date:
181 | timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
182 | save_path = os.path.join(timestamp, save_path)
183 | save_path = os.path.join(args.dataset, save_path)
184 | print("=> will save everything to {}".format(save_path))
185 | if not os.path.exists(save_path):
186 | os.makedirs(save_path)
187 |
188 | if args.split_seed is not None:
189 | np.random.seed(args.split_seed)
190 |
191 | train_writer = SummaryWriter(os.path.join(save_path, "train"))
192 | test_writer = SummaryWriter(os.path.join(save_path, "test"))
193 | output_writers = []
194 | for i in range(3):
195 | output_writers.append(SummaryWriter(os.path.join(save_path, "test", str(i))))
196 |
197 | # Data loading code
198 | input_transform = transforms.Compose(
199 | [
200 | flow_transforms.ArrayToTensor(),
201 | transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
202 | transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]),
203 | ]
204 | )
205 | target_transform = transforms.Compose(
206 | [
207 | flow_transforms.ArrayToTensor(),
208 | transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow]),
209 | ]
210 | )
211 |
212 | if "KITTI" in args.dataset:
213 | args.sparse = True
214 | if args.sparse:
215 | co_transform = flow_transforms.Compose(
216 | [
217 | flow_transforms.RandomCrop((320, 448)),
218 | flow_transforms.RandomVerticalFlip(),
219 | flow_transforms.RandomHorizontalFlip(),
220 | ]
221 | )
222 | else:
223 | co_transform = flow_transforms.Compose(
224 | [
225 | flow_transforms.RandomTranslate(10),
226 | flow_transforms.RandomRotate(10, 5),
227 | flow_transforms.RandomCrop((320, 448)),
228 | flow_transforms.RandomVerticalFlip(),
229 | flow_transforms.RandomHorizontalFlip(),
230 | ]
231 | )
232 |
233 | print("=> fetching img pairs in '{}'".format(args.data))
234 | train_set, test_set = datasets.__dict__[args.dataset](
235 | args.data,
236 | transform=input_transform,
237 | target_transform=target_transform,
238 | co_transform=co_transform,
239 | split=args.split_file if args.split_file else args.split_value,
240 | split_save_path=os.path.join(save_path, "split.txt"),
241 | )
242 | print(
243 | "{} samples found, {} train samples and {} test samples ".format(
244 | len(test_set) + len(train_set), len(train_set), len(test_set)
245 | )
246 | )
247 | n_iter = args.start_epoch * len(train_set)
248 | train_loader = torch.utils.data.DataLoader(
249 | train_set,
250 | batch_size=args.batch_size,
251 | num_workers=args.workers,
252 | pin_memory=True,
253 | shuffle=True,
254 | )
255 | val_loader = torch.utils.data.DataLoader(
256 | test_set,
257 | batch_size=args.batch_size,
258 | num_workers=args.workers,
259 | pin_memory=True,
260 | shuffle=False,
261 | )
262 |
263 | # create model
264 | if args.pretrained:
265 | network_data = torch.load(args.pretrained)
266 | args.arch = network_data["arch"]
267 | print("=> using pre-trained model '{}'".format(args.arch))
268 | else:
269 | network_data = None
270 | print("=> creating model '{}'".format(args.arch))
271 |
272 | model = models.__dict__[args.arch](network_data).to(device)
273 |
274 | assert args.solver in ["adam", "sgd"]
275 | print("=> setting {} solver".format(args.solver))
276 | param_groups = [
277 | {"params": model.bias_parameters(), "weight_decay": args.bias_decay},
278 | {"params": model.weight_parameters(), "weight_decay": args.weight_decay},
279 | ]
280 |
281 | if device.type == "cuda":
282 | model = torch.nn.DataParallel(model).cuda()
283 | cudnn.benchmark = True
284 |
285 | if args.solver == "adam":
286 | optimizer = torch.optim.Adam(
287 | param_groups, args.lr, betas=(args.momentum, args.beta)
288 | )
289 | elif args.solver == "sgd":
290 | optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum)
291 |
292 | if args.evaluate:
293 | best_EPE = validate(val_loader, model, 0, output_writers)
294 | return
295 |
296 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
297 | optimizer, milestones=args.milestones, gamma=0.5
298 | )
299 |
300 | for epoch in range(args.start_epoch, args.epochs):
301 | # train for one epoch
302 | train_loss, train_EPE = train(
303 | train_loader, model, optimizer, epoch, train_writer
304 | )
305 | scheduler.step()
306 | train_writer.add_scalar("mean EPE", train_EPE, epoch)
307 |
308 | # evaluate on validation set
309 |
310 | with torch.no_grad():
311 | EPE = validate(val_loader, model, epoch, output_writers)
312 | test_writer.add_scalar("mean EPE", EPE, epoch)
313 |
314 | if best_EPE < 0:
315 | best_EPE = EPE
316 |
317 | is_best = EPE < best_EPE
318 | best_EPE = min(EPE, best_EPE)
319 | save_checkpoint(
320 | {
321 | "epoch": epoch + 1,
322 | "arch": args.arch,
323 | "state_dict": model.module.state_dict(),
324 | "best_EPE": best_EPE,
325 | "div_flow": args.div_flow,
326 | },
327 | is_best,
328 | save_path,
329 | )
330 |
331 |
332 | def train(train_loader, model, optimizer, epoch, train_writer):
333 | global n_iter, args
334 | batch_time = AverageMeter()
335 | data_time = AverageMeter()
336 | losses = AverageMeter()
337 | flow2_EPEs = AverageMeter()
338 |
339 | epoch_size = (
340 | len(train_loader)
341 | if args.epoch_size == 0
342 | else min(len(train_loader), args.epoch_size)
343 | )
344 |
345 | # switch to train mode
346 | model.train()
347 |
348 | end = time.time()
349 |
350 | for i, (input, target) in enumerate(train_loader):
351 | # measure data loading time
352 | data_time.update(time.time() - end)
353 | target = target.to(device)
354 | input = torch.cat(input, 1).to(device)
355 |
356 | # compute output
357 | output = model(input)
358 | if args.sparse:
359 | # Since Target pooling is not very precise when sparse,
360 | # take the highest resolution prediction and upsample it instead of downsampling target
361 | h, w = target.size()[-2:]
362 | output = [F.interpolate(output[0], (h, w)), *output[1:]]
363 |
364 | loss = multiscaleEPE(
365 | output, target, weights=args.multiscale_weights, sparse=args.sparse
366 | )
367 | flow2_EPE = args.div_flow * realEPE(output[0], target, sparse=args.sparse)
368 | # record loss and EPE
369 | losses.update(loss.item(), target.size(0))
370 | train_writer.add_scalar("train_loss", loss.item(), n_iter)
371 | flow2_EPEs.update(flow2_EPE.item(), target.size(0))
372 |
373 | # compute gradient and do optimization step
374 | optimizer.zero_grad()
375 | loss.backward()
376 | optimizer.step()
377 |
378 | # measure elapsed time
379 | batch_time.update(time.time() - end)
380 | end = time.time()
381 |
382 | if i % args.print_freq == 0:
383 | print(
384 | "Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}".format(
385 | epoch, i, epoch_size, batch_time, data_time, losses, flow2_EPEs
386 | )
387 | )
388 | n_iter += 1
389 | if i >= epoch_size:
390 | break
391 |
392 | return losses.avg, flow2_EPEs.avg
393 |
394 |
395 | def validate(val_loader, model, epoch, output_writers):
396 | global args
397 |
398 | batch_time = AverageMeter()
399 | flow2_EPEs = AverageMeter()
400 |
401 | # switch to evaluate mode
402 | model.eval()
403 |
404 | end = time.time()
405 | for i, (input, target) in enumerate(val_loader):
406 | target = target.to(device)
407 | input = torch.cat(input, 1).to(device)
408 |
409 | # compute output
410 | output = model(input)
411 | flow2_EPE = args.div_flow * realEPE(output, target, sparse=args.sparse)
412 | # record EPE
413 | flow2_EPEs.update(flow2_EPE.item(), target.size(0))
414 |
415 | # measure elapsed time
416 | batch_time.update(time.time() - end)
417 | end = time.time()
418 |
419 | if i < len(output_writers): # log first output of first batches
420 | if epoch == args.start_epoch:
421 | mean_values = torch.tensor(
422 | [0.45, 0.432, 0.411], dtype=input.dtype
423 | ).view(3, 1, 1)
424 | output_writers[i].add_image(
425 | "GroundTruth", flow2rgb(args.div_flow * target[0], max_value=10), 0
426 | )
427 | output_writers[i].add_image(
428 | "Inputs", (input[0, :3].cpu() + mean_values).clamp(0, 1), 0
429 | )
430 | output_writers[i].add_image(
431 | "Inputs", (input[0, 3:].cpu() + mean_values).clamp(0, 1), 1
432 | )
433 | output_writers[i].add_image(
434 | "FlowNet Outputs",
435 | flow2rgb(args.div_flow * output[0], max_value=10),
436 | epoch,
437 | )
438 |
439 | if i % args.print_freq == 0:
440 | print(
441 | "Test: [{0}/{1}]\t Time {2}\t EPE {3}".format(
442 | i, len(val_loader), batch_time, flow2_EPEs
443 | )
444 | )
445 |
446 | print(" * EPE {:.3f}".format(flow2_EPEs.avg))
447 |
448 | return flow2_EPEs.avg
449 |
450 |
451 | if __name__ == "__main__":
452 | main()
453 |
--------------------------------------------------------------------------------
/models/FlowNetC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.init import kaiming_normal_, constant_
4 | from .util import conv, predict_flow, deconv, crop_like, correlate
5 |
6 | __all__ = ["flownetc", "flownetc_bn"]
7 |
8 |
9 | class FlowNetC(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, batchNorm=True):
13 | super(FlowNetC, self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2)
17 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
18 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
19 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1)
20 |
21 | self.conv3_1 = conv(self.batchNorm, 473, 256)
22 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
23 | self.conv4_1 = conv(self.batchNorm, 512, 512)
24 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
25 | self.conv5_1 = conv(self.batchNorm, 512, 512)
26 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
27 | self.conv6_1 = conv(self.batchNorm, 1024, 1024)
28 |
29 | self.deconv5 = deconv(1024, 512)
30 | self.deconv4 = deconv(1026, 256)
31 | self.deconv3 = deconv(770, 128)
32 | self.deconv2 = deconv(386, 64)
33 |
34 | self.predict_flow6 = predict_flow(1024)
35 | self.predict_flow5 = predict_flow(1026)
36 | self.predict_flow4 = predict_flow(770)
37 | self.predict_flow3 = predict_flow(386)
38 | self.predict_flow2 = predict_flow(194)
39 |
40 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
41 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
42 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
43 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
44 |
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
47 | kaiming_normal_(m.weight, 0.1)
48 | if m.bias is not None:
49 | constant_(m.bias, 0)
50 | elif isinstance(m, nn.BatchNorm2d):
51 | constant_(m.weight, 1)
52 | constant_(m.bias, 0)
53 |
54 | def forward(self, x):
55 | x1 = x[:, :3]
56 | x2 = x[:, 3:]
57 |
58 | out_conv1a = self.conv1(x1)
59 | out_conv2a = self.conv2(out_conv1a)
60 | out_conv3a = self.conv3(out_conv2a)
61 |
62 | out_conv1b = self.conv1(x2)
63 | out_conv2b = self.conv2(out_conv1b)
64 | out_conv3b = self.conv3(out_conv2b)
65 |
66 | out_conv_redir = self.conv_redir(out_conv3a)
67 | out_correlation = correlate(out_conv3a, out_conv3b)
68 |
69 | in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
70 |
71 | out_conv3 = self.conv3_1(in_conv3_1)
72 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
73 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
74 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
75 |
76 | flow6 = self.predict_flow6(out_conv6)
77 | flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
78 | out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
79 |
80 | concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
81 | flow5 = self.predict_flow5(concat5)
82 | flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
83 | out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
84 |
85 | concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
86 | flow4 = self.predict_flow4(concat4)
87 | flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
88 | out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
89 |
90 | concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
91 | flow3 = self.predict_flow3(concat3)
92 | flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a)
93 | out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a)
94 |
95 | concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1)
96 | flow2 = self.predict_flow2(concat2)
97 |
98 | if self.training:
99 | return flow2, flow3, flow4, flow5, flow6
100 | else:
101 | return flow2
102 |
103 | def weight_parameters(self):
104 | return [param for name, param in self.named_parameters() if "weight" in name]
105 |
106 | def bias_parameters(self):
107 | return [param for name, param in self.named_parameters() if "bias" in name]
108 |
109 |
110 | def flownetc(data=None):
111 | """FlowNetS model architecture from the
112 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
113 |
114 | Args:
115 | data : pretrained weights of the network. will create a new one if not set
116 | """
117 | model = FlowNetC(batchNorm=False)
118 | if data is not None:
119 | model.load_state_dict(data["state_dict"])
120 | return model
121 |
122 |
123 | def flownetc_bn(data=None):
124 | """FlowNetS model architecture from the
125 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
126 |
127 | Args:
128 | data : pretrained weights of the network. will create a new one if not set
129 | """
130 | model = FlowNetC(batchNorm=True)
131 | if data is not None:
132 | model.load_state_dict(data["state_dict"])
133 | return model
134 |
--------------------------------------------------------------------------------
/models/FlowNetS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.init import kaiming_normal_, constant_
4 | from .util import conv, predict_flow, deconv, crop_like
5 |
6 | __all__ = ["flownets", "flownets_bn"]
7 |
8 |
9 | class FlowNetS(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, batchNorm=True):
13 | super(FlowNetS, self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv1 = conv(self.batchNorm, 6, 64, kernel_size=7, stride=2)
17 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
18 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
19 | self.conv3_1 = conv(self.batchNorm, 256, 256)
20 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
21 | self.conv4_1 = conv(self.batchNorm, 512, 512)
22 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
23 | self.conv5_1 = conv(self.batchNorm, 512, 512)
24 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
25 | self.conv6_1 = conv(self.batchNorm, 1024, 1024)
26 |
27 | self.deconv5 = deconv(1024, 512)
28 | self.deconv4 = deconv(1026, 256)
29 | self.deconv3 = deconv(770, 128)
30 | self.deconv2 = deconv(386, 64)
31 |
32 | self.predict_flow6 = predict_flow(1024)
33 | self.predict_flow5 = predict_flow(1026)
34 | self.predict_flow4 = predict_flow(770)
35 | self.predict_flow3 = predict_flow(386)
36 | self.predict_flow2 = predict_flow(194)
37 |
38 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
39 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
40 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
41 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
42 |
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
45 | kaiming_normal_(m.weight, 0.1)
46 | if m.bias is not None:
47 | constant_(m.bias, 0)
48 | elif isinstance(m, nn.BatchNorm2d):
49 | constant_(m.weight, 1)
50 | constant_(m.bias, 0)
51 |
52 | def forward(self, x):
53 | out_conv2 = self.conv2(self.conv1(x))
54 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
55 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
56 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
57 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
58 |
59 | flow6 = self.predict_flow6(out_conv6)
60 | flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
61 | out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
62 |
63 | concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1)
64 | flow5 = self.predict_flow5(concat5)
65 | flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
66 | out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
67 |
68 | concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1)
69 | flow4 = self.predict_flow4(concat4)
70 | flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
71 | out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
72 |
73 | concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1)
74 | flow3 = self.predict_flow3(concat3)
75 | flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
76 | out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)
77 |
78 | concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1)
79 | flow2 = self.predict_flow2(concat2)
80 |
81 | if self.training:
82 | return flow2, flow3, flow4, flow5, flow6
83 | else:
84 | return flow2
85 |
86 | def weight_parameters(self):
87 | return [param for name, param in self.named_parameters() if "weight" in name]
88 |
89 | def bias_parameters(self):
90 | return [param for name, param in self.named_parameters() if "bias" in name]
91 |
92 |
93 | def flownets(data=None):
94 | """FlowNetS model architecture from the
95 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
96 |
97 | Args:
98 | data : pretrained weights of the network. will create a new one if not set
99 | """
100 | model = FlowNetS(batchNorm=False)
101 | if data is not None:
102 | model.load_state_dict(data["state_dict"])
103 | return model
104 |
105 |
106 | def flownets_bn(data=None):
107 | """FlowNetS model architecture from the
108 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852)
109 |
110 | Args:
111 | data : pretrained weights of the network. will create a new one if not set
112 | """
113 | model = FlowNetS(batchNorm=True)
114 | if data is not None:
115 | model.load_state_dict(data["state_dict"])
116 | return model
117 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .FlowNetS import *
2 | from .FlowNetC import *
3 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | try:
5 | from spatial_correlation_sampler import spatial_correlation_sample
6 | except ImportError as e:
7 | import warnings
8 |
9 | with warnings.catch_warnings():
10 | warnings.filterwarnings("default", category=ImportWarning)
11 | warnings.warn(
12 | "failed to load custom correlation module" "which is needed for FlowNetC",
13 | ImportWarning,
14 | )
15 |
16 |
17 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
18 | if batchNorm:
19 | return nn.Sequential(
20 | nn.Conv2d(
21 | in_planes,
22 | out_planes,
23 | kernel_size=kernel_size,
24 | stride=stride,
25 | padding=(kernel_size - 1) // 2,
26 | bias=False,
27 | ),
28 | nn.BatchNorm2d(out_planes),
29 | nn.LeakyReLU(0.1, inplace=True),
30 | )
31 | else:
32 | return nn.Sequential(
33 | nn.Conv2d(
34 | in_planes,
35 | out_planes,
36 | kernel_size=kernel_size,
37 | stride=stride,
38 | padding=(kernel_size - 1) // 2,
39 | bias=True,
40 | ),
41 | nn.LeakyReLU(0.1, inplace=True),
42 | )
43 |
44 |
45 | def predict_flow(in_planes):
46 | return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False)
47 |
48 |
49 | def deconv(in_planes, out_planes):
50 | return nn.Sequential(
51 | nn.ConvTranspose2d(
52 | in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False
53 | ),
54 | nn.LeakyReLU(0.1, inplace=True),
55 | )
56 |
57 |
58 | def correlate(input1, input2):
59 | out_corr = spatial_correlation_sample(
60 | input1,
61 | input2,
62 | kernel_size=1,
63 | patch_size=21,
64 | stride=1,
65 | padding=0,
66 | dilation_patch=2,
67 | )
68 | # collate dimensions 1 and 2 in order to be treated as a
69 | # regular 4D tensor
70 | b, ph, pw, h, w = out_corr.size()
71 | out_corr = out_corr.view(b, ph * pw, h, w) / input1.size(1)
72 | return F.leaky_relu_(out_corr, 0.1)
73 |
74 |
75 | def crop_like(input, target):
76 | if input.size()[2:] == target.size()[2:]:
77 | return input
78 | else:
79 | return input[:, :, : target.size(2), : target.size(3)]
80 |
--------------------------------------------------------------------------------
/multiscaleloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def EPE(input_flow, target_flow, sparse=False, mean=True):
6 | EPE_map = torch.norm(target_flow - input_flow, 2, 1)
7 | batch_size = EPE_map.size(0)
8 | if sparse:
9 | # invalid flow is defined with both flow coordinates to be exactly 0
10 | mask = (target_flow[:, 0] == 0) & (target_flow[:, 1] == 0)
11 |
12 | EPE_map = EPE_map[~mask]
13 | if mean:
14 | return EPE_map.mean()
15 | else:
16 | return EPE_map.sum() / batch_size
17 |
18 |
19 | def sparse_max_pool(input, size):
20 | """Downsample the input by considering 0 values as invalid.
21 |
22 | Unfortunately, no generic interpolation mode can resize a sparse map correctly,
23 | the strategy here is to use max pooling for positive values and "min pooling"
24 | for negative values, the two results are then summed.
25 | This technique allows sparsity to be minized, contrary to nearest interpolation,
26 | which could potentially lose information for isolated data points."""
27 |
28 | positive = (input > 0).float()
29 | negative = (input < 0).float()
30 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(
31 | -input * negative, size
32 | )
33 | return output
34 |
35 |
36 | def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
37 | def one_scale(output, target, sparse):
38 |
39 | b, _, h, w = output.size()
40 |
41 | if sparse:
42 | target_scaled = sparse_max_pool(target, (h, w))
43 | else:
44 | target_scaled = F.interpolate(target, (h, w), mode="area")
45 | return EPE(output, target_scaled, sparse, mean=False)
46 |
47 | if type(network_output) not in [tuple, list]:
48 | network_output = [network_output]
49 | if weights is None:
50 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
51 | assert len(weights) == len(network_output)
52 |
53 | loss = 0
54 | for output, weight in zip(network_output, weights):
55 | loss += weight * one_scale(output, target_flow, sparse)
56 | return loss
57 |
58 |
59 | def realEPE(output, target, sparse=False):
60 | b, _, h, w = target.size()
61 | upsampled_output = F.interpolate(
62 | output, (h, w), mode="bilinear", align_corners=False
63 | )
64 | return EPE(upsampled_output, target, sparse, mean=True)
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.2
2 | torchvision
3 | numpy
4 | spatial-correlation-sampler>=0.2.1
5 | tensorboard
6 | imageio
7 | argparse
8 | path
9 | tqdm
10 | scipy
11 |
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from path import Path
3 |
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torch.nn.functional as F
7 | import models
8 | from tqdm import tqdm
9 |
10 | import torchvision.transforms as transforms
11 | import flow_transforms
12 | from imageio import imread, imwrite
13 | import numpy as np
14 | from util import flow2rgb
15 |
16 | model_names = sorted(
17 | name for name in models.__dict__ if name.islower() and not name.startswith("__")
18 | )
19 |
20 |
21 | parser = argparse.ArgumentParser(
22 | description="PyTorch FlowNet inference on a folder of img pairs",
23 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
24 | )
25 | parser.add_argument(
26 | "data",
27 | metavar="DIR",
28 | help="path to images folder, image names must match '[name]0.[ext]' and '[name]1.[ext]'",
29 | )
30 | parser.add_argument("pretrained", metavar="PTH", help="path to pre-trained model")
31 | parser.add_argument(
32 | "--output",
33 | "-o",
34 | metavar="DIR",
35 | default=None,
36 | help="path to output folder. If not set, will be created in data folder",
37 | )
38 | parser.add_argument(
39 | "--output-value",
40 | "-v",
41 | choices=["raw", "vis", "both"],
42 | default="both",
43 | help="which value to output, between raw input (as a npy file) and color vizualisation (as an image file)."
44 | " If not set, will output both",
45 | )
46 | parser.add_argument(
47 | "--div-flow",
48 | default=20,
49 | type=float,
50 | help="value by which flow will be divided. overwritten if stored in pretrained file",
51 | )
52 | parser.add_argument(
53 | "--img-exts",
54 | metavar="EXT",
55 | default=["png", "jpg", "bmp", "ppm"],
56 | nargs="*",
57 | type=str,
58 | help="images extensions to glob",
59 | )
60 | parser.add_argument(
61 | "--max_flow",
62 | default=None,
63 | type=float,
64 | help="max flow value. Flow map color is saturated above this value. If not set, will use flow map's max value",
65 | )
66 | parser.add_argument(
67 | "--upsampling",
68 | "-u",
69 | choices=["nearest", "bilinear"],
70 | default=None,
71 | help="if not set, will output FlowNet raw input,"
72 | "which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling",
73 | )
74 | parser.add_argument(
75 | "--bidirectional",
76 | action="store_true",
77 | help="if set, will output invert flow (from 1 to 0) along with regular flow",
78 | )
79 |
80 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
81 |
82 |
83 | @torch.no_grad()
84 | def main():
85 | global args, save_path
86 | args = parser.parse_args()
87 |
88 | if args.output_value == "both":
89 | output_string = "raw output and RGB visualization"
90 | elif args.output_value == "raw":
91 | output_string = "raw output"
92 | elif args.output_value == "vis":
93 | output_string = "RGB visualization"
94 | print("=> will save " + output_string)
95 | data_dir = Path(args.data)
96 | print("=> fetching img pairs in '{}'".format(args.data))
97 | if args.output is None:
98 | save_path = data_dir / "flow"
99 | else:
100 | save_path = Path(args.output)
101 | print("=> will save everything to {}".format(save_path))
102 | save_path.makedirs_p()
103 | # Data loading code
104 | input_transform = transforms.Compose(
105 | [
106 | flow_transforms.ArrayToTensor(),
107 | transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
108 | transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]),
109 | ]
110 | )
111 |
112 | img_pairs = []
113 | for ext in args.img_exts:
114 | test_files = data_dir.files("*1.{}".format(ext))
115 | for file in test_files:
116 | img_pair = file.parent / (file.stem[:-1] + "2.{}".format(ext))
117 | if img_pair.isfile():
118 | img_pairs.append([file, img_pair])
119 |
120 | print("{} samples found".format(len(img_pairs)))
121 | # create model
122 | network_data = torch.load(args.pretrained)
123 | print("=> using pre-trained model '{}'".format(network_data["arch"]))
124 | model = models.__dict__[network_data["arch"]](network_data).to(device)
125 | model.eval()
126 | cudnn.benchmark = True
127 |
128 | if "div_flow" in network_data.keys():
129 | args.div_flow = network_data["div_flow"]
130 |
131 | for img1_file, img2_file in tqdm(img_pairs):
132 |
133 | img1 = input_transform(imread(img1_file))
134 | img2 = input_transform(imread(img2_file))
135 | input_var = torch.cat([img1, img2]).unsqueeze(0)
136 |
137 | if args.bidirectional:
138 | # feed inverted pair along with normal pair
139 | inverted_input_var = torch.cat([img2, img1]).unsqueeze(0)
140 | input_var = torch.cat([input_var, inverted_input_var])
141 |
142 | input_var = input_var.to(device)
143 | # compute output
144 | output = model(input_var)
145 | if args.upsampling is not None:
146 | output = F.interpolate(
147 | output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False
148 | )
149 | for suffix, flow_output in zip(["flow", "inv_flow"], output):
150 | filename = save_path / "{}{}".format(img1_file.stem[:-1], suffix)
151 | if args.output_value in ["vis", "both"]:
152 | rgb_flow = flow2rgb(
153 | args.div_flow * flow_output, max_value=args.max_flow
154 | )
155 | to_save = (rgb_flow * 255).astype(np.uint8).transpose(1, 2, 0)
156 | imwrite(filename + ".png", to_save)
157 | if args.output_value in ["raw", "both"]:
158 | # Make the flow map a HxWx2 array as in .flo files
159 | to_save = (args.div_flow * flow_output).cpu().numpy().transpose(1, 2, 0)
160 | np.save(filename + ".npy", to_save)
161 |
162 |
163 | if __name__ == "__main__":
164 | main()
165 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import shutil
4 | import torch
5 |
6 |
7 | def save_checkpoint(state, is_best, save_path, filename="checkpoint.pth.tar"):
8 | torch.save(state, os.path.join(save_path, filename))
9 | if is_best:
10 | shutil.copyfile(
11 | os.path.join(save_path, filename),
12 | os.path.join(save_path, "model_best.pth.tar"),
13 | )
14 |
15 |
16 | class AverageMeter(object):
17 | """Computes and stores the average and current value"""
18 |
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.avg = 0
25 | self.sum = 0
26 | self.count = 0
27 |
28 | def update(self, val, n=1):
29 | self.val = val
30 | self.sum += val * n
31 | self.count += n
32 | self.avg = self.sum / self.count
33 |
34 | def __repr__(self):
35 | return "{:.3f} ({:.3f})".format(self.val, self.avg)
36 |
37 |
38 | def flow2rgb(flow_map, max_value):
39 | flow_map_np = flow_map.detach().cpu().numpy()
40 | _, h, w = flow_map_np.shape
41 | flow_map_np[:, (flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float("nan")
42 | rgb_map = np.ones((3, h, w)).astype(np.float32)
43 | if max_value is not None:
44 | normalized_flow_map = flow_map_np / max_value
45 | else:
46 | normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
47 | rgb_map[0] += normalized_flow_map[0]
48 | rgb_map[1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1])
49 | rgb_map[2] += normalized_flow_map[1]
50 | return rgb_map.clip(0, 1)
51 |
--------------------------------------------------------------------------------