├── .gitignore ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── examples ├── basic_dice_cubemap.png ├── basic_equirectangular.png ├── example_world_map_dice_cubemap.png ├── example_world_map_equirectangular.png ├── example_world_map_equirectangular_rotated.png ├── example_world_map_horizon_cubemap.png └── example_world_map_perspective.png ├── pytorch360convert ├── __init__.py ├── pytorch360convert.py └── version.py ├── setup.py └── tests ├── __init__.py └── test_module.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | abstract: "Utilities for converting between different cubemap, equirectangular, and panoramic." 2 | authors: 3 | - family-names: Egan 4 | given-names: Ben 5 | cff-version: 1.2.0 6 | date-released: "2024-12-15" 7 | keywords: 8 | - equirectangular 9 | - panorama 10 | - "360 degrees" 11 | - "360 degree images" 12 | - cubemap 13 | - research 14 | license: MIT 15 | message: "If you use this software, please cite it using these metadata." 16 | repository-code: "https://github.com/ProGamerGov/pytorch360convert" 17 | title: "pytorch360convert" -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to pytorch360convert 2 | 3 | This project uses simple linting and testing guidelines, as you'll see below. 4 | 5 | ## Linting 6 | 7 | 8 | Linting is simple to perform. 9 | 10 | ``` 11 | pip install black flake8 mypy ufmt pytest-cov 12 | 13 | ``` 14 | 15 | Linting: 16 | 17 | ``` 18 | cd pytorch360convert 19 | black . 20 | ufmt format . 21 | cd .. 22 | ``` 23 | 24 | Checking: 25 | 26 | ``` 27 | cd pytorch360convert 28 | black --check --diff . 29 | flake8 . --ignore=E203,W503 --max-line-length=88 --exclude build,dist 30 | ufmt check . 31 | mypy . --ignore-missing-imports --allow-redefinition --explicit-package-bases 32 | cd .. 33 | ``` 34 | 35 | 36 | ## Testing 37 | 38 | Tests can run like this: 39 | 40 | ``` 41 | pip install pytest pytest-cov 42 | ``` 43 | 44 | ``` 45 | cd pytorch360convert 46 | pytest -ra --cov=. --cov-report term-missing 47 | cd .. 48 | ``` 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 sunset 4 | 5 | Copyright (c) 2024 Ben Egan 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📷 PyTorch 360° Image Conversion Toolkit 2 | 3 | [![PyPI - Version](https://img.shields.io/pypi/v/pytorch360convert)](https://pypi.org/project/pytorch360convert/) 4 | 5 | 6 | ## Overview 7 | 8 | This PyTorch-based library provides powerful and differentiable image transformation utilities for converting between different panoramic image formats: 9 | 10 | - **Equirectangular (360°) Images** 11 | - **Cubemap Representations** 12 | - **Perspective Projections** 13 | 14 | Built as an improved PyTorch implementation of the original [py360convert](https://github.com/sunset1995/py360convert) project, this library offers flexible, CPU & GPU-accelerated functions. 15 | 16 | 17 |
18 | 19 |
20 | 21 | * Equirectangular format 22 | 23 | 24 |
25 | 26 |
27 | 28 | * Cubemap 'dice' format 29 | 30 | 31 | ## 🔧 Requirements 32 | 33 | - Python 3.7+ 34 | - [PyTorch](https://pytorch.org/) 35 | 36 | 37 | ## 📦 Installation 38 | 39 | You can easily install the library using pip: 40 | 41 | ```bash 42 | pip install pytorch360convert 43 | ``` 44 | 45 | Or you can install it from source like this: 46 | 47 | ```bash 48 | pip install torch 49 | ``` 50 | 51 | Then clone the repository: 52 | 53 | ```bash 54 | git clone https://github.com/ProGamerGov/pytorch360convert.git 55 | cd pytorch360convert 56 | pip install . 57 | ``` 58 | 59 | 60 | ## 🚀 Key Features 61 | 62 | - Lossless conversion between image formats. 63 | - Supports different cubemap input formats (horizon, list, stack, dict, dice). 64 | - Configurable sampling modes (bilinear, nearest). 65 | - Supports different dtypes (float16, float32, float64, bfloat16). 66 | - CPU support. 67 | - GPU acceleration. 68 | - Differentiable transformations for deep learning pipelines. 69 | - [TorchScript](https://pytorch.org/docs/stable/jit.html) (JIT) support. 70 | 71 | 72 | ## 💡 Usage Examples 73 | 74 | 75 | ### Helper Functions 76 | 77 | First we'll setup some helper functions: 78 | 79 | ```bash 80 | pip install torchvision pillow 81 | ``` 82 | 83 | 84 | ```python 85 | import torch 86 | from torchvision.transforms import ToTensor, ToPILImage 87 | from PIL import Image 88 | 89 | def load_image_to_tensor(image_path: str) -> torch.Tensor: 90 | """Load an image as a PyTorch tensor.""" 91 | return ToTensor()(Image.open(image_path).convert('RGB')) 92 | 93 | def save_tensor_as_image(tensor: torch.Tensor, save_path: str) -> None: 94 | """Save a PyTorch tensor as an image.""" 95 | ToPILImage()(tensor).save(save_path) 96 | 97 | ``` 98 | 99 | ### Equirectangular to Cubemap Conversion 100 | 101 | Converting equirectangular images into cubemaps is easy. For simplicity, we'll use the 'dice' format, which places all cube faces into a single 4x3 grid image. 102 | 103 | ```python 104 | from pytorch360convert import e2c 105 | 106 | # Load equirectangular image (3, 1376, 2752) 107 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png") 108 | face_w = equi_image.shape[2] // 4 # 2752 / 4 = 688 109 | 110 | # Convert to cubemap (dice format) 111 | cubemap = e2c( 112 | equi_image, # CHW format 113 | face_w=face_w, # Width of each cube face 114 | mode='bilinear', # Sampling interpolation 115 | cube_format='dice' # Output cubemap layout 116 | ) 117 | 118 | # Save cubemap faces 119 | save_tensor_as_image(cubemap, "dice_cubemap.jpg") 120 | ``` 121 | 122 | | Equirectangular Input | Cubemap 'Dice' Output | 123 | | :---: | :----: | 124 | | ![](examples/example_world_map_equirectangular.png) | ![](examples/example_world_map_dice_cubemap.png) | 125 | 126 | | Cubemap 'Horizon' Output | 127 | | :---: | 128 | | ![](examples/example_world_map_horizon_cubemap.png) | 129 | 130 | ### Cubemap to Equirectangular Conversion 131 | 132 | We can also convert cubemaps into equirectangular images, like so. 133 | 134 | ```python 135 | from pytorch360convert import c2e 136 | 137 | # Load cubemap in 'dice' format 138 | cubemap = load_image_to_tensor("dice_cubemap.jpg") 139 | 140 | # Convert cubemap back to equirectangular 141 | equirectangular = c2e( 142 | cubemap, # Cubemap tensor(s) 143 | mode='bilinear', # Sampling interpolation 144 | cube_format='dice' # Input cubemap layout 145 | ) 146 | 147 | save_tensor_as_image(equirectangular, "equirectangular.jpg") 148 | ``` 149 | 150 | ### Equirectangular to Perspective Projection 151 | 152 | ```python 153 | from pytorch360convert import e2p 154 | 155 | # Load equirectangular input 156 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png") 157 | 158 | # Extract perspective view from equirectangular image 159 | perspective_view = e2p( 160 | equi_image, # Equirectangular image 161 | fov_deg=(70, 60), # Horizontal and vertical FOV 162 | h_deg=260, # Horizontal rotation 163 | v_deg=50, # Vertical rotation 164 | out_hw=(512, 768), # Output image dimensions 165 | mode='bilinear' # Sampling interpolation 166 | ) 167 | 168 | save_tensor_as_image(perspective_view, "perspective.jpg") 169 | ``` 170 | 171 | | Equirectangular Input | Perspective Output | 172 | | :---: | :----: | 173 | | ![](examples/example_world_map_equirectangular.png) | ![](examples/example_world_map_perspective.png) | 174 | 175 | 176 | 177 | ### Equirectangular to Equirectangular 178 | 179 | ```python 180 | from pytorch360convert import e2e 181 | 182 | # Load equirectangular input 183 | equi_image = load_image_to_tensor("examples/example_world_map_equirectangular.png") 184 | 185 | # Rotate an equirectangular image around one more axes 186 | rotated_equi = e2e( 187 | equi_image, # Equirectangular image 188 | h_deg=90.0, # Vertical rotation/shift 189 | v_deg=200.0, # Horizontal rotation/shift 190 | roll=45.0, # Clockwise/counter clockwise rotation 191 | mode='bilinear' # Sampling interpolation 192 | ) 193 | 194 | save_tensor_as_image(rotated_equi, "rotated.jpg") 195 | ``` 196 | 197 | | Equirectangular Input | Rotated Output | 198 | | :---: | :----: | 199 | | ![](examples/example_world_map_equirectangular.png) | ![](examples/example_world_map_equirectangular_rotated.png) | 200 | 201 | 202 | ## 📚 Basic Functions 203 | 204 | ### `e2c(e_img, face_w=256, mode='bilinear', cube_format='dice')` 205 | Converts an equirectangular image to a cubemap projection. 206 | 207 | - **Parameters**: 208 | - `e_img` (torch.Tensor): Equirectangular CHW image tensor. 209 | - `face_w` (int, optional): Cube face width. If set to None, then face_w will be calculated as ` // 2`. Default: `None`. 210 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear` 211 | - `cube_format` (str, optional): The desired output cubemap format. Options are `dict`, `list`, `horizon`, `stack`, and `dice`. Default: `dice` 212 | - `stack` (torch.Tensor): Stack of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 213 | - `list` (list of torch.Tensor): List of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 214 | - `dict` (dict of torch.Tensor): Dictionary with keys pointing to face tensors. Keys are: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 215 | - `dice` (torch.Tensor): A cubemap in a 'dice' layout. 216 | - `horizon` (torch.Tensor): A cubemap in a 'horizon' layout, a 1x6 grid in the order: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 217 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`. 218 | 219 | - **Returns**: Cubemap representation of the input image as a tensor, list of tensors, or dict or tensors. 220 | 221 | ### `c2e(cubemap, h, w, mode='bilinear', cube_format='dice')` 222 | Converts a cubemap projection to an equirectangular image. 223 | 224 | - **Parameters**: 225 | - `cubemap` (torch.Tensor, list of torch.Tensor, or dict of torch.Tensor): Cubemap image tensor, list of tensors, or dict of tensors. Note that tensors should be in the shape of: `CHW`, except for when `cube_format = 'stack'`, in which case a batch dimension is present. Inputs should match the corresponding `cube_format`. 226 | - `h` (int, optional): Output image height. If set to None, ` * 2` will be used. Default: `None`. 227 | - `w` (int, optional): Output image width. If set to None, ` * 4` will be used. Default: `None`. 228 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear` 229 | - `cube_format` (str, optional): Input cubemap format. Options are `dict`, `list`, `horizon`, `stack`, and `dice`. Default: `dice` 230 | - `stack` (torch.Tensor): Stack of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 231 | - `list` (list of torch.Tensor): List of 6 faces, in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 232 | - `dict` (dict of torch.Tensor): Dictionary with keys pointing to face tensors. Keys are expected to be: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 233 | - `dice` (torch.Tensor): A cubemap in a 'dice' layout. 234 | - `horizon` (torch.Tensor): A cubemap in a 'horizon' layout, a 1x6 grid in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 235 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`. 236 | 237 | - **Returns**: Equirectangular projection of the input cubemap as a tensor. 238 | 239 | ### `e2p(e_img, fov_deg, h_deg, v_deg, out_hw, in_rot_deg=0, mode='bilinear')` 240 | Extracts a perspective view from an equirectangular image. 241 | 242 | - **Parameters**: 243 | - `e_img` (torch.Tensor): Equirectangular CHW or NCHW image tensor. 244 | - `fov_deg` (float or tuple of float): Field of view in degrees. If a single value is provided, it will be used for both horizontal and vertical degrees. If using a tuple, values are expected to be in following format: (h_fov_deg, v_fov_deg). 245 | - `h_deg` (float, optional): Horizontal viewing angle in range [-pi, pi]. (-Left/+Right). Default: `0.0` 246 | - `v_deg` (float, optional): Vertical viewing angle in range [-pi/2, pi/2]. (-Down/+Up). Default: `0.0` 247 | - `out_hw` (float or tuple of float, optional): Output image dimensions in the shape of '(height, width)'. Default: `(512, 512)` 248 | - `in_rot_deg` (float, optional): Inplane rotation angle. Default: `0` 249 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear` 250 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`. 251 | 252 | - **Returns**: Perspective view of the equirectangular image as a tensor. 253 | 254 | ### `e2e(e_img, h_deg, v_deg, roll=0, mode='bilinear')` 255 | 256 | Rotate an equirectangular image along one or more axes (roll, pitch, and yaw) to produce a horizontal shift, vertical shift, or to roll the image. 257 | 258 | - **Parameters**: 259 | - `e_img` (torch.Tensor): Equirectangular CHW or NCHW image tensor. 260 | - `roll` (float, optional): Roll angle in degrees (-Counter_Clockwise/+Clockwise). Rotates the image along the x-axis. Default: `0.0` 261 | - `h_deg` (float, optional): Yaw angle in degrees (-Left/+Right). Rotates the image along the z-axis to produce a horizontal shift. Default: `0.0` 262 | - `v_deg` (float, optional): Pitch angle in degrees (-Down/+Up). Rotates the image along the y-axis to produce a vertical shift. Default: `0.0` 263 | - `mode` (str, optional): Sampling interpolation mode. Options are `bilinear`, `bicubic`, and `nearest`. Default: `bilinear` 264 | - `channels_first` (bool, optional): Input cubemap channel format (CHW or HWC). Defaults to the PyTorch CHW standard of `True`. 265 | 266 | - **Returns**: A modified equirectangular image tensor. 267 | 268 | 269 | ## 🤝 Contributing 270 | 271 | Contributions are welcome! Please feel free to submit a Pull Request. 272 | 273 | 274 | ## 🔬 Citation 275 | 276 | If you use this library in your research or project, please refer to the included [CITATION.cff](CITATION.cff) file or cite it as follows: 277 | 278 | ### BibTeX 279 | ```bibtex 280 | @misc{egan2024pytorch360convert, 281 | title={PyTorch 360° Image Conversion Toolkit}, 282 | author={Egan, Ben}, 283 | year={2024}, 284 | publisher={GitHub}, 285 | howpublished={\url{https://github.com/ProGamerGov/pytorch360convert}} 286 | } 287 | ``` 288 | 289 | ### APA Style 290 | ``` 291 | Egan, B. (2024). PyTorch 360° Image Conversion Toolkit [Computer software]. GitHub. https://github.com/ProGamerGov/pytorch360convert 292 | ``` 293 | -------------------------------------------------------------------------------- /examples/basic_dice_cubemap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/basic_dice_cubemap.png -------------------------------------------------------------------------------- /examples/basic_equirectangular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/basic_equirectangular.png -------------------------------------------------------------------------------- /examples/example_world_map_dice_cubemap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_dice_cubemap.png -------------------------------------------------------------------------------- /examples/example_world_map_equirectangular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_equirectangular.png -------------------------------------------------------------------------------- /examples/example_world_map_equirectangular_rotated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_equirectangular_rotated.png -------------------------------------------------------------------------------- /examples/example_world_map_horizon_cubemap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_horizon_cubemap.png -------------------------------------------------------------------------------- /examples/example_world_map_perspective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProGamerGov/pytorch360convert/f13a92c539164274f6387f7cdf68fecb3e62ac59/examples/example_world_map_perspective.png -------------------------------------------------------------------------------- /pytorch360convert/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch360convert.version import __version__ # noqa 2 | from .pytorch360convert import ( 3 | c2e, 4 | cube_dice2h, 5 | cube_dict2h, 6 | cube_h2dice, 7 | cube_h2dict, 8 | cube_h2list, 9 | cube_list2h, 10 | e2c, 11 | e2e, 12 | e2p, 13 | pad_180_to_360, 14 | ) 15 | -------------------------------------------------------------------------------- /pytorch360convert/pytorch360convert.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def rotation_matrix(rad: torch.Tensor, ax: torch.Tensor) -> torch.Tensor: 8 | """ 9 | Create a rotation matrix for a given angle and axis. 10 | 11 | Args: 12 | rad (torch.Tensor): Rotation angle in radians. 13 | ax (torch.Tensor): Rotation axis vector. 14 | 15 | Returns: 16 | torch.Tensor: 3x3 rotation matrix. 17 | """ 18 | ax = ax / torch.sqrt((ax**2).sum()) 19 | c = torch.cos(rad) 20 | s = torch.sin(rad) 21 | R = torch.diag(torch.cat([c, c, c])) 22 | R = R + (1.0 - c) * torch.ger(ax, ax) 23 | K = torch.stack( 24 | [ 25 | torch.stack( 26 | [torch.tensor(0.0, device=ax.device, dtype=ax.dtype), -ax[2], ax[1]] 27 | ), 28 | torch.stack( 29 | [ax[2], torch.tensor(0.0, device=ax.device, dtype=ax.dtype), -ax[0]] 30 | ), 31 | torch.stack( 32 | [-ax[1], ax[0], torch.tensor(0.0, device=ax.device, dtype=ax.dtype)] 33 | ), 34 | ], 35 | dim=0, 36 | ) 37 | R = R + K * s 38 | return R 39 | 40 | 41 | def _nhwc2nchw(x: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Convert NHWC to NCHW or HWC to CHW format. 44 | 45 | Args: 46 | x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format. 47 | 48 | Returns: 49 | torch.Tensor: The converted tensor in NCHW or CHW format. 50 | """ 51 | assert x.dim() == 3 or x.dim() == 4 52 | return x.permute(2, 0, 1) if x.dim() == 3 else x.permute(0, 3, 1, 2) 53 | 54 | 55 | def _nchw2nhwc(x: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Convert NCHW to NHWC or CHW to HWC format. 58 | 59 | Args: 60 | x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format. 61 | 62 | Returns: 63 | torch.Tensor: The converted tensor in NHWC or HWC format. 64 | """ 65 | assert x.dim() == 3 or x.dim() == 4 66 | return x.permute(1, 2, 0) if x.dim() == 3 else x.permute(0, 2, 3, 1) 67 | 68 | 69 | def _slice_chunk( 70 | index: int, width: int, offset: int = 0, device: torch.device = torch.device("cpu") 71 | ) -> torch.Tensor: 72 | """ 73 | Generate a tensor of indices for a chunk of values. 74 | 75 | Args: 76 | index (int): The starting index for the chunk. 77 | width (int): The number of indices in the chunk. 78 | offset (int, optional): An offset added to the starting index, default is 0. 79 | device (torch.device, optional): The device for the tensor. 80 | Default: torch.device('cpu') 81 | 82 | Returns: 83 | torch.Tensor: A tensor containing the indices for the chunk. 84 | """ 85 | start = index * width + offset 86 | # Create a tensor of indices instead of using slice 87 | return torch.arange(start, start + width, dtype=torch.long, device=device) 88 | 89 | 90 | def _face_slice( 91 | index: int, face_w: int, device: torch.device = torch.device("cpu") 92 | ) -> torch.Tensor: 93 | """ 94 | Generate a slice of indices based on the face width. 95 | 96 | Args: 97 | index (int): The starting index. 98 | face_w (int): The width of the face (number of indices). 99 | device (torch.device, optional): The device for the tensor. 100 | Default: torch.device('cpu') 101 | 102 | Returns: 103 | torch.Tensor: A tensor containing the slice of indices. 104 | """ 105 | return _slice_chunk(index, face_w, device=device) 106 | 107 | 108 | def xyzcube( 109 | face_w: int, 110 | device: torch.device = torch.device("cpu"), 111 | dtype: torch.dtype = torch.float32, 112 | ) -> torch.Tensor: 113 | """ 114 | Generate cube coordinates for equirectangular projection. 115 | 116 | Args: 117 | face_w (int): Width of each cube face. 118 | device (torch.device, optional): Device to create tensor on. 119 | Default: torch.device('cpu') 120 | dtype (torch.dtype, optional): Data type of the tensor. 121 | Default: torch.float32 122 | 123 | Returns: 124 | torch.Tensor: Cube coordinates tensor of shape (face_w, face_w * 6, 3). 125 | """ 126 | out = torch.empty((face_w, face_w * 6, 3), dtype=dtype, device=device) 127 | rng = torch.linspace(-0.5, 0.5, steps=face_w, dtype=dtype, device=device) 128 | x, y = torch.meshgrid(rng, -rng, indexing="xy") # shape (face_w, face_w) 129 | 130 | # Pre-compute flips 131 | x_flip = torch.flip(x, [1]) 132 | y_flip = torch.flip(y, [0]) 133 | 134 | # Front face (z = 0.5) 135 | front_indices = _face_slice(0, face_w, device) 136 | out[:, front_indices, 0] = x 137 | out[:, front_indices, 1] = y 138 | out[:, front_indices, 2] = 0.5 139 | 140 | # Right face (x = 0.5) 141 | right_indices = _face_slice(1, face_w, device) 142 | out[:, right_indices, 0] = 0.5 143 | out[:, right_indices, 1] = y 144 | out[:, right_indices, 2] = x_flip 145 | 146 | # Back face (z = -0.5) 147 | back_indices = _face_slice(2, face_w, device) 148 | out[:, back_indices, 0] = x_flip 149 | out[:, back_indices, 1] = y 150 | out[:, back_indices, 2] = -0.5 151 | 152 | # Left face (x = -0.5) 153 | left_indices = _face_slice(3, face_w, device) 154 | out[:, left_indices, 0] = -0.5 155 | out[:, left_indices, 1] = y 156 | out[:, left_indices, 2] = x 157 | 158 | # Up face (y = 0.5) 159 | up_indices = _face_slice(4, face_w, device) 160 | out[:, up_indices, 0] = x 161 | out[:, up_indices, 1] = 0.5 162 | out[:, up_indices, 2] = y_flip 163 | 164 | # Down face (y = -0.5) 165 | down_indices = _face_slice(5, face_w, device) 166 | out[:, down_indices, 0] = x 167 | out[:, down_indices, 1] = -0.5 168 | out[:, down_indices, 2] = y 169 | 170 | return out 171 | 172 | 173 | def equirect_uvgrid( 174 | h: int, 175 | w: int, 176 | device: torch.device = torch.device("cpu"), 177 | dtype: torch.dtype = torch.float32, 178 | ) -> torch.Tensor: 179 | """ 180 | Generate UV grid for equirectangular projection. 181 | 182 | Args: 183 | h (int): Height of the grid. 184 | w (int): Width of the grid. 185 | device (torch.device, optional): Device to create tensor on. 186 | Default: torch.device('cpu') 187 | dtype (torch.dtype, optional): Data type of the tensor. 188 | Default: torch.float32 189 | 190 | Returns: 191 | torch.Tensor: UV grid of shape (h, w, 2). 192 | """ 193 | u = torch.linspace(-torch.pi, torch.pi, steps=w, dtype=dtype, device=device) 194 | v = torch.linspace(torch.pi, -torch.pi, steps=h, dtype=dtype, device=device) / 2 195 | grid_v, grid_u = torch.meshgrid(v, u, indexing="ij") 196 | uv = torch.stack([grid_u, grid_v], dim=-1) 197 | return uv 198 | 199 | 200 | def equirect_facetype( 201 | h: int, 202 | w: int, 203 | device: torch.device = torch.device("cpu"), 204 | dtype: torch.dtype = torch.float32, 205 | ) -> torch.Tensor: 206 | """ 207 | Determine face types for equirectangular projection. 208 | 209 | Args: 210 | h (int): Height of the grid. 211 | w (int): Width of the grid. 212 | device (torch.device, optional): Device to create tensor on. 213 | Default: torch.device('cpu') 214 | dtype (torch.dtype, optional): Data type of the tensor. 215 | Default: torch.float32 216 | 217 | Returns: 218 | torch.Tensor: Face type tensor of shape (h, w) with integer face 219 | indices. 220 | """ 221 | tp = ( 222 | torch.arange(4, device=device) 223 | .repeat_interleave(w // 4) 224 | .unsqueeze(0) 225 | .repeat(h, 1) 226 | ) 227 | tp = torch.roll(tp, shifts=3 * (w // 8), dims=1) 228 | 229 | # Prepare ceil mask 230 | mask = torch.zeros((h, w // 4), dtype=torch.bool, device=device) 231 | idx = torch.linspace(-torch.pi, torch.pi, w // 4, device=device, dtype=dtype) / 4 232 | idx = torch.round(h / 2 - torch.atan(torch.cos(idx)) * h / torch.pi).to(torch.long) 233 | for i, j in enumerate(idx): 234 | mask[:j, i] = True 235 | mask = torch.roll(torch.cat([mask] * 4, dim=1), shifts=3 * (w // 8), dims=1) 236 | 237 | tp[mask] = 4 238 | tp[torch.flip(mask, [0])] = 5 239 | return tp 240 | 241 | 242 | def xyzpers( 243 | h_fov: float, 244 | v_fov: float, 245 | u: float, 246 | v: float, 247 | out_hw: Tuple[int, int], 248 | in_rot: float, 249 | device: torch.device = torch.device("cpu"), 250 | dtype: torch.dtype = torch.float32, 251 | ) -> torch.Tensor: 252 | """ 253 | Generate perspective projection coordinates. 254 | 255 | Args: 256 | h_fov (torch.Tensor): Horizontal field of view in radians. 257 | v_fov (torch.Tensor): Vertical field of view in radians. 258 | u (float): Horizontal rotation angle in radians. 259 | v (float): Vertical rotation angle in radians. 260 | out_hw (tuple of int): Output height and width. 261 | in_rot (torch.Tensor): Input rotation angle in radians. 262 | device (torch.device, optional): Device to create tensor on. 263 | Default: torch.device('cpu') 264 | dtype (torch.dtype, optional): Data type of the tensor. 265 | Default: torch.float32 266 | 267 | Returns: 268 | torch.Tensor: Perspective projection coordinates tensor. 269 | """ 270 | h_fov = torch.tensor([h_fov], dtype=dtype, device=device) 271 | v_fov = torch.tensor([v_fov], dtype=dtype, device=device) 272 | u = torch.tensor([u], dtype=dtype, device=device) 273 | v = torch.tensor([v], dtype=dtype, device=device) 274 | in_rot = torch.tensor([in_rot], dtype=dtype, device=device) 275 | 276 | out = torch.ones((*out_hw, 3), dtype=dtype, device=device) 277 | x_max = torch.tan(h_fov / 2) 278 | y_max = torch.tan(v_fov / 2) 279 | y_range = torch.linspace( 280 | -y_max.item(), y_max.item(), steps=out_hw[0], dtype=dtype, device=device 281 | ) 282 | x_range = torch.linspace( 283 | -x_max.item(), x_max.item(), steps=out_hw[1], dtype=dtype, device=device 284 | ) 285 | grid_y, grid_x = torch.meshgrid(-y_range, x_range, indexing="ij") 286 | out[..., 0] = grid_x 287 | out[..., 1] = grid_y 288 | 289 | Rx = rotation_matrix(v, torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device)) 290 | Ry = rotation_matrix(u, torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device)) 291 | Ri = rotation_matrix( 292 | in_rot, torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) @ Rx @ Ry 293 | ) 294 | 295 | # Apply R = Rx*Ry*Ri to each vector 296 | # like this: out * Rx * Ry * Ri (assuming row vectors) 297 | out = out @ Rx @ Ry @ Ri 298 | return out 299 | 300 | 301 | def xyz2uv(xyz: torch.Tensor) -> torch.Tensor: 302 | """ 303 | Transform cartesian (x, y, z) to spherical(r, u, v), and 304 | only outputs (u, v). 305 | 306 | Args: 307 | xyz (torch.Tensor): Input 3D coordinates tensor. 308 | 309 | Returns: 310 | torch.Tensor: UV coordinates tensor. 311 | """ 312 | x = xyz[..., 0] 313 | y = xyz[..., 1] 314 | z = xyz[..., 2] 315 | u = torch.atan2(x, z) 316 | c = torch.sqrt(x**2 + z**2) 317 | v = torch.atan2(y, c) 318 | return torch.stack([u, v], dim=-1) 319 | 320 | 321 | def uv2unitxyz(uv: torch.Tensor) -> torch.Tensor: 322 | """ 323 | Convert UV coordinates to unit 3D Cartesian coordinates. 324 | 325 | Args: 326 | uv (torch.Tensor): Input UV coordinates tensor. 327 | 328 | Returns: 329 | torch.Tensor: Unit 3D coordinates tensor. 330 | """ 331 | u = uv[..., 0] 332 | v = uv[..., 1] 333 | y = torch.sin(v) 334 | c = torch.cos(v) 335 | x = c * torch.sin(u) 336 | z = c * torch.cos(u) 337 | return torch.stack([x, y, z], dim=-1) 338 | 339 | 340 | def uv2coor(uv: torch.Tensor, h: int, w: int) -> torch.Tensor: 341 | """ 342 | Convert UV coordinates to image coordinates. 343 | 344 | Args: 345 | uv (torch.Tensor): Input UV coordinates tensor. 346 | h (int): Image height. 347 | w (int): Image width. 348 | 349 | Returns: 350 | torch.Tensor: Image coordinates tensor. 351 | """ 352 | u = uv[..., 0] 353 | v = uv[..., 1] 354 | coor_x = (u / (2 * torch.pi) + 0.5) * w - 0.5 355 | coor_y = (-v / torch.pi + 0.5) * h - 0.5 356 | return torch.stack([coor_x, coor_y], dim=-1) 357 | 358 | 359 | def coor2uv(coorxy: torch.Tensor, h: int, w: int) -> torch.Tensor: 360 | """ 361 | Convert image coordinates to UV coordinates. 362 | 363 | Args: 364 | coorxy (torch.Tensor): Input image coordinates tensor. 365 | h (int): Image height. 366 | w (int): Image width. 367 | 368 | Returns: 369 | torch.Tensor: UV coordinates tensor. 370 | """ 371 | coor_x = coorxy[..., 0] 372 | coor_y = coorxy[..., 1] 373 | u = ((coor_x + 0.5) / w - 0.5) * 2 * torch.pi 374 | v = -((coor_y + 0.5) / h - 0.5) * torch.pi 375 | return torch.stack([u, v], dim=-1) 376 | 377 | 378 | def pad_cube_faces(cube_faces: torch.Tensor) -> torch.Tensor: 379 | """ 380 | Adds 1 pixel of padding around each cube face, using pixels from the neighbouring 381 | faces, for each face. 382 | 383 | Args: 384 | cube_faces (torch.Tensor): Tensor of shape [6, H, W, C] representing the 6 385 | faces. Expected face order is: FRONT=0, RIGHT=1, BACK=2, LEFT=3, UP=4, 386 | DOWN=5 387 | 388 | Returns: 389 | torch.Tensor: Padded tensor of shape [6, H+2, W+2, C] 390 | """ 391 | # Define face indices as constants instead of enum 392 | FRONT, RIGHT, BACK, LEFT, UP, DOWN = 0, 1, 2, 3, 4, 5 393 | 394 | # Create padded tensor with zeros 395 | padded = torch.zeros( 396 | cube_faces.shape[0], 397 | cube_faces.shape[1] + 2, 398 | cube_faces.shape[2] + 2, 399 | cube_faces.shape[3], 400 | dtype=cube_faces.dtype, 401 | device=cube_faces.device, 402 | ) 403 | 404 | # Copy original data to center of padded tensor 405 | padded[:, 1:-1, 1:-1, :] = cube_faces 406 | 407 | # Pad above/below 408 | padded[FRONT, 0, 1:-1, :] = padded[UP, -2, 1:-1, :] 409 | padded[FRONT, -1, 1:-1, :] = padded[DOWN, 1, 1:-1, :] 410 | 411 | padded[RIGHT, 0, 1:-1, :] = padded[UP, 1:-1, -2, :].flip(0) 412 | padded[RIGHT, -1, 1:-1, :] = padded[DOWN, 1:-1, -2, :] 413 | 414 | padded[BACK, 0, 1:-1, :] = padded[UP, 1, 1:-1, :].flip(0) 415 | padded[BACK, -1, 1:-1, :] = padded[DOWN, -2, 1:-1, :].flip(0) 416 | 417 | padded[LEFT, 0, 1:-1, :] = padded[UP, 1:-1, 1, :] 418 | padded[LEFT, -1, 1:-1, :] = padded[DOWN, 1:-1, 1, :].flip(0) 419 | 420 | padded[UP, 0, 1:-1, :] = padded[BACK, 1, 1:-1, :].flip(0) 421 | padded[UP, -1, 1:-1, :] = padded[FRONT, 1, 1:-1, :] 422 | 423 | padded[DOWN, 0, 1:-1, :] = padded[FRONT, -2, 1:-1, :] 424 | padded[DOWN, -1, 1:-1, :] = padded[BACK, -2, 1:-1, :].flip(0) 425 | 426 | # Pad left/right 427 | padded[FRONT, 1:-1, 0, :] = padded[LEFT, 1:-1, -2, :] 428 | padded[FRONT, 1:-1, -1, :] = padded[RIGHT, 1:-1, 1, :] 429 | 430 | padded[RIGHT, 1:-1, 0, :] = padded[FRONT, 1:-1, -2, :] 431 | padded[RIGHT, 1:-1, -1, :] = padded[BACK, 1:-1, 1, :] 432 | 433 | padded[BACK, 1:-1, 0, :] = padded[RIGHT, 1:-1, -2, :] 434 | padded[BACK, 1:-1, -1, :] = padded[LEFT, 1:-1, 1, :] 435 | 436 | padded[LEFT, 1:-1, 0, :] = padded[BACK, 1:-1, -2, :] 437 | padded[LEFT, 1:-1, -1, :] = padded[FRONT, 1:-1, 1, :] 438 | 439 | padded[UP, 1:-1, 0, :] = padded[LEFT, 1, 1:-1, :] 440 | padded[UP, 1:-1, -1, :] = padded[RIGHT, 1, 1:-1, :].flip(0) 441 | 442 | padded[DOWN, 1:-1, 0, :] = padded[LEFT, -2, 1:-1, :].flip(0) 443 | padded[DOWN, 1:-1, -1, :] = padded[RIGHT, -2, 1:-1, :] 444 | return padded 445 | 446 | 447 | def grid_sample_wrap( 448 | image: torch.Tensor, 449 | coor_x: torch.Tensor, 450 | coor_y: torch.Tensor, 451 | mode: str = "bilinear", 452 | padding_mode: str = "border", 453 | ) -> torch.Tensor: 454 | """ 455 | Sample from an image with wrapped horizontal coordinates. 456 | 457 | Args: 458 | image (torch.Tensor): Input image tensor in the shape of [H, W, C] or 459 | [N, H, W, C]. 460 | coor_x (torch.Tensor): X coordinates for sampling. 461 | coor_y (torch.Tensor): Y coordinates for sampling. 462 | mode (str, optional): Sampling interpolation mode, 'nearest', 463 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 464 | padding_mode (str, optional): Sampling interpolation mode. 465 | Default: 'border' 466 | 467 | Returns: 468 | torch.Tensor: Sampled image tensor. 469 | """ 470 | 471 | assert image.dim() == 4 or image.dim() == 3 472 | # Permute image to NCHW 473 | if image.dim() == 3: 474 | has_batch = False 475 | H, W, _ = image.shape 476 | # [H,W,C] -> [1,C,H,W] 477 | img_t = image.permute(2, 0, 1).unsqueeze(0) 478 | else: 479 | has_batch = True 480 | _, H, W, _ = image.shape 481 | # [N,H,W,C] -> [N,C,H,W] 482 | img_t = image.permute(0, 3, 1, 2) 483 | 484 | # coor_x, coor_y: [H_out, W_out] 485 | # We must create a grid for F.grid_sample: 486 | # grid_sample expects: input [N, C, H, W], grid [N, H_out, W_out, 2] 487 | # Normalized coords: x: [-1, 1], y: [-1, 1] 488 | # Handle wrapping horizontally: coor_x modulo W 489 | coor_x_wrapped = torch.remainder(coor_x, W) # wrap horizontally 490 | coor_y_clamped = coor_y.clamp(min=0, max=H - 1) 491 | 492 | # Normalize 493 | grid_x = (coor_x_wrapped / (W - 1)) * 2 - 1 494 | grid_y = (coor_y_clamped / (H - 1)) * 2 - 1 495 | grid = torch.stack([grid_x, grid_y], dim=-1) # [H_out, W_out, 2] 496 | 497 | grid = grid.unsqueeze(0) # [1, H_out, W_out,2] 498 | if has_batch: 499 | grid = grid.repeat(img_t.shape[0], 1, 1, 1) 500 | 501 | # grid_sample: note that the code samples using (y, x) order if 502 | # align_corners=False, we must be careful: 503 | # grid is defined as grid[:,:,:,0] = x, grid[:,:,:,1] = y, 504 | # PyTorch grid_sample expects grid in form (N, H_out, W_out, 2), 505 | # with grid[:,:,:,0] = x and grid[:,:,:,1] = y 506 | 507 | if ( 508 | img_t.dtype == torch.float16 or img_t.dtype == torch.bfloat16 509 | ) and img_t.device == torch.device("cpu"): 510 | sampled = F.grid_sample( 511 | img_t.float(), 512 | grid.float(), 513 | mode=mode, 514 | padding_mode=padding_mode, 515 | align_corners=True, 516 | ).to(img_t.dtype) 517 | else: 518 | sampled = F.grid_sample( 519 | img_t, grid, mode=mode, padding_mode=padding_mode, align_corners=True 520 | ) 521 | 522 | if has_batch: 523 | sampled = sampled.permute(0, 2, 3, 1) 524 | else: 525 | # [1, C, H_out, W_out] 526 | sampled = sampled.squeeze(0).permute(1, 2, 0) # [H_out, W_out, C] 527 | return sampled 528 | 529 | 530 | def sample_equirec( 531 | e_img: torch.Tensor, coor_xy: torch.Tensor, mode: str = "bilinear" 532 | ) -> torch.Tensor: 533 | """ 534 | Sample from an equirectangular image. 535 | 536 | Args: 537 | e_img (torch.Tensor): Equirectangular image tensor in the shape of: 538 | [H, W, C]. 539 | coor_xy (torch.Tensor): Sampling coordinates in the shape of 540 | [H_out, W_out, 2]. 541 | mode (str, optional): Sampling interpolation mode, 'nearest', 542 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 543 | 544 | Returns: 545 | torch.Tensor: Sampled image tensor. 546 | """ 547 | coor_x = coor_xy[..., 0] 548 | coor_y = coor_xy[..., 1] 549 | return grid_sample_wrap(e_img, coor_x, coor_y, mode=mode) 550 | 551 | 552 | def sample_cubefaces( 553 | cube_faces: torch.Tensor, 554 | tp: torch.Tensor, 555 | coor_y: torch.Tensor, 556 | coor_x: torch.Tensor, 557 | mode: str = "bilinear", 558 | ) -> torch.Tensor: 559 | """ 560 | Sample from cube faces. 561 | 562 | Args: 563 | cube_faces (torch.Tensor): Cube faces tensor in the shape of: 564 | [6, face_w, face_w, C]. 565 | tp (torch.Tensor): Face type tensor. 566 | coor_y (torch.Tensor): Y coordinates for sampling. 567 | coor_x (torch.Tensor): X coordinates for sampling. 568 | mode (str, optional): Sampling interpolation mode, 'nearest' or 569 | 'bilinear'. Default: 'bilinear' 570 | 571 | Returns: 572 | torch.Tensor: Sampled cube faces tensor. 573 | """ 574 | # cube_faces: [6, face_w, face_w, C] 575 | # We must sample according to tp (face index), coor_y, coor_x 576 | # First we must flatten all faces into a single big image (like cube_h) 577 | # We can do per-face sampling. Instead of map_coordinates 578 | # (tp, y, x), we know each pixel belongs to a certain face. 579 | 580 | # For differentiability and simplicity, let's do a trick: 581 | # Create a big image [face_w,face_w*6, C] (cube_h) and sample from it using 582 | # coor_x, coor_y and tp. 583 | cube_faces = pad_cube_faces(cube_faces) 584 | coor_y = coor_y + 1 585 | coor_x = coor_x + 1 586 | cube_faces_mod = cube_faces.clone() 587 | 588 | face_w = cube_faces_mod.shape[1] 589 | cube_h = torch.cat( 590 | [cube_faces_mod[i] for i in range(6)], dim=1 591 | ) # [face_w, face_w*6, C] 592 | 593 | # We need to map (tp, coor_y, coor_x) -> coordinates in cube_h 594 | # cube_h faces: 0:F, 1:R, 2:B, 3:L, 4:U, 5:D in order 595 | # If tp==0: x in [0, face_w-1] + offset 0 596 | # If tp==1: x in [0, face_w-1] + offset face_w 597 | # etc. 598 | 599 | # coor_x, coor_y are in face coordinates [0, face_w-1] 600 | # offset for face 601 | # x_offset = tp * face_w 602 | 603 | # Construct a single image indexing: 604 | # To handle tp indexing, let's create global_x = coor_x + tp * face_w 605 | # But tp might have shape (H_out,W_out) 606 | global_x = coor_x + tp.to(dtype=cube_h.dtype) * face_w 607 | global_y = coor_y 608 | 609 | return grid_sample_wrap(cube_h, global_x, global_y, mode=mode) 610 | 611 | 612 | def cube_h2list(cube_h: torch.Tensor) -> List[torch.Tensor]: 613 | """ 614 | Convert a horizontal cube representation to a list of cube faces. 615 | 616 | Args: 617 | cube_h (torch.Tensor): Horizontal cube representation tensor in the 618 | shape of: [w, w*6, C] or [B, w, w*6, C]. 619 | 620 | Returns: 621 | List[torch.Tensor]: List of cube face tensors in the order of: 622 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'] 623 | """ 624 | assert cube_h.dim() == 3 or cube_h.dim() == 4 625 | w = cube_h.shape[0] if cube_h.dim() == 3 else cube_h.shape[1] 626 | return [cube_h[..., i * w : (i + 1) * w, :] for i in range(6)] 627 | 628 | 629 | def cube_list2h(cube_list: List[torch.Tensor]) -> torch.Tensor: 630 | """ 631 | Convert a list of cube faces to a horizontal cube representation. 632 | 633 | Args: 634 | cube_list (list of torch.Tensor): List of cube face tensors, in order 635 | of ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'] 636 | 637 | Returns: 638 | torch.Tensor: Horizontal cube representation tensor. 639 | """ 640 | assert all( 641 | cube.shape == cube_list[0].shape for cube in cube_list 642 | ), "All cube faces should have the same shape." 643 | assert all( 644 | cube.device == cube_list[0].device for cube in cube_list 645 | ), "All cube faces should have the same device." 646 | assert all( 647 | cube.dtype == cube_list[0].dtype for cube in cube_list 648 | ), "All cube faces should have the same dtype." 649 | return torch.cat(cube_list, dim=1) 650 | 651 | 652 | def cube_h2dict( 653 | cube_h: torch.Tensor, 654 | face_keys: Optional[List[str]] = None, 655 | ) -> Dict[str, torch.Tensor]: 656 | """ 657 | Convert a horizontal cube representation to a dictionary of cube faces. 658 | 659 | Order: F R B L U D 660 | dice layout: 3*face_w x 4*face_w 661 | 662 | Args: 663 | cube_h (torch.Tensor): Horizontal cube representation tensor in the 664 | shape of: [w, w*6, C]. 665 | face_keys (list of str, optional): List of face keys in order. 666 | Default: '["Front", "Right", "Back", "Left", "Up", "Down"]' 667 | 668 | Returns: 669 | Dict[str, torch.Tensor]: Dictionary of cube faces with keys: 670 | ["Front", "Right", "Back", "Left", "Up", "Down"]. 671 | """ 672 | if face_keys is None: 673 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 674 | cube_list = cube_h2list(cube_h) 675 | return dict(zip(face_keys, cube_list)) 676 | 677 | 678 | def cube_dict2h( 679 | cube_dict: Dict[str, torch.Tensor], 680 | face_keys: Optional[List[str]] = None, 681 | ) -> torch.Tensor: 682 | """ 683 | Convert a dictionary of cube faces to a horizontal cube representation. 684 | 685 | Args: 686 | cube_dict (dict of str to torch.Tensor): Dictionary of cube faces. 687 | face_keys (list of str, optional): List of face keys in order. 688 | Default: '["Front", "Right", "Back", "Left", "Up", "Down"]' 689 | 690 | Returns: 691 | torch.Tensor: Horizontal cube representation tensor. 692 | """ 693 | if face_keys is None: 694 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 695 | return cube_list2h([cube_dict[k] for k in face_keys]) 696 | 697 | 698 | def cube_h2dice(cube_h: torch.Tensor) -> torch.Tensor: 699 | """ 700 | Convert a horizontal cube representation to a dice layout representation. 701 | 702 | Output order: Front Right Back Left Up Down 703 | ┌────┬────┬────┬────┐ 704 | │ │ U │ │ │ 705 | ├────┼────┼────┼────┤ 706 | │ L │ F │ R │ B │ 707 | ├────┼────┼────┼────┤ 708 | │ │ D │ │ │ 709 | └────┴────┴────┴────┘ 710 | 711 | Args: 712 | cube_h (torch.Tensor): Horizontal cube representation tensor in the 713 | shape of: [w, w*6, C]. 714 | 715 | Returns: 716 | torch.Tensor: Dice layout cube representation tensor in the shape of: 717 | [w*3, w*4, C]. 718 | """ 719 | w = cube_h.shape[0] 720 | cube_dice = torch.zeros( 721 | (w * 3, w * 4, cube_h.shape[2]), dtype=cube_h.dtype, device=cube_h.device 722 | ) 723 | cube_list = cube_h2list(cube_h) 724 | sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)] 725 | for i, (sx, sy) in enumerate(sxy): 726 | face = cube_list[i] 727 | cube_dice[sy * w : (sy + 1) * w, sx * w : (sx + 1) * w] = face 728 | return cube_dice 729 | 730 | 731 | def cube_dice2h(cube_dice: torch.Tensor) -> torch.Tensor: 732 | """ 733 | Convert a dice layout representation to a horizontal cube representation. 734 | 735 | Input order: Front Right Back Left Up Down 736 | ┌────┬────┬────┬────┐ 737 | │ │ U │ │ │ 738 | ├────┼────┼────┼────┤ 739 | │ L │ F │ R │ B │ 740 | ├────┼────┼────┼────┤ 741 | │ │ D │ │ │ 742 | └────┴────┴────┴────┘ 743 | 744 | Args: 745 | cube_dice (torch.Tensor): Dice layout cube representation tensor in the 746 | shape of: [w*3, w*4, C]. 747 | 748 | Returns: 749 | torch.Tensor: Horizontal cube representation tensor in the shape of: 750 | [w, w*6, C]. 751 | """ 752 | w = cube_dice.shape[0] // 3 753 | cube_h = torch.zeros( 754 | (w, w * 6, cube_dice.shape[2]), dtype=cube_dice.dtype, device=cube_dice.device 755 | ) 756 | sxy = [(1, 1), (2, 1), (3, 1), (0, 1), (1, 0), (1, 2)] 757 | for i, (sx, sy) in enumerate(sxy): 758 | face = cube_dice[sy * w : (sy + 1) * w, sx * w : (sx + 1) * w] 759 | cube_h[:, i * w : (i + 1) * w] = face 760 | return cube_h 761 | 762 | 763 | def c2e( 764 | cubemap: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]], 765 | h: Optional[int] = None, 766 | w: Optional[int] = None, 767 | mode: str = "bilinear", 768 | cube_format: str = "dice", 769 | channels_first: bool = True, 770 | ) -> torch.Tensor: 771 | """ 772 | Convert a cubemap to an equirectangular projection. 773 | 774 | Args: 775 | cubemap (torch.Tensor, list of torch.Tensor, or dict of torch.Tensor): 776 | Cubemap image tensor, list of tensors, or dict of tensors. Note 777 | that tensors should be in the shape of: 'CHW', except for when 778 | `cube_format = 'stack'`, in which case a batch dimension is 779 | present 'BCHW'. 780 | h (int, optional): Height of the output equirectangular image. If set 781 | to None, * 2 will be used. 782 | Default: ` * 2` 783 | w (int, optional): Width of the output equirectangular image. If set 784 | to None, * 4 will be used. 785 | Default: ` * 4` 786 | mode (str, optional): Sampling interpolation mode, 'nearest', 787 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 788 | cube_format (str, optional): The input 'cubemap' format. Options 789 | are 'dict', 'list', 'horizon', 'stack', and 'dice'. Default: 'dice' 790 | The chosen 'cube_format' should correspond to the provided 791 | 'cubemap' input. 792 | The list of options are: 793 | - 'stack' (torch.Tensor): Stack of 6 faces, in the order 794 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 795 | - 'list' (list of torch.Tensor): List of 6 faces, in the order 796 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 797 | - 'dict' (dict of torch.Tensor): Dictionary with keys pointing to 798 | face tensors. Dict keys are: 799 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 800 | - 'dice' (torch.Tensor): A cubemap in a 'dice' layout. 801 | - 'horizon' (torch.Tensor): A cubemap in a 'horizon' layout, 802 | a 1x6 grid in the order: 803 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 804 | channels_first (bool, optional): The channel format used by the cubemap 805 | tensor(s). PyTorch uses channels first. Default: 'True' 806 | 807 | Returns: 808 | torch.Tensor: A CHW equirectangular projection tensor. 809 | 810 | Raises: 811 | NotImplementedError: If an unknown cube_format is provided. 812 | """ 813 | 814 | if cube_format == "stack": 815 | assert ( 816 | isinstance(cubemap, torch.Tensor) 817 | and len(cubemap.shape) == 4 818 | and cubemap.shape[0] == 6 819 | ) 820 | cubemap = [cubemap[i] for i in range(cubemap.shape[0])] 821 | cube_format = "list" 822 | 823 | # Ensure input is in HWC format for processing 824 | if channels_first: 825 | if cube_format == "list" and isinstance(cubemap, (list, tuple)): 826 | cubemap = [r.permute(1, 2, 0) for r in cubemap] 827 | elif cube_format == "dict" and torch.jit.isinstance( 828 | cubemap, Dict[str, torch.Tensor] 829 | ): 830 | cubemap = {k: v.permute(1, 2, 0) for k, v in cubemap.items()} # type: ignore 831 | elif cube_format in ["horizon", "dice"] and isinstance(cubemap, torch.Tensor): 832 | cubemap = cubemap.permute(1, 2, 0) 833 | else: 834 | raise NotImplementedError("unknown cube_format and cubemap type") 835 | 836 | if cube_format == "horizon" and isinstance(cubemap, torch.Tensor): 837 | assert cubemap.dim() == 3 838 | cube_h = cubemap 839 | elif cube_format == "list" and isinstance(cubemap, (list, tuple)): 840 | assert all([r.dim() == 3 for r in cubemap]) 841 | cube_h = cube_list2h(cubemap) 842 | elif cube_format == "dict" and torch.jit.isinstance( 843 | cubemap, Dict[str, torch.Tensor] 844 | ): 845 | assert all(v.dim() == 3 for k, v in cubemap.items()) # type: ignore[union-attr] 846 | cube_h = cube_dict2h(cubemap) # type: ignore[arg-type] 847 | elif cube_format == "dice" and isinstance(cubemap, torch.Tensor): 848 | assert len(cubemap.shape) == 3 849 | cube_h = cube_dice2h(cubemap) 850 | else: 851 | raise NotImplementedError("unknown cube_format and cubemap type") 852 | assert isinstance(cube_h, torch.Tensor) # Mypy wants this 853 | 854 | device = cube_h.device 855 | dtype = cube_h.dtype 856 | face_w = cube_h.shape[0] 857 | assert cube_h.shape[1] == face_w * 6 858 | 859 | h = face_w * 2 if h is None else h 860 | w = face_w * 4 if w is None else w 861 | 862 | uv = equirect_uvgrid(h, w, device=device, dtype=dtype) 863 | u, v = uv[..., 0], uv[..., 1] 864 | 865 | cube_faces = torch.stack( 866 | torch.split(cube_h, face_w, dim=1), dim=0 867 | ) # [6, face_w, face_w, C] 868 | 869 | tp = equirect_facetype(h, w, device=device, dtype=dtype) 870 | 871 | coor_x = torch.zeros((h, w), device=device, dtype=dtype) 872 | coor_y = torch.zeros((h, w), device=device, dtype=dtype) 873 | 874 | # front, right, back, left 875 | for i in range(4): 876 | mask = tp == i 877 | coor_x[mask] = 0.5 * torch.tan(u[mask] - torch.pi * i / 2) 878 | coor_y[mask] = -0.5 * torch.tan(v[mask]) / torch.cos(u[mask] - torch.pi * i / 2) 879 | 880 | # Up 881 | mask = tp == 4 882 | c = 0.5 * torch.tan(torch.pi / 2 - v[mask]) 883 | coor_x[mask] = c * torch.sin(u[mask]) 884 | coor_y[mask] = c * torch.cos(u[mask]) 885 | 886 | # Down 887 | mask = tp == 5 888 | c = 0.5 * torch.tan(torch.pi / 2 - torch.abs(v[mask])) 889 | coor_x[mask] = c * torch.sin(u[mask]) 890 | coor_y[mask] = -c * torch.cos(u[mask]) 891 | 892 | coor_x = (torch.clamp(coor_x, -0.5, 0.5) + 0.5) * face_w 893 | coor_y = (torch.clamp(coor_y, -0.5, 0.5) + 0.5) * face_w 894 | 895 | equirec = sample_cubefaces(cube_faces, tp, coor_y, coor_x, mode) 896 | 897 | # Convert back to CHW if required 898 | equirec = _nhwc2nchw(equirec) if channels_first else equirec 899 | return equirec 900 | 901 | 902 | def e2c( 903 | e_img: torch.Tensor, 904 | face_w: Optional[int] = None, 905 | mode: str = "bilinear", 906 | cube_format: str = "dice", 907 | channels_first: bool = True, 908 | ) -> Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]: 909 | """ 910 | Convert an equirectangular image to a cubemap. 911 | 912 | Args: 913 | e_img (torch.Tensor): Input equirectangular image tensor of shape 914 | [C, H, W] or [H, W, C]. 915 | face_w (int, optional): Width of each square cube shaped face. If set to None, 916 | then face_w will be calculated as H // 2. Default: None 917 | mode (str, optional): Sampling interpolation mode, 'nearest', 918 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 919 | cube_format (str, optional): The desired output cubemap format. Options 920 | are 'dict', 'list', 'horizon', 'stack', and 'dice'. Default: 'dice' 921 | The list of options are: 922 | - 'stack' (torch.Tensor): Stack of 6 faces, in the order 923 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 924 | - 'list' (list of torch.Tensor): List of 6 faces, in the order 925 | of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 926 | - 'dict' (dict of torch.Tensor): Dictionary with keys pointing to 927 | face tensors. Dict keys are: 928 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 929 | - 'dice' (torch.Tensor): A cubemap in a 'dice' layout. 930 | - 'horizon' (torch.Tensor): A cubemap in a 'horizon' layout, 931 | a 1x6 grid in the order: 932 | ['Front', 'Right', 'Back', 'Left', 'Up', 'Down']. 933 | channels_first (bool, optional): The channel format of e_img. PyTorch 934 | uses channels first. Default: 'True' 935 | 936 | Returns: 937 | Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]]: 938 | A cubemap in the specified format. 939 | 940 | Raises: 941 | NotImplementedError: If an unknown cube_format is provided. 942 | """ 943 | assert e_img.dim() == 3 or e_img.dim() == 4, ( 944 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], " 945 | f"or [H,W,C], got shape of: {e_img.shape}" 946 | ) 947 | 948 | e_img = _nchw2nhwc(e_img) if channels_first else e_img 949 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3] 950 | face_w = h // 2 if face_w is None else face_w 951 | 952 | # returns [face_w, face_w*6, 3] in order 953 | # [Front, Right, Back, Left, Up, Down] 954 | xyz = xyzcube(face_w, device=e_img.device, dtype=e_img.dtype) 955 | uv = xyz2uv(xyz) 956 | coor_xy = uv2coor(uv, h, w) 957 | # Sample all channels: 958 | out_c = sample_equirec(e_img, coor_xy, mode) # [face_w, 6*face_w, C] 959 | # out_c shape: we did it directly for each pixel in the cube map 960 | 961 | result: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor], None] = ( 962 | None 963 | ) 964 | if cube_format == "horizon": 965 | result = out_c 966 | elif cube_format == "list" or cube_format == "stack": 967 | result = cube_h2list(out_c) 968 | elif cube_format == "dict": 969 | result = cube_h2dict(out_c) 970 | elif cube_format == "dice": 971 | result = cube_h2dice(out_c) 972 | else: 973 | raise NotImplementedError("unknown cube_format") 974 | 975 | # Convert to CHW if required 976 | if channels_first: 977 | if cube_format == "list" or cube_format == "stack": 978 | assert isinstance(result, (list, tuple)) 979 | result = [_nhwc2nchw(r) for r in result] 980 | elif cube_format == "dict": 981 | assert torch.jit.isinstance(result, Dict[str, torch.Tensor]) 982 | result = {k: _nhwc2nchw(v) for k, v in result.items()} # type: ignore[union-attr] 983 | elif cube_format in ["horizon", "dice"]: 984 | assert isinstance(result, torch.Tensor) 985 | result = _nhwc2nchw(result) 986 | if cube_format == "stack" and isinstance(result, (list, tuple)): 987 | result = torch.stack(result) 988 | return result 989 | 990 | 991 | def e2p( 992 | e_img: torch.Tensor, 993 | fov_deg: Union[float, Tuple[float, float]], 994 | h_deg: float = 0.0, 995 | v_deg: float = 0.0, 996 | out_hw: Tuple[int, int] = (512, 512), 997 | in_rot_deg: float = 0.0, 998 | mode: str = "bilinear", 999 | channels_first: bool = True, 1000 | ) -> torch.Tensor: 1001 | """ 1002 | Convert an equirectangular image to a perspective projection. 1003 | 1004 | Args: 1005 | e_img (torch.Tensor): Input equirectangular image tensor in the shape 1006 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W] 1007 | or [B, H, W, C]. 1008 | fov_deg (float or tuple of floats, optional): Field of view in degrees. 1009 | Can be a single float or (h_fov, v_fov) tuple. 1010 | h_deg (float, optional): Horizontal rotation angle in degrees 1011 | (-Left/+Right). Default: 0.0 1012 | v_deg (float, optional): Vertical rotation angle in degrees 1013 | (-Down/+Up). Default: 0.0 1014 | out_hw (tuple of int, optional): The output image size specified as 1015 | a tuple of (height, width). Default: (512, 512) 1016 | in_rot_deg (float, optional): Input rotation angle in degrees. 1017 | Default: 0.0 1018 | mode (str, optional): Sampling interpolation mode, 'nearest', 1019 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 1020 | channels_first (bool, optional): The channel format of e_img. PyTorch 1021 | uses channels first. Default: 'True' 1022 | 1023 | Returns: 1024 | torch.Tensor: Perspective projection image tensor. 1025 | """ 1026 | assert e_img.dim() == 3 or e_img.dim() == 4, ( 1027 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], " 1028 | f"or [H,W,C], got shape of: {e_img.shape}" 1029 | ) 1030 | 1031 | # Ensure input is in HWC format for processing 1032 | e_img = _nchw2nhwc(e_img) if channels_first else e_img 1033 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3] 1034 | 1035 | if isinstance(fov_deg, (list, tuple)): 1036 | h_fov_rad = fov_deg[0] * torch.pi / 180 1037 | v_fov_rad = fov_deg[1] * torch.pi / 180 1038 | else: 1039 | fov = fov_deg * torch.pi / 180 1040 | h_fov_rad = fov 1041 | v_fov_rad = fov 1042 | 1043 | in_rot = in_rot_deg * torch.pi / 180 1044 | 1045 | u = -h_deg * torch.pi / 180 1046 | v = v_deg * torch.pi / 180 1047 | 1048 | xyz = xyzpers( 1049 | h_fov_rad, 1050 | v_fov_rad, 1051 | u, 1052 | v, 1053 | out_hw, 1054 | in_rot, 1055 | device=e_img.device, 1056 | dtype=e_img.dtype, 1057 | ) 1058 | uv = xyz2uv(xyz) 1059 | coor_xy = uv2coor(uv, h, w) 1060 | 1061 | pers_img = sample_equirec(e_img, coor_xy, mode) 1062 | 1063 | # Convert back to CHW if required 1064 | pers_img = _nhwc2nchw(pers_img) if channels_first else pers_img 1065 | return pers_img 1066 | 1067 | 1068 | def e2e( 1069 | e_img: torch.Tensor, 1070 | roll: float = 0.0, 1071 | h_deg: float = 0.0, 1072 | v_deg: float = 0.0, 1073 | mode: str = "bilinear", 1074 | channels_first: bool = True, 1075 | ) -> torch.Tensor: 1076 | """ 1077 | Apply rotations to an equirectangular image along the roll, pitch, and yaw 1078 | axes. 1079 | 1080 | This function rotates an equirectangular image tensor along the roll 1081 | (x-axis), pitch (y-axis), and yaw (z-axis) axes, which correspond to 1082 | rotations that produce vertical and horizontal shifts in the image. 1083 | 1084 | Args: 1085 | e_img (torch.Tensor): Input equirectangular image tensor in the shape 1086 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W] 1087 | or [B, H, W, C]. 1088 | roll (float, optional): Roll angle in degrees. Rotates the image along 1089 | the x-axis. Roll directions: (-counter_clockwise/+clockwise). 1090 | Default: 0.0 1091 | h_deg (float, optional): Yaw angle in degrees (-left/+right). Rotates the 1092 | image along the z-axis to produce a horizontal shift. Default: 0.0 1093 | v_deg (float, optional): Pitch angle in degrees (-down/+up). Rotates the 1094 | image along the y-axis to produce a vertical shift. Default: 0.0 1095 | mode (str, optional): Sampling interpolation mode, 'nearest', 1096 | 'bicubic', or 'bilinear'. Default: 'bilinear'. 1097 | channels_first (bool, optional): The channel format of e_img. PyTorch 1098 | uses channels first. Default: 'True' 1099 | 1100 | Returns: 1101 | torch.Tensor: Modified equirectangular image. 1102 | """ 1103 | 1104 | roll = roll 1105 | yaw = h_deg 1106 | pitch = v_deg 1107 | 1108 | assert e_img.dim() == 3 or e_img.dim() == 4, ( 1109 | "e_img should be in the shape of [N,C,H,W], [C,H,W], [N,H,W,C], " 1110 | f"or [H,W,C], got shape of: {e_img.shape}" 1111 | ) 1112 | 1113 | # Ensure input is in HWC format for processing 1114 | e_img = _nchw2nhwc(e_img) if channels_first else e_img 1115 | h, w = e_img.shape[:2] if e_img.dim() == 3 else e_img.shape[1:3] 1116 | 1117 | # Convert angles to radians 1118 | roll_rad = torch.tensor( 1119 | [roll * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype 1120 | ) 1121 | pitch_rad = torch.tensor( 1122 | [pitch * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype 1123 | ) 1124 | yaw_rad = torch.tensor( 1125 | [yaw * torch.pi / 180.0], device=e_img.device, dtype=e_img.dtype 1126 | ) 1127 | 1128 | # Create base coordinates for the output image 1129 | y_range = torch.linspace(0, h - 1, h, device=e_img.device, dtype=e_img.dtype) 1130 | x_range = torch.linspace(0, w - 1, w, device=e_img.device, dtype=e_img.dtype) 1131 | grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") 1132 | 1133 | # Convert pixel coordinates to spherical coordinates 1134 | uv = coor2uv(torch.stack([grid_x, grid_y], dim=-1), h, w) 1135 | 1136 | # Convert to unit vectors on sphere 1137 | xyz = uv2unitxyz(uv) 1138 | 1139 | # Create rotation matrices 1140 | Rx = rotation_matrix( 1141 | roll_rad, torch.tensor([1.0, 0.0, 0.0], device=e_img.device, dtype=e_img.dtype) 1142 | ) 1143 | Ry = rotation_matrix( 1144 | yaw_rad, torch.tensor([0.0, 0.0, 1.0], device=e_img.device, dtype=e_img.dtype) 1145 | ) 1146 | Rz = rotation_matrix( 1147 | pitch_rad, torch.tensor([0.0, 1.0, 0.0], device=e_img.device, dtype=e_img.dtype) 1148 | ) 1149 | 1150 | # Apply rotations: first roll, then pitch, then yaw 1151 | xyz_rot = xyz @ Rx @ Ry @ Rz 1152 | 1153 | # Convert back to UV coordinates 1154 | uv_rot = xyz2uv(xyz_rot) 1155 | 1156 | # Convert UV coordinates to pixel coordinates 1157 | coor_xy = uv2coor(uv_rot, h, w) 1158 | 1159 | # Sample from the input image 1160 | rotated = sample_equirec(e_img, coor_xy, mode=mode) 1161 | 1162 | # Return to original channel format if needed 1163 | rotated = _nhwc2nchw(rotated) if channels_first else rotated 1164 | return rotated 1165 | 1166 | 1167 | def pad_180_to_360( 1168 | e_img: torch.Tensor, fill_value: float = 0.0, channels_first: bool = True 1169 | ) -> torch.Tensor: 1170 | """ 1171 | Pads a 180 degree equirectangular image tensor with a shape of CHW or NCHW, to 1172 | make it a full 360 degree image, by padding to the left and right sides. 1173 | 1174 | For an image of width W (covering 180 degrees), the full panorama requires the 1175 | total image width to be doubled. Padding is evenly split between the left and 1176 | right sides. 1177 | 1178 | Args: 1179 | e_img (torch.Tensor): Input equirectangular image tensor in the shape 1180 | of: [C, H, W] or [H, W, C]. Or with a batch dimension: [B, C, H, W] 1181 | or [B, H, W, C]. 1182 | fill_value (int, float): The constant value for padding. Default: 0.0 1183 | channels_first (bool, optional): The channel format of e_img. PyTorch 1184 | uses channels first. Default: 'True' 1185 | 1186 | Returns: 1187 | torch.Tensor: The padded tensor. 1188 | """ 1189 | assert e_img.dim() in [3, 4] 1190 | e_img = _nhwc2nchw(e_img) if not channels_first else e_img 1191 | H, W = e_img.shape[1:] if e_img.dim() == 3 else e_img.shape[2:] 1192 | pad_left = W // 2 1193 | pad_right = W - pad_left 1194 | 1195 | if e_img.ndim == 3: 1196 | e_img_padded = F.pad( 1197 | e_img.unsqueeze(0), (pad_left, pad_right), mode="constant", value=fill_value 1198 | ).squeeze(0) 1199 | elif e_img.ndim == 4: 1200 | e_img_padded = F.pad( 1201 | e_img, (pad_left, pad_right), mode="constant", value=fill_value 1202 | ) 1203 | else: 1204 | raise ValueError( 1205 | "e_img must be either 3D (CHW) or 4D (NCHW), got {e_img.shape}" 1206 | ) 1207 | 1208 | e_img_padded = _nchw2nhwc(e_img_padded) if not channels_first else e_img_padded 1209 | return e_img_padded 1210 | -------------------------------------------------------------------------------- /pytorch360convert/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.2" 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # Load the current project version 6 | exec(open("pytorch360convert/version.py").read()) 7 | 8 | 9 | # Convert relative image links to full links for PyPI 10 | def _relative_to_full_link(long_description: str) -> str: 11 | """ 12 | Converts relative image links in a README to full GitHub URLs. 13 | 14 | This function replaces relative image links (e.g., in tags and 15 | Markdown ![]() syntax) with their corresponding full GitHub URLs, appending 16 | `?raw=true` for direct access to raw images. 17 | 18 | Links are only replaced if they point to the 'examples' directory, and are 19 | in the format of: `` or 20 | `![](examples/)`. 21 | 22 | Args: 23 | long_description (str): The text containing relative image links. 24 | 25 | Returns: 26 | str: The modified text with full image URLs. 27 | """ 28 | import re 29 | 30 | # Base URL for raw GitHub links 31 | github_base_url = "https://github.com/ProGamerGov/pytorch360convert/raw/main/" 32 | 33 | # Replace relative links in 34 | long_description = re.sub( 35 | r'(=1.8.0", 85 | ], 86 | packages=find_packages(exclude=("tests", "tests.*")), 87 | classifiers=[ 88 | "Development Status :: 4 - Beta", 89 | "Intended Audience :: Developers", 90 | "Intended Audience :: Education", 91 | "Intended Audience :: Science/Research", 92 | "Intended Audience :: Information Technology", 93 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 94 | "Topic :: Scientific/Engineering :: Image Processing", 95 | "Topic :: Software Development", 96 | "Topic :: Software Development :: Libraries", 97 | "Topic :: Games/Entertainment", 98 | "Topic :: Multimedia :: Graphics :: Viewers", 99 | "License :: OSI Approved :: MIT License", 100 | "Programming Language :: Python :: 3", 101 | "Programming Language :: Python :: 3.8", 102 | "Programming Language :: Python :: 3.9", 103 | "Programming Language :: Python :: 3.10", 104 | "Programming Language :: Python :: 3.11", 105 | "Programming Language :: Python :: 3.12", 106 | "Programming Language :: Python :: 3.13", 107 | "Programming Language :: Python :: 3.14", 108 | ], 109 | ) 110 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Module tests 2 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import math 3 | import random 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | import torch 9 | from pytorch360convert.pytorch360convert import ( 10 | _face_slice, 11 | _nchw2nhwc, 12 | _nhwc2nchw, 13 | _slice_chunk, 14 | c2e, 15 | coor2uv, 16 | cube_dice2h, 17 | cube_dict2h, 18 | cube_h2dice, 19 | cube_h2dict, 20 | cube_h2list, 21 | cube_list2h, 22 | e2c, 23 | e2e, 24 | e2p, 25 | equirect_facetype, 26 | equirect_uvgrid, 27 | grid_sample_wrap, 28 | pad_180_to_360, 29 | pad_cube_faces, 30 | rotation_matrix, 31 | sample_cubefaces, 32 | uv2coor, 33 | uv2unitxyz, 34 | xyz2uv, 35 | xyzcube, 36 | xyzpers, 37 | ) 38 | 39 | 40 | def assertTensorAlmostEqual( 41 | self, actual: torch.Tensor, expected: torch.Tensor, delta: float = 0.0001 42 | ) -> None: 43 | """ 44 | Args: 45 | 46 | self (): A unittest instance. 47 | actual (torch.Tensor): A tensor to compare with expected. 48 | expected (torch.Tensor): A tensor to compare with actual. 49 | delta (float, optional): The allowed difference between actual and expected. 50 | Default: 0.0001 51 | """ 52 | self.assertEqual(actual.shape, expected.shape) 53 | self.assertEqual(actual.device, expected.device) 54 | self.assertEqual(actual.dtype, expected.dtype) 55 | self.assertAlmostEqual( 56 | torch.sum(torch.abs(actual - expected)).item(), 0.0, delta=delta 57 | ) 58 | 59 | 60 | def _create_test_faces(face_height: int = 512, face_width: int = 512) -> torch.Tensor: 61 | # Create unique colors for faces (6 colors) 62 | face_colors = [ 63 | [0.0, 0.0, 0.0], 64 | [0.2, 0.2, 0.2], 65 | [0.4, 0.4, 0.4], 66 | [0.6, 0.6, 0.6], 67 | [0.8, 0.8, 0.8], 68 | [1.0, 1.0, 1.0], 69 | ] 70 | face_colors = torch.as_tensor(face_colors).view(6, 3, 1, 1) 71 | 72 | # Create and color faces (6 squares) 73 | faces = torch.ones([6, 3] + [face_height, face_width]) * face_colors 74 | return faces 75 | 76 | 77 | def _create_dice_layout( 78 | faces: torch.Tensor, face_h: int = 512, face_w: int = 512 79 | ) -> torch.Tensor: 80 | H, W = face_h, face_w 81 | cube = torch.zeros((3, 3 * H, 4 * W)) 82 | cube[:, 0 : 1 * H, W : 2 * W] = faces[4] 83 | cube[:, 1 * H : 2 * H, 0:W] = faces[3] 84 | cube[:, 1 * H : 2 * H, W : 2 * W] = faces[0] 85 | cube[:, 1 * H : 2 * H, 2 * W : 3 * W] = faces[1] 86 | cube[:, 1 * H : 2 * H, 3 * W : 4 * W] = faces[2] 87 | cube[:, 2 * H : 3 * H, W : 2 * W] = faces[5] 88 | return cube 89 | 90 | 91 | def _get_c2e_4x4_exact_tensor() -> torch.Tensor: 92 | a = 0.4000000059604645 93 | b = 0.6000000238418579 94 | c = 0.507179856300354 95 | d = 0.0000 96 | e = 0.09061708301305771 97 | f = 0.20000000298023224 98 | g = 0.36016160249710083 99 | 100 | # Create the expected middle part for each of the 3 matrices 101 | middle = [a, a, b, b, b, c, d, d, d, e, f, f, f, g, a, a] 102 | middle = torch.tensor(middle).unsqueeze(0) # Shape (1, 16) 103 | 104 | # Create the base output for the tensor 105 | expected_output = torch.zeros(8, 16) 106 | 107 | # Fill the first and last rows with the constants 108 | expected_output[0:2] = 0.800000011920929 109 | expected_output[6:8] = 1.0000 110 | 111 | # Fill the middle rows with the specific values 112 | expected_output[2:6] = middle 113 | 114 | # Exact values for the last row (row 7) 115 | last_row = [ 116 | 0.7569681406021118, 117 | 0.8475320339202881, 118 | 1.0, 119 | 0.8708105087280273, 120 | 0.8414928317070007, 121 | 0.791767954826355, 122 | 0.9714627265930176, 123 | 0.6305789947509766, 124 | 0.6305789947509766, 125 | 0.5338935256004333, 126 | 0.8733490109443665, 127 | 0.6829856634140015, 128 | 0.7416210174560547, 129 | 0.19919204711914062, 130 | 0.8475320935249329, 131 | 0.7569681406021118, 132 | ] 133 | expected_output[5] = torch.tensor(last_row) 134 | 135 | # Now, create the 3 matrices by stacking the result 136 | expected_output = torch.stack([expected_output] * 3, dim=0) 137 | 138 | return expected_output 139 | 140 | 141 | def _get_e2c_4x4_exact_tensor() -> torch.Tensor: 142 | f = [[1.0, 1.2951672077178955, 1.7048327922821045, 2.0]] * 4 143 | r = [[2.0, 2.2951672077178955, 2.7048325538635254, 3.0]] * 4 144 | b = [[3.0, 3.0, 3.0, 0.0]] * 4 145 | l = [[0.0, 0.29516735672950745, 0.7048328518867493, 1.0]] * 4 146 | u = [ 147 | [0.0, 3.0, 3.0, 3.0], 148 | [0.29516735672950745, 0.0, 3.0, 2.7048325538635254], 149 | [0.7048328518867493, 1.0, 2.0, 2.2951672077178955], 150 | [1.0, 1.2951672077178955, 1.7048327922821045, 2.0], 151 | ] 152 | d = u[::-1] 153 | expected_out = torch.stack( 154 | [ 155 | torch.tensor(f).repeat(3, 1, 1), 156 | torch.tensor(r).repeat(3, 1, 1), 157 | torch.tensor(b).repeat(3, 1, 1), 158 | torch.tensor(l).repeat(3, 1, 1), 159 | torch.tensor(u).repeat(3, 1, 1), 160 | torch.tensor(d).repeat(3, 1, 1), 161 | ] 162 | ) 163 | return expected_out 164 | 165 | 166 | class TestFunctionsBaseTest(unittest.TestCase): 167 | def setUp(self) -> None: 168 | seed = 1234 169 | random.seed(seed) 170 | np.random.seed(seed) 171 | torch.manual_seed(seed) 172 | torch.cuda.manual_seed_all(seed) 173 | torch.backends.cudnn.deterministic = True 174 | 175 | def test_rotation_matrix(self) -> None: 176 | # Test identity rotation (0 radians around any axis) 177 | axis = torch.tensor([1.0, 0.0, 0.0]) 178 | angle = torch.tensor([0.0]) 179 | result = rotation_matrix(angle, axis) 180 | expected = torch.eye(3) 181 | torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6) 182 | 183 | def test_rotation_matrix_90deg(self) -> None: 184 | # Test 90-degree rotation around x-axis 185 | axis = torch.tensor([1.0, 0.0, 0.0]) 186 | angle = torch.tensor([math.pi / 2]) 187 | result = rotation_matrix(angle, axis) 188 | expected = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]]) 189 | torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6) 190 | 191 | # Test rotation matrix properties 192 | # Should be orthogonal (R * R.T = I) 193 | result_t = result.t() 194 | identity = torch.mm(result, result_t) 195 | torch.testing.assert_close(identity, torch.eye(3), rtol=1e-6, atol=1e-6) 196 | 197 | def test_nhwc_to_nchw_channels_first(self) -> None: 198 | input_tensor = torch.rand(2, 3, 4, 5) 199 | converted_tensor = _nhwc2nchw(input_tensor) 200 | self.assertEqual(converted_tensor.shape, (2, 5, 3, 4)) 201 | 202 | def test_nchw_to_nhwc_channels_first(self) -> None: 203 | input_tensor = torch.rand(2, 5, 3, 4) 204 | converted_tensor = _nchw2nhwc(input_tensor) 205 | self.assertEqual(converted_tensor.shape, (2, 3, 4, 5)) 206 | 207 | def test_slice_chunk_default(self) -> None: 208 | index = 2 209 | width = 3 210 | offset = 0 211 | expected = torch.tensor([6, 7, 8], dtype=torch.long) 212 | result = _slice_chunk(index, width, offset) 213 | torch.testing.assert_close(result, expected) 214 | 215 | def test_slice_chunk_with_offset(self) -> None: 216 | # Test with a non-zero offset 217 | index = 2 218 | width = 3 219 | offset = 1 220 | expected = torch.tensor([7, 8, 9], dtype=torch.long) 221 | result = _slice_chunk(index, width, offset) 222 | torch.testing.assert_close(result, expected) 223 | 224 | def test_slice_chunk_gpu(self) -> None: 225 | if not torch.cuda.is_available(): 226 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 227 | index = 2 228 | width = 3 229 | offset = 0 230 | expected = torch.tensor([6, 7, 8], dtype=torch.long).cuda() 231 | result = _slice_chunk(index, width, offset, device=expected.device) 232 | torch.testing.assert_close(result, expected) 233 | self.assertTrue(result.is_cuda) 234 | 235 | def test_face_slice(self) -> None: 236 | # Test _face_slice, which internally calls _slice_chunk 237 | index = 2 238 | face_w = 3 239 | expected = torch.tensor([6, 7, 8], dtype=torch.long) 240 | result = _face_slice(index, face_w) 241 | torch.testing.assert_close(result, expected) 242 | 243 | def test_face_slice_gpu(self) -> None: 244 | if not torch.cuda.is_available(): 245 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 246 | # Test _face_slice, which internally calls _slice_chunk 247 | index = 2 248 | face_w = 3 249 | expected = torch.tensor([6, 7, 8], dtype=torch.long).cuda() 250 | result = _face_slice(index, face_w, device=expected.device) 251 | torch.testing.assert_close(result, expected) 252 | self.assertTrue(result.is_cuda) 253 | 254 | def test_xyzcube(self) -> None: 255 | face_w = 4 256 | result = xyzcube(face_w) 257 | 258 | # Check shape 259 | self.assertEqual(result.shape, (face_w, face_w * 6, 3)) 260 | 261 | # Check that coordinates are normalized (-0.5 to 0.5) 262 | self.assertTrue(torch.all(result >= -0.5)) 263 | self.assertTrue(torch.all(result <= 0.5)) 264 | 265 | # Test front face center point (adjusting for coordinate system) 266 | center_idx = face_w // 2 267 | front_center = result[center_idx, center_idx] 268 | expected_front = torch.tensor([0.0, 0.0, 0.5]) 269 | torch.testing.assert_close(front_center, expected_front, rtol=0.17, atol=0.17) 270 | 271 | def test_cube_h2list(self) -> None: 272 | # Create a mock tensor with a shape [w, w*6, C] 273 | w = 3 # width of the cube face 274 | C = 2 # number of channels (e.g., RGB) 275 | cube_h = torch.randn(w, w * 6, C) # Random tensor with dimensions [3, 18, 2] 276 | 277 | # Call the function 278 | result = cube_h2list(cube_h) 279 | 280 | # Assert that the result is a list of 6 tensors (one for each face) 281 | self.assertEqual(len(result), 6) 282 | 283 | # Assert each tensor has the correct shape [w, w, C] 284 | for tensor in result: 285 | self.assertEqual(tensor.shape, (w, w, C)) 286 | 287 | # Ensure the shapes are sliced correctly 288 | for i in range(6): 289 | self.assertTrue(torch.equal(result[i], cube_h[:, i * w : (i + 1) * w, :])) 290 | 291 | def test_cube_h2dict(self) -> None: 292 | # Create a mock tensor with a shape [w, w*6, C] 293 | w = 3 # width of the cube face 294 | C = 2 # number of channels (e.g., RGB) 295 | cube_h = torch.randn(w, w * 6, C) # Random tensor with dimensions [3, 18, 2] 296 | face_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 297 | 298 | # Call the function 299 | result = cube_h2dict(cube_h, face_keys) 300 | 301 | # Assert that the result is a dictionary with 6 entries 302 | self.assertEqual(len(result), 6) 303 | 304 | # Assert that the dictionary keys are correct 305 | self.assertEqual(list(result.keys()), face_keys) 306 | 307 | # Assert each tensor has the correct shape [w, w, C] 308 | for face in face_keys: 309 | self.assertEqual(result[face].shape, (w, w, C)) 310 | 311 | # Check that the values correspond to the expected slices of the input tensor 312 | for i, face in enumerate(face_keys): 313 | self.assertTrue( 314 | torch.equal(result[face], cube_h[:, i * w : (i + 1) * w, :]) 315 | ) 316 | 317 | def test_equirect_uvgrid(self) -> None: 318 | h, w = 8, 16 319 | result = equirect_uvgrid(h, w) 320 | 321 | # Check shape 322 | self.assertEqual(result.shape, (h, w, 2)) 323 | 324 | # Check ranges 325 | u = result[..., 0] 326 | v = result[..., 1] 327 | self.assertTrue(torch.all(u >= -torch.pi)) 328 | self.assertTrue(torch.all(u <= torch.pi)) 329 | self.assertTrue(torch.all(v >= -torch.pi / 2)) 330 | self.assertTrue(torch.all(v <= torch.pi / 2)) 331 | 332 | # Check center point 333 | center_h, center_w = h // 2, w // 2 334 | center_point = result[center_h, center_w] 335 | expected_center = torch.tensor([0.0, 0.0]) 336 | torch.testing.assert_close( 337 | center_point, expected_center, rtol=0.225, atol=0.225 338 | ) 339 | 340 | def test_equirect_facetype(self) -> None: 341 | h, w = 8, 16 342 | result = equirect_facetype(h, w) 343 | 344 | # Check shape 345 | self.assertEqual(result.shape, (h, w)) 346 | 347 | # Check face type range (0-5 for 6 faces) 348 | self.assertTrue(torch.all(result >= 0)) 349 | self.assertTrue(torch.all(result <= 5)) 350 | 351 | # Check sum 352 | self.assertEqual(result.sum().item(), 384.0) 353 | 354 | # Check dtype 355 | self.assertEqual(result.dtype, torch.int64) 356 | 357 | def test_equirect_facetype_large(self) -> None: 358 | h, w = 512 * 2, 512 * 4 359 | result = equirect_facetype(h, w) 360 | 361 | # Check shape 362 | self.assertEqual(result.shape, (h, w)) 363 | 364 | # Check face type range (0-5 for 6 faces) 365 | self.assertTrue(torch.all(result >= 0)) 366 | self.assertTrue(torch.all(result <= 5)) 367 | 368 | # Check sum 369 | self.assertEqual(result.sum().item(), 6510864.0) 370 | 371 | # Check dtype 372 | self.assertEqual(result.dtype, torch.int64) 373 | 374 | def test_equirect_facetype_float64(self) -> None: 375 | h, w = 8, 16 376 | result = equirect_facetype(h, w, dtype=torch.float64) 377 | 378 | # Check shape 379 | self.assertEqual(result.shape, (h, w)) 380 | 381 | # Check face type range (0-5 for 6 faces) 382 | self.assertTrue(torch.all(result >= 0)) 383 | self.assertTrue(torch.all(result <= 5)) 384 | 385 | # Check sum 386 | self.assertEqual(result.sum().item(), 384.0) 387 | 388 | # Check dtype 389 | self.assertEqual(result.dtype, torch.int64) 390 | 391 | def test_equirect_facetype_float16(self) -> None: 392 | h, w = 8, 16 393 | result = equirect_facetype(h, w, dtype=torch.float16) 394 | 395 | # Check shape 396 | self.assertEqual(result.shape, (h, w)) 397 | 398 | # Check face type range (0-5 for 6 faces) 399 | self.assertTrue(torch.all(result >= 0)) 400 | self.assertTrue(torch.all(result <= 5)) 401 | 402 | # Check sum 403 | self.assertEqual(result.sum().item(), 384.0) 404 | 405 | # Check dtype 406 | self.assertEqual(result.dtype, torch.int64) 407 | 408 | def test_equirect_facetype_gpu(self) -> None: 409 | if not torch.cuda.is_available(): 410 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 411 | h, w = 8, 16 412 | result = equirect_facetype(h, w, device=torch.device("cuda")) 413 | 414 | # Check shape 415 | self.assertEqual(result.shape, (h, w)) 416 | 417 | # Check face type range (0-5 for 6 faces) 418 | self.assertTrue(torch.all(result >= 0)) 419 | self.assertTrue(torch.all(result <= 5)) 420 | 421 | # Check sum 422 | self.assertEqual(result.sum().item(), 384.0) 423 | 424 | # Check dtype 425 | self.assertEqual(result.dtype, torch.int64) 426 | 427 | # Check cuda 428 | self.assertTrue(result.is_cuda) 429 | 430 | def test_xyz2uv_and_uv2unitxyz(self) -> None: 431 | # Create test points 432 | xyz = torch.tensor( 433 | [ 434 | [1.0, 0.0, 0.0], # right 435 | [0.0, 1.0, 0.0], # up 436 | [0.0, 0.0, 1.0], # front 437 | ] 438 | ) 439 | 440 | # Convert xyz to uv 441 | uv = xyz2uv(xyz) 442 | 443 | # Convert back to xyz 444 | xyz_reconstructed = uv2unitxyz(uv) 445 | 446 | # Normalize input xyz for comparison 447 | xyz_normalized = torch.nn.functional.normalize(xyz, dim=-1) 448 | 449 | # Verify reconstruction 450 | torch.testing.assert_close( 451 | xyz_normalized, xyz_reconstructed, rtol=1e-6, atol=1e-6 452 | ) 453 | 454 | def test_uv2coor_and_coor2uv(self) -> None: 455 | h, w = 8, 16 456 | # Create test UV coordinates 457 | test_uv = torch.tensor( 458 | [ 459 | [0.0, 0.0], # center 460 | [torch.pi / 2, 0.0], # right quadrant 461 | [-torch.pi / 2, 0.0], # left quadrant 462 | ] 463 | ) 464 | 465 | # Convert UV to image coordinates 466 | coor = uv2coor(test_uv, h, w) 467 | 468 | # Convert back to UV 469 | uv_reconstructed = coor2uv(coor, h, w) 470 | 471 | # Verify reconstruction 472 | torch.testing.assert_close(test_uv, uv_reconstructed, rtol=1e-5, atol=1e-5) 473 | 474 | def test_grid_sample_wrap(self) -> None: 475 | # Create test image 476 | h, w = 4, 8 477 | channels = 3 478 | image = torch.arange(h * w * channels, dtype=torch.float32) 479 | image = image.reshape(h, w, channels) 480 | 481 | # Test basic sampling 482 | coor_x = torch.tensor([[1.5, 2.5], [3.5, 4.5]]) 483 | coor_y = torch.tensor([[1.5, 1.5], [2.5, 2.5]]) 484 | 485 | # Test both interpolation modes 486 | result_bilinear = grid_sample_wrap(image, coor_x, coor_y, mode="bilinear") 487 | result_nearest = grid_sample_wrap(image, coor_x, coor_y, mode="nearest") 488 | 489 | # Check shapes 490 | self.assertEqual(result_bilinear.shape, (2, 2, channels)) 491 | self.assertEqual(result_nearest.shape, (2, 2, channels)) 492 | 493 | # Test horizontal wrapping 494 | wrap_x = torch.tensor([[w - 1.5, 0.5]]) 495 | wrap_y = torch.tensor([[1.5, 1.5]]) 496 | result_wrap = grid_sample_wrap(image, wrap_x, wrap_y, mode="bilinear") 497 | 498 | # Check that wrapped coordinates produce similar values 499 | # We use a larger tolerance here due to interpolation differences 500 | torch.testing.assert_close( 501 | result_wrap[0, 0], 502 | result_wrap[0, 1], 503 | rtol=0.5, 504 | atol=0.5, 505 | ) 506 | 507 | def test_grid_sample_wrap_cpu_float16(self) -> None: 508 | # Create test image 509 | h, w = 4, 8 510 | channels = 3 511 | image = torch.arange(h * w * channels, dtype=torch.float16) 512 | image = image.reshape(h, w, channels) 513 | 514 | # Test basic sampling 515 | coor_x = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float16) 516 | coor_y = torch.tensor([[1.5, 1.5], [2.5, 2.5]], dtype=torch.float16) 517 | 518 | # Test both interpolation modes 519 | result_bilinear = grid_sample_wrap(image, coor_x, coor_y, mode="bilinear") 520 | self.assertEqual(result_bilinear.dtype, torch.float16) 521 | 522 | def test_sample_cubefaces_cpu_float16(self) -> None: 523 | # Face type tensor (which face to sample) 524 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float16) # Random face types 525 | 526 | # Coordinates for sampling 527 | coor_y = torch.tensor( 528 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float16 529 | ) # y-coordinates 530 | coor_x = torch.tensor( 531 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float16 532 | ) # x-coordinates 533 | 534 | mode = "bilinear" 535 | 536 | # Call sample_cubefaces 537 | output = sample_cubefaces( 538 | torch.ones([6, 8, 8, 3], dtype=torch.float16), tp, coor_y, coor_x, mode 539 | ) 540 | self.assertEqual(output.dtype, tp.dtype) 541 | 542 | def test_sample_cubefaces(self) -> None: 543 | # Face type tensor (which face to sample) 544 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32) # Random face types 545 | 546 | # Coordinates for sampling 547 | coor_y = torch.tensor( 548 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32 549 | ) # y-coordinates 550 | coor_x = torch.tensor( 551 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32 552 | ) # x-coordinates 553 | 554 | mode = "bilinear" 555 | 556 | # Call sample_cubefaces 557 | output = sample_cubefaces(torch.ones(6, 8, 8, 3), tp, coor_y, coor_x, mode) 558 | self.assertEqual(output.sum().item(), 12.0) 559 | 560 | def test_c2e_then_e2c(self) -> None: 561 | face_width = 512 562 | test_faces = _create_test_faces(face_width, face_width) 563 | equi_img = c2e( 564 | test_faces, 565 | face_width * 2, 566 | face_width * 4, 567 | mode="bilinear", 568 | cube_format="stack", 569 | ) 570 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 571 | cubic_img = e2c( 572 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack" 573 | ) 574 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 575 | # assertTensorAlmostEqual(self, cubic_img, test_faces) 576 | 577 | def test_c2e_then_e2c_gpu(self) -> None: 578 | if not torch.cuda.is_available(): 579 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 580 | face_width = 512 581 | test_faces = _create_test_faces(face_width, face_width).cuda() 582 | equi_img = c2e( 583 | test_faces, 584 | face_width * 2, 585 | face_width * 4, 586 | mode="bilinear", 587 | cube_format="stack", 588 | ) 589 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 590 | cubic_img = e2c( 591 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack" 592 | ) 593 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 594 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr] 595 | # assertTensorAlmostEqual(self, cubic_img, test_faces) 596 | 597 | def test_c2e_stack_grad(self) -> None: 598 | face_width = 512 599 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True) 600 | equi_img = c2e( 601 | test_faces, 602 | face_width * 2, 603 | face_width * 4, 604 | mode="bilinear", 605 | cube_format="stack", 606 | ) 607 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 608 | self.assertTrue(equi_img.requires_grad) 609 | 610 | def test_e2c_stack_grad(self) -> None: 611 | face_width = 512 612 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True) 613 | cubic_img = e2c( 614 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 615 | ) 616 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 617 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr] 618 | 619 | def test_c2e_then_e2c_stack_grad(self) -> None: 620 | face_width = 512 621 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True) 622 | equi_img = c2e( 623 | test_faces, 624 | face_width * 2, 625 | face_width * 4, 626 | mode="bilinear", 627 | cube_format="stack", 628 | ) 629 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 630 | cubic_img = e2c( 631 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack" 632 | ) 633 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 634 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr] 635 | 636 | def test_c2e_list_grad(self) -> None: 637 | face_width = 512 638 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True) 639 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])] 640 | equi_img = c2e( 641 | test_faces, 642 | face_width * 2, 643 | face_width * 4, 644 | mode="bilinear", 645 | cube_format="list", 646 | ) 647 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 648 | self.assertTrue(equi_img.requires_grad) 649 | 650 | def test_e2c_list_grad(self) -> None: 651 | face_width = 512 652 | equi_img = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True) 653 | cubic_img = e2c( 654 | equi_img, face_w=face_width, mode="bilinear", cube_format="list" 655 | ) 656 | for i in range(6): 657 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore[index] 658 | for i in range(6): 659 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore[index] 660 | 661 | def test_c2e_then_e2c_list_grad(self) -> None: 662 | face_width = 512 663 | test_faces_tensors = torch.ones( 664 | [6, 3, face_width, face_width], requires_grad=True 665 | ) 666 | test_faces = [test_faces_tensors[i] for i in range(test_faces_tensors.shape[0])] 667 | equi_img = c2e( 668 | test_faces, 669 | face_width * 2, 670 | face_width * 4, 671 | mode="bilinear", 672 | cube_format="list", 673 | ) 674 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 675 | cubic_img = e2c( 676 | equi_img, face_w=face_width, mode="bilinear", cube_format="stack" 677 | ) 678 | for i in range(6): 679 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore[index] 680 | for i in range(6): 681 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore[index] 682 | 683 | def test_c2e_dict_grad(self) -> None: 684 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 685 | face_width = 512 686 | test_faces_tensors = torch.ones( 687 | [6, 3, face_width, face_width], requires_grad=True 688 | ) 689 | test_faces = {k: test_faces_tensors[i] for i, k in zip(range(6), dict_keys)} 690 | equi_img = c2e( 691 | test_faces, 692 | face_width * 2, 693 | face_width * 4, 694 | mode="bilinear", 695 | cube_format="dict", 696 | ) 697 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 698 | self.assertTrue(equi_img.requires_grad) 699 | 700 | def test_c2e_then_e2c_dict_grad(self) -> None: 701 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 702 | face_width = 512 703 | test_faces_tensors = torch.ones( 704 | [6, 3, face_width, face_width], requires_grad=True 705 | ) 706 | test_faces = {k: test_faces_tensors[i] for i, k in zip(range(6), dict_keys)} 707 | equi_img = c2e( 708 | test_faces, 709 | face_width * 2, 710 | face_width * 4, 711 | mode="bilinear", 712 | cube_format="dict", 713 | ) 714 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 715 | cubic_img = e2c( 716 | equi_img, face_w=face_width, mode="bilinear", cube_format="dict" 717 | ) 718 | for i in dict_keys: 719 | self.assertEqual(list(cubic_img[i].shape), [3, face_width, face_width]) # type: ignore 720 | for i in dict_keys: 721 | self.assertTrue(cubic_img[i].requires_grad) # type: ignore 722 | 723 | def test_c2e_stack_nohw_grad(self) -> None: 724 | face_width = 512 725 | test_faces = torch.ones([6, 3, face_width, face_width], requires_grad=True) 726 | equi_img = c2e( 727 | test_faces, 728 | mode="bilinear", 729 | cube_format="stack", 730 | ) 731 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 732 | self.assertTrue(equi_img.requires_grad) 733 | 734 | def test_sample_cubefaces_py360convert(self) -> None: 735 | try: 736 | import py360convert as p360 737 | except: 738 | raise unittest.SkipTest( 739 | "py360convert not installed, skipping sample_cubefaces test" 740 | ) 741 | # Face type tensor (which face to sample) 742 | tp = torch.tensor([[0, 1], [2, 3]], dtype=torch.float32) # Random face types 743 | 744 | # Coordinates for sampling 745 | coor_y = torch.tensor( 746 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32 747 | ) # y-coordinates 748 | coor_x = torch.tensor( 749 | [[0.0, 1.0], [0.0, 1.0]], dtype=torch.float32 750 | ) # x-coordinates 751 | 752 | mode = "bilinear" 753 | 754 | # Call sample_cubefaces 755 | output = sample_cubefaces(torch.ones(6, 8, 8, 3), tp, coor_y, coor_x, mode) 756 | output_np = p360.sample_cubefaces( 757 | torch.ones(6, 8, 8).numpy(), 758 | tp.numpy(), 759 | coor_y.numpy(), 760 | coor_x.numpy(), 761 | mode, 762 | ) 763 | self.assertEqual(output.sum(), output_np.sum() * 3) 764 | 765 | def test_c2e_py360convert(self) -> None: 766 | try: 767 | import py360convert as p360 768 | except: 769 | raise unittest.SkipTest("py360convert not installed, skipping c2e test") 770 | 771 | face_width = 512 772 | test_faces = _create_test_faces(face_width, face_width) 773 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])] 774 | equi_img = c2e( 775 | test_faces, 776 | face_width * 2, 777 | face_width * 4, 778 | mode="bilinear", 779 | cube_format="list", 780 | ) 781 | test_faces_np = [t.permute(1, 2, 0).numpy() for t in test_faces] 782 | 783 | equi_img_np = p360.c2e( 784 | test_faces_np, 785 | face_width * 2, 786 | face_width * 4, 787 | mode="bilinear", 788 | cube_format="list", 789 | ) 790 | equi_img_np_tensor = torch.from_numpy(equi_img_np).permute(2, 0, 1).float() 791 | assertTensorAlmostEqual(self, equi_img, equi_img_np_tensor, 2722.8169) 792 | 793 | def test_c2e_then_e2c_py360convert(self) -> None: 794 | try: 795 | import py360convert as p360 796 | except: 797 | raise unittest.SkipTest( 798 | "py360convert not installed, skipping c2e and e2c test" 799 | ) 800 | 801 | face_width = 512 802 | test_faces = _create_test_faces(face_width, face_width) 803 | test_faces = [test_faces[i] for i in range(test_faces.shape[0])] 804 | equi_img = c2e( 805 | test_faces, 806 | face_width * 2, 807 | face_width * 4, 808 | mode="bilinear", 809 | cube_format="list", 810 | ) 811 | test_faces_np = [t.permute(1, 2, 0).numpy() for t in test_faces] 812 | 813 | equi_img_np = p360.c2e( 814 | test_faces_np, 815 | face_width * 2, 816 | face_width * 4, 817 | mode="bilinear", 818 | cube_format="list", 819 | ) 820 | 821 | cubic_img = e2c( 822 | equi_img, face_w=face_width, mode="bilinear", cube_format="dice" 823 | ) 824 | 825 | cubic_img_np = p360.e2c( 826 | equi_img_np, face_w=face_width, mode="bilinear", cube_format="dice" 827 | ) 828 | cubic_img_np_tensor = torch.from_numpy(cubic_img_np).permute(2, 0, 1).float() 829 | 830 | assertTensorAlmostEqual(self, cubic_img, cubic_img_np_tensor, 5858.8921) # type: ignore[arg-type] 831 | 832 | def test_e2c_horizon_grad(self) -> None: 833 | face_width = 512 834 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True) 835 | cubic_img = e2c( 836 | test_faces, face_w=face_width, mode="bilinear", cube_format="horizon" 837 | ) 838 | self.assertEqual(list(cubic_img.shape), [3, face_width, face_width * 6]) # type: ignore[union-attr] 839 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr] 840 | 841 | def test_e2c_dice_grad(self) -> None: 842 | face_width = 512 843 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True) 844 | cubic_img = e2c( 845 | test_faces, face_w=face_width, mode="bilinear", cube_format="dice" 846 | ) 847 | self.assertEqual(list(cubic_img.shape), [3, face_width * 3, face_width * 4]) # type: ignore[union-attr] 848 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr] 849 | 850 | def test_c2e_then_e2c_dice_grad(self) -> None: 851 | face_width = 512 852 | test_faces = torch.ones([3, face_width * 3, face_width * 4], requires_grad=True) 853 | equi_img = c2e( 854 | test_faces, 855 | face_width * 2, 856 | face_width * 4, 857 | mode="bilinear", 858 | cube_format="dice", 859 | ) 860 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 861 | cubic_img = e2c( 862 | equi_img, face_w=face_width, mode="bilinear", cube_format="dice" 863 | ) 864 | self.assertEqual(list(cubic_img.shape), [3, face_width * 3, face_width * 4]) # type: ignore[union-attr] 865 | self.assertTrue(cubic_img.requires_grad) # type: ignore[union-attr] 866 | 867 | def test_e2p(self) -> None: 868 | # Create a simple test equirectangular image 869 | h, w = 64, 128 870 | channels = 3 871 | e_img = torch.zeros((channels, h, w)) 872 | # Add some recognizable pattern 873 | e_img[0, :, :] = torch.linspace(0, 1, w).repeat( 874 | h, 1 875 | ) # Red gradient horizontally 876 | e_img[1, :, :] = ( 877 | torch.linspace(0, 1, h).unsqueeze(1).repeat(1, w) 878 | ) # Green gradient vertically 879 | 880 | # Test basic perspective projection 881 | fov_deg = 90.0 882 | u_deg = 0.0 883 | v_deg = 0.0 884 | out_hw = (32, 32) 885 | 886 | result = e2p(e_img, fov_deg, u_deg, v_deg, out_hw) 887 | 888 | # Check output shape 889 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 890 | 891 | # Test with different FOV 892 | narrow_fov = e2p(e_img, 45.0, u_deg, v_deg, out_hw) 893 | wide_fov = e2p(e_img, 120.0, u_deg, v_deg, out_hw) 894 | 895 | # Narrow FOV should have less variation in values than wide FOV 896 | self.assertTrue(torch.std(narrow_fov) < torch.std(wide_fov)) 897 | 898 | # Test with rotation 899 | rotated = e2p(e_img, fov_deg, 90.0, v_deg, out_hw) # 90 degrees right 900 | 901 | # Test with different output sizes 902 | large_output = e2p(e_img, fov_deg, u_deg, v_deg, (64, 64)) 903 | self.assertEqual(list(large_output.shape), [channels, 64, 64]) 904 | 905 | # Test with rectangular output 906 | rect_output = e2p(e_img, fov_deg, u_deg, v_deg, (32, 64)) 907 | self.assertEqual(list(rect_output.shape), [channels, 32, 64]) 908 | 909 | # Test with different FOV for height and width 910 | fov_hw = (90.0, 60.0) 911 | diff_fov = e2p(e_img, fov_hw, u_deg, v_deg, out_hw) 912 | self.assertEqual(list(diff_fov.shape), [channels, out_hw[0], out_hw[1]]) 913 | 914 | def test_e2c_stack_face_w_none(self) -> None: 915 | face_width = 512 916 | test_faces = torch.ones([3, face_width * 2, face_width * 4], requires_grad=True) 917 | cubic_img = e2c(test_faces, face_w=None, mode="bilinear", cube_format="stack") 918 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 919 | 920 | def test_e2c_stack_gpu(self) -> None: 921 | if not torch.cuda.is_available(): 922 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 923 | face_width = 512 924 | test_faces = torch.ones([3, face_width * 2, face_width * 4]).cuda() 925 | cubic_img = e2c( 926 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 927 | ) 928 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 929 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr] 930 | 931 | def test_c2e_stack_gpu(self) -> None: 932 | if not torch.cuda.is_available(): 933 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 934 | face_width = 512 935 | test_faces = torch.ones([6, 3, face_width, face_width]).cuda() 936 | equi_img = c2e( 937 | test_faces, 938 | face_width * 2, 939 | face_width * 4, 940 | mode="bilinear", 941 | cube_format="stack", 942 | ) 943 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 944 | self.assertTrue(equi_img.is_cuda) 945 | 946 | def test_e2c_stack_float16(self) -> None: 947 | face_width = 512 948 | test_faces = torch.ones( 949 | [3, face_width * 2, face_width * 4], dtype=torch.float16 950 | ) 951 | cubic_img = e2c( 952 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 953 | ) 954 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 955 | self.assertEqual(cubic_img.dtype, torch.float16) # type: ignore[union-attr] 956 | 957 | def test_e2c_stack_float64(self) -> None: 958 | face_width = 512 959 | test_faces = torch.ones( 960 | [3, face_width * 2, face_width * 4], dtype=torch.float64 961 | ) 962 | cubic_img = e2c( 963 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 964 | ) 965 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 966 | self.assertEqual(cubic_img.dtype, torch.float64) # type: ignore[union-attr] 967 | 968 | def test_c2e_stack_float16(self) -> None: 969 | face_width = 512 970 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=torch.float16) 971 | equi_img = c2e( 972 | test_faces, 973 | face_width * 2, 974 | face_width * 4, 975 | mode="bilinear", 976 | cube_format="stack", 977 | ) 978 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 979 | self.assertEqual(equi_img.dtype, torch.float16) 980 | 981 | def test_c2e_stack_float64(self) -> None: 982 | face_width = 512 983 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=torch.float64) 984 | equi_img = c2e( 985 | test_faces, 986 | face_width * 2, 987 | face_width * 4, 988 | mode="bilinear", 989 | cube_format="stack", 990 | ) 991 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 992 | self.assertEqual(equi_img.dtype, torch.float64) 993 | 994 | def test_e2p_gpu(self) -> None: 995 | if not torch.cuda.is_available(): 996 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 997 | # Create a simple test equirectangular image 998 | h, w = 64, 128 999 | channels = 3 1000 | e_img = torch.zeros((channels, h, w)).cuda() 1001 | 1002 | fov_deg = 90.0 1003 | h_deg = 0.0 1004 | v_deg = 0.0 1005 | out_hw = (32, 32) 1006 | 1007 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1008 | 1009 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1010 | self.assertTrue(result.is_cuda) 1011 | 1012 | def test_e2p_grad(self) -> None: 1013 | # Create a simple test equirectangular image 1014 | h, w = 64, 128 1015 | channels = 3 1016 | e_img = torch.zeros((channels, h, w), requires_grad=True) 1017 | 1018 | fov_deg = 90.0 1019 | h_deg = 0.0 1020 | v_deg = 0.0 1021 | out_hw = (32, 32) 1022 | 1023 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1024 | 1025 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1026 | self.assertTrue(result.requires_grad) 1027 | 1028 | def test_e2p_float16(self) -> None: 1029 | # Create a simple test equirectangular image 1030 | h, w = 64, 128 1031 | channels = 3 1032 | e_img = torch.zeros((channels, h, w), dtype=torch.float16) 1033 | 1034 | fov_deg = 90.0 1035 | h_deg = 0.0 1036 | v_deg = 0.0 1037 | out_hw = (32, 32) 1038 | 1039 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1040 | 1041 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1042 | self.assertEqual(result.dtype, torch.float16) 1043 | 1044 | def test_e2p_float64(self) -> None: 1045 | # Create a simple test equirectangular image 1046 | h, w = 64, 128 1047 | channels = 3 1048 | e_img = torch.zeros((channels, h, w), dtype=torch.float64) 1049 | 1050 | fov_deg = 90.0 1051 | h_deg = 0.0 1052 | v_deg = 0.0 1053 | out_hw = (32, 32) 1054 | 1055 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1056 | 1057 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1058 | self.assertEqual(result.dtype, torch.float64) 1059 | 1060 | def test_c2e_stack_1channel(self) -> None: 1061 | channels = 1 1062 | face_width = 512 1063 | test_faces = torch.ones( 1064 | [6, channels, face_width, face_width], dtype=torch.float64 1065 | ) 1066 | equi_img = c2e( 1067 | test_faces, 1068 | face_width * 2, 1069 | face_width * 4, 1070 | mode="bilinear", 1071 | cube_format="stack", 1072 | ) 1073 | self.assertEqual( 1074 | list(equi_img.shape), [channels, face_width * 2, face_width * 4] 1075 | ) 1076 | 1077 | def test_c2e_stack_4channels(self) -> None: 1078 | channels = 4 1079 | face_width = 512 1080 | test_faces = torch.ones( 1081 | [6, channels, face_width, face_width], dtype=torch.float64 1082 | ) 1083 | equi_img = c2e( 1084 | test_faces, 1085 | face_width * 2, 1086 | face_width * 4, 1087 | mode="bilinear", 1088 | cube_format="stack", 1089 | ) 1090 | self.assertEqual( 1091 | list(equi_img.shape), [channels, face_width * 2, face_width * 4] 1092 | ) 1093 | 1094 | def test_e2c_stack_1channel(self) -> None: 1095 | channels = 1 1096 | face_width = 512 1097 | test_faces = torch.ones([channels, face_width * 2, face_width * 4]) 1098 | cubic_img = e2c( 1099 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 1100 | ) 1101 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width]) # type: ignore[union-attr] 1102 | 1103 | def test_e2c_stack_4channels(self) -> None: 1104 | channels = 4 1105 | face_width = 512 1106 | test_faces = torch.ones([channels, face_width * 2, face_width * 4]) 1107 | cubic_img = e2c( 1108 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 1109 | ) 1110 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width]) # type: ignore[union-attr] 1111 | 1112 | def test_e2p_1channel(self) -> None: 1113 | h, w = 64, 128 1114 | channels = 1 1115 | e_img = torch.zeros((channels, h, w)) 1116 | 1117 | fov_deg = 90.0 1118 | h_deg = 0.0 1119 | v_deg = 0.0 1120 | out_hw = (32, 32) 1121 | 1122 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1123 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1124 | 1125 | def test_e2p_4channels(self) -> None: 1126 | h, w = 64, 128 1127 | channels = 4 1128 | e_img = torch.zeros((channels, h, w)) 1129 | 1130 | fov_deg = 90.0 1131 | h_deg = 0.0 1132 | v_deg = 0.0 1133 | out_hw = (32, 32) 1134 | 1135 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1136 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1137 | 1138 | def test_c2e_stack_jit(self) -> None: 1139 | channels = 3 1140 | face_width = 512 1141 | test_faces = torch.ones( 1142 | [6, channels, face_width, face_width], dtype=torch.float64 1143 | ) 1144 | 1145 | c2e_jit = torch.jit.script(c2e) 1146 | equi_img = c2e_jit( 1147 | test_faces, 1148 | face_width * 2, 1149 | face_width * 4, 1150 | mode="bilinear", 1151 | cube_format="stack", 1152 | ) 1153 | self.assertEqual( 1154 | list(equi_img.shape), [channels, face_width * 2, face_width * 4] 1155 | ) 1156 | 1157 | def test_e2c_stack_jit(self) -> None: 1158 | channels = 3 1159 | face_width = 512 1160 | test_faces = torch.ones([channels, face_width * 2, face_width * 4]) 1161 | e2c_jit = torch.jit.script(e2c) 1162 | cubic_img = e2c_jit( 1163 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 1164 | ) 1165 | self.assertEqual(list(cubic_img.shape), [6, channels, face_width, face_width]) 1166 | 1167 | def test_e2p_jit(self) -> None: 1168 | h, w = 64, 128 1169 | channels = 3 1170 | e_img = torch.zeros((channels, h, w)) 1171 | 1172 | fov_deg = 90.0 1173 | h_deg = 0.0 1174 | v_deg = 0.0 1175 | out_hw = (32, 32) 1176 | 1177 | e2p_jit = torch.jit.script(e2p) 1178 | 1179 | result = e2p_jit(e_img, fov_deg, h_deg, v_deg, out_hw) 1180 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1181 | 1182 | def test_e2p_exact(self) -> None: 1183 | channels = 3 1184 | out_hw = (2, 4) 1185 | x_input0 = torch.arange(1, 256).repeat(1, 256, 2).float() 1186 | x_input1 = torch.arange(1, 256).repeat(1, 256, 2).float() * 2.0 1187 | x_input2 = torch.arange(1, 256).repeat(1, 256, 2).float() * 4.0 1188 | x_input = torch.cat([x_input0, x_input1, x_input2], 0) 1189 | 1190 | result = e2p( 1191 | x_input, 1192 | fov_deg=(40, 60), 1193 | h_deg=10, 1194 | v_deg=50, 1195 | out_hw=out_hw, 1196 | mode="bilinear", 1197 | ) 1198 | 1199 | a = torch.tensor( 1200 | [ 1201 | [ 1202 | 183.03811645507812, 1203 | 225.49948120117188, 1204 | 58.833866119384766, 1205 | 101.29522705078125, 1206 | ], 1207 | [ 1208 | 243.3969268798828, 1209 | 5.628509521484375, 1210 | 23.704801559448242, 1211 | 40.9364013671875, 1212 | ], 1213 | ] 1214 | ) 1215 | b = torch.tensor( 1216 | [ 1217 | [ 1218 | 366.07623291015625, 1219 | 450.99896240234375, 1220 | 117.66773223876953, 1221 | 202.5904541015625, 1222 | ], 1223 | [ 1224 | 486.7938537597656, 1225 | 11.25701904296875, 1226 | 47.409603118896484, 1227 | 81.872802734375, 1228 | ], 1229 | ] 1230 | ) 1231 | c = torch.tensor( 1232 | [ 1233 | [ 1234 | 732.1524658203125, 1235 | 901.9979248046875, 1236 | 235.33546447753906, 1237 | 405.180908203125, 1238 | ], 1239 | [ 1240 | 973.5877075195312, 1241 | 22.5140380859375, 1242 | 94.81920623779297, 1243 | 163.74560546875, 1244 | ], 1245 | ] 1246 | ) 1247 | expected_output = torch.stack([a, b, c]) 1248 | 1249 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1250 | self.assertTrue(torch.allclose(result, expected_output)) 1251 | 1252 | def test_e2p_exact_batch(self) -> None: 1253 | batch = 2 1254 | channels = 3 1255 | out_hw = (2, 4) 1256 | x_input0 = torch.arange(1, 256).repeat(1, 256, 2).float() 1257 | x_input1 = torch.arange(1, 256).repeat(1, 256, 2).float() * 2.0 1258 | x_input2 = torch.arange(1, 256).repeat(1, 256, 2).float() * 4.0 1259 | x_input = torch.cat([x_input0, x_input1, x_input2], 0).repeat(batch, 1, 1, 1) 1260 | 1261 | result = e2p( 1262 | x_input, 1263 | fov_deg=(40, 60), 1264 | h_deg=10, 1265 | v_deg=50, 1266 | out_hw=out_hw, 1267 | mode="bilinear", 1268 | ) 1269 | 1270 | a = torch.tensor( 1271 | [ 1272 | [ 1273 | 183.03811645507812, 1274 | 225.49948120117188, 1275 | 58.833866119384766, 1276 | 101.29522705078125, 1277 | ], 1278 | [ 1279 | 243.3969268798828, 1280 | 5.628509521484375, 1281 | 23.704801559448242, 1282 | 40.9364013671875, 1283 | ], 1284 | ] 1285 | ) 1286 | b = torch.tensor( 1287 | [ 1288 | [ 1289 | 366.07623291015625, 1290 | 450.99896240234375, 1291 | 117.66773223876953, 1292 | 202.5904541015625, 1293 | ], 1294 | [ 1295 | 486.7938537597656, 1296 | 11.25701904296875, 1297 | 47.409603118896484, 1298 | 81.872802734375, 1299 | ], 1300 | ] 1301 | ) 1302 | c = torch.tensor( 1303 | [ 1304 | [ 1305 | 732.1524658203125, 1306 | 901.9979248046875, 1307 | 235.33546447753906, 1308 | 405.180908203125, 1309 | ], 1310 | [ 1311 | 973.5877075195312, 1312 | 22.5140380859375, 1313 | 94.81920623779297, 1314 | 163.74560546875, 1315 | ], 1316 | ] 1317 | ) 1318 | expected_output = torch.stack([a, b, c]).repeat(batch, 1, 1, 1) 1319 | 1320 | self.assertEqual(list(result.shape), [batch, channels, out_hw[0], out_hw[1]]) 1321 | self.assertTrue(torch.allclose(result, expected_output)) 1322 | 1323 | def test_c2e_stack_exact(self) -> None: 1324 | expected_output = _get_c2e_4x4_exact_tensor() 1325 | tile_w = 4 1326 | x_input = _create_test_faces(tile_w, tile_w) 1327 | output_cubic_tensor = c2e(x_input, mode="bilinear", cube_format="stack") 1328 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output)) 1329 | 1330 | def test_c2e_list_exact(self) -> None: 1331 | expected_output = _get_c2e_4x4_exact_tensor() 1332 | tile_w = 4 1333 | x_input = _create_test_faces(tile_w, tile_w) 1334 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 1335 | x_input_list = [x_input[i] for i in range(6)] 1336 | output_cubic_tensor = c2e(x_input_list, mode="bilinear", cube_format="list") 1337 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output)) 1338 | 1339 | def test_c2e_dict_exact(self) -> None: 1340 | expected_output = _get_c2e_4x4_exact_tensor() 1341 | tile_w = 4 1342 | x_input = _create_test_faces(tile_w, tile_w) 1343 | dict_keys = ["Front", "Right", "Back", "Left", "Up", "Down"] 1344 | x_input_dict = {k: x_input[i] for i, k in zip(range(6), dict_keys)} 1345 | output_cubic_tensor = c2e(x_input_dict, mode="bilinear", cube_format="dict") 1346 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output)) 1347 | 1348 | def test_c2e_horizon_exact(self) -> None: 1349 | expected_output = _get_c2e_4x4_exact_tensor() 1350 | tile_w = 4 1351 | x_input = _create_test_faces(tile_w, tile_w) 1352 | x_input_horizon = torch.cat([x_input[i] for i in range(6)], 2) 1353 | output_cubic_tensor = c2e( 1354 | x_input_horizon, mode="bilinear", cube_format="horizon" 1355 | ) 1356 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output)) 1357 | 1358 | def test_c2e_dice_exact(self) -> None: 1359 | expected_output = _get_c2e_4x4_exact_tensor() 1360 | tile_w = 4 1361 | x_input = _create_test_faces(tile_w, tile_w) 1362 | x_input = _create_dice_layout(x_input, tile_w, tile_w) 1363 | output_cubic_tensor = c2e(x_input, mode="bilinear", cube_format="dice") 1364 | self.assertTrue(torch.allclose(output_cubic_tensor, expected_output)) 1365 | 1366 | def test_e2c_stack_exact(self) -> None: 1367 | x_input = torch.arange(0, 4).repeat(3, 2, 1).float() 1368 | output_tensor = e2c(x_input, face_w=4, mode="bilinear", cube_format="stack") 1369 | expected_output = _get_e2c_4x4_exact_tensor() 1370 | self.assertTrue(torch.allclose(output_tensor, expected_output)) # type: ignore[arg-type] 1371 | 1372 | def test_e2e_float16(self) -> None: 1373 | channels = 4 1374 | face_width = 512 1375 | test_equi = torch.ones( 1376 | [channels, face_width * 2, face_width * 4], dtype=torch.float16 1377 | ) 1378 | equi_img = e2e( 1379 | test_equi, 1380 | h_deg=45, 1381 | v_deg=45, 1382 | roll=25, 1383 | mode="bilinear", 1384 | ) 1385 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1386 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1387 | 1388 | def test_e2e_float64(self) -> None: 1389 | channels = 4 1390 | face_width = 512 1391 | test_equi = torch.ones( 1392 | [channels, face_width * 2, face_width * 4], dtype=torch.float64 1393 | ) 1394 | equi_img = e2e( 1395 | test_equi, 1396 | h_deg=45, 1397 | v_deg=45, 1398 | roll=25, 1399 | mode="bilinear", 1400 | ) 1401 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1402 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1403 | 1404 | def test_e2e_1channel(self) -> None: 1405 | channels = 1 1406 | face_width = 512 1407 | test_equi = torch.ones([channels, face_width * 2, face_width * 4]) 1408 | equi_img = e2e( 1409 | test_equi, 1410 | h_deg=45, 1411 | v_deg=45, 1412 | roll=25, 1413 | mode="bilinear", 1414 | ) 1415 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1416 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1417 | 1418 | def test_e2e_4channels(self) -> None: 1419 | channels = 4 1420 | face_width = 512 1421 | test_equi = torch.ones([channels, face_width * 2, face_width * 4]) 1422 | equi_img = e2e( 1423 | test_equi, 1424 | h_deg=45, 1425 | v_deg=45, 1426 | roll=25, 1427 | mode="bilinear", 1428 | ) 1429 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1430 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1431 | 1432 | def test_e2e_grad(self) -> None: 1433 | channels = 3 1434 | face_width = 512 1435 | test_equi = torch.ones( 1436 | [channels, face_width * 2, face_width * 4], requires_grad=True 1437 | ) 1438 | equi_img = e2e( 1439 | test_equi, 1440 | h_deg=45, 1441 | v_deg=45, 1442 | roll=25, 1443 | mode="bilinear", 1444 | ) 1445 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1446 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1447 | self.assertTrue(equi_img.requires_grad) 1448 | 1449 | def test_e2e_gpu(self) -> None: 1450 | if not torch.cuda.is_available(): 1451 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 1452 | channels = 3 1453 | face_width = 512 1454 | test_equi = torch.ones([channels, face_width * 2, face_width * 4]).cuda() 1455 | equi_img = e2e( 1456 | test_equi, 1457 | h_deg=45, 1458 | v_deg=45, 1459 | roll=25, 1460 | mode="bilinear", 1461 | ) 1462 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1463 | self.assertEqual(equi_img.dtype, test_equi.dtype) 1464 | self.assertTrue(equi_img.is_cuda) 1465 | 1466 | def test_e2e_jit(self) -> None: 1467 | channels = 3 1468 | face_width = 512 1469 | test_equi = torch.ones([channels, face_width * 2, face_width * 4]) 1470 | e2e_jit = torch.jit.script(e2e) 1471 | equi_img = e2e_jit( 1472 | test_equi, 1473 | h_deg=45, 1474 | v_deg=45, 1475 | roll=25, 1476 | mode="bilinear", 1477 | ) 1478 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1479 | 1480 | def test_e2e_batch(self) -> None: 1481 | batch = 2 1482 | channels = 4 1483 | face_width = 512 1484 | test_equi = torch.ones([batch, channels, face_width * 2, face_width * 4]) 1485 | equi_img = e2e( 1486 | test_equi, 1487 | h_deg=45, 1488 | v_deg=45, 1489 | roll=25, 1490 | mode="bilinear", 1491 | ) 1492 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1493 | 1494 | def test_e2e_exact(self) -> None: 1495 | channels = 1 1496 | face_width = 2 1497 | 1498 | w = face_width * 2 1499 | h = face_width * 4 1500 | test_equi = torch.arange(h * w * channels, dtype=torch.float32) 1501 | test_equi = test_equi.reshape(channels, h, w) 1502 | 1503 | result = e2e( 1504 | test_equi, 1505 | h_deg=45, 1506 | v_deg=45, 1507 | roll=25, 1508 | mode="bilinear", 1509 | ) 1510 | self.assertEqual(list(result.shape), list(test_equi.shape)) 1511 | 1512 | expected_output = torch.tensor( 1513 | [ 1514 | [ 1515 | [ 1516 | 8.846327781677246, 1517 | 7.38915491104126, 1518 | 10.024271011352539, 1519 | 11.211331367492676, 1520 | ], 1521 | [ 1522 | 9.07547378540039, 1523 | 3.879967451095581, 1524 | 12.109222412109375, 1525 | 15.09968376159668, 1526 | ], 1527 | [ 1528 | 10.49519157409668, 1529 | 2.552384614944458, 1530 | 14.621706008911133, 1531 | 19.000062942504883, 1532 | ], 1533 | [ 1534 | 12.717361450195312, 1535 | 4.980112552642822, 1536 | 17.24430274963379, 1537 | 22.87851905822754, 1538 | ], 1539 | [ 1540 | 15.3826265335083, 1541 | 8.457935333251953, 1542 | 19.71427345275879, 1543 | 26.661039352416992, 1544 | ], 1545 | [ 1546 | 18.184263229370117, 1547 | 12.159286499023438, 1548 | 21.694225311279297, 1549 | 29.683902740478516, 1550 | ], 1551 | [ 1552 | 20.887590408325195, 1553 | 15.91322135925293, 1554 | 22.679805755615234, 1555 | 29.112091064453125, 1556 | ], 1557 | [ 1558 | 20.43504524230957, 1559 | 19.638233184814453, 1560 | 22.215164184570312, 1561 | 23.207149505615234, 1562 | ], 1563 | ] 1564 | ] 1565 | ) 1566 | self.assertTrue(torch.allclose(result, expected_output)) 1567 | 1568 | def test_e2e_exact_batch(self) -> None: 1569 | batch = 2 1570 | channels = 1 1571 | face_width = 2 1572 | 1573 | w = face_width * 2 1574 | h = face_width * 4 1575 | test_equi = torch.arange(batch * h * w * channels, dtype=torch.float32) 1576 | test_equi = test_equi.reshape(batch, channels, h, w) 1577 | 1578 | result = e2e( 1579 | test_equi, 1580 | h_deg=45, 1581 | v_deg=45, 1582 | roll=25, 1583 | mode="bilinear", 1584 | ) 1585 | self.assertEqual(list(result.shape), list(test_equi.shape)) 1586 | 1587 | expected_output = torch.tensor( 1588 | [ 1589 | [ 1590 | [ 1591 | [ 1592 | 8.846327781677246, 1593 | 7.38915491104126, 1594 | 10.024271011352539, 1595 | 11.211331367492676, 1596 | ], 1597 | [ 1598 | 9.07547378540039, 1599 | 3.879967451095581, 1600 | 12.109222412109375, 1601 | 15.09968376159668, 1602 | ], 1603 | [ 1604 | 10.49519157409668, 1605 | 2.552384614944458, 1606 | 14.621706008911133, 1607 | 19.000062942504883, 1608 | ], 1609 | [ 1610 | 12.717361450195312, 1611 | 4.980112552642822, 1612 | 17.24430274963379, 1613 | 22.87851905822754, 1614 | ], 1615 | [ 1616 | 15.3826265335083, 1617 | 8.457935333251953, 1618 | 19.71427345275879, 1619 | 26.661039352416992, 1620 | ], 1621 | [ 1622 | 18.184263229370117, 1623 | 12.159286499023438, 1624 | 21.694225311279297, 1625 | 29.683902740478516, 1626 | ], 1627 | [ 1628 | 20.887590408325195, 1629 | 15.91322135925293, 1630 | 22.679805755615234, 1631 | 29.112091064453125, 1632 | ], 1633 | [ 1634 | 20.43504524230957, 1635 | 19.638233184814453, 1636 | 22.215164184570312, 1637 | 23.207149505615234, 1638 | ], 1639 | ] 1640 | ], 1641 | [ 1642 | [ 1643 | [ 1644 | 40.84632873535156, 1645 | 39.38915252685547, 1646 | 42.02427291870117, 1647 | 43.21133041381836, 1648 | ], 1649 | [ 1650 | 41.07547378540039, 1651 | 35.879966735839844, 1652 | 44.109222412109375, 1653 | 47.09968566894531, 1654 | ], 1655 | [ 1656 | 42.49519348144531, 1657 | 34.5523796081543, 1658 | 46.6217041015625, 1659 | 51.000064849853516, 1660 | ], 1661 | [ 1662 | 44.71736145019531, 1663 | 36.9801139831543, 1664 | 49.244300842285156, 1665 | 54.87852096557617, 1666 | ], 1667 | [ 1668 | 47.382625579833984, 1669 | 40.45793533325195, 1670 | 51.71427536010742, 1671 | 58.66103744506836, 1672 | ], 1673 | [ 1674 | 50.18426513671875, 1675 | 44.15928649902344, 1676 | 53.6942253112793, 1677 | 61.683902740478516, 1678 | ], 1679 | [ 1680 | 52.88758850097656, 1681 | 47.9132194519043, 1682 | 54.679805755615234, 1683 | 61.112091064453125, 1684 | ], 1685 | [ 1686 | 52.43504333496094, 1687 | 51.63823318481445, 1688 | 54.21516418457031, 1689 | 55.207149505615234, 1690 | ], 1691 | ] 1692 | ], 1693 | ] 1694 | ) 1695 | self.assertEqual(result.shape, expected_output.shape) 1696 | self.assertTrue(torch.allclose(result, expected_output)) 1697 | 1698 | def test_c2e_stack_bfloat16(self) -> None: 1699 | dtype = torch.bfloat16 1700 | face_width = 512 1701 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=dtype) 1702 | equi_img = c2e( 1703 | test_faces, 1704 | face_width * 2, 1705 | face_width * 4, 1706 | mode="bilinear", 1707 | cube_format="stack", 1708 | ) 1709 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 1710 | self.assertEqual(equi_img.dtype, dtype) 1711 | 1712 | def test_e2c_stack_bfloat16(self) -> None: 1713 | dtype = torch.bfloat16 1714 | face_width = 512 1715 | test_faces = torch.ones([3, face_width * 2, face_width * 4], dtype=dtype) 1716 | cubic_img = e2c( 1717 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 1718 | ) 1719 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 1720 | self.assertEqual(cubic_img.dtype, dtype) # type: ignore[union-attr] 1721 | 1722 | def test_e2p_bfloat16(self) -> None: 1723 | # Create a simple test equirectangular image 1724 | dtype = torch.bfloat16 1725 | h, w = 64, 128 1726 | channels = 3 1727 | e_img = torch.zeros((channels, h, w), dtype=dtype) 1728 | 1729 | fov_deg = 90.0 1730 | h_deg = 0.0 1731 | v_deg = 0.0 1732 | out_hw = (32, 32) 1733 | 1734 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1735 | 1736 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1737 | self.assertEqual(result.dtype, dtype) 1738 | 1739 | def test_e2e_bfloat16(self) -> None: 1740 | dtype = torch.bfloat16 1741 | channels = 4 1742 | face_width = 512 1743 | test_equi = torch.ones([channels, face_width * 2, face_width * 4], dtype=dtype) 1744 | equi_img = e2e( 1745 | test_equi, 1746 | h_deg=45, 1747 | v_deg=45, 1748 | roll=25, 1749 | mode="bilinear", 1750 | ) 1751 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1752 | self.assertEqual(equi_img.dtype, dtype) 1753 | 1754 | def test_c2e_stack_bfloat16_cuda(self) -> None: 1755 | if not torch.cuda.is_available(): 1756 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 1757 | dtype = torch.bfloat16 1758 | face_width = 512 1759 | test_faces = torch.ones([6, 3, face_width, face_width], dtype=dtype).cuda() 1760 | equi_img = c2e( 1761 | test_faces, 1762 | face_width * 2, 1763 | face_width * 4, 1764 | mode="bilinear", 1765 | cube_format="stack", 1766 | ) 1767 | self.assertEqual(list(equi_img.shape), [3, face_width * 2, face_width * 4]) 1768 | self.assertEqual(equi_img.dtype, dtype) 1769 | self.assertTrue(equi_img.is_cuda) 1770 | 1771 | def test_e2c_stack_bfloat16_cuda(self) -> None: 1772 | if not torch.cuda.is_available(): 1773 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 1774 | dtype = torch.bfloat16 1775 | face_width = 512 1776 | test_faces = torch.ones([3, face_width * 2, face_width * 4], dtype=dtype).cuda() 1777 | cubic_img = e2c( 1778 | test_faces, face_w=face_width, mode="bilinear", cube_format="stack" 1779 | ) 1780 | self.assertEqual(list(cubic_img.shape), [6, 3, face_width, face_width]) # type: ignore[union-attr] 1781 | self.assertEqual(cubic_img.dtype, dtype) # type: ignore[union-attr] 1782 | self.assertTrue(cubic_img.is_cuda) # type: ignore[union-attr] 1783 | 1784 | def test_e2p_bfloat16_cuda(self) -> None: 1785 | if not torch.cuda.is_available(): 1786 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 1787 | # Create a simple test equirectangular image 1788 | dtype = torch.bfloat16 1789 | h, w = 64, 128 1790 | channels = 3 1791 | e_img = torch.zeros((channels, h, w), dtype=dtype).cuda() 1792 | 1793 | fov_deg = 90.0 1794 | h_deg = 0.0 1795 | v_deg = 0.0 1796 | out_hw = (32, 32) 1797 | 1798 | result = e2p(e_img, fov_deg, h_deg, v_deg, out_hw) 1799 | 1800 | self.assertEqual(list(result.shape), [channels, out_hw[0], out_hw[1]]) 1801 | self.assertEqual(result.dtype, dtype) 1802 | self.assertTrue(result.is_cuda) 1803 | 1804 | def test_e2e_bfloat16_cuda(self) -> None: 1805 | if not torch.cuda.is_available(): 1806 | raise unittest.SkipTest("Skipping CUDA test due to not supporting CUDA.") 1807 | dtype = torch.bfloat16 1808 | channels = 4 1809 | face_width = 512 1810 | test_equi = torch.ones( 1811 | [channels, face_width * 2, face_width * 4], dtype=dtype 1812 | ).cuda() 1813 | equi_img = e2e( 1814 | test_equi, 1815 | h_deg=45, 1816 | v_deg=45, 1817 | roll=25, 1818 | mode="bilinear", 1819 | ) 1820 | self.assertEqual(list(equi_img.shape), list(test_equi.shape)) 1821 | self.assertEqual(equi_img.dtype, dtype) 1822 | self.assertTrue(equi_img.is_cuda) 1823 | 1824 | def test_pad_180_to_360_channels_first(self) -> None: 1825 | e_img = torch.ones(3, 4, 4) # [C, H, W] = [3, 4, 4] 1826 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1827 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1828 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1829 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1830 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1831 | 1832 | def test_pad_180_to_360_channels_last(self) -> None: 1833 | e_img = torch.ones(4, 4, 3) # [H, W, C] = [4, 4, 3] 1834 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=False) 1835 | self.assertEqual(padded_img.shape, (4, 8, 3)) 1836 | self.assertTrue(torch.all(padded_img[:, 0, :] == 0.0)) 1837 | self.assertTrue(torch.all(padded_img[:, -1, :] == 0.0)) 1838 | self.assertTrue(torch.all(padded_img[:, 2:6, :] == 1.0)) 1839 | 1840 | def test_pad_180_to_360_batch(self) -> None: 1841 | e_img = torch.ones(2, 3, 4, 4) # [B, C, H, W] = [2, 3, 4, 4] 1842 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1843 | self.assertEqual(padded_img.shape, (2, 3, 4, 8)) 1844 | self.assertTrue(torch.all(padded_img[:, :, :, 0] == 0.0)) 1845 | self.assertTrue(torch.all(padded_img[:, :, :, -1] == 0.0)) 1846 | self.assertTrue(torch.all(padded_img[:, :, :, 2:6] == 1.0)) 1847 | 1848 | def test_pad_180_to_360_gpu(self) -> None: 1849 | if torch.cuda.is_available(): 1850 | e_img = torch.ones(3, 4, 4, device="cuda") # [C, H, W] = [3, 4, 4] 1851 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1852 | self.assertTrue(padded_img.is_cuda) 1853 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1854 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1855 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1856 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1857 | 1858 | def test_pad_180_to_360_float16(self) -> None: 1859 | e_img = torch.ones(3, 4, 4, dtype=torch.float16) 1860 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1861 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1862 | self.assertEqual(padded_img.dtype, torch.float16) 1863 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1864 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1865 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1866 | 1867 | def test_pad_180_to_360_float64(self) -> None: 1868 | e_img = torch.ones(3, 4, 4, dtype=torch.float64) 1869 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1870 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1871 | self.assertEqual(padded_img.dtype, torch.float64) 1872 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1873 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1874 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1875 | 1876 | def test_pad_180_to_360_bfloat16(self) -> None: 1877 | e_img = torch.ones(3, 4, 4, dtype=torch.bfloat16) 1878 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1879 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1880 | self.assertEqual(padded_img.dtype, torch.bfloat16) 1881 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1882 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1883 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1884 | 1885 | def test_pad_180_to_360_gradient(self) -> None: 1886 | e_img = torch.ones(3, 4, 4, dtype=torch.float32, requires_grad=True) 1887 | padded_img = pad_180_to_360(e_img, fill_value=0.0, channels_first=True) 1888 | self.assertTrue(padded_img.requires_grad) 1889 | 1890 | def test_pad_180_to_360_jit(self) -> None: 1891 | e_img = torch.ones(3, 4, 4) # [C, H, W] = [3, 4, 4] 1892 | pad_180_to_360_jit = torch.jit.script(pad_180_to_360) 1893 | padded_img = pad_180_to_360_jit(e_img, fill_value=0.0, channels_first=True) 1894 | self.assertEqual(padded_img.shape, (3, 4, 8)) 1895 | self.assertTrue(torch.all(padded_img[:, :, 0] == 0.0)) 1896 | self.assertTrue(torch.all(padded_img[:, :, -1] == 0.0)) 1897 | self.assertTrue(torch.all(padded_img[:, :, 2:6] == 1.0)) 1898 | 1899 | def _prepare_test_data_cube_padding( 1900 | self, 1901 | height: int = 4, 1902 | width: int = 4, 1903 | channels: int = 3, 1904 | dtype: torch.dtype = torch.float32, 1905 | device: str = "cpu", 1906 | requires_grad: bool = False, 1907 | ) -> torch.Tensor: 1908 | """Helper to create test cube faces with specific configuration""" 1909 | # Create test faces where each face has a unique value 1910 | cube_faces = torch.zeros( 1911 | (6, height, width, channels), dtype=dtype, device=device 1912 | ) 1913 | 1914 | # Set each face to a unique value for easy identification 1915 | for face_idx in range(6): 1916 | cube_faces[face_idx, :, :, :] = face_idx / 5.0 # Normalize to [0, 1] 1917 | 1918 | if requires_grad: 1919 | cube_faces.requires_grad_(True) 1920 | 1921 | return cube_faces 1922 | 1923 | def test_pad_cube_faces_shape(self) -> None: 1924 | """Test that the output shape is correct""" 1925 | # Use equal height and width to avoid dimension mismatch when flipping 1926 | height, width, channels = 4, 4, 3 1927 | cube_faces = self._prepare_test_data_cube_padding(height, width, channels) 1928 | 1929 | padded_faces = pad_cube_faces(cube_faces) 1930 | 1931 | self.assertEqual(padded_faces.shape, (6, height + 2, width + 2, channels)) 1932 | 1933 | def test_pad_cube_faces_jit(self) -> None: 1934 | """Test that the function works with torch.jit.script""" 1935 | cube_faces = self._prepare_test_data_cube_padding(4, 4, 3) 1936 | 1937 | # Script the function 1938 | pad_cube_faces_jit = torch.jit.script(pad_cube_faces) 1939 | 1940 | # Run the scripted function 1941 | padded_faces = pad_cube_faces_jit(cube_faces) 1942 | 1943 | # Verify output shape 1944 | self.assertEqual(padded_faces.shape, (6, 6, 6, 3)) 1945 | 1946 | # Compare with non-scripted function 1947 | expected_output = pad_cube_faces(cube_faces) 1948 | self.assertTrue(torch.allclose(padded_faces, expected_output)) 1949 | 1950 | def test_pad_cube_faces_gradient(self) -> None: 1951 | """Test that gradients flow through the function""" 1952 | cube_faces = self._prepare_test_data_cube_padding(requires_grad=True) 1953 | 1954 | padded_faces = pad_cube_faces(cube_faces) 1955 | 1956 | # Check that requires_grad is preserved 1957 | self.assertTrue(padded_faces.requires_grad) 1958 | 1959 | # Test gradient flow by computing a loss and backpropagating 1960 | loss = padded_faces.sum() 1961 | loss.backward() 1962 | 1963 | # Verify that the gradient exists and is not zero 1964 | self.assertIsNotNone(cube_faces.grad) 1965 | self.assertFalse( 1966 | torch.allclose(cube_faces.grad, torch.zeros_like(cube_faces.grad)) # type: ignore[arg-type] 1967 | ) 1968 | 1969 | def test_pad_cube_faces_cuda(self) -> None: 1970 | """Test that the function works on CUDA device if available""" 1971 | if not torch.cuda.is_available(): 1972 | self.skipTest("CUDA not available") 1973 | 1974 | cube_faces = self._prepare_test_data_cube_padding(device="cuda") 1975 | 1976 | padded_faces = pad_cube_faces(cube_faces) 1977 | 1978 | # Check that the output is on the correct device 1979 | self.assertTrue(padded_faces.is_cuda) 1980 | 1981 | # Verify shape 1982 | self.assertEqual(padded_faces.shape, (6, 6, 6, 3)) 1983 | 1984 | def test_pad_cube_faces_float16(self) -> None: 1985 | """Test that the function works with float16 precision""" 1986 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float16) 1987 | 1988 | padded_faces = pad_cube_faces(cube_faces) 1989 | 1990 | # Check that dtype is preserved 1991 | self.assertEqual(padded_faces.dtype, torch.float16) 1992 | 1993 | def test_pad_cube_faces_float32(self) -> None: 1994 | """Test that the function works with float32 precision""" 1995 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float32) 1996 | 1997 | padded_faces = pad_cube_faces(cube_faces) 1998 | 1999 | # Check that dtype is preserved 2000 | self.assertEqual(padded_faces.dtype, torch.float32) 2001 | 2002 | def test_pad_cube_faces_float64(self) -> None: 2003 | """Test that the function works with float64 precision""" 2004 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.float64) 2005 | 2006 | padded_faces = pad_cube_faces(cube_faces) 2007 | 2008 | # Check that dtype is preserved 2009 | self.assertEqual(padded_faces.dtype, torch.float64) 2010 | 2011 | def test_pad_cube_faces_bfloat16(self) -> None: 2012 | """Test that the function works with bfloat16 precision""" 2013 | # Skip if bfloat16 is not supported 2014 | if not hasattr(torch, "bfloat16"): 2015 | self.skipTest("bfloat16 not supported in this PyTorch version") 2016 | 2017 | cube_faces = self._prepare_test_data_cube_padding(dtype=torch.bfloat16) 2018 | 2019 | padded_faces = pad_cube_faces(cube_faces) 2020 | 2021 | # Check that dtype is preserved 2022 | self.assertEqual(padded_faces.dtype, torch.bfloat16) 2023 | 2024 | def test_pad_cube_faces_exact_values(self) -> None: 2025 | """Test the exact values of padded faces for a small input""" 2026 | # Define small cube faces with easily identifiable values 2027 | cube_faces = torch.zeros((6, 2, 2, 1)) 2028 | 2029 | # FRONT=0, RIGHT=1, BACK=2, LEFT=3, UP=4, DOWN=5 2030 | for face_idx in range(6): 2031 | # Set each face to its index value for easier verification 2032 | cube_faces[face_idx, :, :, :] = face_idx 2033 | 2034 | padded_faces = pad_cube_faces(cube_faces) 2035 | 2036 | # Verify central values (should be unchanged) 2037 | for face_idx in range(6): 2038 | self.assertTrue( 2039 | torch.all(padded_faces[face_idx, 1:-1, 1:-1, :] == face_idx) 2040 | ) 2041 | 2042 | # Verify specific border values based on the padding logic 2043 | # These assertions check that the padding values come from the correct neighboring faces 2044 | 2045 | # Check FRONT face padding 2046 | self.assertTrue(torch.all(padded_faces[0, 0, 1:-1, :] == 4)) # Top from UP 2047 | self.assertTrue( 2048 | torch.all(padded_faces[0, -1, 1:-1, :] == 5) 2049 | ) # Bottom from DOWN 2050 | self.assertTrue(torch.all(padded_faces[0, 1:-1, 0, :] == 3)) # Left from LEFT 2051 | self.assertTrue( 2052 | torch.all(padded_faces[0, 1:-1, -1, :] == 1) 2053 | ) # Right from RIGHT 2054 | 2055 | # Check additional values for other faces to verify correct padding 2056 | # RIGHT face 2057 | self.assertTrue(torch.all(padded_faces[1, 1:-1, 0, :] == 0)) # Left from FRONT 2058 | 2059 | # BACK face 2060 | self.assertTrue(torch.all(padded_faces[2, 1:-1, -1, :] == 3)) # Right from LEFT 2061 | 2062 | # UP face 2063 | self.assertTrue( 2064 | torch.all(padded_faces[4, -1, 1:-1, :] == 0) 2065 | ) # Bottom from FRONT 2066 | --------------------------------------------------------------------------------