├── .gitignore ├── README.md ├── assets ├── runtime_versus_signal_length.pdf ├── runtime_versus_signal_length.png ├── scalogram_comparison.pdf └── scalogram_comparison.png ├── examples ├── __init__.py ├── benchmark.py ├── plot.py ├── plot_comparison.py └── simple_example.py ├── license.txt ├── requirements.txt ├── setup.py └── wavelets_pytorch ├── __init__.py ├── network.py ├── transform.py └── wavelets.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignores 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Wavelet Transforms in PyTorch 2 | 3 | This is a PyTorch implementation for the wavelet analysis outlined in [Torrence 4 | and Compo (BAMS, 1998)](http://paos.colorado.edu/research/wavelets/). The code builds upon the excellent [implementation](https://github.com/aaren/wavelets/) 5 | of Aaron O'Leary by adding a PyTorch filter bank wrapper to enable fast convolution on the GPU. Specifically, the code was written to speed-up the CWT computation for a large number of 1D signals and relies on `torch.nn.Conv1d` for convolution. 6 | 7 | ![PyTorch Wavelets](/assets/scalogram_comparison.png "Scalogram Comparison") 8 | 9 | ## Citation 10 | 11 | If you found this code useful, please cite our paper [Repetition Estimation](https://link.springer.com/article/10.1007/s11263-019-01194-0) (IJCV, 2019): 12 | 13 | @article{runia2019repetition, 14 | title={Repetition estimation}, 15 | author={Runia, Tom FH and Snoek, Cees GM and Smeulders, Arnold WM}, 16 | journal={International Journal of Computer Vision}, 17 | volume={127}, 18 | number={9}, 19 | pages={1361--1383}, 20 | year={2019}, 21 | publisher={Springer} 22 | } 23 | 24 | ## Usage 25 | 26 | In addition to the PyTorch implementation defined in `WaveletTransformTorch` the original SciPy version is also included in `WaveletTransform` for completeness. As the GPU implementation highly benefits from parallelization, the `cwt` and `power` methods expect signal batches of shape `[num_signals,signal_length]` instead of individual signals. 27 | 28 | ```python 29 | import numpy as np 30 | from wavelets_pytorch.transform import WaveletTransform # SciPy version 31 | from wavelets_pytorch.transform import WaveletTransformTorch # PyTorch version 32 | 33 | dt = 0.1 # sampling frequency 34 | dj = 0.125 # scale distribution parameter 35 | batch_size = 32 # how many signals to process in parallel 36 | 37 | # Batch of signals to process 38 | batch = [batch_size x signal_length] 39 | 40 | # Initialize wavelet filter banks (scipy and torch implementation) 41 | wa_scipy = WaveletTransform(dt, dj) 42 | wa_torch = WaveletTransformTorch(dt, dj, cuda=True) 43 | 44 | # Performing wavelet transform (and compute scalogram) 45 | cwt_scipy = wa_scipy.cwt(batch) 46 | cwt_torch = wa_torch.cwt(batch) 47 | 48 | # For plotting, see the examples/plot.py function. 49 | # ... 50 | ``` 51 | 52 | ## Supported Wavelets 53 | 54 | The wavelet implementations are taken from [here](https://github.com/aaren/wavelets/blob/master/wavelets/wavelets.py). Default is the Morlet wavelet. 55 | 56 | ## Benchmark 57 | 58 | Performing parallel CWT computation on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comaprison between the `scipy` versus `torch` implementation as function of the batch size `N` and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU. 59 | 60 | 61 | 62 | ## Installation 63 | 64 | Clone and install: 65 | 66 | ```sh 67 | git clone https://github.com/tomrunia/PyTorchWavelets.git 68 | cd PyTorchWavelets 69 | pip install -r requirements.txt 70 | python setup.py install 71 | ``` 72 | 73 | ## Requirements 74 | 75 | - Python 2.7 or 3.6 (other versions might also work) 76 | - Numpy (developed with 1.14.1) 77 | - Scipy (developed with 1.0.0) 78 | - PyTorch >= 0.4.0 79 | 80 | The core of the PyTorch implementation relies on the `torch.nn.Conv1d` module. 81 | 82 | ## License 83 | 84 | MIT License 85 | 86 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 87 | 88 | Permission is hereby granted, free of charge, to any person obtaining a copy 89 | of this software and associated documentation files (the "Software"), to deal 90 | in the Software without restriction, including without limitation the rights 91 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 92 | copies of the Software, and to permit persons to whom the Software is 93 | furnished to do so, subject to the following conditions: 94 | 95 | The above copyright notice and this permission notice shall be included in all 96 | copies or substantial portions of the Software. 97 | 98 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 99 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 100 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 101 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 102 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 103 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 104 | SOFTWARE. 105 | -------------------------------------------------------------------------------- /assets/runtime_versus_signal_length.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchWavelets/8c6c093d890140483ee42f3b090f308c2a666e76/assets/runtime_versus_signal_length.pdf -------------------------------------------------------------------------------- /assets/runtime_versus_signal_length.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchWavelets/8c6c093d890140483ee42f3b090f308c2a666e76/assets/runtime_versus_signal_length.png -------------------------------------------------------------------------------- /assets/scalogram_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchWavelets/8c6c093d890140483ee42f3b090f308c2a666e76/assets/scalogram_comparison.pdf -------------------------------------------------------------------------------- /assets/scalogram_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomrunia/PyTorchWavelets/8c6c093d890140483ee42f3b090f308c2a666e76/assets/scalogram_comparison.png -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2017 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function -------------------------------------------------------------------------------- /examples/benchmark.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import time 20 | import numpy as np 21 | 22 | from wavelets_pytorch.transform import WaveletTransform 23 | from wavelets_pytorch.transform import WaveletTransformTorch 24 | 25 | ###################################### 26 | 27 | fps = 20 28 | dt = 1.0/fps 29 | dj = 0.125 30 | unbias = False 31 | t_min = 0 32 | 33 | batch_sizes = np.asarray([1,8,16,32,64,128,256,512], np.int32) 34 | durations = np.asarray([5,10,25,50,100], np.int32) 35 | signal_lengths = durations*fps 36 | 37 | num_runs = 5 38 | 39 | runtimes_scipy = np.zeros((len(batch_sizes), len(signal_lengths), num_runs), np.float32) 40 | runtimes_torch = np.zeros((len(batch_sizes), len(signal_lengths), num_runs), np.float32) 41 | 42 | for batch_ind, batch_size in enumerate(batch_sizes): 43 | 44 | for length_ind, t_max in enumerate(durations): 45 | 46 | t = np.linspace(t_min, t_max, (t_max-t_min)*fps) 47 | print('#'*60) 48 | print('Benchmarking | BatchSize = {}, SignalLength = {}'.format(batch_size, signal_lengths[length_ind])) 49 | 50 | for run_ind in range(num_runs): 51 | 52 | random_frequencies = np.random.uniform(-0.5, 4.0, size=batch_size) 53 | batch = np.asarray([np.sin(2*np.pi*f*t) for f in random_frequencies]) 54 | 55 | # Perform batch computation of SciPy implementation 56 | t_start = time.time() 57 | wa = WaveletTransform(dt, dj, unbias=unbias) 58 | power = wa.power(batch) 59 | runtimes_scipy[batch_ind,length_ind,run_ind] = time.time() - t_start 60 | #print(" Run {}/{} | SciPy: {:.2f}s".format(run_ind+1, num_runs, runtimes[batch_ind,length_ind,run_ind,0])) 61 | 62 | # Perform batch computation of Torch implementation 63 | t_start = time.time() 64 | wa = WaveletTransformTorch(dt, dj, unbias=unbias) 65 | power = wa.power(batch) 66 | runtimes_torch[batch_ind,length_ind,run_ind] = time.time() - t_start 67 | #print(" Run {}/{} | Torch: {:.2f}s".format(run_ind+1, num_runs, runtimes[batch_ind,length_ind,run_ind,1])) 68 | 69 | avg_scipy = np.mean(runtimes_scipy[batch_ind,length_ind,:]) 70 | avg_torch = np.mean(runtimes_torch[batch_ind,length_ind,:]) 71 | print(' Average SciPy: {:.2f}s'.format(avg_scipy)) 72 | print(' Average Torch: {:.2f}s'.format(avg_torch)) 73 | 74 | np.save('./runtimes_scipy.npy', runtimes_scipy) 75 | np.save('./runtimes_torch.npy', runtimes_torch) 76 | -------------------------------------------------------------------------------- /examples/plot.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | 21 | import matplotlib.pyplot as plt 22 | import matplotlib.ticker as ticker 23 | from mpl_toolkits.axes_grid1 import make_axes_locatable 24 | 25 | 26 | def plot_scalogram(power, scales, t, normalize_columns=True, cmap=None, ax=None, scale_legend=True): 27 | """ 28 | Plot the wavelet power spectrum (scalogram). 29 | 30 | :param power: np.ndarray, CWT power spectrum of shape [n_scales,signal_length] 31 | :param scales: np.ndarray, scale distribution of shape [n_scales] 32 | :param t: np.ndarray, temporal range of shape [signal_length] 33 | :param normalize_columns: boolean, whether to normalize spectrum per timestep 34 | :param cmap: matplotlib cmap, please refer to their documentation 35 | :param ax: matplotlib axis object, if None creates a new subplot 36 | :param scale_legend: boolean, whether to include scale legend on the right 37 | :return: ax, matplotlib axis object that contains the scalogram 38 | """ 39 | 40 | if not cmap: cmap = plt.get_cmap("PuBu_r") 41 | if ax is None: fig, ax = plt.subplots() 42 | if normalize_columns: power = power/np.max(power, axis=0) 43 | 44 | T, S = np.meshgrid(t, scales) 45 | cnt = ax.contourf(T, S, power, 100, cmap=cmap) 46 | 47 | # Fix for saving as PDF (aliasing) 48 | for c in cnt.collections: 49 | c.set_edgecolor("face") 50 | 51 | ax.set_yscale('log') 52 | ax.set_ylabel("Scale (Log Scale)") 53 | ax.set_xlabel("Time (s)") 54 | ax.set_title("Wavelet Power Spectrum") 55 | 56 | if scale_legend: 57 | def format_axes_label(x, pos): 58 | return "{:.2f}".format(x) 59 | divider = make_axes_locatable(ax) 60 | cax = divider.append_axes("right", size="5%", pad=0.05) 61 | plt.colorbar(cnt, cax=cax, ticks=[np.min(power), 0, np.max(power)], 62 | format=ticker.FuncFormatter(format_axes_label)) 63 | 64 | return ax -------------------------------------------------------------------------------- /examples/plot_comparison.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | 22 | from wavelets_pytorch.wavelets import Morlet, Ricker, DOG 23 | from wavelets_pytorch.transform import WaveletTransform, WaveletTransformTorch 24 | from examples.plot import plot_scalogram 25 | 26 | """ 27 | Example script to plot SciPy and PyTorch implementation outputs side-to-side. 28 | """ 29 | 30 | fps = 20 31 | dt = 1.0/fps 32 | dj = 0.125 33 | unbias = False 34 | batch_size = 32 35 | wavelet = Morlet(w0=6) 36 | 37 | t_min = 0 38 | t_max = 10 39 | t = np.linspace(t_min, t_max, (t_max-t_min)*fps) 40 | 41 | ###################################### 42 | # Generating batch of random sinusoidals 43 | 44 | random_frequencies = np.random.uniform(0.5, 4.0, size=batch_size) 45 | batch = np.asarray([np.sin(2*np.pi*f*t) for f in random_frequencies]) 46 | batch += np.random.normal(0, 0.2, batch.shape) # Gaussian noise 47 | 48 | ###################################### 49 | # Performing wavelet transform 50 | 51 | wa = WaveletTransform(dt, dj, wavelet, unbias=unbias) 52 | wa_torch = WaveletTransformTorch(dt, dj, wavelet, unbias=unbias, cuda=True) 53 | 54 | power = wa.power(batch) 55 | power_torch = wa_torch.power(batch) 56 | 57 | ###################################### 58 | # Plotting 59 | 60 | fig, ax = plt.subplots(1, 3, figsize=(12,3)) 61 | ax = ax.flatten() 62 | ax[0].plot(t, batch[0]) 63 | ax[0].set_title(r'$f(t) = \sin(2\pi \cdot f t) + \mathcal{N}(\mu,\,\sigma^{2})$') 64 | ax[0].set_xlabel('Time (s)') 65 | 66 | # Plot scalogram for SciPy implementation 67 | plot_scalogram(power[0], wa.fourier_periods, t, ax=ax[1], scale_legend=False) 68 | ax[1].axhline(1.0 / random_frequencies[0], lw=1, color='k') 69 | ax[1].set_title('Scalogram (SciPy)'.format(1.0/random_frequencies[0])) 70 | 71 | # Plot scalogram for PyTorch implementation 72 | plot_scalogram(power_torch[0], wa_torch.fourier_periods, t, ax=ax[2]) 73 | ax[2].axhline(1.0 / random_frequencies[0], lw=1, color='k') 74 | ax[2].set_title('Scalogram (Torch)'.format(1.0/random_frequencies[0])) 75 | ax[2].set_ylabel('') 76 | ax[2].set_yticks([]) 77 | 78 | plt.tight_layout() 79 | plt.show() -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | from wavelets_pytorch.transform import WaveletTransform # SciPy version 21 | from wavelets_pytorch.transform import WaveletTransformTorch # PyTorch version 22 | 23 | from wavelets_pytorch.wavelets import Morlet, Ricker 24 | 25 | """ 26 | Example script to demonstrate the CWT on a batch of random sinusoidal signals. 27 | We compare both the SciPy implementation and the PyTorch implementation. 28 | """ 29 | 30 | dt = 0.1 # sampling frequency 31 | dj = 0.125 # scale distribution parameter 32 | batch_size = 32 # how many signals to process in parallel 33 | cuda = True # enable GPU 34 | 35 | t = np.linspace(0., 10., int(10./dt)) 36 | 37 | # Both use a complex and real wavelet 38 | for wavelet in [Morlet(), Ricker()]: 39 | 40 | # Sinusoidals with random frequency 41 | frequencies = np.random.uniform(-0.5, 2.0, size=batch_size) 42 | batch = np.asarray([np.sin(2*np.pi*f*t) for f in frequencies]) 43 | 44 | # Initialize wavelet filter banks (scipy and torch implementation) 45 | wa_scipy = WaveletTransform(dt, dj, wavelet) 46 | wa_torch = WaveletTransformTorch(dt, dj, wavelet, cuda=cuda) 47 | 48 | # Performing wavelet transform (and compute scalogram) 49 | cwt_scipy = wa_scipy.cwt(batch) 50 | cwt_torch = wa_torch.cwt(batch) 51 | 52 | print(cwt_scipy.shape) 53 | print(cwt_torch.shape) 54 | 55 | # For plotting, see the examples/plot.py function. 56 | # ... -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | numpy 3 | scipy 4 | torch >= 0.4.0 5 | matplotlib -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup 3 | 4 | setup( 5 | name='wavelets_pytorch', 6 | version='0.1', 7 | author='Tom Runia', 8 | author_email='tomrunia@gmail.com', 9 | url='https://github.com/tomrunia/PyTorchWavelets', 10 | description='Wavelet Transform in PyTorch', 11 | long_description='Fast CPU/CUDA implementation of the Continuous Wavelet Transform in PyTorch.', 12 | license='MIT', 13 | packages=['wavelets_pytorch'], 14 | scripts=[] 15 | ) -------------------------------------------------------------------------------- /wavelets_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2017 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function -------------------------------------------------------------------------------- /wavelets_pytorch/network.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-16 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | class TorchFilterBank(nn.Module): 26 | 27 | def __init__(self, filters=None, cuda=True): 28 | """ 29 | Temporal filter bank in PyTorch storing a collection of nn.Conv1d filters. 30 | When cuda=True, the convolutions are performed on the GPU. If initialized with 31 | filters=None, the set_filters() method has to be called before actual running 32 | the convolutions. 33 | 34 | :param filters: list, collection of variable sized 1D filters (default: []) 35 | :param cuda: boolean, whether to run on GPU or not (default: True) 36 | """ 37 | super(TorchFilterBank, self).__init__() 38 | self._cuda = cuda 39 | self._filters = [] if not filters else self.set_filters(filters) 40 | 41 | def forward(self, x): 42 | """ 43 | Takes a batch of signals and convoles each signal with all elements in the filter 44 | bank. After convoling the entire filter bank, the method returns a tensor of 45 | shape [N,N_scales,1/2,T] where the 1/2 number of channels depends on whether 46 | the filter bank is composed of real or complex filters. If the filters are 47 | complex the 2 channels represent [real, imag] parts. 48 | 49 | :param x: torch.Variable, batch of input signals of shape [N,1,T] 50 | :return: torch.Variable, batch of outputs of size [N,N_scales,1/2,T] 51 | """ 52 | 53 | if not self._filters: 54 | raise ValueError('PyTorch filters not initialized. Please call set_filters() first.') 55 | return None 56 | results = [None]*len(self._filters) 57 | for ind, conv in enumerate(self._filters): 58 | results[ind] = conv(x) 59 | results = torch.stack(results) # [n_scales,n_batch,2,t] 60 | results = results.permute(1,0,2,3) # [n_batch,n_scales,2,t] 61 | return results 62 | 63 | def set_filters(self, filters, padding_type='SAME'): 64 | """ 65 | Given a list of temporal 1D filters of variable size, this method creates a 66 | list of nn.conv1d objects that collectively form the filter bank. 67 | 68 | :param filters: list, collection of filters each a np.ndarray 69 | :param padding_type: str, should be SAME or VALID 70 | :return: 71 | """ 72 | 73 | assert isinstance(filters, list) 74 | assert padding_type in ['SAME', 'VALID'] 75 | 76 | self._filters = [None]*len(filters) 77 | for ind, filt in enumerate(filters): 78 | 79 | assert filt.dtype in (np.float32, np.float64, np.complex64, np.complex128) 80 | 81 | if np.iscomplex(filt).any(): 82 | chn_out = 2 83 | filt_weights = np.asarray([np.real(filt), np.imag(filt)], np.float32) 84 | else: 85 | chn_out = 1 86 | filt_weights = filt.astype(np.float32)[None,:] 87 | 88 | filt_weights = np.expand_dims(filt_weights, 1) # append chn_in dimension 89 | filt_size = filt_weights.shape[-1] # filter length 90 | padding = self._get_padding(padding_type, filt_size) 91 | 92 | conv = nn.Conv1d(1, chn_out, kernel_size=filt_size, padding=padding, bias=False) 93 | conv.weight.data = torch.from_numpy(filt_weights) 94 | conv.weight.requires_grad_(False) 95 | 96 | if self._cuda: conv.cuda() 97 | self._filters[ind] = conv 98 | 99 | @staticmethod 100 | def _get_padding(padding_type, kernel_size): 101 | assert isinstance(kernel_size, int) 102 | assert padding_type in ['SAME', 'VALID'] 103 | if padding_type == 'SAME': 104 | return (kernel_size - 1) // 2 105 | return 0 -------------------------------------------------------------------------------- /wavelets_pytorch/transform.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-04-15 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import six 20 | from abc import ABCMeta, abstractmethod 21 | 22 | import numpy as np 23 | import scipy.signal 24 | import scipy.optimize 25 | 26 | import torch 27 | from torch.autograd import Variable 28 | 29 | from wavelets_pytorch.wavelets import Morlet 30 | from wavelets_pytorch.network import TorchFilterBank 31 | 32 | ########################################################################################## 33 | 34 | @six.add_metaclass(ABCMeta) 35 | class WaveletTransformBase(object): 36 | """ 37 | 38 | Base class for the Continuous Wavelet Transform as described in: 39 | "Torrence & Combo, A Practical Guide to Wavelet Analysis (BAMS, 1998)" 40 | 41 | This class is a abstract super class for child classes: 42 | WaveletTransform => implements CWT in SciPy 43 | WaveletTransformTorch => implements CWT in PyTorch 44 | 45 | For a more detailed explanation of the parameters, the original code serves as reference: 46 | https://github.com/aaren/wavelets/blob/master/wavelets/transform.py#L145 47 | 48 | """ 49 | 50 | def __init__(self, dt=1.0, dj=0.125, wavelet=Morlet(), unbias=False): 51 | """ 52 | :param dt: float, sample spacing 53 | :param dj: float, scale distribution parameter 54 | :param wavelet: wavelet object, see 'wavelets.py' 55 | :param unbias: boolean, whether to unbias the power spectrum 56 | """ 57 | self._dt = dt 58 | self._dj = dj 59 | self._wavelet = wavelet 60 | self._unbias = unbias 61 | self._scale_minimum = self.compute_minimum_scale() 62 | self._signal_length = None # initialize on first call 63 | self._scales = None # initialize on first call 64 | self._filters = None # initialize on first call 65 | 66 | @abstractmethod 67 | def cwt(self, x): 68 | raise NotImplementedError 69 | 70 | def _build_filters(self): 71 | """ 72 | Determines the optimal scale distribution (see. Torrence & Combo, Eq. 9-10), 73 | and then initializes the filter bank consisting of rescaled versions 74 | of the mother wavelet. Also includes normalization. Code is based on: 75 | https://github.com/aaren/wavelets/blob/master/wavelets/transform.py#L88 76 | """ 77 | self._scale_minimum = self.compute_minimum_scale() 78 | self._scales = self.compute_optimal_scales() 79 | 80 | self._filters = [None]*len(self.scales) 81 | for scale_idx, scale in enumerate(self._scales): 82 | # Number of points needed to capture wavelet 83 | M = 10 * scale / self.dt 84 | # Times to use, centred at zero 85 | t = np.arange((-M + 1) / 2., (M + 1) / 2.) * self.dt 86 | if len(t) % 2 == 0: t = t[0:-1] # requires odd filter size 87 | # Sample wavelet and normalise 88 | norm = (self.dt / scale) ** .5 89 | self._filters[scale_idx] = norm * self.wavelet(t, scale) 90 | 91 | def compute_optimal_scales(self): 92 | """ 93 | Determines the optimal scale distribution (see. Torrence & Combo, Eq. 9-10). 94 | :return: np.ndarray, collection of scales 95 | """ 96 | if self.signal_length is None: 97 | raise ValueError('Please specify signal_length before computing optimal scales.') 98 | J = int((1 / self.dj) * np.log2(self.signal_length * self.dt / self._scale_minimum)) 99 | scales = self._scale_minimum * 2 ** (self.dj * np.arange(0, J + 1)) 100 | return scales 101 | 102 | def compute_minimum_scale(self): 103 | """ 104 | Choose s0 so that the equivalent Fourier period is 2 * dt. 105 | See Torrence & Combo Sections 3f and 3h. 106 | :return: float, minimum scale level 107 | """ 108 | dt = self.dt 109 | def func_to_solve(s): 110 | return self.fourier_period(s) - 2 * dt 111 | return scipy.optimize.fsolve(func_to_solve, 1)[0] 112 | 113 | def power(self, x): 114 | """ 115 | Performs CWT and converts to a power spectrum (scalogram). 116 | See Torrence & Combo, Section 4d. 117 | :param x: np.ndarray, batch of input signals of shape [n_batch,signal_length] 118 | :return: np.ndarray, scalogram for each signal [n_batch,n_scales,signal_length] 119 | """ 120 | if self.unbias: 121 | return (np.abs(self.cwt(x)).T ** 2 / self.scales).T 122 | else: 123 | return np.abs(self.cwt(x)) ** 2 124 | 125 | @property 126 | def dt(self): 127 | return self._dt 128 | 129 | @dt.setter 130 | def dt(self, value): 131 | # Needs to recompute scale distribution and filters 132 | self._dt = value 133 | self._build_filters() 134 | 135 | @property 136 | def signal_length(self): 137 | return self._signal_length 138 | 139 | @signal_length.setter 140 | def signal_length(self, value): 141 | # Needs to recompute scale distribution and filters 142 | self._signal_length = value 143 | self._build_filters() 144 | 145 | @property 146 | def wavelet(self): 147 | return self._wavelet 148 | 149 | @property 150 | def fourier_period(self): 151 | """ Return a function that calculates the equivalent Fourier. """ 152 | return getattr(self.wavelet, 'fourier_period') 153 | 154 | @property 155 | def scale_from_period(self): 156 | """ Return a function that calculates the wavelet scale from the fourier period """ 157 | return getattr(self.wavelet, 'scale_from_period') 158 | 159 | @property 160 | def fourier_periods(self): 161 | """ Return the equivalent Fourier periods for the scales used. """ 162 | assert self._scales is not None, 'Wavelet scales are not initialized.' 163 | return self.fourier_period(self.scales) 164 | 165 | @property 166 | def fourier_frequencies(self): 167 | """ Return the equivalent frequencies. """ 168 | return np.reciprocal(self.fourier_periods) 169 | 170 | @property 171 | def scales(self): 172 | return self._scales 173 | 174 | @property 175 | def dj(self): 176 | return self._dj 177 | 178 | @property 179 | def wavelet(self): 180 | return self._wavelet 181 | 182 | @property 183 | def unbias(self): 184 | return self._unbias 185 | 186 | @property 187 | def complex_wavelet(self): 188 | return np.iscomplexobj(self._filters[0]) 189 | 190 | @property 191 | def output_dtype(self): 192 | return np.complex128 if self.complex_wavelet else np.float64 193 | 194 | ########################################################################################## 195 | 196 | class WaveletTransform(WaveletTransformBase): 197 | 198 | def __init__(self, dt=1.0, dj=0.125, wavelet=Morlet(), unbias=False): 199 | """ 200 | This is SciPy version of the CWT filter bank. Main work for this filter bank 201 | is performed by the convolution implementated in 'scipy.signal.convolve' 202 | 203 | :param dt: float, sample spacing 204 | :param dj: float, scale distribution parameter 205 | :param wavelet: wavelet object, see 'wavelets.py' 206 | :param unbias: boolean, whether to unbias the power spectrum 207 | """ 208 | super(WaveletTransform,self).__init__(dt, dj, wavelet, unbias) 209 | 210 | def cwt(self, x): 211 | """ 212 | Implements the continuous wavelet transform on a batch of signals. All signals 213 | in the batch must have the same length, otherwise manual zero padding has to be 214 | applied. On the first call, the signal length is used to determines the optimal 215 | scale distribution and uses this for initialization of the wavelet filter bank. 216 | If there is only one example in the batch the batch dimension is squeezed. 217 | 218 | :param x: np.ndarray, batch of signals of shape [n_batch,signal_length] 219 | :return: np.ndarray, CWT for each signal in the batch [n_batch,n_scales,signal_length] 220 | """ 221 | 222 | # Append batch dimension 223 | if x.ndim == 1: 224 | x = x[None,:] 225 | 226 | num_examples = x.shape[0] 227 | signal_length = x.shape[-1] 228 | 229 | if signal_length != self.signal_length or not self._filters: 230 | # First call initializtion, or change in signal length. Note that calling 231 | # this also determines the optimal scales and initialized the filter bank. 232 | self.signal_length = signal_length 233 | 234 | # Wavelets can be complex so output is complex (np.float64 or np.complex128) 235 | cwt = np.zeros((num_examples, len(self.scales), x.shape[-1]), self.output_dtype) 236 | for example_idx in range(num_examples): 237 | cwt[example_idx] = self._compute_single(x[example_idx]) 238 | 239 | # Squeeze batch dimension if single example 240 | if num_examples == 1: 241 | cwt = cwt.squeeze(0) 242 | return cwt 243 | 244 | def _compute_single(self, x): 245 | assert x.ndim == 1, 'Input signal must have single dimension.' 246 | output = np.zeros((len(self.scales), len(x)), self.output_dtype) 247 | for scale_idx, filt in enumerate(self._filters): 248 | output[scale_idx,:] = scipy.signal.convolve(x, filt, mode='same') 249 | return output 250 | 251 | ########################################################################################## 252 | 253 | class WaveletTransformTorch(WaveletTransformBase): 254 | 255 | def __init__(self, dt=1.0, dj=0.125, wavelet=Morlet(), unbias=False, cuda=True): 256 | """ 257 | This is PyTorch version of the CWT filter bank. Main work for this filter bank 258 | is performed by the convolution implementated in 'torch.nn.Conv1d'. Actual 259 | convolutions are performed by the helper class defined in 'network.py' which 260 | implements a 'torch.nn.module' that contains the convolution filters. 261 | 262 | :param dt: float, sample spacing 263 | :param dj: float, scale distribution parameter 264 | :param wavelet: wavelet object, see 'wavelets.py' 265 | :param unbias: boolean, whether to unbias the power spectrum 266 | :param cuda: boolean, whether to run convolutions on the GPU 267 | """ 268 | super(WaveletTransformTorch, self).__init__(dt, dj, wavelet, unbias) 269 | self._cuda = cuda 270 | self._extractor = TorchFilterBank(self._filters, cuda) 271 | 272 | def cwt(self, x): 273 | """ 274 | Implements the continuous wavelet transform on a batch of signals. All signals 275 | in the batch must have the same length, otherwise manual zero padding has to be 276 | applied. On the first call, the signal length is used to determines the optimal 277 | scale distribution and uses this for initialization of the wavelet filter bank. 278 | If there is only one example in the batch the batch dimension is squeezed. 279 | 280 | :param x: np.ndarray, batch of signals of shape [n_batch,signal_length] 281 | :return: np.ndarray, CWT for each signal in the batch [n_batch,n_scales,signal_length] 282 | """ 283 | 284 | if x.ndim == 1: 285 | # Append batch_size and chn_in dimensions 286 | # [signal_length] => [n_batch,1,signal_length] 287 | x = x[None,None,:] 288 | elif x.ndim == 2: 289 | # Just append chn_in dimension 290 | # [n_batch,signal_length] => [n_batch,1,signal_length] 291 | x = x[:,None,:] 292 | 293 | num_examples = x.shape[0] 294 | signal_length = x.shape[-1] 295 | 296 | if signal_length != self.signal_length or not self._filters: 297 | # First call initializtion, or change in signal length. Note that calling 298 | # this also determines the optimal scales and initialized the filter bank. 299 | self.signal_length = signal_length 300 | 301 | # Move to GPU and perform CWT computation 302 | x = torch.from_numpy(x).type(torch.FloatTensor) 303 | x.requires_grad_(requires_grad=False) 304 | 305 | if self._cuda: x = x.cuda() 306 | cwt = self._extractor(x) 307 | 308 | # Move back to CPU 309 | cwt = cwt.detach() 310 | if self._cuda: cwt = cwt.cpu() 311 | cwt = cwt.numpy() 312 | 313 | if self.complex_wavelet: 314 | # Combine real and imag parts, returns object of shape 315 | # [n_batch,n_scales,signal_length] of type np.complex128 316 | cwt = (cwt[:,:,0,:] + cwt[:,:,1,:]*1j).astype(self.output_dtype) 317 | else: 318 | # Just squeeze the chn_out dimension (=1) to obtain an object of shape 319 | # [n_batch,n_scales,signal_length] of type np.float64 320 | cwt = np.squeeze(cwt, 2).astype(self.output_dtype) 321 | 322 | # Squeeze batch dimension if single example 323 | if num_examples == 1: 324 | cwt = cwt.squeeze(0) 325 | return cwt 326 | 327 | @property 328 | def dt(self): 329 | return self._dt 330 | 331 | @dt.setter 332 | def dt(self, value): 333 | super(WaveletTransformTorch, self.__class__).dt.fset(self, value) 334 | self._extractor.set_filters(self._filters) 335 | 336 | @property 337 | def signal_length(self): 338 | return self._signal_length 339 | 340 | @signal_length.setter 341 | def signal_length(self, value): 342 | super(WaveletTransformTorch, self.__class__).signal_length.fset(self, value) 343 | self._extractor.set_filters(self._filters) -------------------------------------------------------------------------------- /wavelets_pytorch/wavelets.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Aaron O'Leary (dev@aaren.me) 13 | # Date Created: 2016-02-28 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import scipy 21 | import scipy.signal 22 | import scipy.optimize 23 | import scipy.special 24 | from scipy.misc import factorial 25 | 26 | __all__ = ['Morlet', 'Paul', 'DOG', 'Ricker', 'Marr', 'Mexican_hat'] 27 | 28 | 29 | class Morlet(object): 30 | def __init__(self, w0=6): 31 | """w0 is the nondimensional frequency constant. If this is 32 | set too low then the wavelet does not sample very well: a 33 | value over 5 should be ok; Terrence and Compo set it to 6. 34 | """ 35 | self.w0 = w0 36 | if w0 == 6: 37 | # value of C_d from TC98 38 | self.C_d = 0.776 39 | 40 | def __call__(self, *args, **kwargs): 41 | return self.time(*args, **kwargs) 42 | 43 | def time(self, t, s=1.0, complete=True): 44 | """ 45 | Complex Morlet wavelet, centred at zero. 46 | 47 | Parameters 48 | ---------- 49 | t : float 50 | Time. If s is not specified, this can be used as the 51 | non-dimensional time t/s. 52 | s : float 53 | Scaling factor. Default is 1. 54 | complete : bool 55 | Whether to use the complete or the standard version. 56 | 57 | Returns 58 | ------- 59 | out : complex 60 | Value of the Morlet wavelet at the given time 61 | 62 | See Also 63 | -------- 64 | scipy.signal.gausspulse 65 | 66 | Notes 67 | ----- 68 | The standard version:: 69 | 70 | pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2)) 71 | 72 | This commonly used wavelet is often referred to simply as the 73 | Morlet wavelet. Note that this simplified version can cause 74 | admissibility problems at low values of `w`. 75 | 76 | The complete version:: 77 | 78 | pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2)) 79 | 80 | The complete version of the Morlet wavelet, with a correction 81 | term to improve admissibility. For `w` greater than 5, the 82 | correction term is negligible. 83 | 84 | Note that the energy of the return wavelet is not normalised 85 | according to `s`. 86 | 87 | The fundamental frequency of this wavelet in Hz is given 88 | by ``f = 2*s*w*r / M`` where r is the sampling rate. 89 | 90 | """ 91 | w = self.w0 92 | 93 | x = t / s 94 | 95 | output = np.exp(1j * w * x) 96 | 97 | if complete: 98 | output -= np.exp(-0.5 * (w ** 2)) 99 | 100 | output *= np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25) 101 | 102 | return output 103 | 104 | # Fourier wavelengths 105 | def fourier_period(self, s): 106 | """Equivalent Fourier period of Morlet""" 107 | return 4 * np.pi * s / (self.w0 + (2 + self.w0 ** 2) ** .5) 108 | 109 | def scale_from_period(self, period): 110 | """ 111 | Compute the scale from the fourier period. 112 | Returns the scale 113 | """ 114 | # Solve 4 * np.pi * scale / (w0 + (2 + w0 ** 2) ** .5) 115 | # for s to obtain this formula 116 | coeff = np.sqrt(self.w0 * self.w0 + 2) 117 | return (period * (coeff + self.w0)) / (4. * np.pi) 118 | 119 | # Frequency representation 120 | def frequency(self, w, s=1.0): 121 | """Frequency representation of Morlet. 122 | 123 | Parameters 124 | ---------- 125 | w : float 126 | Angular frequency. If `s` is not specified, i.e. set to 1, 127 | this can be used as the non-dimensional angular 128 | frequency w * s. 129 | s : float 130 | Scaling factor. Default is 1. 131 | 132 | Returns 133 | ------- 134 | out : complex 135 | Value of the Morlet wavelet at the given frequency 136 | """ 137 | x = w * s 138 | # Heaviside mock 139 | Hw = np.array(w) 140 | Hw[w <= 0] = 0 141 | Hw[w > 0] = 1 142 | return np.pi ** -.25 * Hw * np.exp((-(x - self.w0) ** 2) / 2) 143 | 144 | def coi(self, s): 145 | """The e folding time for the autocorrelation of wavelet 146 | power at each scale, i.e. the timescale over which an edge 147 | effect decays by a factor of 1/e^2. 148 | 149 | This can be worked out analytically by solving 150 | 151 | |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2 152 | """ 153 | return 2 ** .5 * s 154 | 155 | 156 | class Paul(object): 157 | def __init__(self, m=4): 158 | """Initialise a Paul wavelet function of order `m`. 159 | """ 160 | self.m = m 161 | 162 | def __call__(self, *args, **kwargs): 163 | return self.time(*args, **kwargs) 164 | 165 | def time(self, t, s=1.0): 166 | """ 167 | Complex Paul wavelet, centred at zero. 168 | 169 | Parameters 170 | ---------- 171 | t : float 172 | Time. If `s` is not specified, i.e. set to 1, this can be 173 | used as the non-dimensional time t/s. 174 | s : float 175 | Scaling factor. Default is 1. 176 | 177 | Returns 178 | ------- 179 | out : complex 180 | Value of the Paul wavelet at the given time 181 | 182 | The Paul wavelet is defined (in time) as:: 183 | 184 | (2 ** m * i ** m * m!) / (pi * (2 * m)!) \ 185 | * (1 - i * t / s) ** -(m + 1) 186 | 187 | """ 188 | m = self.m 189 | x = t / s 190 | 191 | const = (2 ** m * 1j ** m * factorial(m)) \ 192 | / (np.pi * factorial(2 * m)) ** .5 193 | functional_form = (1 - 1j * x) ** -(m + 1) 194 | 195 | output = const * functional_form 196 | 197 | return output 198 | 199 | # Fourier wavelengths 200 | def fourier_period(self, s): 201 | """Equivalent Fourier period of Paul""" 202 | return 4 * np.pi * s / (2 * self.m + 1) 203 | 204 | def scale_from_period(self, period): 205 | raise NotImplementedError() 206 | 207 | # Frequency representation 208 | def frequency(self, w, s=1.0): 209 | """Frequency representation of Paul. 210 | 211 | Parameters 212 | ---------- 213 | w : float 214 | Angular frequency. If `s` is not specified, i.e. set to 1, 215 | this can be used as the non-dimensional angular 216 | frequency w * s. 217 | s : float 218 | Scaling factor. Default is 1. 219 | 220 | Returns 221 | ------- 222 | out : complex 223 | Value of the Paul wavelet at the given frequency 224 | 225 | """ 226 | m = self.m 227 | x = w * s 228 | # Heaviside mock 229 | Hw = 0.5 * (np.sign(x) + 1) 230 | 231 | # prefactor 232 | const = 2 ** m / (m * factorial(2 * m - 1)) ** .5 233 | 234 | functional_form = Hw * (x) ** m * np.exp(-x) 235 | 236 | output = const * functional_form 237 | 238 | return output 239 | 240 | def coi(self, s): 241 | """The e folding time for the autocorrelation of wavelet 242 | power at each scale, i.e. the timescale over which an edge 243 | effect decays by a factor of 1/e^2. 244 | 245 | This can be worked out analytically by solving 246 | 247 | |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2 248 | """ 249 | return s / 2 ** .5 250 | 251 | 252 | class DOG(object): 253 | def __init__(self, m=2): 254 | """Initialise a Derivative of Gaussian wavelet of order `m`.""" 255 | if m == 2: 256 | # value of C_d from TC98 257 | self.C_d = 3.541 258 | elif m == 6: 259 | self.C_d = 1.966 260 | else: 261 | pass 262 | self.m = m 263 | 264 | def __call__(self, *args, **kwargs): 265 | return self.time(*args, **kwargs) 266 | 267 | def time(self, t, s=1.0): 268 | """ 269 | Return a Derivative of Gaussian wavelet, 270 | 271 | When m = 2, this is also known as the "Mexican hat", "Marr" 272 | or "Ricker" wavelet. 273 | 274 | It models the function:: 275 | 276 | ``A d^m/dx^m exp(-x^2 / 2)``, 277 | 278 | where ``A = (-1)^(m+1) / (gamma(m + 1/2))^.5`` 279 | and ``x = t / s``. 280 | 281 | Note that the energy of the return wavelet is not normalised 282 | according to `s`. 283 | 284 | Parameters 285 | ---------- 286 | t : float 287 | Time. If `s` is not specified, this can be used as the 288 | non-dimensional time t/s. 289 | s : scalar 290 | Width parameter of the wavelet. 291 | 292 | Returns 293 | ------- 294 | out : float 295 | Value of the DOG wavelet at the given time 296 | 297 | Notes 298 | ----- 299 | The derivative of the Gaussian has a polynomial representation: 300 | 301 | from http://en.wikipedia.org/wiki/Gaussian_function: 302 | 303 | "Mathematically, the derivatives of the Gaussian function can be 304 | represented using Hermite functions. The n-th derivative of the 305 | Gaussian is the Gaussian function itself multiplied by the n-th 306 | Hermite polynomial, up to scale." 307 | 308 | http://en.wikipedia.org/wiki/Hermite_polynomial 309 | 310 | Here, we want the 'probabilists' Hermite polynomial (He_n), 311 | which is computed by scipy.special.hermitenorm 312 | 313 | """ 314 | x = t / s 315 | m = self.m 316 | 317 | # compute the Hermite polynomial (used to evaluate the 318 | # derivative of a Gaussian) 319 | He_n = scipy.special.hermitenorm(m) 320 | gamma = scipy.special.gamma 321 | 322 | const = (-1) ** (m + 1) / gamma(m + 0.5) ** .5 323 | function = He_n(x) * np.exp(-x ** 2 / 2) 324 | 325 | return const * function 326 | 327 | def fourier_period(self, s): 328 | """Equivalent Fourier period of derivative of Gaussian""" 329 | return 2 * np.pi * s / (self.m + 0.5) ** .5 330 | 331 | def scale_from_period(self, period): 332 | raise NotImplementedError() 333 | 334 | def frequency(self, w, s=1.0): 335 | """Frequency representation of derivative of Gaussian. 336 | 337 | Parameters 338 | ---------- 339 | w : float 340 | Angular frequency. If `s` is not specified, i.e. set to 1, 341 | this can be used as the non-dimensional angular 342 | frequency w * s. 343 | s : float 344 | Scaling factor. Default is 1. 345 | 346 | Returns 347 | ------- 348 | out : complex 349 | Value of the derivative of Gaussian wavelet at the 350 | given time 351 | """ 352 | m = self.m 353 | x = s * w 354 | gamma = scipy.special.gamma 355 | const = -1j ** m / gamma(m + 0.5) ** .5 356 | function = x ** m * np.exp(-x ** 2 / 2) 357 | return const * function 358 | 359 | def coi(self, s): 360 | """The e folding time for the autocorrelation of wavelet 361 | power at each scale, i.e. the timescale over which an edge 362 | effect decays by a factor of 1/e^2. 363 | 364 | This can be worked out analytically by solving 365 | 366 | |Y_0(T)|^2 / |Y_0(0)|^2 = 1 / e^2 367 | """ 368 | return 2 ** .5 * s 369 | 370 | 371 | class Ricker(DOG): 372 | def __init__(self): 373 | """The Ricker, aka Marr / Mexican Hat, wavelet is a 374 | derivative of Gaussian order 2. 375 | """ 376 | DOG.__init__(self, m=2) 377 | # value of C_d from TC98 378 | self.C_d = 3.541 379 | 380 | 381 | # aliases for DOG2 382 | Marr = Ricker 383 | Mexican_hat = Ricker --------------------------------------------------------------------------------