├── pics ├── shuffle.png └── unshuffle.png ├── LICENSE ├── demo.ipynb ├── README.md ├── .gitignore └── pixel_shuffle3d.py /pics/shuffle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scalyvladimir/pixel_shuffle3d/HEAD/pics/shuffle.png -------------------------------------------------------------------------------- /pics/unshuffle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scalyvladimir/pixel_shuffle3d/HEAD/pics/unshuffle.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vladimir Chernyy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "torch.Size([1, 27, 4, 4, 4])\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "from pixel_shuffle3d import PixelUnshuffle3d\n", 18 | "import torch\n", 19 | "\n", 20 | "pixel_unshuffle = PixelUnshuffle3d(3)\n", 21 | "input = torch.randn(1, 1, 12, 12, 12)\n", 22 | "output = pixel_unshuffle(input)\n", 23 | "print(output.size())" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 17, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "torch.Size([1, 1, 12, 12, 12])\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "from pixel_shuffle3d import PixelShuffle3d\n", 41 | "import torch\n", 42 | "\n", 43 | "pixel_shuffle = PixelShuffle3d(3)\n", 44 | "input = torch.randn(1, 27, 4, 4, 4)\n", 45 | "output = pixel_shuffle(input)\n", 46 | "print(output.size())" 47 | ] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "Python 3", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.9.6" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 2 71 | } 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Getting Started 3 | 4 | This repo contains 3D version of original Pixel Shuffle idea from: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158v2), implemented in [PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.PixelUnshuffle.html). 5 | 6 | ### Demonstration 7 | 8 | Visual intuition of how 3D (un)-shuffle operator works 9 | 10 |
11 | drawing 14 |
15 | 16 |
17 | drawing 20 |
21 | 22 | ### Installation 23 | 24 | 1. Clone the repo 25 | ```sh 26 | git clone git@github.com:scalyvladimir/pixel_shuffle3d.git 27 | ``` 28 | 29 | 2. Install all the demanded packages with: 30 | ```sh 31 | pip3 install torch numpy 32 | ``` 33 | 34 | ### Usage 35 | 1. ``PixelUnshuffle3d`` 36 | ```python 37 | from pixel_shuffle3d import PixelUnshuffle3d 38 | import torch 39 | 40 | pixel_unshuffle = PixelUnshuffle3d(3) 41 | input = torch.randn(1, 1, 12, 12, 12) 42 | output = pixel_unshuffle(input) 43 | print(output.size()) 44 | # torch.Size([1, 27, 4, 4, 4]) 45 | ``` 46 | 47 | 2. ``PixelShuffle3d`` 48 | ```python 49 | from pixel_shuffle3d import PixelShuffle3d 50 | import torch 51 | 52 | pixel_shuffle = PixelShuffle3d(3) 53 | input = torch.randn(1, 27, 4, 4, 4) 54 | output = pixel_shuffle(input) 55 | print(output.size()) 56 | # torch.Size([1, 1, 12, 12, 12]) 57 | ``` 58 | 59 | 60 | 61 | 62 | ## Contributing 63 | 64 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. 65 | 66 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". 67 | Don't forget to give the project a star! Thanks again! 68 | 69 | 1. Fork the Project 70 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 71 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 72 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 73 | 5. Open a Pull Request -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /pixel_shuffle3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class PixelShuffle3d(nn.Module): 5 | def __init__(self, upscale_factor=None): 6 | super().__init__() 7 | 8 | if upscale_factor is None: 9 | raise TypeError('__init__() missing 1 required positional argument: \'upscale_factor\'') 10 | 11 | self.upscale_factor = upscale_factor 12 | 13 | def forward(self, x): 14 | if x.ndim < 3: 15 | raise RuntimeError( 16 | f'pixel_shuffle expects input to have at least 3 dimensions, but got input with {x.ndim} dimension(s)' 17 | ) 18 | elif x.shape[-4] % self.upscale_factor**3 != 0: 19 | raise RuntimeError( 20 | f'pixel_shuffle expects its input\'s \'channel\' dimension to be divisible by the cube of upscale_factor, but input.size(-4)={x.shape[-4]} is not divisible by {self.upscale_factor**3}' 21 | ) 22 | 23 | channels, in_depth, in_height, in_width = x.shape[-4:] 24 | nOut = channels // self.upscale_factor ** 3 25 | 26 | out_depth = in_depth * self.upscale_factor 27 | out_height = in_height * self.upscale_factor 28 | out_width = in_width * self.upscale_factor 29 | 30 | input_view = x.contiguous().view( 31 | *x.shape[:-4], 32 | nOut, 33 | self.upscale_factor, 34 | self.upscale_factor, 35 | self.upscale_factor, 36 | in_depth, 37 | in_height, 38 | in_width 39 | ) 40 | 41 | axes = torch.arange(input_view.ndim)[:-6].tolist() + [-3, -6, -2, -5, -1, -4] 42 | output = input_view.permute(axes).contiguous() 43 | 44 | return output.view(*x.shape[:-4], nOut, out_depth, out_height, out_width) 45 | 46 | class PixelUnshuffle3d(nn.Module): 47 | def __init__(self, upscale_factor=None): 48 | super().__init__() 49 | 50 | if upscale_factor is None: 51 | raise TypeError('__init__() missing 1 required positional argument: \'upscale_factor\'') 52 | 53 | self.upscale_factor = upscale_factor 54 | 55 | def forward(self, x): 56 | if x.ndim < 3: 57 | raise RuntimeError( 58 | f'pixel_unshuffle expects input to have at least 3 dimensions, but got input with {x.ndim} dimension(s)' 59 | ) 60 | elif x.shape[-3] % self.upscale_factor != 0: 61 | raise RuntimeError( 62 | f'pixel_unshuffle expects depth to be divisible by downscale_factor, but input.size(-3)={x.shape[-3]} is not divisible by {self.upscale_factor}' 63 | ) 64 | elif x.shape[-2] % self.upscale_factor != 0: 65 | raise RuntimeError( 66 | f'pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)={x.shape[-2]} is not divisible by {self.upscale_factor}' 67 | ) 68 | elif x.shape[-1] % self.upscale_factor != 0: 69 | raise RuntimeError( 70 | f'pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)={x.shape[-1]} is not divisible by {self.upscale_factor}' 71 | ) 72 | 73 | channels, in_depth, in_height, in_width = x.shape[-4:] 74 | 75 | out_depth = in_depth // self.upscale_factor 76 | out_height = in_height // self.upscale_factor 77 | out_width = in_width // self.upscale_factor 78 | nOut = channels * self.upscale_factor**3 79 | 80 | input_view = x.contiguous().view( 81 | *x.shape[:-4], 82 | channels, 83 | out_depth, 84 | self.upscale_factor, 85 | out_height, 86 | self.upscale_factor, 87 | out_width, 88 | self.upscale_factor 89 | ) 90 | 91 | axes = torch.arange(input_view.ndim)[:-6].tolist() + [-5, -3, -1, -6, -4, -2] 92 | output = input_view.permute(axes).contiguous() 93 | 94 | return output.view(*x.shape[:-4], nOut, out_depth, out_height, out_width) --------------------------------------------------------------------------------