├── .gitignore ├── README.md ├── assets ├── coeff.png ├── lena.jpg ├── patagonia.jpg ├── runtime_benchmark.pdf └── runtime_benchmark.png ├── examples ├── benchmark_runtime.py ├── compare_both.py ├── example_lena.py ├── example_numpy.py └── example_pytorch.py ├── license.txt ├── requirements.txt ├── setup.py ├── steerable ├── SCFpyr_NumPy.py ├── SCFpyr_PyTorch.py ├── __init__.py ├── math_utils.py └── utils.py └── tests ├── test_ifft.py ├── test_torch_fft.py ├── test_torch_fftshift.py └── test_torch_fftshift2d.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignores 2 | .idea/ 3 | .vscode 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | docs/_build/ 58 | 59 | # PyBuilder 60 | target/ 61 | 62 | #Ipython Notebook 63 | .ipynb_checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complex Steerable Pyramid in PyTorch 2 | 3 | This is a PyTorch implementation of the Complex Steerable Pyramid described in [Portilla and Simoncelli (IJCV, 2000)](http://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Portilla99). 4 | 5 | It uses PyTorch's efficient spectral decomposition layers `torch.fft` and `torch.ifft`. Just like a normal convolution layer, the complex steerable pyramid expects a batch of images of shape `[N,C,H,W]` with current support only for grayscale images (`C=1`). It returns a `list` structure containing the low-pass, high-pass and intermediate levels of the pyramid for each image in the batch (as `torch.Tensor`). Computing the steerable pyramid is significantly faster on the GPU as can be observed from the runtime benchmark below. 6 | 7 | 8 | 9 | ## Usage 10 | 11 | In addition to the PyTorch implementation defined in `SCFpyr_PyTorch` the original SciPy version is also included in `SCFpyr` for completeness and comparison. As the GPU implementation highly benefits from parallelization, the `cwt` and `power` methods expect signal batches of shape `[N,H,W]` containing a batch of `N` images of shape `HxW`. 12 | 13 | ```python 14 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 15 | import steerable.utils as utils 16 | 17 | # Load batch of images [N,1,H,W] 18 | im_batch_numpy = utils.load_image_batch(...) 19 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 20 | 21 | # Requires PyTorch with MKL when setting to 'cpu' 22 | device = torch.device('cuda:0') 23 | 24 | # Initialize Complex Steerbale Pyramid 25 | pyr = SCFpyr_PyTorch(height=5, nbands=4, scale_factor=2, device=device) 26 | 27 | # Decompose entire batch of images 28 | coeff = pyr.build(im_batch_torch) 29 | 30 | # Reconstruct batch of images again 31 | im_batch_reconstructed = pyr.reconstruct(coeff) 32 | 33 | # Visualization 34 | coeff_single = utils.extract_from_batch(coeff, 0) 35 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 36 | cv2.imshow('Complex Steerable Pyramid', coeff_grid) 37 | cv2.waitKey(0) 38 | ``` 39 | 40 | ## Benchmark 41 | 42 | Performing parallel the CSP decomposition on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comprison between the `scipy` versus `torch` implementation as function of the batch size `N` and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU. 43 | 44 | 45 | 46 | ## Installation 47 | 48 | Clone and install: 49 | 50 | ```sh 51 | git clone https://github.com/tomrunia/PyTorchSteerablePyramid.git 52 | cd PyTorchSteerablePyramid 53 | pip install -r requirements.txt 54 | python setup.py install 55 | ``` 56 | 57 | ## Requirements 58 | 59 | - Python 2.7 or 3.6 (other versions might also work) 60 | - Numpy (developed with 1.15.4) 61 | - Scipy (developed with 1.1.0) 62 | - PyTorch >= 0.4.0 (developed with 1.0.0; see note below) 63 | 64 | The steerable pyramid depends utilizes `torch.fft` and `torch.ifft` to perform operations in the spectral domain. At the moment, PyTorch only implements these operations for the GPU or with the MKL back-end on the CPU. Therefore, if you want to run the code on the CPU you might need to compile PyTorch from source with MKL enabled. Use `torch.backends.mkl.is_available()` to check if MKL is installed. 65 | 66 | ## References 67 | 68 | - [J. Portilla and E.P. Simoncelli, Complex Steerable Pyramid (IJCV, 2000)](http://www.cns.nyu.edu/pub/eero/portilla99-reprint.pdf) 69 | - [The Steerable Pyramid](http://www.cns.nyu.edu/~eero/steerpyr/) 70 | - [Official implementation: matPyrTools](http://www.cns.nyu.edu/~lcv/software.php) 71 | - [perceptual repository by Dzung Nguyen](https://github.com/andreydung/Steerable-filter) 72 | 73 | ## License 74 | 75 | MIT License 76 | 77 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 78 | 79 | Permission is hereby granted, free of charge, to any person obtaining a copy 80 | of this software and associated documentation files (the "Software"), to deal 81 | in the Software without restriction, including without limitation the rights 82 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 83 | copies of the Software, and to permit persons to whom the Software is 84 | furnished to do so, subject to the following conditions: 85 | 86 | The above copyright notice and this permission notice shall be included in all 87 | copies or substantial portions of the Software. 88 | 89 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 90 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 91 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 92 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 93 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 94 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 95 | SOFTWARE. -------------------------------------------------------------------------------- /assets/coeff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/coeff.png -------------------------------------------------------------------------------- /assets/lena.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/lena.jpg -------------------------------------------------------------------------------- /assets/patagonia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/patagonia.jpg -------------------------------------------------------------------------------- /assets/runtime_benchmark.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/runtime_benchmark.pdf -------------------------------------------------------------------------------- /assets/runtime_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/runtime_benchmark.png -------------------------------------------------------------------------------- /examples/benchmark_runtime.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-10 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import time 20 | import argparse 21 | 22 | import numpy as np 23 | import torch 24 | import matplotlib.pyplot as plt 25 | 26 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 27 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 28 | import steerable.utils as utils 29 | 30 | import cortex.plot 31 | cortex.plot.init_plotting() 32 | colors = cortex.plot.nature_colors() 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 39 | parser.add_argument('--batch_sizes', type=str, default='1,8,16,32,64,128,256') 40 | parser.add_argument('--image_sizes', type=str, default='128,256,512') 41 | parser.add_argument('--num_runs', type=int, default='5') 42 | parser.add_argument('--pyr_nlevels', type=int, default='5') 43 | parser.add_argument('--pyr_nbands', type=int, default='4') 44 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 45 | parser.add_argument('--device', type=str, default='cuda:0') 46 | config = parser.parse_args() 47 | 48 | config.batch_sizes = list(map(int, config.batch_sizes.split(','))) 49 | config.image_sizes = list(map(int, config.image_sizes.split(','))) 50 | 51 | device = utils.get_device(config.device) 52 | 53 | ################################################################################ 54 | 55 | pyr_numpy = SCFpyr_NumPy( 56 | height=config.pyr_nlevels, 57 | nbands=config.pyr_nbands, 58 | scale_factor=config.pyr_scale_factor, 59 | ) 60 | 61 | pyr_torch = SCFpyr_PyTorch( 62 | height=config.pyr_nlevels, 63 | nbands=config.pyr_nbands, 64 | scale_factor=config.pyr_scale_factor, 65 | device=device 66 | ) 67 | 68 | ############################################################################ 69 | # Run Benchmark 70 | 71 | durations_numpy = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs)) 72 | durations_torch = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs)) 73 | 74 | for batch_idx, batch_size in enumerate(config.batch_sizes): 75 | 76 | for size_idx, im_size in enumerate(config.image_sizes): 77 | 78 | for run_idx in range(config.num_runs): 79 | 80 | im_batch_numpy = utils.load_image_batch(config.image_file, batch_size, im_size) 81 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 82 | 83 | # NumPy implementation 84 | start_time = time.time() 85 | 86 | for image in im_batch_numpy: 87 | coeff = pyr_numpy.build(image[0,]) 88 | 89 | duration = time.time()-start_time 90 | durations_numpy[batch_idx,size_idx,run_idx] = duration 91 | print('BatchSize: {batch_size} | ImSize: {im_size} | NumPy Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format( 92 | batch_size=batch_size, 93 | im_size=im_size, 94 | curr_run=run_idx+1, 95 | num_run=config.num_runs, 96 | duration=duration 97 | )) 98 | 99 | # PyTorch Implementation 100 | start_time = time.time() 101 | 102 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 103 | coeff = pyr_torch.build(im_batch_torch) 104 | 105 | duration = time.time()-start_time 106 | durations_torch[batch_idx,size_idx,run_idx] = duration 107 | print('BatchSize: {batch_size} | ImSize: {im_size} | Torch Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format( 108 | batch_size=batch_size, 109 | im_size=im_size, 110 | curr_run=run_idx+1, 111 | num_run=config.num_runs, 112 | duration=duration 113 | )) 114 | 115 | np.save('./assets/durations_numpy.npy', durations_numpy) 116 | np.save('./assets/durations_torch.npy', durations_torch) 117 | 118 | ################################################################################ 119 | # Plotting 120 | 121 | durations_numpy = np.load('./assets/durations_numpy.npy') 122 | durations_torch = np.load('./assets/durations_torch.npy') 123 | 124 | for i, num_examples in enumerate(config.batch_sizes): 125 | if num_examples == 8: continue 126 | avg_durations_numpy = np.mean(durations_numpy[i,:], -1) / num_examples 127 | avg_durations_torch = np.mean(durations_torch[i,:], -1) / num_examples 128 | plt.plot(config.image_sizes, avg_durations_numpy, marker='o', linestyle='-', lw=1.4, color=colors[i], label='numpy (N = {})'.format(num_examples)) 129 | plt.plot(config.image_sizes, avg_durations_torch, marker='d', linestyle='--', lw=1.4, color=colors[i], label='pytorch (N = {})'.format(num_examples)) 130 | 131 | plt.title('Runtime Benchmark ({} levels, {} bands, average {numruns} runs)'.format(config.pyr_nlevels, config.pyr_nbands, numruns=config.num_runs)) 132 | plt.xlabel('Image Size (px)') 133 | plt.ylabel('Time per Example (s)') 134 | plt.xlim((100, 550)) 135 | plt.ylim((-0.01, 0.2)) 136 | plt.xticks(config.image_sizes) 137 | plt.legend(ncol=2, loc='top left') 138 | plt.tight_layout() 139 | 140 | plt.savefig('./assets/runtime_benchmark.png', dpi=600) 141 | plt.savefig('./assets/runtime_benchmark.pdf') 142 | plt.show() 143 | -------------------------------------------------------------------------------- /examples/compare_both.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | import cv2 22 | 23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 25 | import steerable.utils as utils 26 | 27 | ################################################################################ 28 | ################################################################################ 29 | # Common 30 | 31 | image_file = './assets/lena.jpg' 32 | image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 33 | image = cv2.resize(image, (200,200)) 34 | 35 | # Number of pyramid levels 36 | pyr_height = 5 37 | 38 | # Number of orientation bands 39 | pyr_nbands = 4 40 | 41 | # Tolerance for error checking 42 | tolerance = 1e-3 43 | 44 | ################################################################################ 45 | # NumPy 46 | 47 | pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2) 48 | coeff_numpy = pyr_numpy.build(image) 49 | reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy) 50 | reconstruction_numpy = reconstruction_numpy.astype(np.uint8) 51 | 52 | print('#'*60) 53 | 54 | ################################################################################ 55 | # PyTorch 56 | 57 | device = torch.device('cuda:0') 58 | 59 | im_batch = torch.from_numpy(image[None,None,:,:]) 60 | im_batch = im_batch.to(device).float() 61 | 62 | pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device) 63 | coeff_torch = pyr_torch.build(im_batch) 64 | reconstruction_torch = pyr_torch.reconstruct(coeff_torch) 65 | reconstruction_torch = reconstruction_torch.cpu().numpy()[0,] 66 | 67 | # Extract first example from the batch and move to CPU 68 | coeff_torch = utils.extract_from_batch(coeff_torch, 0) 69 | 70 | ################################################################################ 71 | # Check correctness 72 | 73 | print('#'*60) 74 | assert len(coeff_numpy) == len(coeff_torch) 75 | 76 | for level, _ in enumerate(coeff_numpy): 77 | 78 | print('Pyramid Level {level}'.format(level=level)) 79 | coeff_level_numpy = coeff_numpy[level] 80 | coeff_level_torch = coeff_torch[level] 81 | 82 | assert isinstance(coeff_level_torch, type(coeff_level_numpy)) 83 | 84 | if isinstance(coeff_level_numpy, np.ndarray): 85 | 86 | # Low- or High-Pass 87 | print(' NumPy. min = {min:.3f}, max = {max:.3f},' 88 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 89 | min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy), 90 | mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy))) 91 | 92 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},' 93 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 94 | min=np.min(coeff_level_torch), max=np.max(coeff_level_torch), 95 | mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch))) 96 | 97 | # Check numerical correctness 98 | assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance) 99 | 100 | elif isinstance(coeff_level_numpy, list): 101 | 102 | # Intermediate bands 103 | for band, _ in enumerate(coeff_level_numpy): 104 | 105 | band_numpy = coeff_level_numpy[band] 106 | band_torch = coeff_level_torch[band] 107 | 108 | print(' Orientation Band {}'.format(band)) 109 | print(' NumPy. min = {min:.3f}, max = {max:.3f},' 110 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 111 | min=np.min(band_numpy), max=np.max(band_numpy), 112 | mean=np.mean(band_numpy), std=np.std(band_numpy))) 113 | 114 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},' 115 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 116 | min=np.min(band_torch), max=np.max(band_torch), 117 | mean=np.mean(band_torch), std=np.std(band_torch))) 118 | 119 | # Check numerical correctness 120 | assert np.allclose(band_numpy, band_torch, atol=tolerance) 121 | 122 | ################################################################################ 123 | # Visualize 124 | 125 | coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False) 126 | coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False) 127 | 128 | cv2.imshow('image', image) 129 | cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy)) 130 | cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch)) 131 | cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8)) 132 | cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8)) 133 | 134 | cv2.waitKey(0) 135 | -------------------------------------------------------------------------------- /examples/example_lena.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import cv2 21 | 22 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 23 | import steerable.utils as utils 24 | 25 | ################################################################################ 26 | ################################################################################ 27 | 28 | if __name__ == "__main__": 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 32 | parser.add_argument('--batch_size', type=int, default='1') 33 | parser.add_argument('--image_size', type=int, default='200') 34 | parser.add_argument('--pyr_nlevels', type=int, default='5') 35 | parser.add_argument('--pyr_nbands', type=int, default='4') 36 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 37 | parser.add_argument('--visualize', type=bool, default=True) 38 | config = parser.parse_args() 39 | 40 | ############################################################################ 41 | # Build the complex steerable pyramid 42 | 43 | pyr = SCFpyr_NumPy( 44 | height=config.pyr_nlevels, 45 | nbands=config.pyr_nbands, 46 | scale_factor=config.pyr_scale_factor, 47 | ) 48 | 49 | ############################################################################ 50 | # Create a batch and feed-forward 51 | 52 | image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE) 53 | image = cv2.resize(image, dsize=(200,200)) 54 | 55 | coeff = pyr.build(image) 56 | 57 | grid = utils.make_grid_coeff(coeff, normalize=True) 58 | 59 | cv2.imwrite('./assets/coeff.png', grid) 60 | cv2.imshow('image', image) 61 | cv2.imshow('coeff', grid) 62 | cv2.waitKey(0) 63 | 64 | -------------------------------------------------------------------------------- /examples/example_numpy.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import time 20 | import argparse 21 | import numpy as np 22 | 23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 24 | import steerable.utils as utils 25 | 26 | ################################################################################ 27 | ################################################################################ 28 | 29 | if __name__ == "__main__": 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 33 | parser.add_argument('--batch_size', type=int, default='1') 34 | parser.add_argument('--image_size', type=int, default='200') 35 | parser.add_argument('--pyr_nlevels', type=int, default='5') 36 | parser.add_argument('--pyr_nbands', type=int, default='4') 37 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 38 | parser.add_argument('--visualize', type=bool, default=True) 39 | config = parser.parse_args() 40 | 41 | ############################################################################ 42 | # Build the complex steerable pyramid 43 | 44 | pyr = SCFpyr_NumPy( 45 | height=config.pyr_nlevels, 46 | nbands=config.pyr_nbands, 47 | scale_factor=config.pyr_scale_factor, 48 | ) 49 | 50 | ############################################################################ 51 | # Create a batch and feed-forward 52 | 53 | start_time = time.time() 54 | 55 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size) 56 | im_batch_numpy = im_batch_numpy.squeeze(1) # no channel dim for NumPy 57 | 58 | # Compute Steerable Pyramid 59 | start_time = time.time() 60 | for image in im_batch_numpy: 61 | coeff = pyr.build(image) 62 | 63 | duration = time.time()-start_time 64 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format( 65 | batch_size=config.batch_size, 66 | duration=duration 67 | )) 68 | 69 | ############################################################################ 70 | # Visualization 71 | 72 | if config.visualize: 73 | import cv2 74 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 75 | cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8)) 76 | cv2.imshow('coeff', coeff_grid) 77 | cv2.waitKey(0) 78 | 79 | -------------------------------------------------------------------------------- /examples/example_pytorch.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import time 21 | import numpy as np 22 | import torch 23 | 24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 25 | import steerable.utils as utils 26 | 27 | ################################################################################ 28 | ################################################################################ 29 | 30 | if __name__ == "__main__": 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 34 | parser.add_argument('--batch_size', type=int, default='32') 35 | parser.add_argument('--image_size', type=int, default='200') 36 | parser.add_argument('--pyr_nlevels', type=int, default='5') 37 | parser.add_argument('--pyr_nbands', type=int, default='4') 38 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 39 | parser.add_argument('--device', type=str, default='cuda:0') 40 | parser.add_argument('--visualize', type=bool, default=True) 41 | config = parser.parse_args() 42 | 43 | device = utils.get_device(config.device) 44 | 45 | ############################################################################ 46 | # Build the complex steerable pyramid 47 | 48 | pyr = SCFpyr_PyTorch( 49 | height=config.pyr_nlevels, 50 | nbands=config.pyr_nbands, 51 | scale_factor=config.pyr_scale_factor, 52 | device=device 53 | ) 54 | 55 | ############################################################################ 56 | # Create a batch and feed-forward 57 | 58 | start_time = time.time() 59 | 60 | # Load Batch 61 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size) 62 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 63 | 64 | # Compute Steerable Pyramid 65 | coeff = pyr.build(im_batch_torch) 66 | 67 | duration = time.time()-start_time 68 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format( 69 | batch_size=config.batch_size, 70 | duration=duration 71 | )) 72 | 73 | ############################################################################ 74 | # Visualization 75 | 76 | # Just extract a single example from the batch 77 | # Also moves the example to CPU and NumPy 78 | coeff = utils.extract_from_batch(coeff, 0) 79 | 80 | if config.visualize: 81 | import cv2 82 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 83 | cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8)) 84 | cv2.imshow('coeff', coeff_grid) 85 | cv2.waitKey(0) 86 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | numpy 3 | scipy 4 | torch >= 0.4.0 5 | matplotlib 6 | pillow -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup 3 | 4 | setup( 5 | name='steerable_pytorch', 6 | version='0.1', 7 | author='Tom Runia', 8 | author_email='tomrunia@gmail.com', 9 | url='https://github.com/tomrunia/PyTorchSteerablePyramid', 10 | description='Complex Steerable Pyramids in PyTorch', 11 | long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.', 12 | license='MIT', 13 | packages=['steerable_pytorch'], 14 | scripts=[] 15 | ) -------------------------------------------------------------------------------- /steerable/SCFpyr_NumPy.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | from scipy.misc import factorial 21 | 22 | import steerable.math_utils as math_utils 23 | pointOp = math_utils.pointOp 24 | 25 | ################################################################################ 26 | 27 | class SCFpyr_NumPy(): 28 | ''' 29 | This is a modified version of buildSFpyr, that constructs a 30 | complex-valued steerable pyramid using Hilbert-transform pairs 31 | of filters. Note that the imaginary parts will *not* be steerable. 32 | 33 | Description of this transform appears in: Portilla & Simoncelli, 34 | International Journal of Computer Vision, 40(1):49-71, Oct 2000. 35 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/ 36 | 37 | Modified code from the perceptual repository: 38 | https://github.com/andreydung/Steerable-filter 39 | 40 | This code looks very similar to the original Matlab code: 41 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m 42 | 43 | Also looks very similar to the original Python code presented here: 44 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py 45 | 46 | ''' 47 | 48 | def __init__(self, height=5, nbands=4, scale_factor=2): 49 | self.nbands = nbands # number of orientation bands 50 | self.height = height # including low-pass and high-pass 51 | self.scale_factor = scale_factor 52 | 53 | # Cache constants 54 | self.lutsize = 1024 55 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize 56 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi 57 | 58 | 59 | ################################################################################ 60 | # Construction of Steerable Pyramid 61 | 62 | def build(self, im): 63 | ''' Decomposes an image into it's complex steerable pyramid. 64 | The pyramid typically has ~4 levels and 4-8 orientations. 65 | 66 | Args: 67 | im_batch (np.ndarray): single image [H,W] 68 | 69 | Returns: 70 | pyramid: list containing np.ndarray objects storing the pyramid 71 | ''' 72 | 73 | assert len(im.shape) == 2, 'Input im must be grayscale' 74 | height, width = im.shape 75 | 76 | # Check whether image size is sufficient for number of levels 77 | if self.height > int(np.floor(np.log2(min(width, height))) - 2): 78 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) 79 | 80 | # Prepare a grid 81 | log_rad, angle = math_utils.prepare_grid(height, width) 82 | 83 | # Radial transition function (a raised cosine in log-frequency): 84 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 85 | Yrcos = np.sqrt(Yrcos) 86 | 87 | YIrcos = np.sqrt(1 - Yrcos**2) 88 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 89 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 90 | 91 | # Shift the zero-frequency component to the center of the spectrum. 92 | imdft = np.fft.fftshift(np.fft.fft2(im)) 93 | 94 | # Low-pass 95 | lo0dft = imdft * lo0mask 96 | 97 | # Recursive build the steerable pyramid 98 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1) 99 | 100 | # High-pass 101 | hi0dft = imdft * hi0mask 102 | hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft)) 103 | coeff.insert(0, hi0.real) 104 | return coeff 105 | 106 | 107 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): 108 | 109 | if height <= 1: 110 | 111 | # Low-pass 112 | lo0 = np.fft.ifftshift(lodft) 113 | lo0 = np.fft.ifft2(lo0) 114 | coeff = [lo0.real] 115 | 116 | else: 117 | 118 | Xrcos = Xrcos - np.log2(self.scale_factor) 119 | 120 | #################################################################### 121 | ####################### Orientation bandpass ####################### 122 | #################################################################### 123 | 124 | himask = pointOp(log_rad, Yrcos, Xrcos) 125 | 126 | order = self.nbands - 1 127 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 128 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) 129 | 130 | # Loop through all orientation bands 131 | orientations = [] 132 | for b in range(self.nbands): 133 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands) 134 | banddft = np.power(np.complex(0, -1), self.nbands - 1) * lodft * anglemask * himask 135 | band = np.fft.ifft2(np.fft.ifftshift(banddft)) 136 | orientations.append(band) 137 | 138 | #################################################################### 139 | ######################## Subsample lowpass ######################### 140 | #################################################################### 141 | 142 | dims = np.array(lodft.shape) 143 | 144 | # Both are tuples of size 2 145 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int) 146 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int) 147 | 148 | # Selection 149 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 150 | angle = angle[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 151 | lodft = lodft[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 152 | 153 | # Subsampling in frequency domain 154 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) 155 | lomask = pointOp(log_rad, YIrcos, Xrcos) 156 | lodft = lomask * lodft 157 | 158 | #################################################################### 159 | ####################### Recursion next level ####################### 160 | #################################################################### 161 | 162 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1) 163 | coeff.insert(0, orientations) 164 | 165 | return coeff 166 | 167 | ############################################################################ 168 | ########################### RECONSTRUCTION ################################# 169 | ############################################################################ 170 | 171 | def reconstruct(self, coeff): 172 | 173 | if self.nbands != len(coeff[1]): 174 | raise Exception("Unmatched number of orientations") 175 | 176 | height, width = coeff[0].shape 177 | log_rad, angle = math_utils.prepare_grid(height, width) 178 | 179 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 180 | Yrcos = np.sqrt(Yrcos) 181 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 182 | 183 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 184 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 185 | 186 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) 187 | 188 | hidft = np.fft.fftshift(np.fft.fft2(coeff[0])) 189 | outdft = tempdft * lo0mask + hidft * hi0mask 190 | 191 | reconstruction = np.fft.ifftshift(outdft) 192 | reconstruction = np.fft.ifft2(reconstruction) 193 | reconstruction = reconstruction.real 194 | 195 | return reconstruction 196 | 197 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): 198 | 199 | if len(coeff) == 1: 200 | dft = np.fft.fft2(coeff[0]) 201 | dft = np.fft.fftshift(dft) 202 | return dft 203 | 204 | Xrcos = Xrcos - np.log2(self.scale_factor) 205 | 206 | #################################################################### 207 | ####################### Orientation Residue ######################## 208 | #################################################################### 209 | 210 | himask = pointOp(log_rad, Yrcos, Xrcos) 211 | 212 | lutsize = 1024 213 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize 214 | order = self.nbands - 1 215 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 216 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order) 217 | 218 | orientdft = np.zeros(coeff[0][0].shape) 219 | 220 | for b in range(self.nbands): 221 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands) 222 | banddft = np.fft.fft2(coeff[0][b]) 223 | banddft = np.fft.fftshift(banddft) 224 | orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask 225 | 226 | #################################################################### 227 | ########## Lowpass component are upsampled and convoluted ########## 228 | #################################################################### 229 | 230 | dims = np.array(coeff[0][0].shape) 231 | 232 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32) 233 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32) 234 | 235 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] 236 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] 237 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 238 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 239 | 240 | ################################################################################ 241 | 242 | # Recursive call for image reconstruction 243 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) 244 | 245 | resdft = np.zeros(dims, 'complex') 246 | resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask 247 | 248 | return resdft + orientdft 249 | -------------------------------------------------------------------------------- /steerable/SCFpyr_PyTorch.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | from scipy.misc import factorial 22 | 23 | import steerable.math_utils as math_utils 24 | pointOp = math_utils.pointOp 25 | 26 | ################################################################################ 27 | ################################################################################ 28 | 29 | 30 | class SCFpyr_PyTorch(object): 31 | ''' 32 | This is a modified version of buildSFpyr, that constructs a 33 | complex-valued steerable pyramid using Hilbert-transform pairs 34 | of filters. Note that the imaginary parts will *not* be steerable. 35 | 36 | Description of this transform appears in: Portilla & Simoncelli, 37 | International Journal of Computer Vision, 40(1):49-71, Oct 2000. 38 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/ 39 | 40 | Modified code from the perceptual repository: 41 | https://github.com/andreydung/Steerable-filter 42 | 43 | This code looks very similar to the original Matlab code: 44 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m 45 | 46 | Also looks very similar to the original Python code presented here: 47 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py 48 | 49 | ''' 50 | 51 | def __init__(self, height=5, nbands=4, scale_factor=2, device=None): 52 | self.height = height # including low-pass and high-pass 53 | self.nbands = nbands # number of orientation bands 54 | self.scale_factor = scale_factor 55 | self.device = torch.device('cpu') if device is None else device 56 | 57 | # Cache constants 58 | self.lutsize = 1024 59 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize 60 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi 61 | self.complex_fact_construct = np.power(np.complex(0, -1), self.nbands-1) 62 | self.complex_fact_reconstruct = np.power(np.complex(0, 1), self.nbands-1) 63 | 64 | ################################################################################ 65 | # Construction of Steerable Pyramid 66 | 67 | def build(self, im_batch): 68 | ''' Decomposes a batch of images into a complex steerable pyramid. 69 | The pyramid typically has ~4 levels and 4-8 orientations. 70 | 71 | Args: 72 | im_batch (torch.Tensor): Batch of images of shape [N,C,H,W] 73 | 74 | Returns: 75 | pyramid: list containing torch.Tensor objects storing the pyramid 76 | ''' 77 | 78 | assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device) 79 | assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32' 80 | assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' 81 | assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image' 82 | 83 | im_batch = im_batch.squeeze(1) # flatten channels dim 84 | height, width = im_batch.shape[2], im_batch.shape[1] 85 | 86 | # Check whether image size is sufficient for number of levels 87 | if self.height > int(np.floor(np.log2(min(width, height))) - 2): 88 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) 89 | 90 | # Prepare a grid 91 | log_rad, angle = math_utils.prepare_grid(height, width) 92 | 93 | # Radial transition function (a raised cosine in log-frequency): 94 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 95 | Yrcos = np.sqrt(Yrcos) 96 | 97 | YIrcos = np.sqrt(1 - Yrcos**2) 98 | 99 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 100 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 101 | 102 | # Note that we expand dims to support broadcasting later 103 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) 104 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) 105 | 106 | # Fourier transform (2D) and shifting 107 | batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False) 108 | batch_dft = math_utils.batch_fftshift2d(batch_dft) 109 | 110 | # Low-pass 111 | lo0dft = batch_dft * lo0mask 112 | 113 | # Start recursively building the pyramids 114 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1) 115 | 116 | # High-pass 117 | hi0dft = batch_dft * hi0mask 118 | hi0 = math_utils.batch_ifftshift2d(hi0dft) 119 | hi0 = torch.ifft(hi0, signal_ndim=2) 120 | hi0_real = torch.unbind(hi0, -1)[0] 121 | coeff.insert(0, hi0_real) 122 | return coeff 123 | 124 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): 125 | 126 | if height <= 1: 127 | 128 | # Low-pass 129 | lo0 = math_utils.batch_ifftshift2d(lodft) 130 | lo0 = torch.ifft(lo0, signal_ndim=2) 131 | lo0_real = torch.unbind(lo0, -1)[0] 132 | coeff = [lo0_real] 133 | 134 | else: 135 | 136 | Xrcos = Xrcos - np.log2(self.scale_factor) 137 | 138 | #################################################################### 139 | ####################### Orientation bandpass ####################### 140 | #################################################################### 141 | 142 | himask = pointOp(log_rad, Yrcos, Xrcos) 143 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) 144 | 145 | order = self.nbands - 1 146 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 147 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) # [n,] 148 | 149 | # Loop through all orientation bands 150 | orientations = [] 151 | for b in range(self.nbands): 152 | 153 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands) 154 | anglemask = anglemask[None,:,:,None] # for broadcasting 155 | anglemask = torch.from_numpy(anglemask).float().to(self.device) 156 | 157 | # Bandpass filtering 158 | banddft = lodft * anglemask * himask 159 | 160 | # Now multiply with complex number 161 | # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i 162 | banddft = torch.unbind(banddft, -1) 163 | banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1] 164 | banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0] 165 | banddft = torch.stack((banddft_real, banddft_imag), -1) 166 | 167 | band = math_utils.batch_ifftshift2d(banddft) 168 | band = torch.ifft(band, signal_ndim=2) 169 | orientations.append(band) 170 | 171 | #################################################################### 172 | ######################## Subsample lowpass ######################### 173 | #################################################################### 174 | 175 | # Don't consider batch_size and imag/real dim 176 | dims = np.array(lodft.shape[1:3]) 177 | 178 | # Both are tuples of size 2 179 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int) 180 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int) 181 | 182 | # Subsampling indices 183 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] 184 | angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] 185 | 186 | # Actual subsampling 187 | lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:] 188 | 189 | # Filtering 190 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) 191 | lomask = pointOp(log_rad, YIrcos, Xrcos) 192 | lomask = torch.from_numpy(lomask[None,:,:,None]).float() 193 | lomask = lomask.to(self.device) 194 | 195 | # Convolution in spatial domain 196 | lodft = lomask * lodft 197 | 198 | #################################################################### 199 | ####################### Recursion next level ####################### 200 | #################################################################### 201 | 202 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1) 203 | coeff.insert(0, orientations) 204 | 205 | return coeff 206 | 207 | ############################################################################ 208 | ########################### RECONSTRUCTION ################################# 209 | ############################################################################ 210 | 211 | def reconstruct(self, coeff): 212 | 213 | if self.nbands != len(coeff[1]): 214 | raise Exception("Unmatched number of orientations") 215 | 216 | height, width = coeff[0].shape[2], coeff[0].shape[1] 217 | log_rad, angle = math_utils.prepare_grid(height, width) 218 | 219 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 220 | Yrcos = np.sqrt(Yrcos) 221 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 222 | 223 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 224 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 225 | 226 | # Note that we expand dims to support broadcasting later 227 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) 228 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) 229 | 230 | # Start recursive reconstruction 231 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) 232 | 233 | hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) 234 | hidft = math_utils.batch_fftshift2d(hidft) 235 | 236 | outdft = tempdft * lo0mask + hidft * hi0mask 237 | 238 | reconstruction = math_utils.batch_ifftshift2d(outdft) 239 | reconstruction = torch.ifft(reconstruction, signal_ndim=2) 240 | reconstruction = torch.unbind(reconstruction, -1)[0] # real 241 | 242 | return reconstruction 243 | 244 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): 245 | 246 | if len(coeff) == 1: 247 | dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) 248 | dft = math_utils.batch_fftshift2d(dft) 249 | return dft 250 | 251 | Xrcos = Xrcos - np.log2(self.scale_factor) 252 | 253 | #################################################################### 254 | ####################### Orientation Residue ######################## 255 | #################################################################### 256 | 257 | himask = pointOp(log_rad, Yrcos, Xrcos) 258 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) 259 | 260 | lutsize = 1024 261 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize 262 | order = self.nbands - 1 263 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 264 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order) 265 | 266 | orientdft = torch.zeros_like(coeff[0][0]) 267 | for b in range(self.nbands): 268 | 269 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands) 270 | anglemask = anglemask[None,:,:,None] # for broadcasting 271 | anglemask = torch.from_numpy(anglemask).float().to(self.device) 272 | 273 | banddft = torch.fft(coeff[0][b], signal_ndim=2) 274 | banddft = math_utils.batch_fftshift2d(banddft) 275 | 276 | banddft = banddft * anglemask * himask 277 | banddft = torch.unbind(banddft, -1) 278 | banddft_real = self.complex_fact_reconstruct.real*banddft[0] - self.complex_fact_reconstruct.imag*banddft[1] 279 | banddft_imag = self.complex_fact_reconstruct.real*banddft[1] + self.complex_fact_reconstruct.imag*banddft[0] 280 | banddft = torch.stack((banddft_real, banddft_imag), -1) 281 | 282 | orientdft = orientdft + banddft 283 | 284 | #################################################################### 285 | ########## Lowpass component are upsampled and convoluted ########## 286 | #################################################################### 287 | 288 | dims = np.array(coeff[0][0].shape[1:3]) 289 | 290 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32) 291 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32) 292 | 293 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] 294 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] 295 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 296 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 297 | 298 | # Filtering 299 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 300 | lomask = torch.from_numpy(lomask[None,:,:,None]) 301 | lomask = lomask.float().to(self.device) 302 | 303 | ################################################################################ 304 | 305 | # Recursive call for image reconstruction 306 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) 307 | 308 | resdft = torch.zeros_like(coeff[0][0]).to(self.device) 309 | resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask 310 | 311 | return resdft + orientdft 312 | -------------------------------------------------------------------------------- /steerable/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/steerable/__init__.py -------------------------------------------------------------------------------- /steerable/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | 22 | ################################################################################ 23 | ################################################################################ 24 | 25 | def roll_n(X, axis, n): 26 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) 27 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) 28 | front = X[f_idx] 29 | back = X[b_idx] 30 | return torch.cat([back, front], axis) 31 | 32 | def batch_fftshift2d(x): 33 | real, imag = torch.unbind(x, -1) 34 | for dim in range(1, len(real.size())): 35 | n_shift = real.size(dim)//2 36 | if real.size(dim) % 2 != 0: 37 | n_shift += 1 # for odd-sized images 38 | real = roll_n(real, axis=dim, n=n_shift) 39 | imag = roll_n(imag, axis=dim, n=n_shift) 40 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 41 | 42 | def batch_ifftshift2d(x): 43 | real, imag = torch.unbind(x, -1) 44 | for dim in range(len(real.size()) - 1, 0, -1): 45 | real = roll_n(real, axis=dim, n=real.size(dim)//2) 46 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) 47 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 48 | 49 | ################################################################################ 50 | ################################################################################ 51 | 52 | def prepare_grid(m, n): 53 | x = np.linspace(-(m // 2)/(m / 2), (m // 2)/(m / 2) - (1 - m % 2)*2/m, num=m) 54 | y = np.linspace(-(n // 2)/(n / 2), (n // 2)/(n / 2) - (1 - n % 2)*2/n, num=n) 55 | xv, yv = np.meshgrid(y, x) 56 | angle = np.arctan2(yv, xv) 57 | rad = np.sqrt(xv**2 + yv**2) 58 | rad[m//2][n//2] = rad[m//2][n//2 - 1] 59 | log_rad = np.log2(rad) 60 | return log_rad, angle 61 | 62 | def rcosFn(width, position): 63 | N = 256 # abritrary 64 | X = np.pi * np.array(range(-N-1, 2))/2/N 65 | Y = np.cos(X)**2 66 | Y[0] = Y[1] 67 | Y[N+2] = Y[N+1] 68 | X = position + 2*width/np.pi*(X + np.pi/4) 69 | return X, Y 70 | 71 | def pointOp(im, Y, X): 72 | out = np.interp(im.flatten(), X, Y) 73 | return np.reshape(out, im.shape) 74 | 75 | def getlist(coeff): 76 | straight = [bands for scale in coeff[1:-1] for bands in scale] 77 | straight = [coeff[0]] + straight + [coeff[-1]] 78 | return straight 79 | 80 | ################################################################################ 81 | # NumPy reference implementation (fftshift and ifftshift) 82 | 83 | # def fftshift(x, axes=None): 84 | # """ 85 | # Shift the zero-frequency component to the center of the spectrum. 86 | # This function swaps half-spaces for all axes listed (defaults to all). 87 | # Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. 88 | # Parameters 89 | # """ 90 | # x = np.asarray(x) 91 | # if axes is None: 92 | # axes = tuple(range(x.ndim)) 93 | # shift = [dim // 2 for dim in x.shape] 94 | # shift = [x.shape[ax] // 2 for ax in axes] 95 | # return np.roll(x, shift, axes) 96 | # 97 | # def ifftshift(x, axes=None): 98 | # """ 99 | # The inverse of `fftshift`. Although identical for even-length `x`, the 100 | # functions differ by one sample for odd-length `x`. 101 | # """ 102 | # x = np.asarray(x) 103 | # if axes is None: 104 | # axes = tuple(range(x.ndim)) 105 | # shift = [-(dim // 2) for dim in x.shape] 106 | # shift = [-(x.shape[ax] // 2) for ax in axes] 107 | # return np.roll(x, shift, axes) 108 | -------------------------------------------------------------------------------- /steerable/utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-10 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import numpy as np 21 | import skimage 22 | import matplotlib.pyplot as plt 23 | 24 | import torch 25 | import torchvision 26 | 27 | ################################################################################ 28 | 29 | ToPIL = torchvision.transforms.ToPILImage() 30 | Grayscale = torchvision.transforms.Grayscale() 31 | RandomCrop = torchvision.transforms.RandomCrop 32 | 33 | def get_device(device='cuda:0'): 34 | assert isinstance(device, str) 35 | num_cuda = torch.cuda.device_count() 36 | 37 | if 'cuda' in device: 38 | if num_cuda > 0: 39 | # Found CUDA device, use the GPU 40 | return torch.device(device) 41 | # Fallback to CPU 42 | print('No CUDA devices found, falling back to CPU') 43 | device = 'cpu' 44 | 45 | if not torch.backends.mkl.is_available(): 46 | raise NotImplementedError( 47 | 'torch.fft on the CPU requires MKL back-end. ' + 48 | 'Please recompile your PyTorch distribution.') 49 | return torch.device('cpu') 50 | 51 | def load_image_batch(image_file, batch_size, image_size=200): 52 | if not os.path.isfile(image_file): 53 | raise FileNotFoundError('Image file not found on disk: {}'.format(image_file)) 54 | im = ToPIL(skimage.io.imread(image_file)) 55 | im = Grayscale(im) 56 | im_batch = np.zeros((batch_size, image_size, image_size), np.float32) 57 | for i in range(batch_size): 58 | im_batch[i] = RandomCrop(image_size)(im) 59 | # insert channels dim and rescale 60 | return im_batch[:,None,:,:]/225. 61 | 62 | def show_image_batch(im_batch): 63 | assert isinstance(im_batch, torch.Tensor) 64 | im_batch = torchvision.utils.make_grid(im_batch).numpy() 65 | im_batch = np.transpose(im_batch.squeeze(1), (1,2,0)) 66 | plt.imshow(im_batch) 67 | plt.axis('off') 68 | plt.tight_layout() 69 | plt.show() 70 | return im_batch 71 | 72 | def extract_from_batch(coeff_batch, example_idx=0): 73 | ''' 74 | Given the batched Complex Steerable Pyramid, extract the coefficients 75 | for a single example from the batch. Additionally, it converts all 76 | torch.Tensor's to np.ndarrays' and changes creates proper np.complex 77 | objects for all the orientation bands. 78 | 79 | Args: 80 | coeff_batch (list): list containing low-pass, high-pass and pyr levels 81 | example_idx (int, optional): Defaults to 0. index in batch to extract 82 | 83 | Returns: 84 | list: list containing low-pass, high-pass and pyr levels as np.ndarray 85 | ''' 86 | if not isinstance(coeff_batch, list): 87 | raise ValueError('Batch of coefficients must be a list') 88 | coeff = [] # coefficient for single example 89 | for coeff_level in coeff_batch: 90 | if isinstance(coeff_level, torch.Tensor): 91 | # Low- or High-Pass 92 | coeff_level_numpy = coeff_level[example_idx].cpu().numpy() 93 | coeff.append(coeff_level_numpy) 94 | elif isinstance(coeff_level, list): 95 | coeff_orientations_numpy = [] 96 | for coeff_orientation in coeff_level: 97 | coeff_orientation_numpy = coeff_orientation[example_idx].cpu().numpy() 98 | coeff_orientation_numpy = coeff_orientation_numpy[:,:,0] + 1j*coeff_orientation_numpy[:,:,1] 99 | coeff_orientations_numpy.append(coeff_orientation_numpy) 100 | coeff.append(coeff_orientations_numpy) 101 | else: 102 | raise ValueError('coeff leve must be of type (list, torch.Tensor)') 103 | return coeff 104 | 105 | ################################################################################ 106 | 107 | def make_grid_coeff(coeff, normalize=True): 108 | ''' 109 | Visualization function for building a large image that contains the 110 | low-pass, high-pass and all intermediate levels in the steerable pyramid. 111 | For the complex intermediate bands, the real part is visualized. 112 | 113 | Args: 114 | coeff (list): complex pyramid stored as list containing all levels 115 | normalize (bool, optional): Defaults to True. Whether to normalize each band 116 | 117 | Returns: 118 | np.ndarray: large image that contains grid of all bands and orientations 119 | ''' 120 | M, N = coeff[1][0].shape 121 | Norients = len(coeff[1]) 122 | out = np.zeros((M * 2 - coeff[-1].shape[0], Norients * N)) 123 | currentx, currenty = 0, 0 124 | 125 | for i in range(1, len(coeff[:-1])): 126 | for j in range(len(coeff[1])): 127 | tmp = coeff[i][j].real 128 | m, n = tmp.shape 129 | if normalize: 130 | tmp = 255 * tmp/tmp.max() 131 | tmp[m-1,:] = 255 132 | tmp[:,n-1] = 255 133 | out[currentx:currentx+m,currenty:currenty+n] = tmp 134 | currenty += n 135 | currentx += coeff[i][0].shape[0] 136 | currenty = 0 137 | 138 | m, n = coeff[-1].shape 139 | out[currentx: currentx+m, currenty: currenty+n] = 255 * coeff[-1]/coeff[-1].max() 140 | out[0,:] = 255 141 | out[:,0] = 255 142 | return out.astype(np.uint8) 143 | -------------------------------------------------------------------------------- /tests/test_ifft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-07 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | import cv2 22 | 23 | import steerable.fft as fft_utils 24 | import matplotlib.pyplot as plt 25 | 26 | ################################################################################ 27 | 28 | tolerance = 1e-6 29 | 30 | image_file = './assets/lena.jpg' 31 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 32 | im = cv2.resize(im, dsize=(200, 200)) 33 | im = im.astype(np.float64)/255. # note the np.float64 34 | 35 | ################################################################################ 36 | # NumPy 37 | 38 | fft_numpy = np.fft.fft2(im) 39 | fft_numpy = np.fft.fftshift(fft_numpy) 40 | 41 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy)) 42 | fft_numpy_ang_viz = np.angle(fft_numpy) 43 | 44 | ifft_numpy1 = np.fft.ifftshift(fft_numpy) 45 | ifft_numpy = np.fft.ifft2(ifft_numpy1) 46 | 47 | ################################################################################ 48 | # Torch 49 | 50 | device = torch.device('cpu') 51 | 52 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim 53 | im_torch = im_torch.to(device) 54 | 55 | # fft = complex-to-complex, rfft = real-to-complex 56 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False) 57 | fft_torch = fft_utils.batch_fftshift2d(fft_torch) 58 | 59 | ifft_torch = fft_utils.batch_ifftshift2d(fft_torch) 60 | ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False) 61 | 62 | ifft_torch_to_numpy = ifft_torch.numpy() 63 | ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2, -1) # complex => real/imag 64 | ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1) 65 | ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j*ifft_torch_to_numpy[1] 66 | all_close_ifft = np.allclose(ifft_numpy1, ifft_torch_to_numpy, atol=tolerance) 67 | print('ifft all close: ', all_close_ifft) 68 | 69 | fft_torch = fft_torch.cpu().numpy().squeeze() 70 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag 71 | fft_torch = np.squeeze(fft_torch, -1) 72 | fft_torch = fft_torch[0] + 1j*fft_torch[1] 73 | 74 | ifft_torch = ifft_torch.cpu().numpy().squeeze() 75 | ifft_torch = np.split(ifft_torch, 2, -1) # complex => real/imag 76 | ifft_torch = np.squeeze(ifft_torch, -1) 77 | ifft_torch = ifft_torch[0] + 1j*ifft_torch[1] 78 | 79 | fft_torch_mag_viz = np.log10(np.abs(fft_torch)) 80 | fft_torch_ang_viz = np.angle(fft_torch) 81 | 82 | ################################################################################ 83 | # Tolerance checking 84 | 85 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) 86 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) 87 | print('fft allclose real: {}'.format(all_close_real)) 88 | print('fft allclose imag: {}'.format(all_close_imag)) 89 | 90 | all_close_real = np.allclose(np.real(ifft_numpy), np.real(ifft_torch), atol=tolerance) 91 | all_close_imag = np.allclose(np.imag(ifft_numpy), np.imag(ifft_torch), atol=tolerance) 92 | print('ifft allclose real: {}'.format(all_close_real)) 93 | print('ifft allclose imag: {}'.format(all_close_imag)) 94 | 95 | ################################################################################ 96 | 97 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12,6)) 98 | 99 | # Plotting NumPy results 100 | ax[0][0].imshow(im, cmap='gray') 101 | 102 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray') 103 | ax[0][1].set_title('NumPy fft magnitude') 104 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray') 105 | ax[0][2].set_title('NumPy fft spectrum') 106 | ax[0][3].imshow(ifft_numpy.real, cmap='gray') 107 | ax[0][3].set_title('NumPy ifft real') 108 | ax[0][4].imshow(ifft_numpy.imag, cmap='gray') 109 | ax[0][4].set_title('NumPy ifft imag') 110 | 111 | # Plotting PyTorch results 112 | ax[1][0].imshow(im, cmap='gray') 113 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray') 114 | ax[1][1].set_title('PyTorch fft magnitude') 115 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray') 116 | ax[1][2].set_title('PyTorch fft phase') 117 | ax[1][3].imshow(ifft_torch.real, cmap='gray') 118 | ax[1][3].set_title('PyTorch ifft real') 119 | ax[1][4].imshow(ifft_torch.imag, cmap='gray') 120 | ax[1][4].set_title('PyTorch ifft imag') 121 | 122 | for cur_ax in ax.flatten(): 123 | cur_ax.axis('off') 124 | plt.tight_layout() 125 | plt.show() 126 | -------------------------------------------------------------------------------- /tests/test_torch_fft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | 22 | ################################################################################ 23 | 24 | x = np.linspace(-10, 10, 100) 25 | y = np.cos(np.pi*x) 26 | 27 | ################################################################################ 28 | # NumPy 29 | 30 | y_fft_numpy = np.fft.fft(y) 31 | 32 | ################################################################################ 33 | # Torch 34 | 35 | device = torch.device('cuda:0') 36 | 37 | y_torch = torch.from_numpy(y[None, :]) 38 | y_torch = y_torch.to(device) 39 | 40 | # fft = complex-to-complex, rfft = real-to-complex 41 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False) 42 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze() 43 | y_fft_torch = y_fft_torch[:, 0] + 1j*y_fft_torch[:, 1] 44 | 45 | tolerance = 1e-6 46 | all_close = np.allclose(y_fft_numpy, y_fft_torch, atol=tolerance) 47 | print('numpy', y_fft_numpy.shape, y_fft_numpy.dtype) 48 | print('torch', y_fft_torch.shape, y_fft_torch.dtype) 49 | print('Succesful: {}'.format(all_close)) 50 | -------------------------------------------------------------------------------- /tests/test_torch_fftshift.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | 24 | from steerable.fft import fftshift 25 | 26 | ################################################################################ 27 | 28 | x = np.linspace(-10, 10, 100) 29 | y = np.cos(np.pi*x) 30 | 31 | ################################################################################ 32 | # NumPy 33 | 34 | y_fft_numpy = np.fft.fft(y) 35 | y_fft_numpy_shift = np.fft.fftshift(y_fft_numpy) 36 | fft_numpy_real = np.real(y_fft_numpy_shift) 37 | fft_numpy_imag = np.imag(y_fft_numpy_shift) 38 | 39 | ################################################################################ 40 | # Torch 41 | 42 | device = torch.device('cuda:0') 43 | 44 | y_torch = torch.from_numpy(y[None, :]) 45 | y_torch = y_torch.to(device) 46 | 47 | # fft = complex-to-complex, rfft = real-to-complex 48 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False) 49 | y_fft_torch = fftshift(y_fft_torch[:,:,0], y_fft_torch[:,:,1]) 50 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze() 51 | fft_torch_real = y_fft_torch[:,0] 52 | fft_torch_imag = y_fft_torch[:,1] 53 | 54 | tolerance = 1e-6 55 | 56 | all_close_real = np.allclose(fft_numpy_real, fft_torch_real, atol=tolerance) 57 | all_close_imag = np.allclose(fft_numpy_imag, fft_torch_imag, atol=tolerance) 58 | 59 | print('fftshift allclose real: {}'.format(all_close_real)) 60 | print('fftshift allclose imag: {}'.format(all_close_imag)) 61 | 62 | ################################################################################ 63 | 64 | import cortex.plot 65 | import cortex.plot.colors 66 | colors = cortex.plot.nature_colors() 67 | 68 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6)) 69 | 70 | # Plotting NumPy results 71 | ax[0][0].plot(x, y, color=colors[3]) 72 | ax[0][1].plot(np.real(y_fft_numpy), color=colors[0]) 73 | ax[0][1].plot(fft_numpy_real, color=colors[1]) 74 | ax[0][1].set_title('NumPy Real') 75 | ax[0][2].plot(np.imag(y_fft_numpy), color=colors[0]) 76 | ax[0][2].plot(fft_numpy_real, color=colors[1]) 77 | ax[0][2].set_title('NumPy Imag') 78 | 79 | # Plotting NumPy results 80 | ax[1][0].plot(x, y, color=colors[3]) 81 | ax[1][1].plot(fft_torch_real, color=colors[1]) 82 | ax[1][1].set_title('Torch Real') 83 | ax[1][2].plot(fft_torch_imag, color=colors[1]) 84 | ax[1][2].set_title('Torch Imag') 85 | 86 | plt.show() 87 | -------------------------------------------------------------------------------- /tests/test_torch_fftshift2d.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | import cv2 24 | 25 | from steerable.fft import fftshift 26 | 27 | ################################################################################ 28 | 29 | image_file = './assets/lena.jpg' 30 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 31 | im = cv2.resize(im, dsize=(200, 200)) 32 | im = im.astype(np.float64)/255. # note the np.float64 33 | 34 | ################################################################################ 35 | # NumPy 36 | 37 | fft_numpy = np.fft.fft2(im) 38 | fft_numpy = np.fft.fftshift(fft_numpy) 39 | 40 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy)) 41 | fft_numpy_ang_viz = np.angle(fft_numpy) 42 | 43 | ################################################################################ 44 | # Torch 45 | 46 | device = torch.device('cuda:0') 47 | 48 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim 49 | im_torch = im_torch.to(device) 50 | 51 | # fft = complex-to-complex, rfft = real-to-complex 52 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False) 53 | fft_torch = fftshift(fft_torch[:,:,:,0], fft_torch[:,:,:,1]) 54 | fft_torch = fft_torch.cpu().numpy().squeeze() 55 | print(fft_torch.shape) 56 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag 57 | fft_torch = np.squeeze(fft_torch, -1) 58 | fft_torch = fft_torch[0] + 1j*fft_torch[1] 59 | 60 | print('fft_torch', fft_torch.shape, fft_torch.dtype) 61 | fft_torch_mag_viz = np.log10(np.abs(fft_torch)) 62 | fft_torch_ang_viz = np.angle(fft_torch) 63 | 64 | tolerance = 1e-6 65 | 66 | print(fft_numpy.dtype, fft_torch.dtype) 67 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) 68 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) 69 | 70 | print('fftshift allclose real: {}'.format(all_close_real)) 71 | print('fftshift allclose imag: {}'.format(all_close_imag)) 72 | 73 | ################################################################################ 74 | 75 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6)) 76 | 77 | # Plotting NumPy results 78 | ax[0][0].imshow(im, cmap='gray') 79 | 80 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray') 81 | ax[0][1].set_title('NumPy Magnitude Spectrum') 82 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray') 83 | ax[0][2].set_title('NumPy Phase Spectrum') 84 | 85 | # Plotting PyTorch results 86 | ax[1][0].imshow(im, cmap='gray') 87 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray') 88 | ax[1][1].set_title('Torch Magnitude Spectrum') 89 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray') 90 | ax[1][2].set_title('Torch Phase Spectrum') 91 | 92 | for cur_ax in ax.flatten(): 93 | cur_ax.axis('off') 94 | plt.tight_layout() 95 | plt.show() 96 | --------------------------------------------------------------------------------