├── .gitignore
├── .idea
├── misc.xml
├── modules.xml
└── vcs.xml
├── .travis.yml
├── LICENSE
├── README.md
├── setup.py
├── test_req.txt
├── torch-dct.iml
└── torch_dct
├── __init__.py
├── _dct.py
└── test
├── __init__.py
├── test_dct.py
└── test_lineardct.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 | ### JetBrains template
108 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
109 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
110 |
111 | # User-specific stuff
112 | .idea/**/workspace.xml
113 | .idea/**/tasks.xml
114 | .idea/**/usage.statistics.xml
115 | .idea/**/dictionaries
116 | .idea/**/shelf
117 |
118 | # Sensitive or high-churn files
119 | .idea/**/dataSources/
120 | .idea/**/dataSources.ids
121 | .idea/**/dataSources.local.xml
122 | .idea/**/sqlDataSources.xml
123 | .idea/**/dynamic.xml
124 | .idea/**/uiDesigner.xml
125 | .idea/**/dbnavigator.xml
126 |
127 | # Gradle
128 | .idea/**/gradle.xml
129 | .idea/**/libraries
130 |
131 | # Gradle and Maven with auto-import
132 | # When using Gradle or Maven with auto-import, you should exclude module files,
133 | # since they will be recreated, and may cause churn. Uncomment if using
134 | # auto-import.
135 | # .idea/modules.xml
136 | # .idea/*.iml
137 | # .idea/modules
138 |
139 | # CMake
140 | cmake-build-*/
141 |
142 | # Mongo Explorer plugin
143 | .idea/**/mongoSettings.xml
144 |
145 | # File-based project format
146 | *.iws
147 |
148 | # IntelliJ
149 | out/
150 |
151 | # mpeltonen/sbt-idea plugin
152 | .idea_modules/
153 |
154 | # JIRA plugin
155 | atlassian-ide-plugin.xml
156 |
157 | # Cursive Clojure plugin
158 | .idea/replstate.xml
159 |
160 | # Crashlytics plugin (for Android Studio and IntelliJ)
161 | com_crashlytics_export_strings.xml
162 | crashlytics.properties
163 | crashlytics-build.properties
164 | fabric.properties
165 |
166 | # Editor-based Rest Client
167 | .idea/httpRequests
168 |
169 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | IDE
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.5"
4 | - "3.6"
5 | #- "2.7"
6 | install:
7 | - pip install -r test_req.txt | cat
8 | script:
9 | - py.test --verbose --cov=./torch_dct
10 | - codecov
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | (c) Copyright 2018 Ziyang Hu.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCT (Discrete Cosine Transform) for pytorch
2 |
3 | [](https://travis-ci.com/zh217/torch-dct)
4 | [](https://codecov.io/gh/zh217/torch-dct)
5 | [](https://pypi.python.org/pypi/torch-dct/)
6 | [](https://pypi.python.org/pypi/torch-dct/)
7 | [](https://pypi.python.org/pypi/torch-dct/)
8 | [](https://github.com/zh217/torch-dct/blob/master/LICENSE)
9 |
10 |
11 | This library implements DCT in terms of the built-in FFT operations in pytorch so that
12 | back propagation works through it, on both CPU and GPU. For more information on
13 | DCT and the algorithms used here, see
14 | [Wikipedia](https://en.wikipedia.org/wiki/Discrete_cosine_transform) and the paper by
15 | [J. Makhoul](https://ieeexplore.ieee.org/document/1163351/). This
16 | [StackExchange article](https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft)
17 | might also be helpful.
18 |
19 | The following are currently implemented:
20 |
21 | * 1-D DCT-I and its inverse (which is a scaled DCT-I)
22 | * 1-D DCT-II and its inverse (which is a scaled DCT-III)
23 | * 2-D DCT-II and its inverse (which is a scaled DCT-III)
24 | * 3-D DCT-II and its inverse (which is a scaled DCT-III)
25 |
26 | ## Install
27 |
28 | ```
29 | pip install torch-dct
30 | ```
31 |
32 | Requires `torch>=0.4.1` (lower versions are probably OK but I haven't tested them).
33 |
34 | You can run test by getting the source and run `pytest`. To run the test you also
35 | need `scipy` installed.
36 |
37 | ## Usage
38 |
39 | ```python
40 | import torch
41 | import torch_dct as dct
42 |
43 | x = torch.randn(200)
44 | X = dct.dct(x) # DCT-II done through the last dimension
45 | y = dct.idct(X) # scaled DCT-III done through the last dimension
46 | assert (torch.abs(x - y)).sum() < 1e-10 # x == y within numerical tolerance
47 | ```
48 |
49 | `dct.dct1` and `dct.idct1` are for DCT-I and its inverse. The usage is the same.
50 |
51 | Just replace `dct` and `idct` by `dct_2d`, `dct_3d`, `idct_2d`, `idct_3d`, etc
52 | to get the multidimensional versions.
53 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='torch-dct',
5 | version='0.1.6',
6 | packages=['torch_dct'],
7 | platforms='any',
8 | classifiers=[
9 | 'Development Status :: 4 - Beta',
10 | 'License :: OSI Approved :: Apache Software License',
11 | 'Programming Language :: Python :: 2',
12 | 'Programming Language :: Python :: 3'
13 | ],
14 | install_requires=['torch>=0.4.1'],
15 | url='https://github.com/zh217/torch-dct',
16 | license='MIT',
17 | author='Ziyang Hu',
18 | author_email='hu.ziyang@cantab.net',
19 | description='Discrete Cosine Transform (DCT) for pytorch',
20 | long_description=open('README.md').read(),
21 | long_description_content_type='text/markdown'
22 | )
23 |
--------------------------------------------------------------------------------
/test_req.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-runner
3 | pytest-cov
4 | codecov
5 | torch>=0.4.1
6 | scipy
--------------------------------------------------------------------------------
/torch-dct.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/torch_dct/__init__.py:
--------------------------------------------------------------------------------
1 | from ._dct import dct, idct, dct1, idct1, dct_2d, idct_2d, dct_3d, idct_3d, LinearDCT, apply_linear_2d, apply_linear_3d
2 |
--------------------------------------------------------------------------------
/torch_dct/_dct.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | try:
6 | # PyTorch 1.7.0 and newer versions
7 | import torch.fft
8 |
9 | def dct1_rfft_impl(x):
10 | return torch.view_as_real(torch.fft.rfft(x, dim=1))
11 |
12 | def dct_fft_impl(v):
13 | return torch.view_as_real(torch.fft.fft(v, dim=1))
14 |
15 | def idct_irfft_impl(V):
16 | return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
17 | except ImportError:
18 | # PyTorch 1.6.0 and older versions
19 | def dct1_rfft_impl(x):
20 | return torch.rfft(x, 1)
21 |
22 | def dct_fft_impl(v):
23 | return torch.rfft(v, 1, onesided=False)
24 |
25 | def idct_irfft_impl(V):
26 | return torch.irfft(V, 1, onesided=False)
27 |
28 |
29 |
30 | def dct1(x):
31 | """
32 | Discrete Cosine Transform, Type I
33 |
34 | :param x: the input signal
35 | :return: the DCT-I of the signal over the last dimension
36 | """
37 | x_shape = x.shape
38 | x = x.view(-1, x_shape[-1])
39 | x = torch.cat([x, x.flip([1])[:, 1:-1]], dim=1)
40 |
41 | return dct1_rfft_impl(x)[:, :, 0].view(*x_shape)
42 |
43 |
44 | def idct1(X):
45 | """
46 | The inverse of DCT-I, which is just a scaled DCT-I
47 |
48 | Our definition if idct1 is such that idct1(dct1(x)) == x
49 |
50 | :param X: the input signal
51 | :return: the inverse DCT-I of the signal over the last dimension
52 | """
53 | n = X.shape[-1]
54 | return dct1(X) / (2 * (n - 1))
55 |
56 |
57 | def dct(x, norm=None):
58 | """
59 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
60 |
61 | For the meaning of the parameter `norm`, see:
62 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
63 |
64 | :param x: the input signal
65 | :param norm: the normalization, None or 'ortho'
66 | :return: the DCT-II of the signal over the last dimension
67 | """
68 | x_shape = x.shape
69 | N = x_shape[-1]
70 | x = x.contiguous().view(-1, N)
71 |
72 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
73 |
74 | Vc = dct_fft_impl(v)
75 |
76 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
77 | W_r = torch.cos(k)
78 | W_i = torch.sin(k)
79 |
80 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
81 |
82 | if norm == 'ortho':
83 | V[:, 0] /= np.sqrt(N) * 2
84 | V[:, 1:] /= np.sqrt(N / 2) * 2
85 |
86 | V = 2 * V.view(*x_shape)
87 |
88 | return V
89 |
90 |
91 | def idct(X, norm=None):
92 | """
93 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
94 |
95 | Our definition of idct is that idct(dct(x)) == x
96 |
97 | For the meaning of the parameter `norm`, see:
98 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
99 |
100 | :param X: the input signal
101 | :param norm: the normalization, None or 'ortho'
102 | :return: the inverse DCT-II of the signal over the last dimension
103 | """
104 |
105 | x_shape = X.shape
106 | N = x_shape[-1]
107 |
108 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
109 |
110 | if norm == 'ortho':
111 | X_v[:, 0] *= np.sqrt(N) * 2
112 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
113 |
114 | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
115 | W_r = torch.cos(k)
116 | W_i = torch.sin(k)
117 |
118 | V_t_r = X_v
119 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
120 |
121 | V_r = V_t_r * W_r - V_t_i * W_i
122 | V_i = V_t_r * W_i + V_t_i * W_r
123 |
124 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
125 |
126 | v = idct_irfft_impl(V)
127 | x = v.new_zeros(v.shape)
128 | x[:, ::2] += v[:, :N - (N // 2)]
129 | x[:, 1::2] += v.flip([1])[:, :N // 2]
130 |
131 | return x.view(*x_shape)
132 |
133 |
134 | def dct_2d(x, norm=None):
135 | """
136 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
137 |
138 | For the meaning of the parameter `norm`, see:
139 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
140 |
141 | :param x: the input signal
142 | :param norm: the normalization, None or 'ortho'
143 | :return: the DCT-II of the signal over the last 2 dimensions
144 | """
145 | X1 = dct(x, norm=norm)
146 | X2 = dct(X1.transpose(-1, -2), norm=norm)
147 | return X2.transpose(-1, -2)
148 |
149 |
150 | def idct_2d(X, norm=None):
151 | """
152 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
153 |
154 | Our definition of idct is that idct_2d(dct_2d(x)) == x
155 |
156 | For the meaning of the parameter `norm`, see:
157 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
158 |
159 | :param X: the input signal
160 | :param norm: the normalization, None or 'ortho'
161 | :return: the DCT-II of the signal over the last 2 dimensions
162 | """
163 | x1 = idct(X, norm=norm)
164 | x2 = idct(x1.transpose(-1, -2), norm=norm)
165 | return x2.transpose(-1, -2)
166 |
167 |
168 | def dct_3d(x, norm=None):
169 | """
170 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
171 |
172 | For the meaning of the parameter `norm`, see:
173 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
174 |
175 | :param x: the input signal
176 | :param norm: the normalization, None or 'ortho'
177 | :return: the DCT-II of the signal over the last 3 dimensions
178 | """
179 | X1 = dct(x, norm=norm)
180 | X2 = dct(X1.transpose(-1, -2), norm=norm)
181 | X3 = dct(X2.transpose(-1, -3), norm=norm)
182 | return X3.transpose(-1, -3).transpose(-1, -2)
183 |
184 |
185 | def idct_3d(X, norm=None):
186 | """
187 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
188 |
189 | Our definition of idct is that idct_3d(dct_3d(x)) == x
190 |
191 | For the meaning of the parameter `norm`, see:
192 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
193 |
194 | :param X: the input signal
195 | :param norm: the normalization, None or 'ortho'
196 | :return: the DCT-II of the signal over the last 3 dimensions
197 | """
198 | x1 = idct(X, norm=norm)
199 | x2 = idct(x1.transpose(-1, -2), norm=norm)
200 | x3 = idct(x2.transpose(-1, -3), norm=norm)
201 | return x3.transpose(-1, -3).transpose(-1, -2)
202 |
203 |
204 | class LinearDCT(nn.Linear):
205 | """Implement any DCT as a linear layer; in practice this executes around
206 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
207 | increase memory usage.
208 | :param in_features: size of expected input
209 | :param type: which dct function in this file to use"""
210 | def __init__(self, in_features, type, norm=None, bias=False):
211 | self.type = type
212 | self.N = in_features
213 | self.norm = norm
214 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
215 |
216 | def reset_parameters(self):
217 | # initialise using dct function
218 | I = torch.eye(self.N)
219 | if self.type == 'dct1':
220 | self.weight.data = dct1(I).data.t()
221 | elif self.type == 'idct1':
222 | self.weight.data = idct1(I).data.t()
223 | elif self.type == 'dct':
224 | self.weight.data = dct(I, norm=self.norm).data.t()
225 | elif self.type == 'idct':
226 | self.weight.data = idct(I, norm=self.norm).data.t()
227 | self.weight.requires_grad = False # don't learn this!
228 |
229 |
230 | def apply_linear_2d(x, linear_layer):
231 | """Can be used with a LinearDCT layer to do a 2D DCT.
232 | :param x: the input signal
233 | :param linear_layer: any PyTorch Linear layer
234 | :return: result of linear layer applied to last 2 dimensions
235 | """
236 | X1 = linear_layer(x)
237 | X2 = linear_layer(X1.transpose(-1, -2))
238 | return X2.transpose(-1, -2)
239 |
240 | def apply_linear_3d(x, linear_layer):
241 | """Can be used with a LinearDCT layer to do a 3D DCT.
242 | :param x: the input signal
243 | :param linear_layer: any PyTorch Linear layer
244 | :return: result of linear layer applied to last 3 dimensions
245 | """
246 | X1 = linear_layer(x)
247 | X2 = linear_layer(X1.transpose(-1, -2))
248 | X3 = linear_layer(X2.transpose(-1, -3))
249 | return X3.transpose(-1, -3).transpose(-1, -2)
250 |
251 | if __name__ == '__main__':
252 | x = torch.Tensor(1000,4096)
253 | x.normal_(0,1)
254 | linear_dct = LinearDCT(4096, 'dct')
255 | error = torch.abs(dct(x) - linear_dct(x))
256 | assert error.max() < 1e-3, (error, error.max())
257 | linear_idct = LinearDCT(4096, 'idct')
258 | error = torch.abs(idct(x) - linear_idct(x))
259 | assert error.max() < 1e-3, (error, error.max())
260 |
261 |
--------------------------------------------------------------------------------
/torch_dct/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zh217/torch-dct/0804f5ed2ddcaecc24c14b096bd62695e0478cec/torch_dct/test/__init__.py
--------------------------------------------------------------------------------
/torch_dct/test/test_dct.py:
--------------------------------------------------------------------------------
1 | import torch_dct as dct
2 | import scipy.fftpack as fftpack
3 | import numpy as np
4 | import torch
5 |
6 | np.random.seed(1)
7 |
8 | EPS = 1e-10
9 |
10 |
11 | def test_dct1():
12 | for N in [2, 5, 32, 111]:
13 | x = np.random.normal(size=(1, N,))
14 | ref = fftpack.dct(x, type=1)
15 | act = dct.dct1(torch.tensor(x)).numpy()
16 | assert np.abs(ref - act).max() < EPS, ref
17 |
18 | for d in [2, 3, 4]:
19 | x = np.random.normal(size=(2,) * d)
20 | ref = fftpack.dct(x, type=1)
21 | act = dct.dct1(torch.tensor(x)).numpy()
22 | assert np.abs(ref - act).max() < EPS, ref
23 |
24 |
25 | def test_idct1():
26 | for N in [2, 5, 32, 111]:
27 | x = np.random.normal(size=(1, N))
28 | X = dct.dct1(torch.tensor(x))
29 | y = dct.idct1(X).numpy()
30 | assert np.abs(x - y).max() < EPS, x
31 |
32 |
33 | def test_dct():
34 | for norm in [None, 'ortho']:
35 | for N in [2, 3, 5, 32, 111]:
36 | x = np.random.normal(size=(1, N,))
37 | ref = fftpack.dct(x, type=2, norm=norm)
38 | act = dct.dct(torch.tensor(x), norm=norm).numpy()
39 | assert np.abs(ref - act).max() < EPS, (norm, N)
40 |
41 | for d in [2, 3, 4, 11]:
42 | x = np.random.normal(size=(2,) * d)
43 | ref = fftpack.dct(x, type=2, norm=norm)
44 | act = dct.dct(torch.tensor(x), norm=norm).numpy()
45 | assert np.abs(ref - act).max() < EPS, (norm, d)
46 |
47 |
48 | def test_idct():
49 | for norm in [None, 'ortho']:
50 | for N in [5, 2, 32, 111]:
51 | x = np.random.normal(size=(1, N))
52 | X = dct.dct(torch.tensor(x), norm=norm)
53 | y = dct.idct(X, norm=norm).numpy()
54 | assert np.abs(x - y).max() < EPS, x
55 |
56 |
57 | def test_cuda():
58 | if torch.cuda.is_available():
59 | device = torch.device('cuda:0')
60 |
61 | for N in [2, 5, 32, 111]:
62 | x = np.random.normal(size=(1, N,))
63 | ref = fftpack.dct(x, type=1)
64 | act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy()
65 | assert np.abs(ref - act).max() < EPS, ref
66 |
67 | for d in [2, 3, 4]:
68 | x = np.random.normal(size=(2,) * d)
69 | ref = fftpack.dct(x, type=1)
70 | act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy()
71 | assert np.abs(ref - act).max() < EPS, ref
72 |
73 | for norm in [None, 'ortho']:
74 | for N in [2, 3, 5, 32, 111]:
75 | x = np.random.normal(size=(1, N,))
76 | ref = fftpack.dct(x, type=2, norm=norm)
77 | act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy()
78 | assert np.abs(ref - act).max() < EPS, (norm, N)
79 |
80 | for d in [2, 3, 4, 11]:
81 | x = np.random.normal(size=(2,) * d)
82 | ref = fftpack.dct(x, type=2, norm=norm)
83 | act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy()
84 | assert np.abs(ref - act).max() < EPS, (norm, d)
85 |
86 | for N in [5, 2, 32, 111]:
87 | x = np.random.normal(size=(1, N))
88 | X = dct.dct(torch.tensor(x, device=device), norm=norm)
89 | y = dct.idct(X, norm=norm).cpu().numpy()
90 | assert np.abs(x - y).max() < EPS, x
91 |
92 | def test_dct_2d():
93 | for N1 in [2, 5, 32]:
94 | for N2 in [2, 5, 32]:
95 | x = np.random.normal(size=(1, N1, N2))
96 | ref = fftpack.dct(x, axis=2, type=2)
97 | ref = fftpack.dct(ref, axis=1, type=2)
98 | act = dct.dct_2d(torch.tensor(x)).numpy()
99 | assert np.abs(ref - act).max() < EPS, (ref, act)
100 |
101 |
102 | def test_idct_2d():
103 | for N1 in [2, 5, 32]:
104 | for N2 in [2, 5, 32]:
105 | x = np.random.normal(size=(1, N1, N2))
106 | X = dct.dct_2d(torch.tensor(x))
107 | y = dct.idct_2d(X).numpy()
108 | assert np.abs(x - y).max() < EPS, x
109 |
110 |
111 | def test_dct_3d():
112 | for N1 in [2, 5, 32]:
113 | for N2 in [2, 5, 32]:
114 | for N3 in [2, 5, 32]:
115 | x = np.random.normal(size=(1, N1, N2, N3))
116 | ref = fftpack.dct(x, axis=3, type=2)
117 | ref = fftpack.dct(ref, axis=2, type=2)
118 | ref = fftpack.dct(ref, axis=1, type=2)
119 | act = dct.dct_3d(torch.tensor(x)).numpy()
120 | assert np.abs(ref - act).max() < EPS, (ref, act)
121 |
122 |
123 | def test_idct_3d():
124 | for N1 in [2, 5, 32]:
125 | for N2 in [2, 5, 32]:
126 | for N3 in [2, 5, 32]:
127 | x = np.random.normal(size=(1, N1, N2, N3))
128 | X = dct.dct_3d(torch.tensor(x))
129 | y = dct.idct_3d(X).numpy()
130 | assert np.abs(x - y).max() < EPS, x
131 |
--------------------------------------------------------------------------------
/torch_dct/test/test_lineardct.py:
--------------------------------------------------------------------------------
1 | import torch_dct
2 | import scipy.fftpack as fftpack
3 | import numpy as np
4 | import torch
5 |
6 | np.random.seed(1)
7 |
8 | EPS = 1e-3
9 | # THIS IS NOT HOW THESE LAYERS SHOULD BE USED IN PRACTICE
10 | # only written this way for testing convenience
11 | dct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='dct1')(x).data
12 | idct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='idct1')(x).data
13 | def dct(x, norm=None):
14 | return torch_dct.LinearDCT(x.size(1), type='dct', norm=norm)(x).data
15 | def idct(x, norm=None):
16 | return torch_dct.LinearDCT(x.size(1), type='idct', norm=norm)(x).data
17 |
18 | dct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data
19 | dct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data
20 | idct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data
21 | idct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data
22 |
23 | def test_dct1():
24 | for N in [2, 5, 32, 111]:
25 | x = np.random.normal(size=(1, N,))
26 | ref = fftpack.dct(x, type=1)
27 | act = dct1(torch.tensor(x).float()).numpy()
28 | assert np.abs(ref - act).max() < EPS, ref
29 |
30 | for d in [2, 3, 4]:
31 | x = np.random.normal(size=(2,) * d)
32 | ref = fftpack.dct(x, type=1)
33 | act = dct1(torch.tensor(x).float()).numpy()
34 | assert np.abs(ref - act).max() < EPS, ref
35 |
36 |
37 | def test_idct1():
38 | for N in [2, 5, 32, 111]:
39 | x = np.random.normal(size=(1, N))
40 | X = dct1(torch.tensor(x).float())
41 | y = idct1(X).numpy()
42 | assert np.abs(x - y).max() < EPS, x
43 |
44 |
45 | def test_dct():
46 | for norm in [None, 'ortho']:
47 | for N in [2, 3, 5, 32, 111]:
48 | x = np.random.normal(size=(1, N,))
49 | ref = fftpack.dct(x, type=2, norm=norm)
50 | act = dct(torch.tensor(x).float(), norm=norm).numpy()
51 | assert np.abs(ref - act).max() < EPS, (norm, N)
52 |
53 | for d in [2, 3, 4, 11]:
54 | x = np.random.normal(size=(2,) * d)
55 | ref = fftpack.dct(x, type=2, norm=norm)
56 | act = dct(torch.tensor(x).float(), norm=norm).numpy()
57 | assert np.abs(ref - act).max() < EPS, (norm, d)
58 |
59 |
60 | def test_idct():
61 | for norm in [None, 'ortho']:
62 | for N in [5, 2, 32, 111]:
63 | x = np.random.normal(size=(1, N))
64 | X = dct(torch.tensor(x).float(), norm=norm)
65 | y = idct(X, norm=norm).numpy()
66 | assert np.abs(x - y).max() < EPS, x
67 |
68 | def test_dct_2d():
69 | for N1 in [2, 5, 32]:
70 | x = np.random.normal(size=(1, N1, N1))
71 | ref = fftpack.dct(x, axis=2, type=2)
72 | ref = fftpack.dct(ref, axis=1, type=2)
73 | act = dct_2d(torch.tensor(x).float()).numpy()
74 | assert np.abs(ref - act).max() < EPS, (ref, act)
75 |
76 |
77 | def test_idct_2d():
78 | for N1 in [2, 5, 32]:
79 | x = np.random.normal(size=(1, N1, N1))
80 | X = dct_2d(torch.tensor(x).float())
81 | y = idct_2d(X).numpy()
82 | assert np.abs(x - y).max() < EPS, x
83 |
84 |
85 | def test_dct_3d():
86 | for N1 in [2, 5, 32]:
87 | x = np.random.normal(size=(1, N1, N1, N1))
88 | ref = fftpack.dct(x, axis=3, type=2)
89 | ref = fftpack.dct(ref, axis=2, type=2)
90 | ref = fftpack.dct(ref, axis=1, type=2)
91 | act = dct_3d(torch.tensor(x).float()).numpy()
92 | assert np.abs(ref - act).max() < EPS, (ref, act)
93 |
94 |
95 | def test_idct_3d():
96 | for N1 in [2, 5, 32]:
97 | x = np.random.normal(size=(1, N1, N1, N1))
98 | X = dct_3d(torch.tensor(x).float())
99 | y = idct_3d(X).numpy()
100 | assert np.abs(x - y).max() < EPS, x
101 |
--------------------------------------------------------------------------------