├── .gitignore ├── LICENSE ├── README.md ├── example.py ├── flow_vis_torch ├── __init__.py └── flow_vis_torch.py ├── images ├── frame_0005_flow_vis.png ├── frame_0005_flow_vis_torch.png ├── frame_0014_flow_vis.png ├── frame_0014_flow_vis_torch.png ├── frame_0023_flow_vis.png └── frame_0023_flow_vis_torch.png ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Flow files 2 | *.flo 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 135 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 136 | # Source: https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 137 | 138 | # User-specific stuff 139 | .idea/**/workspace.xml 140 | .idea/**/tasks.xml 141 | .idea/**/usage.statistics.xml 142 | .idea/**/dictionaries 143 | .idea/**/shelf 144 | 145 | # Generated files 146 | .idea/**/contentModel.xml 147 | 148 | # Sensitive or high-churn files 149 | .idea/**/dataSources/ 150 | .idea/**/dataSources.ids 151 | .idea/**/dataSources.local.xml 152 | .idea/**/sqlDataSources.xml 153 | .idea/**/dynamic.xml 154 | .idea/**/uiDesigner.xml 155 | .idea/**/dbnavigator.xml 156 | 157 | # Gradle 158 | .idea/**/gradle.xml 159 | .idea/**/libraries 160 | 161 | # Gradle and Maven with auto-import 162 | # When using Gradle or Maven with auto-import, you should exclude module files, 163 | # since they will be recreated, and may cause churn. Uncomment if using 164 | # auto-import. 165 | # .idea/artifacts 166 | # .idea/compiler.xml 167 | # .idea/jarRepositories.xml 168 | # .idea/modules.xml 169 | # .idea/*.iml 170 | # .idea/modules 171 | # *.iml 172 | # *.ipr 173 | 174 | # CMake 175 | cmake-build-*/ 176 | 177 | # Mongo Explorer plugin 178 | .idea/**/mongoSettings.xml 179 | 180 | # File-based project format 181 | *.iws 182 | 183 | # IntelliJ 184 | out/ 185 | 186 | # mpeltonen/sbt-idea plugin 187 | .idea_modules/ 188 | 189 | # JIRA plugin 190 | atlassian-ide-plugin.xml 191 | 192 | # Cursive Clojure plugin 193 | .idea/replstate.xml 194 | 195 | # Crashlytics plugin (for Android Studio and IntelliJ) 196 | com_crashlytics_export_strings.xml 197 | crashlytics.properties 198 | crashlytics-build.properties 199 | fabric.properties 200 | 201 | # Editor-based Rest Client 202 | .idea/httpRequests 203 | 204 | # Android studio 3.1+ serialized cache file 205 | .idea/caches/build_file_checksums.ser 206 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Christoph Reich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optical Flow Visualization for PyTorch 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/blob/master/LICENSE) 4 | 5 | This repository is a PyTorch fork of the [OpticalFlow_Visualization](https://github.com/tomrunia/OpticalFlow_Visualization) (flow_vis) repository, originally published under the [MIT license](https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/LICENSE.txt). The optical flow visualization follows the color encoding proposed in the paper "[A database and evaluation methodology for optical flow](https://link.springer.com/content/pdf/10.1007/s11263-010-0390-2.pdf)" by Baker et al. published at ICCV 2007 [1]. 6 | 7 | ## Installation 8 | 9 | Simply run the following command to install `flow_vis_torch`. 10 | 11 | ```shell script 12 | pip install git+https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | Convert a given flow of the shape `[batch size (optional), 2, height, width]` to an RGB image of the shape `[batch size (optional), 3, height, width]` by calling `flow_vis_torch.flow_to_color`. 18 | 19 | ```python 20 | import flow_vis_torch 21 | flow_rgb = flow_vis_torch.flow_to_color(flow) 22 | ``` 23 | 24 | For a detailed example have a look at the [example script](example.py). 25 | 26 | ## Visualizations 27 | 28 | Flow maps taken from the [MPI Sintel Flow Dataset](http://sintel.is.tue.mpg.de/) [2]. 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
Output flow_vis_torch Output flow_vis
1 2
3 4
5 6
48 | 49 | ## References 50 | 51 | ```bibtex 52 | [1] @inproceedings{Baker2007, 53 | title={{A Database and Evaluation Methodology for Optical Flow}}, 54 | author={Baker, Simon and Roth, Stefan and Scharstein, Daniel and Black, Michael J and Lewis, JP and Szeliski, Richard}, 55 | booktitle={{International Conference on Computer Vision (ICCV)}}, 56 | pages={1--8}, 57 | year={2007}, 58 | organization={IEEE} 59 | } 60 | ``` 61 | 62 | ```bibtex 63 | [2] @inproceedings{Butler2012, 64 | title={{A Naturalistic Open Source Movie for Optical Flow Evaluation}}, 65 | author={Butler, Daniel J and Wulff, Jonas and Stanley, Garrett B and Black, Michael J}, 66 | booktitle={{European Conference on Computer Vision (ECCV)}}, 67 | pages = {611--625}, 68 | year = {2012}, 69 | publisher={Springer} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | import flow_vis 5 | import flow_vis_torch 6 | 7 | 8 | def load_flo_file(file_path: str) -> torch.Tensor: 9 | """ 10 | Function loads a .flo file. 11 | :param file_path: (str) Path to flo file 12 | :return: (torch.Tensor) Torch.Tensor of the shape [2, height, width] 13 | """ 14 | # Open file 15 | with open(file_path, "rb") as file: 16 | # Load date 17 | data_array: np.ndarray = np.fromfile(file, np.float32)[3:] 18 | # Reshape data 19 | flow: np.ndarray = data_array.reshape((436, 1024, 2)) 20 | # To PyTorch and reshape 21 | return torch.from_numpy(flow).permute(2, 0, 1) 22 | 23 | 24 | if __name__ == '__main__': 25 | for file in ["frame_0014.flo", "frame_0005.flo", "frame_0023.flo"]: 26 | # Load flow maps 27 | flow = load_flo_file(file_path=file) 28 | # Standard package 29 | flow_rgb_flow_vis = torch.from_numpy( 30 | flow_vis.flow_to_color(flow.clone().permute(1, 2, 0).numpy()).astype(float)).permute(2, 0, 1) 31 | torchvision.utils.save_image(flow_rgb_flow_vis.float(), file.replace(".flo", "_flow_vis.png"), normalize=True) 32 | # PyTorch version 33 | flow_rgb_flow_vis_torch = flow_vis_torch.flow_to_color(flow) 34 | torchvision.utils.save_image(flow_rgb_flow_vis.float(), file.replace(".flo", "_flow_vis_torch.png"), 35 | normalize=True) 36 | -------------------------------------------------------------------------------- /flow_vis_torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_vis_torch import flow_to_color 2 | -------------------------------------------------------------------------------- /flow_vis_torch/flow_vis_torch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | from math import pi as PI 5 | 6 | 7 | def get_color_wheel(device: torch.device) -> torch.Tensor: 8 | """ 9 | Generates the color wheel. 10 | :param device: (torch.device) Device to be used 11 | :return: (torch.Tensor) Color wheel tensor of the shape [55, 3] 12 | """ 13 | # Set constants 14 | RY: int = 15 15 | YG: int = 6 16 | GC: int = 4 17 | CB: int = 11 18 | BM: int = 13 19 | MR: int = 6 20 | # Init color wheel 21 | color_wheel: torch.Tensor = torch.zeros((RY + YG + GC + CB + BM + MR, 3), dtype=torch.float32) 22 | # Init counter 23 | counter: int = 0 24 | # RY 25 | color_wheel[0:RY, 0] = 255 26 | color_wheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) 27 | counter: int = counter + RY 28 | # YG 29 | color_wheel[counter:counter + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) 30 | color_wheel[counter:counter + YG, 1] = 255 31 | counter: int = counter + YG 32 | # GC 33 | color_wheel[counter:counter + GC, 1] = 255 34 | color_wheel[counter:counter + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) 35 | counter: int = counter + GC 36 | # CB 37 | color_wheel[counter:counter + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) 38 | color_wheel[counter:counter + CB, 2] = 255 39 | counter: int = counter + CB 40 | # BM 41 | color_wheel[counter:counter + BM, 2] = 255 42 | color_wheel[counter:counter + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) 43 | counter: int = counter + BM 44 | # MR 45 | color_wheel[counter:counter + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) 46 | color_wheel[counter:counter + MR, 0] = 255 47 | # To device 48 | color_wheel: torch.Tensor = color_wheel.to(device) 49 | return color_wheel 50 | 51 | 52 | def _flow_hw_to_color(flow_vertical: torch.Tensor, flow_horizontal: torch.Tensor, 53 | color_wheel: torch.Tensor, device: torch.device) -> torch.Tensor: 54 | """ 55 | Private function applies the flow color wheel to flow components (vertical and horizontal). 56 | :param flow_vertical: (torch.Tensor) Vertical flow of the shape [height, width] 57 | :param flow_horizontal: (torch.Tensor) Horizontal flow of the shape [height, width] 58 | :param color_wheel: (torch.Tensor) Color wheel tensor of the shape [55, 3] 59 | :param: device: (torch.device) Device to be used 60 | :return: (torch.Tensor) Visualized flow of the shape [3, height, width] 61 | """ 62 | # Get shapes 63 | _, height, width = flow_vertical.shape 64 | # Init flow image 65 | flow_image: torch.Tensor = torch.zeros(3, height, width, dtype=torch.float32, device=device) 66 | # Get number of colors 67 | number_of_colors: int = color_wheel.shape[0] 68 | # Compute norm, angle and factors 69 | flow_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt() 70 | angle: torch.Tensor = torch.atan2(- flow_vertical, - flow_horizontal) / PI 71 | fk: torch.Tensor = (angle + 1.) / 2. * (number_of_colors - 1.) 72 | k0: torch.Tensor = torch.floor(fk).long() 73 | k1: torch.Tensor = k0 + 1 74 | k1[k1 == number_of_colors] = 0 75 | f: torch.Tensor = fk - k0 76 | # Iterate over color components 77 | for index in range(color_wheel.shape[1]): 78 | # Get component of all colors 79 | tmp: torch.Tensor = color_wheel[:, index] 80 | # Get colors 81 | color_0: torch.Tensor = tmp[k0] / 255. 82 | color_1: torch.Tensor = tmp[k1] / 255. 83 | # Compute color 84 | color: torch.Tensor = (1. - f) * color_0 + f * color_1 85 | # Get color index 86 | color_index: torch.Tensor = flow_norm <= 1 87 | # Set color saturation 88 | color[color_index] = 1 - flow_norm[color_index] * (1. - color[color_index]) 89 | color[~color_index] = color[~color_index] * 0.75 90 | # Set color in image 91 | flow_image[index] = torch.floor(255 * color) 92 | return flow_image 93 | 94 | 95 | def flow_to_color(flow: torch.Tensor, clip_flow: Optional[Union[float, torch.Tensor]] = None, 96 | normalize_over_video: bool = False) -> torch.Tensor: 97 | """ 98 | Function converts a given optical flow map into the classical color schema. 99 | :param flow: (torch.Tensor) Optical flow tensor of the shape [batch size (optional), 2, height, width]. 100 | :param clip_flow: (Optional[Union[float, torch.Tensor]]) Max value of flow values for clipping (default None). 101 | :param normalize_over_video: (bool) If true scale is normalized over the whole video (batch). 102 | :return: (torch.Tensor) Flow visualization (float tensor) with the shape [batch size (if used), 3, height, width]. 103 | """ 104 | # Check parameter types 105 | assert torch.is_tensor(flow), "Given flow map must be a torch.Tensor, {} given".format(type(flow)) 106 | assert torch.is_tensor(clip_flow) or isinstance(clip_flow, float) or clip_flow is None, \ 107 | "Given clip_flow parameter must be a float, a torch.Tensor, or None, {} given".format(type(clip_flow)) 108 | # Check shapes 109 | assert flow.ndimension() in [3, 4], \ 110 | "Given flow must be a 3D or 4D tensor, given tensor shape {}.".format(flow.shape) 111 | if torch.is_tensor(clip_flow): 112 | assert clip_flow.ndimension() == 0, \ 113 | "Given clip_flow tensor must be a scalar, given tensor shape {}.".format(clip_flow.shape) 114 | # Manage batch dimension 115 | batch_dimension: bool = True 116 | if flow.ndimension() == 3: 117 | flow = flow[None] 118 | batch_dimension: bool = False 119 | # Save shape 120 | batch_size, _, height, width = flow.shape 121 | # Check flow dimension 122 | assert flow.shape[1] == 2, "Flow dimension must have the shape 2 but tensor with {} given".format(flow.shape[1]) 123 | # Save device 124 | device: torch.device = flow.device 125 | # Clip flow if utilized 126 | if clip_flow is not None: 127 | flow = flow.clip(max=clip_flow) 128 | # Get horizontal and vertical flow 129 | flow_vertical: torch.Tensor = flow[:, 0:1] 130 | flow_horizontal: torch.Tensor = flow[:, 1:2] 131 | # Get max norm of flow 132 | flow_max_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt().view(batch_size, -1).max(dim=-1)[0] 133 | flow_max_norm: torch.Tensor = flow_max_norm.view(batch_size, 1, 1, 1) 134 | if normalize_over_video: 135 | flow_max_norm: Tensor = flow_max_norm.max(dim=0, keepdim=True)[0] 136 | # Normalize flow 137 | flow_vertical: torch.Tensor = flow_vertical / (flow_max_norm + 1e-05) 138 | flow_horizontal: torch.Tensor = flow_horizontal / (flow_max_norm + 1e-05) 139 | # Get color wheel 140 | color_wheel: torch.Tensor = get_color_wheel(device=device) 141 | # Init flow image 142 | flow_image = torch.zeros(batch_size, 3, height, width, device=device) 143 | # Iterate over batch dimension 144 | for index in range(batch_size): 145 | flow_image[index] = _flow_hw_to_color(flow_vertical=flow_vertical[index], 146 | flow_horizontal=flow_horizontal[index], color_wheel=color_wheel, 147 | device=device) 148 | return flow_image if batch_dimension else flow_image[0] 149 | -------------------------------------------------------------------------------- /images/frame_0005_flow_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0005_flow_vis.png -------------------------------------------------------------------------------- /images/frame_0005_flow_vis_torch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0005_flow_vis_torch.png -------------------------------------------------------------------------------- /images/frame_0014_flow_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0014_flow_vis.png -------------------------------------------------------------------------------- /images/frame_0014_flow_vis_torch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0014_flow_vis_torch.png -------------------------------------------------------------------------------- /images/frame_0023_flow_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0023_flow_vis.png -------------------------------------------------------------------------------- /images/frame_0023_flow_vis_torch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0023_flow_vis_torch.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="flow_vis_torch", 5 | packages=["flow_vis_torch"], 6 | version="0.1", 7 | license="MIT", 8 | author="Christoph Reich", 9 | description="Easy optical flow visualisation in Python (PyTorch).", 10 | url="https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch", 11 | keywords=["optical flow", "visualization", "motion"], 12 | install_requires=["torch"], 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | ], 17 | python_requires=">=3.6", 18 | ) 19 | --------------------------------------------------------------------------------