├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── build.py ├── pytorch_fft ├── __init__.py ├── fft │ ├── __init__.py │ ├── autograd.py │ └── fft.py └── src │ ├── generic │ ├── helpers.c │ ├── th_fft_cuda.c │ ├── th_fft_cuda.h │ ├── th_irfft_cuda.c │ └── th_rfft_cuda.c │ ├── th_fft_cuda.c │ ├── th_fft_cuda.h │ ├── th_fft_generate_double.h │ ├── th_fft_generate_float.h │ └── th_fft_generate_helpers.h ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | pytorch_fft/_ext 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include build.py 2 | recursive-include pytorch_fft/src * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A PyTorch wrapper for CUDA FFTs [![License][license-image]][license] 2 | 3 | [license-image]: http://img.shields.io/badge/license-Apache--2-blue.svg?style=flat 4 | [license]: LICENSE 5 | 6 | *A package that provides a PyTorch C extension for performing batches of 2D CuFFT 7 | transformations, by [Eric Wong](https://github.com/riceric22)* 8 | 9 | Update: FFT functionality is now officially in PyTorch 0.4, see the 10 | documentation [here](https://pytorch.org/docs/0.4.0/torch.html?highlight=fft#torch.fft). 11 | This repository is only useful for older versions of PyTorch, and will no longer 12 | be updated. 13 | 14 | ## Installation 15 | 16 | This package is on PyPi. Install with `pip install pytorch-fft`. 17 | 18 | ## Usage 19 | 20 | + From the `pytorch_fft.fft` module, you can use the following to do 21 | foward and backward FFT transformations (complex to complex) 22 | + `fft` and `ifft` for 1D transformations 23 | + `fft2` and `ifft2` for 2D transformations 24 | + `fft3` and `ifft3` for 3D transformations 25 | + From the same module, you can also use the following for 26 | real to complex / complex to real FFT transformations 27 | + `rfft` and `irfft` for 1D transformations 28 | + `rfft2` and `irfft2` for 2D transformations 29 | + `rfft3` and `irfft3` for 3D transformations 30 | + For an `d`-D transformation, the input tensors are required to have >= (d+1) 31 | dimensions (n1 x ... x nk x m1 x ... x md) where `n1 x ... x nk` is the 32 | batch of FFT transformations, and `m1 x ... x md` are the dimensions of the 33 | `d`-D transformation. `d` must be a number from 1 to 3. 34 | + Finally, the module contains the following helper functions you may find 35 | useful 36 | + `reverse(X, group_size=1)` reverses the elements of a tensor and returns 37 | the result in a new tensor. Note that PyTorch does not current support 38 | negative slicing, see this 39 | [issue](https://github.com/pytorch/pytorch/issues/229). If a group size is 40 | supplied, the elements will be reversed in groups of that size. 41 | + `expand(X, imag=False, odd=True)` takes a tensor output of a real 2D or 3D 42 | FFT and expands it with its redundant entries to match the output of a 43 | complex FFT. 44 | + For autograd support, use the following functions in the 45 | `pytorch_fft.fft.autograd` module: 46 | + `Fft` and `Ifft` for 1D transformations 47 | + `Fft2d` and `Ifft2d` for 2D transformations 48 | + `Fft3d` and `Ifft3d` for 3D transformations 49 | 50 | 51 | ```Python 52 | # Example that does a batch of three 2D transformations of size 4 by 5. 53 | import torch 54 | import pytorch_fft.fft as fft 55 | 56 | A_real, A_imag = torch.randn(3,4,5).cuda(), torch.zeros(3,4,5).cuda() 57 | B_real, B_imag = fft.fft2(A_real, A_imag) 58 | fft.ifft2(B_real, B_imag) # equals (A, zeros) 59 | 60 | B_real, B_imag = fft.rfft2(A) # is a truncated version which omits 61 | # redundant entries 62 | 63 | reverse(torch.arange(0,6)) # outputs [5,4,3,2,1,0] 64 | reverse(torch.arange(0,6), 2) # outputs [4,5,2,3,0,1] 65 | 66 | expand(B_real) # is equivalent to fft.fft2(A, zeros)[0] 67 | expand(B_imag, imag=True) # is equivalent to fft.fft2(A, zeros)[1] 68 | ``` 69 | 70 | 71 | ```Python 72 | # Example that uses the autograd for 2D fft: 73 | import torch 74 | from torch.autograd import Variable 75 | import pytorch_fft.fft.autograd as fft 76 | import numpy as np 77 | 78 | f = fft.Fft2d() 79 | invf= fft.Ifft2d() 80 | 81 | fx, fy = (Variable(torch.arange(0,100).view((1,1,10,10)).cuda(), requires_grad=True), 82 | Variable(torch.zeros(1, 1, 10, 10).cuda(),requires_grad=True)) 83 | k1,k2 = f(fx,fy) 84 | z = k1.sum() + k2.sum() 85 | z.backward() 86 | print(fx.grad, fy.grad) 87 | ``` 88 | 89 | ## Notes 90 | + This follows NumPy semantics and behavior, so `ifft2(fft2(x)) = x`. Note 91 | that CuFFT semantics for inverse FFT only flip the sign of the transform, 92 | but it is not a true inverse. 93 | + Similarly, the real to complex / complex to real variants also follow NumPy 94 | semantics and behavior. In the 1D case, this means that for an input of size 95 | `N`, it returns an output of size `N//2+1` (it omits redundant entries, see 96 | the [Numpy docs](https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.rfft.html)) 97 | + The functions in the `pytorch_fft.fft` module do not implement the PyTorch 98 | autograd `Function`, and are semantically and functionally like their numpy 99 | equivalents. 100 | + Autograd functionality is in the `pytorch_fft.fft.autograd` module. 101 | 102 | ## Repository contents 103 | - pytorch_fft/src: C source code 104 | - pytorch_fft/fft: Python convenience wrapper 105 | - build.py: compilation file 106 | - test.py: tests against NumPy FFTs and Autograd checks 107 | 108 | ## Issues and Contributions 109 | 110 | If you have any issues or feature requests, 111 | [file an issue](https://github.com/bamos/block/issues) 112 | or [send in a PR](https://github.com/bamos/block/pulls). 113 | 114 | -------------------------------------------------------------------------------- /build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | sources = [] 8 | headers = [] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['pytorch_fft/src/th_fft_cuda.c'] 15 | headers += ['pytorch_fft/src/th_fft_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | ffi = create_extension( 20 | 'pytorch_fft._ext.th_fft', 21 | package=True, 22 | headers=headers, 23 | sources=sources, 24 | define_macros=defines, 25 | relative_to=__file__, 26 | with_cuda=with_cuda, 27 | include_dirs=[os.getcwd() + '/pytorch_fft/src'], 28 | library_dirs=['/usr/local/cuda/lib64'], 29 | libraries=['cufft'] 30 | ) 31 | 32 | if __name__ == '__main__': 33 | ffi.build() 34 | -------------------------------------------------------------------------------- /pytorch_fft/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fft -------------------------------------------------------------------------------- /pytorch_fft/fft/__init__.py: -------------------------------------------------------------------------------- 1 | from .fft import * 2 | from .autograd import * -------------------------------------------------------------------------------- /pytorch_fft/fft/autograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .fft import fft,ifft,fft2,ifft2,fft3,ifft3,rfft,irfft,rfft2,irfft2,rfft3,irfft3 3 | 4 | def make_contiguous(*Xs): 5 | return tuple(X if X.is_contiguous() else X.contiguous() for X in Xs) 6 | 7 | def contiguous_clone(X): 8 | if X.is_contiguous(): 9 | return X.clone() 10 | else: 11 | return X.contiguous() 12 | 13 | class Fft(torch.autograd.Function): 14 | def forward(self, X_re, X_im): 15 | X_re, X_im = make_contiguous(X_re, X_im) 16 | return fft(X_re, X_im) 17 | 18 | def backward(self, grad_output_re, grad_output_im): 19 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 20 | grad_output_im) 21 | gi, gr = fft(grad_output_im,grad_output_re) 22 | return gr,gi 23 | 24 | 25 | class Ifft(torch.autograd.Function): 26 | 27 | def forward(self, k_re, k_im): 28 | k_re, k_im = make_contiguous(k_re, k_im) 29 | return ifft(k_re, k_im) 30 | 31 | def backward(self, grad_output_re, grad_output_im): 32 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 33 | grad_output_im) 34 | gi, gr = ifft(grad_output_im,grad_output_re) 35 | return gr, gi 36 | 37 | 38 | class Fft2d(torch.autograd.Function): 39 | def forward(self, X_re, X_im): 40 | X_re, X_im = make_contiguous(X_re, X_im) 41 | return fft2(X_re, X_im) 42 | 43 | def backward(self, grad_output_re, grad_output_im): 44 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 45 | grad_output_im) 46 | gi, gr = fft2(grad_output_im,grad_output_re) 47 | return gr,gi 48 | 49 | 50 | class Ifft2d(torch.autograd.Function): 51 | 52 | def forward(self, k_re, k_im): 53 | k_re, k_im = make_contiguous(k_re, k_im) 54 | return ifft2(k_re, k_im) 55 | 56 | def backward(self, grad_output_re, grad_output_im): 57 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 58 | grad_output_im) 59 | gi, gr = ifft2(grad_output_im,grad_output_re) 60 | return gr, gi 61 | 62 | 63 | class Fft3d(torch.autograd.Function): 64 | def forward(self, X_re, X_im): 65 | X_re, X_im = make_contiguous(X_re, X_im) 66 | return fft3(X_re, X_im) 67 | 68 | def backward(self, grad_output_re, grad_output_im): 69 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 70 | grad_output_im) 71 | gi, gr = fft3(grad_output_im,grad_output_re) 72 | return gr,gi 73 | 74 | 75 | class Ifft3d(torch.autograd.Function): 76 | 77 | def forward(self, k_re, k_im): 78 | k_re, k_im = make_contiguous(k_re, k_im) 79 | return ifft3(k_re, k_im) 80 | 81 | def backward(self, grad_output_re, grad_output_im): 82 | grad_output_re, grad_output_im = make_contiguous(grad_output_re, 83 | grad_output_im) 84 | gi, gr = ifft3(grad_output_im,grad_output_re) 85 | return gr, gi 86 | 87 | 88 | class Rfft(torch.autograd.Function): 89 | def forward(self, X_re): 90 | X_re = X_re.contiguous() 91 | self._to_save_input_size = X_re.size(-1) 92 | return rfft(X_re) 93 | 94 | def backward(self, grad_output_re, grad_output_im): 95 | # Clone the array and make contiguous if needed 96 | grad_output_re = contiguous_clone(grad_output_re) 97 | grad_output_im = contiguous_clone(grad_output_im) 98 | 99 | if self._to_save_input_size & 1: 100 | grad_output_re[...,1:] /= 2 101 | else: 102 | grad_output_re[...,1:-1] /= 2 103 | 104 | if self._to_save_input_size & 1: 105 | grad_output_im[...,1:] /= 2 106 | else: 107 | grad_output_im[...,1:-1] /= 2 108 | 109 | gr = irfft(grad_output_re,grad_output_im,self._to_save_input_size, normalize=False) 110 | return gr 111 | 112 | 113 | class Irfft(torch.autograd.Function): 114 | 115 | def forward(self, k_re, k_im): 116 | k_re, k_im = make_contiguous(k_re, k_im) 117 | return irfft(k_re, k_im) 118 | 119 | def backward(self, grad_output_re): 120 | grad_output_re = grad_output_re.contiguous() 121 | gr, gi = rfft(grad_output_re) 122 | 123 | N = grad_output_re.size(-1) 124 | gr[...,0] /= N 125 | gr[...,1:-1] /= N/2 126 | gr[...,-1] /= N 127 | 128 | gi[...,0] /= N 129 | gi[...,1:-1] /= N/2 130 | gi[...,-1] /= N 131 | return gr, gi 132 | 133 | 134 | class Rfft2d(torch.autograd.Function): 135 | def forward(self, X_re): 136 | X_re = X_re.contiguous() 137 | self._to_save_input_size = X_re.size(-1) 138 | return rfft2(X_re) 139 | 140 | def backward(self, grad_output_re, grad_output_im): 141 | # Clone the array and make contiguous if needed 142 | grad_output_re = contiguous_clone(grad_output_re) 143 | grad_output_im = contiguous_clone(grad_output_im) 144 | 145 | if self._to_save_input_size & 1: 146 | grad_output_re[...,1:] /= 2 147 | else: 148 | grad_output_re[...,1:-1] /= 2 149 | 150 | if self._to_save_input_size & 1: 151 | grad_output_im[...,1:] /= 2 152 | else: 153 | grad_output_im[...,1:-1] /= 2 154 | 155 | gr = irfft2(grad_output_re,grad_output_im,self._to_save_input_size, normalize=False) 156 | return gr 157 | 158 | 159 | class Irfft2d(torch.autograd.Function): 160 | 161 | def forward(self, k_re, k_im): 162 | k_re, k_im = make_contiguous(k_re, k_im) 163 | return irfft2(k_re, k_im) 164 | 165 | def backward(self, grad_output_re): 166 | grad_output_re = grad_output_re.contiguous() 167 | gr, gi = rfft2(grad_output_re) 168 | 169 | N = grad_output_re.size(-1) * grad_output_re.size(-2) 170 | gr[...,0] /= N 171 | gr[...,1:-1] /= N/2 172 | gr[...,-1] /= N 173 | 174 | gi[...,0] /= N 175 | gi[...,1:-1] /= N/2 176 | gi[...,-1] /= N 177 | return gr, gi 178 | 179 | 180 | class Rfft3d(torch.autograd.Function): 181 | def forward(self, X_re): 182 | X_re = X_re.contiguous() 183 | self._to_save_input_size = X_re.size(-1) 184 | return rfft3(X_re) 185 | 186 | def backward(self, grad_output_re, grad_output_im): 187 | # Clone the array and make contiguous if needed 188 | grad_output_re = contiguous_clone(grad_output_re) 189 | grad_output_im = contiguous_clone(grad_output_im) 190 | 191 | if self._to_save_input_size & 1: 192 | grad_output_re[...,1:] /= 2 193 | else: 194 | grad_output_re[...,1:-1] /= 2 195 | 196 | if self._to_save_input_size & 1: 197 | grad_output_im[...,1:] /= 2 198 | else: 199 | grad_output_im[...,1:-1] /= 2 200 | 201 | gr = irfft3(grad_output_re,grad_output_im,self._to_save_input_size, normalize=False) 202 | return gr 203 | 204 | 205 | class Irfft3d(torch.autograd.Function): 206 | 207 | def forward(self, k_re, k_im): 208 | k_re, k_im = make_contiguous(k_re, k_im) 209 | return irfft3(k_re, k_im) 210 | 211 | def backward(self, grad_output_re): 212 | grad_output_re = grad_output_re.contiguous() 213 | gr, gi = rfft3(grad_output_re) 214 | 215 | N = grad_output_re.size(-1) * grad_output_re.size(-2) * grad_output_re.size(-3) 216 | gr[...,0] /= N 217 | gr[...,1:-1] /= N/2 218 | gr[...,-1] /= N 219 | 220 | gi[...,0] /= N 221 | gi[...,1:-1] /= N/2 222 | gi[...,-1] /= N 223 | return gr, gi 224 | 225 | -------------------------------------------------------------------------------- /pytorch_fft/fft/fft.py: -------------------------------------------------------------------------------- 1 | # functions/fft.py 2 | import torch 3 | from .._ext import th_fft 4 | 5 | def _fft(X_re, X_im, f, rank): 6 | if not(X_re.size() == X_im.size()): 7 | raise ValueError("Real and imaginary tensors must have the same dimension.") 8 | if not(X_re.dim() >= rank+1 and X_im.dim() >= rank+1): 9 | raise ValueError("Inputs must have at least {} dimensions.".format(rank+1)) 10 | if not(X_re.is_cuda and X_im.is_cuda): 11 | raise ValueError("Input must be a CUDA tensor.") 12 | if not(X_re.is_contiguous() and X_im.is_contiguous()): 13 | raise ValueError("Input must be contiguous.") 14 | 15 | Y1, Y2 = tuple(X_re.new(*X_re.size()).zero_() for _ in range(2)) 16 | f(X_re, X_im, Y1, Y2) 17 | return (Y1, Y2) 18 | 19 | def fft(X_re, X_im): 20 | if X_re.dtype == torch.float32: 21 | f = th_fft.th_Float_fft1 22 | elif X_re.dtype == torch.float64: 23 | f = th_fft.th_Double_fft1 24 | else: 25 | raise NotImplementedError 26 | return _fft(X_re, X_im, f, 1) 27 | 28 | def ifft(X_re, X_im): 29 | N = X_re.size(-1) 30 | if X_re.dtype == torch.float32: 31 | f = th_fft.th_Float_ifft1 32 | elif X_re.dtype == torch.float64: 33 | f = th_fft.th_Double_ifft1 34 | else: 35 | raise NotImplementedError 36 | Y1, Y2 = _fft(X_re, X_im, f, 1) 37 | return (Y1/N, Y2/N) 38 | 39 | def fft2(X_re, X_im): 40 | if X_re.dtype == torch.float32: 41 | f = th_fft.th_Float_fft2 42 | elif X_re.dtype == torch.float64: 43 | f = th_fft.th_Double_fft2 44 | else: 45 | raise NotImplementedError 46 | return _fft(X_re, X_im, f, 2) 47 | 48 | def ifft2(X_re, X_im): 49 | N = X_re.size(-1)*X_re.size(-2) 50 | if X_re.dtype == torch.float32: 51 | f = th_fft.th_Float_ifft2 52 | elif X_re.dtype == torch.float64: 53 | f = th_fft.th_Double_ifft2 54 | else: 55 | raise NotImplementedError 56 | Y1, Y2 = _fft(X_re, X_im, f, 2) 57 | return (Y1/N, Y2/N) 58 | 59 | def fft3(X_re, X_im): 60 | if X_re.dtype == torch.float32: 61 | f = th_fft.th_Float_fft3 62 | elif X_re.dtype == torch.float64: 63 | f = th_fft.th_Double_fft3 64 | else: 65 | raise NotImplementedError 66 | return _fft(X_re, X_im, f, 3) 67 | 68 | def ifft3(X_re, X_im): 69 | N = X_re.size(-1)*X_re.size(-2)*X_re.size(-3) 70 | if X_re.dtype == torch.float32: 71 | f = th_fft.th_Float_ifft3 72 | elif X_re.dtype == torch.float64: 73 | f = th_fft.th_Double_ifft3 74 | else: 75 | raise NotImplementedError 76 | Y1, Y2 = _fft(X_re, X_im, f, 3) 77 | return (Y1/N, Y2/N) 78 | 79 | _s = slice(None, None, None) 80 | 81 | def _rfft(X, f, rank): 82 | if not(X.dim() >= rank+1): 83 | raise ValueError("Input must have at least {} dimensions.".format(rank+1)) 84 | if not(X.is_cuda): 85 | raise ValueError("Input must be a CUDA tensor.") 86 | if not(X.is_contiguous()): 87 | raise ValueError("Input must be contiguous.") 88 | 89 | new_size = tuple(X.size())[:-1] + (X.size(-1)//2 + 1,) 90 | # new_size = tuple(X.size()) 91 | Y1, Y2 = tuple(X.new(*new_size).zero_() for _ in range(2)) 92 | f(X, Y1, Y2) 93 | # i = tuple(_s for _ in range(X.dim()-1)) + (slice(None, X.size(-1)//2 + 1, ),) 94 | # print(Y1, i) 95 | # return (Y1[i], Y2[i]) 96 | return (Y1, Y2) 97 | 98 | def rfft(X): 99 | if X.dtype == torch.float32: 100 | f = th_fft.th_Float_rfft1 101 | elif X.dtype == torch.float64: 102 | f = th_fft.th_Double_rfft1 103 | else: 104 | raise NotImplementedError 105 | return _rfft(X, f, 1) 106 | 107 | def rfft2(X): 108 | if X.dtype == torch.float32: 109 | f = th_fft.th_Float_rfft2 110 | elif X.dtype == torch.float64: 111 | f = th_fft.th_Double_rfft2 112 | else: 113 | raise NotImplementedError 114 | return _rfft(X, f, 2) 115 | 116 | def rfft3(X): 117 | if X.dtype == torch.float32: 118 | f = th_fft.th_Float_rfft3 119 | elif X.dtype == torch.float64: 120 | f = th_fft.th_Double_rfft3 121 | else: 122 | raise NotImplementedError 123 | return _rfft(X, f, 3) 124 | 125 | def _irfft(X_re, X_im, f, rank, N, normalize): 126 | if not(X_re.size() == X_im.size()): 127 | raise ValueError("Real and imaginary tensors must have the same dimension.") 128 | if not(X_re.dim() >= rank+1 and X_im.dim() >= rank+1): 129 | raise ValueError("Inputs must have at least {} dimensions.".format(rank+1)) 130 | if not(X_re.is_cuda and X_im.is_cuda): 131 | raise ValueError("Input must be a CUDA tensor.") 132 | if not(X_re.is_contiguous() and X_im.is_contiguous()): 133 | raise ValueError("Input must be contiguous.") 134 | 135 | input_size = X_re.size(-1) 136 | 137 | if N is not None: 138 | if input_size != int(N/2) + 1: 139 | raise ValueError("Input size must be equal to n/2 + 1") 140 | else: 141 | N = (X_re.size(-1) - 1)*2 142 | 143 | new_size = tuple(X_re.size())[:-1] + (N,) 144 | Y = X_re.new(*new_size).zero_() 145 | f(X_re, X_im, Y) 146 | 147 | if normalize: 148 | M = 1 149 | for i in range(rank): 150 | M *= new_size[-(i+1)] 151 | return Y/M 152 | else: 153 | return Y 154 | 155 | def irfft(X_re, X_im, n=None, normalize=True): 156 | if X_re.dtype == torch.float32: 157 | f = th_fft.th_Float_irfft1 158 | elif X_re.dtype == torch.float64: 159 | f = th_fft.th_Double_irfft1 160 | else: 161 | raise NotImplementedError 162 | return _irfft(X_re, X_im, f, 1, n, normalize) 163 | 164 | def irfft2(X_re, X_im, n=None, normalize=True): 165 | if X_re.dtype == torch.float32: 166 | f = th_fft.th_Float_irfft2 167 | elif X_re.dtype == torch.float64: 168 | f = th_fft.th_Double_irfft2 169 | else: 170 | raise NotImplementedError 171 | return _irfft(X_re, X_im, f, 2, n, normalize) 172 | 173 | def irfft3(X_re, X_im, n=None, normalize=True): 174 | if X_re.dtype == torch.float32: 175 | f = th_fft.th_Float_irfft3 176 | elif X_re.dtype == torch.float64: 177 | f = th_fft.th_Double_irfft3 178 | else: 179 | raise NotImplementedError 180 | return _irfft(X_re, X_im, f, 3, n, normalize) 181 | 182 | def reverse(X, group_size=1): 183 | if not(X.is_cuda): 184 | raise ValueError("Input must be a CUDA tensor.") 185 | if not(X.is_contiguous()): 186 | raise ValueError("Input must be contiguous.") 187 | 188 | if X.dtype == torch.float32: 189 | f = th_fft.reverse_Float 190 | elif X.dtype == torch.float64: 191 | f = th_fft.reverse_Double 192 | else: 193 | raise NotImplementedError 194 | Y = X.new(*X.size()) 195 | f(X,Y, group_size) 196 | return Y 197 | 198 | 199 | def expand(X, imag=False, odd=False): 200 | N1, N2 = X.size(-2), X.size(-1) 201 | N3 = (X.size(-1) - 1)*2 202 | if odd: 203 | N3 += 1 204 | new_size = tuple(X.size())[:-1] + (N3,) 205 | Y = X.new(*new_size).zero_() 206 | i = tuple(slice(None, None, None) for _ in range(X.dim() - 1)) + (slice(None,N2, None),) 207 | Y[i] = X 208 | 209 | if odd: 210 | i = tuple(slice(None, None, None) for _ in range(X.dim() - 1)) + (slice(-(N3-N2),None, None),) 211 | else: 212 | i = tuple(slice(None, None, None) for _ in range(X.dim() - 1)) + (slice(-(1+N3-N2),-1, None),) 213 | X0 = X[i].contiguous() 214 | 215 | X0 = reverse(X0) 216 | i0 = (tuple(slice(None, None, None) for _ in range(X.dim() - 2)) + 217 | (slice(-1,None, None), slice(None, None, None))) 218 | i1 = (tuple(slice(None, None, None) for _ in range(X.dim() - 2)) + 219 | (slice(None, -1, None), slice(None, None, None))) 220 | X0 = torch.cat([X0[i0], X0[i1]], -2) 221 | X0 = reverse(X0, N1*(N3-N2)) 222 | 223 | i = tuple(slice(None, None, None) for _ in range(X.dim() - 1)) + (slice(N2, None, None),) 224 | if not imag: 225 | Y[i] = X0 226 | else: 227 | Y[i] = -X0 228 | return Y 229 | 230 | def roll_n(X, axis, n): 231 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0,n,None) 232 | for i in range(X.dim())) 233 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n,None,None) 234 | for i in range(X.dim())) 235 | front = X[f_idx] 236 | back = X[b_idx] 237 | return torch.cat([back, front],axis) 238 | -------------------------------------------------------------------------------- /pytorch_fft/src/generic/helpers.c: -------------------------------------------------------------------------------- 1 | #ifndef THC_GENERIC_FILE 2 | #define THC_GENERIC_FILE "generic/helpers.c" 3 | #else 4 | 5 | // helper to convert a pair of real arrays into a complex array 6 | void pair2complex(real *a, real *b, cufft_complex *c, int n) 7 | { 8 | real *c_tmp = (real*)c; 9 | cudaMemcpy2D(c_tmp, 2*sizeof(real), 10 | a, sizeof(real), 11 | sizeof(real), n, cudaMemcpyDeviceToDevice); 12 | cudaMemcpy2D(c_tmp+1, 2*sizeof(real), 13 | b, sizeof(real), 14 | sizeof(real), n, cudaMemcpyDeviceToDevice); 15 | } 16 | 17 | void complex2pair(cufft_complex *a, real *b, real *c, int n) 18 | { 19 | real *a_tmp = (real*)a; 20 | cudaMemcpy2D(b, sizeof(real), 21 | a_tmp, 2*sizeof(real), 22 | sizeof(real), n, cudaMemcpyDeviceToDevice); 23 | cudaMemcpy2D(c, sizeof(real), 24 | a_tmp+1, 2*sizeof(real), 25 | sizeof(real), n, cudaMemcpyDeviceToDevice); 26 | } 27 | 28 | void reverse_(THCTensor *input, THCTensor *output, int group_size) 29 | { 30 | real *input_data = THCTensor_(data)(state, input); 31 | real *output_data = THCTensor_(data)(state, output); 32 | int n = THCTensor_(nElement)(state, input); 33 | 34 | cudaMemcpy2D(output_data, sizeof(real)*group_size, 35 | input_data+n-group_size, -sizeof(real)*group_size, 36 | sizeof(real)*group_size, n/group_size, cudaMemcpyDeviceToDevice); 37 | } 38 | 39 | #endif -------------------------------------------------------------------------------- /pytorch_fft/src/generic/th_fft_cuda.c: -------------------------------------------------------------------------------- 1 | #ifndef THC_GENERIC_FILE 2 | #define THC_GENERIC_FILE "generic/th_fft_cuda.c" 3 | #else 4 | 5 | int th_(THCTensor *input1, THCTensor *input2, THCTensor *output1, THCTensor *output2) 6 | { 7 | // Require that all tensors be of the same size. 8 | if (!THCTensor_(isSameSizeAs)(state, input1, output1)) 9 | return 0; 10 | if (!THCTensor_(isSameSizeAs)(state, input1, output2)) 11 | return 0; 12 | if (!THCTensor_(isSameSizeAs)(state, input1, input2)) 13 | return 0; 14 | 15 | // Get the tensor dimensions (batchsize, rows, cols). 16 | int ndim = THCTensor_(nDimension)(state, input1); 17 | int batch = 1; 18 | int i, d; 19 | for(i=0; i 2 | #include 3 | #include 4 | #include 5 | // this symbol will be resolved automatically from PyTorch libs 6 | extern THCState *state; 7 | 8 | #define th_ TH_CONCAT_4(th_, Real, _, func_name) 9 | #define pair2complex TH_CONCAT_2(Real, 2complex) 10 | #define complex2pair TH_CONCAT_2(complex2, Real) 11 | #define reverse_ TH_CONCAT_2(reverse_, Real) 12 | 13 | #include "th_fft_generate_helpers.h" 14 | 15 | #define cufft_rank 1 16 | #include "th_fft_generate_float.h" 17 | #include "th_fft_generate_double.h" 18 | #undef cufft_rank 19 | 20 | #define cufft_rank 2 21 | #include "th_fft_generate_float.h" 22 | #include "th_fft_generate_double.h" 23 | #undef cufft_rank 24 | 25 | #define cufft_rank 3 26 | #include "th_fft_generate_float.h" 27 | #include "th_fft_generate_double.h" 28 | #undef cufft_rank 29 | -------------------------------------------------------------------------------- /pytorch_fft/src/th_fft_cuda.h: -------------------------------------------------------------------------------- 1 | int th_Float_fft1(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 2 | int th_Float_ifft1(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 3 | int th_Double_fft1(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 4 | int th_Double_ifft1(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 5 | 6 | int th_Float_fft2(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 7 | int th_Float_ifft2(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 8 | int th_Double_fft2(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 9 | int th_Double_ifft2(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 10 | 11 | int th_Float_fft3(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 12 | int th_Float_ifft3(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1, THCudaTensor *output2); 13 | int th_Double_fft3(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 14 | int th_Double_ifft3(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 15 | 16 | int th_Float_rfft1(THCudaTensor *input1, THCudaTensor *output1, THCudaTensor *output2); 17 | int th_Float_irfft1(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1); 18 | int th_Double_rfft1(THCudaDoubleTensor *input1, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 19 | int th_Double_irfft1(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1); 20 | 21 | int th_Float_rfft2(THCudaTensor *input1, THCudaTensor *output1, THCudaTensor *output2); 22 | int th_Float_irfft2(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1); 23 | int th_Double_rfft2(THCudaDoubleTensor *input1, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 24 | int th_Double_irfft2(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1); 25 | 26 | int th_Float_rfft3(THCudaTensor *input1, THCudaTensor *output1, THCudaTensor *output2); 27 | int th_Float_irfft3(THCudaTensor *input1, THCudaTensor *input2, THCudaTensor *output1); 28 | int th_Double_rfft3(THCudaDoubleTensor *input1, THCudaDoubleTensor *output1, THCudaDoubleTensor *output2); 29 | int th_Double_irfft3(THCudaDoubleTensor *input1, THCudaDoubleTensor *input2, THCudaDoubleTensor *output1); 30 | 31 | void reverse_Float(THCudaTensor *input, THCudaTensor *output, int group_size); 32 | void reverse_Double(THCudaDoubleTensor *input, THCudaDoubleTensor *output, int group_size); 33 | 34 | // void expand_2D_Float(THCudaTensor *input, THCudaTensor *output); 35 | // void expand_2D_Double(THCudaDoubleTensor *input, THCudaDoubleTensor *output); -------------------------------------------------------------------------------- /pytorch_fft/src/th_fft_generate_double.h: -------------------------------------------------------------------------------- 1 | // Generate Double FFTs 2 | #define cufft_complex cufftDoubleComplex 3 | 4 | #define cufft_type CUFFT_Z2Z 5 | #define cufft_exec cufftExecZ2Z 6 | 7 | #define cufft_direction CUFFT_FORWARD 8 | #define func_name TH_CONCAT_2(fft, cufft_rank) 9 | 10 | #include "generic/th_fft_cuda.c" 11 | #include "THCGenerateDoubleType.h" 12 | 13 | #undef cufft_direction 14 | #undef func_name 15 | 16 | #define cufft_direction CUFFT_INVERSE 17 | #define func_name TH_CONCAT_2(ifft, cufft_rank) 18 | 19 | #include "generic/th_fft_cuda.c" 20 | #include "THCGenerateDoubleType.h" 21 | 22 | #undef cufft_direction 23 | #undef func_name 24 | 25 | #undef cufft_type 26 | #undef cufft_exec 27 | 28 | // Generate Double rFFTs 29 | #define cufft_type CUFFT_D2Z 30 | #define cufft_exec cufftExecD2Z 31 | #define func_name TH_CONCAT_2(rfft, cufft_rank) 32 | 33 | #include "generic/th_rfft_cuda.c" 34 | #include "THCGenerateDoubleType.h" 35 | 36 | #undef cufft_type 37 | #undef cufft_exec 38 | #undef func_name 39 | 40 | #define cufft_type CUFFT_Z2D 41 | #define cufft_exec cufftExecZ2D 42 | #define func_name TH_CONCAT_2(irfft, cufft_rank) 43 | 44 | #include "generic/th_irfft_cuda.c" 45 | #include "THCGenerateDoubleType.h" 46 | 47 | #undef cufft_type 48 | #undef cufft_exec 49 | #undef func_name 50 | 51 | #undef cufft_complex -------------------------------------------------------------------------------- /pytorch_fft/src/th_fft_generate_float.h: -------------------------------------------------------------------------------- 1 | // Generate float FFTs 2 | #define cufft_complex cufftComplex 3 | 4 | #define cufft_type CUFFT_C2C 5 | #define cufft_exec cufftExecC2C 6 | 7 | #define cufft_direction CUFFT_FORWARD 8 | #define func_name TH_CONCAT_2(fft, cufft_rank) 9 | 10 | #include "generic/th_fft_cuda.c" 11 | #include "THCGenerateFloatType.h" 12 | 13 | #undef func_name 14 | #undef cufft_direction 15 | 16 | #define cufft_direction CUFFT_INVERSE 17 | #define func_name TH_CONCAT_2(ifft, cufft_rank) 18 | 19 | #include "generic/th_fft_cuda.c" 20 | #include "THCGenerateFloatType.h" 21 | 22 | #undef func_name 23 | #undef cufft_direction 24 | 25 | 26 | #undef cufft_type 27 | #undef cufft_exec 28 | 29 | // Generate float rFFTs 30 | #define cufft_type CUFFT_R2C 31 | #define cufft_exec cufftExecR2C 32 | #define func_name TH_CONCAT_2(rfft, cufft_rank) 33 | 34 | #include "generic/th_rfft_cuda.c" 35 | #include "THCGenerateFloatType.h" 36 | 37 | #undef func_name 38 | #undef cufft_type 39 | #undef cufft_exec 40 | 41 | #define cufft_type CUFFT_C2R 42 | #define cufft_exec cufftExecC2R 43 | #define func_name TH_CONCAT_2(irfft, cufft_rank) 44 | 45 | #include "generic/th_irfft_cuda.c" 46 | #include "THCGenerateFloatType.h" 47 | 48 | #undef func_name 49 | #undef cufft_type 50 | #undef cufft_exec 51 | 52 | #undef cufft_complex -------------------------------------------------------------------------------- /pytorch_fft/src/th_fft_generate_helpers.h: -------------------------------------------------------------------------------- 1 | // Generate float and double helpers 2 | #define cufft_complex cufftComplex 3 | 4 | #include "generic/helpers.c" 5 | #include "THCGenerateFloatType.h" 6 | 7 | #undef cufft_complex 8 | 9 | #define cufft_complex cufftDoubleComplex 10 | 11 | #include "generic/helpers.c" 12 | #include "THCGenerateDoubleType.h" 13 | 14 | #undef cufft_complex -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import setup, find_packages 5 | 6 | import build 7 | 8 | this_file = os.path.dirname(__file__) 9 | 10 | setup( 11 | name="pytorch_fft", 12 | version="0.15", 13 | description="A PyTorch wrapper for CUDA FFTs", 14 | url="https://github.com/locuslab/pytorch_fft", 15 | author="Eric Wong", 16 | author_email="ericwong@cs.cmu.edu", 17 | # Require cffi. 18 | install_requires=["cffi>=1.0.0"], 19 | setup_requires=["cffi>=1.0.0"], 20 | # Exclude the build files. 21 | packages=find_packages(exclude=["build"]), 22 | # Package where to put the extensions. Has to be a prefix of build.py. 23 | ext_package="", 24 | # Extensions to compile. 25 | cffi_modules=[ 26 | os.path.join(this_file, "build.py:ffi") 27 | ], 28 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(0) 3 | # from _ext import th_fft 4 | import pytorch_fft.fft as cfft 5 | import pytorch_fft.fft.autograd as afft 6 | import numpy as np 7 | import numpy.fft as nfft 8 | 9 | def run_c2c(x, z, _f1, _f2, _if1, _if2, atol): 10 | y1, y2 = _f1(x, z) 11 | x_np = x.cpu().numpy().squeeze() 12 | y_np = _f2(x_np) 13 | assert np.allclose(y1.cpu().numpy(), y_np.real, atol=atol) 14 | assert np.allclose(y2.cpu().numpy(), y_np.imag, atol=atol) 15 | 16 | x0, z0 = _if1(y1, y2) 17 | x0_np = _if2(y_np) 18 | assert np.allclose(x0.cpu().numpy(), x0_np.real, atol=atol) 19 | assert np.allclose(z0.cpu().numpy(), x0_np.imag, atol=atol) 20 | 21 | 22 | def test_c2c(_f1, _f2, _if1, _if2): 23 | batch = 3 24 | nch = 4 25 | n = 5 26 | m = 7 27 | x = torch.randn(batch*nch*n*m).view(batch, nch, n, m).cuda() 28 | z = torch.zeros(batch, nch, n, m).cuda() 29 | run_c2c(x, z, _f1, _f2, _if1, _if2, 1e-6) 30 | run_c2c(x.double(), z.double(), _f1, _f2, _if1, _if2, 1e-14) 31 | 32 | 33 | 34 | def run_r2c(x, _f1, _f2, _if1, _if2, atol): 35 | y1, y2 = _f1(x) 36 | x_np = x.cpu().numpy().squeeze() 37 | y_np = _f2(x_np) 38 | assert np.allclose(y1.cpu().numpy(), y_np.real, atol=atol) 39 | assert np.allclose(y2.cpu().numpy(), y_np.imag, atol=atol) 40 | 41 | x0 = _if1(y1, y2) 42 | x0_np = _if2(y_np) 43 | assert np.allclose(x0.cpu().numpy(), x0_np.real, atol=atol) 44 | 45 | 46 | def test_r2c(_f1, _f2, _if1, _if2): 47 | batch = 3 48 | nch = 2 49 | n = 2 50 | m = 4 51 | x = torch.randn(batch*nch*n*m).view(batch, nch, n, m).cuda() 52 | run_r2c(x, _f1, _f2, _if1, _if2, 1e-6) 53 | run_r2c(x.double(), _f1, _f2, _if1, _if2, 1e-14) 54 | 55 | def test_expand(): 56 | X = torch.randn(2,2,4,4).cuda().double() 57 | zeros = torch.zeros(2,2,4,4).cuda().double() 58 | r1, r2 = cfft.rfft2(X) 59 | c1, c2 = cfft.fft2(X, zeros) 60 | assert np.allclose(cfft.expand(r1).cpu().numpy(), c1.cpu().numpy()) 61 | assert np.allclose(cfft.expand(r2, imag=True).cpu().numpy(), c2.cpu().numpy()) 62 | r1, r2 = cfft.rfft3(X) 63 | c1, c2 = cfft.fft3(X, zeros) 64 | assert np.allclose(cfft.expand(r1).cpu().numpy(), c1.cpu().numpy()) 65 | assert np.allclose(cfft.expand(r2, imag=True).cpu().numpy(), c2.cpu().numpy()) 66 | 67 | X = torch.randn(2,2,5,5).cuda().double() 68 | zeros = torch.zeros(2,2,5,5).cuda().double() 69 | r1, r2 = cfft.rfft3(X) 70 | c1, c2 = cfft.fft3(X, zeros) 71 | assert np.allclose(cfft.expand(r1, odd=True).cpu().numpy(), c1.cpu().numpy()) 72 | assert np.allclose(cfft.expand(r2, imag=True, odd=True).cpu().numpy(), c2.cpu().numpy()) 73 | 74 | def create_real_var(*args): 75 | return (torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True),) 76 | 77 | def create_complex_var(*args): 78 | return (torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True), 79 | torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True)) 80 | 81 | def test_fft_gradcheck(): 82 | invar = create_complex_var(5,10) 83 | assert torch.autograd.gradcheck(afft.Fft(), invar) 84 | 85 | def test_ifft_gradcheck(): 86 | invar = create_complex_var(5,10) 87 | assert torch.autograd.gradcheck(afft.Ifft(), invar) 88 | 89 | def test_fft2d_gradcheck(): 90 | invar = create_complex_var(5,5,5) 91 | assert torch.autograd.gradcheck(afft.Fft2d(), invar) 92 | 93 | def test_ifft2d_gradcheck(): 94 | invar = create_complex_var(5,5,5) 95 | assert torch.autograd.gradcheck(afft.Ifft2d(), invar) 96 | 97 | def test_fft3d_gradcheck(): 98 | invar = create_complex_var(5,3,3,3) 99 | assert torch.autograd.gradcheck(afft.Fft3d(), invar) 100 | 101 | def test_ifft3d_gradcheck(): 102 | invar = create_complex_var(5,3,3,3) 103 | assert torch.autograd.gradcheck(afft.Ifft3d(), invar) 104 | 105 | def test_rfft_gradcheck(): 106 | invar = create_real_var(5,10) 107 | assert torch.autograd.gradcheck(afft.Rfft(), invar) 108 | 109 | invar = create_real_var(5,11) 110 | assert torch.autograd.gradcheck(afft.Rfft(), invar) 111 | 112 | def test_rfft2d_gradcheck(): 113 | invar = create_real_var(5,6,6) 114 | assert torch.autograd.gradcheck(afft.Rfft2d(), invar) 115 | 116 | invar = create_real_var(5,5,5) 117 | assert torch.autograd.gradcheck(afft.Rfft2d(), invar) 118 | 119 | def test_rfft3d_gradcheck(): 120 | invar = create_real_var(5,4,4,4) 121 | assert torch.autograd.gradcheck(afft.Rfft3d(), invar) 122 | 123 | invar = create_real_var(5,3,3,3) 124 | assert torch.autograd.gradcheck(afft.Rfft3d(), invar) 125 | 126 | def test_irfft_gradcheck(): 127 | invar = create_complex_var(5,11) 128 | assert torch.autograd.gradcheck(afft.Irfft(), invar) 129 | 130 | def test_irfft2d_gradcheck(): 131 | invar = create_complex_var(5,5,5) 132 | assert torch.autograd.gradcheck(afft.Irfft2d(), invar) 133 | 134 | def test_irfft3d_gradcheck(): 135 | invar = create_complex_var(5,3,3,3) 136 | assert torch.autograd.gradcheck(afft.Irfft3d(), invar) 137 | 138 | if __name__ == "__main__": 139 | if torch.cuda.is_available(): 140 | nfft3 = lambda x: nfft.fftn(x,axes=(1,2,3)) 141 | nifft3 = lambda x: nfft.ifftn(x,axes=(1,2,3)) 142 | 143 | cfs = [cfft.fft, cfft.fft2, cfft.fft3] 144 | nfs = [nfft.fft, nfft.fft2, nfft3] 145 | cifs = [cfft.ifft, cfft.ifft2, cfft.ifft3] 146 | nifs = [nfft.ifft, nfft.ifft2, nifft3] 147 | 148 | for args in zip(cfs, nfs, cifs, nifs): 149 | test_c2c(*args) 150 | 151 | nrfft3 = lambda x: nfft.rfftn(x,axes=(1,2,3)) 152 | nirfft3 = lambda x: nfft.irfftn(x,axes=(1,2,3)) 153 | 154 | cfs = [cfft.rfft, cfft.rfft2, cfft.rfft3] 155 | nfs = [nfft.rfft, nfft.rfft2, nrfft3] 156 | cifs = [cfft.irfft, cfft.irfft2, cfft.irfft3] 157 | nifs = [nfft.irfft, nfft.irfft2, nirfft3] 158 | 159 | for args in zip(cfs, nfs, cifs, nifs): 160 | test_r2c(*args) 161 | 162 | test_expand() 163 | test_fft_gradcheck() 164 | test_ifft_gradcheck() 165 | test_fft2d_gradcheck() 166 | test_ifft2d_gradcheck() 167 | test_fft3d_gradcheck() 168 | test_ifft3d_gradcheck() 169 | 170 | test_rfft_gradcheck() 171 | test_irfft_gradcheck() 172 | test_rfft2d_gradcheck() 173 | test_irfft2d_gradcheck() 174 | test_rfft3d_gradcheck() 175 | test_irfft3d_gradcheck() 176 | else: 177 | print("Cuda not available, cannot test.") 178 | --------------------------------------------------------------------------------