├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── KITTI.py ├── __init__.py ├── flyingchairs.py ├── listdataset.py ├── mpisintel.py └── util.py ├── flow_transforms.py ├── images ├── GT_1.png ├── GT_2.png ├── GT_3.png ├── input_1.gif ├── input_2.gif ├── input_3.gif ├── pred_1.png ├── pred_2.png └── pred_3.png ├── main.py ├── models ├── FlowNetC.py ├── FlowNetS.py ├── __init__.py └── util.py ├── multiscaleloss.py ├── requirements.txt ├── run_inference.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | */_pycache__ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Clément Pinard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowNetPytorch 2 | Pytorch implementation of FlowNet by Dosovitskiy et al. 3 | 4 | This repository is a torch implementation of [FlowNet](http://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/), by [Alexey Dosovitskiy](http://lmb.informatik.uni-freiburg.de/people/dosovits/) et al. in PyTorch. See Torch implementation [here](https://github.com/ClementPinard/FlowNetTorch) 5 | 6 | This code is mainly inspired from official [imagenet example](https://github.com/pytorch/examples/tree/master/imagenet). 7 | It has not been tested for multiple GPU, but it should work just as in original code. 8 | 9 | The code provides a training example, using [the flying chair dataset](http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) , with data augmentation. An implementation for [Scene Flow Datasets](http://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) may be added in the future. 10 | 11 | Two neural network models are currently provided, along with their batch norm variation (experimental) : 12 | 13 | - **FlowNetS** 14 | - **FlowNetSBN** 15 | - **FlowNetC** 16 | - **FlowNetCBN** 17 | 18 | ## Pretrained Models 19 | Thanks to [Kaixhin](https://github.com/Kaixhin) you can download a pretrained version of FlowNetS (from caffe, not from pytorch) [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharing). This folder also contains trained networks from scratch. 20 | 21 | ### Note on networks loading 22 | Directly feed the downloaded Network to the script, you don't need to uncompress it even if your desktop environment tells you so. 23 | 24 | ### Note on networks from caffe 25 | These networks expect a BGR input (compared to RGB in pytorch). However, BGR order is not very important. 26 | 27 | ## Prerequisite 28 | these modules can be installed with `pip` 29 | 30 | ``` 31 | pytorch >= 1.2 32 | tensorboard-pytorch 33 | tensorboardX >= 1.4 34 | spatial-correlation-sampler>=0.2.1 35 | imageio 36 | argparse 37 | path.py 38 | ``` 39 | 40 | or 41 | ```bash 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ## Training on Flying Chair Dataset 46 | 47 | First, you need to download the [the flying chair dataset](http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html) . It is ~64GB big and we recommend you put it in a SSD Drive. 48 | 49 | Default HyperParameters provided in `main.py` are the same as in the caffe training scripts. 50 | 51 | * Example usage for FlowNetS : 52 | 53 | ```bash 54 | python main.py /path/to/flying_chairs/ -b8 -j8 -a flownets 55 | ``` 56 | 57 | We recommend you set j (number of data threads) to high if you use DataAugmentation as to avoid data loading to slow the training. 58 | 59 | For further help you can type 60 | 61 | ```bash 62 | python main.py -h 63 | ``` 64 | 65 | ## Visualizing training 66 | [Tensorboard-pytorch](https://github.com/lanpa/tensorboard-pytorch) is used for logging. To visualize result, simply type 67 | 68 | ```bash 69 | tensorboard --logdir=/path/to/checkpoints 70 | ``` 71 | 72 | ## Training results 73 | 74 | Models can be downloaded [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharing) in the pytorch folder. 75 | 76 | Models were trained with default options unless specified. Color warping was not used. 77 | 78 | | Arch | learning rate | batch size | epoch size | filename | validation EPE | 79 | | ----------- | ------------- | ---------- | ---------- | ---------------------------- | -------------- | 80 | | FlowNetS | 1e-4 | 8 | 2700 | flownets_EPE1.951.pth.tar | 1.951 | 81 | | FlowNetS BN | 1e-3 | 32 | 695 | flownets_bn_EPE2.459.pth.tar | 2.459 | 82 | | FlowNetC | 1e-4 | 8 | 2700 | flownetc_EPE1.766.pth.tar | 1.766 | 83 | 84 | *Note* : FlowNetS BN took longer to train and got worse results. It is strongly advised not to you use it for Flying Chairs dataset. 85 | 86 | ## Validation samples 87 | 88 | Prediction are made by FlowNetS. 89 | 90 | Exact code for Optical Flow -> Color map can be found [here](main.py#L321) 91 | 92 | | Input | prediction | GroundTruth | 93 | |-------|------------|-------------| 94 | | | | | 95 | | | | | 96 | | | | | 97 | 98 | ## Running inference on a set of image pairs 99 | 100 | If you need to run the network on your images, you can download a pretrained network [here](https://drive.google.com/drive/folders/1dTpSyc7rIYYG19p1uiDfilcsmSPNy-_3?usp=sharingM) and launch the inference script on your folder of image pairs. 101 | 102 | Your folder needs to have all the images pairs in the same location, with the name pattern 103 | ``` 104 | {image_name}1.{ext} 105 | {image_name}2.{ext} 106 | ``` 107 | 108 | ```bash 109 | python3 run_inference.py /path/to/images/folder /path/to/pretrained 110 | ``` 111 | 112 | As for the `main.py` script, a help menu is available for additional options. 113 | 114 | ## Note on transform functions 115 | 116 | In order to have coherent transformations between inputs and target, we must define new transformations that take both input and target, as a new random variable is defined each time a random transformation is called. 117 | 118 | ### Flow Transformations 119 | 120 | To allow data augmentation, we have considered rotation and translations for inputs and their result on target flow Map. 121 | Here is a set of things to take care of in order to achieve a proper data augmentation 122 | 123 | #### The Flow Map is directly linked to img1 124 | If you apply a transformation on img1, you have to apply the very same to Flow Map, to get coherent origin points for flow. 125 | 126 | #### Translation between img1 and img2 127 | Given a translation `(tx,ty)` applied on img2, we will have 128 | ``` 129 | flow[:,:,0] += tx 130 | flow[:,:,1] += ty 131 | ``` 132 | 133 | #### Scale 134 | A scale applied on both img1 and img2 with a zoom parameters `alpha` multiplies the flow by the same amount 135 | ``` 136 | flow *= alpha 137 | ``` 138 | 139 | #### Rotation applied on both images 140 | A rotation applied on both images by an angle `theta` also rotates flow vectors (`flow[i,j]`) by the same angle 141 | ``` 142 | \for_all i,j flow[i,j] = rotate(flow[i,j], theta) 143 | 144 | rotate: x,y,theta -> (x*cos(theta)-x*sin(theta), y*cos(theta), x*sin(theta)) 145 | ``` 146 | 147 | #### Rotation applied on img2 148 | Let us consider a rotation by the angle `theta` from the image center. 149 | 150 | We must tranform each flow vector based on the coordinates where it lands. On each coordinate `(i, j)`, we have: 151 | ``` 152 | flow[i, j, 0] += (cos(theta) - 1) * (j - w/2 + flow[i, j, 0]) + sin(theta) * (i - h/2 + flow[i, j, 1]) 153 | flow[i, j, 1] += -sin(theta) * (j - w/2 + flow[i, j, 0]) + (cos(theta) - 1) * (i - h/2 + flow[i, j, 1]) 154 | ``` 155 | -------------------------------------------------------------------------------- /datasets/KITTI.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os.path 3 | import glob 4 | from .listdataset import ListDataset 5 | from .util import split2list 6 | import numpy as np 7 | import flow_transforms 8 | 9 | try: 10 | import cv2 11 | except ImportError: 12 | import warnings 13 | 14 | with warnings.catch_warnings(): 15 | warnings.filterwarnings("default", category=ImportWarning) 16 | warnings.warn( 17 | "failed to load openCV, which is needed" 18 | "for KITTI which uses 16bit PNG images", 19 | ImportWarning, 20 | ) 21 | 22 | """ 23 | Dataset routines for KITTI_flow, 2012 and 2015. 24 | http://www.cvlibs.net/datasets/kitti/eval_flow.php 25 | The dataset is not very big, you might want to only finetune on it for flownet 26 | EPE are not representative in this dataset because of the sparsity of the GT. 27 | OpenCV is needed to load 16bit png images 28 | """ 29 | 30 | 31 | def load_flow_from_png(png_path): 32 | # The -1 is here to specify not to change the image depth (16bit), and is compatible 33 | # with both OpenCV2 and OpenCV3 34 | flo_file = cv2.imread(png_path, -1) 35 | flo_img = flo_file[:, :, 2:0:-1].astype(np.float32) 36 | invalid = flo_file[:, :, 0] == 0 37 | flo_img = flo_img - 32768 38 | flo_img = flo_img / 64 39 | flo_img[np.abs(flo_img) < 1e-10] = 1e-10 40 | flo_img[invalid, :] = 0 41 | return flo_img 42 | 43 | 44 | def make_dataset(dir, split, split_save_path, occ=True): 45 | """Will search in training folder for folders 'flow_noc' or 'flow_occ' 46 | and 'colored_0' (KITTI 2012) or 'image_2' (KITTI 2015)""" 47 | flow_dir = "flow_occ" if occ else "flow_noc" 48 | assert os.path.isdir(os.path.join(dir, flow_dir)) 49 | img_dir = "colored_0" 50 | if not os.path.isdir(os.path.join(dir, img_dir)): 51 | img_dir = "image_2" 52 | assert os.path.isdir(os.path.join(dir, img_dir)) 53 | 54 | images = [] 55 | for flow_map in glob.iglob(os.path.join(dir, flow_dir, "*.png")): 56 | flow_map = os.path.basename(flow_map) 57 | root_filename = flow_map[:-7] 58 | flow_map = os.path.join(flow_dir, flow_map) 59 | img1 = os.path.join(img_dir, root_filename + "_10.png") 60 | img2 = os.path.join(img_dir, root_filename + "_11.png") 61 | if not ( 62 | os.path.isfile(os.path.join(dir, img1)) 63 | or os.path.isfile(os.path.join(dir, img2)) 64 | ): 65 | continue 66 | images.append([[img1, img2], flow_map]) 67 | 68 | return split2list(images, split, split_save_path, default_split=0.9) 69 | 70 | 71 | def KITTI_loader(root, path_imgs, path_flo): 72 | imgs = [os.path.join(root, path) for path in path_imgs] 73 | flo = os.path.join(root, path_flo) 74 | return [ 75 | cv2.imread(img)[:, :, ::-1].astype(np.float32) for img in imgs 76 | ], load_flow_from_png(flo) 77 | 78 | 79 | def KITTI_occ( 80 | root, 81 | transform=None, 82 | target_transform=None, 83 | co_transform=None, 84 | split=None, 85 | split_save_path=None, 86 | ): 87 | train_list, test_list = make_dataset(root, split, split_save_path, True) 88 | train_dataset = ListDataset( 89 | root, train_list, transform, target_transform, co_transform, loader=KITTI_loader 90 | ) 91 | # All test sample are cropped to lowest possible size of KITTI images 92 | test_dataset = ListDataset( 93 | root, 94 | test_list, 95 | transform, 96 | target_transform, 97 | flow_transforms.CenterCrop((370, 1224)), 98 | loader=KITTI_loader, 99 | ) 100 | 101 | return train_dataset, test_dataset 102 | 103 | 104 | def KITTI_noc( 105 | root, 106 | transform=None, 107 | target_transform=None, 108 | co_transform=None, 109 | split=None, 110 | split_save_path=None, 111 | ): 112 | train_list, test_list = make_dataset(root, split, split_save_path, False) 113 | train_dataset = ListDataset( 114 | root, train_list, transform, target_transform, co_transform, loader=KITTI_loader 115 | ) 116 | # All test sample are cropped to lowest possible size of KITTI images 117 | test_dataset = ListDataset( 118 | root, 119 | test_list, 120 | transform, 121 | target_transform, 122 | flow_transforms.CenterCrop((370, 1224)), 123 | loader=KITTI_loader, 124 | ) 125 | 126 | return train_dataset, test_dataset 127 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .flyingchairs import flying_chairs 2 | from .KITTI import KITTI_occ, KITTI_noc 3 | from .mpisintel import mpi_sintel_clean, mpi_sintel_final, mpi_sintel_both 4 | 5 | __all__ = ( 6 | "flying_chairs", 7 | "KITTI_occ", 8 | "KITTI_noc", 9 | "mpi_sintel_clean", 10 | "mpi_sintel_final", 11 | "mpi_sintel_both", 12 | ) 13 | -------------------------------------------------------------------------------- /datasets/flyingchairs.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | from .listdataset import ListDataset 4 | from .util import split2list 5 | 6 | 7 | def make_dataset(dir, split=None, split_save_path=None): 8 | """Will search for triplets that go by the pattern '[name]_img1.ppm [name]_img2.ppm [name]_flow.flo'""" 9 | images = [] 10 | for flow_map in sorted(glob.glob(os.path.join(dir, "*_flow.flo"))): 11 | flow_map = os.path.basename(flow_map) 12 | root_filename = flow_map[:-9] 13 | img1 = root_filename + "_img1.ppm" 14 | img2 = root_filename + "_img2.ppm" 15 | if not ( 16 | os.path.isfile(os.path.join(dir, img1)) 17 | and os.path.isfile(os.path.join(dir, img2)) 18 | ): 19 | continue 20 | 21 | images.append([[img1, img2], flow_map]) 22 | return split2list(images, split, split_save_path, default_split=0.97) 23 | 24 | 25 | def flying_chairs( 26 | root, 27 | transform=None, 28 | target_transform=None, 29 | co_transform=None, 30 | split=None, 31 | split_save_path=None, 32 | ): 33 | train_list, test_list = make_dataset(root, split, split_save_path) 34 | train_dataset = ListDataset( 35 | root, train_list, transform, target_transform, co_transform 36 | ) 37 | test_dataset = ListDataset(root, test_list, transform, target_transform) 38 | 39 | return train_dataset, test_dataset 40 | -------------------------------------------------------------------------------- /datasets/listdataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import os.path 4 | from imageio import imread 5 | import numpy as np 6 | 7 | 8 | def load_flo(path): 9 | with open(path, "rb") as f: 10 | magic = np.fromfile(f, np.float32, count=1) 11 | assert 202021.25 == magic, "Magic number incorrect. Invalid .flo file" 12 | h = np.fromfile(f, np.int32, count=1)[0] 13 | w = np.fromfile(f, np.int32, count=1)[0] 14 | data = np.fromfile(f, np.float32, count=2 * w * h) 15 | # Reshape data into 3D array (columns, rows, bands) 16 | data2D = np.resize(data, (w, h, 2)) 17 | return data2D 18 | 19 | 20 | def default_loader(root, path_imgs, path_flo): 21 | imgs = [os.path.join(root, path) for path in path_imgs] 22 | flo = os.path.join(root, path_flo) 23 | return [imread(img).astype(np.float32) for img in imgs], load_flo(flo) 24 | 25 | 26 | class ListDataset(data.Dataset): 27 | def __init__( 28 | self, 29 | root, 30 | path_list, 31 | transform=None, 32 | target_transform=None, 33 | co_transform=None, 34 | loader=default_loader, 35 | ): 36 | 37 | self.root = root 38 | self.path_list = path_list 39 | self.transform = transform 40 | self.target_transform = target_transform 41 | self.co_transform = co_transform 42 | self.loader = loader 43 | 44 | def __getitem__(self, index): 45 | inputs, target = self.path_list[index] 46 | 47 | inputs, target = self.loader(self.root, inputs, target) 48 | if self.co_transform is not None: 49 | inputs, target = self.co_transform(inputs, target) 50 | if self.transform is not None: 51 | inputs[0] = self.transform(inputs[0]) 52 | inputs[1] = self.transform(inputs[1]) 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | return inputs, target 56 | 57 | def __len__(self): 58 | return len(self.path_list) 59 | -------------------------------------------------------------------------------- /datasets/mpisintel.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | from .listdataset import ListDataset 4 | from .util import split2list 5 | import flow_transforms 6 | 7 | """ 8 | Dataset routines for MPI Sintel. 9 | http://sintel.is.tue.mpg.de/ 10 | clean version imgs are without shaders, final version imgs are fully rendered 11 | The dataset is not very big, you might want to only pretrain on it for flownet 12 | """ 13 | 14 | 15 | def make_dataset(dataset_dir, split, split_save_path, dataset_type="clean"): 16 | flow_dir = "flow" 17 | assert os.path.isdir(os.path.join(dataset_dir, flow_dir)) 18 | img_dir = dataset_type 19 | assert os.path.isdir(os.path.join(dataset_dir, img_dir)) 20 | 21 | images = [] 22 | for flow_map in sorted( 23 | glob.glob(os.path.join(dataset_dir, flow_dir, "*", "*.flo")) 24 | ): 25 | flow_map = os.path.relpath(flow_map, os.path.join(dataset_dir, flow_dir)) 26 | 27 | scene_dir, filename = os.path.split(flow_map) 28 | no_ext_filename = os.path.splitext(filename)[0] 29 | prefix, frame_nb = no_ext_filename.split("_") 30 | frame_nb = int(frame_nb) 31 | img1 = os.path.join( 32 | img_dir, scene_dir, "{}_{:04d}.png".format(prefix, frame_nb) 33 | ) 34 | img2 = os.path.join( 35 | img_dir, scene_dir, "{}_{:04d}.png".format(prefix, frame_nb + 1) 36 | ) 37 | flow_map = os.path.join(flow_dir, flow_map) 38 | if not ( 39 | os.path.isfile(os.path.join(dataset_dir, img1)) 40 | and os.path.isfile(os.path.join(dataset_dir, img2)) 41 | ): 42 | continue 43 | images.append([[img1, img2], flow_map]) 44 | 45 | return split2list(images, split, split_save_path, default_split=0.87) 46 | 47 | 48 | def mpi_sintel_clean( 49 | root, 50 | transform=None, 51 | target_transform=None, 52 | co_transform=None, 53 | split=None, 54 | split_save_path=None, 55 | ): 56 | train_list, test_list = make_dataset(root, split, split_save_path, "clean") 57 | train_dataset = ListDataset( 58 | root, train_list, transform, target_transform, co_transform 59 | ) 60 | test_dataset = ListDataset( 61 | root, 62 | test_list, 63 | transform, 64 | target_transform, 65 | flow_transforms.CenterCrop((384, 1024)), 66 | ) 67 | 68 | return train_dataset, test_dataset 69 | 70 | 71 | def mpi_sintel_final( 72 | root, 73 | transform=None, 74 | target_transform=None, 75 | co_transform=None, 76 | split=None, 77 | split_save_path=None, 78 | ): 79 | train_list, test_list = make_dataset(root, split, split_save_path, "final") 80 | train_dataset = ListDataset( 81 | root, train_list, transform, target_transform, co_transform 82 | ) 83 | test_dataset = ListDataset( 84 | root, 85 | test_list, 86 | transform, 87 | target_transform, 88 | flow_transforms.CenterCrop((384, 1024)), 89 | ) 90 | 91 | return train_dataset, test_dataset 92 | 93 | 94 | def mpi_sintel_both( 95 | root, 96 | transform=None, 97 | target_transform=None, 98 | co_transform=None, 99 | split=None, 100 | split_save_path=None, 101 | ): 102 | """load images from both clean and final folders. 103 | We cannot shuffle input, because it would very likely cause data snooping 104 | for the clean and final frames are not that different""" 105 | assert isinstance( 106 | split, str 107 | ), "To avoid data snooping, you must provide a static list of train/val when dealing with both clean and final." 108 | " Look at Sintel_train_val.txt for an example" 109 | train_list1, test_list1 = make_dataset(root, split, split_save_path, "clean") 110 | train_list2, test_list2 = make_dataset(root, split, split_save_path, "final") 111 | train_dataset = ListDataset( 112 | root, train_list1 + train_list2, transform, target_transform, co_transform 113 | ) 114 | test_dataset = ListDataset( 115 | root, 116 | test_list1 + test_list2, 117 | transform, 118 | target_transform, 119 | flow_transforms.CenterCrop((384, 1024)), 120 | ) 121 | 122 | return train_dataset, test_dataset 123 | -------------------------------------------------------------------------------- /datasets/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def split2list(images, split, split_save_path, default_split=0.9): 5 | if isinstance(split, str): 6 | with open(split) as f: 7 | split_values = [x.strip() == "1" for x in f.readlines()] 8 | assert len(images) == len(split_values) 9 | elif split is None: 10 | split_values = np.random.uniform(0, 1, len(images)) < default_split 11 | else: 12 | try: 13 | split = float(split) 14 | except TypeError: 15 | print("Invalid Split value, it must be either a filepath or a float") 16 | raise 17 | split_values = np.random.uniform(0, 1, len(images)) < split 18 | if split_save_path is not None: 19 | with open(split_save_path, "w") as f: 20 | f.write("\n".join(map(lambda x: str(int(x)), split_values))) 21 | train_samples = [sample for sample, split in zip(images, split_values) if split] 22 | test_samples = [sample for sample, split in zip(images, split_values) if not split] 23 | return train_samples, test_samples 24 | -------------------------------------------------------------------------------- /flow_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | import numbers 6 | import types 7 | import scipy.ndimage as ndimage 8 | 9 | """Set of tranform random routines that takes both input and target as arguments, 10 | in order to have random but coherent transformations. 11 | inputs are PIL Image pairs and targets are ndarrays""" 12 | 13 | 14 | class Compose(object): 15 | """Composes several co_transforms together. 16 | For example: 17 | >>> co_transforms.Compose([ 18 | >>> co_transforms.CenterCrop(10), 19 | >>> co_transforms.ToTensor(), 20 | >>> ]) 21 | """ 22 | 23 | def __init__(self, co_transforms): 24 | self.co_transforms = co_transforms 25 | 26 | def __call__(self, input, target): 27 | for t in self.co_transforms: 28 | input, target = t(input, target) 29 | return input, target 30 | 31 | 32 | class ArrayToTensor(object): 33 | """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).""" 34 | 35 | def __call__(self, array): 36 | assert isinstance(array, np.ndarray) 37 | array = np.transpose(array, (2, 0, 1)) 38 | # handle numpy array 39 | tensor = torch.from_numpy(array) 40 | # put it from HWC to CHW format 41 | return tensor.float() 42 | 43 | 44 | class Lambda(object): 45 | """Applies a lambda as a transform""" 46 | 47 | def __init__(self, lambd): 48 | assert isinstance(lambd, types.LambdaType) 49 | self.lambd = lambd 50 | 51 | def __call__(self, input, target): 52 | return self.lambd(input, target) 53 | 54 | 55 | class CenterCrop(object): 56 | """Crops the given inputs and target arrays at the center to have a region of 57 | the given size. size can be a tuple (target_height, target_width) 58 | or an integer, in which case the target will be of a square shape (size, size) 59 | Careful, img1 and img2 may not be the same size 60 | """ 61 | 62 | def __init__(self, size): 63 | if isinstance(size, numbers.Number): 64 | self.size = (int(size), int(size)) 65 | else: 66 | self.size = size 67 | 68 | def __call__(self, inputs, target): 69 | h1, w1, _ = inputs[0].shape 70 | h2, w2, _ = inputs[1].shape 71 | th, tw = self.size 72 | x1 = int(round((w1 - tw) / 2.0)) 73 | y1 = int(round((h1 - th) / 2.0)) 74 | x2 = int(round((w2 - tw) / 2.0)) 75 | y2 = int(round((h2 - th) / 2.0)) 76 | 77 | inputs[0] = inputs[0][y1 : y1 + th, x1 : x1 + tw] 78 | inputs[1] = inputs[1][y2 : y2 + th, x2 : x2 + tw] 79 | target = target[y1 : y1 + th, x1 : x1 + tw] 80 | return inputs, target 81 | 82 | 83 | class Scale(object): 84 | """Rescales the inputs and target arrays to the given 'size'. 85 | 'size' will be the size of the smaller edge. 86 | For example, if height > width, then image will be 87 | rescaled to (size * height / width, size) 88 | size: size of the smaller edge 89 | interpolation order: Default: 2 (bilinear) 90 | """ 91 | 92 | def __init__(self, size, order=2): 93 | self.size = size 94 | self.order = order 95 | 96 | def __call__(self, inputs, target): 97 | h, w, _ = inputs[0].shape 98 | if (w <= h and w == self.size) or (h <= w and h == self.size): 99 | return inputs, target 100 | if w < h: 101 | ratio = self.size / w 102 | else: 103 | ratio = self.size / h 104 | 105 | inputs[0] = ndimage.interpolation.zoom(inputs[0], ratio, order=self.order) 106 | inputs[1] = ndimage.interpolation.zoom(inputs[1], ratio, order=self.order) 107 | 108 | target = ndimage.interpolation.zoom(target, ratio, order=self.order) 109 | target *= ratio 110 | return inputs, target 111 | 112 | 113 | class RandomCrop(object): 114 | """Crops the given PIL.Image at a random location to have a region of 115 | the given size. size can be a tuple (target_height, target_width) 116 | or an integer, in which case the target will be of a square shape (size, size) 117 | """ 118 | 119 | def __init__(self, size): 120 | if isinstance(size, numbers.Number): 121 | self.size = (int(size), int(size)) 122 | else: 123 | self.size = size 124 | 125 | def __call__(self, inputs, target): 126 | h, w, _ = inputs[0].shape 127 | th, tw = self.size 128 | if w == tw and h == th: 129 | return inputs, target 130 | 131 | x1 = random.randint(0, w - tw) 132 | y1 = random.randint(0, h - th) 133 | inputs[0] = inputs[0][y1 : y1 + th, x1 : x1 + tw] 134 | inputs[1] = inputs[1][y1 : y1 + th, x1 : x1 + tw] 135 | return inputs, target[y1 : y1 + th, x1 : x1 + tw] 136 | 137 | 138 | class RandomHorizontalFlip(object): 139 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5""" 140 | 141 | def __call__(self, inputs, target): 142 | if random.random() < 0.5: 143 | inputs[0] = np.copy(np.fliplr(inputs[0])) 144 | inputs[1] = np.copy(np.fliplr(inputs[1])) 145 | target = np.copy(np.fliplr(target)) 146 | target[:, :, 0] *= -1 147 | return inputs, target 148 | 149 | 150 | class RandomVerticalFlip(object): 151 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5""" 152 | 153 | def __call__(self, inputs, target): 154 | if random.random() < 0.5: 155 | inputs[0] = np.copy(np.flipud(inputs[0])) 156 | inputs[1] = np.copy(np.flipud(inputs[1])) 157 | target = np.copy(np.flipud(target)) 158 | target[:, :, 1] *= -1 159 | return inputs, target 160 | 161 | 162 | class RandomRotate(object): 163 | """Random rotation of the image from -angle to angle (in degrees) 164 | This is useful for dataAugmentation, especially for geometric problems such as FlowEstimation 165 | angle: max angle of the rotation 166 | interpolation order: Default: 2 (bilinear) 167 | reshape: Default: false. If set to true, image size will be set to keep every pixel in the image. 168 | diff_angle: Default: 0. 169 | """ 170 | 171 | def __init__(self, angle, diff_angle=0, order=2, reshape=False): 172 | self.angle = angle 173 | self.reshape = reshape 174 | self.order = order 175 | self.diff_angle = diff_angle 176 | 177 | def __call__(self, inputs, target): 178 | applied_angle = random.uniform(-self.angle, self.angle) 179 | diff = random.uniform(-self.diff_angle, self.diff_angle) 180 | angle1 = applied_angle - diff / 2 181 | angle2 = applied_angle + diff / 2 182 | angle1_rad = angle1 * np.pi / 180 183 | diff_rad = diff * np.pi / 180 184 | 185 | h, w, _ = target.shape 186 | 187 | warped_coords = np.mgrid[:w, :h].T + target 188 | warped_coords -= np.array([w / 2, h / 2]) 189 | 190 | warped_coords_rot = np.zeros_like(target) 191 | 192 | warped_coords_rot[..., 0] = (np.cos(diff_rad) - 1) * warped_coords[ 193 | ..., 0 194 | ] + np.sin(diff_rad) * warped_coords[..., 1] 195 | 196 | warped_coords_rot[..., 1] = ( 197 | -np.sin(diff_rad) * warped_coords[..., 0] 198 | + (np.cos(diff_rad) - 1) * warped_coords[..., 1] 199 | ) 200 | 201 | target += warped_coords_rot 202 | 203 | inputs[0] = ndimage.interpolation.rotate( 204 | inputs[0], angle1, reshape=self.reshape, order=self.order 205 | ) 206 | inputs[1] = ndimage.interpolation.rotate( 207 | inputs[1], angle2, reshape=self.reshape, order=self.order 208 | ) 209 | target = ndimage.interpolation.rotate( 210 | target, angle1, reshape=self.reshape, order=self.order 211 | ) 212 | # flow vectors must be rotated too! careful about Y flow which is upside down 213 | target_ = np.copy(target) 214 | target[:, :, 0] = ( 215 | np.cos(angle1_rad) * target_[:, :, 0] 216 | + np.sin(angle1_rad) * target_[:, :, 1] 217 | ) 218 | target[:, :, 1] = ( 219 | -np.sin(angle1_rad) * target_[:, :, 0] 220 | + np.cos(angle1_rad) * target_[:, :, 1] 221 | ) 222 | return inputs, target 223 | 224 | 225 | class RandomTranslate(object): 226 | def __init__(self, translation): 227 | if isinstance(translation, numbers.Number): 228 | self.translation = (int(translation), int(translation)) 229 | else: 230 | self.translation = translation 231 | 232 | def __call__(self, inputs, target): 233 | h, w, _ = inputs[0].shape 234 | th, tw = self.translation 235 | tw = random.randint(-tw, tw) 236 | th = random.randint(-th, th) 237 | if tw == 0 and th == 0: 238 | return inputs, target 239 | # compute x1,x2,y1,y2 for img1 and target, and x3,x4,y3,y4 for img2 240 | x1, x2, x3, x4 = max(0, tw), min(w + tw, w), max(0, -tw), min(w - tw, w) 241 | y1, y2, y3, y4 = max(0, th), min(h + th, h), max(0, -th), min(h - th, h) 242 | 243 | inputs[0] = inputs[0][y1:y2, x1:x2] 244 | inputs[1] = inputs[1][y3:y4, x3:x4] 245 | target = target[y1:y2, x1:x2] 246 | target[:, :, 0] += tw 247 | target[:, :, 1] += th 248 | 249 | return inputs, target 250 | 251 | 252 | class RandomColorWarp(object): 253 | def __init__(self, mean_range=0, std_range=0): 254 | self.mean_range = mean_range 255 | self.std_range = std_range 256 | 257 | def __call__(self, inputs, target): 258 | random_std = np.random.uniform(-self.std_range, self.std_range, 3) 259 | random_mean = np.random.uniform(-self.mean_range, self.mean_range, 3) 260 | random_order = np.random.permutation(3) 261 | 262 | inputs[0] *= 1 + random_std 263 | inputs[0] += random_mean 264 | 265 | inputs[1] *= 1 + random_std 266 | inputs[1] += random_mean 267 | 268 | inputs[0] = inputs[0][:, :, random_order] 269 | inputs[1] = inputs[1][:, :, random_order] 270 | 271 | return inputs, target 272 | -------------------------------------------------------------------------------- /images/GT_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_1.png -------------------------------------------------------------------------------- /images/GT_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_2.png -------------------------------------------------------------------------------- /images/GT_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/GT_3.png -------------------------------------------------------------------------------- /images/input_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_1.gif -------------------------------------------------------------------------------- /images/input_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_2.gif -------------------------------------------------------------------------------- /images/input_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/input_3.gif -------------------------------------------------------------------------------- /images/pred_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_1.png -------------------------------------------------------------------------------- /images/pred_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_2.png -------------------------------------------------------------------------------- /images/pred_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ClementPinard/FlowNetPytorch/990dba7e37c2374df96698f34f011e64d2d1fff0/images/pred_3.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import flow_transforms 13 | import models 14 | import datasets 15 | from multiscaleloss import multiscaleEPE, realEPE 16 | import datetime 17 | from torch.utils.tensorboard import SummaryWriter 18 | from util import flow2rgb, AverageMeter, save_checkpoint 19 | import numpy as np 20 | 21 | model_names = sorted( 22 | name for name in models.__dict__ if name.islower() and not name.startswith("__") 23 | ) 24 | dataset_names = sorted(name for name in datasets.__all__) 25 | 26 | parser = argparse.ArgumentParser( 27 | description="PyTorch FlowNet Training on several datasets", 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 29 | ) 30 | parser.add_argument("data", metavar="DIR", help="path to dataset") 31 | parser.add_argument( 32 | "--dataset", 33 | metavar="DATASET", 34 | default="flying_chairs", 35 | choices=dataset_names, 36 | help="dataset type : " + " | ".join(dataset_names), 37 | ) 38 | group = parser.add_mutually_exclusive_group() 39 | group.add_argument( 40 | "-s", "--split-file", default=None, type=str, help="test-val split file" 41 | ) 42 | group.add_argument( 43 | "--split-value", 44 | default=0.8, 45 | type=float, 46 | help="test-val split proportion between 0 (only test) and 1 (only train), " 47 | "will be overwritten if a split file is set", 48 | ) 49 | parser.add_argument( 50 | "--split-seed", 51 | type=int, 52 | default=None, 53 | help="Seed the train-val split to enforce reproducibility (consistent restart too)", 54 | ) 55 | parser.add_argument( 56 | "--arch", 57 | "-a", 58 | metavar="ARCH", 59 | default="flownets", 60 | choices=model_names, 61 | help="model architecture, overwritten if pretrained is specified: " 62 | + " | ".join(model_names), 63 | ) 64 | parser.add_argument( 65 | "--solver", default="adam", choices=["adam", "sgd"], help="solver algorithms" 66 | ) 67 | parser.add_argument( 68 | "-j", 69 | "--workers", 70 | default=8, 71 | type=int, 72 | metavar="N", 73 | help="number of data loading workers", 74 | ) 75 | parser.add_argument( 76 | "--epochs", default=300, type=int, metavar="N", help="number of total epochs to run" 77 | ) 78 | parser.add_argument( 79 | "--start-epoch", 80 | default=0, 81 | type=int, 82 | metavar="N", 83 | help="manual epoch number (useful on restarts)", 84 | ) 85 | parser.add_argument( 86 | "--epoch-size", 87 | default=1000, 88 | type=int, 89 | metavar="N", 90 | help="manual epoch size (will match dataset size if set to 0)", 91 | ) 92 | parser.add_argument( 93 | "-b", "--batch-size", default=8, type=int, metavar="N", help="mini-batch size" 94 | ) 95 | parser.add_argument( 96 | "--lr", 97 | "--learning-rate", 98 | default=0.0001, 99 | type=float, 100 | metavar="LR", 101 | help="initial learning rate", 102 | ) 103 | parser.add_argument( 104 | "--momentum", 105 | default=0.9, 106 | type=float, 107 | metavar="M", 108 | help="momentum for sgd, alpha parameter for adam", 109 | ) 110 | parser.add_argument( 111 | "--beta", default=0.999, type=float, metavar="M", help="beta parameter for adam" 112 | ) 113 | parser.add_argument( 114 | "--weight-decay", "--wd", default=4e-4, type=float, metavar="W", help="weight decay" 115 | ) 116 | parser.add_argument( 117 | "--bias-decay", default=0, type=float, metavar="B", help="bias decay" 118 | ) 119 | parser.add_argument( 120 | "--multiscale-weights", 121 | "-w", 122 | default=[0.005, 0.01, 0.02, 0.08, 0.32], 123 | type=float, 124 | nargs=5, 125 | help="training weight for each scale, from highest resolution (flow2) to lowest (flow6)", 126 | metavar=("W2", "W3", "W4", "W5", "W6"), 127 | ) 128 | parser.add_argument( 129 | "--sparse", 130 | action="store_true", 131 | help="look for NaNs in target flow when computing EPE, avoid if flow is garantied to be dense," 132 | "automatically seleted when choosing a KITTIdataset", 133 | ) 134 | parser.add_argument( 135 | "--print-freq", "-p", default=10, type=int, metavar="N", help="print frequency" 136 | ) 137 | parser.add_argument( 138 | "-e", 139 | "--evaluate", 140 | dest="evaluate", 141 | action="store_true", 142 | help="evaluate model on validation set", 143 | ) 144 | parser.add_argument( 145 | "--pretrained", dest="pretrained", default=None, help="path to pre-trained model" 146 | ) 147 | parser.add_argument( 148 | "--no-date", action="store_true", help="don't append date timestamp to folder" 149 | ) 150 | parser.add_argument( 151 | "--div-flow", 152 | default=20, 153 | help="value by which flow will be divided. Original value is 20 but 1 with batchNorm gives good results", 154 | ) 155 | parser.add_argument( 156 | "--milestones", 157 | default=[100, 150, 200], 158 | metavar="N", 159 | nargs="*", 160 | help="epochs at which learning rate is divided by 2", 161 | ) 162 | 163 | 164 | best_EPE = -1 165 | n_iter = 0 166 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 167 | 168 | 169 | def main(): 170 | global args, best_EPE, n_iter 171 | args = parser.parse_args() 172 | save_path = "{},{},{}epochs{},b{},lr{}".format( 173 | args.arch, 174 | args.solver, 175 | args.epochs, 176 | ",epochSize" + str(args.epoch_size) if args.epoch_size > 0 else "", 177 | args.batch_size, 178 | args.lr, 179 | ) 180 | if not args.no_date: 181 | timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") 182 | save_path = os.path.join(timestamp, save_path) 183 | save_path = os.path.join(args.dataset, save_path) 184 | print("=> will save everything to {}".format(save_path)) 185 | if not os.path.exists(save_path): 186 | os.makedirs(save_path) 187 | 188 | if args.split_seed is not None: 189 | np.random.seed(args.split_seed) 190 | 191 | train_writer = SummaryWriter(os.path.join(save_path, "train")) 192 | test_writer = SummaryWriter(os.path.join(save_path, "test")) 193 | output_writers = [] 194 | for i in range(3): 195 | output_writers.append(SummaryWriter(os.path.join(save_path, "test", str(i)))) 196 | 197 | # Data loading code 198 | input_transform = transforms.Compose( 199 | [ 200 | flow_transforms.ArrayToTensor(), 201 | transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), 202 | transforms.Normalize(mean=[0.45, 0.432, 0.411], std=[1, 1, 1]), 203 | ] 204 | ) 205 | target_transform = transforms.Compose( 206 | [ 207 | flow_transforms.ArrayToTensor(), 208 | transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow]), 209 | ] 210 | ) 211 | 212 | if "KITTI" in args.dataset: 213 | args.sparse = True 214 | if args.sparse: 215 | co_transform = flow_transforms.Compose( 216 | [ 217 | flow_transforms.RandomCrop((320, 448)), 218 | flow_transforms.RandomVerticalFlip(), 219 | flow_transforms.RandomHorizontalFlip(), 220 | ] 221 | ) 222 | else: 223 | co_transform = flow_transforms.Compose( 224 | [ 225 | flow_transforms.RandomTranslate(10), 226 | flow_transforms.RandomRotate(10, 5), 227 | flow_transforms.RandomCrop((320, 448)), 228 | flow_transforms.RandomVerticalFlip(), 229 | flow_transforms.RandomHorizontalFlip(), 230 | ] 231 | ) 232 | 233 | print("=> fetching img pairs in '{}'".format(args.data)) 234 | train_set, test_set = datasets.__dict__[args.dataset]( 235 | args.data, 236 | transform=input_transform, 237 | target_transform=target_transform, 238 | co_transform=co_transform, 239 | split=args.split_file if args.split_file else args.split_value, 240 | split_save_path=os.path.join(save_path, "split.txt"), 241 | ) 242 | print( 243 | "{} samples found, {} train samples and {} test samples ".format( 244 | len(test_set) + len(train_set), len(train_set), len(test_set) 245 | ) 246 | ) 247 | n_iter = args.start_epoch * len(train_set) 248 | train_loader = torch.utils.data.DataLoader( 249 | train_set, 250 | batch_size=args.batch_size, 251 | num_workers=args.workers, 252 | pin_memory=True, 253 | shuffle=True, 254 | ) 255 | val_loader = torch.utils.data.DataLoader( 256 | test_set, 257 | batch_size=args.batch_size, 258 | num_workers=args.workers, 259 | pin_memory=True, 260 | shuffle=False, 261 | ) 262 | 263 | # create model 264 | if args.pretrained: 265 | network_data = torch.load(args.pretrained) 266 | args.arch = network_data["arch"] 267 | print("=> using pre-trained model '{}'".format(args.arch)) 268 | else: 269 | network_data = None 270 | print("=> creating model '{}'".format(args.arch)) 271 | 272 | model = models.__dict__[args.arch](network_data).to(device) 273 | 274 | assert args.solver in ["adam", "sgd"] 275 | print("=> setting {} solver".format(args.solver)) 276 | param_groups = [ 277 | {"params": model.bias_parameters(), "weight_decay": args.bias_decay}, 278 | {"params": model.weight_parameters(), "weight_decay": args.weight_decay}, 279 | ] 280 | 281 | if device.type == "cuda": 282 | model = torch.nn.DataParallel(model).cuda() 283 | cudnn.benchmark = True 284 | 285 | if args.solver == "adam": 286 | optimizer = torch.optim.Adam( 287 | param_groups, args.lr, betas=(args.momentum, args.beta) 288 | ) 289 | elif args.solver == "sgd": 290 | optimizer = torch.optim.SGD(param_groups, args.lr, momentum=args.momentum) 291 | 292 | if args.evaluate: 293 | best_EPE = validate(val_loader, model, 0, output_writers) 294 | return 295 | 296 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 297 | optimizer, milestones=args.milestones, gamma=0.5 298 | ) 299 | 300 | for epoch in range(args.start_epoch, args.epochs): 301 | # train for one epoch 302 | train_loss, train_EPE = train( 303 | train_loader, model, optimizer, epoch, train_writer 304 | ) 305 | scheduler.step() 306 | train_writer.add_scalar("mean EPE", train_EPE, epoch) 307 | 308 | # evaluate on validation set 309 | 310 | with torch.no_grad(): 311 | EPE = validate(val_loader, model, epoch, output_writers) 312 | test_writer.add_scalar("mean EPE", EPE, epoch) 313 | 314 | if best_EPE < 0: 315 | best_EPE = EPE 316 | 317 | is_best = EPE < best_EPE 318 | best_EPE = min(EPE, best_EPE) 319 | save_checkpoint( 320 | { 321 | "epoch": epoch + 1, 322 | "arch": args.arch, 323 | "state_dict": model.module.state_dict(), 324 | "best_EPE": best_EPE, 325 | "div_flow": args.div_flow, 326 | }, 327 | is_best, 328 | save_path, 329 | ) 330 | 331 | 332 | def train(train_loader, model, optimizer, epoch, train_writer): 333 | global n_iter, args 334 | batch_time = AverageMeter() 335 | data_time = AverageMeter() 336 | losses = AverageMeter() 337 | flow2_EPEs = AverageMeter() 338 | 339 | epoch_size = ( 340 | len(train_loader) 341 | if args.epoch_size == 0 342 | else min(len(train_loader), args.epoch_size) 343 | ) 344 | 345 | # switch to train mode 346 | model.train() 347 | 348 | end = time.time() 349 | 350 | for i, (input, target) in enumerate(train_loader): 351 | # measure data loading time 352 | data_time.update(time.time() - end) 353 | target = target.to(device) 354 | input = torch.cat(input, 1).to(device) 355 | 356 | # compute output 357 | output = model(input) 358 | if args.sparse: 359 | # Since Target pooling is not very precise when sparse, 360 | # take the highest resolution prediction and upsample it instead of downsampling target 361 | h, w = target.size()[-2:] 362 | output = [F.interpolate(output[0], (h, w)), *output[1:]] 363 | 364 | loss = multiscaleEPE( 365 | output, target, weights=args.multiscale_weights, sparse=args.sparse 366 | ) 367 | flow2_EPE = args.div_flow * realEPE(output[0], target, sparse=args.sparse) 368 | # record loss and EPE 369 | losses.update(loss.item(), target.size(0)) 370 | train_writer.add_scalar("train_loss", loss.item(), n_iter) 371 | flow2_EPEs.update(flow2_EPE.item(), target.size(0)) 372 | 373 | # compute gradient and do optimization step 374 | optimizer.zero_grad() 375 | loss.backward() 376 | optimizer.step() 377 | 378 | # measure elapsed time 379 | batch_time.update(time.time() - end) 380 | end = time.time() 381 | 382 | if i % args.print_freq == 0: 383 | print( 384 | "Epoch: [{0}][{1}/{2}]\t Time {3}\t Data {4}\t Loss {5}\t EPE {6}".format( 385 | epoch, i, epoch_size, batch_time, data_time, losses, flow2_EPEs 386 | ) 387 | ) 388 | n_iter += 1 389 | if i >= epoch_size: 390 | break 391 | 392 | return losses.avg, flow2_EPEs.avg 393 | 394 | 395 | def validate(val_loader, model, epoch, output_writers): 396 | global args 397 | 398 | batch_time = AverageMeter() 399 | flow2_EPEs = AverageMeter() 400 | 401 | # switch to evaluate mode 402 | model.eval() 403 | 404 | end = time.time() 405 | for i, (input, target) in enumerate(val_loader): 406 | target = target.to(device) 407 | input = torch.cat(input, 1).to(device) 408 | 409 | # compute output 410 | output = model(input) 411 | flow2_EPE = args.div_flow * realEPE(output, target, sparse=args.sparse) 412 | # record EPE 413 | flow2_EPEs.update(flow2_EPE.item(), target.size(0)) 414 | 415 | # measure elapsed time 416 | batch_time.update(time.time() - end) 417 | end = time.time() 418 | 419 | if i < len(output_writers): # log first output of first batches 420 | if epoch == args.start_epoch: 421 | mean_values = torch.tensor( 422 | [0.45, 0.432, 0.411], dtype=input.dtype 423 | ).view(3, 1, 1) 424 | output_writers[i].add_image( 425 | "GroundTruth", flow2rgb(args.div_flow * target[0], max_value=10), 0 426 | ) 427 | output_writers[i].add_image( 428 | "Inputs", (input[0, :3].cpu() + mean_values).clamp(0, 1), 0 429 | ) 430 | output_writers[i].add_image( 431 | "Inputs", (input[0, 3:].cpu() + mean_values).clamp(0, 1), 1 432 | ) 433 | output_writers[i].add_image( 434 | "FlowNet Outputs", 435 | flow2rgb(args.div_flow * output[0], max_value=10), 436 | epoch, 437 | ) 438 | 439 | if i % args.print_freq == 0: 440 | print( 441 | "Test: [{0}/{1}]\t Time {2}\t EPE {3}".format( 442 | i, len(val_loader), batch_time, flow2_EPEs 443 | ) 444 | ) 445 | 446 | print(" * EPE {:.3f}".format(flow2_EPEs.avg)) 447 | 448 | return flow2_EPEs.avg 449 | 450 | 451 | if __name__ == "__main__": 452 | main() 453 | -------------------------------------------------------------------------------- /models/FlowNetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import kaiming_normal_, constant_ 4 | from .util import conv, predict_flow, deconv, crop_like, correlate 5 | 6 | __all__ = ["flownetc", "flownetc_bn"] 7 | 8 | 9 | class FlowNetC(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, batchNorm=True): 13 | super(FlowNetC, self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 17 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 18 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 19 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 20 | 21 | self.conv3_1 = conv(self.batchNorm, 473, 256) 22 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 23 | self.conv4_1 = conv(self.batchNorm, 512, 512) 24 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 25 | self.conv5_1 = conv(self.batchNorm, 512, 512) 26 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 27 | self.conv6_1 = conv(self.batchNorm, 1024, 1024) 28 | 29 | self.deconv5 = deconv(1024, 512) 30 | self.deconv4 = deconv(1026, 256) 31 | self.deconv3 = deconv(770, 128) 32 | self.deconv2 = deconv(386, 64) 33 | 34 | self.predict_flow6 = predict_flow(1024) 35 | self.predict_flow5 = predict_flow(1026) 36 | self.predict_flow4 = predict_flow(770) 37 | self.predict_flow3 = predict_flow(386) 38 | self.predict_flow2 = predict_flow(194) 39 | 40 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 41 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 42 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 43 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 47 | kaiming_normal_(m.weight, 0.1) 48 | if m.bias is not None: 49 | constant_(m.bias, 0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | constant_(m.weight, 1) 52 | constant_(m.bias, 0) 53 | 54 | def forward(self, x): 55 | x1 = x[:, :3] 56 | x2 = x[:, 3:] 57 | 58 | out_conv1a = self.conv1(x1) 59 | out_conv2a = self.conv2(out_conv1a) 60 | out_conv3a = self.conv3(out_conv2a) 61 | 62 | out_conv1b = self.conv1(x2) 63 | out_conv2b = self.conv2(out_conv1b) 64 | out_conv3b = self.conv3(out_conv2b) 65 | 66 | out_conv_redir = self.conv_redir(out_conv3a) 67 | out_correlation = correlate(out_conv3a, out_conv3b) 68 | 69 | in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1) 70 | 71 | out_conv3 = self.conv3_1(in_conv3_1) 72 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 73 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 74 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 75 | 76 | flow6 = self.predict_flow6(out_conv6) 77 | flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5) 78 | out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5) 79 | 80 | concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) 81 | flow5 = self.predict_flow5(concat5) 82 | flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4) 83 | out_deconv4 = crop_like(self.deconv4(concat5), out_conv4) 84 | 85 | concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) 86 | flow4 = self.predict_flow4(concat4) 87 | flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3) 88 | out_deconv3 = crop_like(self.deconv3(concat4), out_conv3) 89 | 90 | concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) 91 | flow3 = self.predict_flow3(concat3) 92 | flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a) 93 | out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a) 94 | 95 | concat2 = torch.cat((out_conv2a, out_deconv2, flow3_up), 1) 96 | flow2 = self.predict_flow2(concat2) 97 | 98 | if self.training: 99 | return flow2, flow3, flow4, flow5, flow6 100 | else: 101 | return flow2 102 | 103 | def weight_parameters(self): 104 | return [param for name, param in self.named_parameters() if "weight" in name] 105 | 106 | def bias_parameters(self): 107 | return [param for name, param in self.named_parameters() if "bias" in name] 108 | 109 | 110 | def flownetc(data=None): 111 | """FlowNetS model architecture from the 112 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) 113 | 114 | Args: 115 | data : pretrained weights of the network. will create a new one if not set 116 | """ 117 | model = FlowNetC(batchNorm=False) 118 | if data is not None: 119 | model.load_state_dict(data["state_dict"]) 120 | return model 121 | 122 | 123 | def flownetc_bn(data=None): 124 | """FlowNetS model architecture from the 125 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) 126 | 127 | Args: 128 | data : pretrained weights of the network. will create a new one if not set 129 | """ 130 | model = FlowNetC(batchNorm=True) 131 | if data is not None: 132 | model.load_state_dict(data["state_dict"]) 133 | return model 134 | -------------------------------------------------------------------------------- /models/FlowNetS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import kaiming_normal_, constant_ 4 | from .util import conv, predict_flow, deconv, crop_like 5 | 6 | __all__ = ["flownets", "flownets_bn"] 7 | 8 | 9 | class FlowNetS(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, batchNorm=True): 13 | super(FlowNetS, self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv1 = conv(self.batchNorm, 6, 64, kernel_size=7, stride=2) 17 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 18 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 19 | self.conv3_1 = conv(self.batchNorm, 256, 256) 20 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 21 | self.conv4_1 = conv(self.batchNorm, 512, 512) 22 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 23 | self.conv5_1 = conv(self.batchNorm, 512, 512) 24 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 25 | self.conv6_1 = conv(self.batchNorm, 1024, 1024) 26 | 27 | self.deconv5 = deconv(1024, 512) 28 | self.deconv4 = deconv(1026, 256) 29 | self.deconv3 = deconv(770, 128) 30 | self.deconv2 = deconv(386, 64) 31 | 32 | self.predict_flow6 = predict_flow(1024) 33 | self.predict_flow5 = predict_flow(1026) 34 | self.predict_flow4 = predict_flow(770) 35 | self.predict_flow3 = predict_flow(386) 36 | self.predict_flow2 = predict_flow(194) 37 | 38 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 39 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 40 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 41 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 45 | kaiming_normal_(m.weight, 0.1) 46 | if m.bias is not None: 47 | constant_(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | constant_(m.weight, 1) 50 | constant_(m.bias, 0) 51 | 52 | def forward(self, x): 53 | out_conv2 = self.conv2(self.conv1(x)) 54 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 55 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 56 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 57 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 58 | 59 | flow6 = self.predict_flow6(out_conv6) 60 | flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5) 61 | out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5) 62 | 63 | concat5 = torch.cat((out_conv5, out_deconv5, flow6_up), 1) 64 | flow5 = self.predict_flow5(concat5) 65 | flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4) 66 | out_deconv4 = crop_like(self.deconv4(concat5), out_conv4) 67 | 68 | concat4 = torch.cat((out_conv4, out_deconv4, flow5_up), 1) 69 | flow4 = self.predict_flow4(concat4) 70 | flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3) 71 | out_deconv3 = crop_like(self.deconv3(concat4), out_conv3) 72 | 73 | concat3 = torch.cat((out_conv3, out_deconv3, flow4_up), 1) 74 | flow3 = self.predict_flow3(concat3) 75 | flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2) 76 | out_deconv2 = crop_like(self.deconv2(concat3), out_conv2) 77 | 78 | concat2 = torch.cat((out_conv2, out_deconv2, flow3_up), 1) 79 | flow2 = self.predict_flow2(concat2) 80 | 81 | if self.training: 82 | return flow2, flow3, flow4, flow5, flow6 83 | else: 84 | return flow2 85 | 86 | def weight_parameters(self): 87 | return [param for name, param in self.named_parameters() if "weight" in name] 88 | 89 | def bias_parameters(self): 90 | return [param for name, param in self.named_parameters() if "bias" in name] 91 | 92 | 93 | def flownets(data=None): 94 | """FlowNetS model architecture from the 95 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) 96 | 97 | Args: 98 | data : pretrained weights of the network. will create a new one if not set 99 | """ 100 | model = FlowNetS(batchNorm=False) 101 | if data is not None: 102 | model.load_state_dict(data["state_dict"]) 103 | return model 104 | 105 | 106 | def flownets_bn(data=None): 107 | """FlowNetS model architecture from the 108 | "Learning Optical Flow with Convolutional Networks" paper (https://arxiv.org/abs/1504.06852) 109 | 110 | Args: 111 | data : pretrained weights of the network. will create a new one if not set 112 | """ 113 | model = FlowNetS(batchNorm=True) 114 | if data is not None: 115 | model.load_state_dict(data["state_dict"]) 116 | return model 117 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .FlowNetS import * 2 | from .FlowNetC import * 3 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | try: 5 | from spatial_correlation_sampler import spatial_correlation_sample 6 | except ImportError as e: 7 | import warnings 8 | 9 | with warnings.catch_warnings(): 10 | warnings.filterwarnings("default", category=ImportWarning) 11 | warnings.warn( 12 | "failed to load custom correlation module" "which is needed for FlowNetC", 13 | ImportWarning, 14 | ) 15 | 16 | 17 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 18 | if batchNorm: 19 | return nn.Sequential( 20 | nn.Conv2d( 21 | in_planes, 22 | out_planes, 23 | kernel_size=kernel_size, 24 | stride=stride, 25 | padding=(kernel_size - 1) // 2, 26 | bias=False, 27 | ), 28 | nn.BatchNorm2d(out_planes), 29 | nn.LeakyReLU(0.1, inplace=True), 30 | ) 31 | else: 32 | return nn.Sequential( 33 | nn.Conv2d( 34 | in_planes, 35 | out_planes, 36 | kernel_size=kernel_size, 37 | stride=stride, 38 | padding=(kernel_size - 1) // 2, 39 | bias=True, 40 | ), 41 | nn.LeakyReLU(0.1, inplace=True), 42 | ) 43 | 44 | 45 | def predict_flow(in_planes): 46 | return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False) 47 | 48 | 49 | def deconv(in_planes, out_planes): 50 | return nn.Sequential( 51 | nn.ConvTranspose2d( 52 | in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False 53 | ), 54 | nn.LeakyReLU(0.1, inplace=True), 55 | ) 56 | 57 | 58 | def correlate(input1, input2): 59 | out_corr = spatial_correlation_sample( 60 | input1, 61 | input2, 62 | kernel_size=1, 63 | patch_size=21, 64 | stride=1, 65 | padding=0, 66 | dilation_patch=2, 67 | ) 68 | # collate dimensions 1 and 2 in order to be treated as a 69 | # regular 4D tensor 70 | b, ph, pw, h, w = out_corr.size() 71 | out_corr = out_corr.view(b, ph * pw, h, w) / input1.size(1) 72 | return F.leaky_relu_(out_corr, 0.1) 73 | 74 | 75 | def crop_like(input, target): 76 | if input.size()[2:] == target.size()[2:]: 77 | return input 78 | else: 79 | return input[:, :, : target.size(2), : target.size(3)] 80 | -------------------------------------------------------------------------------- /multiscaleloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def EPE(input_flow, target_flow, sparse=False, mean=True): 6 | EPE_map = torch.norm(target_flow - input_flow, 2, 1) 7 | batch_size = EPE_map.size(0) 8 | if sparse: 9 | # invalid flow is defined with both flow coordinates to be exactly 0 10 | mask = (target_flow[:, 0] == 0) & (target_flow[:, 1] == 0) 11 | 12 | EPE_map = EPE_map[~mask] 13 | if mean: 14 | return EPE_map.mean() 15 | else: 16 | return EPE_map.sum() / batch_size 17 | 18 | 19 | def sparse_max_pool(input, size): 20 | """Downsample the input by considering 0 values as invalid. 21 | 22 | Unfortunately, no generic interpolation mode can resize a sparse map correctly, 23 | the strategy here is to use max pooling for positive values and "min pooling" 24 | for negative values, the two results are then summed. 25 | This technique allows sparsity to be minized, contrary to nearest interpolation, 26 | which could potentially lose information for isolated data points.""" 27 | 28 | positive = (input > 0).float() 29 | negative = (input < 0).float() 30 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d( 31 | -input * negative, size 32 | ) 33 | return output 34 | 35 | 36 | def multiscaleEPE(network_output, target_flow, weights=None, sparse=False): 37 | def one_scale(output, target, sparse): 38 | 39 | b, _, h, w = output.size() 40 | 41 | if sparse: 42 | target_scaled = sparse_max_pool(target, (h, w)) 43 | else: 44 | target_scaled = F.interpolate(target, (h, w), mode="area") 45 | return EPE(output, target_scaled, sparse, mean=False) 46 | 47 | if type(network_output) not in [tuple, list]: 48 | network_output = [network_output] 49 | if weights is None: 50 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article 51 | assert len(weights) == len(network_output) 52 | 53 | loss = 0 54 | for output, weight in zip(network_output, weights): 55 | loss += weight * one_scale(output, target_flow, sparse) 56 | return loss 57 | 58 | 59 | def realEPE(output, target, sparse=False): 60 | b, _, h, w = target.size() 61 | upsampled_output = F.interpolate( 62 | output, (h, w), mode="bilinear", align_corners=False 63 | ) 64 | return EPE(upsampled_output, target, sparse, mean=True) 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2 2 | torchvision 3 | numpy 4 | spatial-correlation-sampler>=0.2.1 5 | tensorboard 6 | imageio 7 | argparse 8 | path 9 | tqdm 10 | scipy 11 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from path import Path 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.nn.functional as F 7 | import models 8 | from tqdm import tqdm 9 | 10 | import torchvision.transforms as transforms 11 | import flow_transforms 12 | from imageio import imread, imwrite 13 | import numpy as np 14 | from util import flow2rgb 15 | 16 | model_names = sorted( 17 | name for name in models.__dict__ if name.islower() and not name.startswith("__") 18 | ) 19 | 20 | 21 | parser = argparse.ArgumentParser( 22 | description="PyTorch FlowNet inference on a folder of img pairs", 23 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 24 | ) 25 | parser.add_argument( 26 | "data", 27 | metavar="DIR", 28 | help="path to images folder, image names must match '[name]0.[ext]' and '[name]1.[ext]'", 29 | ) 30 | parser.add_argument("pretrained", metavar="PTH", help="path to pre-trained model") 31 | parser.add_argument( 32 | "--output", 33 | "-o", 34 | metavar="DIR", 35 | default=None, 36 | help="path to output folder. If not set, will be created in data folder", 37 | ) 38 | parser.add_argument( 39 | "--output-value", 40 | "-v", 41 | choices=["raw", "vis", "both"], 42 | default="both", 43 | help="which value to output, between raw input (as a npy file) and color vizualisation (as an image file)." 44 | " If not set, will output both", 45 | ) 46 | parser.add_argument( 47 | "--div-flow", 48 | default=20, 49 | type=float, 50 | help="value by which flow will be divided. overwritten if stored in pretrained file", 51 | ) 52 | parser.add_argument( 53 | "--img-exts", 54 | metavar="EXT", 55 | default=["png", "jpg", "bmp", "ppm"], 56 | nargs="*", 57 | type=str, 58 | help="images extensions to glob", 59 | ) 60 | parser.add_argument( 61 | "--max_flow", 62 | default=None, 63 | type=float, 64 | help="max flow value. Flow map color is saturated above this value. If not set, will use flow map's max value", 65 | ) 66 | parser.add_argument( 67 | "--upsampling", 68 | "-u", 69 | choices=["nearest", "bilinear"], 70 | default=None, 71 | help="if not set, will output FlowNet raw input," 72 | "which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling", 73 | ) 74 | parser.add_argument( 75 | "--bidirectional", 76 | action="store_true", 77 | help="if set, will output invert flow (from 1 to 0) along with regular flow", 78 | ) 79 | 80 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 81 | 82 | 83 | @torch.no_grad() 84 | def main(): 85 | global args, save_path 86 | args = parser.parse_args() 87 | 88 | if args.output_value == "both": 89 | output_string = "raw output and RGB visualization" 90 | elif args.output_value == "raw": 91 | output_string = "raw output" 92 | elif args.output_value == "vis": 93 | output_string = "RGB visualization" 94 | print("=> will save " + output_string) 95 | data_dir = Path(args.data) 96 | print("=> fetching img pairs in '{}'".format(args.data)) 97 | if args.output is None: 98 | save_path = data_dir / "flow" 99 | else: 100 | save_path = Path(args.output) 101 | print("=> will save everything to {}".format(save_path)) 102 | save_path.makedirs_p() 103 | # Data loading code 104 | input_transform = transforms.Compose( 105 | [ 106 | flow_transforms.ArrayToTensor(), 107 | transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), 108 | transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1]), 109 | ] 110 | ) 111 | 112 | img_pairs = [] 113 | for ext in args.img_exts: 114 | test_files = data_dir.files("*1.{}".format(ext)) 115 | for file in test_files: 116 | img_pair = file.parent / (file.stem[:-1] + "2.{}".format(ext)) 117 | if img_pair.isfile(): 118 | img_pairs.append([file, img_pair]) 119 | 120 | print("{} samples found".format(len(img_pairs))) 121 | # create model 122 | network_data = torch.load(args.pretrained) 123 | print("=> using pre-trained model '{}'".format(network_data["arch"])) 124 | model = models.__dict__[network_data["arch"]](network_data).to(device) 125 | model.eval() 126 | cudnn.benchmark = True 127 | 128 | if "div_flow" in network_data.keys(): 129 | args.div_flow = network_data["div_flow"] 130 | 131 | for img1_file, img2_file in tqdm(img_pairs): 132 | 133 | img1 = input_transform(imread(img1_file)) 134 | img2 = input_transform(imread(img2_file)) 135 | input_var = torch.cat([img1, img2]).unsqueeze(0) 136 | 137 | if args.bidirectional: 138 | # feed inverted pair along with normal pair 139 | inverted_input_var = torch.cat([img2, img1]).unsqueeze(0) 140 | input_var = torch.cat([input_var, inverted_input_var]) 141 | 142 | input_var = input_var.to(device) 143 | # compute output 144 | output = model(input_var) 145 | if args.upsampling is not None: 146 | output = F.interpolate( 147 | output, size=img1.size()[-2:], mode=args.upsampling, align_corners=False 148 | ) 149 | for suffix, flow_output in zip(["flow", "inv_flow"], output): 150 | filename = save_path / "{}{}".format(img1_file.stem[:-1], suffix) 151 | if args.output_value in ["vis", "both"]: 152 | rgb_flow = flow2rgb( 153 | args.div_flow * flow_output, max_value=args.max_flow 154 | ) 155 | to_save = (rgb_flow * 255).astype(np.uint8).transpose(1, 2, 0) 156 | imwrite(filename + ".png", to_save) 157 | if args.output_value in ["raw", "both"]: 158 | # Make the flow map a HxWx2 array as in .flo files 159 | to_save = (args.div_flow * flow_output).cpu().numpy().transpose(1, 2, 0) 160 | np.save(filename + ".npy", to_save) 161 | 162 | 163 | if __name__ == "__main__": 164 | main() 165 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import shutil 4 | import torch 5 | 6 | 7 | def save_checkpoint(state, is_best, save_path, filename="checkpoint.pth.tar"): 8 | torch.save(state, os.path.join(save_path, filename)) 9 | if is_best: 10 | shutil.copyfile( 11 | os.path.join(save_path, filename), 12 | os.path.join(save_path, "model_best.pth.tar"), 13 | ) 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | 34 | def __repr__(self): 35 | return "{:.3f} ({:.3f})".format(self.val, self.avg) 36 | 37 | 38 | def flow2rgb(flow_map, max_value): 39 | flow_map_np = flow_map.detach().cpu().numpy() 40 | _, h, w = flow_map_np.shape 41 | flow_map_np[:, (flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float("nan") 42 | rgb_map = np.ones((3, h, w)).astype(np.float32) 43 | if max_value is not None: 44 | normalized_flow_map = flow_map_np / max_value 45 | else: 46 | normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max()) 47 | rgb_map[0] += normalized_flow_map[0] 48 | rgb_map[1] -= 0.5 * (normalized_flow_map[0] + normalized_flow_map[1]) 49 | rgb_map[2] += normalized_flow_map[1] 50 | return rgb_map.clip(0, 1) 51 | --------------------------------------------------------------------------------