├── .gitignore
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── examples
├── basic_dice_cubemap.png
├── basic_equirectangular.png
├── example_world_map_dice_cubemap.png
├── example_world_map_equirectangular.png
├── example_world_map_equirectangular_rotated.png
├── example_world_map_horizon_cubemap.png
└── example_world_map_perspective.png
├── pytorch360convert
├── __init__.py
├── pytorch360convert.py
└── version.py
├── setup.py
└── tests
├── __init__.py
└── test_module.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | abstract: "Utilities for converting between different cubemap, equirectangular, and panoramic."
2 | authors:
3 | - family-names: Egan
4 | given-names: Ben
5 | cff-version: 1.2.0
6 | date-released: "2024-12-15"
7 | keywords:
8 | - equirectangular
9 | - panorama
10 | - "360 degrees"
11 | - "360 degree images"
12 | - cubemap
13 | - research
14 | license: MIT
15 | message: "If you use this software, please cite it using these metadata."
16 | repository-code: "https://github.com/ProGamerGov/pytorch360convert"
17 | title: "pytorch360convert"
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to pytorch360convert
2 |
3 | This project uses simple linting and testing guidelines, as you'll see below.
4 |
5 | ## Linting
6 |
7 |
8 | Linting is simple to perform.
9 |
10 | ```
11 | pip install black flake8 mypy ufmt pytest-cov
12 |
13 | ```
14 |
15 | Linting:
16 |
17 | ```
18 | cd pytorch360convert
19 | black .
20 | ufmt format .
21 | cd ..
22 | ```
23 |
24 | Checking:
25 |
26 | ```
27 | cd pytorch360convert
28 | black --check --diff .
29 | flake8 . --ignore=E203,W503 --max-line-length=88 --exclude build,dist
30 | ufmt check .
31 | mypy . --ignore-missing-imports --allow-redefinition --explicit-package-bases
32 | cd ..
33 | ```
34 |
35 |
36 | ## Testing
37 |
38 | Tests can run like this:
39 |
40 | ```
41 | pip install pytest pytest-cov
42 | ```
43 |
44 | ```
45 | cd pytorch360convert
46 | pytest -ra --cov=. --cov-report term-missing
47 | cd ..
48 | ```
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 sunset
4 |
5 | Copyright (c) 2024 Ben Egan
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📷 PyTorch 360° Image Conversion Toolkit
2 |
3 | [](https://pypi.org/project/pytorch360convert/)
4 |
5 |
6 | ## Overview
7 |
8 | This PyTorch-based library provides powerful and differentiable image transformation utilities for converting between different panoramic image formats:
9 |
10 | - **Equirectangular (360°) Images**
11 | - **Cubemap Representations**
12 | - **Perspective Projections**
13 |
14 | Built as an improved PyTorch implementation of the original [py360convert](https://github.com/sunset1995/py360convert) project, this library offers flexible, CPU & GPU-accelerated functions.
15 |
16 |
17 |
18 |

19 |
20 |
21 | * Equirectangular format
22 |
23 |
24 |
25 |

26 |
27 |
28 | * Cubemap 'dice' format
29 |
30 |
31 | ## 🔧 Requirements
32 |
33 | - Python 3.7+
34 | - [PyTorch](https://pytorch.org/)
35 |
36 |
37 | ## 📦 Installation
38 |
39 | You can easily install the library using pip:
40 |
41 | ```bash
42 | pip install pytorch360convert
43 | ```
44 |
45 | Or you can install it from source like this:
46 |
47 | ```bash
48 | pip install torch
49 | ```
50 |
51 | Then clone the repository:
52 |
53 | ```bash
54 | git clone https://github.com/ProGamerGov/pytorch360convert.git
55 | cd pytorch360convert
56 | pip install .
57 | ```
58 |
59 |
60 | ## 🚀 Key Features
61 |
62 | - Lossless conversion between image formats.
63 | - Supports different cubemap input formats (horizon, list, stack, dict, dice).
64 | - Configurable sampling modes (bilinear, nearest).
65 | - Supports different dtypes (float16, float32, float64, bfloat16).
66 | - CPU support.
67 | - GPU acceleration.
68 | - Differentiable transformations for deep learning pipelines.
69 | - [TorchScript](https://pytorch.org/docs/stable/jit.html) (JIT) support.
70 |
71 |
72 | ## 💡 Usage Examples
73 |
74 |
75 | ### Helper Functions
76 |
77 | First we'll setup some helper functions:
78 |
79 | ```bash
80 | pip install torchvision pillow
81 | ```
82 |
83 |
84 | ```python
85 | import torch
86 | from torchvision.transforms import ToTensor, ToPILImage
87 | from PIL import Image
88 |
89 | def load_image_to_tensor(image_path: str) -> torch.Tensor:
90 | """Load an image as a PyTorch tensor."""
91 | return ToTensor()(Image.open(image_path).convert('RGB'))
92 |
93 | def save_tensor_as_image(tensor: torch.Tensor, save_path: str) -> None:
94 | """Save a PyTorch tensor as an image."""
95 | ToPILImage()(tensor).save(save_path)
96 |
97 | ```
98 |
99 | ### Equirectangular to Cubemap Conversion
100 |
101 | Converting equirectangular images into cubemaps is easy. For simplicity, we'll use the 'dice' format, which places all cube faces into a single 4x3 grid image.
102 |
103 | ```python
104 | from pytorch360convert import e2c
105 |
106 | # Load equirectangular image (3, 1376, 2752)
107 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png")
108 | face_w = equi_image.shape[2] // 4 # 2752 / 4 = 688
109 |
110 | # Convert to cubemap (dice format)
111 | cubemap = e2c(
112 | equi_image, # CHW format
113 | face_w=face_w, # Width of each cube face
114 | mode='bilinear', # Sampling interpolation
115 | cube_format='dice' # Output cubemap layout
116 | )
117 |
118 | # Save cubemap faces
119 | save_tensor_as_image(cubemap, "dice_cubemap.jpg")
120 | ```
121 |
122 | | Equirectangular Input | Cubemap 'Dice' Output |
123 | | :---: | :----: |
124 | |  |  |
125 |
126 | | Cubemap 'Horizon' Output |
127 | | :---: |
128 | |  |
129 |
130 | ### Cubemap to Equirectangular Conversion
131 |
132 | We can also convert cubemaps into equirectangular images, like so.
133 |
134 | ```python
135 | from pytorch360convert import c2e
136 |
137 | # Load cubemap in 'dice' format
138 | cubemap = load_image_to_tensor("dice_cubemap.jpg")
139 |
140 | # Convert cubemap back to equirectangular
141 | equirectangular = c2e(
142 | cubemap, # Cubemap tensor(s)
143 | mode='bilinear', # Sampling interpolation
144 | cube_format='dice' # Input cubemap layout
145 | )
146 |
147 | save_tensor_as_image(equirectangular, "equirectangular.jpg")
148 | ```
149 |
150 | ### Equirectangular to Perspective Projection
151 |
152 | ```python
153 | from pytorch360convert import e2p
154 |
155 | # Load equirectangular input
156 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png")
157 |
158 | # Extract perspective view from equirectangular image
159 | perspective_view = e2p(
160 | equi_image, # Equirectangular image
161 | fov_deg=(70, 60), # Horizontal and vertical FOV
162 | h_deg=260, # Horizontal rotation
163 | v_deg=50, # Vertical rotation
164 | out_hw=(512, 768), # Output image dimensions
165 | mode='bilinear' # Sampling interpolation
166 | )
167 |
168 | save_tensor_as_image(perspective_view, "perspective.jpg")
169 | ```
170 |
171 | | Equirectangular Input | Perspective Output |
172 | | :---: | :----: |
173 | |  |  |
174 |
175 |
176 |
177 | ### Equirectangular to Equirectangular
178 |
179 | ```python
180 | from pytorch360convert import e2e
181 |
182 | # Load equirectangular input
183 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png")
184 |
185 | # Rotate an equirectangular image around one more axes
186 | rotated_equi = e2e(
187 | equi_image, # Equirectangular image
188 | h_deg=90.0, # Vertical rotation/shift
189 | v_deg=200.0, # Horizontal rotation/shift
190 | roll=45.0, # Clockwise/counter clockwise rotation
191 | mode='bilinear' # Sampling interpolation
192 | )
193 |
194 | save_tensor_as_image(rotated_equi, "rotated.jpg")
195 | ```
196 |
197 | | Equirectangular Input | Rotated Output |
198 | | :---: | :----: |
199 | |  |  |
200 |
201 |
202 | ## 📚 Basic Functions
203 |
204 | ### `e2c(e_img, face_w=256, mode='bilinear', cube_format='dice')`
205 | Converts an equirectangular image to a cubemap projection.
206 |
207 | - **Parameters**:
208 | - `e_img` (torch.Tensor): Equirectangular CHW image tensor.
209 | - `face_w` (int, optional): Cube face width. If set to None, then face_w will be calculated as ` // 2`. Default: `None`.
210 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear`
211 | - `cube_format` (str, optional): The desired output cubemap format. Options are `dict`, `list`, `horizon`, `stack`, and `dice`. Default: `dice`
212 | - `stack` (torch.Tensor): Stack of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
213 | - `list` (list of torch.Tensor): List of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
214 | - `dict` (dict of torch.Tensor): Dictionary with keys pointing to face tensors. Keys are: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
215 | - `dice` (torch.Tensor): A cubemap in a 'dice' layout.
216 | - `horizon` (torch.Tensor): A cubemap in a 'horizon' layout, a 1x6 grid in the order: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
217 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`.
218 |
219 | - **Returns**: Cubemap representation of the input image as a tensor, list of tensors, or dict or tensors.
220 |
221 | ### `c2e(cubemap, h, w, mode='bilinear', cube_format='dice')`
222 | Converts a cubemap projection to an equirectangular image.
223 |
224 | - **Parameters**:
225 | - `cubemap` (torch.Tensor, list of torch.Tensor, or dict of torch.Tensor): Cubemap image tensor, list of tensors, or dict of tensors. Note that tensors should be in the shape of: `CHW`, except for when `cube_format = 'stack'`, in which case a batch dimension is present. Inputs should match the corresponding `cube_format`.
226 | - `h` (int, optional): Output image height. If set to None, ` * 2` will be used. Default: `None`.
227 | - `w` (int, optional): Output image width. If set to None, ` * 4` will be used. Default: `None`.
228 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear`
229 | - `cube_format` (str, optional): Input cubemap format. Options are `dict`, `list`, `horizon`, `stack`, and `dice`. Default: `dice`
230 | - `stack` (torch.Tensor): Stack of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
231 | - `list` (list of torch.Tensor): List of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
232 | - `dict` (dict of torch.Tensor): Dictionary with keys pointing to face tensors. Keys are expected to be: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
233 | - `dice` (torch.Tensor): A cubemap in a 'dice' layout.
234 | - `horizon` (torch.Tensor): A cubemap in a 'horizon' layout, a 1x6 grid in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
235 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`.
236 |
237 | - **Returns**: Equirectangular projection of the input cubemap as a tensor.
238 |
239 | ### `e2p(e_img, fov_deg, h_deg, v_deg, out_hw, in_rot_deg=0, mode='bilinear')`
240 | Extracts a perspective view from an equirectangular image.
241 |
242 | - **Parameters**:
243 | - `e_img` (torch.Tensor): Equirectangular CHW or NCHW image tensor.
244 | - `fov_deg` (float or tuple of float): Field of view in degrees. If a single value is provided, it will be used for both horizontal and vertical degrees. If using a tuple, values are expected to be in following format: (h_fov_deg, v_fov_deg).
245 | - `h_deg` (float, optional): Horizontal viewing angle in range [-pi, pi]. (-Left/+Right). Default: `0.0`
246 | - `v_deg` (float, optional): Vertical viewing angle in range [-pi/2, pi/2]. (-Down/+Up). Default: `0.0`
247 | - `out_hw` (float or tuple of float, optional): Output image dimensions in the shape of '(height, width)'. Default: `(512, 512)`
248 | - `in_rot_deg` (float, optional): Inplane rotation angle. Default: `0`
249 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear`
250 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`.
251 |
252 | - **Returns**: Perspective view of the equirectangular image as a tensor.
253 |
254 | ### `e2e(e_img, h_deg, v_deg, roll=0, mode='bilinear')`
255 |
256 | Rotate an equirectangular image along one or more axes (roll, pitch, and yaw) to produce a horizontal shift, vertical shift, or to roll the image.
257 |
258 | - **Parameters**:
259 | - `e_img` (torch.Tensor): Equirectangular CHW or NCHW image tensor.
260 | - `roll` (float, optional): Roll angle in degrees (-Counter_Clockwise/+Clockwise). Rotates the image along the x-axis. Default: `0.0`
261 | - `h_deg` (float, optional): Yaw angle in degrees (-Left/+Right). Rotates the image along the z-axis to produce a horizontal shift. Default: `0.0`
262 | - `v_deg` (float, optional): Pitch angle in degrees (-Down/+Up). Rotates the image along the y-axis to produce a vertical shift. Default: `0.0`
263 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear`
264 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`.
265 |
266 | - **Returns**: A modified equirectangular image tensor.
267 |
268 |
269 | ## 🤝 Contributing
270 |
271 | Contributions are welcome! Please feel free to submit a Pull Request.
272 |
273 |
274 | ## 🔬 Citation
275 |
276 | If you use this library in your research or project, please refer to the included [CITATION.cff](CITATION.cff) file or cite it as follows:
277 |
278 | ### BibTeX
279 | ```bibtex
280 | @misc{egan2024pytorch360convert,
281 | title={PyTorch 360° Image Conversion Toolkit},
282 | author={Egan, Ben},
283 | year={2024},
284 | publisher={GitHub},
285 | howpublished={\url{https://github.com/ProGamerGov/pytorch360convert}}
286 | }
287 | ```
288 |
289 | ### APA Style
290 | ```
291 | Egan, B. (2024). PyTorch 360° Image Conversion Toolkit [Computer software]. GitHub. https://github.com/ProGamerGov/pytorch360convert
292 | ```
293 |
--------------------------------------------------------------------------------
/examples/basic_dice_cubemap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/basic_dice_cubemap.png
--------------------------------------------------------------------------------
/examples/basic_equirectangular.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/basic_equirectangular.png
--------------------------------------------------------------------------------
/examples/example_world_map_dice_cubemap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_dice_cubemap.png
--------------------------------------------------------------------------------
/examples/example_world_map_equirectangular.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_equirectangular.png
--------------------------------------------------------------------------------
/examples/example_world_map_equirectangular_rotated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_equirectangular_rotated.png
--------------------------------------------------------------------------------
/examples/example_world_map_horizon_cubemap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_horizon_cubemap.png
--------------------------------------------------------------------------------
/examples/example_world_map_perspective.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_perspective.png
--------------------------------------------------------------------------------
/pytorch360convert/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch360convert.version import __version__ # noqa
2 | from .pytorch360convert import (
3 | c2e,
4 | cube_dice2h,
5 | cube_dict2h,
6 | cube_h2dice,
7 | cube_h2dict,
8 | cube_h2list,
9 | cube_list2h,
10 | e2c,
11 | e2e,
12 | e2p,
13 | pad_180_to_360,
14 | )
15 |
--------------------------------------------------------------------------------
/pytorch360convert/pytorch360convert.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def rotation_matrix(rad: torch.Tensor, ax: torch.Tensor) -> torch.Tensor:
8 | """
9 | Create a rotation matrix for a given angle and axis.
10 |
11 | Args:
12 | rad (torch.Tensor): Rotation angle in radians.
13 | ax (torch.Tensor): Rotation axis vector.
14 |
15 | Returns:
16 | torch.Tensor: 3x3 rotation matrix.
17 | """
18 | ax = ax / torch.sqrt((ax**2).sum())
19 | c = torch.cos(rad)
20 | s = torch.sin(rad)
21 | R = torch.diag(torch.cat([c, c, c]))
22 | R = R + (1.0 - c) * torch.ger(ax, ax)
23 | K = torch.stack(
24 | [
25 | torch.stack(
26 | [torch.tensor(0.0, device=ax.device, dtype=ax.dtype), -ax[2], ax[1]]
27 | ),
28 | torch.stack(
29 | [ax[2], torch.tensor(0.0, device=ax.device, dtype=ax.dtype), -ax[0]]
30 | ),
31 | torch.stack(
32 | [-ax[1], ax[0], torch.tensor(0.0, device=ax.device, dtype=ax.dtype)]
33 | ),
34 | ],
35 | dim=0,
36 | )
37 | R = R + K * s
38 | return R
39 |
40 |
41 | def _nhwc2nchw(x: torch.Tensor) -> torch.Tensor:
42 | """
43 | Convert NHWC to NCHW or HWC to CHW format.
44 |
45 | Args:
46 | x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format.
47 |
48 | Returns:
49 | torch.Tensor: The converted tensor in NCHW or CHW format.
50 | """
51 | assert x.dim() == 3 or x.dim() == 4
52 | return x.permute(2, 0, 1) if x.dim() == 3 else x.permute(0, 3, 1, 2)
53 |
54 |
55 | def _nchw2nhwc(x: torch.Tensor) -> torch.Tensor:
56 | """
57 | Convert NCHW to NHWC or CHW to HWC format.
58 |
59 | Args:
60 | x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format.
61 |
62 | Returns:
63 | torch.Tensor: The converted tensor in NHWC or HWC format.
64 | """
65 | assert x.dim() == 3 or x.dim() == 4
66 | return x.permute(1, 2, 0) if x.dim() == 3 else x.permute(0, 2, 3, 1)
67 |
68 |
69 | def _slice_chunk(
70 | index: int, width: int, offset: int = 0, device: torch.device = torch.device("cpu")
71 | ) -> torch.Tensor:
72 | """
73 | Generate a tensor of indices for a chunk of values.
74 |
75 | Args:
76 | index (int): The starting index for the chunk.
77 | width (int): The number of indices in the chunk.
78 | offset (int, optional): An offset added to the starting index, default is 0.
79 | device (torch.device, optional): The device for the tensor.
80 | Default: torch.device('cpu')
81 |
82 | Returns:
83 | torch.Tensor: A tensor containing the indices for the chunk.
84 | """
85 | start = index * width + offset
86 | # Create a tensor of indices instead of using slice
87 | return torch.arange(start, start + width, dtype=torch.long, device=device)
88 |
89 |
90 | def _face_slice(
91 | index: int, face_w: int, device: torch.device = torch.device("cpu")
92 | ) -> torch.Tensor:
93 | """
94 | Generate a slice of indices based on the face width.
95 |
96 | Args:
97 | index (int): The starting index.
98 | face_w (int): The width of the face (number of indices).
99 | device (torch.device, optional): The device for the tensor.
100 | Default: torch.device('cpu')
101 |
102 | Returns:
103 | torch.Tensor: A tensor containing the slice of indices.
104 | """
105 | return _slice_chunk(index, face_w, device=device)
106 |
107 |
108 | def xyzcube(
109 | face_w: int,
110 | device: torch.device = torch.device("cpu"),
111 | dtype: torch.dtype = torch.float32,
112 | ) -> torch.Tensor:
113 | """
114 | Generate cube coordinates for equirectangular projection.
115 |
116 | Args:
117 | face_w (int): Width of each cube face.
118 | device (torch.device, optional): Device to create tensor on.
119 | Default: torch.device('cpu')
120 | dtype (torch.dtype, optional): Data type of the tensor.
121 | Default: torch.float32
122 |
123 | Returns:
124 | torch.Tensor: Cube coordinates tensor of shape (face_w, face_w * 6, 3).
125 | """
126 | out = torch.empty((face_w, face_w * 6, 3), dtype=dtype, device=device)
127 | rng = torch.linspace(-0.5, 0.5, steps=face_w, dtype=dtype, device=device)
128 | x, y = torch.meshgrid(rng, -rng, indexing="xy") # shape (face_w, face_w)
129 |
130 | # Pre-compute flips
131 | x_flip = torch.flip(x, [1])
132 | y_flip = torch.flip(y, [0])
133 |
134 | # Front face (z = 0.5)
135 | front_indices = _face_slice(0, face_w, device)
136 | out[:, front_indices, 0] = x
137 | out[:, front_indices, 1] = y
138 | out[:, front_indices, 2] = 0.5
139 |
140 | # Right face (x = 0.5)
141 | right_indices = _face_slice(1, face_w, device)
142 | out[:, right_indices, 0] = 0.5
143 | out[:, right_indices, 1] = y
144 | out[:, right_indices, 2] = x_flip
145 |
146 | # Back face (z = -0.5)
147 | back_indices = _face_slice(2, face_w, device)
148 | out[:, back_indices, 0] = x_flip
149 | out[:, back_indices, 1] = y
150 | out[:, back_indices, 2] = -0.5
151 |
152 | # Left face (x = -0.5)
153 | left_indices = _face_slice(3, face_w, device)
154 | out[:, left_indices, 0] = -0.5
155 | out[:, left_indices, 1] = y
156 | out[:, left_indices, 2] = x
157 |
158 | # Up face (y = 0.5)
159 | up_indices = _face_slice(4, face_w, device)
160 | out[:, up_indices, 0] = x
161 | out[:, up_indices, 1] = 0.5
162 | out[:, up_indices, 2] = y_flip
163 |
164 | # Down face (y = -0.5)
165 | down_indices = _face_slice(5, face_w, device)
166 | out[:, down_indices, 0] = x
167 | out[:, down_indices, 1] = -0.5
168 | out[:, down_indices, 2] = y
169 |
170 | return out
171 |
172 |
173 | def equirect_uvgrid(
174 | h: int,
175 | w: int,
176 | device: torch.device = torch.device("cpu"),
177 | dtype: torch.dtype = torch.float32,
178 | ) -> torch.Tensor:
179 | """
180 | Generate UV grid for equirectangular projection.
181 |
182 | Args:
183 | h (int): Height of the grid.
184 | w (int): Width of the grid.
185 | device (torch.device, optional): Device to create tensor on.
186 | Default: torch.device('cpu')
187 | dtype (torch.dtype, optional): Data type of the tensor.
188 | Default: torch.float32
189 |
190 | Returns:
191 | torch.Tensor: UV grid of shape (h, w, 2).
192 | """
193 | u = torch.linspace(-torch.pi, torch.pi, steps=w, dtype=dtype, device=device)
194 | v = torch.linspace(torch.pi, -torch.pi, steps=h, dtype=dtype, device=device) / 2
195 | grid_v, grid_u = torch.meshgrid(v, u, indexing="ij")
196 | uv = torch.stack([grid_u, grid_v], dim=-1)
197 | return uv
198 |
199 |
200 | def equirect_facetype(
201 | h: int,
202 | w: int,
203 | device: torch.device = torch.device("cpu"),
204 | dtype: torch.dtype = torch.float32,
205 | ) -> torch.Tensor:
206 | """
207 | Determine face types for equirectangular projection.
208 |
209 | Args:
210 | h (int): Height of the grid.
211 | w (int): Width of the grid.
212 | device (torch.device, optional): Device to create tensor on.
213 | Default: torch.device('cpu')
214 | dtype (torch.dtype, optional): Data type of the tensor.
215 | Default: torch.float32
216 |
217 | Returns:
218 | torch.Tensor: Face type tensor of shape (h, w) with integer face
219 | indices.
220 | """
221 | tp = (
222 | torch.arange(4, device=device)
223 | .repeat_interleave(w // 4)
224 | .unsqueeze(0)
225 | .repeat(h, 1)
226 | )
227 | tp = torch.roll(tp, shifts=3 * (w // 8), dims=1)
228 |
229 | # Prepare ceil mask
230 | mask = torch.zeros((h, w // 4), dtype=torch.bool, device=device)
231 | idx = torch.linspace(-torch.pi, torch.pi, w // 4, device=device, dtype=dtype) / 4
232 | idx = torch.round(h / 2 - torch.atan(torch.cos(idx)) * h / torch.pi).to(torch.long)
233 | for i, j in enumerate(idx):
234 | mask[:j, i] = True
235 | mask = torch.roll(torch.cat([mask] * 4, dim=1), shifts=3 * (w // 8), dims=1)
236 |
237 | tp[mask] = 4
238 | tp[torch.flip(mask, [0])] = 5
239 | return tp
240 |
241 |
242 | def xyzpers(
243 | h_fov: float,
244 | v_fov: float,
245 | u: float,
246 | v: float,
247 | out_hw: Tuple[int, int],
248 | in_rot: float,
249 | device: torch.device = torch.device("cpu"),
250 | dtype: torch.dtype = torch.float32,
251 | ) -> torch.Tensor:
252 | """
253 | Generate perspective projection coordinates.
254 |
255 | Args:
256 | h_fov (torch.Tensor): Horizontal field of view in radians.
257 | v_fov (torch.Tensor): Vertical field of view in radians.
258 | u (float): Horizontal rotation angle in radians.
259 | v (float): Vertical rotation angle in radians.
260 | out_hw (tuple of int): Output height and width.
261 | in_rot (torch.Tensor): Input rotation angle in radians.
262 | device (torch.device, optional): Device to create tensor on.
263 | Default: torch.device('cpu')
264 | dtype (torch.dtype, optional): Data type of the tensor.
265 | Default: torch.float32
266 |
267 | Returns:
268 | torch.Tensor: Perspective projection coordinates tensor.
269 | """
270 | h_fov = torch.tensor([h_fov], dtype=dtype, device=device)
271 | v_fov = torch.tensor([v_fov], dtype=dtype, device=device)
272 | u = torch.tensor([u], dtype=dtype, device=device)
273 | v = torch.tensor([v], dtype=dtype, device=device)
274 | in_rot = torch.tensor([in_rot], dtype=dtype, device=device)
275 |
276 | out = torch.ones((*out_hw, 3), dtype=dtype, device=device)
277 | x_max = torch.tan(h_fov / 2)
278 | y_max = torch.tan(v_fov / 2)
279 | y_range = torch.linspace(
280 | -y_max.item(), y_max.item(), steps=out_hw[0], dtype=dtype, device=device
281 | )
282 | x_range = torch.linspace(
283 | -x_max.item(), x_max.item(), steps=out_hw[1], dtype=dtype, device=device
284 | )
285 | grid_y, grid_x = torch.meshgrid(-y_range, x_range, indexing="ij")
286 | out[..., 0] = grid_x
287 | out[..., 1] = grid_y
288 |
289 | Rx = rotation_matrix(v, torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device))
290 | Ry = rotation_matrix(u, torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device))
291 | Ri = rotation_matrix(
292 | in_rot, torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) @ Rx @ Ry
293 | )
294 |
295 | # Apply R = Rx*Ry*Ri to each vector
296 | # like this: out * Rx * Ry * Ri (assuming row vectors)
297 | out = out @ Rx @ Ry @ Ri
298 | return out
299 |
300 |
301 | def xyz2uv(xyz: torch.Tensor) -> torch.Tensor:
302 | """
303 | Transform cartesian (x, y, z) to spherical(r, u, v), and
304 | only outputs (u, v).
305 |
306 | Args:
307 | xyz (torch.Tensor): Input 3D coordinates tensor.
308 |
309 | Returns:
310 | torch.Tensor: UV coordinates tensor.
311 | """
312 | x = xyz[..., 0]
313 | y = xyz[..., 1]
314 | z = xyz[..., 2]
315 | u = torch.atan2(x, z)
316 | c = torch.sqrt(x**2 + z**2)
317 | v = torch.atan2(y, c)
318 | return torch.stack([u, v], dim=-1)
319 |
320 |
321 | def uv2unitxyz(uv: torch.Tensor) -> torch.Tensor:
322 | """
323 | Convert UV coordinates to unit 3D Cartesian coordinates.
324 |
325 | Args:
326 | uv (torch.Tensor): Input UV coordinates tensor.
327 |
328 | Returns:
329 | torch.Tensor: Unit 3D coordinates tensor.
330 | """
331 | u = uv[..., 0]
332 | v = uv[..., 1]
333 | y = torch.sin(v)
334 | c = torch.cos(v)
335 | x = c * torch.sin(u)
336 | z = c * torch.cos(u)
337 | return torch.stack([x, y, z], dim=-1)
338 |
339 |
340 | def uv2coor(uv: torch.Tensor, h: int, w: int) -> torch.Tensor:
341 | """
342 | Convert UV coordinates to image coordinates.
343 |
344 | Args:
345 | uv (torch.Tensor): Input UV coordinates tensor.
346 | h (int): Image height.
347 | w (int): Image width.
348 |
349 | Returns:
350 | torch.Tensor: Image coordinates tensor.
351 | """
352 | u = uv[..., 0]
353 | v = uv[..., 1]
354 | coor_x = (u / (2 * torch.pi) + 0.5) * w - 0.5
355 | coor_y = (-v / torch.pi + 0.5) * h - 0.5
356 | return torch.stack([coor_x, coor_y], dim=-1)
357 |
358 |
359 | def coor2uv(coorxy: torch.Tensor, h: int, w: int) -> torch.Tensor:
360 | """
361 | Convert image coordinates to UV coordinates.
362 |
363 | Args:
364 | coorxy (torch.Tensor): Input image coordinates tensor.
365 | h (int): Image height.
366 | w (int): Image width.
367 |
368 | Returns:
369 | torch.Tensor: UV coordinates tensor.
370 | """
371 | coor_x = coorxy[..., 0]
372 | coor_y = coorxy[..., 1]
373 | u = ((coor_x + 0.5) / w - 0.5) * 2 * torch.pi
374 | v = -((coor_y + 0.5) / h - 0.5) * torch.pi
375 | return torch.stack([u, v], dim=-1)
376 |
377 |
378 | def pad_cube_faces(cube_faces: torch.Tensor) -> torch.Tensor:
379 | """
380 | Adds 1 pixel of padding around each cube face, using pixels from the neighbouring
381 | faces, for each face.
382 |
383 | Args:
384 | cube_faces (torch.Tensor): Tensor of shape [6, H, W, C] representing the 6
385 | faces. Expected face order is: FRONT=0, RIGHT=1, BACK=2, LEFT=3, UP=4,
386 | DOWN=5
387 |
388 | Returns:
389 | torch.Tensor: Padded tensor of shape [6, H+2, W+2, C]
390 | """
391 | # Define face indices as constants instead of enum
392 | FRONT, RIGHT, BACK, LEFT, UP, DOWN = 0, 1, 2, 3, 4, 5
393 |
394 | # Create padded tensor with zeros
395 | padded = torch.zeros(
396 | cube_faces.shape[0],
397 | cube_faces.shape[1] + 2,
398 | cube_faces.shape[2] + 2,
399 | cube_faces.shape[3],
400 | dtype=cube_faces.dtype,
401 | device=cube_faces.device,
402 | )
403 |
404 | # Copy original data to center of padded tensor
405 | padded[:, 1:-1, 1:-1, :] = cube_faces
406 |
407 | # Pad above/below
408 | padded[FRONT, 0, 1:-1, :] = padded[UP, -2, 1:-1, :]
409 | padded[FRONT, -1, 1:-1, :] = padded[DOWN, 1, 1:-1, :]
410 |
411 | padded[RIGHT, 0, 1:-1, :] = padded[UP, 1:-1, -2, :].flip(0)
412 | padded[RIGHT, -1, 1:-1, :] = padded[DOWN, 1:-1, -2, :]
413 |
414 | padded[BACK, 0, 1:-1, :] = padded[UP, 1, 1:-1, :].flip(0)
415 | padded[BACK, -1, 1:-1, :] = padded[DOWN, -2, 1:-1, :].flip(0)
416 |
417 | padded[LEFT, 0, 1:-1, :] = padded[UP, 1:-1, 1, :]
418 | padded[LEFT, -1, 1:-1, :] = padded[DOWN, 1:-1, 1, :].flip(0)
419 |
420 | padded[UP, 0, 1:-1, :] = padded[BACK, 1, 1:-1, :].flip(0)
421 | padded[UP, -1, 1:-1, :] = padded[FRONT, 1, 1:-1, :]
422 |
423 | padded[DOWN, 0, 1:-1, :] = padded[FRONT, -2, 1:-1, :]
424 | padded[DOWN, -1, 1:-1, :] = padded[BACK, -2, 1:-1, :].flip(0)
425 |
426 | # Pad left/right
427 | padded[FRONT, 1:-1, 0, :] = padded[LEFT, 1:-1, -2, :]
428 | padded[FRONT, 1:-1, -1, :] = padded[RIGHT, 1:-1, 1, :]
429 |
430 | padded[RIGHT, 1:-1, 0, :] = padded[FRONT, 1:-1, -2, :]
431 | padded[RIGHT, 1:-1, -1, :] = padded[BACK, 1:-1, 1, :]
432 |
433 | padded[BACK, 1:-1, 0, :] = padded[RIGHT, 1:-1, -2, :]
434 | padded[BACK, 1:-1, -1, :] = padded[LEFT, 1:-1, 1, :]
435 |
436 | padded[LEFT, 1:-1, 0, :] = padded[BACK, 1:-1, -2, :]
437 | padded[LEFT, 1:-1, -1, :] = padded[FRONT, 1:-1, 1, :]
438 |
439 | padded[UP, 1:-1, 0, :] = padded[LEFT, 1, 1:-1, :]
440 | padded[UP, 1:-1, -1, :] = padded[RIGHT, 1, 1:-1, :].flip(0)
441 |
442 | padded[DOWN, 1:-1, 0, :] = padded[LEFT, -2, 1:-1, :].flip(0)
443 | padded[DOWN, 1:-1, -1, :] = padded[RIGHT, -2, 1:-1, :]
444 | return padded
445 |
446 |
447 | def grid_sample_wrap(
448 | image: torch.Tensor,
449 | coor_x: torch.Tensor,
450 | coor_y: torch.Tensor,
451 | mode: str = "bilinear",
452 | padding_mode: str = "border",
453 | ) -> torch.Tensor:
454 | """
455 | Sample from an image with wrapped horizontal coordinates.
456 |
457 | Args:
458 | image (torch.Tensor): Input image tensor in the shape of [H, W, C] or
459 | [N, H, W, C].
460 | coor_x (torch.Tensor): X coordinates for sampling.
461 | coor_y (torch.Tensor): Y coordinates for sampling.
462 | mode (str, optional): Sampling interpolation mode, 'nearest',
463 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
464 | padding_mode (str, optional): Sampling interpolation mode.
465 | Default: 'border'
466 |
467 | Returns:
468 | torch.Tensor: Sampled image tensor.
469 | """
470 |
471 | assert image.dim() == 4 or image.dim() == 3
472 | # Permute image to NCHW
473 | if image.dim() == 3:
474 | has_batch = False
475 | H, W, _ = image.shape
476 | # [H,W,C] -> [1,C,H,W]
477 | img_t = image.permute(2, 0, 1).unsqueeze(0)
478 | else:
479 | has_batch = True
480 | _, H, W, _ = image.shape
481 | # [N,H,W,C] -> [N,C,H,W]
482 | img_t = image.permute(0, 3, 1, 2)
483 |
484 | # coor_x, coor_y: [H_out, W_out]
485 | # We must create a grid for F.grid_sample:
486 | # grid_sample expects: input [N, C, H, W], grid [N, H_out, W_out, 2]
487 | # Normalized coords: x: [-1, 1], y: [-1, 1]
488 | # Handle wrapping horizontally: coor_x modulo W
489 | coor_x_wrapped = torch.remainder(coor_x, W) # wrap horizontally
490 | coor_y_clamped = coor_y.clamp(min=0, max=H - 1)
491 |
492 | # Normalize
493 | grid_x = (coor_x_wrapped / (W - 1)) * 2 - 1
494 | grid_y = (coor_y_clamped / (H - 1)) * 2 - 1
495 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H_out, W_out, 2]
496 |
497 | grid = grid.unsqueeze(0) # [1, H_out, W_out,2]
498 | if has_batch:
499 | grid = grid.repeat(img_t.shape[0], 1, 1, 1)
500 |
501 | # grid_sample: note that the code samples using (y, x) order if
502 | # align_corners=False, we must be careful:
503 | # grid is defined as grid[:,:,:,0] = x, grid[:,:,:,1] = y,
504 | # PyTorch grid_sample expects grid in form (N, H_out, W_out, 2),
505 | # with grid[:,:,:,0] = x and grid[:,:,:,1] = y
506 |
507 | if (
508 | img_t.dtype == torch.float16 or img_t.dtype == torch.bfloat16
509 | ) and img_t.device == torch.device("cpu"):
510 | sampled = F.grid_sample(
511 | img_t.float(),
512 | grid.float(),
513 | mode=mode,
514 | padding_mode=padding_mode,
515 | align_corners=True,
516 | ).to(img_t.dtype)
517 | else:
518 | sampled = F.grid_sample(
519 | img_t, grid, mode=mode, padding_mode=padding_mode, align_corners=True
520 | )
521 |
522 | if has_batch:
523 | sampled = sampled.permute(0, 2, 3, 1)
524 | else:
525 | # [1, C, H_out, W_out]
526 | sampled = sampled.squeeze(0).permute(1, 2, 0) # [H_out, W_out, C]
527 | return sampled
528 |
529 |
530 | def sample_equirec(
531 | e_img: torch.Tensor, coor_xy: torch.Tensor, mode: str = "bilinear"
532 | ) -> torch.Tensor:
533 | """
534 | Sample from an equirectangular image.
535 |
536 | Args:
537 | e_img (torch.Tensor): Equirectangular image tensor in the shape of:
538 | [H, W, C].
539 | coor_xy (torch.Tensor): Sampling coordinates in the shape of
540 | [H_out, W_out, 2].
541 | mode (str, optional): Sampling interpolation mode, 'nearest',
542 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
543 |
544 | Returns:
545 | torch.Tensor: Sampled image tensor.
546 | """
547 | coor_x = coor_xy[..., 0]
548 | coor_y = coor_xy[..., 1]
549 | return grid_sample_wrap(e_img, coor_x, coor_y, mode=mode)
550 |
551 |
552 | def sample_cubefaces(
553 | cube_faces: torch.Tensor,
554 | tp: torch.Tensor,
555 | coor_y: torch.Tensor,
556 | coor_x: torch.Tensor,
557 | mode: str = "bilinear",
558 | ) -> torch.Tensor:
559 | """
560 | Sample from cube faces.
561 |
562 | Args:
563 | cube_faces (torch.Tensor): Cube faces tensor in the shape of:
564 | [6, face_w, face_w, C].
565 | tp (torch.Tensor): Face type tensor.
566 | coor_y (torch.Tensor): Y coordinates for sampling.
567 | coor_x (torch.Tensor): X coordinates for sampling.
568 | mode (str, optional): Sampling interpolation mode, 'nearest' or
569 | 'bilinear'. Default: 'bilinear'
570 |
571 | Returns:
572 | torch.Tensor: Sampled cube faces tensor.
573 | """
574 | # cube_faces: [6, face_w, face_w, C]
575 | # We must sample according to tp (face index), coor_y, coor_x
576 | # First we must flatten all faces into a single big image (like cube_h)
577 | # We can do per-face sampling. Instead of map_coordinates
578 | # (tp, y, x), we know each pixel belongs to a certain face.
579 |
580 | # For differentiability and simplicity, let's do a trick:
581 | # Create a big image [face_w,face_w*6, C] (cube_h) and sample from it using
582 | # coor_x, coor_y and tp.
583 | cube_faces = pad_cube_faces(cube_faces)
584 | coor_y = coor_y + 1
585 | coor_x = coor_x + 1
586 | cube_faces_mod = cube_faces.clone()
587 |
588 | face_w = cube_faces_mod.shape[1]
589 | cube_h = torch.cat(
590 | [cube_faces_mod[i] for i in range(6)], dim=1
591 | ) # [face_w, face_w*6, C]
592 |
593 | # We need to map (tp, coor_y, coor_x) -> coordinates in cube_h
594 | # cube_h faces: 0:F, 1:R, 2:B, 3:L, 4:U, 5:D in order
595 | # If tp==0: x in [0, face_w-1] + offset 0
596 | # If tp==1: x in [0, face_w-1] + offset face_w
597 | # etc.
598 |
599 | # coor_x, coor_y are in face coordinates [0, face_w-1]
600 | # offset for face
601 | # x_offset = tp * face_w
602 |
603 | # Construct a single image indexing:
604 | # To handle tp indexing, let's create global_x = coor_x + tp * face_w
605 | # But tp might have shape (H_out,W_out)
606 | global_x = coor_x + tp.to(dtype=cube_h.dtype) * face_w
607 | global_y = coor_y
608 |
609 | return grid_sample_wrap(cube_h, global_x, global_y, mode=mode)
610 |
611 |
612 | def cube_h2list(cube_h: torch.Tensor) -> List[torch.Tensor]:
613 | """
614 | Convert a horizontal cube representation to a list of cube faces.
615 |
616 | Args:
617 | cube_h (torch.Tensor): Horizontal cube representation tensor in the
618 | shape of: [w, w*6, C] or [B, w, w*6, C].
619 |
620 | Returns:
621 | List[torch.Tensor]: List of cube face tensors in the order of:
622 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']
623 | """
624 | assert cube_h.dim() == 3 or cube_h.dim() == 4
625 | w = cube_h.shape[0] if cube_h.dim() == 3 else cube_h.shape[1]
626 | return [cube_h[..., i * w : (i + 1) * w, :] for i in range(6)]
627 |
628 |
629 | def cube_list2h(cube_list: List[torch.Tensor]) -> torch.Tensor:
630 | """
631 | Convert a list of cube faces to a horizontal cube representation.
632 |
633 | Args:
634 | cube_list (list of torch.Tensor): List of cube face tensors, in order
635 | of ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']
636 |
637 | Returns:
638 | torch.Tensor: Horizontal cube representation tensor.
639 | """
640 | assert all(
641 | cube.shape == cube_list[0].shape for cube in cube_list
642 | ), "All cube faces should have the same shape."
643 | assert all(
644 | cube.device == cube_list[0].device for cube in cube_list
645 | ), "All cube faces should have the same device."
646 | assert all(
647 | cube.dtype == cube_list[0].dtype for cube in cube_list
648 | ), "All cube faces should have the same dtype."
649 | return torch.cat(cube_list, dim=1)
650 |
651 |
652 | def cube_h2dict(
653 | cube_h: torch.Tensor,
654 | face_keys: Optional[List[str]] = None,
655 | ) -> Dict[str, torch.Tensor]:
656 | """
657 | Convert a horizontal cube representation to a dictionary of cube faces.
658 |
659 | Order: F R B L U D
660 | dice layout: 3*face_w x 4*face_w
661 |
662 | Args:
663 | cube_h (torch.Tensor): Horizontal cube representation tensor in the
664 | shape of: [w, w*6, C].
665 | face_keys (list of str, optional): List of face keys in order.
666 | Default: '["Front", "Right", "Back", "Left", "Up", "Down"]'
667 |
668 | Returns:
669 | Dict[str, torch.Tensor]: Dictionary of cube faces with keys:
670 | ["Front", "Right", "Back", "Left", "Up", "Down"].
671 | """
672 | if face_keys is None:
673 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
674 | cube_list = cube_h2list(cube_h)
675 | return dict(zip(face_keys, cube_list))
676 |
677 |
678 | def cube_dict2h(
679 | cube_dict: Dict[str, torch.Tensor],
680 | face_keys: Optional[List[str]] = None,
681 | ) -> torch.Tensor:
682 | """
683 | Convert a dictionary of cube faces to a horizontal cube representation.
684 |
685 | Args:
686 | cube_dict (dict of str to torch.Tensor): Dictionary of cube faces.
687 | face_keys (list of str, optional): List of face keys in order.
688 | Default: '["Front", "Right", "Back", "Left", "Up", "Down"]'
689 |
690 | Returns:
691 | torch.Tensor: Horizontal cube representation tensor.
692 | """
693 | if face_keys is None:
694 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
695 | return cube_list2h([cube_dict[k] for k in face_keys])
696 |
697 |
698 | def cube_h2dice(cube_h: torch.Tensor) -> torch.Tensor:
699 | """
700 | Convert a horizontal cube representation to a dice layout representation.
701 |
702 | Output order: Front Right Back Left Up Down
703 | ┌────┬────┬────┬────┐
704 | │ │ U │ │ │
705 | ├────┼────┼────┼────┤
706 | │ L │ F │ R │ B │
707 | ├────┼────┼────┼────┤
708 | │ │ D │ │ │
709 | └────┴────┴────┴────┘
710 |
711 | Args:
712 | cube_h (torch.Tensor): Horizontal cube representation tensor in the
713 | shape of: [w, w*6, C].
714 |
715 | Returns:
716 | torch.Tensor: Dice layout cube representation tensor in the shape of:
717 | [w*3, w*4, C].
718 | """
719 | w = cube_h.shape[0]
720 | cube_dice = torch.zeros(
721 | (w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype, device=cube_h.device
722 | )
723 | cube_list = cube_h2list(cube_h)
724 | sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
725 | for i, (sx, sy) in enumerate(sxy):
726 | face = cube_list[i]
727 | cube_dice[sy * w : (sy + 1) * w, sx * w : (sx + 1) * w] = face
728 | return cube_dice
729 |
730 |
731 | def cube_dice2h(cube_dice: torch.Tensor) -> torch.Tensor:
732 | """
733 | Convert a dice layout representation to a horizontal cube representation.
734 |
735 | Input order: Front Right Back Left Up Down
736 | ┌────┬────┬────┬────┐
737 | │ │ U │ │ │
738 | ├────┼────┼────┼────┤
739 | │ L │ F │ R │ B │
740 | ├────┼────┼────┼────┤
741 | │ │ D │ │ │
742 | └────┴────┴────┴────┘
743 |
744 | Args:
745 | cube_dice (torch.Tensor): Dice layout cube representation tensor in the
746 | shape of: [w*3, w*4, C].
747 |
748 | Returns:
749 | torch.Tensor: Horizontal cube representation tensor in the shape of:
750 | [w, w*6, C].
751 | """
752 | w = cube_dice.shape[0] // 3
753 | cube_h = torch.zeros(
754 | (w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype, device=cube_dice.device
755 | )
756 | sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)]
757 | for i, (sx, sy) in enumerate(sxy):
758 | face = cube_dice[sy * w : (sy + 1) * w, sx * w : (sx + 1) * w]
759 | cube_h[:, i * w : (i + 1) * w] = face
760 | return cube_h
761 |
762 |
763 | def c2e(
764 | cubemap: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]],
765 | h: Optional[int] = None,
766 | w: Optional[int] = None,
767 | mode: str = "bilinear",
768 | cube_format: str = "dice",
769 | channels_first: bool = True,
770 | ) -> torch.Tensor:
771 | """
772 | Convert a cubemap to an equirectangular projection.
773 |
774 | Args:
775 | cubemap (torch.Tensor, list of torch.Tensor, or dict of torch.Tensor):
776 | Cubemap image tensor, list of tensors, or dict of tensors. Note
777 | that tensors should be in the shape of: 'CHW', except for when
778 | `cube_format = 'stack'`, in which case a batch dimension is
779 | present 'BCHW'.
780 | h (int, optional): Height of the output equirectangular image. If set
781 | to None, * 2 will be used.
782 | Default: ` * 2`
783 | w (int, optional): Width of the output equirectangular image. If set
784 | to None, * 4 will be used.
785 | Default: ` * 4`
786 | mode (str, optional): Sampling interpolation mode, 'nearest',
787 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
788 | cube_format (str, optional): The input 'cubemap' format. Options
789 | are 'dict', 'list', 'horizon', 'stack', and 'dice'. Default: 'dice'
790 | The chosen 'cube_format' should correspond to the provided
791 | 'cubemap' input.
792 | The list of options are:
793 | - 'stack' (torch.Tensor): Stack of 6 faces, in the order
794 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
795 | - 'list' (list of torch.Tensor): List of 6 faces, in the order
796 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
797 | - 'dict' (dict of torch.Tensor): Dictionary with keys pointing to
798 | face tensors. Dict keys are:
799 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
800 | - 'dice' (torch.Tensor): A cubemap in a 'dice' layout.
801 | - 'horizon' (torch.Tensor): A cubemap in a 'horizon' layout,
802 | a 1x6 grid in the order:
803 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
804 | channels_first (bool, optional): The channel format used by the cubemap
805 | tensor(s). PyTorch uses channels first. Default: 'True'
806 |
807 | Returns:
808 | torch.Tensor: A CHW equirectangular projection tensor.
809 |
810 | Raises:
811 | NotImplementedError: If an unknown cube_format is provided.
812 | """
813 |
814 | if cube_format == "stack":
815 | assert (
816 | isinstance(cubemap, torch.Tensor)
817 | and len(cubemap.shape) == 4
818 | and cubemap.shape[0] == 6
819 | )
820 | cubemap = [cubemap[i] for i in range(cubemap.shape[0])]
821 | cube_format = "list"
822 |
823 | # Ensure input is in HWC format for processing
824 | if channels_first:
825 | if cube_format == "list" and isinstance(cubemap, (list, tuple)):
826 | cubemap = [r.permute(1, 2, 0) for r in cubemap]
827 | elif cube_format == "dict" and torch.jit.isinstance(
828 | cubemap, Dict[str, torch.Tensor]
829 | ):
830 | cubemap = {k: v.permute(1, 2, 0) for k, v in cubemap.items()} # type: ignore
831 | elif cube_format in ["horizon", "dice"] and isinstance(cubemap, torch.Tensor):
832 | cubemap = cubemap.permute(1, 2, 0)
833 | else:
834 | raise NotImplementedError("unknown cube_format and cubemap type")
835 |
836 | if cube_format == "horizon" and isinstance(cubemap, torch.Tensor):
837 | assert cubemap.dim() == 3
838 | cube_h = cubemap
839 | elif cube_format == "list" and isinstance(cubemap, (list, tuple)):
840 | assert all([r.dim() == 3 for r in cubemap])
841 | cube_h = cube_list2h(cubemap)
842 | elif cube_format == "dict" and torch.jit.isinstance(
843 | cubemap, Dict[str, torch.Tensor]
844 | ):
845 | assert all(v.dim() == 3 for k, v in cubemap.items()) # type: ignore[union-attr]
846 | cube_h = cube_dict2h(cubemap) # type: ignore[arg-type]
847 | elif cube_format == "dice" and isinstance(cubemap, torch.Tensor):
848 | assert len(cubemap.shape) == 3
849 | cube_h = cube_dice2h(cubemap)
850 | else:
851 | raise NotImplementedError("unknown cube_format and cubemap type")
852 | assert isinstance(cube_h, torch.Tensor) # Mypy wants this
853 |
854 | device = cube_h.device
855 | dtype = cube_h.dtype
856 | face_w = cube_h.shape[0]
857 | assert cube_h.shape[1] == face_w * 6
858 |
859 | h = face_w * 2 if h is None else h
860 | w = face_w * 4 if w is None else w
861 |
862 | uv = equirect_uvgrid(h, w, device=device, dtype=dtype)
863 | u, v = uv[..., 0], uv[..., 1]
864 |
865 | cube_faces = torch.stack(
866 | torch.split(cube_h, face_w, dim=1), dim=0
867 | ) # [6, face_w, face_w, C]
868 |
869 | tp = equirect_facetype(h, w, device=device, dtype=dtype)
870 |
871 | coor_x = torch.zeros((h, w), device=device, dtype=dtype)
872 | coor_y = torch.zeros((h, w), device=device, dtype=dtype)
873 |
874 | # front, right, back, left
875 | for i in range(4):
876 | mask = tp == i
877 | coor_x[mask] = 0.5 * torch.tan(u[mask] - torch.pi * i / 2)
878 | coor_y[mask] = -0.5 * torch.tan(v[mask]) / torch.cos(u[mask] - torch.pi * i / 2)
879 |
880 | # Up
881 | mask = tp == 4
882 | c = 0.5 * torch.tan(torch.pi / 2 - v[mask])
883 | coor_x[mask] = c * torch.sin(u[mask])
884 | coor_y[mask] = c * torch.cos(u[mask])
885 |
886 | # Down
887 | mask = tp == 5
888 | c = 0.5 * torch.tan(torch.pi / 2 - torch.abs(v[mask]))
889 | coor_x[mask] = c * torch.sin(u[mask])
890 | coor_y[mask] = -c * torch.cos(u[mask])
891 |
892 | coor_x = (torch.clamp(coor_x, -0.5, 0.5) + 0.5) * face_w
893 | coor_y = (torch.clamp(coor_y, -0.5, 0.5) + 0.5) * face_w
894 |
895 | equirec = sample_cubefaces(cube_faces, tp, coor_y, coor_x, mode)
896 |
897 | # Convert back to CHW if required
898 | equirec = _nhwc2nchw(equirec) if channels_first else equirec
899 | return equirec
900 |
901 |
902 | def e2c(
903 | e_img: torch.Tensor,
904 | face_w: Optional[int] = None,
905 | mode: str = "bilinear",
906 | cube_format: str = "dice",
907 | channels_first: bool = True,
908 | ) -> Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]:
909 | """
910 | Convert an equirectangular image to a cubemap.
911 |
912 | Args:
913 | e_img (torch.Tensor): Input equirectangular image tensor of shape
914 | [C, H, W] or [H, W, C].
915 | face_w (int, optional): Width of each square cube shaped face. If set to None,
916 | then face_w will be calculated as H // 2. Default: None
917 | mode (str, optional): Sampling interpolation mode, 'nearest',
918 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
919 | cube_format (str, optional): The desired output cubemap format. Options
920 | are 'dict', 'list', 'horizon', 'stack', and 'dice'. Default: 'dice'
921 | The list of options are:
922 | - 'stack' (torch.Tensor): Stack of 6 faces, in the order
923 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
924 | - 'list' (list of torch.Tensor): List of 6 faces, in the order
925 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
926 | - 'dict' (dict of torch.Tensor): Dictionary with keys pointing to
927 | face tensors. Dict keys are:
928 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
929 | - 'dice' (torch.Tensor): A cubemap in a 'dice' layout.
930 | - 'horizon' (torch.Tensor): A cubemap in a 'horizon' layout,
931 | a 1x6 grid in the order:
932 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'].
933 | channels_first (bool, optional): The channel format of e_img. PyTorch
934 | uses channels first. Default: 'True'
935 |
936 | Returns:
937 | Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]:
938 | A cubemap in the specified format.
939 |
940 | Raises:
941 | NotImplementedError: If an unknown cube_format is provided.
942 | """
943 | assert e_img.dim() == 3 or e_img.dim() == 4, (
944 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], "
945 | f"or [H,W,C], got shape of: {e_img.shape}"
946 | )
947 |
948 | e_img = _nchw2nhwc(e_img) if channels_first else e_img
949 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3]
950 | face_w = h // 2 if face_w is None else face_w
951 |
952 | # returns [face_w, face_w*6, 3] in order
953 | # [Front, Right, Back, Left, Up, Down]
954 | xyz = xyzcube(face_w, device=e_img.device, dtype=e_img.dtype)
955 | uv = xyz2uv(xyz)
956 | coor_xy = uv2coor(uv, h, w)
957 | # Sample all channels:
958 | out_c = sample_equirec(e_img, coor_xy, mode) # [face_w, 6*face_w, C]
959 | # out_c shape: we did it directly for each pixel in the cube map
960 |
961 | result: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], None] = (
962 | None
963 | )
964 | if cube_format == "horizon":
965 | result = out_c
966 | elif cube_format == "list" or cube_format == "stack":
967 | result = cube_h2list(out_c)
968 | elif cube_format == "dict":
969 | result = cube_h2dict(out_c)
970 | elif cube_format == "dice":
971 | result = cube_h2dice(out_c)
972 | else:
973 | raise NotImplementedError("unknown cube_format")
974 |
975 | # Convert to CHW if required
976 | if channels_first:
977 | if cube_format == "list" or cube_format == "stack":
978 | assert isinstance(result, (list, tuple))
979 | result = [_nhwc2nchw(r) for r in result]
980 | elif cube_format == "dict":
981 | assert torch.jit.isinstance(result, Dict[str, torch.Tensor])
982 | result = {k: _nhwc2nchw(v) for k, v in result.items()} # type: ignore[union-attr]
983 | elif cube_format in ["horizon", "dice"]:
984 | assert isinstance(result, torch.Tensor)
985 | result = _nhwc2nchw(result)
986 | if cube_format == "stack" and isinstance(result, (list, tuple)):
987 | result = torch.stack(result)
988 | return result
989 |
990 |
991 | def e2p(
992 | e_img: torch.Tensor,
993 | fov_deg: Union[float, Tuple[float, float]],
994 | h_deg: float = 0.0,
995 | v_deg: float = 0.0,
996 | out_hw: Tuple[int, int] = (512, 512),
997 | in_rot_deg: float = 0.0,
998 | mode: str = "bilinear",
999 | channels_first: bool = True,
1000 | ) -> torch.Tensor:
1001 | """
1002 | Convert an equirectangular image to a perspective projection.
1003 |
1004 | Args:
1005 | e_img (torch.Tensor): Input equirectangular image tensor in the shape
1006 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W]
1007 | or [B, H, W, C].
1008 | fov_deg (float or tuple of floats, optional): Field of view in degrees.
1009 | Can be a single float or (h_fov, v_fov) tuple.
1010 | h_deg (float, optional): Horizontal rotation angle in degrees
1011 | (-Left/+Right). Default: 0.0
1012 | v_deg (float, optional): Vertical rotation angle in degrees
1013 | (-Down/+Up). Default: 0.0
1014 | out_hw (tuple of int, optional): The output image size specified as
1015 | a tuple of (height, width). Default: (512, 512)
1016 | in_rot_deg (float, optional): Input rotation angle in degrees.
1017 | Default: 0.0
1018 | mode (str, optional): Sampling interpolation mode, 'nearest',
1019 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
1020 | channels_first (bool, optional): The channel format of e_img. PyTorch
1021 | uses channels first. Default: 'True'
1022 |
1023 | Returns:
1024 | torch.Tensor: Perspective projection image tensor.
1025 | """
1026 | assert e_img.dim() == 3 or e_img.dim() == 4, (
1027 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], "
1028 | f"or [H,W,C], got shape of: {e_img.shape}"
1029 | )
1030 |
1031 | # Ensure input is in HWC format for processing
1032 | e_img = _nchw2nhwc(e_img) if channels_first else e_img
1033 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3]
1034 |
1035 | if isinstance(fov_deg, (list, tuple)):
1036 | h_fov_rad = fov_deg[0] * torch.pi / 180
1037 | v_fov_rad = fov_deg[1] * torch.pi / 180
1038 | else:
1039 | fov = fov_deg * torch.pi / 180
1040 | h_fov_rad = fov
1041 | v_fov_rad = fov
1042 |
1043 | in_rot = in_rot_deg * torch.pi / 180
1044 |
1045 | u = -h_deg * torch.pi / 180
1046 | v = v_deg * torch.pi / 180
1047 |
1048 | xyz = xyzpers(
1049 | h_fov_rad,
1050 | v_fov_rad,
1051 | u,
1052 | v,
1053 | out_hw,
1054 | in_rot,
1055 | device=e_img.device,
1056 | dtype=e_img.dtype,
1057 | )
1058 | uv = xyz2uv(xyz)
1059 | coor_xy = uv2coor(uv, h, w)
1060 |
1061 | pers_img = sample_equirec(e_img, coor_xy, mode)
1062 |
1063 | # Convert back to CHW if required
1064 | pers_img = _nhwc2nchw(pers_img) if channels_first else pers_img
1065 | return pers_img
1066 |
1067 |
1068 | def e2e(
1069 | e_img: torch.Tensor,
1070 | roll: float = 0.0,
1071 | h_deg: float = 0.0,
1072 | v_deg: float = 0.0,
1073 | mode: str = "bilinear",
1074 | channels_first: bool = True,
1075 | ) -> torch.Tensor:
1076 | """
1077 | Apply rotations to an equirectangular image along the roll, pitch, and yaw
1078 | axes.
1079 |
1080 | This function rotates an equirectangular image tensor along the roll
1081 | (x-axis), pitch (y-axis), and yaw (z-axis) axes, which correspond to
1082 | rotations that produce vertical and horizontal shifts in the image.
1083 |
1084 | Args:
1085 | e_img (torch.Tensor): Input equirectangular image tensor in the shape
1086 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W]
1087 | or [B, H, W, C].
1088 | roll (float, optional): Roll angle in degrees. Rotates the image along
1089 | the x-axis. Roll directions: (-counter_clockwise/+clockwise).
1090 | Default: 0.0
1091 | h_deg (float, optional): Yaw angle in degrees (-left/+right). Rotates the
1092 | image along the z-axis to produce a horizontal shift. Default: 0.0
1093 | v_deg (float, optional): Pitch angle in degrees (-down/+up). Rotates the
1094 | image along the y-axis to produce a vertical shift. Default: 0.0
1095 | mode (str, optional): Sampling interpolation mode, 'nearest',
1096 | 'bicubic', or 'bilinear'. Default: 'bilinear'.
1097 | channels_first (bool, optional): The channel format of e_img. PyTorch
1098 | uses channels first. Default: 'True'
1099 |
1100 | Returns:
1101 | torch.Tensor: Modified equirectangular image.
1102 | """
1103 |
1104 | roll = roll
1105 | yaw = h_deg
1106 | pitch = v_deg
1107 |
1108 | assert e_img.dim() == 3 or e_img.dim() == 4, (
1109 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], "
1110 | f"or [H,W,C], got shape of: {e_img.shape}"
1111 | )
1112 |
1113 | # Ensure input is in HWC format for processing
1114 | e_img = _nchw2nhwc(e_img) if channels_first else e_img
1115 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3]
1116 |
1117 | # Convert angles to radians
1118 | roll_rad = torch.tensor(
1119 | [roll * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype
1120 | )
1121 | pitch_rad = torch.tensor(
1122 | [pitch * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype
1123 | )
1124 | yaw_rad = torch.tensor(
1125 | [yaw * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype
1126 | )
1127 |
1128 | # Create base coordinates for the output image
1129 | y_range = torch.linspace(0, h - 1, h, device=e_img.device, dtype=e_img.dtype)
1130 | x_range = torch.linspace(0, w - 1, w, device=e_img.device, dtype=e_img.dtype)
1131 | grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij")
1132 |
1133 | # Convert pixel coordinates to spherical coordinates
1134 | uv = coor2uv(torch.stack([grid_x, grid_y], dim=-1), h, w)
1135 |
1136 | # Convert to unit vectors on sphere
1137 | xyz = uv2unitxyz(uv)
1138 |
1139 | # Create rotation matrices
1140 | Rx = rotation_matrix(
1141 | roll_rad, torch.tensor([1.0, 0.0, 0.0], device=e_img.device, dtype=e_img.dtype)
1142 | )
1143 | Ry = rotation_matrix(
1144 | yaw_rad, torch.tensor([0.0, 0.0, 1.0], device=e_img.device, dtype=e_img.dtype)
1145 | )
1146 | Rz = rotation_matrix(
1147 | pitch_rad, torch.tensor([0.0, 1.0, 0.0], device=e_img.device, dtype=e_img.dtype)
1148 | )
1149 |
1150 | # Apply rotations: first roll, then pitch, then yaw
1151 | xyz_rot = xyz @ Rx @ Ry @ Rz
1152 |
1153 | # Convert back to UV coordinates
1154 | uv_rot = xyz2uv(xyz_rot)
1155 |
1156 | # Convert UV coordinates to pixel coordinates
1157 | coor_xy = uv2coor(uv_rot, h, w)
1158 |
1159 | # Sample from the input image
1160 | rotated = sample_equirec(e_img, coor_xy, mode=mode)
1161 |
1162 | # Return to original channel format if needed
1163 | rotated = _nhwc2nchw(rotated) if channels_first else rotated
1164 | return rotated
1165 |
1166 |
1167 | def pad_180_to_360(
1168 | e_img: torch.Tensor, fill_value: float = 0.0, channels_first: bool = True
1169 | ) -> torch.Tensor:
1170 | """
1171 | Pads a 180 degree equirectangular image tensor with a shape of CHW or NCHW, to
1172 | make it a full 360 degree image, by padding to the left and right sides.
1173 |
1174 | For an image of width W (covering 180 degrees), the full panorama requires the
1175 | total image width to be doubled. Padding is evenly split between the left and
1176 | right sides.
1177 |
1178 | Args:
1179 | e_img (torch.Tensor): Input equirectangular image tensor in the shape
1180 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W]
1181 | or [B, H, W, C].
1182 | fill_value (int, float): The constant value for padding. Default: 0.0
1183 | channels_first (bool, optional): The channel format of e_img. PyTorch
1184 | uses channels first. Default: 'True'
1185 |
1186 | Returns:
1187 | torch.Tensor: The padded tensor.
1188 | """
1189 | assert e_img.dim() in [3, 4]
1190 | e_img = _nhwc2nchw(e_img) if not channels_first else e_img
1191 | H, W = e_img.shape[1:] if e_img.dim() == 3 else e_img.shape[2:]
1192 | pad_left = W // 2
1193 | pad_right = W - pad_left
1194 |
1195 | if e_img.ndim == 3:
1196 | e_img_padded = F.pad(
1197 | e_img.unsqueeze(0), (pad_left, pad_right), mode="constant", value=fill_value
1198 | ).squeeze(0)
1199 | elif e_img.ndim == 4:
1200 | e_img_padded = F.pad(
1201 | e_img, (pad_left, pad_right), mode="constant", value=fill_value
1202 | )
1203 | else:
1204 | raise ValueError(
1205 | "e_img must be either 3D (CHW) or 4D (NCHW), got {e_img.shape}"
1206 | )
1207 |
1208 | e_img_padded = _nchw2nhwc(e_img_padded) if not channels_first else e_img_padded
1209 | return e_img_padded
1210 |
--------------------------------------------------------------------------------
/pytorch360convert/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.2"
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # Load the current project version
6 | exec(open("pytorch360convert/version.py").read())
7 |
8 |
9 | # Convert relative image links to full links for PyPI
10 | def _relative_to_full_link(long_description: str) -> str:
11 | """
12 | Converts relative image links in a README to full GitHub URLs.
13 |
14 | This function replaces relative image links (e.g., in
tags and
15 | Markdown ![]() syntax) with their corresponding full GitHub URLs, appending
16 | `?raw=true` for direct access to raw images.
17 |
18 | Links are only replaced if they point to the 'examples' directory, and are
19 | in the format of: `
` or
20 | ``.
21 |
22 | Args:
23 | long_description (str): The text containing relative image links.
24 |
25 | Returns:
26 | str: The modified text with full image URLs.
27 | """
28 | import re
29 |
30 | # Base URL for raw GitHub links
31 | github_base_url = "https://github.com/ProGamerGov/pytorch360convert/raw/main/"
32 |
33 | # Replace relative links in
34 | long_description = re.sub(
35 | r'(
=1.8.0",
85 | ],
86 | packages=find_packages(exclude=("tests", "tests.*")),
87 | classifiers=[
88 | "Development Status :: 4 - Beta",
89 | "Intended Audience :: Developers",
90 | "Intended Audience :: Education",
91 | "Intended Audience :: Science/Research",
92 | "Intended Audience :: Information Technology",
93 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
94 | "Topic :: Scientific/Engineering :: Image Processing",
95 | "Topic :: Software Development",
96 | "Topic :: Software Development :: Libraries",
97 | "Topic :: Games/Entertainment",
98 | "Topic :: Multimedia :: Graphics :: Viewers",
99 | "License :: OSI Approved :: MIT License",
100 | "Programming Language :: Python :: 3",
101 | "Programming Language :: Python :: 3.8",
102 | "Programming Language :: Python :: 3.9",
103 | "Programming Language :: Python :: 3.10",
104 | "Programming Language :: Python :: 3.11",
105 | "Programming Language :: Python :: 3.12",
106 | "Programming Language :: Python :: 3.13",
107 | "Programming Language :: Python :: 3.14",
108 | ],
109 | )
110 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Module tests
2 |
--------------------------------------------------------------------------------
/tests/test_module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import math
3 | import random
4 | import unittest
5 |
6 | import numpy as np
7 |
8 | import torch
9 | from pytorch360convert.pytorch360convert import (
10 | _face_slice,
11 | _nchw2nhwc,
12 | _nhwc2nchw,
13 | _slice_chunk,
14 | c2e,
15 | coor2uv,
16 | cube_dice2h,
17 | cube_dict2h,
18 | cube_h2dice,
19 | cube_h2dict,
20 | cube_h2list,
21 | cube_list2h,
22 | e2c,
23 | e2e,
24 | e2p,
25 | equirect_facetype,
26 | equirect_uvgrid,
27 | grid_sample_wrap,
28 | pad_180_to_360,
29 | pad_cube_faces,
30 | rotation_matrix,
31 | sample_cubefaces,
32 | uv2coor,
33 | uv2unitxyz,
34 | xyz2uv,
35 | xyzcube,
36 | xyzpers,
37 | )
38 |
39 |
40 | def assertTensorAlmostEqual(
41 | self, actual: torch.Tensor, expected: torch.Tensor, delta: float = 0.0001
42 | ) -> None:
43 | """
44 | Args:
45 |
46 | self (): A unittest instance.
47 | actual (torch.Tensor): A tensor to compare with expected.
48 | expected (torch.Tensor): A tensor to compare with actual.
49 | delta (float, optional): The allowed difference between actual and expected.
50 | Default: 0.0001
51 | """
52 | self.assertEqual(actual.shape, expected.shape)
53 | self.assertEqual(actual.device, expected.device)
54 | self.assertEqual(actual.dtype, expected.dtype)
55 | self.assertAlmostEqual(
56 | torch.sum(torch.abs(actual - expected)).item(), 0.0, delta=delta
57 | )
58 |
59 |
60 | def _create_test_faces(face_height: int = 512, face_width: int = 512) -> torch.Tensor:
61 | # Create unique colors for faces (6 colors)
62 | face_colors = [
63 | [0.0, 0.0, 0.0],
64 | [0.2, 0.2, 0.2],
65 | [0.4, 0.4, 0.4],
66 | [0.6, 0.6, 0.6],
67 | [0.8, 0.8, 0.8],
68 | [1.0, 1.0, 1.0],
69 | ]
70 | face_colors = torch.as_tensor(face_colors).view(6, 3, 1, 1)
71 |
72 | # Create and color faces (6 squares)
73 | faces = torch.ones([6, 3] + [face_height, face_width]) * face_colors
74 | return faces
75 |
76 |
77 | def _create_dice_layout(
78 | faces: torch.Tensor, face_h: int = 512, face_w: int = 512
79 | ) -> torch.Tensor:
80 | H, W = face_h, face_w
81 | cube = torch.zeros((3, 3 * H, 4 * W))
82 | cube[:, 0 : 1 * H, W : 2 * W] = faces[4]
83 | cube[:, 1 * H : 2 * H, 0:W] = faces[3]
84 | cube[:, 1 * H : 2 * H, W : 2 * W] = faces[0]
85 | cube[:, 1 * H : 2 * H, 2 * W : 3 * W] = faces[1]
86 | cube[:, 1 * H : 2 * H, 3 * W : 4 * W] = faces[2]
87 | cube[:, 2 * H : 3 * H, W : 2 * W] = faces[5]
88 | return cube
89 |
90 |
91 | def _get_c2e_4x4_exact_tensor() -> torch.Tensor:
92 | a = 0.4000000059604645
93 | b = 0.6000000238418579
94 | c = 0.507179856300354
95 | d = 0.0000
96 | e = 0.09061708301305771
97 | f = 0.20000000298023224
98 | g = 0.36016160249710083
99 |
100 | # Create the expected middle part for each of the 3 matrices
101 | middle = [a, a, b, b, b, c, d, d, d, e, f, f, f, g, a, a]
102 | middle = torch.tensor(middle).unsqueeze(0) # Shape (1, 16)
103 |
104 | # Create the base output for the tensor
105 | expected_output = torch.zeros(8, 16)
106 |
107 | # Fill the first and last rows with the constants
108 | expected_output[0:2] = 0.800000011920929
109 | expected_output[6:8] = 1.0000
110 |
111 | # Fill the middle rows with the specific values
112 | expected_output[2:6] = middle
113 |
114 | # Exact values for the last row (row 7)
115 | last_row = [
116 | 0.7569681406021118,
117 | 0.8475320339202881,
118 | 1.0,
119 | 0.8708105087280273,
120 | 0.8414928317070007,
121 | 0.791767954826355,
122 | 0.9714627265930176,
123 | 0.6305789947509766,
124 | 0.6305789947509766,
125 | 0.5338935256004333,
126 | 0.8733490109443665,
127 | 0.6829856634140015,
128 | 0.7416210174560547,
129 | 0.19919204711914062,
130 | 0.8475320935249329,
131 | 0.7569681406021118,
132 | ]
133 | expected_output[5] = torch.tensor(last_row)
134 |
135 | # Now, create the 3 matrices by stacking the result
136 | expected_output = torch.stack([expected_output] * 3, dim=0)
137 |
138 | return expected_output
139 |
140 |
141 | def _get_e2c_4x4_exact_tensor() -> torch.Tensor:
142 | f = [[1.0, 1.2951672077178955, 1.7048327922821045, 2.0]] * 4
143 | r = [[2.0, 2.2951672077178955, 2.7048325538635254, 3.0]] * 4
144 | b = [[3.0, 3.0, 3.0, 0.0]] * 4
145 | l = [[0.0, 0.29516735672950745, 0.7048328518867493, 1.0]] * 4
146 | u = [
147 | [0.0, 3.0, 3.0, 3.0],
148 | [0.29516735672950745, 0.0, 3.0, 2.7048325538635254],
149 | [0.7048328518867493, 1.0, 2.0, 2.2951672077178955],
150 | [1.0, 1.2951672077178955, 1.7048327922821045, 2.0],
151 | ]
152 | d = u[::-1]
153 | expected_out = torch.stack(
154 | [
155 | torch.tensor(f).repeat(3, 1, 1),
156 | torch.tensor(r).repeat(3, 1, 1),
157 | torch.tensor(b).repeat(3, 1, 1),
158 | torch.tensor(l).repeat(3, 1, 1),
159 | torch.tensor(u).repeat(3, 1, 1),
160 | torch.tensor(d).repeat(3, 1, 1),
161 | ]
162 | )
163 | return expected_out
164 |
165 |
166 | class TestFunctionsBaseTest(unittest.TestCase):
167 | def setUp(self) -> None:
168 | seed = 1234
169 | random.seed(seed)
170 | np.random.seed(seed)
171 | torch.manual_seed(seed)
172 | torch.cuda.manual_seed_all(seed)
173 | torch.backends.cudnn.deterministic = True
174 |
175 | def test_rotation_matrix(self) -> None:
176 | # Test identity rotation (0 radians around any axis)
177 | axis = torch.tensor([1.0, 0.0, 0.0])
178 | angle = torch.tensor([0.0])
179 | result = rotation_matrix(angle, axis)
180 | expected = torch.eye(3)
181 | torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6)
182 |
183 | def test_rotation_matrix_90deg(self) -> None:
184 | # Test 90-degree rotation around x-axis
185 | axis = torch.tensor([1.0, 0.0, 0.0])
186 | angle = torch.tensor([math.pi / 2])
187 | result = rotation_matrix(angle, axis)
188 | expected = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
189 | torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6)
190 |
191 | # Test rotation matrix properties
192 | # Should be orthogonal (R * R.T = I)
193 | result_t = result.t()
194 | identity = torch.mm(result, result_t)
195 | torch.testing.assert_close(identity, torch.eye(3), rtol=1e-6, atol=1e-6)
196 |
197 | def test_nhwc_to_nchw_channels_first(self) -> None:
198 | input_tensor = torch.rand(2, 3, 4, 5)
199 | converted_tensor = _nhwc2nchw(input_tensor)
200 | self.assertEqual(converted_tensor.shape, (2, 5, 3, 4))
201 |
202 | def test_nchw_to_nhwc_channels_first(self) -> None:
203 | input_tensor = torch.rand(2, 5, 3, 4)
204 | converted_tensor = _nchw2nhwc(input_tensor)
205 | self.assertEqual(converted_tensor.shape, (2, 3, 4, 5))
206 |
207 | def test_slice_chunk_default(self) -> None:
208 | index = 2
209 | width = 3
210 | offset = 0
211 | expected = torch.tensor([6, 7, 8], dtype=torch.long)
212 | result = _slice_chunk(index, width, offset)
213 | torch.testing.assert_close(result, expected)
214 |
215 | def test_slice_chunk_with_offset(self) -> None:
216 | # Test with a non-zero offset
217 | index = 2
218 | width = 3
219 | offset = 1
220 | expected = torch.tensor([7, 8, 9], dtype=torch.long)
221 | result = _slice_chunk(index, width, offset)
222 | torch.testing.assert_close(result, expected)
223 |
224 | def test_slice_chunk_gpu(self) -> None:
225 | if not torch.cuda.is_available():
226 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
227 | index = 2
228 | width = 3
229 | offset = 0
230 | expected = torch.tensor([6, 7, 8], dtype=torch.long).cuda()
231 | result = _slice_chunk(index, width, offset, device=expected.device)
232 | torch.testing.assert_close(result, expected)
233 | self.assertTrue(result.is_cuda)
234 |
235 | def test_face_slice(self) -> None:
236 | # Test _face_slice, which internally calls _slice_chunk
237 | index = 2
238 | face_w = 3
239 | expected = torch.tensor([6, 7, 8], dtype=torch.long)
240 | result = _face_slice(index, face_w)
241 | torch.testing.assert_close(result, expected)
242 |
243 | def test_face_slice_gpu(self) -> None:
244 | if not torch.cuda.is_available():
245 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
246 | # Test _face_slice, which internally calls _slice_chunk
247 | index = 2
248 | face_w = 3
249 | expected = torch.tensor([6, 7, 8], dtype=torch.long).cuda()
250 | result = _face_slice(index, face_w, device=expected.device)
251 | torch.testing.assert_close(result, expected)
252 | self.assertTrue(result.is_cuda)
253 |
254 | def test_xyzcube(self) -> None:
255 | face_w = 4
256 | result = xyzcube(face_w)
257 |
258 | # Check shape
259 | self.assertEqual(result.shape, (face_w, face_w * 6, 3))
260 |
261 | # Check that coordinates are normalized (-0.5 to 0.5)
262 | self.assertTrue(torch.all(result >= -0.5))
263 | self.assertTrue(torch.all(result <= 0.5))
264 |
265 | # Test front face center point (adjusting for coordinate system)
266 | center_idx = face_w // 2
267 | front_center = result[center_idx, center_idx]
268 | expected_front = torch.tensor([0.0, 0.0, 0.5])
269 | torch.testing.assert_close(front_center, expected_front, rtol=0.17, atol=0.17)
270 |
271 | def test_cube_h2list(self) -> None:
272 | # Create a mock tensor with a shape [w, w*6, C]
273 | w = 3 # width of the cube face
274 | C = 2 # number of channels (e.g., RGB)
275 | cube_h = torch.randn(w, w * 6, C) # Random tensor with dimensions [3, 18, 2]
276 |
277 | # Call the function
278 | result = cube_h2list(cube_h)
279 |
280 | # Assert that the result is a list of 6 tensors (one for each face)
281 | self.assertEqual(len(result), 6)
282 |
283 | # Assert each tensor has the correct shape [w, w, C]
284 | for tensor in result:
285 | self.assertEqual(tensor.shape, (w, w, C))
286 |
287 | # Ensure the shapes are sliced correctly
288 | for i in range(6):
289 | self.assertTrue(torch.equal(result[i], cube_h[:, i * w : (i + 1) * w, :]))
290 |
291 | def test_cube_h2dict(self) -> None:
292 | # Create a mock tensor with a shape [w, w*6, C]
293 | w = 3 # width of the cube face
294 | C = 2 # number of channels (e.g., RGB)
295 | cube_h = torch.randn(w, w * 6, C) # Random tensor with dimensions [3, 18, 2]
296 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
297 |
298 | # Call the function
299 | result = cube_h2dict(cube_h, face_keys)
300 |
301 | # Assert that the result is a dictionary with 6 entries
302 | self.assertEqual(len(result), 6)
303 |
304 | # Assert that the dictionary keys are correct
305 | self.assertEqual(list(result.keys()), face_keys)
306 |
307 | # Assert each tensor has the correct shape [w, w, C]
308 | for face in face_keys:
309 | self.assertEqual(result[face].shape, (w, w, C))
310 |
311 | # Check that the values correspond to the expected slices of the input tensor
312 | for i, face in enumerate(face_keys):
313 | self.assertTrue(
314 | torch.equal(result[face], cube_h[:, i * w : (i + 1) * w, :])
315 | )
316 |
317 | def test_equirect_uvgrid(self) -> None:
318 | h, w = 8, 16
319 | result = equirect_uvgrid(h, w)
320 |
321 | # Check shape
322 | self.assertEqual(result.shape, (h, w, 2))
323 |
324 | # Check ranges
325 | u = result[..., 0]
326 | v = result[..., 1]
327 | self.assertTrue(torch.all(u >= -torch.pi))
328 | self.assertTrue(torch.all(u <= torch.pi))
329 | self.assertTrue(torch.all(v >= -torch.pi / 2))
330 | self.assertTrue(torch.all(v <= torch.pi / 2))
331 |
332 | # Check center point
333 | center_h, center_w = h // 2, w // 2
334 | center_point = result[center_h, center_w]
335 | expected_center = torch.tensor([0.0, 0.0])
336 | torch.testing.assert_close(
337 | center_point, expected_center, rtol=0.225, atol=0.225
338 | )
339 |
340 | def test_equirect_facetype(self) -> None:
341 | h, w = 8, 16
342 | result = equirect_facetype(h, w)
343 |
344 | # Check shape
345 | self.assertEqual(result.shape, (h, w))
346 |
347 | # Check face type range (0-5 for 6 faces)
348 | self.assertTrue(torch.all(result >= 0))
349 | self.assertTrue(torch.all(result <= 5))
350 |
351 | # Check sum
352 | self.assertEqual(result.sum().item(), 384.0)
353 |
354 | # Check dtype
355 | self.assertEqual(result.dtype, torch.int64)
356 |
357 | def test_equirect_facetype_large(self) -> None:
358 | h, w = 512 * 2, 512 * 4
359 | result = equirect_facetype(h, w)
360 |
361 | # Check shape
362 | self.assertEqual(result.shape, (h, w))
363 |
364 | # Check face type range (0-5 for 6 faces)
365 | self.assertTrue(torch.all(result >= 0))
366 | self.assertTrue(torch.all(result <= 5))
367 |
368 | # Check sum
369 | self.assertEqual(result.sum().item(), 6510864.0)
370 |
371 | # Check dtype
372 | self.assertEqual(result.dtype, torch.int64)
373 |
374 | def test_equirect_facetype_float64(self) -> None:
375 | h, w = 8, 16
376 | result = equirect_facetype(h, w, dtype=torch.float64)
377 |
378 | # Check shape
379 | self.assertEqual(result.shape, (h, w))
380 |
381 | # Check face type range (0-5 for 6 faces)
382 | self.assertTrue(torch.all(result >= 0))
383 | self.assertTrue(torch.all(result <= 5))
384 |
385 | # Check sum
386 | self.assertEqual(result.sum().item(), 384.0)
387 |
388 | # Check dtype
389 | self.assertEqual(result.dtype, torch.int64)
390 |
391 | def test_equirect_facetype_float16(self) -> None:
392 | h, w = 8, 16
393 | result = equirect_facetype(h, w, dtype=torch.float16)
394 |
395 | # Check shape
396 | self.assertEqual(result.shape, (h, w))
397 |
398 | # Check face type range (0-5 for 6 faces)
399 | self.assertTrue(torch.all(result >= 0))
400 | self.assertTrue(torch.all(result <= 5))
401 |
402 | # Check sum
403 | self.assertEqual(result.sum().item(), 384.0)
404 |
405 | # Check dtype
406 | self.assertEqual(result.dtype, torch.int64)
407 |
408 | def test_equirect_facetype_gpu(self) -> None:
409 | if not torch.cuda.is_available():
410 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
411 | h, w = 8, 16
412 | result = equirect_facetype(h, w, device=torch.device("cuda"))
413 |
414 | # Check shape
415 | self.assertEqual(result.shape, (h, w))
416 |
417 | # Check face type range (0-5 for 6 faces)
418 | self.assertTrue(torch.all(result >= 0))
419 | self.assertTrue(torch.all(result <= 5))
420 |
421 | # Check sum
422 | self.assertEqual(result.sum().item(), 384.0)
423 |
424 | # Check dtype
425 | self.assertEqual(result.dtype, torch.int64)
426 |
427 | # Check cuda
428 | self.assertTrue(result.is_cuda)
429 |
430 | def test_xyz2uv_and_uv2unitxyz(self) -> None:
431 | # Create test points
432 | xyz = torch.tensor(
433 | [
434 | [1.0, 0.0, 0.0], # right
435 | [0.0, 1.0, 0.0], # up
436 | [0.0, 0.0, 1.0], # front
437 | ]
438 | )
439 |
440 | # Convert xyz to uv
441 | uv = xyz2uv(xyz)
442 |
443 | # Convert back to xyz
444 | xyz_reconstructed = uv2unitxyz(uv)
445 |
446 | # Normalize input xyz for comparison
447 | xyz_normalized = torch.nn.functional.normalize(xyz, dim=-1)
448 |
449 | # Verify reconstruction
450 | torch.testing.assert_close(
451 | xyz_normalized, xyz_reconstructed, rtol=1e-6, atol=1e-6
452 | )
453 |
454 | def test_uv2coor_and_coor2uv(self) -> None:
455 | h, w = 8, 16
456 | # Create test UV coordinates
457 | test_uv = torch.tensor(
458 | [
459 | [0.0, 0.0], # center
460 | [torch.pi / 2, 0.0], # right quadrant
461 | [-torch.pi / 2, 0.0], # left quadrant
462 | ]
463 | )
464 |
465 | # Convert UV to image coordinates
466 | coor = uv2coor(test_uv, h, w)
467 |
468 | # Convert back to UV
469 | uv_reconstructed = coor2uv(coor, h, w)
470 |
471 | # Verify reconstruction
472 | torch.testing.assert_close(test_uv, uv_reconstructed, rtol=1e-5, atol=1e-5)
473 |
474 | def test_grid_sample_wrap(self) -> None:
475 | # Create test image
476 | h, w = 4, 8
477 | channels = 3
478 | image = torch.arange(h * w * channels, dtype=torch.float32)
479 | image = image.reshape(h, w, channels)
480 |
481 | # Test basic sampling
482 | coor_x = torch.tensor([[1.5, 2.5], [3.5, 4.5]])
483 | coor_y = torch.tensor([[1.5, 1.5], [2.5, 2.5]])
484 |
485 | # Test both interpolation modes
486 | result_bilinear = grid_sample_wrap(image, coor_x, coor_y, mode="bilinear")
487 | result_nearest = grid_sample_wrap(image, coor_x, coor_y, mode="nearest")
488 |
489 | # Check shapes
490 | self.assertEqual(result_bilinear.shape, (2, 2, channels))
491 | self.assertEqual(result_nearest.shape, (2, 2, channels))
492 |
493 | # Test horizontal wrapping
494 | wrap_x = torch.tensor([[w - 1.5, 0.5]])
495 | wrap_y = torch.tensor([[1.5, 1.5]])
496 | result_wrap = grid_sample_wrap(image, wrap_x, wrap_y, mode="bilinear")
497 |
498 | # Check that wrapped coordinates produce similar values
499 | # We use a larger tolerance here due to interpolation differences
500 | torch.testing.assert_close(
501 | result_wrap[0, 0],
502 | result_wrap[0, 1],
503 | rtol=0.5,
504 | atol=0.5,
505 | )
506 |
507 | def test_grid_sample_wrap_cpu_float16(self) -> None:
508 | # Create test image
509 | h, w = 4, 8
510 | channels = 3
511 | image = torch.arange(h * w * channels, dtype=torch.float16)
512 | image = image.reshape(h, w, channels)
513 |
514 | # Test basic sampling
515 | coor_x = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float16)
516 | coor_y = torch.tensor([[1.5, 1.5], [2.5, 2.5]], dtype=torch.float16)
517 |
518 | # Test both interpolation modes
519 | result_bilinear = grid_sample_wrap(image, coor_x, coor_y, mode="bilinear")
520 | self.assertEqual(result_bilinear.dtype, torch.float16)
521 |
522 | def test_sample_cubefaces_cpu_float16(self) -> None:
523 | # Face type tensor (which face to sample)
524 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float16) # Random face types
525 |
526 | # Coordinates for sampling
527 | coor_y = torch.tensor(
528 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float16
529 | ) # y-coordinates
530 | coor_x = torch.tensor(
531 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float16
532 | ) # x-coordinates
533 |
534 | mode = "bilinear"
535 |
536 | # Call sample_cubefaces
537 | output = sample_cubefaces(
538 | torch.ones([6, 8, 8, 3], dtype=torch.float16), tp, coor_y, coor_x, mode
539 | )
540 | self.assertEqual(output.dtype, tp.dtype)
541 |
542 | def test_sample_cubefaces(self) -> None:
543 | # Face type tensor (which face to sample)
544 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32) # Random face types
545 |
546 | # Coordinates for sampling
547 | coor_y = torch.tensor(
548 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32
549 | ) # y-coordinates
550 | coor_x = torch.tensor(
551 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32
552 | ) # x-coordinates
553 |
554 | mode = "bilinear"
555 |
556 | # Call sample_cubefaces
557 | output = sample_cubefaces(torch.ones(6, 8, 8, 3), tp, coor_y, coor_x, mode)
558 | self.assertEqual(output.sum().item(), 12.0)
559 |
560 | def test_c2e_then_e2c(self) -> None:
561 | face_width = 512
562 | test_faces = _create_test_faces(face_width, face_width)
563 | equi_img = c2e(
564 | test_faces,
565 | face_width * 2,
566 | face_width * 4,
567 | mode="bilinear",
568 | cube_format="stack",
569 | )
570 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
571 | cubic_img = e2c(
572 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack"
573 | )
574 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
575 | # assertTensorAlmostEqual(self, cubic_img, test_faces)
576 |
577 | def test_c2e_then_e2c_gpu(self) -> None:
578 | if not torch.cuda.is_available():
579 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
580 | face_width = 512
581 | test_faces = _create_test_faces(face_width, face_width).cuda()
582 | equi_img = c2e(
583 | test_faces,
584 | face_width * 2,
585 | face_width * 4,
586 | mode="bilinear",
587 | cube_format="stack",
588 | )
589 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
590 | cubic_img = e2c(
591 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack"
592 | )
593 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
594 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr]
595 | # assertTensorAlmostEqual(self, cubic_img, test_faces)
596 |
597 | def test_c2e_stack_grad(self) -> None:
598 | face_width = 512
599 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True)
600 | equi_img = c2e(
601 | test_faces,
602 | face_width * 2,
603 | face_width * 4,
604 | mode="bilinear",
605 | cube_format="stack",
606 | )
607 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
608 | self.assertTrue(equi_img.requires_grad)
609 |
610 | def test_e2c_stack_grad(self) -> None:
611 | face_width = 512
612 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True)
613 | cubic_img = e2c(
614 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
615 | )
616 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
617 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr]
618 |
619 | def test_c2e_then_e2c_stack_grad(self) -> None:
620 | face_width = 512
621 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True)
622 | equi_img = c2e(
623 | test_faces,
624 | face_width * 2,
625 | face_width * 4,
626 | mode="bilinear",
627 | cube_format="stack",
628 | )
629 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
630 | cubic_img = e2c(
631 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack"
632 | )
633 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
634 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr]
635 |
636 | def test_c2e_list_grad(self) -> None:
637 | face_width = 512
638 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True)
639 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])]
640 | equi_img = c2e(
641 | test_faces,
642 | face_width * 2,
643 | face_width * 4,
644 | mode="bilinear",
645 | cube_format="list",
646 | )
647 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
648 | self.assertTrue(equi_img.requires_grad)
649 |
650 | def test_e2c_list_grad(self) -> None:
651 | face_width = 512
652 | equi_img = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True)
653 | cubic_img = e2c(
654 | equi_img, face_w=face_width, mode="bilinear", cube_format="list"
655 | )
656 | for i in range(6):
657 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore[index]
658 | for i in range(6):
659 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore[index]
660 |
661 | def test_c2e_then_e2c_list_grad(self) -> None:
662 | face_width = 512
663 | test_faces_tensors = torch.ones(
664 | [6, 3, face_width, face_width], requires_grad=True
665 | )
666 | test_faces = [test_faces_tensors[i] for i in range(test_faces_tensors.shape[0])]
667 | equi_img = c2e(
668 | test_faces,
669 | face_width * 2,
670 | face_width * 4,
671 | mode="bilinear",
672 | cube_format="list",
673 | )
674 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
675 | cubic_img = e2c(
676 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack"
677 | )
678 | for i in range(6):
679 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore[index]
680 | for i in range(6):
681 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore[index]
682 |
683 | def test_c2e_dict_grad(self) -> None:
684 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
685 | face_width = 512
686 | test_faces_tensors = torch.ones(
687 | [6, 3, face_width, face_width], requires_grad=True
688 | )
689 | test_faces = {k: test_faces_tensors[i] for i, k in zip(range(6), dict_keys)}
690 | equi_img = c2e(
691 | test_faces,
692 | face_width * 2,
693 | face_width * 4,
694 | mode="bilinear",
695 | cube_format="dict",
696 | )
697 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
698 | self.assertTrue(equi_img.requires_grad)
699 |
700 | def test_c2e_then_e2c_dict_grad(self) -> None:
701 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
702 | face_width = 512
703 | test_faces_tensors = torch.ones(
704 | [6, 3, face_width, face_width], requires_grad=True
705 | )
706 | test_faces = {k: test_faces_tensors[i] for i, k in zip(range(6), dict_keys)}
707 | equi_img = c2e(
708 | test_faces,
709 | face_width * 2,
710 | face_width * 4,
711 | mode="bilinear",
712 | cube_format="dict",
713 | )
714 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
715 | cubic_img = e2c(
716 | equi_img, face_w=face_width, mode="bilinear", cube_format="dict"
717 | )
718 | for i in dict_keys:
719 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore
720 | for i in dict_keys:
721 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore
722 |
723 | def test_c2e_stack_nohw_grad(self) -> None:
724 | face_width = 512
725 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True)
726 | equi_img = c2e(
727 | test_faces,
728 | mode="bilinear",
729 | cube_format="stack",
730 | )
731 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
732 | self.assertTrue(equi_img.requires_grad)
733 |
734 | def test_sample_cubefaces_py360convert(self) -> None:
735 | try:
736 | import py360convert as p360
737 | except:
738 | raise unittest.SkipTest(
739 | "py360convert not installed, skipping sample_cubefaces test"
740 | )
741 | # Face type tensor (which face to sample)
742 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32) # Random face types
743 |
744 | # Coordinates for sampling
745 | coor_y = torch.tensor(
746 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32
747 | ) # y-coordinates
748 | coor_x = torch.tensor(
749 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32
750 | ) # x-coordinates
751 |
752 | mode = "bilinear"
753 |
754 | # Call sample_cubefaces
755 | output = sample_cubefaces(torch.ones(6, 8, 8, 3), tp, coor_y, coor_x, mode)
756 | output_np = p360.sample_cubefaces(
757 | torch.ones(6, 8, 8).numpy(),
758 | tp.numpy(),
759 | coor_y.numpy(),
760 | coor_x.numpy(),
761 | mode,
762 | )
763 | self.assertEqual(output.sum(), output_np.sum() * 3)
764 |
765 | def test_c2e_py360convert(self) -> None:
766 | try:
767 | import py360convert as p360
768 | except:
769 | raise unittest.SkipTest("py360convert not installed, skipping c2e test")
770 |
771 | face_width = 512
772 | test_faces = _create_test_faces(face_width, face_width)
773 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])]
774 | equi_img = c2e(
775 | test_faces,
776 | face_width * 2,
777 | face_width * 4,
778 | mode="bilinear",
779 | cube_format="list",
780 | )
781 | test_faces_np = [t.permute(1, 2, 0).numpy() for t in test_faces]
782 |
783 | equi_img_np = p360.c2e(
784 | test_faces_np,
785 | face_width * 2,
786 | face_width * 4,
787 | mode="bilinear",
788 | cube_format="list",
789 | )
790 | equi_img_np_tensor = torch.from_numpy(equi_img_np).permute(2, 0, 1).float()
791 | assertTensorAlmostEqual(self, equi_img, equi_img_np_tensor, 2722.8169)
792 |
793 | def test_c2e_then_e2c_py360convert(self) -> None:
794 | try:
795 | import py360convert as p360
796 | except:
797 | raise unittest.SkipTest(
798 | "py360convert not installed, skipping c2e and e2c test"
799 | )
800 |
801 | face_width = 512
802 | test_faces = _create_test_faces(face_width, face_width)
803 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])]
804 | equi_img = c2e(
805 | test_faces,
806 | face_width * 2,
807 | face_width * 4,
808 | mode="bilinear",
809 | cube_format="list",
810 | )
811 | test_faces_np = [t.permute(1, 2, 0).numpy() for t in test_faces]
812 |
813 | equi_img_np = p360.c2e(
814 | test_faces_np,
815 | face_width * 2,
816 | face_width * 4,
817 | mode="bilinear",
818 | cube_format="list",
819 | )
820 |
821 | cubic_img = e2c(
822 | equi_img, face_w=face_width, mode="bilinear", cube_format="dice"
823 | )
824 |
825 | cubic_img_np = p360.e2c(
826 | equi_img_np, face_w=face_width, mode="bilinear", cube_format="dice"
827 | )
828 | cubic_img_np_tensor = torch.from_numpy(cubic_img_np).permute(2, 0, 1).float()
829 |
830 | assertTensorAlmostEqual(self, cubic_img, cubic_img_np_tensor, 5858.8921) # type: ignore[arg-type]
831 |
832 | def test_e2c_horizon_grad(self) -> None:
833 | face_width = 512
834 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True)
835 | cubic_img = e2c(
836 | test_faces, face_w=face_width, mode="bilinear", cube_format="horizon"
837 | )
838 | self.assertEqual(list(cubic_img.shape), [3, face_width, face_width * 6]) # type: ignore[union-attr]
839 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr]
840 |
841 | def test_e2c_dice_grad(self) -> None:
842 | face_width = 512
843 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True)
844 | cubic_img = e2c(
845 | test_faces, face_w=face_width, mode="bilinear", cube_format="dice"
846 | )
847 | self.assertEqual(list(cubic_img.shape), [3, face_width * 3, face_width * 4]) # type: ignore[union-attr]
848 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr]
849 |
850 | def test_c2e_then_e2c_dice_grad(self) -> None:
851 | face_width = 512
852 | test_faces = torch.ones([3, face_width * 3, face_width * 4], requires_grad=True)
853 | equi_img = c2e(
854 | test_faces,
855 | face_width * 2,
856 | face_width * 4,
857 | mode="bilinear",
858 | cube_format="dice",
859 | )
860 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
861 | cubic_img = e2c(
862 | equi_img, face_w=face_width, mode="bilinear", cube_format="dice"
863 | )
864 | self.assertEqual(list(cubic_img.shape), [3, face_width * 3, face_width * 4]) # type: ignore[union-attr]
865 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr]
866 |
867 | def test_e2p(self) -> None:
868 | # Create a simple test equirectangular image
869 | h, w = 64, 128
870 | channels = 3
871 | e_img = torch.zeros((channels, h, w))
872 | # Add some recognizable pattern
873 | e_img[0, :, :] = torch.linspace(0, 1, w).repeat(
874 | h, 1
875 | ) # Red gradient horizontally
876 | e_img[1, :, :] = (
877 | torch.linspace(0, 1, h).unsqueeze(1).repeat(1, w)
878 | ) # Green gradient vertically
879 |
880 | # Test basic perspective projection
881 | fov_deg = 90.0
882 | u_deg = 0.0
883 | v_deg = 0.0
884 | out_hw = (32, 32)
885 |
886 | result = e2p(e_img, fov_deg, u_deg, v_deg, out_hw)
887 |
888 | # Check output shape
889 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
890 |
891 | # Test with different FOV
892 | narrow_fov = e2p(e_img, 45.0, u_deg, v_deg, out_hw)
893 | wide_fov = e2p(e_img, 120.0, u_deg, v_deg, out_hw)
894 |
895 | # Narrow FOV should have less variation in values than wide FOV
896 | self.assertTrue(torch.std(narrow_fov) < torch.std(wide_fov))
897 |
898 | # Test with rotation
899 | rotated = e2p(e_img, fov_deg, 90.0, v_deg, out_hw) # 90 degrees right
900 |
901 | # Test with different output sizes
902 | large_output = e2p(e_img, fov_deg, u_deg, v_deg, (64, 64))
903 | self.assertEqual(list(large_output.shape), [channels, 64, 64])
904 |
905 | # Test with rectangular output
906 | rect_output = e2p(e_img, fov_deg, u_deg, v_deg, (32, 64))
907 | self.assertEqual(list(rect_output.shape), [channels, 32, 64])
908 |
909 | # Test with different FOV for height and width
910 | fov_hw = (90.0, 60.0)
911 | diff_fov = e2p(e_img, fov_hw, u_deg, v_deg, out_hw)
912 | self.assertEqual(list(diff_fov.shape), [channels, out_hw[0], out_hw[1]])
913 |
914 | def test_e2c_stack_face_w_none(self) -> None:
915 | face_width = 512
916 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True)
917 | cubic_img = e2c(test_faces, face_w=None, mode="bilinear", cube_format="stack")
918 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
919 |
920 | def test_e2c_stack_gpu(self) -> None:
921 | if not torch.cuda.is_available():
922 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
923 | face_width = 512
924 | test_faces = torch.ones([3, face_width * 2, face_width * 4]).cuda()
925 | cubic_img = e2c(
926 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
927 | )
928 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
929 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr]
930 |
931 | def test_c2e_stack_gpu(self) -> None:
932 | if not torch.cuda.is_available():
933 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
934 | face_width = 512
935 | test_faces = torch.ones([6, 3, face_width, face_width]).cuda()
936 | equi_img = c2e(
937 | test_faces,
938 | face_width * 2,
939 | face_width * 4,
940 | mode="bilinear",
941 | cube_format="stack",
942 | )
943 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
944 | self.assertTrue(equi_img.is_cuda)
945 |
946 | def test_e2c_stack_float16(self) -> None:
947 | face_width = 512
948 | test_faces = torch.ones(
949 | [3, face_width * 2, face_width * 4], dtype=torch.float16
950 | )
951 | cubic_img = e2c(
952 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
953 | )
954 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
955 | self.assertEqual(cubic_img.dtype, torch.float16) # type: ignore[union-attr]
956 |
957 | def test_e2c_stack_float64(self) -> None:
958 | face_width = 512
959 | test_faces = torch.ones(
960 | [3, face_width * 2, face_width * 4], dtype=torch.float64
961 | )
962 | cubic_img = e2c(
963 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
964 | )
965 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
966 | self.assertEqual(cubic_img.dtype, torch.float64) # type: ignore[union-attr]
967 |
968 | def test_c2e_stack_float16(self) -> None:
969 | face_width = 512
970 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=torch.float16)
971 | equi_img = c2e(
972 | test_faces,
973 | face_width * 2,
974 | face_width * 4,
975 | mode="bilinear",
976 | cube_format="stack",
977 | )
978 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
979 | self.assertEqual(equi_img.dtype, torch.float16)
980 |
981 | def test_c2e_stack_float64(self) -> None:
982 | face_width = 512
983 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=torch.float64)
984 | equi_img = c2e(
985 | test_faces,
986 | face_width * 2,
987 | face_width * 4,
988 | mode="bilinear",
989 | cube_format="stack",
990 | )
991 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
992 | self.assertEqual(equi_img.dtype, torch.float64)
993 |
994 | def test_e2p_gpu(self) -> None:
995 | if not torch.cuda.is_available():
996 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
997 | # Create a simple test equirectangular image
998 | h, w = 64, 128
999 | channels = 3
1000 | e_img = torch.zeros((channels, h, w)).cuda()
1001 |
1002 | fov_deg = 90.0
1003 | h_deg = 0.0
1004 | v_deg = 0.0
1005 | out_hw = (32, 32)
1006 |
1007 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1008 |
1009 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1010 | self.assertTrue(result.is_cuda)
1011 |
1012 | def test_e2p_grad(self) -> None:
1013 | # Create a simple test equirectangular image
1014 | h, w = 64, 128
1015 | channels = 3
1016 | e_img = torch.zeros((channels, h, w), requires_grad=True)
1017 |
1018 | fov_deg = 90.0
1019 | h_deg = 0.0
1020 | v_deg = 0.0
1021 | out_hw = (32, 32)
1022 |
1023 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1024 |
1025 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1026 | self.assertTrue(result.requires_grad)
1027 |
1028 | def test_e2p_float16(self) -> None:
1029 | # Create a simple test equirectangular image
1030 | h, w = 64, 128
1031 | channels = 3
1032 | e_img = torch.zeros((channels, h, w), dtype=torch.float16)
1033 |
1034 | fov_deg = 90.0
1035 | h_deg = 0.0
1036 | v_deg = 0.0
1037 | out_hw = (32, 32)
1038 |
1039 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1040 |
1041 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1042 | self.assertEqual(result.dtype, torch.float16)
1043 |
1044 | def test_e2p_float64(self) -> None:
1045 | # Create a simple test equirectangular image
1046 | h, w = 64, 128
1047 | channels = 3
1048 | e_img = torch.zeros((channels, h, w), dtype=torch.float64)
1049 |
1050 | fov_deg = 90.0
1051 | h_deg = 0.0
1052 | v_deg = 0.0
1053 | out_hw = (32, 32)
1054 |
1055 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1056 |
1057 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1058 | self.assertEqual(result.dtype, torch.float64)
1059 |
1060 | def test_c2e_stack_1channel(self) -> None:
1061 | channels = 1
1062 | face_width = 512
1063 | test_faces = torch.ones(
1064 | [6, channels, face_width, face_width], dtype=torch.float64
1065 | )
1066 | equi_img = c2e(
1067 | test_faces,
1068 | face_width * 2,
1069 | face_width * 4,
1070 | mode="bilinear",
1071 | cube_format="stack",
1072 | )
1073 | self.assertEqual(
1074 | list(equi_img.shape), [channels, face_width * 2, face_width * 4]
1075 | )
1076 |
1077 | def test_c2e_stack_4channels(self) -> None:
1078 | channels = 4
1079 | face_width = 512
1080 | test_faces = torch.ones(
1081 | [6, channels, face_width, face_width], dtype=torch.float64
1082 | )
1083 | equi_img = c2e(
1084 | test_faces,
1085 | face_width * 2,
1086 | face_width * 4,
1087 | mode="bilinear",
1088 | cube_format="stack",
1089 | )
1090 | self.assertEqual(
1091 | list(equi_img.shape), [channels, face_width * 2, face_width * 4]
1092 | )
1093 |
1094 | def test_e2c_stack_1channel(self) -> None:
1095 | channels = 1
1096 | face_width = 512
1097 | test_faces = torch.ones([channels, face_width * 2, face_width * 4])
1098 | cubic_img = e2c(
1099 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
1100 | )
1101 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width]) # type: ignore[union-attr]
1102 |
1103 | def test_e2c_stack_4channels(self) -> None:
1104 | channels = 4
1105 | face_width = 512
1106 | test_faces = torch.ones([channels, face_width * 2, face_width * 4])
1107 | cubic_img = e2c(
1108 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
1109 | )
1110 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width]) # type: ignore[union-attr]
1111 |
1112 | def test_e2p_1channel(self) -> None:
1113 | h, w = 64, 128
1114 | channels = 1
1115 | e_img = torch.zeros((channels, h, w))
1116 |
1117 | fov_deg = 90.0
1118 | h_deg = 0.0
1119 | v_deg = 0.0
1120 | out_hw = (32, 32)
1121 |
1122 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1123 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1124 |
1125 | def test_e2p_4channels(self) -> None:
1126 | h, w = 64, 128
1127 | channels = 4
1128 | e_img = torch.zeros((channels, h, w))
1129 |
1130 | fov_deg = 90.0
1131 | h_deg = 0.0
1132 | v_deg = 0.0
1133 | out_hw = (32, 32)
1134 |
1135 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1136 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1137 |
1138 | def test_c2e_stack_jit(self) -> None:
1139 | channels = 3
1140 | face_width = 512
1141 | test_faces = torch.ones(
1142 | [6, channels, face_width, face_width], dtype=torch.float64
1143 | )
1144 |
1145 | c2e_jit = torch.jit.script(c2e)
1146 | equi_img = c2e_jit(
1147 | test_faces,
1148 | face_width * 2,
1149 | face_width * 4,
1150 | mode="bilinear",
1151 | cube_format="stack",
1152 | )
1153 | self.assertEqual(
1154 | list(equi_img.shape), [channels, face_width * 2, face_width * 4]
1155 | )
1156 |
1157 | def test_e2c_stack_jit(self) -> None:
1158 | channels = 3
1159 | face_width = 512
1160 | test_faces = torch.ones([channels, face_width * 2, face_width * 4])
1161 | e2c_jit = torch.jit.script(e2c)
1162 | cubic_img = e2c_jit(
1163 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
1164 | )
1165 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width])
1166 |
1167 | def test_e2p_jit(self) -> None:
1168 | h, w = 64, 128
1169 | channels = 3
1170 | e_img = torch.zeros((channels, h, w))
1171 |
1172 | fov_deg = 90.0
1173 | h_deg = 0.0
1174 | v_deg = 0.0
1175 | out_hw = (32, 32)
1176 |
1177 | e2p_jit = torch.jit.script(e2p)
1178 |
1179 | result = e2p_jit(e_img, fov_deg, h_deg, v_deg, out_hw)
1180 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1181 |
1182 | def test_e2p_exact(self) -> None:
1183 | channels = 3
1184 | out_hw = (2, 4)
1185 | x_input0 = torch.arange(1, 256).repeat(1, 256, 2).float()
1186 | x_input1 = torch.arange(1, 256).repeat(1, 256, 2).float() * 2.0
1187 | x_input2 = torch.arange(1, 256).repeat(1, 256, 2).float() * 4.0
1188 | x_input = torch.cat([x_input0, x_input1, x_input2], 0)
1189 |
1190 | result = e2p(
1191 | x_input,
1192 | fov_deg=(40, 60),
1193 | h_deg=10,
1194 | v_deg=50,
1195 | out_hw=out_hw,
1196 | mode="bilinear",
1197 | )
1198 |
1199 | a = torch.tensor(
1200 | [
1201 | [
1202 | 183.03811645507812,
1203 | 225.49948120117188,
1204 | 58.833866119384766,
1205 | 101.29522705078125,
1206 | ],
1207 | [
1208 | 243.3969268798828,
1209 | 5.628509521484375,
1210 | 23.704801559448242,
1211 | 40.9364013671875,
1212 | ],
1213 | ]
1214 | )
1215 | b = torch.tensor(
1216 | [
1217 | [
1218 | 366.07623291015625,
1219 | 450.99896240234375,
1220 | 117.66773223876953,
1221 | 202.5904541015625,
1222 | ],
1223 | [
1224 | 486.7938537597656,
1225 | 11.25701904296875,
1226 | 47.409603118896484,
1227 | 81.872802734375,
1228 | ],
1229 | ]
1230 | )
1231 | c = torch.tensor(
1232 | [
1233 | [
1234 | 732.1524658203125,
1235 | 901.9979248046875,
1236 | 235.33546447753906,
1237 | 405.180908203125,
1238 | ],
1239 | [
1240 | 973.5877075195312,
1241 | 22.5140380859375,
1242 | 94.81920623779297,
1243 | 163.74560546875,
1244 | ],
1245 | ]
1246 | )
1247 | expected_output = torch.stack([a, b, c])
1248 |
1249 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1250 | self.assertTrue(torch.allclose(result, expected_output))
1251 |
1252 | def test_e2p_exact_batch(self) -> None:
1253 | batch = 2
1254 | channels = 3
1255 | out_hw = (2, 4)
1256 | x_input0 = torch.arange(1, 256).repeat(1, 256, 2).float()
1257 | x_input1 = torch.arange(1, 256).repeat(1, 256, 2).float() * 2.0
1258 | x_input2 = torch.arange(1, 256).repeat(1, 256, 2).float() * 4.0
1259 | x_input = torch.cat([x_input0, x_input1, x_input2], 0).repeat(batch, 1, 1, 1)
1260 |
1261 | result = e2p(
1262 | x_input,
1263 | fov_deg=(40, 60),
1264 | h_deg=10,
1265 | v_deg=50,
1266 | out_hw=out_hw,
1267 | mode="bilinear",
1268 | )
1269 |
1270 | a = torch.tensor(
1271 | [
1272 | [
1273 | 183.03811645507812,
1274 | 225.49948120117188,
1275 | 58.833866119384766,
1276 | 101.29522705078125,
1277 | ],
1278 | [
1279 | 243.3969268798828,
1280 | 5.628509521484375,
1281 | 23.704801559448242,
1282 | 40.9364013671875,
1283 | ],
1284 | ]
1285 | )
1286 | b = torch.tensor(
1287 | [
1288 | [
1289 | 366.07623291015625,
1290 | 450.99896240234375,
1291 | 117.66773223876953,
1292 | 202.5904541015625,
1293 | ],
1294 | [
1295 | 486.7938537597656,
1296 | 11.25701904296875,
1297 | 47.409603118896484,
1298 | 81.872802734375,
1299 | ],
1300 | ]
1301 | )
1302 | c = torch.tensor(
1303 | [
1304 | [
1305 | 732.1524658203125,
1306 | 901.9979248046875,
1307 | 235.33546447753906,
1308 | 405.180908203125,
1309 | ],
1310 | [
1311 | 973.5877075195312,
1312 | 22.5140380859375,
1313 | 94.81920623779297,
1314 | 163.74560546875,
1315 | ],
1316 | ]
1317 | )
1318 | expected_output = torch.stack([a, b, c]).repeat(batch, 1, 1, 1)
1319 |
1320 | self.assertEqual(list(result.shape), [batch, channels, out_hw[0], out_hw[1]])
1321 | self.assertTrue(torch.allclose(result, expected_output))
1322 |
1323 | def test_c2e_stack_exact(self) -> None:
1324 | expected_output = _get_c2e_4x4_exact_tensor()
1325 | tile_w = 4
1326 | x_input = _create_test_faces(tile_w, tile_w)
1327 | output_cubic_tensor = c2e(x_input, mode="bilinear", cube_format="stack")
1328 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output))
1329 |
1330 | def test_c2e_list_exact(self) -> None:
1331 | expected_output = _get_c2e_4x4_exact_tensor()
1332 | tile_w = 4
1333 | x_input = _create_test_faces(tile_w, tile_w)
1334 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
1335 | x_input_list = [x_input[i] for i in range(6)]
1336 | output_cubic_tensor = c2e(x_input_list, mode="bilinear", cube_format="list")
1337 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output))
1338 |
1339 | def test_c2e_dict_exact(self) -> None:
1340 | expected_output = _get_c2e_4x4_exact_tensor()
1341 | tile_w = 4
1342 | x_input = _create_test_faces(tile_w, tile_w)
1343 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"]
1344 | x_input_dict = {k: x_input[i] for i, k in zip(range(6), dict_keys)}
1345 | output_cubic_tensor = c2e(x_input_dict, mode="bilinear", cube_format="dict")
1346 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output))
1347 |
1348 | def test_c2e_horizon_exact(self) -> None:
1349 | expected_output = _get_c2e_4x4_exact_tensor()
1350 | tile_w = 4
1351 | x_input = _create_test_faces(tile_w, tile_w)
1352 | x_input_horizon = torch.cat([x_input[i] for i in range(6)], 2)
1353 | output_cubic_tensor = c2e(
1354 | x_input_horizon, mode="bilinear", cube_format="horizon"
1355 | )
1356 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output))
1357 |
1358 | def test_c2e_dice_exact(self) -> None:
1359 | expected_output = _get_c2e_4x4_exact_tensor()
1360 | tile_w = 4
1361 | x_input = _create_test_faces(tile_w, tile_w)
1362 | x_input = _create_dice_layout(x_input, tile_w, tile_w)
1363 | output_cubic_tensor = c2e(x_input, mode="bilinear", cube_format="dice")
1364 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output))
1365 |
1366 | def test_e2c_stack_exact(self) -> None:
1367 | x_input = torch.arange(0, 4).repeat(3, 2, 1).float()
1368 | output_tensor = e2c(x_input, face_w=4, mode="bilinear", cube_format="stack")
1369 | expected_output = _get_e2c_4x4_exact_tensor()
1370 | self.assertTrue(torch.allclose(output_tensor, expected_output)) # type: ignore[arg-type]
1371 |
1372 | def test_e2e_float16(self) -> None:
1373 | channels = 4
1374 | face_width = 512
1375 | test_equi = torch.ones(
1376 | [channels, face_width * 2, face_width * 4], dtype=torch.float16
1377 | )
1378 | equi_img = e2e(
1379 | test_equi,
1380 | h_deg=45,
1381 | v_deg=45,
1382 | roll=25,
1383 | mode="bilinear",
1384 | )
1385 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1386 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1387 |
1388 | def test_e2e_float64(self) -> None:
1389 | channels = 4
1390 | face_width = 512
1391 | test_equi = torch.ones(
1392 | [channels, face_width * 2, face_width * 4], dtype=torch.float64
1393 | )
1394 | equi_img = e2e(
1395 | test_equi,
1396 | h_deg=45,
1397 | v_deg=45,
1398 | roll=25,
1399 | mode="bilinear",
1400 | )
1401 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1402 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1403 |
1404 | def test_e2e_1channel(self) -> None:
1405 | channels = 1
1406 | face_width = 512
1407 | test_equi = torch.ones([channels, face_width * 2, face_width * 4])
1408 | equi_img = e2e(
1409 | test_equi,
1410 | h_deg=45,
1411 | v_deg=45,
1412 | roll=25,
1413 | mode="bilinear",
1414 | )
1415 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1416 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1417 |
1418 | def test_e2e_4channels(self) -> None:
1419 | channels = 4
1420 | face_width = 512
1421 | test_equi = torch.ones([channels, face_width * 2, face_width * 4])
1422 | equi_img = e2e(
1423 | test_equi,
1424 | h_deg=45,
1425 | v_deg=45,
1426 | roll=25,
1427 | mode="bilinear",
1428 | )
1429 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1430 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1431 |
1432 | def test_e2e_grad(self) -> None:
1433 | channels = 3
1434 | face_width = 512
1435 | test_equi = torch.ones(
1436 | [channels, face_width * 2, face_width * 4], requires_grad=True
1437 | )
1438 | equi_img = e2e(
1439 | test_equi,
1440 | h_deg=45,
1441 | v_deg=45,
1442 | roll=25,
1443 | mode="bilinear",
1444 | )
1445 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1446 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1447 | self.assertTrue(equi_img.requires_grad)
1448 |
1449 | def test_e2e_gpu(self) -> None:
1450 | if not torch.cuda.is_available():
1451 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
1452 | channels = 3
1453 | face_width = 512
1454 | test_equi = torch.ones([channels, face_width * 2, face_width * 4]).cuda()
1455 | equi_img = e2e(
1456 | test_equi,
1457 | h_deg=45,
1458 | v_deg=45,
1459 | roll=25,
1460 | mode="bilinear",
1461 | )
1462 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1463 | self.assertEqual(equi_img.dtype, test_equi.dtype)
1464 | self.assertTrue(equi_img.is_cuda)
1465 |
1466 | def test_e2e_jit(self) -> None:
1467 | channels = 3
1468 | face_width = 512
1469 | test_equi = torch.ones([channels, face_width * 2, face_width * 4])
1470 | e2e_jit = torch.jit.script(e2e)
1471 | equi_img = e2e_jit(
1472 | test_equi,
1473 | h_deg=45,
1474 | v_deg=45,
1475 | roll=25,
1476 | mode="bilinear",
1477 | )
1478 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1479 |
1480 | def test_e2e_batch(self) -> None:
1481 | batch = 2
1482 | channels = 4
1483 | face_width = 512
1484 | test_equi = torch.ones([batch, channels, face_width * 2, face_width * 4])
1485 | equi_img = e2e(
1486 | test_equi,
1487 | h_deg=45,
1488 | v_deg=45,
1489 | roll=25,
1490 | mode="bilinear",
1491 | )
1492 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1493 |
1494 | def test_e2e_exact(self) -> None:
1495 | channels = 1
1496 | face_width = 2
1497 |
1498 | w = face_width * 2
1499 | h = face_width * 4
1500 | test_equi = torch.arange(h * w * channels, dtype=torch.float32)
1501 | test_equi = test_equi.reshape(channels, h, w)
1502 |
1503 | result = e2e(
1504 | test_equi,
1505 | h_deg=45,
1506 | v_deg=45,
1507 | roll=25,
1508 | mode="bilinear",
1509 | )
1510 | self.assertEqual(list(result.shape), list(test_equi.shape))
1511 |
1512 | expected_output = torch.tensor(
1513 | [
1514 | [
1515 | [
1516 | 8.846327781677246,
1517 | 7.38915491104126,
1518 | 10.024271011352539,
1519 | 11.211331367492676,
1520 | ],
1521 | [
1522 | 9.07547378540039,
1523 | 3.879967451095581,
1524 | 12.109222412109375,
1525 | 15.09968376159668,
1526 | ],
1527 | [
1528 | 10.49519157409668,
1529 | 2.552384614944458,
1530 | 14.621706008911133,
1531 | 19.000062942504883,
1532 | ],
1533 | [
1534 | 12.717361450195312,
1535 | 4.980112552642822,
1536 | 17.24430274963379,
1537 | 22.87851905822754,
1538 | ],
1539 | [
1540 | 15.3826265335083,
1541 | 8.457935333251953,
1542 | 19.71427345275879,
1543 | 26.661039352416992,
1544 | ],
1545 | [
1546 | 18.184263229370117,
1547 | 12.159286499023438,
1548 | 21.694225311279297,
1549 | 29.683902740478516,
1550 | ],
1551 | [
1552 | 20.887590408325195,
1553 | 15.91322135925293,
1554 | 22.679805755615234,
1555 | 29.112091064453125,
1556 | ],
1557 | [
1558 | 20.43504524230957,
1559 | 19.638233184814453,
1560 | 22.215164184570312,
1561 | 23.207149505615234,
1562 | ],
1563 | ]
1564 | ]
1565 | )
1566 | self.assertTrue(torch.allclose(result, expected_output))
1567 |
1568 | def test_e2e_exact_batch(self) -> None:
1569 | batch = 2
1570 | channels = 1
1571 | face_width = 2
1572 |
1573 | w = face_width * 2
1574 | h = face_width * 4
1575 | test_equi = torch.arange(batch * h * w * channels, dtype=torch.float32)
1576 | test_equi = test_equi.reshape(batch, channels, h, w)
1577 |
1578 | result = e2e(
1579 | test_equi,
1580 | h_deg=45,
1581 | v_deg=45,
1582 | roll=25,
1583 | mode="bilinear",
1584 | )
1585 | self.assertEqual(list(result.shape), list(test_equi.shape))
1586 |
1587 | expected_output = torch.tensor(
1588 | [
1589 | [
1590 | [
1591 | [
1592 | 8.846327781677246,
1593 | 7.38915491104126,
1594 | 10.024271011352539,
1595 | 11.211331367492676,
1596 | ],
1597 | [
1598 | 9.07547378540039,
1599 | 3.879967451095581,
1600 | 12.109222412109375,
1601 | 15.09968376159668,
1602 | ],
1603 | [
1604 | 10.49519157409668,
1605 | 2.552384614944458,
1606 | 14.621706008911133,
1607 | 19.000062942504883,
1608 | ],
1609 | [
1610 | 12.717361450195312,
1611 | 4.980112552642822,
1612 | 17.24430274963379,
1613 | 22.87851905822754,
1614 | ],
1615 | [
1616 | 15.3826265335083,
1617 | 8.457935333251953,
1618 | 19.71427345275879,
1619 | 26.661039352416992,
1620 | ],
1621 | [
1622 | 18.184263229370117,
1623 | 12.159286499023438,
1624 | 21.694225311279297,
1625 | 29.683902740478516,
1626 | ],
1627 | [
1628 | 20.887590408325195,
1629 | 15.91322135925293,
1630 | 22.679805755615234,
1631 | 29.112091064453125,
1632 | ],
1633 | [
1634 | 20.43504524230957,
1635 | 19.638233184814453,
1636 | 22.215164184570312,
1637 | 23.207149505615234,
1638 | ],
1639 | ]
1640 | ],
1641 | [
1642 | [
1643 | [
1644 | 40.84632873535156,
1645 | 39.38915252685547,
1646 | 42.02427291870117,
1647 | 43.21133041381836,
1648 | ],
1649 | [
1650 | 41.07547378540039,
1651 | 35.879966735839844,
1652 | 44.109222412109375,
1653 | 47.09968566894531,
1654 | ],
1655 | [
1656 | 42.49519348144531,
1657 | 34.5523796081543,
1658 | 46.6217041015625,
1659 | 51.000064849853516,
1660 | ],
1661 | [
1662 | 44.71736145019531,
1663 | 36.9801139831543,
1664 | 49.244300842285156,
1665 | 54.87852096557617,
1666 | ],
1667 | [
1668 | 47.382625579833984,
1669 | 40.45793533325195,
1670 | 51.71427536010742,
1671 | 58.66103744506836,
1672 | ],
1673 | [
1674 | 50.18426513671875,
1675 | 44.15928649902344,
1676 | 53.6942253112793,
1677 | 61.683902740478516,
1678 | ],
1679 | [
1680 | 52.88758850097656,
1681 | 47.9132194519043,
1682 | 54.679805755615234,
1683 | 61.112091064453125,
1684 | ],
1685 | [
1686 | 52.43504333496094,
1687 | 51.63823318481445,
1688 | 54.21516418457031,
1689 | 55.207149505615234,
1690 | ],
1691 | ]
1692 | ],
1693 | ]
1694 | )
1695 | self.assertEqual(result.shape, expected_output.shape)
1696 | self.assertTrue(torch.allclose(result, expected_output))
1697 |
1698 | def test_c2e_stack_bfloat16(self) -> None:
1699 | dtype = torch.bfloat16
1700 | face_width = 512
1701 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=dtype)
1702 | equi_img = c2e(
1703 | test_faces,
1704 | face_width * 2,
1705 | face_width * 4,
1706 | mode="bilinear",
1707 | cube_format="stack",
1708 | )
1709 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
1710 | self.assertEqual(equi_img.dtype, dtype)
1711 |
1712 | def test_e2c_stack_bfloat16(self) -> None:
1713 | dtype = torch.bfloat16
1714 | face_width = 512
1715 | test_faces = torch.ones([3, face_width * 2, face_width * 4], dtype=dtype)
1716 | cubic_img = e2c(
1717 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
1718 | )
1719 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
1720 | self.assertEqual(cubic_img.dtype, dtype) # type: ignore[union-attr]
1721 |
1722 | def test_e2p_bfloat16(self) -> None:
1723 | # Create a simple test equirectangular image
1724 | dtype = torch.bfloat16
1725 | h, w = 64, 128
1726 | channels = 3
1727 | e_img = torch.zeros((channels, h, w), dtype=dtype)
1728 |
1729 | fov_deg = 90.0
1730 | h_deg = 0.0
1731 | v_deg = 0.0
1732 | out_hw = (32, 32)
1733 |
1734 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1735 |
1736 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1737 | self.assertEqual(result.dtype, dtype)
1738 |
1739 | def test_e2e_bfloat16(self) -> None:
1740 | dtype = torch.bfloat16
1741 | channels = 4
1742 | face_width = 512
1743 | test_equi = torch.ones([channels, face_width * 2, face_width * 4], dtype=dtype)
1744 | equi_img = e2e(
1745 | test_equi,
1746 | h_deg=45,
1747 | v_deg=45,
1748 | roll=25,
1749 | mode="bilinear",
1750 | )
1751 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1752 | self.assertEqual(equi_img.dtype, dtype)
1753 |
1754 | def test_c2e_stack_bfloat16_cuda(self) -> None:
1755 | if not torch.cuda.is_available():
1756 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
1757 | dtype = torch.bfloat16
1758 | face_width = 512
1759 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=dtype).cuda()
1760 | equi_img = c2e(
1761 | test_faces,
1762 | face_width * 2,
1763 | face_width * 4,
1764 | mode="bilinear",
1765 | cube_format="stack",
1766 | )
1767 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4])
1768 | self.assertEqual(equi_img.dtype, dtype)
1769 | self.assertTrue(equi_img.is_cuda)
1770 |
1771 | def test_e2c_stack_bfloat16_cuda(self) -> None:
1772 | if not torch.cuda.is_available():
1773 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
1774 | dtype = torch.bfloat16
1775 | face_width = 512
1776 | test_faces = torch.ones([3, face_width * 2, face_width * 4], dtype=dtype).cuda()
1777 | cubic_img = e2c(
1778 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack"
1779 | )
1780 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr]
1781 | self.assertEqual(cubic_img.dtype, dtype) # type: ignore[union-attr]
1782 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr]
1783 |
1784 | def test_e2p_bfloat16_cuda(self) -> None:
1785 | if not torch.cuda.is_available():
1786 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
1787 | # Create a simple test equirectangular image
1788 | dtype = torch.bfloat16
1789 | h, w = 64, 128
1790 | channels = 3
1791 | e_img = torch.zeros((channels, h, w), dtype=dtype).cuda()
1792 |
1793 | fov_deg = 90.0
1794 | h_deg = 0.0
1795 | v_deg = 0.0
1796 | out_hw = (32, 32)
1797 |
1798 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw)
1799 |
1800 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]])
1801 | self.assertEqual(result.dtype, dtype)
1802 | self.assertTrue(result.is_cuda)
1803 |
1804 | def test_e2e_bfloat16_cuda(self) -> None:
1805 | if not torch.cuda.is_available():
1806 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.")
1807 | dtype = torch.bfloat16
1808 | channels = 4
1809 | face_width = 512
1810 | test_equi = torch.ones(
1811 | [channels, face_width * 2, face_width * 4], dtype=dtype
1812 | ).cuda()
1813 | equi_img = e2e(
1814 | test_equi,
1815 | h_deg=45,
1816 | v_deg=45,
1817 | roll=25,
1818 | mode="bilinear",
1819 | )
1820 | self.assertEqual(list(equi_img.shape), list(test_equi.shape))
1821 | self.assertEqual(equi_img.dtype, dtype)
1822 | self.assertTrue(equi_img.is_cuda)
1823 |
1824 | def test_pad_180_to_360_channels_first(self) -> None:
1825 | e_img = torch.ones(3, 4, 4) # [C, H, W] = [3, 4, 4]
1826 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1827 | self.assertEqual(padded_img.shape, (3, 4, 8))
1828 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1829 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1830 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1831 |
1832 | def test_pad_180_to_360_channels_last(self) -> None:
1833 | e_img = torch.ones(4, 4, 3) # [H, W, C] = [4, 4, 3]
1834 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=False)
1835 | self.assertEqual(padded_img.shape, (4, 8, 3))
1836 | self.assertTrue(torch.all(padded_img[:, 0, :] == 0.0))
1837 | self.assertTrue(torch.all(padded_img[:, -1, :] == 0.0))
1838 | self.assertTrue(torch.all(padded_img[:, 2:6, :] == 1.0))
1839 |
1840 | def test_pad_180_to_360_batch(self) -> None:
1841 | e_img = torch.ones(2, 3, 4, 4) # [B, C, H, W] = [2, 3, 4, 4]
1842 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1843 | self.assertEqual(padded_img.shape, (2, 3, 4, 8))
1844 | self.assertTrue(torch.all(padded_img[:, :, :, 0] == 0.0))
1845 | self.assertTrue(torch.all(padded_img[:, :, :, -1] == 0.0))
1846 | self.assertTrue(torch.all(padded_img[:, :, :, 2:6] == 1.0))
1847 |
1848 | def test_pad_180_to_360_gpu(self) -> None:
1849 | if torch.cuda.is_available():
1850 | e_img = torch.ones(3, 4, 4, device="cuda") # [C, H, W] = [3, 4, 4]
1851 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1852 | self.assertTrue(padded_img.is_cuda)
1853 | self.assertEqual(padded_img.shape, (3, 4, 8))
1854 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1855 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1856 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1857 |
1858 | def test_pad_180_to_360_float16(self) -> None:
1859 | e_img = torch.ones(3, 4, 4, dtype=torch.float16)
1860 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1861 | self.assertEqual(padded_img.shape, (3, 4, 8))
1862 | self.assertEqual(padded_img.dtype, torch.float16)
1863 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1864 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1865 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1866 |
1867 | def test_pad_180_to_360_float64(self) -> None:
1868 | e_img = torch.ones(3, 4, 4, dtype=torch.float64)
1869 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1870 | self.assertEqual(padded_img.shape, (3, 4, 8))
1871 | self.assertEqual(padded_img.dtype, torch.float64)
1872 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1873 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1874 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1875 |
1876 | def test_pad_180_to_360_bfloat16(self) -> None:
1877 | e_img = torch.ones(3, 4, 4, dtype=torch.bfloat16)
1878 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1879 | self.assertEqual(padded_img.shape, (3, 4, 8))
1880 | self.assertEqual(padded_img.dtype, torch.bfloat16)
1881 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1882 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1883 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1884 |
1885 | def test_pad_180_to_360_gradient(self) -> None:
1886 | e_img = torch.ones(3, 4, 4, dtype=torch.float32, requires_grad=True)
1887 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True)
1888 | self.assertTrue(padded_img.requires_grad)
1889 |
1890 | def test_pad_180_to_360_jit(self) -> None:
1891 | e_img = torch.ones(3, 4, 4) # [C, H, W] = [3, 4, 4]
1892 | pad_180_to_360_jit = torch.jit.script(pad_180_to_360)
1893 | padded_img = pad_180_to_360_jit(e_img, fill_value=0.0, channels_first=True)
1894 | self.assertEqual(padded_img.shape, (3, 4, 8))
1895 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0))
1896 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0))
1897 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0))
1898 |
1899 | def _prepare_test_data_cube_padding(
1900 | self,
1901 | height: int = 4,
1902 | width: int = 4,
1903 | channels: int = 3,
1904 | dtype: torch.dtype = torch.float32,
1905 | device: str = "cpu",
1906 | requires_grad: bool = False,
1907 | ) -> torch.Tensor:
1908 | """Helper to create test cube faces with specific configuration"""
1909 | # Create test faces where each face has a unique value
1910 | cube_faces = torch.zeros(
1911 | (6, height, width, channels), dtype=dtype, device=device
1912 | )
1913 |
1914 | # Set each face to a unique value for easy identification
1915 | for face_idx in range(6):
1916 | cube_faces[face_idx, :, :, :] = face_idx / 5.0 # Normalize to [0, 1]
1917 |
1918 | if requires_grad:
1919 | cube_faces.requires_grad_(True)
1920 |
1921 | return cube_faces
1922 |
1923 | def test_pad_cube_faces_shape(self) -> None:
1924 | """Test that the output shape is correct"""
1925 | # Use equal height and width to avoid dimension mismatch when flipping
1926 | height, width, channels = 4, 4, 3
1927 | cube_faces = self._prepare_test_data_cube_padding(height, width, channels)
1928 |
1929 | padded_faces = pad_cube_faces(cube_faces)
1930 |
1931 | self.assertEqual(padded_faces.shape, (6, height + 2, width + 2, channels))
1932 |
1933 | def test_pad_cube_faces_jit(self) -> None:
1934 | """Test that the function works with torch.jit.script"""
1935 | cube_faces = self._prepare_test_data_cube_padding(4, 4, 3)
1936 |
1937 | # Script the function
1938 | pad_cube_faces_jit = torch.jit.script(pad_cube_faces)
1939 |
1940 | # Run the scripted function
1941 | padded_faces = pad_cube_faces_jit(cube_faces)
1942 |
1943 | # Verify output shape
1944 | self.assertEqual(padded_faces.shape, (6, 6, 6, 3))
1945 |
1946 | # Compare with non-scripted function
1947 | expected_output = pad_cube_faces(cube_faces)
1948 | self.assertTrue(torch.allclose(padded_faces, expected_output))
1949 |
1950 | def test_pad_cube_faces_gradient(self) -> None:
1951 | """Test that gradients flow through the function"""
1952 | cube_faces = self._prepare_test_data_cube_padding(requires_grad=True)
1953 |
1954 | padded_faces = pad_cube_faces(cube_faces)
1955 |
1956 | # Check that requires_grad is preserved
1957 | self.assertTrue(padded_faces.requires_grad)
1958 |
1959 | # Test gradient flow by computing a loss and backpropagating
1960 | loss = padded_faces.sum()
1961 | loss.backward()
1962 |
1963 | # Verify that the gradient exists and is not zero
1964 | self.assertIsNotNone(cube_faces.grad)
1965 | self.assertFalse(
1966 | torch.allclose(cube_faces.grad, torch.zeros_like(cube_faces.grad)) # type: ignore[arg-type]
1967 | )
1968 |
1969 | def test_pad_cube_faces_cuda(self) -> None:
1970 | """Test that the function works on CUDA device if available"""
1971 | if not torch.cuda.is_available():
1972 | self.skipTest("CUDA not available")
1973 |
1974 | cube_faces = self._prepare_test_data_cube_padding(device="cuda")
1975 |
1976 | padded_faces = pad_cube_faces(cube_faces)
1977 |
1978 | # Check that the output is on the correct device
1979 | self.assertTrue(padded_faces.is_cuda)
1980 |
1981 | # Verify shape
1982 | self.assertEqual(padded_faces.shape, (6, 6, 6, 3))
1983 |
1984 | def test_pad_cube_faces_float16(self) -> None:
1985 | """Test that the function works with float16 precision"""
1986 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float16)
1987 |
1988 | padded_faces = pad_cube_faces(cube_faces)
1989 |
1990 | # Check that dtype is preserved
1991 | self.assertEqual(padded_faces.dtype, torch.float16)
1992 |
1993 | def test_pad_cube_faces_float32(self) -> None:
1994 | """Test that the function works with float32 precision"""
1995 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float32)
1996 |
1997 | padded_faces = pad_cube_faces(cube_faces)
1998 |
1999 | # Check that dtype is preserved
2000 | self.assertEqual(padded_faces.dtype, torch.float32)
2001 |
2002 | def test_pad_cube_faces_float64(self) -> None:
2003 | """Test that the function works with float64 precision"""
2004 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float64)
2005 |
2006 | padded_faces = pad_cube_faces(cube_faces)
2007 |
2008 | # Check that dtype is preserved
2009 | self.assertEqual(padded_faces.dtype, torch.float64)
2010 |
2011 | def test_pad_cube_faces_bfloat16(self) -> None:
2012 | """Test that the function works with bfloat16 precision"""
2013 | # Skip if bfloat16 is not supported
2014 | if not hasattr(torch, "bfloat16"):
2015 | self.skipTest("bfloat16 not supported in this PyTorch version")
2016 |
2017 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.bfloat16)
2018 |
2019 | padded_faces = pad_cube_faces(cube_faces)
2020 |
2021 | # Check that dtype is preserved
2022 | self.assertEqual(padded_faces.dtype, torch.bfloat16)
2023 |
2024 | def test_pad_cube_faces_exact_values(self) -> None:
2025 | """Test the exact values of padded faces for a small input"""
2026 | # Define small cube faces with easily identifiable values
2027 | cube_faces = torch.zeros((6, 2, 2, 1))
2028 |
2029 | # FRONT=0, RIGHT=1, BACK=2, LEFT=3, UP=4, DOWN=5
2030 | for face_idx in range(6):
2031 | # Set each face to its index value for easier verification
2032 | cube_faces[face_idx, :, :, :] = face_idx
2033 |
2034 | padded_faces = pad_cube_faces(cube_faces)
2035 |
2036 | # Verify central values (should be unchanged)
2037 | for face_idx in range(6):
2038 | self.assertTrue(
2039 | torch.all(padded_faces[face_idx, 1:-1, 1:-1, :] == face_idx)
2040 | )
2041 |
2042 | # Verify specific border values based on the padding logic
2043 | # These assertions check that the padding values come from the correct neighboring faces
2044 |
2045 | # Check FRONT face padding
2046 | self.assertTrue(torch.all(padded_faces[0, 0, 1:-1, :] == 4)) # Top from UP
2047 | self.assertTrue(
2048 | torch.all(padded_faces[0, -1, 1:-1, :] == 5)
2049 | ) # Bottom from DOWN
2050 | self.assertTrue(torch.all(padded_faces[0, 1:-1, 0, :] == 3)) # Left from LEFT
2051 | self.assertTrue(
2052 | torch.all(padded_faces[0, 1:-1, -1, :] == 1)
2053 | ) # Right from RIGHT
2054 |
2055 | # Check additional values for other faces to verify correct padding
2056 | # RIGHT face
2057 | self.assertTrue(torch.all(padded_faces[1, 1:-1, 0, :] == 0)) # Left from FRONT
2058 |
2059 | # BACK face
2060 | self.assertTrue(torch.all(padded_faces[2, 1:-1, -1, :] == 3)) # Right from LEFT
2061 |
2062 | # UP face
2063 | self.assertTrue(
2064 | torch.all(padded_faces[4, -1, 1:-1, :] == 0)
2065 | ) # Bottom from FRONT
2066 |
--------------------------------------------------------------------------------