├── .gitignore
├── README.md
├── assets
├── coeff.png
├── lena.jpg
├── patagonia.jpg
├── runtime_benchmark.pdf
└── runtime_benchmark.png
├── examples
├── benchmark_runtime.py
├── compare_both.py
├── example_lena.py
├── example_numpy.py
└── example_pytorch.py
├── license.txt
├── requirements.txt
├── setup.py
├── steerable
├── SCFpyr_NumPy.py
├── SCFpyr_PyTorch.py
├── __init__.py
├── math_utils.py
└── utils.py
└── tests
├── test_ifft.py
├── test_torch_fft.py
├── test_torch_fftshift.py
└── test_torch_fftshift2d.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom ignores
2 | .idea/
3 | .vscode
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # Distribution / packaging
11 | .Python
12 | env/
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *,cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 |
56 | # Sphinx documentation
57 | docs/_build/
58 |
59 | # PyBuilder
60 | target/
61 |
62 | #Ipython Notebook
63 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Complex Steerable Pyramid in PyTorch
2 |
3 | This is a PyTorch implementation of the Complex Steerable Pyramid described in [Portilla and Simoncelli (IJCV, 2000)](http://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Portilla99).
4 |
5 | It uses PyTorch's efficient spectral decomposition layers `torch.fft` and `torch.ifft`. Just like a normal convolution layer, the complex steerable pyramid expects a batch of images of shape `[N,C,H,W]` with current support only for grayscale images (`C=1`). It returns a `list` structure containing the low-pass, high-pass and intermediate levels of the pyramid for each image in the batch (as `torch.Tensor`). Computing the steerable pyramid is significantly faster on the GPU as can be observed from the runtime benchmark below.
6 |
7 |
8 |
9 | ## Usage
10 |
11 | In addition to the PyTorch implementation defined in `SCFpyr_PyTorch` the original SciPy version is also included in `SCFpyr` for completeness and comparison. As the GPU implementation highly benefits from parallelization, the `cwt` and `power` methods expect signal batches of shape `[N,H,W]` containing a batch of `N` images of shape `HxW`.
12 |
13 | ```python
14 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
15 | import steerable.utils as utils
16 |
17 | # Load batch of images [N,1,H,W]
18 | im_batch_numpy = utils.load_image_batch(...)
19 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
20 |
21 | # Requires PyTorch with MKL when setting to 'cpu'
22 | device = torch.device('cuda:0')
23 |
24 | # Initialize Complex Steerbale Pyramid
25 | pyr = SCFpyr_PyTorch(height=5, nbands=4, scale_factor=2, device=device)
26 |
27 | # Decompose entire batch of images
28 | coeff = pyr.build(im_batch_torch)
29 |
30 | # Reconstruct batch of images again
31 | im_batch_reconstructed = pyr.reconstruct(coeff)
32 |
33 | # Visualization
34 | coeff_single = utils.extract_from_batch(coeff, 0)
35 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
36 | cv2.imshow('Complex Steerable Pyramid', coeff_grid)
37 | cv2.waitKey(0)
38 | ```
39 |
40 | ## Benchmark
41 |
42 | Performing parallel the CSP decomposition on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comprison between the `scipy` versus `torch` implementation as function of the batch size `N` and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU.
43 |
44 |
45 |
46 | ## Installation
47 |
48 | Clone and install:
49 |
50 | ```sh
51 | git clone https://github.com/tomrunia/PyTorchSteerablePyramid.git
52 | cd PyTorchSteerablePyramid
53 | pip install -r requirements.txt
54 | python setup.py install
55 | ```
56 |
57 | ## Requirements
58 |
59 | - Python 2.7 or 3.6 (other versions might also work)
60 | - Numpy (developed with 1.15.4)
61 | - Scipy (developed with 1.1.0)
62 | - PyTorch >= 0.4.0 (developed with 1.0.0; see note below)
63 |
64 | The steerable pyramid depends utilizes `torch.fft` and `torch.ifft` to perform operations in the spectral domain. At the moment, PyTorch only implements these operations for the GPU or with the MKL back-end on the CPU. Therefore, if you want to run the code on the CPU you might need to compile PyTorch from source with MKL enabled. Use `torch.backends.mkl.is_available()` to check if MKL is installed.
65 |
66 | ## References
67 |
68 | - [J. Portilla and E.P. Simoncelli, Complex Steerable Pyramid (IJCV, 2000)](http://www.cns.nyu.edu/pub/eero/portilla99-reprint.pdf)
69 | - [The Steerable Pyramid](http://www.cns.nyu.edu/~eero/steerpyr/)
70 | - [Official implementation: matPyrTools](http://www.cns.nyu.edu/~lcv/software.php)
71 | - [perceptual repository by Dzung Nguyen](https://github.com/andreydung/Steerable-filter)
72 |
73 | ## License
74 |
75 | MIT License
76 |
77 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com)
78 |
79 | Permission is hereby granted, free of charge, to any person obtaining a copy
80 | of this software and associated documentation files (the "Software"), to deal
81 | in the Software without restriction, including without limitation the rights
82 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
83 | copies of the Software, and to permit persons to whom the Software is
84 | furnished to do so, subject to the following conditions:
85 |
86 | The above copyright notice and this permission notice shall be included in all
87 | copies or substantial portions of the Software.
88 |
89 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
90 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
91 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
92 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
93 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
94 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
95 | SOFTWARE.
--------------------------------------------------------------------------------
/assets/coeff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/coeff.png
--------------------------------------------------------------------------------
/assets/lena.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/lena.jpg
--------------------------------------------------------------------------------
/assets/patagonia.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/patagonia.jpg
--------------------------------------------------------------------------------
/assets/runtime_benchmark.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/runtime_benchmark.pdf
--------------------------------------------------------------------------------
/assets/runtime_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/assets/runtime_benchmark.png
--------------------------------------------------------------------------------
/examples/benchmark_runtime.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-10
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import time
20 | import argparse
21 |
22 | import numpy as np
23 | import torch
24 | import matplotlib.pyplot as plt
25 |
26 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
27 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
28 | import steerable.utils as utils
29 |
30 | import cortex.plot
31 | cortex.plot.init_plotting()
32 | colors = cortex.plot.nature_colors()
33 |
34 |
35 | if __name__ == "__main__":
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
39 | parser.add_argument('--batch_sizes', type=str, default='1,8,16,32,64,128,256')
40 | parser.add_argument('--image_sizes', type=str, default='128,256,512')
41 | parser.add_argument('--num_runs', type=int, default='5')
42 | parser.add_argument('--pyr_nlevels', type=int, default='5')
43 | parser.add_argument('--pyr_nbands', type=int, default='4')
44 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
45 | parser.add_argument('--device', type=str, default='cuda:0')
46 | config = parser.parse_args()
47 |
48 | config.batch_sizes = list(map(int, config.batch_sizes.split(',')))
49 | config.image_sizes = list(map(int, config.image_sizes.split(',')))
50 |
51 | device = utils.get_device(config.device)
52 |
53 | ################################################################################
54 |
55 | pyr_numpy = SCFpyr_NumPy(
56 | height=config.pyr_nlevels,
57 | nbands=config.pyr_nbands,
58 | scale_factor=config.pyr_scale_factor,
59 | )
60 |
61 | pyr_torch = SCFpyr_PyTorch(
62 | height=config.pyr_nlevels,
63 | nbands=config.pyr_nbands,
64 | scale_factor=config.pyr_scale_factor,
65 | device=device
66 | )
67 |
68 | ############################################################################
69 | # Run Benchmark
70 |
71 | durations_numpy = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs))
72 | durations_torch = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs))
73 |
74 | for batch_idx, batch_size in enumerate(config.batch_sizes):
75 |
76 | for size_idx, im_size in enumerate(config.image_sizes):
77 |
78 | for run_idx in range(config.num_runs):
79 |
80 | im_batch_numpy = utils.load_image_batch(config.image_file, batch_size, im_size)
81 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
82 |
83 | # NumPy implementation
84 | start_time = time.time()
85 |
86 | for image in im_batch_numpy:
87 | coeff = pyr_numpy.build(image[0,])
88 |
89 | duration = time.time()-start_time
90 | durations_numpy[batch_idx,size_idx,run_idx] = duration
91 | print('BatchSize: {batch_size} | ImSize: {im_size} | NumPy Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format(
92 | batch_size=batch_size,
93 | im_size=im_size,
94 | curr_run=run_idx+1,
95 | num_run=config.num_runs,
96 | duration=duration
97 | ))
98 |
99 | # PyTorch Implementation
100 | start_time = time.time()
101 |
102 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
103 | coeff = pyr_torch.build(im_batch_torch)
104 |
105 | duration = time.time()-start_time
106 | durations_torch[batch_idx,size_idx,run_idx] = duration
107 | print('BatchSize: {batch_size} | ImSize: {im_size} | Torch Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format(
108 | batch_size=batch_size,
109 | im_size=im_size,
110 | curr_run=run_idx+1,
111 | num_run=config.num_runs,
112 | duration=duration
113 | ))
114 |
115 | np.save('./assets/durations_numpy.npy', durations_numpy)
116 | np.save('./assets/durations_torch.npy', durations_torch)
117 |
118 | ################################################################################
119 | # Plotting
120 |
121 | durations_numpy = np.load('./assets/durations_numpy.npy')
122 | durations_torch = np.load('./assets/durations_torch.npy')
123 |
124 | for i, num_examples in enumerate(config.batch_sizes):
125 | if num_examples == 8: continue
126 | avg_durations_numpy = np.mean(durations_numpy[i,:], -1) / num_examples
127 | avg_durations_torch = np.mean(durations_torch[i,:], -1) / num_examples
128 | plt.plot(config.image_sizes, avg_durations_numpy, marker='o', linestyle='-', lw=1.4, color=colors[i], label='numpy (N = {})'.format(num_examples))
129 | plt.plot(config.image_sizes, avg_durations_torch, marker='d', linestyle='--', lw=1.4, color=colors[i], label='pytorch (N = {})'.format(num_examples))
130 |
131 | plt.title('Runtime Benchmark ({} levels, {} bands, average {numruns} runs)'.format(config.pyr_nlevels, config.pyr_nbands, numruns=config.num_runs))
132 | plt.xlabel('Image Size (px)')
133 | plt.ylabel('Time per Example (s)')
134 | plt.xlim((100, 550))
135 | plt.ylim((-0.01, 0.2))
136 | plt.xticks(config.image_sizes)
137 | plt.legend(ncol=2, loc='top left')
138 | plt.tight_layout()
139 |
140 | plt.savefig('./assets/runtime_benchmark.png', dpi=600)
141 | plt.savefig('./assets/runtime_benchmark.pdf')
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/examples/compare_both.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | import cv2
22 |
23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
25 | import steerable.utils as utils
26 |
27 | ################################################################################
28 | ################################################################################
29 | # Common
30 |
31 | image_file = './assets/lena.jpg'
32 | image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
33 | image = cv2.resize(image, (200,200))
34 |
35 | # Number of pyramid levels
36 | pyr_height = 5
37 |
38 | # Number of orientation bands
39 | pyr_nbands = 4
40 |
41 | # Tolerance for error checking
42 | tolerance = 1e-3
43 |
44 | ################################################################################
45 | # NumPy
46 |
47 | pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
48 | coeff_numpy = pyr_numpy.build(image)
49 | reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)
50 | reconstruction_numpy = reconstruction_numpy.astype(np.uint8)
51 |
52 | print('#'*60)
53 |
54 | ################################################################################
55 | # PyTorch
56 |
57 | device = torch.device('cuda:0')
58 |
59 | im_batch = torch.from_numpy(image[None,None,:,:])
60 | im_batch = im_batch.to(device).float()
61 |
62 | pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
63 | coeff_torch = pyr_torch.build(im_batch)
64 | reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
65 | reconstruction_torch = reconstruction_torch.cpu().numpy()[0,]
66 |
67 | # Extract first example from the batch and move to CPU
68 | coeff_torch = utils.extract_from_batch(coeff_torch, 0)
69 |
70 | ################################################################################
71 | # Check correctness
72 |
73 | print('#'*60)
74 | assert len(coeff_numpy) == len(coeff_torch)
75 |
76 | for level, _ in enumerate(coeff_numpy):
77 |
78 | print('Pyramid Level {level}'.format(level=level))
79 | coeff_level_numpy = coeff_numpy[level]
80 | coeff_level_torch = coeff_torch[level]
81 |
82 | assert isinstance(coeff_level_torch, type(coeff_level_numpy))
83 |
84 | if isinstance(coeff_level_numpy, np.ndarray):
85 |
86 | # Low- or High-Pass
87 | print(' NumPy. min = {min:.3f}, max = {max:.3f},'
88 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
89 | min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
90 | mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)))
91 |
92 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
93 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
94 | min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
95 | mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)))
96 |
97 | # Check numerical correctness
98 | assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance)
99 |
100 | elif isinstance(coeff_level_numpy, list):
101 |
102 | # Intermediate bands
103 | for band, _ in enumerate(coeff_level_numpy):
104 |
105 | band_numpy = coeff_level_numpy[band]
106 | band_torch = coeff_level_torch[band]
107 |
108 | print(' Orientation Band {}'.format(band))
109 | print(' NumPy. min = {min:.3f}, max = {max:.3f},'
110 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
111 | min=np.min(band_numpy), max=np.max(band_numpy),
112 | mean=np.mean(band_numpy), std=np.std(band_numpy)))
113 |
114 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
115 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
116 | min=np.min(band_torch), max=np.max(band_torch),
117 | mean=np.mean(band_torch), std=np.std(band_torch)))
118 |
119 | # Check numerical correctness
120 | assert np.allclose(band_numpy, band_torch, atol=tolerance)
121 |
122 | ################################################################################
123 | # Visualize
124 |
125 | coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False)
126 | coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False)
127 |
128 | cv2.imshow('image', image)
129 | cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy))
130 | cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch))
131 | cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8))
132 | cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8))
133 |
134 | cv2.waitKey(0)
135 |
--------------------------------------------------------------------------------
/examples/example_lena.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import cv2
21 |
22 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
23 | import steerable.utils as utils
24 |
25 | ################################################################################
26 | ################################################################################
27 |
28 | if __name__ == "__main__":
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
32 | parser.add_argument('--batch_size', type=int, default='1')
33 | parser.add_argument('--image_size', type=int, default='200')
34 | parser.add_argument('--pyr_nlevels', type=int, default='5')
35 | parser.add_argument('--pyr_nbands', type=int, default='4')
36 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
37 | parser.add_argument('--visualize', type=bool, default=True)
38 | config = parser.parse_args()
39 |
40 | ############################################################################
41 | # Build the complex steerable pyramid
42 |
43 | pyr = SCFpyr_NumPy(
44 | height=config.pyr_nlevels,
45 | nbands=config.pyr_nbands,
46 | scale_factor=config.pyr_scale_factor,
47 | )
48 |
49 | ############################################################################
50 | # Create a batch and feed-forward
51 |
52 | image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE)
53 | image = cv2.resize(image, dsize=(200,200))
54 |
55 | coeff = pyr.build(image)
56 |
57 | grid = utils.make_grid_coeff(coeff, normalize=True)
58 |
59 | cv2.imwrite('./assets/coeff.png', grid)
60 | cv2.imshow('image', image)
61 | cv2.imshow('coeff', grid)
62 | cv2.waitKey(0)
63 |
64 |
--------------------------------------------------------------------------------
/examples/example_numpy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import time
20 | import argparse
21 | import numpy as np
22 |
23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
24 | import steerable.utils as utils
25 |
26 | ################################################################################
27 | ################################################################################
28 |
29 | if __name__ == "__main__":
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
33 | parser.add_argument('--batch_size', type=int, default='1')
34 | parser.add_argument('--image_size', type=int, default='200')
35 | parser.add_argument('--pyr_nlevels', type=int, default='5')
36 | parser.add_argument('--pyr_nbands', type=int, default='4')
37 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
38 | parser.add_argument('--visualize', type=bool, default=True)
39 | config = parser.parse_args()
40 |
41 | ############################################################################
42 | # Build the complex steerable pyramid
43 |
44 | pyr = SCFpyr_NumPy(
45 | height=config.pyr_nlevels,
46 | nbands=config.pyr_nbands,
47 | scale_factor=config.pyr_scale_factor,
48 | )
49 |
50 | ############################################################################
51 | # Create a batch and feed-forward
52 |
53 | start_time = time.time()
54 |
55 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size)
56 | im_batch_numpy = im_batch_numpy.squeeze(1) # no channel dim for NumPy
57 |
58 | # Compute Steerable Pyramid
59 | start_time = time.time()
60 | for image in im_batch_numpy:
61 | coeff = pyr.build(image)
62 |
63 | duration = time.time()-start_time
64 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format(
65 | batch_size=config.batch_size,
66 | duration=duration
67 | ))
68 |
69 | ############################################################################
70 | # Visualization
71 |
72 | if config.visualize:
73 | import cv2
74 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
75 | cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8))
76 | cv2.imshow('coeff', coeff_grid)
77 | cv2.waitKey(0)
78 |
79 |
--------------------------------------------------------------------------------
/examples/example_pytorch.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import time
21 | import numpy as np
22 | import torch
23 |
24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
25 | import steerable.utils as utils
26 |
27 | ################################################################################
28 | ################################################################################
29 |
30 | if __name__ == "__main__":
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
34 | parser.add_argument('--batch_size', type=int, default='32')
35 | parser.add_argument('--image_size', type=int, default='200')
36 | parser.add_argument('--pyr_nlevels', type=int, default='5')
37 | parser.add_argument('--pyr_nbands', type=int, default='4')
38 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
39 | parser.add_argument('--device', type=str, default='cuda:0')
40 | parser.add_argument('--visualize', type=bool, default=True)
41 | config = parser.parse_args()
42 |
43 | device = utils.get_device(config.device)
44 |
45 | ############################################################################
46 | # Build the complex steerable pyramid
47 |
48 | pyr = SCFpyr_PyTorch(
49 | height=config.pyr_nlevels,
50 | nbands=config.pyr_nbands,
51 | scale_factor=config.pyr_scale_factor,
52 | device=device
53 | )
54 |
55 | ############################################################################
56 | # Create a batch and feed-forward
57 |
58 | start_time = time.time()
59 |
60 | # Load Batch
61 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size)
62 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
63 |
64 | # Compute Steerable Pyramid
65 | coeff = pyr.build(im_batch_torch)
66 |
67 | duration = time.time()-start_time
68 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format(
69 | batch_size=config.batch_size,
70 | duration=duration
71 | ))
72 |
73 | ############################################################################
74 | # Visualization
75 |
76 | # Just extract a single example from the batch
77 | # Also moves the example to CPU and NumPy
78 | coeff = utils.extract_from_batch(coeff, 0)
79 |
80 | if config.visualize:
81 | import cv2
82 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
83 | cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8))
84 | cv2.imshow('coeff', coeff_grid)
85 | cv2.waitKey(0)
86 |
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | six
2 | numpy
3 | scipy
4 | torch >= 0.4.0
5 | matplotlib
6 | pillow
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from setuptools import setup
3 |
4 | setup(
5 | name='steerable_pytorch',
6 | version='0.1',
7 | author='Tom Runia',
8 | author_email='tomrunia@gmail.com',
9 | url='https://github.com/tomrunia/PyTorchSteerablePyramid',
10 | description='Complex Steerable Pyramids in PyTorch',
11 | long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.',
12 | license='MIT',
13 | packages=['steerable_pytorch'],
14 | scripts=[]
15 | )
--------------------------------------------------------------------------------
/steerable/SCFpyr_NumPy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | from scipy.misc import factorial
21 |
22 | import steerable.math_utils as math_utils
23 | pointOp = math_utils.pointOp
24 |
25 | ################################################################################
26 |
27 | class SCFpyr_NumPy():
28 | '''
29 | This is a modified version of buildSFpyr, that constructs a
30 | complex-valued steerable pyramid using Hilbert-transform pairs
31 | of filters. Note that the imaginary parts will *not* be steerable.
32 |
33 | Description of this transform appears in: Portilla & Simoncelli,
34 | International Journal of Computer Vision, 40(1):49-71, Oct 2000.
35 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/
36 |
37 | Modified code from the perceptual repository:
38 | https://github.com/andreydung/Steerable-filter
39 |
40 | This code looks very similar to the original Matlab code:
41 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m
42 |
43 | Also looks very similar to the original Python code presented here:
44 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py
45 |
46 | '''
47 |
48 | def __init__(self, height=5, nbands=4, scale_factor=2):
49 | self.nbands = nbands # number of orientation bands
50 | self.height = height # including low-pass and high-pass
51 | self.scale_factor = scale_factor
52 |
53 | # Cache constants
54 | self.lutsize = 1024
55 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize
56 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi
57 |
58 |
59 | ################################################################################
60 | # Construction of Steerable Pyramid
61 |
62 | def build(self, im):
63 | ''' Decomposes an image into it's complex steerable pyramid.
64 | The pyramid typically has ~4 levels and 4-8 orientations.
65 |
66 | Args:
67 | im_batch (np.ndarray): single image [H,W]
68 |
69 | Returns:
70 | pyramid: list containing np.ndarray objects storing the pyramid
71 | '''
72 |
73 | assert len(im.shape) == 2, 'Input im must be grayscale'
74 | height, width = im.shape
75 |
76 | # Check whether image size is sufficient for number of levels
77 | if self.height > int(np.floor(np.log2(min(width, height))) - 2):
78 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
79 |
80 | # Prepare a grid
81 | log_rad, angle = math_utils.prepare_grid(height, width)
82 |
83 | # Radial transition function (a raised cosine in log-frequency):
84 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
85 | Yrcos = np.sqrt(Yrcos)
86 |
87 | YIrcos = np.sqrt(1 - Yrcos**2)
88 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
89 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
90 |
91 | # Shift the zero-frequency component to the center of the spectrum.
92 | imdft = np.fft.fftshift(np.fft.fft2(im))
93 |
94 | # Low-pass
95 | lo0dft = imdft * lo0mask
96 |
97 | # Recursive build the steerable pyramid
98 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)
99 |
100 | # High-pass
101 | hi0dft = imdft * hi0mask
102 | hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft))
103 | coeff.insert(0, hi0.real)
104 | return coeff
105 |
106 |
107 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
108 |
109 | if height <= 1:
110 |
111 | # Low-pass
112 | lo0 = np.fft.ifftshift(lodft)
113 | lo0 = np.fft.ifft2(lo0)
114 | coeff = [lo0.real]
115 |
116 | else:
117 |
118 | Xrcos = Xrcos - np.log2(self.scale_factor)
119 |
120 | ####################################################################
121 | ####################### Orientation bandpass #######################
122 | ####################################################################
123 |
124 | himask = pointOp(log_rad, Yrcos, Xrcos)
125 |
126 | order = self.nbands - 1
127 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
128 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2)
129 |
130 | # Loop through all orientation bands
131 | orientations = []
132 | for b in range(self.nbands):
133 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
134 | banddft = np.power(np.complex(0, -1), self.nbands - 1) * lodft * anglemask * himask
135 | band = np.fft.ifft2(np.fft.ifftshift(banddft))
136 | orientations.append(band)
137 |
138 | ####################################################################
139 | ######################## Subsample lowpass #########################
140 | ####################################################################
141 |
142 | dims = np.array(lodft.shape)
143 |
144 | # Both are tuples of size 2
145 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
146 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
147 |
148 | # Selection
149 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
150 | angle = angle[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
151 | lodft = lodft[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
152 |
153 | # Subsampling in frequency domain
154 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
155 | lomask = pointOp(log_rad, YIrcos, Xrcos)
156 | lodft = lomask * lodft
157 |
158 | ####################################################################
159 | ####################### Recursion next level #######################
160 | ####################################################################
161 |
162 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
163 | coeff.insert(0, orientations)
164 |
165 | return coeff
166 |
167 | ############################################################################
168 | ########################### RECONSTRUCTION #################################
169 | ############################################################################
170 |
171 | def reconstruct(self, coeff):
172 |
173 | if self.nbands != len(coeff[1]):
174 | raise Exception("Unmatched number of orientations")
175 |
176 | height, width = coeff[0].shape
177 | log_rad, angle = math_utils.prepare_grid(height, width)
178 |
179 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
180 | Yrcos = np.sqrt(Yrcos)
181 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
182 |
183 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
184 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
185 |
186 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)
187 |
188 | hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
189 | outdft = tempdft * lo0mask + hidft * hi0mask
190 |
191 | reconstruction = np.fft.ifftshift(outdft)
192 | reconstruction = np.fft.ifft2(reconstruction)
193 | reconstruction = reconstruction.real
194 |
195 | return reconstruction
196 |
197 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
198 |
199 | if len(coeff) == 1:
200 | dft = np.fft.fft2(coeff[0])
201 | dft = np.fft.fftshift(dft)
202 | return dft
203 |
204 | Xrcos = Xrcos - np.log2(self.scale_factor)
205 |
206 | ####################################################################
207 | ####################### Orientation Residue ########################
208 | ####################################################################
209 |
210 | himask = pointOp(log_rad, Yrcos, Xrcos)
211 |
212 | lutsize = 1024
213 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
214 | order = self.nbands - 1
215 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
216 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)
217 |
218 | orientdft = np.zeros(coeff[0][0].shape)
219 |
220 | for b in range(self.nbands):
221 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
222 | banddft = np.fft.fft2(coeff[0][b])
223 | banddft = np.fft.fftshift(banddft)
224 | orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask
225 |
226 | ####################################################################
227 | ########## Lowpass component are upsampled and convoluted ##########
228 | ####################################################################
229 |
230 | dims = np.array(coeff[0][0].shape)
231 |
232 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
233 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)
234 |
235 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
236 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
237 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
238 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
239 |
240 | ################################################################################
241 |
242 | # Recursive call for image reconstruction
243 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)
244 |
245 | resdft = np.zeros(dims, 'complex')
246 | resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask
247 |
248 | return resdft + orientdft
249 |
--------------------------------------------------------------------------------
/steerable/SCFpyr_PyTorch.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | from scipy.misc import factorial
22 |
23 | import steerable.math_utils as math_utils
24 | pointOp = math_utils.pointOp
25 |
26 | ################################################################################
27 | ################################################################################
28 |
29 |
30 | class SCFpyr_PyTorch(object):
31 | '''
32 | This is a modified version of buildSFpyr, that constructs a
33 | complex-valued steerable pyramid using Hilbert-transform pairs
34 | of filters. Note that the imaginary parts will *not* be steerable.
35 |
36 | Description of this transform appears in: Portilla & Simoncelli,
37 | International Journal of Computer Vision, 40(1):49-71, Oct 2000.
38 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/
39 |
40 | Modified code from the perceptual repository:
41 | https://github.com/andreydung/Steerable-filter
42 |
43 | This code looks very similar to the original Matlab code:
44 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m
45 |
46 | Also looks very similar to the original Python code presented here:
47 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py
48 |
49 | '''
50 |
51 | def __init__(self, height=5, nbands=4, scale_factor=2, device=None):
52 | self.height = height # including low-pass and high-pass
53 | self.nbands = nbands # number of orientation bands
54 | self.scale_factor = scale_factor
55 | self.device = torch.device('cpu') if device is None else device
56 |
57 | # Cache constants
58 | self.lutsize = 1024
59 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize
60 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi
61 | self.complex_fact_construct = np.power(np.complex(0, -1), self.nbands-1)
62 | self.complex_fact_reconstruct = np.power(np.complex(0, 1), self.nbands-1)
63 |
64 | ################################################################################
65 | # Construction of Steerable Pyramid
66 |
67 | def build(self, im_batch):
68 | ''' Decomposes a batch of images into a complex steerable pyramid.
69 | The pyramid typically has ~4 levels and 4-8 orientations.
70 |
71 | Args:
72 | im_batch (torch.Tensor): Batch of images of shape [N,C,H,W]
73 |
74 | Returns:
75 | pyramid: list containing torch.Tensor objects storing the pyramid
76 | '''
77 |
78 | assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device)
79 | assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
80 | assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
81 | assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'
82 |
83 | im_batch = im_batch.squeeze(1) # flatten channels dim
84 | height, width = im_batch.shape[2], im_batch.shape[1]
85 |
86 | # Check whether image size is sufficient for number of levels
87 | if self.height > int(np.floor(np.log2(min(width, height))) - 2):
88 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
89 |
90 | # Prepare a grid
91 | log_rad, angle = math_utils.prepare_grid(height, width)
92 |
93 | # Radial transition function (a raised cosine in log-frequency):
94 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
95 | Yrcos = np.sqrt(Yrcos)
96 |
97 | YIrcos = np.sqrt(1 - Yrcos**2)
98 |
99 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
100 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
101 |
102 | # Note that we expand dims to support broadcasting later
103 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
104 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
105 |
106 | # Fourier transform (2D) and shifting
107 | batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
108 | batch_dft = math_utils.batch_fftshift2d(batch_dft)
109 |
110 | # Low-pass
111 | lo0dft = batch_dft * lo0mask
112 |
113 | # Start recursively building the pyramids
114 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)
115 |
116 | # High-pass
117 | hi0dft = batch_dft * hi0mask
118 | hi0 = math_utils.batch_ifftshift2d(hi0dft)
119 | hi0 = torch.ifft(hi0, signal_ndim=2)
120 | hi0_real = torch.unbind(hi0, -1)[0]
121 | coeff.insert(0, hi0_real)
122 | return coeff
123 |
124 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
125 |
126 | if height <= 1:
127 |
128 | # Low-pass
129 | lo0 = math_utils.batch_ifftshift2d(lodft)
130 | lo0 = torch.ifft(lo0, signal_ndim=2)
131 | lo0_real = torch.unbind(lo0, -1)[0]
132 | coeff = [lo0_real]
133 |
134 | else:
135 |
136 | Xrcos = Xrcos - np.log2(self.scale_factor)
137 |
138 | ####################################################################
139 | ####################### Orientation bandpass #######################
140 | ####################################################################
141 |
142 | himask = pointOp(log_rad, Yrcos, Xrcos)
143 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
144 |
145 | order = self.nbands - 1
146 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
147 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) # [n,]
148 |
149 | # Loop through all orientation bands
150 | orientations = []
151 | for b in range(self.nbands):
152 |
153 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
154 | anglemask = anglemask[None,:,:,None] # for broadcasting
155 | anglemask = torch.from_numpy(anglemask).float().to(self.device)
156 |
157 | # Bandpass filtering
158 | banddft = lodft * anglemask * himask
159 |
160 | # Now multiply with complex number
161 | # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
162 | banddft = torch.unbind(banddft, -1)
163 | banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1]
164 | banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0]
165 | banddft = torch.stack((banddft_real, banddft_imag), -1)
166 |
167 | band = math_utils.batch_ifftshift2d(banddft)
168 | band = torch.ifft(band, signal_ndim=2)
169 | orientations.append(band)
170 |
171 | ####################################################################
172 | ######################## Subsample lowpass #########################
173 | ####################################################################
174 |
175 | # Don't consider batch_size and imag/real dim
176 | dims = np.array(lodft.shape[1:3])
177 |
178 | # Both are tuples of size 2
179 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
180 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
181 |
182 | # Subsampling indices
183 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]
184 | angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]
185 |
186 | # Actual subsampling
187 | lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:]
188 |
189 | # Filtering
190 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
191 | lomask = pointOp(log_rad, YIrcos, Xrcos)
192 | lomask = torch.from_numpy(lomask[None,:,:,None]).float()
193 | lomask = lomask.to(self.device)
194 |
195 | # Convolution in spatial domain
196 | lodft = lomask * lodft
197 |
198 | ####################################################################
199 | ####################### Recursion next level #######################
200 | ####################################################################
201 |
202 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
203 | coeff.insert(0, orientations)
204 |
205 | return coeff
206 |
207 | ############################################################################
208 | ########################### RECONSTRUCTION #################################
209 | ############################################################################
210 |
211 | def reconstruct(self, coeff):
212 |
213 | if self.nbands != len(coeff[1]):
214 | raise Exception("Unmatched number of orientations")
215 |
216 | height, width = coeff[0].shape[2], coeff[0].shape[1]
217 | log_rad, angle = math_utils.prepare_grid(height, width)
218 |
219 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
220 | Yrcos = np.sqrt(Yrcos)
221 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
222 |
223 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
224 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
225 |
226 | # Note that we expand dims to support broadcasting later
227 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
228 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
229 |
230 | # Start recursive reconstruction
231 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)
232 |
233 | hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
234 | hidft = math_utils.batch_fftshift2d(hidft)
235 |
236 | outdft = tempdft * lo0mask + hidft * hi0mask
237 |
238 | reconstruction = math_utils.batch_ifftshift2d(outdft)
239 | reconstruction = torch.ifft(reconstruction, signal_ndim=2)
240 | reconstruction = torch.unbind(reconstruction, -1)[0] # real
241 |
242 | return reconstruction
243 |
244 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
245 |
246 | if len(coeff) == 1:
247 | dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
248 | dft = math_utils.batch_fftshift2d(dft)
249 | return dft
250 |
251 | Xrcos = Xrcos - np.log2(self.scale_factor)
252 |
253 | ####################################################################
254 | ####################### Orientation Residue ########################
255 | ####################################################################
256 |
257 | himask = pointOp(log_rad, Yrcos, Xrcos)
258 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
259 |
260 | lutsize = 1024
261 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
262 | order = self.nbands - 1
263 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
264 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)
265 |
266 | orientdft = torch.zeros_like(coeff[0][0])
267 | for b in range(self.nbands):
268 |
269 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
270 | anglemask = anglemask[None,:,:,None] # for broadcasting
271 | anglemask = torch.from_numpy(anglemask).float().to(self.device)
272 |
273 | banddft = torch.fft(coeff[0][b], signal_ndim=2)
274 | banddft = math_utils.batch_fftshift2d(banddft)
275 |
276 | banddft = banddft * anglemask * himask
277 | banddft = torch.unbind(banddft, -1)
278 | banddft_real = self.complex_fact_reconstruct.real*banddft[0] - self.complex_fact_reconstruct.imag*banddft[1]
279 | banddft_imag = self.complex_fact_reconstruct.real*banddft[1] + self.complex_fact_reconstruct.imag*banddft[0]
280 | banddft = torch.stack((banddft_real, banddft_imag), -1)
281 |
282 | orientdft = orientdft + banddft
283 |
284 | ####################################################################
285 | ########## Lowpass component are upsampled and convoluted ##########
286 | ####################################################################
287 |
288 | dims = np.array(coeff[0][0].shape[1:3])
289 |
290 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
291 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)
292 |
293 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
294 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
295 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
296 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
297 |
298 | # Filtering
299 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
300 | lomask = torch.from_numpy(lomask[None,:,:,None])
301 | lomask = lomask.float().to(self.device)
302 |
303 | ################################################################################
304 |
305 | # Recursive call for image reconstruction
306 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)
307 |
308 | resdft = torch.zeros_like(coeff[0][0]).to(self.device)
309 | resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask
310 |
311 | return resdft + orientdft
312 |
--------------------------------------------------------------------------------
/steerable/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomrunia/PyTorchSteerablePyramid/e54981e7fcfd24263354d9c11fe70cb44457a594/steerable/__init__.py
--------------------------------------------------------------------------------
/steerable/math_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 |
22 | ################################################################################
23 | ################################################################################
24 |
25 | def roll_n(X, axis, n):
26 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
27 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
28 | front = X[f_idx]
29 | back = X[b_idx]
30 | return torch.cat([back, front], axis)
31 |
32 | def batch_fftshift2d(x):
33 | real, imag = torch.unbind(x, -1)
34 | for dim in range(1, len(real.size())):
35 | n_shift = real.size(dim)//2
36 | if real.size(dim) % 2 != 0:
37 | n_shift += 1 # for odd-sized images
38 | real = roll_n(real, axis=dim, n=n_shift)
39 | imag = roll_n(imag, axis=dim, n=n_shift)
40 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
41 |
42 | def batch_ifftshift2d(x):
43 | real, imag = torch.unbind(x, -1)
44 | for dim in range(len(real.size()) - 1, 0, -1):
45 | real = roll_n(real, axis=dim, n=real.size(dim)//2)
46 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
47 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
48 |
49 | ################################################################################
50 | ################################################################################
51 |
52 | def prepare_grid(m, n):
53 | x = np.linspace(-(m // 2)/(m / 2), (m // 2)/(m / 2) - (1 - m % 2)*2/m, num=m)
54 | y = np.linspace(-(n // 2)/(n / 2), (n // 2)/(n / 2) - (1 - n % 2)*2/n, num=n)
55 | xv, yv = np.meshgrid(y, x)
56 | angle = np.arctan2(yv, xv)
57 | rad = np.sqrt(xv**2 + yv**2)
58 | rad[m//2][n//2] = rad[m//2][n//2 - 1]
59 | log_rad = np.log2(rad)
60 | return log_rad, angle
61 |
62 | def rcosFn(width, position):
63 | N = 256 # abritrary
64 | X = np.pi * np.array(range(-N-1, 2))/2/N
65 | Y = np.cos(X)**2
66 | Y[0] = Y[1]
67 | Y[N+2] = Y[N+1]
68 | X = position + 2*width/np.pi*(X + np.pi/4)
69 | return X, Y
70 |
71 | def pointOp(im, Y, X):
72 | out = np.interp(im.flatten(), X, Y)
73 | return np.reshape(out, im.shape)
74 |
75 | def getlist(coeff):
76 | straight = [bands for scale in coeff[1:-1] for bands in scale]
77 | straight = [coeff[0]] + straight + [coeff[-1]]
78 | return straight
79 |
80 | ################################################################################
81 | # NumPy reference implementation (fftshift and ifftshift)
82 |
83 | # def fftshift(x, axes=None):
84 | # """
85 | # Shift the zero-frequency component to the center of the spectrum.
86 | # This function swaps half-spaces for all axes listed (defaults to all).
87 | # Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
88 | # Parameters
89 | # """
90 | # x = np.asarray(x)
91 | # if axes is None:
92 | # axes = tuple(range(x.ndim))
93 | # shift = [dim // 2 for dim in x.shape]
94 | # shift = [x.shape[ax] // 2 for ax in axes]
95 | # return np.roll(x, shift, axes)
96 | #
97 | # def ifftshift(x, axes=None):
98 | # """
99 | # The inverse of `fftshift`. Although identical for even-length `x`, the
100 | # functions differ by one sample for odd-length `x`.
101 | # """
102 | # x = np.asarray(x)
103 | # if axes is None:
104 | # axes = tuple(range(x.ndim))
105 | # shift = [-(dim // 2) for dim in x.shape]
106 | # shift = [-(x.shape[ax] // 2) for ax in axes]
107 | # return np.roll(x, shift, axes)
108 |
--------------------------------------------------------------------------------
/steerable/utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-10
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import numpy as np
21 | import skimage
22 | import matplotlib.pyplot as plt
23 |
24 | import torch
25 | import torchvision
26 |
27 | ################################################################################
28 |
29 | ToPIL = torchvision.transforms.ToPILImage()
30 | Grayscale = torchvision.transforms.Grayscale()
31 | RandomCrop = torchvision.transforms.RandomCrop
32 |
33 | def get_device(device='cuda:0'):
34 | assert isinstance(device, str)
35 | num_cuda = torch.cuda.device_count()
36 |
37 | if 'cuda' in device:
38 | if num_cuda > 0:
39 | # Found CUDA device, use the GPU
40 | return torch.device(device)
41 | # Fallback to CPU
42 | print('No CUDA devices found, falling back to CPU')
43 | device = 'cpu'
44 |
45 | if not torch.backends.mkl.is_available():
46 | raise NotImplementedError(
47 | 'torch.fft on the CPU requires MKL back-end. ' +
48 | 'Please recompile your PyTorch distribution.')
49 | return torch.device('cpu')
50 |
51 | def load_image_batch(image_file, batch_size, image_size=200):
52 | if not os.path.isfile(image_file):
53 | raise FileNotFoundError('Image file not found on disk: {}'.format(image_file))
54 | im = ToPIL(skimage.io.imread(image_file))
55 | im = Grayscale(im)
56 | im_batch = np.zeros((batch_size, image_size, image_size), np.float32)
57 | for i in range(batch_size):
58 | im_batch[i] = RandomCrop(image_size)(im)
59 | # insert channels dim and rescale
60 | return im_batch[:,None,:,:]/225.
61 |
62 | def show_image_batch(im_batch):
63 | assert isinstance(im_batch, torch.Tensor)
64 | im_batch = torchvision.utils.make_grid(im_batch).numpy()
65 | im_batch = np.transpose(im_batch.squeeze(1), (1,2,0))
66 | plt.imshow(im_batch)
67 | plt.axis('off')
68 | plt.tight_layout()
69 | plt.show()
70 | return im_batch
71 |
72 | def extract_from_batch(coeff_batch, example_idx=0):
73 | '''
74 | Given the batched Complex Steerable Pyramid, extract the coefficients
75 | for a single example from the batch. Additionally, it converts all
76 | torch.Tensor's to np.ndarrays' and changes creates proper np.complex
77 | objects for all the orientation bands.
78 |
79 | Args:
80 | coeff_batch (list): list containing low-pass, high-pass and pyr levels
81 | example_idx (int, optional): Defaults to 0. index in batch to extract
82 |
83 | Returns:
84 | list: list containing low-pass, high-pass and pyr levels as np.ndarray
85 | '''
86 | if not isinstance(coeff_batch, list):
87 | raise ValueError('Batch of coefficients must be a list')
88 | coeff = [] # coefficient for single example
89 | for coeff_level in coeff_batch:
90 | if isinstance(coeff_level, torch.Tensor):
91 | # Low- or High-Pass
92 | coeff_level_numpy = coeff_level[example_idx].cpu().numpy()
93 | coeff.append(coeff_level_numpy)
94 | elif isinstance(coeff_level, list):
95 | coeff_orientations_numpy = []
96 | for coeff_orientation in coeff_level:
97 | coeff_orientation_numpy = coeff_orientation[example_idx].cpu().numpy()
98 | coeff_orientation_numpy = coeff_orientation_numpy[:,:,0] + 1j*coeff_orientation_numpy[:,:,1]
99 | coeff_orientations_numpy.append(coeff_orientation_numpy)
100 | coeff.append(coeff_orientations_numpy)
101 | else:
102 | raise ValueError('coeff leve must be of type (list, torch.Tensor)')
103 | return coeff
104 |
105 | ################################################################################
106 |
107 | def make_grid_coeff(coeff, normalize=True):
108 | '''
109 | Visualization function for building a large image that contains the
110 | low-pass, high-pass and all intermediate levels in the steerable pyramid.
111 | For the complex intermediate bands, the real part is visualized.
112 |
113 | Args:
114 | coeff (list): complex pyramid stored as list containing all levels
115 | normalize (bool, optional): Defaults to True. Whether to normalize each band
116 |
117 | Returns:
118 | np.ndarray: large image that contains grid of all bands and orientations
119 | '''
120 | M, N = coeff[1][0].shape
121 | Norients = len(coeff[1])
122 | out = np.zeros((M * 2 - coeff[-1].shape[0], Norients * N))
123 | currentx, currenty = 0, 0
124 |
125 | for i in range(1, len(coeff[:-1])):
126 | for j in range(len(coeff[1])):
127 | tmp = coeff[i][j].real
128 | m, n = tmp.shape
129 | if normalize:
130 | tmp = 255 * tmp/tmp.max()
131 | tmp[m-1,:] = 255
132 | tmp[:,n-1] = 255
133 | out[currentx:currentx+m,currenty:currenty+n] = tmp
134 | currenty += n
135 | currentx += coeff[i][0].shape[0]
136 | currenty = 0
137 |
138 | m, n = coeff[-1].shape
139 | out[currentx: currentx+m, currenty: currenty+n] = 255 * coeff[-1]/coeff[-1].max()
140 | out[0,:] = 255
141 | out[:,0] = 255
142 | return out.astype(np.uint8)
143 |
--------------------------------------------------------------------------------
/tests/test_ifft.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-07
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | import cv2
22 |
23 | import steerable.fft as fft_utils
24 | import matplotlib.pyplot as plt
25 |
26 | ################################################################################
27 |
28 | tolerance = 1e-6
29 |
30 | image_file = './assets/lena.jpg'
31 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
32 | im = cv2.resize(im, dsize=(200, 200))
33 | im = im.astype(np.float64)/255. # note the np.float64
34 |
35 | ################################################################################
36 | # NumPy
37 |
38 | fft_numpy = np.fft.fft2(im)
39 | fft_numpy = np.fft.fftshift(fft_numpy)
40 |
41 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy))
42 | fft_numpy_ang_viz = np.angle(fft_numpy)
43 |
44 | ifft_numpy1 = np.fft.ifftshift(fft_numpy)
45 | ifft_numpy = np.fft.ifft2(ifft_numpy1)
46 |
47 | ################################################################################
48 | # Torch
49 |
50 | device = torch.device('cpu')
51 |
52 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim
53 | im_torch = im_torch.to(device)
54 |
55 | # fft = complex-to-complex, rfft = real-to-complex
56 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
57 | fft_torch = fft_utils.batch_fftshift2d(fft_torch)
58 |
59 | ifft_torch = fft_utils.batch_ifftshift2d(fft_torch)
60 | ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False)
61 |
62 | ifft_torch_to_numpy = ifft_torch.numpy()
63 | ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2, -1) # complex => real/imag
64 | ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1)
65 | ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j*ifft_torch_to_numpy[1]
66 | all_close_ifft = np.allclose(ifft_numpy1, ifft_torch_to_numpy, atol=tolerance)
67 | print('ifft all close: ', all_close_ifft)
68 |
69 | fft_torch = fft_torch.cpu().numpy().squeeze()
70 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag
71 | fft_torch = np.squeeze(fft_torch, -1)
72 | fft_torch = fft_torch[0] + 1j*fft_torch[1]
73 |
74 | ifft_torch = ifft_torch.cpu().numpy().squeeze()
75 | ifft_torch = np.split(ifft_torch, 2, -1) # complex => real/imag
76 | ifft_torch = np.squeeze(ifft_torch, -1)
77 | ifft_torch = ifft_torch[0] + 1j*ifft_torch[1]
78 |
79 | fft_torch_mag_viz = np.log10(np.abs(fft_torch))
80 | fft_torch_ang_viz = np.angle(fft_torch)
81 |
82 | ################################################################################
83 | # Tolerance checking
84 |
85 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance)
86 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance)
87 | print('fft allclose real: {}'.format(all_close_real))
88 | print('fft allclose imag: {}'.format(all_close_imag))
89 |
90 | all_close_real = np.allclose(np.real(ifft_numpy), np.real(ifft_torch), atol=tolerance)
91 | all_close_imag = np.allclose(np.imag(ifft_numpy), np.imag(ifft_torch), atol=tolerance)
92 | print('ifft allclose real: {}'.format(all_close_real))
93 | print('ifft allclose imag: {}'.format(all_close_imag))
94 |
95 | ################################################################################
96 |
97 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12,6))
98 |
99 | # Plotting NumPy results
100 | ax[0][0].imshow(im, cmap='gray')
101 |
102 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray')
103 | ax[0][1].set_title('NumPy fft magnitude')
104 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray')
105 | ax[0][2].set_title('NumPy fft spectrum')
106 | ax[0][3].imshow(ifft_numpy.real, cmap='gray')
107 | ax[0][3].set_title('NumPy ifft real')
108 | ax[0][4].imshow(ifft_numpy.imag, cmap='gray')
109 | ax[0][4].set_title('NumPy ifft imag')
110 |
111 | # Plotting PyTorch results
112 | ax[1][0].imshow(im, cmap='gray')
113 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray')
114 | ax[1][1].set_title('PyTorch fft magnitude')
115 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray')
116 | ax[1][2].set_title('PyTorch fft phase')
117 | ax[1][3].imshow(ifft_torch.real, cmap='gray')
118 | ax[1][3].set_title('PyTorch ifft real')
119 | ax[1][4].imshow(ifft_torch.imag, cmap='gray')
120 | ax[1][4].set_title('PyTorch ifft imag')
121 |
122 | for cur_ax in ax.flatten():
123 | cur_ax.axis('off')
124 | plt.tight_layout()
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/tests/test_torch_fft.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 |
22 | ################################################################################
23 |
24 | x = np.linspace(-10, 10, 100)
25 | y = np.cos(np.pi*x)
26 |
27 | ################################################################################
28 | # NumPy
29 |
30 | y_fft_numpy = np.fft.fft(y)
31 |
32 | ################################################################################
33 | # Torch
34 |
35 | device = torch.device('cuda:0')
36 |
37 | y_torch = torch.from_numpy(y[None, :])
38 | y_torch = y_torch.to(device)
39 |
40 | # fft = complex-to-complex, rfft = real-to-complex
41 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False)
42 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze()
43 | y_fft_torch = y_fft_torch[:, 0] + 1j*y_fft_torch[:, 1]
44 |
45 | tolerance = 1e-6
46 | all_close = np.allclose(y_fft_numpy, y_fft_torch, atol=tolerance)
47 | print('numpy', y_fft_numpy.shape, y_fft_numpy.dtype)
48 | print('torch', y_fft_torch.shape, y_fft_torch.dtype)
49 | print('Succesful: {}'.format(all_close))
50 |
--------------------------------------------------------------------------------
/tests/test_torch_fftshift.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 |
24 | from steerable.fft import fftshift
25 |
26 | ################################################################################
27 |
28 | x = np.linspace(-10, 10, 100)
29 | y = np.cos(np.pi*x)
30 |
31 | ################################################################################
32 | # NumPy
33 |
34 | y_fft_numpy = np.fft.fft(y)
35 | y_fft_numpy_shift = np.fft.fftshift(y_fft_numpy)
36 | fft_numpy_real = np.real(y_fft_numpy_shift)
37 | fft_numpy_imag = np.imag(y_fft_numpy_shift)
38 |
39 | ################################################################################
40 | # Torch
41 |
42 | device = torch.device('cuda:0')
43 |
44 | y_torch = torch.from_numpy(y[None, :])
45 | y_torch = y_torch.to(device)
46 |
47 | # fft = complex-to-complex, rfft = real-to-complex
48 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False)
49 | y_fft_torch = fftshift(y_fft_torch[:,:,0], y_fft_torch[:,:,1])
50 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze()
51 | fft_torch_real = y_fft_torch[:,0]
52 | fft_torch_imag = y_fft_torch[:,1]
53 |
54 | tolerance = 1e-6
55 |
56 | all_close_real = np.allclose(fft_numpy_real, fft_torch_real, atol=tolerance)
57 | all_close_imag = np.allclose(fft_numpy_imag, fft_torch_imag, atol=tolerance)
58 |
59 | print('fftshift allclose real: {}'.format(all_close_real))
60 | print('fftshift allclose imag: {}'.format(all_close_imag))
61 |
62 | ################################################################################
63 |
64 | import cortex.plot
65 | import cortex.plot.colors
66 | colors = cortex.plot.nature_colors()
67 |
68 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6))
69 |
70 | # Plotting NumPy results
71 | ax[0][0].plot(x, y, color=colors[3])
72 | ax[0][1].plot(np.real(y_fft_numpy), color=colors[0])
73 | ax[0][1].plot(fft_numpy_real, color=colors[1])
74 | ax[0][1].set_title('NumPy Real')
75 | ax[0][2].plot(np.imag(y_fft_numpy), color=colors[0])
76 | ax[0][2].plot(fft_numpy_real, color=colors[1])
77 | ax[0][2].set_title('NumPy Imag')
78 |
79 | # Plotting NumPy results
80 | ax[1][0].plot(x, y, color=colors[3])
81 | ax[1][1].plot(fft_torch_real, color=colors[1])
82 | ax[1][1].set_title('Torch Real')
83 | ax[1][2].plot(fft_torch_imag, color=colors[1])
84 | ax[1][2].set_title('Torch Imag')
85 |
86 | plt.show()
87 |
--------------------------------------------------------------------------------
/tests/test_torch_fftshift2d.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 | import cv2
24 |
25 | from steerable.fft import fftshift
26 |
27 | ################################################################################
28 |
29 | image_file = './assets/lena.jpg'
30 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
31 | im = cv2.resize(im, dsize=(200, 200))
32 | im = im.astype(np.float64)/255. # note the np.float64
33 |
34 | ################################################################################
35 | # NumPy
36 |
37 | fft_numpy = np.fft.fft2(im)
38 | fft_numpy = np.fft.fftshift(fft_numpy)
39 |
40 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy))
41 | fft_numpy_ang_viz = np.angle(fft_numpy)
42 |
43 | ################################################################################
44 | # Torch
45 |
46 | device = torch.device('cuda:0')
47 |
48 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim
49 | im_torch = im_torch.to(device)
50 |
51 | # fft = complex-to-complex, rfft = real-to-complex
52 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
53 | fft_torch = fftshift(fft_torch[:,:,:,0], fft_torch[:,:,:,1])
54 | fft_torch = fft_torch.cpu().numpy().squeeze()
55 | print(fft_torch.shape)
56 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag
57 | fft_torch = np.squeeze(fft_torch, -1)
58 | fft_torch = fft_torch[0] + 1j*fft_torch[1]
59 |
60 | print('fft_torch', fft_torch.shape, fft_torch.dtype)
61 | fft_torch_mag_viz = np.log10(np.abs(fft_torch))
62 | fft_torch_ang_viz = np.angle(fft_torch)
63 |
64 | tolerance = 1e-6
65 |
66 | print(fft_numpy.dtype, fft_torch.dtype)
67 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance)
68 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance)
69 |
70 | print('fftshift allclose real: {}'.format(all_close_real))
71 | print('fftshift allclose imag: {}'.format(all_close_imag))
72 |
73 | ################################################################################
74 |
75 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6))
76 |
77 | # Plotting NumPy results
78 | ax[0][0].imshow(im, cmap='gray')
79 |
80 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray')
81 | ax[0][1].set_title('NumPy Magnitude Spectrum')
82 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray')
83 | ax[0][2].set_title('NumPy Phase Spectrum')
84 |
85 | # Plotting PyTorch results
86 | ax[1][0].imshow(im, cmap='gray')
87 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray')
88 | ax[1][1].set_title('Torch Magnitude Spectrum')
89 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray')
90 | ax[1][2].set_title('Torch Phase Spectrum')
91 |
92 | for cur_ax in ax.flatten():
93 | cur_ax.axis('off')
94 | plt.tight_layout()
95 | plt.show()
96 |
--------------------------------------------------------------------------------