├── .gitignore
├── LICENSE
├── README.md
├── example.py
├── flow_vis_torch
├── __init__.py
└── flow_vis_torch.py
├── images
├── frame_0005_flow_vis.png
├── frame_0005_flow_vis_torch.png
├── frame_0014_flow_vis.png
├── frame_0014_flow_vis_torch.png
├── frame_0023_flow_vis.png
└── frame_0023_flow_vis_torch.png
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Flow files
2 | *.flo
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
135 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
136 | # Source: https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore
137 |
138 | # User-specific stuff
139 | .idea/**/workspace.xml
140 | .idea/**/tasks.xml
141 | .idea/**/usage.statistics.xml
142 | .idea/**/dictionaries
143 | .idea/**/shelf
144 |
145 | # Generated files
146 | .idea/**/contentModel.xml
147 |
148 | # Sensitive or high-churn files
149 | .idea/**/dataSources/
150 | .idea/**/dataSources.ids
151 | .idea/**/dataSources.local.xml
152 | .idea/**/sqlDataSources.xml
153 | .idea/**/dynamic.xml
154 | .idea/**/uiDesigner.xml
155 | .idea/**/dbnavigator.xml
156 |
157 | # Gradle
158 | .idea/**/gradle.xml
159 | .idea/**/libraries
160 |
161 | # Gradle and Maven with auto-import
162 | # When using Gradle or Maven with auto-import, you should exclude module files,
163 | # since they will be recreated, and may cause churn. Uncomment if using
164 | # auto-import.
165 | # .idea/artifacts
166 | # .idea/compiler.xml
167 | # .idea/jarRepositories.xml
168 | # .idea/modules.xml
169 | # .idea/*.iml
170 | # .idea/modules
171 | # *.iml
172 | # *.ipr
173 |
174 | # CMake
175 | cmake-build-*/
176 |
177 | # Mongo Explorer plugin
178 | .idea/**/mongoSettings.xml
179 |
180 | # File-based project format
181 | *.iws
182 |
183 | # IntelliJ
184 | out/
185 |
186 | # mpeltonen/sbt-idea plugin
187 | .idea_modules/
188 |
189 | # JIRA plugin
190 | atlassian-ide-plugin.xml
191 |
192 | # Cursive Clojure plugin
193 | .idea/replstate.xml
194 |
195 | # Crashlytics plugin (for Android Studio and IntelliJ)
196 | com_crashlytics_export_strings.xml
197 | crashlytics.properties
198 | crashlytics-build.properties
199 | fabric.properties
200 |
201 | # Editor-based Rest Client
202 | .idea/httpRequests
203 |
204 | # Android studio 3.1+ serialized cache file
205 | .idea/caches/build_file_checksums.ser
206 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Christoph Reich
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Optical Flow Visualization for PyTorch
2 |
3 | [](https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/blob/master/LICENSE)
4 |
5 | This repository is a PyTorch fork of the [OpticalFlow_Visualization](https://github.com/tomrunia/OpticalFlow_Visualization) (flow_vis) repository, originally published under the [MIT license](https://github.com/tomrunia/OpticalFlow_Visualization/blob/master/LICENSE.txt). The optical flow visualization follows the color encoding proposed in the paper "[A database and evaluation methodology for optical flow](https://link.springer.com/content/pdf/10.1007/s11263-010-0390-2.pdf)" by Baker et al. published at ICCV 2007 [1].
6 |
7 | ## Installation
8 |
9 | Simply run the following command to install `flow_vis_torch`.
10 |
11 | ```shell script
12 | pip install git+https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch
13 | ```
14 |
15 | ## Usage
16 |
17 | Convert a given flow of the shape `[batch size (optional), 2, height, width]` to an RGB image of the shape `[batch size (optional), 3, height, width]` by calling `flow_vis_torch.flow_to_color`.
18 |
19 | ```python
20 | import flow_vis_torch
21 | flow_rgb = flow_vis_torch.flow_to_color(flow)
22 | ```
23 |
24 | For a detailed example have a look at the [example script](example.py).
25 |
26 | ## Visualizations
27 |
28 | Flow maps taken from the [MPI Sintel Flow Dataset](http://sintel.is.tue.mpg.de/) [2].
29 |
30 |
31 |
32 | Output flow_vis_torch |
33 | Output flow_vis |
34 |
35 |
36 | |
37 | |
38 |
39 |
40 | |
41 | |
42 |
43 |
44 | |
45 | |
46 |
47 |
48 |
49 | ## References
50 |
51 | ```bibtex
52 | [1] @inproceedings{Baker2007,
53 | title={{A Database and Evaluation Methodology for Optical Flow}},
54 | author={Baker, Simon and Roth, Stefan and Scharstein, Daniel and Black, Michael J and Lewis, JP and Szeliski, Richard},
55 | booktitle={{International Conference on Computer Vision (ICCV)}},
56 | pages={1--8},
57 | year={2007},
58 | organization={IEEE}
59 | }
60 | ```
61 |
62 | ```bibtex
63 | [2] @inproceedings{Butler2012,
64 | title={{A Naturalistic Open Source Movie for Optical Flow Evaluation}},
65 | author={Butler, Daniel J and Wulff, Jonas and Stanley, Garrett B and Black, Michael J},
66 | booktitle={{European Conference on Computer Vision (ECCV)}},
67 | pages = {611--625},
68 | year = {2012},
69 | publisher={Springer}
70 | }
71 | ```
72 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision
4 | import flow_vis
5 | import flow_vis_torch
6 |
7 |
8 | def load_flo_file(file_path: str) -> torch.Tensor:
9 | """
10 | Function loads a .flo file.
11 | :param file_path: (str) Path to flo file
12 | :return: (torch.Tensor) Torch.Tensor of the shape [2, height, width]
13 | """
14 | # Open file
15 | with open(file_path, "rb") as file:
16 | # Load date
17 | data_array: np.ndarray = np.fromfile(file, np.float32)[3:]
18 | # Reshape data
19 | flow: np.ndarray = data_array.reshape((436, 1024, 2))
20 | # To PyTorch and reshape
21 | return torch.from_numpy(flow).permute(2, 0, 1)
22 |
23 |
24 | if __name__ == '__main__':
25 | for file in ["frame_0014.flo", "frame_0005.flo", "frame_0023.flo"]:
26 | # Load flow maps
27 | flow = load_flo_file(file_path=file)
28 | # Standard package
29 | flow_rgb_flow_vis = torch.from_numpy(
30 | flow_vis.flow_to_color(flow.clone().permute(1, 2, 0).numpy()).astype(float)).permute(2, 0, 1)
31 | torchvision.utils.save_image(flow_rgb_flow_vis.float(), file.replace(".flo", "_flow_vis.png"), normalize=True)
32 | # PyTorch version
33 | flow_rgb_flow_vis_torch = flow_vis_torch.flow_to_color(flow)
34 | torchvision.utils.save_image(flow_rgb_flow_vis.float(), file.replace(".flo", "_flow_vis_torch.png"),
35 | normalize=True)
36 |
--------------------------------------------------------------------------------
/flow_vis_torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_vis_torch import flow_to_color
2 |
--------------------------------------------------------------------------------
/flow_vis_torch/flow_vis_torch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | from math import pi as PI
5 |
6 |
7 | def get_color_wheel(device: torch.device) -> torch.Tensor:
8 | """
9 | Generates the color wheel.
10 | :param device: (torch.device) Device to be used
11 | :return: (torch.Tensor) Color wheel tensor of the shape [55, 3]
12 | """
13 | # Set constants
14 | RY: int = 15
15 | YG: int = 6
16 | GC: int = 4
17 | CB: int = 11
18 | BM: int = 13
19 | MR: int = 6
20 | # Init color wheel
21 | color_wheel: torch.Tensor = torch.zeros((RY + YG + GC + CB + BM + MR, 3), dtype=torch.float32)
22 | # Init counter
23 | counter: int = 0
24 | # RY
25 | color_wheel[0:RY, 0] = 255
26 | color_wheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY)
27 | counter: int = counter + RY
28 | # YG
29 | color_wheel[counter:counter + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG)
30 | color_wheel[counter:counter + YG, 1] = 255
31 | counter: int = counter + YG
32 | # GC
33 | color_wheel[counter:counter + GC, 1] = 255
34 | color_wheel[counter:counter + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC)
35 | counter: int = counter + GC
36 | # CB
37 | color_wheel[counter:counter + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB)
38 | color_wheel[counter:counter + CB, 2] = 255
39 | counter: int = counter + CB
40 | # BM
41 | color_wheel[counter:counter + BM, 2] = 255
42 | color_wheel[counter:counter + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM)
43 | counter: int = counter + BM
44 | # MR
45 | color_wheel[counter:counter + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR)
46 | color_wheel[counter:counter + MR, 0] = 255
47 | # To device
48 | color_wheel: torch.Tensor = color_wheel.to(device)
49 | return color_wheel
50 |
51 |
52 | def _flow_hw_to_color(flow_vertical: torch.Tensor, flow_horizontal: torch.Tensor,
53 | color_wheel: torch.Tensor, device: torch.device) -> torch.Tensor:
54 | """
55 | Private function applies the flow color wheel to flow components (vertical and horizontal).
56 | :param flow_vertical: (torch.Tensor) Vertical flow of the shape [height, width]
57 | :param flow_horizontal: (torch.Tensor) Horizontal flow of the shape [height, width]
58 | :param color_wheel: (torch.Tensor) Color wheel tensor of the shape [55, 3]
59 | :param: device: (torch.device) Device to be used
60 | :return: (torch.Tensor) Visualized flow of the shape [3, height, width]
61 | """
62 | # Get shapes
63 | _, height, width = flow_vertical.shape
64 | # Init flow image
65 | flow_image: torch.Tensor = torch.zeros(3, height, width, dtype=torch.float32, device=device)
66 | # Get number of colors
67 | number_of_colors: int = color_wheel.shape[0]
68 | # Compute norm, angle and factors
69 | flow_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt()
70 | angle: torch.Tensor = torch.atan2(- flow_vertical, - flow_horizontal) / PI
71 | fk: torch.Tensor = (angle + 1.) / 2. * (number_of_colors - 1.)
72 | k0: torch.Tensor = torch.floor(fk).long()
73 | k1: torch.Tensor = k0 + 1
74 | k1[k1 == number_of_colors] = 0
75 | f: torch.Tensor = fk - k0
76 | # Iterate over color components
77 | for index in range(color_wheel.shape[1]):
78 | # Get component of all colors
79 | tmp: torch.Tensor = color_wheel[:, index]
80 | # Get colors
81 | color_0: torch.Tensor = tmp[k0] / 255.
82 | color_1: torch.Tensor = tmp[k1] / 255.
83 | # Compute color
84 | color: torch.Tensor = (1. - f) * color_0 + f * color_1
85 | # Get color index
86 | color_index: torch.Tensor = flow_norm <= 1
87 | # Set color saturation
88 | color[color_index] = 1 - flow_norm[color_index] * (1. - color[color_index])
89 | color[~color_index] = color[~color_index] * 0.75
90 | # Set color in image
91 | flow_image[index] = torch.floor(255 * color)
92 | return flow_image
93 |
94 |
95 | def flow_to_color(flow: torch.Tensor, clip_flow: Optional[Union[float, torch.Tensor]] = None,
96 | normalize_over_video: bool = False) -> torch.Tensor:
97 | """
98 | Function converts a given optical flow map into the classical color schema.
99 | :param flow: (torch.Tensor) Optical flow tensor of the shape [batch size (optional), 2, height, width].
100 | :param clip_flow: (Optional[Union[float, torch.Tensor]]) Max value of flow values for clipping (default None).
101 | :param normalize_over_video: (bool) If true scale is normalized over the whole video (batch).
102 | :return: (torch.Tensor) Flow visualization (float tensor) with the shape [batch size (if used), 3, height, width].
103 | """
104 | # Check parameter types
105 | assert torch.is_tensor(flow), "Given flow map must be a torch.Tensor, {} given".format(type(flow))
106 | assert torch.is_tensor(clip_flow) or isinstance(clip_flow, float) or clip_flow is None, \
107 | "Given clip_flow parameter must be a float, a torch.Tensor, or None, {} given".format(type(clip_flow))
108 | # Check shapes
109 | assert flow.ndimension() in [3, 4], \
110 | "Given flow must be a 3D or 4D tensor, given tensor shape {}.".format(flow.shape)
111 | if torch.is_tensor(clip_flow):
112 | assert clip_flow.ndimension() == 0, \
113 | "Given clip_flow tensor must be a scalar, given tensor shape {}.".format(clip_flow.shape)
114 | # Manage batch dimension
115 | batch_dimension: bool = True
116 | if flow.ndimension() == 3:
117 | flow = flow[None]
118 | batch_dimension: bool = False
119 | # Save shape
120 | batch_size, _, height, width = flow.shape
121 | # Check flow dimension
122 | assert flow.shape[1] == 2, "Flow dimension must have the shape 2 but tensor with {} given".format(flow.shape[1])
123 | # Save device
124 | device: torch.device = flow.device
125 | # Clip flow if utilized
126 | if clip_flow is not None:
127 | flow = flow.clip(max=clip_flow)
128 | # Get horizontal and vertical flow
129 | flow_vertical: torch.Tensor = flow[:, 0:1]
130 | flow_horizontal: torch.Tensor = flow[:, 1:2]
131 | # Get max norm of flow
132 | flow_max_norm: torch.Tensor = (flow_vertical ** 2 + flow_horizontal ** 2).sqrt().view(batch_size, -1).max(dim=-1)[0]
133 | flow_max_norm: torch.Tensor = flow_max_norm.view(batch_size, 1, 1, 1)
134 | if normalize_over_video:
135 | flow_max_norm: Tensor = flow_max_norm.max(dim=0, keepdim=True)[0]
136 | # Normalize flow
137 | flow_vertical: torch.Tensor = flow_vertical / (flow_max_norm + 1e-05)
138 | flow_horizontal: torch.Tensor = flow_horizontal / (flow_max_norm + 1e-05)
139 | # Get color wheel
140 | color_wheel: torch.Tensor = get_color_wheel(device=device)
141 | # Init flow image
142 | flow_image = torch.zeros(batch_size, 3, height, width, device=device)
143 | # Iterate over batch dimension
144 | for index in range(batch_size):
145 | flow_image[index] = _flow_hw_to_color(flow_vertical=flow_vertical[index],
146 | flow_horizontal=flow_horizontal[index], color_wheel=color_wheel,
147 | device=device)
148 | return flow_image if batch_dimension else flow_image[0]
149 |
--------------------------------------------------------------------------------
/images/frame_0005_flow_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0005_flow_vis.png
--------------------------------------------------------------------------------
/images/frame_0005_flow_vis_torch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0005_flow_vis_torch.png
--------------------------------------------------------------------------------
/images/frame_0014_flow_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0014_flow_vis.png
--------------------------------------------------------------------------------
/images/frame_0014_flow_vis_torch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0014_flow_vis_torch.png
--------------------------------------------------------------------------------
/images/frame_0023_flow_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0023_flow_vis.png
--------------------------------------------------------------------------------
/images/frame_0023_flow_vis_torch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch/9177370c7c00b4b7dbe4deda6fed734fdff48b2c/images/frame_0023_flow_vis_torch.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="flow_vis_torch",
5 | packages=["flow_vis_torch"],
6 | version="0.1",
7 | license="MIT",
8 | author="Christoph Reich",
9 | description="Easy optical flow visualisation in Python (PyTorch).",
10 | url="https://github.com/ChristophReich1996/Optical-Flow-Visualization-PyTorch",
11 | keywords=["optical flow", "visualization", "motion"],
12 | install_requires=["torch"],
13 | classifiers=[
14 | "Programming Language :: Python :: 3",
15 | "License :: OSI Approved :: MIT License",
16 | ],
17 | python_requires=">=3.6",
18 | )
19 |
--------------------------------------------------------------------------------