├── .gitignore ├── .idea ├── misc.xml ├── modules.xml └── vcs.xml ├── .travis.yml ├── LICENSE ├── README.md ├── setup.py ├── test_req.txt ├── torch-dct.iml └── torch_dct ├── __init__.py ├── _dct.py └── test ├── __init__.py ├── test_dct.py └── test_lineardct.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | ### JetBrains template 108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 110 | 111 | # User-specific stuff 112 | .idea/**/workspace.xml 113 | .idea/**/tasks.xml 114 | .idea/**/usage.statistics.xml 115 | .idea/**/dictionaries 116 | .idea/**/shelf 117 | 118 | # Sensitive or high-churn files 119 | .idea/**/dataSources/ 120 | .idea/**/dataSources.ids 121 | .idea/**/dataSources.local.xml 122 | .idea/**/sqlDataSources.xml 123 | .idea/**/dynamic.xml 124 | .idea/**/uiDesigner.xml 125 | .idea/**/dbnavigator.xml 126 | 127 | # Gradle 128 | .idea/**/gradle.xml 129 | .idea/**/libraries 130 | 131 | # Gradle and Maven with auto-import 132 | # When using Gradle or Maven with auto-import, you should exclude module files, 133 | # since they will be recreated, and may cause churn. Uncomment if using 134 | # auto-import. 135 | # .idea/modules.xml 136 | # .idea/*.iml 137 | # .idea/modules 138 | 139 | # CMake 140 | cmake-build-*/ 141 | 142 | # Mongo Explorer plugin 143 | .idea/**/mongoSettings.xml 144 | 145 | # File-based project format 146 | *.iws 147 | 148 | # IntelliJ 149 | out/ 150 | 151 | # mpeltonen/sbt-idea plugin 152 | .idea_modules/ 153 | 154 | # JIRA plugin 155 | atlassian-ide-plugin.xml 156 | 157 | # Cursive Clojure plugin 158 | .idea/replstate.xml 159 | 160 | # Crashlytics plugin (for Android Studio and IntelliJ) 161 | com_crashlytics_export_strings.xml 162 | crashlytics.properties 163 | crashlytics-build.properties 164 | fabric.properties 165 | 166 | # Editor-based Rest Client 167 | .idea/httpRequests 168 | 169 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | IDE 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | #- "2.7" 6 | install: 7 | - pip install -r test_req.txt | cat 8 | script: 9 | - py.test --verbose --cov=./torch_dct 10 | - codecov -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | (c) Copyright 2018 Ziyang Hu. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCT (Discrete Cosine Transform) for pytorch 2 | 3 | [![Build Status](https://travis-ci.com/zh217/torch-dct.svg?branch=master)](https://travis-ci.com/zh217/torch-dct) 4 | [![codecov](https://codecov.io/gh/zh217/torch-dct/branch/master/graph/badge.svg)](https://codecov.io/gh/zh217/torch-dct) 5 | [![PyPI version](https://img.shields.io/pypi/v/torch-dct.svg)](https://pypi.python.org/pypi/torch-dct/) 6 | [![PyPI version](https://img.shields.io/pypi/pyversions/torch-dct.svg)](https://pypi.python.org/pypi/torch-dct/) 7 | [![PyPI status](https://img.shields.io/pypi/status/torch-dct.svg)](https://pypi.python.org/pypi/torch-dct/) 8 | [![GitHub license](https://img.shields.io/github/license/zh217/torch-dct.svg)](https://github.com/zh217/torch-dct/blob/master/LICENSE) 9 | 10 | 11 | This library implements DCT in terms of the built-in FFT operations in pytorch so that 12 | back propagation works through it, on both CPU and GPU. For more information on 13 | DCT and the algorithms used here, see 14 | [Wikipedia](https://en.wikipedia.org/wiki/Discrete_cosine_transform) and the paper by 15 | [J. Makhoul](https://ieeexplore.ieee.org/document/1163351/). This 16 | [StackExchange article](https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft) 17 | might also be helpful. 18 | 19 | The following are currently implemented: 20 | 21 | * 1-D DCT-I and its inverse (which is a scaled DCT-I) 22 | * 1-D DCT-II and its inverse (which is a scaled DCT-III) 23 | * 2-D DCT-II and its inverse (which is a scaled DCT-III) 24 | * 3-D DCT-II and its inverse (which is a scaled DCT-III) 25 | 26 | ## Install 27 | 28 | ``` 29 | pip install torch-dct 30 | ``` 31 | 32 | Requires `torch>=0.4.1` (lower versions are probably OK but I haven't tested them). 33 | 34 | You can run test by getting the source and run `pytest`. To run the test you also 35 | need `scipy` installed. 36 | 37 | ## Usage 38 | 39 | ```python 40 | import torch 41 | import torch_dct as dct 42 | 43 | x = torch.randn(200) 44 | X = dct.dct(x) # DCT-II done through the last dimension 45 | y = dct.idct(X) # scaled DCT-III done through the last dimension 46 | assert (torch.abs(x - y)).sum() < 1e-10 # x == y within numerical tolerance 47 | ``` 48 | 49 | `dct.dct1` and `dct.idct1` are for DCT-I and its inverse. The usage is the same. 50 | 51 | Just replace `dct` and `idct` by `dct_2d`, `dct_3d`, `idct_2d`, `idct_3d`, etc 52 | to get the multidimensional versions. 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='torch-dct', 5 | version='0.1.6', 6 | packages=['torch_dct'], 7 | platforms='any', 8 | classifiers=[ 9 | 'Development Status :: 4 - Beta', 10 | 'License :: OSI Approved :: Apache Software License', 11 | 'Programming Language :: Python :: 2', 12 | 'Programming Language :: Python :: 3' 13 | ], 14 | install_requires=['torch>=0.4.1'], 15 | url='https://github.com/zh217/torch-dct', 16 | license='MIT', 17 | author='Ziyang Hu', 18 | author_email='hu.ziyang@cantab.net', 19 | description='Discrete Cosine Transform (DCT) for pytorch', 20 | long_description=open('README.md').read(), 21 | long_description_content_type='text/markdown' 22 | ) 23 | -------------------------------------------------------------------------------- /test_req.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-runner 3 | pytest-cov 4 | codecov 5 | torch>=0.4.1 6 | scipy -------------------------------------------------------------------------------- /torch-dct.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 16 | -------------------------------------------------------------------------------- /torch_dct/__init__.py: -------------------------------------------------------------------------------- 1 | from ._dct import dct, idct, dct1, idct1, dct_2d, idct_2d, dct_3d, idct_3d, LinearDCT, apply_linear_2d, apply_linear_3d 2 | -------------------------------------------------------------------------------- /torch_dct/_dct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | # PyTorch 1.7.0 and newer versions 7 | import torch.fft 8 | 9 | def dct1_rfft_impl(x): 10 | return torch.view_as_real(torch.fft.rfft(x, dim=1)) 11 | 12 | def dct_fft_impl(v): 13 | return torch.view_as_real(torch.fft.fft(v, dim=1)) 14 | 15 | def idct_irfft_impl(V): 16 | return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 17 | except ImportError: 18 | # PyTorch 1.6.0 and older versions 19 | def dct1_rfft_impl(x): 20 | return torch.rfft(x, 1) 21 | 22 | def dct_fft_impl(v): 23 | return torch.rfft(v, 1, onesided=False) 24 | 25 | def idct_irfft_impl(V): 26 | return torch.irfft(V, 1, onesided=False) 27 | 28 | 29 | 30 | def dct1(x): 31 | """ 32 | Discrete Cosine Transform, Type I 33 | 34 | :param x: the input signal 35 | :return: the DCT-I of the signal over the last dimension 36 | """ 37 | x_shape = x.shape 38 | x = x.view(-1, x_shape[-1]) 39 | x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1) 40 | 41 | return dct1_rfft_impl(x)[:, :, 0].view(*x_shape) 42 | 43 | 44 | def idct1(X): 45 | """ 46 | The inverse of DCT-I, which is just a scaled DCT-I 47 | 48 | Our definition if idct1 is such that idct1(dct1(x)) == x 49 | 50 | :param X: the input signal 51 | :return: the inverse DCT-I of the signal over the last dimension 52 | """ 53 | n = X.shape[-1] 54 | return dct1(X) / (2 * (n - 1)) 55 | 56 | 57 | def dct(x, norm=None): 58 | """ 59 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 60 | 61 | For the meaning of the parameter `norm`, see: 62 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 63 | 64 | :param x: the input signal 65 | :param norm: the normalization, None or 'ortho' 66 | :return: the DCT-II of the signal over the last dimension 67 | """ 68 | x_shape = x.shape 69 | N = x_shape[-1] 70 | x = x.contiguous().view(-1, N) 71 | 72 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 73 | 74 | Vc = dct_fft_impl(v) 75 | 76 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 77 | W_r = torch.cos(k) 78 | W_i = torch.sin(k) 79 | 80 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 81 | 82 | if norm == 'ortho': 83 | V[:, 0] /= np.sqrt(N) * 2 84 | V[:, 1:] /= np.sqrt(N / 2) * 2 85 | 86 | V = 2 * V.view(*x_shape) 87 | 88 | return V 89 | 90 | 91 | def idct(X, norm=None): 92 | """ 93 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 94 | 95 | Our definition of idct is that idct(dct(x)) == x 96 | 97 | For the meaning of the parameter `norm`, see: 98 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 99 | 100 | :param X: the input signal 101 | :param norm: the normalization, None or 'ortho' 102 | :return: the inverse DCT-II of the signal over the last dimension 103 | """ 104 | 105 | x_shape = X.shape 106 | N = x_shape[-1] 107 | 108 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 109 | 110 | if norm == 'ortho': 111 | X_v[:, 0] *= np.sqrt(N) * 2 112 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 113 | 114 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) 115 | W_r = torch.cos(k) 116 | W_i = torch.sin(k) 117 | 118 | V_t_r = X_v 119 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 120 | 121 | V_r = V_t_r * W_r - V_t_i * W_i 122 | V_i = V_t_r * W_i + V_t_i * W_r 123 | 124 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 125 | 126 | v = idct_irfft_impl(V) 127 | x = v.new_zeros(v.shape) 128 | x[:, ::2] += v[:, :N - (N // 2)] 129 | x[:, 1::2] += v.flip([1])[:, :N // 2] 130 | 131 | return x.view(*x_shape) 132 | 133 | 134 | def dct_2d(x, norm=None): 135 | """ 136 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 137 | 138 | For the meaning of the parameter `norm`, see: 139 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 140 | 141 | :param x: the input signal 142 | :param norm: the normalization, None or 'ortho' 143 | :return: the DCT-II of the signal over the last 2 dimensions 144 | """ 145 | X1 = dct(x, norm=norm) 146 | X2 = dct(X1.transpose(-1, -2), norm=norm) 147 | return X2.transpose(-1, -2) 148 | 149 | 150 | def idct_2d(X, norm=None): 151 | """ 152 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III 153 | 154 | Our definition of idct is that idct_2d(dct_2d(x)) == x 155 | 156 | For the meaning of the parameter `norm`, see: 157 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 158 | 159 | :param X: the input signal 160 | :param norm: the normalization, None or 'ortho' 161 | :return: the DCT-II of the signal over the last 2 dimensions 162 | """ 163 | x1 = idct(X, norm=norm) 164 | x2 = idct(x1.transpose(-1, -2), norm=norm) 165 | return x2.transpose(-1, -2) 166 | 167 | 168 | def dct_3d(x, norm=None): 169 | """ 170 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 171 | 172 | For the meaning of the parameter `norm`, see: 173 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 174 | 175 | :param x: the input signal 176 | :param norm: the normalization, None or 'ortho' 177 | :return: the DCT-II of the signal over the last 3 dimensions 178 | """ 179 | X1 = dct(x, norm=norm) 180 | X2 = dct(X1.transpose(-1, -2), norm=norm) 181 | X3 = dct(X2.transpose(-1, -3), norm=norm) 182 | return X3.transpose(-1, -3).transpose(-1, -2) 183 | 184 | 185 | def idct_3d(X, norm=None): 186 | """ 187 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III 188 | 189 | Our definition of idct is that idct_3d(dct_3d(x)) == x 190 | 191 | For the meaning of the parameter `norm`, see: 192 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 193 | 194 | :param X: the input signal 195 | :param norm: the normalization, None or 'ortho' 196 | :return: the DCT-II of the signal over the last 3 dimensions 197 | """ 198 | x1 = idct(X, norm=norm) 199 | x2 = idct(x1.transpose(-1, -2), norm=norm) 200 | x3 = idct(x2.transpose(-1, -3), norm=norm) 201 | return x3.transpose(-1, -3).transpose(-1, -2) 202 | 203 | 204 | class LinearDCT(nn.Linear): 205 | """Implement any DCT as a linear layer; in practice this executes around 206 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 207 | increase memory usage. 208 | :param in_features: size of expected input 209 | :param type: which dct function in this file to use""" 210 | def __init__(self, in_features, type, norm=None, bias=False): 211 | self.type = type 212 | self.N = in_features 213 | self.norm = norm 214 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias) 215 | 216 | def reset_parameters(self): 217 | # initialise using dct function 218 | I = torch.eye(self.N) 219 | if self.type == 'dct1': 220 | self.weight.data = dct1(I).data.t() 221 | elif self.type == 'idct1': 222 | self.weight.data = idct1(I).data.t() 223 | elif self.type == 'dct': 224 | self.weight.data = dct(I, norm=self.norm).data.t() 225 | elif self.type == 'idct': 226 | self.weight.data = idct(I, norm=self.norm).data.t() 227 | self.weight.requires_grad = False # don't learn this! 228 | 229 | 230 | def apply_linear_2d(x, linear_layer): 231 | """Can be used with a LinearDCT layer to do a 2D DCT. 232 | :param x: the input signal 233 | :param linear_layer: any PyTorch Linear layer 234 | :return: result of linear layer applied to last 2 dimensions 235 | """ 236 | X1 = linear_layer(x) 237 | X2 = linear_layer(X1.transpose(-1, -2)) 238 | return X2.transpose(-1, -2) 239 | 240 | def apply_linear_3d(x, linear_layer): 241 | """Can be used with a LinearDCT layer to do a 3D DCT. 242 | :param x: the input signal 243 | :param linear_layer: any PyTorch Linear layer 244 | :return: result of linear layer applied to last 3 dimensions 245 | """ 246 | X1 = linear_layer(x) 247 | X2 = linear_layer(X1.transpose(-1, -2)) 248 | X3 = linear_layer(X2.transpose(-1, -3)) 249 | return X3.transpose(-1, -3).transpose(-1, -2) 250 | 251 | if __name__ == '__main__': 252 | x = torch.Tensor(1000,4096) 253 | x.normal_(0,1) 254 | linear_dct = LinearDCT(4096, 'dct') 255 | error = torch.abs(dct(x) - linear_dct(x)) 256 | assert error.max() < 1e-3, (error, error.max()) 257 | linear_idct = LinearDCT(4096, 'idct') 258 | error = torch.abs(idct(x) - linear_idct(x)) 259 | assert error.max() < 1e-3, (error, error.max()) 260 | 261 | -------------------------------------------------------------------------------- /torch_dct/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zh217/torch-dct/0804f5ed2ddcaecc24c14b096bd62695e0478cec/torch_dct/test/__init__.py -------------------------------------------------------------------------------- /torch_dct/test/test_dct.py: -------------------------------------------------------------------------------- 1 | import torch_dct as dct 2 | import scipy.fftpack as fftpack 3 | import numpy as np 4 | import torch 5 | 6 | np.random.seed(1) 7 | 8 | EPS = 1e-10 9 | 10 | 11 | def test_dct1(): 12 | for N in [2, 5, 32, 111]: 13 | x = np.random.normal(size=(1, N,)) 14 | ref = fftpack.dct(x, type=1) 15 | act = dct.dct1(torch.tensor(x)).numpy() 16 | assert np.abs(ref - act).max() < EPS, ref 17 | 18 | for d in [2, 3, 4]: 19 | x = np.random.normal(size=(2,) * d) 20 | ref = fftpack.dct(x, type=1) 21 | act = dct.dct1(torch.tensor(x)).numpy() 22 | assert np.abs(ref - act).max() < EPS, ref 23 | 24 | 25 | def test_idct1(): 26 | for N in [2, 5, 32, 111]: 27 | x = np.random.normal(size=(1, N)) 28 | X = dct.dct1(torch.tensor(x)) 29 | y = dct.idct1(X).numpy() 30 | assert np.abs(x - y).max() < EPS, x 31 | 32 | 33 | def test_dct(): 34 | for norm in [None, 'ortho']: 35 | for N in [2, 3, 5, 32, 111]: 36 | x = np.random.normal(size=(1, N,)) 37 | ref = fftpack.dct(x, type=2, norm=norm) 38 | act = dct.dct(torch.tensor(x), norm=norm).numpy() 39 | assert np.abs(ref - act).max() < EPS, (norm, N) 40 | 41 | for d in [2, 3, 4, 11]: 42 | x = np.random.normal(size=(2,) * d) 43 | ref = fftpack.dct(x, type=2, norm=norm) 44 | act = dct.dct(torch.tensor(x), norm=norm).numpy() 45 | assert np.abs(ref - act).max() < EPS, (norm, d) 46 | 47 | 48 | def test_idct(): 49 | for norm in [None, 'ortho']: 50 | for N in [5, 2, 32, 111]: 51 | x = np.random.normal(size=(1, N)) 52 | X = dct.dct(torch.tensor(x), norm=norm) 53 | y = dct.idct(X, norm=norm).numpy() 54 | assert np.abs(x - y).max() < EPS, x 55 | 56 | 57 | def test_cuda(): 58 | if torch.cuda.is_available(): 59 | device = torch.device('cuda:0') 60 | 61 | for N in [2, 5, 32, 111]: 62 | x = np.random.normal(size=(1, N,)) 63 | ref = fftpack.dct(x, type=1) 64 | act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy() 65 | assert np.abs(ref - act).max() < EPS, ref 66 | 67 | for d in [2, 3, 4]: 68 | x = np.random.normal(size=(2,) * d) 69 | ref = fftpack.dct(x, type=1) 70 | act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy() 71 | assert np.abs(ref - act).max() < EPS, ref 72 | 73 | for norm in [None, 'ortho']: 74 | for N in [2, 3, 5, 32, 111]: 75 | x = np.random.normal(size=(1, N,)) 76 | ref = fftpack.dct(x, type=2, norm=norm) 77 | act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy() 78 | assert np.abs(ref - act).max() < EPS, (norm, N) 79 | 80 | for d in [2, 3, 4, 11]: 81 | x = np.random.normal(size=(2,) * d) 82 | ref = fftpack.dct(x, type=2, norm=norm) 83 | act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy() 84 | assert np.abs(ref - act).max() < EPS, (norm, d) 85 | 86 | for N in [5, 2, 32, 111]: 87 | x = np.random.normal(size=(1, N)) 88 | X = dct.dct(torch.tensor(x, device=device), norm=norm) 89 | y = dct.idct(X, norm=norm).cpu().numpy() 90 | assert np.abs(x - y).max() < EPS, x 91 | 92 | def test_dct_2d(): 93 | for N1 in [2, 5, 32]: 94 | for N2 in [2, 5, 32]: 95 | x = np.random.normal(size=(1, N1, N2)) 96 | ref = fftpack.dct(x, axis=2, type=2) 97 | ref = fftpack.dct(ref, axis=1, type=2) 98 | act = dct.dct_2d(torch.tensor(x)).numpy() 99 | assert np.abs(ref - act).max() < EPS, (ref, act) 100 | 101 | 102 | def test_idct_2d(): 103 | for N1 in [2, 5, 32]: 104 | for N2 in [2, 5, 32]: 105 | x = np.random.normal(size=(1, N1, N2)) 106 | X = dct.dct_2d(torch.tensor(x)) 107 | y = dct.idct_2d(X).numpy() 108 | assert np.abs(x - y).max() < EPS, x 109 | 110 | 111 | def test_dct_3d(): 112 | for N1 in [2, 5, 32]: 113 | for N2 in [2, 5, 32]: 114 | for N3 in [2, 5, 32]: 115 | x = np.random.normal(size=(1, N1, N2, N3)) 116 | ref = fftpack.dct(x, axis=3, type=2) 117 | ref = fftpack.dct(ref, axis=2, type=2) 118 | ref = fftpack.dct(ref, axis=1, type=2) 119 | act = dct.dct_3d(torch.tensor(x)).numpy() 120 | assert np.abs(ref - act).max() < EPS, (ref, act) 121 | 122 | 123 | def test_idct_3d(): 124 | for N1 in [2, 5, 32]: 125 | for N2 in [2, 5, 32]: 126 | for N3 in [2, 5, 32]: 127 | x = np.random.normal(size=(1, N1, N2, N3)) 128 | X = dct.dct_3d(torch.tensor(x)) 129 | y = dct.idct_3d(X).numpy() 130 | assert np.abs(x - y).max() < EPS, x 131 | -------------------------------------------------------------------------------- /torch_dct/test/test_lineardct.py: -------------------------------------------------------------------------------- 1 | import torch_dct 2 | import scipy.fftpack as fftpack 3 | import numpy as np 4 | import torch 5 | 6 | np.random.seed(1) 7 | 8 | EPS = 1e-3 9 | # THIS IS NOT HOW THESE LAYERS SHOULD BE USED IN PRACTICE 10 | # only written this way for testing convenience 11 | dct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='dct1')(x).data 12 | idct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='idct1')(x).data 13 | def dct(x, norm=None): 14 | return torch_dct.LinearDCT(x.size(1), type='dct', norm=norm)(x).data 15 | def idct(x, norm=None): 16 | return torch_dct.LinearDCT(x.size(1), type='idct', norm=norm)(x).data 17 | 18 | dct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data 19 | dct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data 20 | idct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data 21 | idct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data 22 | 23 | def test_dct1(): 24 | for N in [2, 5, 32, 111]: 25 | x = np.random.normal(size=(1, N,)) 26 | ref = fftpack.dct(x, type=1) 27 | act = dct1(torch.tensor(x).float()).numpy() 28 | assert np.abs(ref - act).max() < EPS, ref 29 | 30 | for d in [2, 3, 4]: 31 | x = np.random.normal(size=(2,) * d) 32 | ref = fftpack.dct(x, type=1) 33 | act = dct1(torch.tensor(x).float()).numpy() 34 | assert np.abs(ref - act).max() < EPS, ref 35 | 36 | 37 | def test_idct1(): 38 | for N in [2, 5, 32, 111]: 39 | x = np.random.normal(size=(1, N)) 40 | X = dct1(torch.tensor(x).float()) 41 | y = idct1(X).numpy() 42 | assert np.abs(x - y).max() < EPS, x 43 | 44 | 45 | def test_dct(): 46 | for norm in [None, 'ortho']: 47 | for N in [2, 3, 5, 32, 111]: 48 | x = np.random.normal(size=(1, N,)) 49 | ref = fftpack.dct(x, type=2, norm=norm) 50 | act = dct(torch.tensor(x).float(), norm=norm).numpy() 51 | assert np.abs(ref - act).max() < EPS, (norm, N) 52 | 53 | for d in [2, 3, 4, 11]: 54 | x = np.random.normal(size=(2,) * d) 55 | ref = fftpack.dct(x, type=2, norm=norm) 56 | act = dct(torch.tensor(x).float(), norm=norm).numpy() 57 | assert np.abs(ref - act).max() < EPS, (norm, d) 58 | 59 | 60 | def test_idct(): 61 | for norm in [None, 'ortho']: 62 | for N in [5, 2, 32, 111]: 63 | x = np.random.normal(size=(1, N)) 64 | X = dct(torch.tensor(x).float(), norm=norm) 65 | y = idct(X, norm=norm).numpy() 66 | assert np.abs(x - y).max() < EPS, x 67 | 68 | def test_dct_2d(): 69 | for N1 in [2, 5, 32]: 70 | x = np.random.normal(size=(1, N1, N1)) 71 | ref = fftpack.dct(x, axis=2, type=2) 72 | ref = fftpack.dct(ref, axis=1, type=2) 73 | act = dct_2d(torch.tensor(x).float()).numpy() 74 | assert np.abs(ref - act).max() < EPS, (ref, act) 75 | 76 | 77 | def test_idct_2d(): 78 | for N1 in [2, 5, 32]: 79 | x = np.random.normal(size=(1, N1, N1)) 80 | X = dct_2d(torch.tensor(x).float()) 81 | y = idct_2d(X).numpy() 82 | assert np.abs(x - y).max() < EPS, x 83 | 84 | 85 | def test_dct_3d(): 86 | for N1 in [2, 5, 32]: 87 | x = np.random.normal(size=(1, N1, N1, N1)) 88 | ref = fftpack.dct(x, axis=3, type=2) 89 | ref = fftpack.dct(ref, axis=2, type=2) 90 | ref = fftpack.dct(ref, axis=1, type=2) 91 | act = dct_3d(torch.tensor(x).float()).numpy() 92 | assert np.abs(ref - act).max() < EPS, (ref, act) 93 | 94 | 95 | def test_idct_3d(): 96 | for N1 in [2, 5, 32]: 97 | x = np.random.normal(size=(1, N1, N1, N1)) 98 | X = dct_3d(torch.tensor(x).float()) 99 | y = idct_3d(X).numpy() 100 | assert np.abs(x - y).max() < EPS, x 101 | --------------------------------------------------------------------------------