├── pics
├── shuffle.png
└── unshuffle.png
├── LICENSE
├── demo.ipynb
├── README.md
├── .gitignore
└── pixel_shuffle3d.py
/pics/shuffle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scalyvladimir/pixel_shuffle3d/HEAD/pics/shuffle.png
--------------------------------------------------------------------------------
/pics/unshuffle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scalyvladimir/pixel_shuffle3d/HEAD/pics/unshuffle.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Vladimir Chernyy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 15,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "torch.Size([1, 27, 4, 4, 4])\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "from pixel_shuffle3d import PixelUnshuffle3d\n",
18 | "import torch\n",
19 | "\n",
20 | "pixel_unshuffle = PixelUnshuffle3d(3)\n",
21 | "input = torch.randn(1, 1, 12, 12, 12)\n",
22 | "output = pixel_unshuffle(input)\n",
23 | "print(output.size())"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 17,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "torch.Size([1, 1, 12, 12, 12])\n"
36 | ]
37 | }
38 | ],
39 | "source": [
40 | "from pixel_shuffle3d import PixelShuffle3d\n",
41 | "import torch\n",
42 | "\n",
43 | "pixel_shuffle = PixelShuffle3d(3)\n",
44 | "input = torch.randn(1, 27, 4, 4, 4)\n",
45 | "output = pixel_shuffle(input)\n",
46 | "print(output.size())"
47 | ]
48 | }
49 | ],
50 | "metadata": {
51 | "kernelspec": {
52 | "display_name": "Python 3",
53 | "language": "python",
54 | "name": "python3"
55 | },
56 | "language_info": {
57 | "codemirror_mode": {
58 | "name": "ipython",
59 | "version": 3
60 | },
61 | "file_extension": ".py",
62 | "mimetype": "text/x-python",
63 | "name": "python",
64 | "nbconvert_exporter": "python",
65 | "pygments_lexer": "ipython3",
66 | "version": "3.9.6"
67 | }
68 | },
69 | "nbformat": 4,
70 | "nbformat_minor": 2
71 | }
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Getting Started
3 |
4 | This repo contains 3D version of original Pixel Shuffle idea from: [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/abs/1609.05158v2), implemented in [PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.PixelUnshuffle.html).
5 |
6 | ### Demonstration
7 |
8 | Visual intuition of how 3D (un)-shuffle operator works
9 |
10 |
11 |
14 |
15 |
16 |
17 |
20 |
21 |
22 | ### Installation
23 |
24 | 1. Clone the repo
25 | ```sh
26 | git clone git@github.com:scalyvladimir/pixel_shuffle3d.git
27 | ```
28 |
29 | 2. Install all the demanded packages with:
30 | ```sh
31 | pip3 install torch numpy
32 | ```
33 |
34 | ### Usage
35 | 1. ``PixelUnshuffle3d``
36 | ```python
37 | from pixel_shuffle3d import PixelUnshuffle3d
38 | import torch
39 |
40 | pixel_unshuffle = PixelUnshuffle3d(3)
41 | input = torch.randn(1, 1, 12, 12, 12)
42 | output = pixel_unshuffle(input)
43 | print(output.size())
44 | # torch.Size([1, 27, 4, 4, 4])
45 | ```
46 |
47 | 2. ``PixelShuffle3d``
48 | ```python
49 | from pixel_shuffle3d import PixelShuffle3d
50 | import torch
51 |
52 | pixel_shuffle = PixelShuffle3d(3)
53 | input = torch.randn(1, 27, 4, 4, 4)
54 | output = pixel_shuffle(input)
55 | print(output.size())
56 | # torch.Size([1, 1, 12, 12, 12])
57 | ```
58 |
59 |
60 |
61 |
62 | ## Contributing
63 |
64 | Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
65 |
66 | If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement".
67 | Don't forget to give the project a star! Thanks again!
68 |
69 | 1. Fork the Project
70 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
71 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
72 | 4. Push to the Branch (`git push origin feature/AmazingFeature`)
73 | 5. Open a Pull Request
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/pixel_shuffle3d.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class PixelShuffle3d(nn.Module):
5 | def __init__(self, upscale_factor=None):
6 | super().__init__()
7 |
8 | if upscale_factor is None:
9 | raise TypeError('__init__() missing 1 required positional argument: \'upscale_factor\'')
10 |
11 | self.upscale_factor = upscale_factor
12 |
13 | def forward(self, x):
14 | if x.ndim < 3:
15 | raise RuntimeError(
16 | f'pixel_shuffle expects input to have at least 3 dimensions, but got input with {x.ndim} dimension(s)'
17 | )
18 | elif x.shape[-4] % self.upscale_factor**3 != 0:
19 | raise RuntimeError(
20 | f'pixel_shuffle expects its input\'s \'channel\' dimension to be divisible by the cube of upscale_factor, but input.size(-4)={x.shape[-4]} is not divisible by {self.upscale_factor**3}'
21 | )
22 |
23 | channels, in_depth, in_height, in_width = x.shape[-4:]
24 | nOut = channels // self.upscale_factor ** 3
25 |
26 | out_depth = in_depth * self.upscale_factor
27 | out_height = in_height * self.upscale_factor
28 | out_width = in_width * self.upscale_factor
29 |
30 | input_view = x.contiguous().view(
31 | *x.shape[:-4],
32 | nOut,
33 | self.upscale_factor,
34 | self.upscale_factor,
35 | self.upscale_factor,
36 | in_depth,
37 | in_height,
38 | in_width
39 | )
40 |
41 | axes = torch.arange(input_view.ndim)[:-6].tolist() + [-3, -6, -2, -5, -1, -4]
42 | output = input_view.permute(axes).contiguous()
43 |
44 | return output.view(*x.shape[:-4], nOut, out_depth, out_height, out_width)
45 |
46 | class PixelUnshuffle3d(nn.Module):
47 | def __init__(self, upscale_factor=None):
48 | super().__init__()
49 |
50 | if upscale_factor is None:
51 | raise TypeError('__init__() missing 1 required positional argument: \'upscale_factor\'')
52 |
53 | self.upscale_factor = upscale_factor
54 |
55 | def forward(self, x):
56 | if x.ndim < 3:
57 | raise RuntimeError(
58 | f'pixel_unshuffle expects input to have at least 3 dimensions, but got input with {x.ndim} dimension(s)'
59 | )
60 | elif x.shape[-3] % self.upscale_factor != 0:
61 | raise RuntimeError(
62 | f'pixel_unshuffle expects depth to be divisible by downscale_factor, but input.size(-3)={x.shape[-3]} is not divisible by {self.upscale_factor}'
63 | )
64 | elif x.shape[-2] % self.upscale_factor != 0:
65 | raise RuntimeError(
66 | f'pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)={x.shape[-2]} is not divisible by {self.upscale_factor}'
67 | )
68 | elif x.shape[-1] % self.upscale_factor != 0:
69 | raise RuntimeError(
70 | f'pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)={x.shape[-1]} is not divisible by {self.upscale_factor}'
71 | )
72 |
73 | channels, in_depth, in_height, in_width = x.shape[-4:]
74 |
75 | out_depth = in_depth // self.upscale_factor
76 | out_height = in_height // self.upscale_factor
77 | out_width = in_width // self.upscale_factor
78 | nOut = channels * self.upscale_factor**3
79 |
80 | input_view = x.contiguous().view(
81 | *x.shape[:-4],
82 | channels,
83 | out_depth,
84 | self.upscale_factor,
85 | out_height,
86 | self.upscale_factor,
87 | out_width,
88 | self.upscale_factor
89 | )
90 |
91 | axes = torch.arange(input_view.ndim)[:-6].tolist() + [-5, -3, -1, -6, -4, -2]
92 | output = input_view.permute(axes).contiguous()
93 |
94 | return output.view(*x.shape[:-4], nOut, out_depth, out_height, out_width)
--------------------------------------------------------------------------------