├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── pixelshuffle_invert.py └── test_pixelshuffle_speed.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 onesixth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pixelshuffle_invert_pytorch 2 | Fast pixelshuffle and pixelshuffle-invert implementations via pytorch. 3 | 4 | pixelshuffle_invert 是 pixelshuffle 的逆向操作。 5 | pixelshuffle_invert is the reverse operation of pixelshuffle. 6 | 7 | 我的实现与官方(v1.3.1)相比支持不同宽高比的输入。 8 | My implementation supports input with different aspect ratios compared to the official (v1.3.1). 9 | 10 | # 速度测试 / Speed Test 11 | 12 | PixelShuffle 13 | 14 | 测试输出。看起来比官方实现还快一点点:) 15 | Test output. Looks a little faster than the official implementation :) 16 | ``` 17 | Warm up 18 | cuda time 1165.0928955078125 19 | perf_counter time 4.8265442 20 | cuda time 1009.6998901367188 21 | perf_counter time 1.0099245999999997 22 | Warm up finish 23 | 24 | Testing my speed 25 | cuda time 60574.80859375 26 | perf_counter time 60.57620209999999 27 | 28 | Testing pytorch 1.3.1 official speed 29 | cuda time 64837.421875 30 | perf_counter time 64.8391723 31 | 32 | 33 | Process finished with exit code 0 34 | 35 | ``` 36 | 37 | 你可以在你的计算机上测试。 38 | You can test on your computer. 39 | ``` 40 | python test_pixelshuffle_speed.py 41 | ``` 42 | 43 | # 怎么用 / How to use 44 | 45 | ## PixelShuffle 46 | official 47 | ``` 48 | import torch 49 | import torch.nn.functional as F 50 | 51 | x = torch.rand(5, 256, 128, 128) # BCHW 52 | y = F.pixel_shuffle(x, 2) 53 | print(y.shape) 54 | ``` 55 | my code 56 | ``` 57 | import torch 58 | from pixelshuffle_invert import pixelshuffle 59 | 60 | x = torch.rand(5, 256, 128, 128) # BCHW 61 | y = pixelshuffle(x, (2, 2)) 62 | print(y.shape) 63 | ``` 64 | 65 | ## PixelShuffle_invert 66 | no official implementation 67 | 68 | my code 69 | ``` 70 | import torch 71 | from pixelshuffle_invert import pixelshuffle_invert 72 | 73 | x = torch.rand(5, 256, 128, 128) # BCHW 74 | y = pixelshuffle_invert(x, (2, 2)) 75 | print(y.shape) 76 | ``` 77 | 78 | # References 79 | https://arxiv.org/abs/1609.05158 80 | Wait to add... 81 | -------------------------------------------------------------------------------- /pixelshuffle_invert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | 5 | @torch.jit.script 6 | def pixelshuffle(x: torch.Tensor, factor_hw: Tuple[int, int]): 7 | pH = factor_hw[0] 8 | pW = factor_hw[1] 9 | y = x 10 | B, iC, iH, iW = y.shape 11 | oC, oH, oW = iC//(pH*pW), iH*pH, iW*pW 12 | y = y.reshape(B, oC, pH, pW, iH, iW) 13 | y = y.permute(0, 1, 4, 2, 5, 3) # B, oC, iH, pH, iW, pW 14 | y = y.reshape(B, oC, oH, oW) 15 | return y 16 | 17 | 18 | @torch.jit.script 19 | def pixelshuffle_invert(x: torch.Tensor, factor_hw: Tuple[int, int]): 20 | pH = factor_hw[0] 21 | pW = factor_hw[1] 22 | y = x 23 | B, iC, iH, iW = y.shape 24 | oC, oH, oW = iC*(pH*pW), iH//pH, iW//pW 25 | y = y.reshape(B, iC, oH, pH, oW, pW) 26 | y = y.permute(0, 1, 3, 5, 2, 4) # B, iC, pH, pW, oH, oW 27 | y = y.reshape(B, oC, oH, oW) 28 | return y 29 | 30 | 31 | if __name__ == '__main__': 32 | import torch.nn.functional as F 33 | 34 | print('Check function correct') 35 | print() 36 | 37 | for s in [1, 2, 4, 8, 16]: 38 | print('Checking scale {}'.format(s)) 39 | x = torch.rand(5, 256, 128, 128) # BCHW 40 | 41 | y1 = F.pixel_shuffle(x, s) 42 | y2 = pixelshuffle(x, (s, s)) 43 | 44 | assert torch.allclose(y1, y2) 45 | print('pixelshuffle works correctly.') 46 | 47 | rev_x = pixelshuffle_invert(y1, (s, s)) 48 | 49 | assert torch.allclose(x, rev_x) 50 | print('pixelshuffle_invert works correctly.') 51 | print() 52 | -------------------------------------------------------------------------------- /test_pixelshuffle_speed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pixelshuffle_invert import pixelshuffle 4 | import time 5 | 6 | 7 | def test_speed(func, args, epoch=50): 8 | a = torch.rand(10, 64, 512, 512) 9 | a.requires_grad = True 10 | 11 | start_record = torch.cuda.Event(enable_timing=True) 12 | end_record = torch.cuda.Event(enable_timing=True) 13 | 14 | start_time = time.perf_counter() 15 | start_record.record() 16 | for _ in range(epoch): 17 | loss = func(a, **args).mean() 18 | loss.backward() 19 | end_record.record() 20 | end_time = time.perf_counter() 21 | 22 | torch.cuda.synchronize() 23 | 24 | print('cuda time', start_record.elapsed_time(end_record)) 25 | print('perf_counter time', end_time - start_time) 26 | 27 | 28 | if __name__ == '__main__': 29 | print('Warm up') 30 | test_speed(pixelshuffle, {'factor_hw': (2, 2)}, epoch=1) 31 | test_speed(F.pixel_shuffle, {'upscale_factor': 2}, epoch=1) 32 | print('Warm up finish') 33 | print() 34 | 35 | print('Testing my speed') 36 | test_speed(pixelshuffle, {'factor_hw': (2, 2)}) 37 | print() 38 | 39 | print('Testing pytorch {} official speed'.format(torch.__version__)) 40 | test_speed(F.pixel_shuffle, {'upscale_factor': 2}) 41 | print() 42 | --------------------------------------------------------------------------------