├── .gitignore
├── LICENSE
├── README.md
├── SignalProcessing
├── Q11333 - STFT - Beats in spectrogram of pure sines
│ └── main.py
├── Q29509 - DFT - Extract sine phase and amplitude
│ ├── estimators.py
│ └── main.py
├── Q76329 - CWT, Wavelets - Scale vs frequency
│ ├── im0.png
│ ├── im1.png
│ ├── im10.png
│ ├── im11.png
│ ├── im12.png
│ ├── im13.png
│ ├── im14.png
│ ├── im15.png
│ ├── im16.png
│ ├── im3.png
│ ├── im4.png
│ ├── im5.png
│ ├── im6.png
│ ├── im7.png
│ ├── im8.png
│ ├── im9.png
│ ├── main.py
│ └── wavgif.gif
├── Q76463 - FM - Frequency Modulating a Signal
│ ├── im0.png
│ ├── im1.png
│ ├── im2.png
│ └── main.py
├── Q76560 - Analyticity - Symmetries
│ ├── main.py
│ ├── morlets0.png
│ ├── morlets1.png
│ └── rands0.png
├── Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass
│ ├── WGN.png
│ ├── im2.png
│ ├── ims.gif
│ └── main.py
├── Q76644 - estimation - Estimate sine frequency under white noise
│ ├── _optimized.pyx
│ ├── benchmarks.py
│ ├── estimators.py
│ ├── main.py
│ ├── optimized.py
│ ├── setup.py
│ └── utils76644.py
├── Q76754 - Hilbert - Alleviate boundary effects
│ ├── data.npy
│ └── main.py
├── Q78512 - Wavelet Scattering explanation
│ ├── cwt.gif
│ ├── cwt.py
│ ├── cwt_phi.gif
│ ├── cwt_phi.py
│ ├── reconstruction.gif
│ ├── reconstruction.py
│ ├── scat_impulse.gif
│ ├── scat_impulse_anim.py
│ ├── scat_impulse_resc.gif
│ ├── scat_impulse_unscaled_anim.gif
│ ├── second_order.py
│ ├── sine_warp.png
│ ├── tshift_invar.py
│ ├── warp_cwt_anim.py
│ ├── warp_cwt_echirp.gif
│ ├── warp_cwt_sine.gif
│ ├── warp_scat_T.gif
│ ├── warp_scat_T_anim.py
│ ├── warp_scat_anim.py
│ ├── warp_scat_tau.gif
│ ├── warp_scat_tau_anim.py
│ ├── warp_scat_tau_echirp.gif
│ └── warp_scat_tau_pulse.gif
├── Q78644 - Joint Time-Frequency Scattering explanation
│ ├── data_3d_anim.py
│ ├── echirp_2d.py
│ ├── echirp_2d_anim.py
│ ├── echirp_3d_anim.py
│ ├── fdts.py
│ ├── filterbank.gif
│ ├── filterbank.py
│ ├── freq_tp.gif
│ ├── freq_tp.py
│ ├── jtfs3d_echirp_full.gif
│ ├── jtfs3d_echirp_spinned.gif
│ ├── jtfs3d_seiz.gif
│ ├── jtfs3d_trumpet.gif
│ ├── jtfs_echirp.gif
│ ├── jtfs_echirp_wavelets.gif
│ ├── jtfs_sine.gif
│ ├── jtfs_ts.gif
│ └── sine_2d_anim.py
├── Q80918 - STFT - Overlap, window length, uncertainty principle
│ └── main.py
├── Q81624 - STFT, optimization - Optimize window length for STFT
│ └── main.py
├── Q85745 - Energy, power, relation to sampling rate and duration
│ └── main.py
├── Q86181 - CWT Power and Energy
│ ├── cwt_power_example.m
│ ├── cwt_power_example.py
│ └── cwt_power_validation.py
├── Q86726 - STFT - Amplitude extraction
│ └── main.py
├── Q86937 - STFT - Equivalence of Windowed Fourier and Convolutions
│ └── main.py
├── Q87355 - audio, algorithms - Detecting abrupt changes
│ ├── T2_D_00004323.wav
│ ├── T2_D_00004324.wav
│ ├── T2_D_00004326.wav
│ ├── T2_D_00004328.wav
│ ├── T2_D_00004330.wav
│ ├── T2_D_00004331.wav
│ ├── T2_D_00004333.wav
│ ├── T2_D_00004335.wav
│ ├── T2_D_00004338.wav
│ ├── T2_D_00004339.wav
│ ├── T2_T_00043244.wav
│ ├── T2_T_00043245.wav
│ ├── T2_T_00043246.wav
│ ├── example.wav
│ ├── im00.png
│ ├── im01.png
│ ├── im02.png
│ ├── im03.png
│ ├── im04.png
│ ├── im05.png
│ ├── im06.png
│ ├── im07.png
│ ├── im08.png
│ ├── im09.png
│ ├── im10.png
│ ├── im11.png
│ ├── im12.png
│ ├── im13.png
│ ├── main.py
│ ├── preds.gif
│ ├── test_algo.py
│ └── utils87355.py
├── Q87706 - aliasing - How to measure aliasing
│ └── main.py
├── Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation
│ ├── covid.png
│ ├── covid_target.png
│ ├── covid_target2.png
│ ├── covid_template.png
│ ├── main.py
│ └── validation.py
├── Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python
│ ├── cc2d.py
│ └── main.py
├── Q87926 - DFT of a sine, closed form solution and insights
│ ├── main.py
│ └── solutions.py
└── Q88042 - sampling - Is downsampling LTI for bandlimited inputs
│ └── main.py
└── StackOverflow
└── Q76812752 - MATLAB dictionary comprehension
├── ifelse.m
├── py_dictcomp.m
└── test_py_dictcomp.m
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 OverLordGoldDragon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StackExchange Answers
2 |
3 | [](https://zenodo.org/badge/latestdoi/387701666)
4 |
5 | "Watch" to be notified of new posts.
6 |
7 | Question titles may be paraphrased, but IDs are exact.
8 |
9 | ### Citing
10 |
11 | If using significant portions of code, please cite this repository. Short form (year = latest, or when the cited content was published):
12 |
13 | > John Muradeli, StackExchange Answers, 2023. GitHub repository, https://github.com/OverLordGoldDragon/StackExchangeAnswers/. DOI: 10.5281/zenodo.5115534
14 |
15 | BibTeX
16 |
17 | ```bibtex
18 | @article{OverLordGoldDragon2020StackExchangeAnswers,
19 | title={StackExchange Answers},
20 | author={John Muradeli},
21 | journal={GitHub. Note: https://github.com/OverLordGoldDragon/StackExchangeAnswers/},
22 | year={2023},
23 | doi={10.5281/zenodo.5115534},
24 | }
25 | ```
26 |
27 |
28 |
29 | ### Signal Processing / Feature Engineering
30 |
31 | Code-accompanied answers say "code Q...", found in respective folders, or "code included".
32 |
33 | ✯ = exceptional quality; 🡡 = great quality (brag!)
34 |
35 | #### Feature Engineering (applied) / Problem Solving
36 |
37 | 1. ✯ Identifying abrupt changes in audio - [Q&A](https://dsp.stackexchange.com/a/87512/50076), code Q87355
38 | 2. 🡡 Locating non-homogeneous areas in an image - [Q&A](https://dsp.stackexchange.com/a/75636/50076), code included
39 | 3. 🡡 Estimate sine frequency under white noise - [Q&A](https://dsp.stackexchange.com/a/88711/50076), code Q76644
40 | 4. Extract sine phase and amplitude, accurate and noise-robust method - [Q&A](https://dsp.stackexchange.com/a/88833/50076), code Q29509
41 | 5. Nonlinear filtering and inspection of noisy vibration signal - [Q&A](https://dsp.stackexchange.com/a/84095/50076), code included
42 |
43 | #### Intuition / "Deep" Explanations
44 |
45 | 1. ✯ Physical significance of negative frequencies - [Q&A](https://dsp.stackexchange.com/a/87994/50076)
46 | 2. ✯ Wavelet "center frequency" explanation, relation to CWT scales - [Q&A](https://dsp.stackexchange.com/a/76371/50076), code Q76329
47 | 3. 🡡 DFT coefficients meaning? - [Q&A](https://dsp.stackexchange.com/a/70395/50076)
48 | 4. 🡡 Why is `|x|^2` power? - [Q&A](https://dsp.stackexchange.com/a/86919/50076)
49 | 5. 🡡 Does zero-padding distort the spectrum? - [Q&A](https://dsp.stackexchange.com/a/70498/50076)
50 | 6. How is a 1D signal considered a "vector"? - [Q&A](https://dsp.stackexchange.com/a/87057/50076)
51 | 7. Are colored images 2D or 3D? - [Q&A](https://dsp.stackexchange.com/a/81831/50076)
52 | 8. Are real-world signals bandlimited? - [Q&A](https://dsp.stackexchange.com/a/75147/50076)
53 | 9. What can be meant by "bandlimited"? - [Q&A](https://dsp.stackexchange.com/a/70949/50076)
54 |
55 | #### Measures / Heuristics
56 |
57 | 1. ✯ How to validate a wavelet filterbank? - [Q&A](https://dsp.stackexchange.com/a/86069/50076), code to-be-released
58 | 2. 🡡 How to measure aliasing? - [Q&A](https://dsp.stackexchange.com/a/87706/50076), code Q87705
59 |
60 | #### Time-Frequency Analysis / Feature Engineering
61 |
62 | 1. ✯ Wavelet Scattering explanation - [Q&A](https://dsp.stackexchange.com/a/78513/50076), code Q78512
63 | 2. ✯ Joint Time-Frequency Scattering explanation - [Q&A](https://dsp.stackexchange.com/a/78623/50076), code Q78644
64 | 3. ✯ Synchrosqueezing explanation - [Q&A](https://dsp.stackexchange.com/a/71399/50076)
65 | 4. ✯ Wavelet Scattering properties & implementation - [Q&A](https://dsp.stackexchange.com/a/78515/50076), code Q78512
66 | 5. 🡡 Equivalence between "windowed Fourier transform" and STFT as convolutions/filtering - [Q&A](https://dsp.stackexchange.com/a/86938/50076), code Q86937
67 | 6. 🡡 What is "spin" for the 2D (separable) Morlet? - [Q&A](https://dsp.stackexchange.com/a/86013/50076), code to-be-released
68 | 7. 🡡 Amplitude extraction using STFT - [Q&A](https://dsp.stackexchange.com/a/86731/50076), code Q86726
69 | 8. 🡡 Why are there beats in STFT of sines? - [Q&A](https://dsp.stackexchange.com/a/87086/50076), code Q11333
70 | 9. 🡡 Power/Energy from CWT - [Q&A](https://dsp.stackexchange.com/a/86182/50076), code Q86181
71 | 10. 🡡 CWT vs DWT - [Q&A](https://dsp.stackexchange.com/a/76639/50076)
72 | 11. 🡡 One-integral inverse CWT - [Q&A](https://dsp.stackexchange.com/a/76239/50076), code included
73 | 12. 🡡 Joint Time-Frequency Scattering structure & implementation - [Q&A](https://dsp.stackexchange.com/a/78625/50076), code Q78644
74 | 13. What do CWT values correspond to? - [Q&A](https://dsp.stackexchange.com/a/86939/50076), code included
75 | 14. Role of window length and overlap in uncertainty principle - [Q&A](https://dsp.stackexchange.com/a/80920/50076), Q80918
76 | 15. Inverting a scalogram - [Q&A](https://dsp.stackexchange.com/a/78531/50076), code Q78530
77 | 16. Importance of overlap in STFT - [Q&A](https://dsp.stackexchange.com/a/76507/50076)
78 | 17. Advantages of complex wavelets over real-valued - [Q&A](https://dsp.stackexchange.com/a/76093/50076)
79 | 18. Bandpass filtering for amplitude - [Q&A](https://dsp.stackexchange.com/a/75819/50076), code included
80 | 19. How is wavelet time & frequency resolution computed? - [Q&A](https://dsp.stackexchange.com/a/72043/50076)
81 | 20. How is wavelet center frequency computed? - [Q&A](https://dsp.stackexchange.com/a/70837/50076)
82 |
83 | #### Signal Theory
84 |
85 | 1. ✯ DFT of a sine, closed form solution and insights - [Q&A](https://dsp.stackexchange.com/a/88365/50076), code Q87926
86 | 2. 🡡 Subsampling in frequency domain, Effect of sampling rate on spectrum - [Q&A](https://dsp.stackexchange.com/a/87121/50076), code included
87 | 3. 🡡 Why is SNR periodic/sinusoidal for streamed windowing of a sine? - [Q&A](https://dsp.stackexchange.com/a/87933/50076)
88 | 4. 🡡 DFT modulus of a sine, closed form solution and insights - [Q&A](https://dsp.stackexchange.com/a/88400/50076), code Q87926
89 | 5. Relationship between energy, power, and sampling rate - [Q&A](https://dsp.stackexchange.com/a/86132/50076), code Q85745
90 | 6. Hilbert transform for amplitude extraction/demodulation of broadband waveforms - [Q&A](https://dsp.stackexchange.com/a/83257/50076)
91 | 7. Other end of Nyquist limit - [Q&A](https://dsp.stackexchange.com/a/84746/50076)
92 | 8. Instantaneous frequency of a signal - [Q&A](https://dsp.stackexchange.com/a/84144/50076)
93 | 9. DFT filter bank interpretation - [Q&A](https://dsp.stackexchange.com/a/83878/50076)
94 | 10. `ifft` first half of time domain ("DFT unpad property") - [Q&A](https://dsp.stackexchange.com/a/81994/50076), code included
95 | 11. Role of padding in convolution / Alleviating edge effects in Hilbert transform - [Q&A](https://dsp.stackexchange.com/a/76792/50076), code Q76754
96 | 12. How to frequency modulate? - [Q&A](https://dsp.stackexchange.com/a/76482/50076)
97 | 13. Symmetries of analyticity - [Q&A](https://dsp.stackexchange.com/q/76560/50076)
98 | 14. Effect of sampling rate and duration on discrete parameters of sine (spectrum) - [Q&A](https://dsp.stackexchange.com/a/88389/50076)
99 |
100 | #### Algorithms
101 |
102 | 1. 🡡 2D FFT cross-correlation in Python - [Q&A](https://dsp.stackexchange.com/a/87782/50076), code Q87781
103 | 2. Efficient 2D convolution/cross-correlation along only one axis (1D output) - [Q&A](https://dsp.stackexchange.com/a/87563/50076), code included
104 | 3. Optimizing window length in STFT via gradient descent - [Q&A](https://dsp.stackexchange.com/a/81627/50076), code Q81624
105 |
106 | #### Proofs / Mathy
107 |
108 | 1. 🡡 Amplitude extraction / demodulation criteria for Hilbert transform - [Q&A](https://dsp.stackexchange.com/a/83299/50076)
109 | 2. 🡡 Is downsampling LTI for bandlimited inputs? - [Q&A](https://dsp.stackexchange.com/a/88055/50076), code Q88042
110 | 3. 🡡 Sine DFT: N/4-symmetry, bin sinusoidal modulation vs time shifts - [Q&A](https://dsp.stackexchange.com/a/88421/50076), code Q87926
111 | 4. Inverse CWT derivation - [Q&A](https://dsp.stackexchange.com/a/71148/50076)
112 |
113 |
114 |
115 |
116 | ### License
117 |
118 | MIT-licensed.
119 |
--------------------------------------------------------------------------------
/SignalProcessing/Q11333 - STFT - Beats in spectrogram of pure sines/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/11333/50076
3 | # Some code copied from my answer to https://dsp.stackexchange.com/q/85745/50076
4 | import numpy as np
5 | from scipy.signal.windows import dpss
6 | from numpy.fft import fft, ifft, fftshift, ifftshift
7 | from ssqueezepy import stft
8 | from ssqueezepy.visuals import plot, plotscat
9 |
10 | #%%############################################################################
11 | # Helpers
12 | # -------
13 | def _pad_window(w, padded_len):
14 | pleft = (padded_len - len(w)) // 2
15 | pright = padded_len - pleft - len(w)
16 | return np.pad(w, [pleft, pright])
17 |
18 | def get_adj_windows(window, n_fft, N):
19 | padded_len_conv = N + n_fft - 1
20 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2)
21 | # window_adj_f = np.pad(window, (n_fft - len(window))//2)
22 | window_adj = _pad_window(window, padded_len_conv)
23 | window_adj_f = _pad_window(window, n_fft)
24 |
25 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0
26 | def _center(w):
27 | w = ifftshift(w)
28 | w = fftshift(np.roll(w, np.argmax(w)))
29 | return w
30 |
31 | window_adj = _center(window_adj)
32 | window_adj_f = _center(window_adj_f)
33 | return window_adj, window_adj_f
34 |
35 | #%%############################################################################
36 | # Row-wise STFT implementation
37 | # ----------------------------
38 | def cisoid(N, f):
39 | t = np.linspace(0, 1, N, endpoint=False)
40 | return (np.cos(2*np.pi * f * t) +
41 | np.sin(2*np.pi * f * t) * 1j)
42 |
43 |
44 | def stft_rowwise(x, window, n_fft):
45 | assert len(window) == n_fft and n_fft <= len(x)
46 |
47 | # compute some params
48 | xp = x
49 | N = len(x)
50 | padded_len = N
51 |
52 | # pad such that `armgax(window)` for `Sx[:, 0]` is at `Sx[:, 0]`
53 | # (DFT-center the convolution kernel)
54 | # note, due to how scipy generates windows and standard stft handles padding,
55 | # this still won't yield zero-phase for majority of signals, but it's both
56 | # fixable and can be very close; ideally just pass in `len(window)==len(x)`.
57 | _wpad_right = (padded_len - len(window)) / 2
58 | wpad_right = (int(np.floor(_wpad_right)) if n_fft % 2 == 1 else
59 | int(np.ceil(_wpad_right)))
60 | wpad_left = padded_len - len(window) - wpad_right
61 |
62 | # generate filterbank
63 | cisoids = np.array([cisoid(n_fft, f) for f in range(n_fft//2 + 1)])
64 | fbank = ifftshift(window) * cisoids
65 | fbank = np.pad(fftshift(fbank), [[0, 0], [wpad_left, wpad_right]])
66 | fbank = ifftshift(fbank)
67 | fbank_f = fft(fbank, axis=-1).conj()
68 |
69 | # circular convolution
70 | prod = fbank_f * fft(xp)[None]
71 | Sx = ifft(prod, axis=-1)
72 | return Sx, fbank_f
73 |
74 | #%%############################################################################
75 | # Current answer -- messy code
76 | import matplotlib.pyplot as plt
77 | from scipy.signal.windows import gaussian
78 | from ssqueezepy import ssq_cwt, TestSignals, Wavelet
79 | from ssqueezepy.visuals import imshow
80 |
81 | N = 256
82 | n_fft = N
83 | t = np.linspace(0, 1, N, 0)
84 | x = np.cos(2*np.pi * 125 * t)
85 | window = gaussian(N, 6, 0)
86 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
87 |
88 | Sx, fbank_f = stft_rowwise(x, window_adj_f, n_fft)
89 |
90 | plot(fbank_f[::4].T, abs=1, color='tab:blue', h=.7, w=1.1,
91 | vlines=(128, {'linewidth': 3, 'color': 'tab:red'}),
92 | title="STFT filterbank, frequency-domain | Gaussian", show=1)
93 | #%%
94 | for case in (0, 1):
95 | if case == 0:
96 | a = fbank_f[-8].copy()
97 | else:
98 | a = fbank_f[-1].copy()
99 | title = "Filter peak freq = {}".format(np.argmax(np.abs(a)))
100 |
101 | fig, axes = plt.subplots(4, 1, figsize=(8, 24))
102 |
103 | b = fft(x)
104 | a /= max(abs(a))
105 | b /= max(abs(b))
106 | pkw = dict(ylims=(-.01, 1.03), abs=1, fig=fig)
107 | plotscat(b, abs=1, color='k', ax=axes[0])
108 | plot(a, **pkw, title=title, ax=axes[0])
109 | plotscat(a * b, **pkw, ax=axes[1], title="fft(filter) * fft(x)")
110 |
111 | x_filt = ifft(a * b)
112 | plot(x_filt, complex=1, ax=axes[2], fig=fig,)
113 | plot(x_filt, abs=1, color='k', linestyle='--', ax=axes[2], fig=fig,
114 | title="x_filt = ifft(fft(filter) * fft(x))")
115 |
116 | plotscat(fft(abs(x_filt)), abs=1, ax=axes[3], fig=fig,
117 | title="fft(envelope) | envelope = abs(x_filt)")
118 | fig.subplots_adjust(hspace=.13)
119 |
120 | #%%
121 | ikw = dict(w=1.2, h=.35, yticks=np.arange(len(Sx))[::-1])
122 | imshow(Sx[::-1], abs=1, **ikw, title="|STFT|")
123 | imshow(Sx.real[::-1], **ikw, title="STFT.real")
124 | imshow(Sx.imag[::-1], **ikw, title="STFT.imag")
125 |
126 | #%%
127 | ts = TestSignals(2048)
128 | x = ts.hchirp(fmin=.11)[0]
129 | t = np.linspace(0, 1, len(x), 0)
130 | x += x[::-1]
131 |
132 | for analytic in (0, 1):
133 | wavelet = Wavelet(('gmw', {'gamma': 1, 'beta': 1}))
134 | Tx, Wx, _, scales, *_ = ssq_cwt(x, wavelet, scales='log')
135 |
136 | if not analytic:
137 | psiup_t = ifft(wavelet(N=len(x)*4, scale=scales), axis=-1)
138 | wavelet._Psih = fft(psiup_t[:, ::2])
139 | Tx, Wx, *_ = ssq_cwt(x, wavelet, scales=scales)
140 |
141 | title = "analytic" if analytic else "non-analytic"
142 | for i, g in enumerate([Tx, Wx]):
143 | scl = .5 if i == 0 else .5
144 | scl = 240 if i == 0 else 3
145 | imshow(g[:], abs=1, w=.6, h=.45, norm=(0, np.abs(g).mean()*scl),
146 | yticks=0,
147 | title=title + (" |SSQ_CWT|" if i == 0 else " |CWT|"))
148 |
149 | #%%
150 | for case in (0, 1):
151 | beta = 10 if case == 1 else 60
152 | wavelet = Wavelet(('gmw', {'gamma': 3, 'beta': beta}))
153 | _ = ssq_cwt(np.arange(8192), wavelet)
154 |
155 | pf = wavelet._Psih[20].copy()
156 | pf = np.roll(pf, -np.argmax(pf) + len(pf)//2)
157 | if case == 1:
158 | pf[len(pf)//2 + 1:] = 0
159 | pf[len(pf)//2] /= 2
160 | pt = ifftshift(ifft(pf))
161 |
162 | fig, axes = plt.subplots(3, 1, figsize=(8, 18))
163 | plot(pf, abs=1, ticks=0, fig=fig, ax=axes[0], title="freq-domain")
164 | plot(pt, complex=1, fig=fig, ax=axes[1])
165 | plot(pt, abs=1, linestyle='--', color='k', ticks=0,
166 | fig=fig, ax=axes[1], title="time-domain")
167 |
168 | ctr = len(pt)//2
169 | zm = 40
170 | slc = pt[ctr-zm:ctr+zm+1]
171 | plot(slc, complex=1, fig=fig, ax=axes[2])
172 | plot(slc, abs=1, linestyle='--', color='k', show=0, ticks=0,
173 | fig=fig, ax=axes[2], title="time-domain, zoomed")
174 | fig.subplots_adjust(hspace=.1)
175 | plt.show()
176 |
--------------------------------------------------------------------------------
/SignalProcessing/Q29509 - DFT - Extract sine phase and amplitude/estimators.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/29509/50076
3 | import numpy as np
4 | from numpy.fft import rfft
5 |
6 |
7 | def est_amp_phase_cedron_2bin(x):
8 | """Amplitude & phase.
9 | "A Two Bin Solution", Cedron Dawg, Eqs 25 & 26
10 | https://www.dsprelated.com/showarticle/1284.php
11 | """
12 | # not performance-optimized
13 |
14 | # get DFT & compute freq
15 | N = len(x)
16 | xf = rfft(x)
17 |
18 | if N == 3:
19 | kmax = 1
20 | Z = xf
21 | f_N = _est_f_cedron_2bin(Z, kmax, N)
22 | else:
23 | kmax = np.argmax(abs(xf[1:-1])) + 1 # see other referenced answer
24 | Z = xf[kmax-1:kmax+2]
25 | f_N = _est_f_cedron_3bin(Z, kmax, N)
26 |
27 | # run two bin calculations
28 | alpha = f_N * (2*np.pi)
29 | alphaN = alpha * N
30 |
31 | kj = kmax - 1
32 | kk = kmax
33 | cosa = np.cos(alpha)
34 | sina = np.sin(alpha)
35 | cosbj, cosbk = np.cos(2*np.pi/N * np.array([kj, kk]))
36 | sinbj, sinbk = np.sin(2*np.pi/N * np.array([kj, kk]))
37 |
38 | UA = np.cos(alphaN) - 1
39 | VA = np.cos(alphaN - alpha) - cosa
40 |
41 | UB = np.sin(alphaN)
42 | VB = np.sin(alphaN - alpha) + sina
43 |
44 | fj = 1 / (2 * N * (cosa - cosbj))
45 | fk = 1 / (2 * N * (cosa - cosbk))
46 |
47 | A = np.array([
48 | fj * (UA * cosbj - VA),
49 | fj * (UA * sinbj),
50 | fk * (UA * cosbk - VA),
51 | fk * (UA * sinbk),
52 | ])
53 | B = np.array([
54 | fj * (UB * cosbj - VB),
55 | fj * (UB * sinbj),
56 | fk * (UB * cosbk - VB),
57 | fk * (UB * sinbk),
58 | ])
59 |
60 | # "unfurl"
61 | Z = np.array([Z[0].real, Z[0].imag, Z[1].real, Z[1].imag]) / N
62 |
63 | AB = A@B
64 | AZ = A@Z
65 | BZ = B@Z
66 | AA = A@A
67 | BB = B@B
68 |
69 | d = AA*BB - AB*AB
70 | ca = (BB*AZ - AB*BZ) / d
71 | cb = (AA*BZ - AB*AZ) / d
72 |
73 | amplitude = np.sqrt(ca**2 + cb**2)
74 | phase = np.arctan2(-cb, ca)
75 | return amplitude, phase
76 |
77 |
78 | ### Frequency case ###########################################################
79 | def _get_cedron_basics(Z, k, N):
80 | re = Z.real
81 | im = Z.imag
82 |
83 | rt2 = np.sqrt(2)
84 | betas = np.array([k - 1, k, k + 1]) * 2*np.pi/N
85 | cosb = np.cos(betas)
86 | sinb = np.sin(betas)
87 | return re, im, cosb, sinb, rt2
88 |
89 | def _cedron_bin_finish(A, B, C):
90 | P = C / np.linalg.norm(C)
91 | D = A + B
92 | K = D - (D @ P)*P
93 |
94 | num = K @ B # dot product
95 | den = K @ A
96 | ratio = max(min(num/den, 1), -1) # handles float issues
97 | f = np.arccos(ratio) / (2*np.pi)
98 | return f
99 |
100 |
101 | def est_f_cedron_3bin_complex(Z, k, N):
102 | """
103 | "Three Bin Exact Frequency Formulas for a Pure Complex Tone in a DFT",
104 | Cedron Dawg, Eq 19 (via Eqs 35, 31, 20, 16)
105 | https://www.dsprelated.com/showarticle/1043.php
106 | """
107 | R1 = np.exp(-1j*1*2*np.pi/N)
108 | DZ = Z * np.array([1/R1, 1, R1])
109 | G = np.conj(Z + DZ)
110 | K = G - np.mean(G)
111 |
112 | num = K @ Z
113 | den = K @ DZ
114 | ratio = num / den
115 | if k > N//2:
116 | k = -(N - k)
117 | f = k/N + np.arctan2(ratio.imag, ratio.real) / (2*np.pi)
118 | return f
119 |
120 |
121 | def est_f_cedron_3bin(x):
122 | N = len(x)
123 | xf = rfft(x)
124 | kmax = np.argmax(abs(xf[1:-1])) + 1
125 | Z = xf[kmax-1:kmax+2]
126 |
127 | f = _est_f_cedron_3bin(Z, kmax, N)
128 | return f
129 |
130 |
131 | def _est_f_cedron_3bin(Z, k, N):
132 | """
133 | "Improved Three Bin Exact Frequency Formula for a Pure Real Tone in a DFT",
134 | Cedron Dawg, Eq 9
135 | https://www.dsprelated.com/showarticle/1108.php
136 | """
137 | re, im, cosb, sinb, rt2 = _get_cedron_basics(Z, k, N)
138 |
139 | A = np.array([(re[1] - re[0])/rt2,
140 | (re[1] - re[2])/rt2,
141 | im[0],
142 | im[1],
143 | im[2]])
144 | B = np.array([(cosb[1]*re[1] - cosb[0]*re[0])/rt2,
145 | (cosb[1]*re[1] - cosb[2]*re[2])/rt2,
146 | cosb[0]*im[0],
147 | cosb[1]*im[1],
148 | cosb[2]*im[2]])
149 | C = np.array([(cosb[1] - cosb[0])/rt2,
150 | (cosb[1] - cosb[2])/rt2,
151 | sinb[0],
152 | sinb[1],
153 | sinb[2]])
154 |
155 | f = _cedron_bin_finish(A, B, C)
156 | return f
157 |
158 |
159 | def est_f_cedron_2bin(x):
160 | N = len(x)
161 | xf = rfft(x)
162 |
163 | if N == 3:
164 | # not necessarily optimal scheme
165 | kmax = 1
166 | Z = xf
167 | else:
168 | kmax = np.argmax(np.abs(xf))
169 | Z = xf[kmax-1:kmax+1]
170 |
171 | f = _est_f_cedron_2bin(Z, kmax, N)
172 | return f
173 |
174 |
175 | def _est_f_cedron_2bin(Z, k, N):
176 | """
177 | "A Two Bin Solution", Cedron Dawg, Eq 14
178 | https://www.dsprelated.com/showarticle/1284.php
179 | """
180 | re, im, cosb, sinb, rt2 = _get_cedron_basics(Z, k, N)
181 |
182 | A = np.array([(re[1] - re[0])/rt2,
183 | im[1],
184 | im[0]])
185 | B = np.array([(cosb[1]*re[1] - cosb[0]*re[0])/rt2,
186 | cosb[1]*im[1],
187 | cosb[0]*im[0]])
188 | C = np.array([(cosb[1] - cosb[0])/rt2,
189 | sinb[1],
190 | sinb[0]])
191 |
192 | f = _cedron_bin_finish(A, B, C)
193 | return f
194 |
--------------------------------------------------------------------------------
/SignalProcessing/Q29509 - DFT - Extract sine phase and amplitude/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/29509/50076
3 | import warnings
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | # ensure the files can be found
8 | import sys
9 | from pathlib import Path
10 | _dir = Path(__file__).parent
11 | assert _dir.is_file() or _dir.is_dir(), str(_dir)
12 | if not any(str(_dir).lower() == p.lower() for p in sys.path):
13 | sys.path.insert(0, str(_dir))
14 |
15 | from estimators import est_amp_phase_cedron_2bin, est_f_cedron_2bin
16 |
17 | #%% Helpers ##################################################################
18 | def cisoid(N, f, phi=0):
19 | return (np.cos(2*np.pi*f*np.arange(N)/N + phi) +
20 | np.sin(2*np.pi*f*np.arange(N)/N + phi)*1j)
21 |
22 | def est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi,
23 | errs_A=None, errs_f=None):
24 | if errs_f is not None:
25 | f_est = est_f_cedron_2bin(x)
26 |
27 | A_est, phi_est = est_amp_phase_cedron_2bin(x)
28 | err_A, err_A_alt, err_phi = get_errs_A_phi(A_est, phi_est, A, phi)
29 |
30 | if errs_A is not None:
31 | errs_A.append(err_A)
32 | if errs_f is not None:
33 | errs_f.append((f_est - f_N)**2)
34 |
35 | # normalized
36 | errs_A_alt.append(err_A_alt)
37 |
38 | errs_phi.append(err_phi)
39 |
40 |
41 | def get_errs_A_phi(A_est, phi_est, A, phi):
42 | err_A = (A_est - A)**2
43 | err_A_alt = (1 - A_est/A)**2
44 |
45 | # note inherent ambiguity; discard sign in this case
46 | if abs(abs(phi_est) - np.pi) < .001:
47 | err_phi = (abs(phi_est) - abs(phi))**2
48 | else:
49 | err_phi = (phi_est - phi)**2
50 | return err_A, err_A_alt, err_phi
51 |
52 |
53 | #%% Manual testing ###########################################################
54 | N = 3
55 | f = 0.0035035*N
56 | phi = -1
57 | A = 1
58 |
59 | x = A*cisoid(N, f, phi).real
60 | f_N = f/N
61 |
62 | f_est = est_f_cedron_2bin(x)
63 | A_est, phi_est = est_amp_phase_cedron_2bin(x)
64 |
65 | print(f_est / f_N)
66 | print(A_est / A)
67 | print(phi_est / phi)
68 |
69 | #%% Test full (noiseless) ####################################################
70 | N = 3
71 |
72 | errs_f_fmax, errs_A_fmax, errs_phi_fmax = [], [], []
73 | errs_A_alt_fmax = []
74 | # 0.5 takes more coding but can be handled
75 |
76 | # make freqs
77 | f_N_all = []
78 | for f_N in np.linspace(0, 0.5, 203, endpoint=False):
79 | # skip near-integer frequency per numeric instability
80 | # (robust version not implemented)
81 | if (f_N * N) % 1 < .01:
82 | continue
83 | f_N_all.append(f_N)
84 | f_N_all = np.array(f_N_all)
85 |
86 | # make amplitudes
87 | A_all = np.logspace(np.log10(1e-3), np.log10(1000), 100)
88 | # make phases
89 | phi_all = np.linspace(-np.pi, np.pi, 100)
90 |
91 | for f_N in f_N_all:
92 | errs_f, errs_A, errs_phi = [], [], []
93 | errs_A_alt = []
94 | f = f_N * N
95 |
96 | for A in A_all:
97 | for phi in phi_all:
98 | x = A*cisoid(N, f, phi).real
99 | est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi,
100 | errs_A, errs_f)
101 |
102 | errs_f_fmax.append(np.max(errs_f))
103 | errs_A_fmax.append(np.max(errs_A))
104 | errs_phi_fmax.append(np.max(errs_phi))
105 | errs_A_alt_fmax.append(np.max(errs_A_alt))
106 |
107 | #%% Visualize ################################################################
108 | fig, ax = plt.subplots(layout='constrained', figsize=(12, 8))
109 |
110 | ax.plot(f_N_all, np.log10(errs_f_fmax))
111 | ax.plot(f_N_all, np.log10(errs_A_alt_fmax))
112 | ax.plot(f_N_all, np.log10(errs_phi_fmax))
113 |
114 | title = ("Max errors: frequency, amplitude, phase | N={}\n"
115 | "n_freqs, n_amps, n_phases = {}, {}, {}"
116 | ).format(N, len(f_N_all), len(A_all), len(phi_all))
117 |
118 | ax.set_title(title, weight='bold', fontsize=24, loc='left')
119 | ax.set_ylabel("Squared Error (log10)", fontsize=22)
120 | ax.set_xlabel("f/N", fontsize=22)
121 | ax.legend(["f", "A", "phi"], fontsize=22)
122 |
123 | plt.show()
124 |
125 | #%% Test full (noisy) ########################################################
126 | np.random.seed(0)
127 | N = 300
128 | n_trials = 50
129 |
130 | snrs = np.linspace(-10, 50, 25)
131 |
132 | # make freqs
133 | f_N_all = []
134 | for f_N in np.linspace(1/N, 0.5-1/N, 42):
135 | # skip near-integer frequency per numeric instability
136 | # (robust version not implemented)
137 | if (f_N * N) % 1 < .01:
138 | continue
139 | f_N_all.append(f_N)
140 | f_N_all = np.array(f_N_all)
141 |
142 | # make amplitudes
143 | A_all = np.logspace(np.log10(1e-3), np.log10(1000), 10)
144 |
145 | errs_A_alt_all, errs_phi_all = [], []
146 | for snr in snrs:
147 | errs_A_alt, errs_phi = [], []
148 | for f_N in f_N_all:
149 | f = f_N * N
150 | for A in A_all:
151 | for _ in range(n_trials):
152 | phi = np.random.uniform(-np.pi, np.pi)
153 | xo = A * cisoid(N, f, phi).real
154 | noise_var = xo.var() / 10**(snr/10)
155 | noise = np.random.randn(N) * np.sqrt(noise_var)
156 | x = xo + noise
157 |
158 | with warnings.catch_warnings(record=True) as ws:
159 | est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi)
160 | # if len(ws) > 0:
161 | # print(snr, f_N, f_N*N, A, phi, sep='\n')
162 | # 1/0
163 | errs_A_alt_all.append(np.mean(errs_A_alt))
164 | errs_phi_all.append(np.mean(errs_phi))
165 | print(end='.', flush=True)
166 |
167 | #%% Visualize ################################################################
168 | fig, ax = plt.subplots(layout='constrained', figsize=(12, 8))
169 |
170 | ax.plot(snrs, np.log10(errs_A_alt_all))
171 | ax.plot(snrs, np.log10(errs_phi_all))
172 |
173 | title = ("N={}, f/N=lin sweep, phase=randu(-pi, pi)\n"
174 | "n_freqs={}, n_amps={}, n_trials_per_f_and_A={}"
175 | ).format(N, len(f_N_all), len(A_all), n_trials)
176 |
177 | ax.set_title(title, weight='bold', fontsize=24, loc='left')
178 | ax.set_ylabel("MSE (log10)", fontsize=22)
179 | ax.set_xlabel("SNR [dB]", fontsize=22)
180 | ax.legend(["A", "phi"], fontsize=22)
181 | ax.set_ylim(-10, 1)
182 |
183 | plt.show()
184 |
185 | #%% OP's case ################################################################
186 | np.random.seed(0)
187 | A = 1
188 | f = 30.1
189 | N = 300
190 | t = np.linspace(0, 1, N, endpoint=False)
191 | n_trials = 10000
192 |
193 | errs_A_alt, errs_phi = [], []
194 | for _ in range(n_trials):
195 | phi = np.random.uniform(-np.pi, np.pi)
196 | xo = A*np.cos(2*np.pi*f*t + phi)
197 | noise = np.random.normal(0, 0.05, len(t))
198 | x = xo + noise
199 |
200 | A_est, phi_est = est_amp_phase_cedron_2bin(x)
201 |
202 | _, err_A_alt, err_phi = get_errs_A_phi(A_est, phi_est, A, phi)
203 | errs_A_alt.append(err_A_alt)
204 | errs_phi.append(err_phi)
205 |
206 | SNR = 10*np.log10((A/2)/.05**2)
207 | print("MSE: A={:.3g}, phi={:.3g} -- SNR={:.3g}, N={}, n_trials={}".format(
208 | np.mean(errs_A_alt), np.mean(errs_phi), SNR, N, n_trials))
209 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im0.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im1.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im10.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im11.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im12.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im13.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im14.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im15.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im16.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im3.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im4.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im5.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im6.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im7.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im8.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im9.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/76329/50076 ################################
3 | import numpy as np
4 | from numpy.fft import ifft, ifftshift
5 | from ssqueezepy.visuals import plot, scat, imshow
6 | from ssqueezepy.utils import make_scales, cwt_scalebounds
7 | from ssqueezepy import ssq_cwt, cwt, Wavelet
8 | from kymatio.scattering1d.filter_bank import morlet_1d
9 |
10 | #%%# Helper methods ##########################################################
11 | def viz0(xf, xref=None, center=True):
12 | x = ifft(xf)
13 | if center:
14 | x = ifftshift(x)
15 | scat(xf[:32], show=1)
16 | if xref is not None:
17 | xref *= np.abs(x).max() / np.abs(xref).max()
18 | plot(xref, color=grey, linestyle='--')
19 | plot(x.real, show=1)
20 |
21 | def morlet(N, xi, sigma=.1/2, viz=1, trim=True, nonhalved=False):
22 | xf = morlet_1d(N, xi=xi, sigma=sigma)
23 |
24 | trim_idx = int(xi * N) + 1
25 | if trim:
26 | xf[trim_idx:] = 0
27 | if not nonhalved:
28 | xf[trim_idx - 1] /= 2
29 | x = ifftshift(ifft(xf))
30 |
31 | if viz:
32 | scat(xf, show=1)
33 | plot(x, complex=1)
34 | plot(x, abs=1, color='k', linestyle='--', show=1)
35 | return x, xf
36 |
37 | def ref_sine(xi, x=None, N=None, endpoint=False, zoom=None):
38 | N = N if N is not None else len(x)
39 | t = np.linspace(0, 1, N, endpoint=endpoint)
40 | xref = np.cos(2*np.pi * (xi * N) * t)
41 | if x is not None:
42 | xref *= x.real.max() / xref.max()
43 |
44 | if zoom is not None:
45 | ctr = N // 2
46 | zoom = (zoom, zoom) if not isinstance(zoom, tuple) else zoom
47 | a, b = ctr - zoom[0], ctr + zoom[1] + 1
48 | _t = np.arange(a, b)
49 |
50 | plot(_t, xref[a:b], color=grey, linestyle='--')
51 | if x is not None:
52 | plot(_t, x.real[a:b], show=1, title="zoomed, real part")
53 | return xref
54 |
55 | grey = (.1, .1, .1, .4)
56 |
57 | #%%# Simple sinusoid #########################################################
58 | N = 128
59 | t = np.linspace(0, 1, N, 0)
60 | xf = np.zeros(N)
61 |
62 | xf[16] = 1
63 | xref = ifft(xf).real
64 | viz0(xf, center=0)
65 |
66 | #%%# Add lateral peaks
67 | xf = np.zeros(N)
68 | xf[15] = .5
69 | xf[17] = .5
70 | viz0(xf, center=0)
71 | #%%# Both at once
72 | xf[16] = 1
73 | viz0(xf, xref)
74 | #%%# Positive + negative example
75 | xf = np.zeros(N)
76 | xf[15] = .5
77 | xf[17] = -.5
78 | viz0(xf, center=0)
79 |
80 | #%%# GIF #####################################################################
81 | xf_full = morlet_1d(N, xi=16/128, sigma=.1/8)
82 | xf = np.zeros(N)
83 | xf[16] = xf_full[16]
84 | viz0(xf)
85 |
86 | for idx in (17, 18, 19, 20, 21):
87 | xf[idx] = xf_full[idx]
88 | xf[16 + (16 - idx)] = xf_full[16 + (16 - idx)]
89 | viz0(xf, xref)
90 |
91 | #%%# CWT near Nyquist ########################################################
92 | N = 129
93 | t = np.linspace(0, 1, N, 1)
94 | x = np.cos(2*np.pi * 64 * t)
95 |
96 | wavelet = Wavelet('morlet')
97 | min_scale, max_scale = cwt_scalebounds(wavelet, N=N, preset='minimal')
98 | # adjust such that highest freq wavelet's peak is at Nyquist
99 | scales0 = make_scales(N, min_scale, max_scale, wavelet=wavelet) * 1.12
100 | scales = scales0
101 | Wx, _ = cwt(x, wavelet, scales=scales)
102 |
103 | imshow(Wx, abs=1, title="abs(CWT)", xlabel="time", ylabel="scale",
104 | yticks=scales)
105 | plot(Wx[:, 0], abs=1, title="abs(CWT) at a time slice (same for all slices)",
106 | xlabel="scale", show=1, xticks=scales)
107 | imshow(Wx.real)
108 |
109 | #%%# Morlet near nyquist #####################################################
110 | N = 128
111 | xf = morlet(N, xi=.5, viz=1)
112 |
113 | #%%# Trimmed peak but higher sampling rate
114 | N = 512
115 | xi = .5 / 4
116 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=1, nonhalved=True)
117 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=1)
118 |
119 | #%%# Zoomed sinusoid at center frequency
120 | xref = ref_sine(xi, x=x, zoom=20)
121 | xref = ref_sine(xi, x=x, zoom=(80, -40))
122 |
123 | #%%# SSQ_CWT of wavelet ######################################################
124 | # extreme time-localization
125 | wavelet1 = Wavelet(('gmw', {'beta': 1, 'gamma': 1}))
126 |
127 | Tx0, Wx0, *_ = ssq_cwt(x.real, wavelet1, scales='log')
128 | mx = np.abs(Tx0).max() * .6
129 | # zoomed
130 | imshow(Tx0, abs=1, title="abs(SSQ_CWT) of wavelet.real", norm=(0, mx))
131 | imshow(Tx0[20:160], abs=1, norm=(0, mx))
132 |
133 | #%%# CWT
134 | # to take perfect CWT of wavelet, tweak slightly to account for padding
135 | xref = np.cos(2*np.pi * (N//8) * np.linspace(0, 1, N + 1, 1))
136 | Tx1, Wx1, *_ = ssq_cwt(xref, wavelet, scales='log')
137 | imshow(Tx1, abs=1, title="abs(SSQ_CWT) of sinusoid at peak center freq")
138 |
139 | #%%# Peak center frequency ###################################################
140 | N = 512
141 | xi = .5 / 4
142 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=0, nonhalved=True)
143 | scat(xf)
144 | f = int(N * xi)
145 | scat(np.array([f]), xf[f], color='red', s=30, show=1)
146 |
147 | #%%
148 | xref = ref_sine(xi, x=x)
149 | plot(xref, color=grey)
150 | plot(x.real, show=1, title="peak center frequency")
151 |
152 | #%%
153 | t = np.linspace(0, 1, N, 1)
154 | xref = np.cos(2*np.pi * 64 * np.linspace(0, 1, 513, 1))
155 |
156 | wavelet = Wavelet('morlet')
157 | min_scale, max_scale = cwt_scalebounds(wavelet, N=N, preset='minimal')
158 | scales = make_scales(N, min_scale, max_scale, wavelet=wavelet) * 1.12
159 | Wx, scales = cwt(xref, wavelet, scales=scales)
160 |
161 | #%%# Energy center frequency #################################################
162 | axf = np.abs(xf)
163 | axfs = axf**2
164 | fmean_l1 = np.sum(np.arange(len(x)) * axf / axf.sum())
165 | fmean = np.sum(np.arange(len(x)) * axfs / axfs.sum())
166 | xref = np.cos(2*np.pi * fmean * np.linspace(0, 1, N, 0))
167 | xref *= x.real.max() / xref.max()
168 |
169 | plot(xref, color=grey)
170 | plot(x.real, show=1, title="energy center frequency")
171 |
172 | #%% For spectrum use one near Nyquist
173 | xref = ref_sine(xi=.5, N=129, endpoint=1)
174 | Wx, _ = cwt(xref, 'morlet', scales=scales0)
175 | imshow(Wx, abs=1)
176 |
177 | slc = np.abs(Wx[:, 0])
178 | fmean_nyq = np.sum(np.arange(len(slc)) * (slc**2) / (slc**2).sum())
179 | scat(slc, vlines=(fmean_nyq, {'color': 'tab:red'}), show=1)
180 | imshow(Wx, abs=1)
181 |
182 | #%%# Central instantaneous center frequency ##################################
183 | # determine central instantaneous frequency from CWT since SSQ
184 | # reassigns nonlinearly
185 | N = len(Tx0[0])
186 | finst_idx = np.argmax(np.abs(Wx0[:, 256]))
187 | psihs = wavelet1.Psih()
188 | finst = np.argmax(np.abs(psihs[finst_idx])) // 2 # x2 pad doubles index
189 |
190 | #%%
191 | xref = np.cos(2*np.pi * finst * np.linspace(0, 1, N, 0))
192 | xref *= x.real.max() / xref.max()
193 | ctr = N // 2
194 | d = 9
195 | a, b = ctr - d, ctr + d + 1
196 | xref = xref[a:b] * N / (a - b)
197 | _t = np.arange(a, b)
198 |
199 | plot(_t, xref, color=grey, auto_xlims=0)
200 | plot(x.real, show=1, title="central instantaneous center frequency")
201 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/wavgif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/wavgif.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im0.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im1.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im2.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/76463/50076 ################################
3 | import numpy as np
4 | from numpy.fft import fft, ifft
5 | from scipy.io import wavfile
6 | from ssqueezepy import ssq_cwt, Wavelet
7 | from ssqueezepy.visuals import imshow, plot
8 |
9 | #%%# Helper methods ##########################################################
10 | def frequency_modulate(slc, fc=None, b=.3):
11 | N = len(slc)
12 | if fc is None:
13 | fc = N / 18 # arbitrary
14 | # track actual `b` for demodulation purposes
15 | b_effective = b
16 |
17 | t_min, t_max = start / fs, end / fs
18 | t = np.linspace(t_min, t_max, N, endpoint=False)
19 | assert np.allclose(fs, 1 / np.diff(t))
20 |
21 | x0 = slc[:N]
22 | # ensure it's [-.5, .5] so diff(phi) is b*[-pi, pi]
23 | x0max = np.abs(x0).max()
24 | x0 /= (2*x0max)
25 | b_effective /= (2*x0max)
26 |
27 | # generate phase
28 | phi0 = 2*np.pi * fc * t
29 | phi1 = 2*np.pi * b * np.cumsum(x0)
30 | phi = phi0 + phi1
31 | diffmax = np.abs(np.diff(phi)).max()
32 | # `b` correction
33 | if diffmax > np.pi or np.allclose(phi, np.pi):
34 | diffmax0 = np.abs(np.diff(phi0)).max()
35 | diffmax1 = np.abs(np.diff(phi1)).max()
36 | # epsilon term for stable inversion / pi-unambiguity
37 | eps = 1e-7
38 | factor = ((np.pi - diffmax0 - eps) / diffmax1)
39 | phi1 *= factor
40 | b_effective *= factor
41 | phi = phi0 + phi1
42 | assert np.abs(np.diff(phi)).max() <= np.pi
43 |
44 | # modulate
45 | x = np.cos(phi)
46 | return x, t, phi0, phi1, b_effective
47 |
48 | def analytic(x):
49 | N = len(x)
50 | xf = fft(x)
51 |
52 | xaf = np.zeros(N, dtype='complex128')
53 | xaf[:N//2 + 1] = 2 * xf[:N//2 + 1]
54 | xaf[0] /= 2
55 | xaf[N//2] /= 2
56 |
57 | xa = ifft(xaf)
58 | assert np.allclose(xa.real, x)
59 | return xa
60 |
61 | #%%# Load data, select slice #################################################
62 | fs, data = wavfile.read(r"C:\Desktop\recording.wav")
63 | data = data.astype('float64')
64 | data /= (2*np.abs(data).max())
65 |
66 | start, end = 0, fs
67 | slc = data[start:end]
68 |
69 | #%%# Modulate & validate #####################################################
70 | x, t, phi0, phi10, b_effective = frequency_modulate(slc)
71 |
72 | # extreme time localization
73 | wavelet = Wavelet(('gmw', {'gamma': 1, 'beta': 1}))
74 | # synchrosqueezed CWT
75 | Tx, Wx, ssq_freqs, *_ = ssq_cwt(x, wavelet)
76 | # ints for better plotting
77 | ssq_freqs = (ssq_freqs * fs).astype(int)
78 |
79 | #%%# Visualize ##############################################################
80 | def viz(t_min=None, t_max=None, f_min=None, f_max=None, show_original=False,
81 | show_modulated=False):
82 | freqs = ssq_freqs[::-1]
83 | a = int(t_min * fs) if t_min is not None else 0
84 | b = int(t_max * fs) if t_max is not None else None
85 | d = np.argmin(np.abs(freqs - f_min)) if f_min is not None else None
86 | c = np.argmin(np.abs(freqs - f_max)) if f_max is not None else 0
87 |
88 | imshow(Tx[c:d, a:b], xticks=t[a:b], yticks=freqs[c:d], **kw)
89 |
90 | if show_original:
91 | plot(t[a:b], slc[a:b], xlabel="time [sec]", title="original data",
92 | show=1)
93 | if show_modulated:
94 | plot(t[a:b], x[a:b], xlabel="time [sec]", title="modulated data",
95 | show=1)
96 |
97 | mx = np.abs(Tx).max() * .4
98 | kw = dict(abs=1, norm=(0, mx), title="abs(SSQ_CWT)",
99 | xlabel="time [sec]", ylabel="frequency [Hz]")
100 | viz()
101 | viz(0, .2, 300, 8000, show_original=1)
102 | viz(0, .02, 300, 8000, show_original=1, show_modulated=1)
103 |
104 | #%%
105 | # plot(x, show=1)
106 | # plot(slc, show=1)
107 | phie = np.arccos(x)[:2048]
108 |
109 | phi = np.unwrap(np.angle(analytic(x)))[:2048]
110 | phi_exact = (phi0 + phi10)[:2048]
111 | phi[:10] = phi_exact[:10]
112 | ae = np.abs(phi - phi_exact)
113 | print(ae.mean(), ae.max(), ae.min())
114 |
115 | phi1 = (phi - phi0[:2048])
116 | x_inv_cs = phi1 / (2*np.pi * b_effective)
117 | x_inv = np.diff(x_inv_cs)
118 | plot(x_inv)
119 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76560 - Analyticity - Symmetries/main.py:
--------------------------------------------------------------------------------
1 | # https://dsp.stackexchange.com/q/76560/50076 ################################
2 | # -*- coding: utf-8 -*-
3 | import numpy as np
4 | from numpy.fft import ifft, ifftshift
5 | from kymatio.scattering1d.filter_bank import morlet_1d
6 | from ssqueezepy.visuals import plot as _plot, plotscat
7 |
8 | def plot(*args, **kw):
9 | if 'title' in kw:
10 | kw['title'] = (kw['title'], {'fontsize': 18})
11 | mx = re_im_max(args[0]) * 1.05
12 | kw['ylims'] = (-mx, mx)
13 | _plot(*args, **kw)
14 |
15 | def re_im_max(x):
16 | return max(np.abs(x.real).max(), np.abs(x.imag).max())
17 |
18 | def re_im_parts(x, y):
19 | A = x.real * y.real
20 | B = -x.imag * y.imag
21 | C = x.real * y.imag
22 | D = x.imag * y.real
23 | return A, B, C, D
24 |
25 | def l2(x):
26 | return np.sqrt(np.sum(np.abs(x)**2))
27 |
28 | #%%# Morlet visuals ##########################################################
29 | N = 256
30 | pf = morlet_1d(N, xi=.025, sigma=.1/20)
31 | pt = ifftshift(ifft(pf))
32 | #%%
33 | kw = dict(complex=1, show=1, ticks=(1, 0))
34 | plot(pt, **kw, title="psi: analytic Morlet")
35 |
36 | #%%
37 | A, B, C, D = re_im_parts(pt, pt)
38 | plot(A + 1j*B, **kw, title="(psi * psi).real")
39 | plot(C + 1j*D, **kw, title="(psi * psi).imag")
40 | plot(pt * pt, **kw,
41 | title="(psi * psi).real + (psi * psi).imag = psi * psi")
42 |
43 | #%%
44 | pta = np.conj(pt)
45 | plot(pta, **kw, title="psia: anti-analytic Morlet")
46 |
47 | #%%
48 | plot(pt.imag + 1j*pta.imag, **kw, title="psia = psi.real - psi.imag")
49 |
50 | #%%
51 | A, B, C, D = re_im_parts(pt, pta)
52 | plot(A + 1j*B, **kw, title="(psi * psia).real")
53 | plot(C + 1j*D, **kw, title="(psi * psia).imag")
54 | plot(pt * pta, **kw,
55 | title="(psi * psia).real + (psi * psia).imag = psi * psia")
56 |
57 | #%%# Random sequence visuals #################################################
58 | def viz_seq(x, show=1, complex=1, **kw):
59 | mx = re_im_max(x) * 1.1
60 | plotscat(x, complex=complex, ylims=(-mx, mx), show=show, **kw,
61 | hlines=(0, {'color': 'tab:red', 'linewidth': 1}),
62 | ticks=(1, 0), auto_xlims=0)
63 |
64 | def zero_sum_seq(N):
65 | x = np.random.randn(N) + 1j*np.random.randn(N)
66 | x[N//2:] = 0
67 |
68 | slc = x[:N//2][::-1]
69 | x[N//2:] = slc.real - 1j * slc.imag
70 | x -= x.mean()
71 |
72 | x.real *= (l2(x.imag) / l2(x.real))
73 | return x
74 |
75 | #%%
76 | np.random.seed(0)
77 | N = 12
78 | x = zero_sum_seq(N)
79 | y = zero_sum_seq(N)
80 |
81 | viz_seq(x, title="x | sum = {:.2f}".format(x.sum()))
82 | viz_seq(y, title="y | sum = {:.2f}".format(y.sum()))
83 | viz_seq(x * x, title="x * x | sum = {:.2f}".format((x*x).sum()))
84 | viz_seq(y * y, title="y * y | sum = {:.2f}".format((y*y).sum()))
85 | viz_seq(x * y, title="x * y | sum = {:.2f}".format((x * y).sum()))
86 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets0.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets1.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76560 - Analyticity - Symmetries/rands0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/rands0.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/WGN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/WGN.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/im2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/im2.png
--------------------------------------------------------------------------------
/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/ims.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/ims.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def plot(x, title):
6 | fig = plt.figure()
7 | plt.plot(np.abs(x))
8 | plt.scatter(np.arange(len(x)), np.abs(x), s=10)
9 | plt.title(title, weight='bold', fontsize=18, loc='left')
10 | return fig
11 |
12 | def plot_T(x, Tmax):
13 | if Tmax == 0:
14 | title = "|H(w)|: x(n)"
15 | elif Tmax == 1:
16 | title = "|H(w)|: x(n) - x(n - 1)"
17 | elif Tmax == 2:
18 | title = "|H(w)|: x(n) - x(n - 1) + x(n - 2)"
19 | else:
20 | title = "|H(w)|: x(n) - x(n - 1) + x(n - 2) - ... x(n - %s)" % Tmax
21 |
22 | fig = plot(x, title, scatter=1)
23 | plt.ylim(-.05, 1.05)
24 |
25 | plt.savefig(f'im{Tmax}.png', bbox_inches='tight')
26 | plt.close(fig)
27 |
28 | def csoid(f):
29 | return (np.cos(2*np.pi* f * t) -
30 | np.sin(2*np.pi* f * t) * 1j)
31 |
32 | #%%# Direct frequency response ###############################################
33 | N = 32
34 | t = np.linspace(0, 1, N, 0)
35 |
36 | for Tmax in range(N):
37 | x = np.sum([(-1)**T * csoid(T) for T in range(Tmax + 1)], axis=0)
38 | x /= np.abs(x).max()
39 | plot_T(x, Tmax)
40 |
41 | #%%# WGN example #############################################################
42 | def plot_and_save(x, title, savepath):
43 | fig = plot(x, title)
44 | plt.savefig(savepath, bbox_inches='tight')
45 | plt.close(fig)
46 |
47 | np.random.seed(69)
48 | x = np.random.randn(32)
49 | xf0 = np.fft.fft(x)
50 | x = x - np.roll(x, 1) + np.roll(x, 2)
51 | xf1 = np.fft.fft(x)
52 |
53 | plot_and_save(xf0, "|X(w)|: x(n)", "WGN0.png")
54 | plot_and_save(xf1, "|X(w)|: x(n) - x(n - 1) + x(n - 2)", "WGN1.png")
55 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/_optimized.pyx:
--------------------------------------------------------------------------------
1 | import cython
2 | from libc.math cimport atan2
3 |
4 |
5 | @cython.boundscheck(False)
6 | @cython.wraparound(False)
7 | cpdef double kay_weighted_complex(double[:] R, double[:] I, double[:] W):
8 | # initialize variables
9 | cdef Py_ssize_t N = R.shape[0]
10 | cdef Py_ssize_t i = 0
11 | cdef double f_est = 0
12 |
13 | # main loop
14 | for i in range(N - 1):
15 | f_est += W[i] * atan2(R[i]*I[i + 1] - I[i]*R[i + 1],
16 | R[i]*R[i + 1] + I[i]*I[i + 1])
17 |
18 | return f_est
19 |
20 |
21 | @cython.boundscheck(False)
22 | @cython.wraparound(False)
23 | cpdef int abs_argmax(double[:] R, double[:] I):
24 | # initialize variables
25 | cdef Py_ssize_t N = R.shape[0]
26 | cdef Py_ssize_t i = 0
27 | cdef int max_idx = 0
28 | cdef double current_max = 0
29 | cdef double current_abs2 = 0
30 |
31 | # main loop
32 | for i in range(N):
33 | current_abs2 = R[i]*R[i] + I[i]*I[i]
34 | if current_abs2 > current_max:
35 | max_idx = i
36 | current_max = current_abs2
37 |
38 | return max_idx
39 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/benchmarks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/76644/50076
3 | import os
4 | os.environ['NUMEXPR_NUM_THREADS'] = '1'
5 | os.environ['OMP_NUM_THREADS'] = '1'
6 | os.environ['OPENBLAS_NUM_THREADS'] = '1'
7 | os.environ['MKL_NUM_THREADS'] = '1'
8 | os.environ["VECLIB_MAXIMUM_THREADS"] = '1'
9 | os.environ["NUMEXPR_NUM_THREADS"] = '1'
10 |
11 | import numpy as np
12 | import matplotlib.pyplot as plt
13 |
14 | # ensure the files can be found
15 | import sys
16 | from pathlib import Path
17 | _dir = Path(__file__).parent
18 | assert _dir.is_file() or _dir.is_dir(), str(_dir)
19 | if not any(str(_dir).lower() == p.lower() for p in sys.path):
20 | sys.path.insert(0, str(_dir))
21 |
22 | from optimized import Kay2Complex, Cedron3BinComplex
23 | from utils76644 import make_x, timeit
24 |
25 | #%%
26 | kay2_complex_optimized = Kay2Complex()
27 | cedron_3bin_complex_optimized = Cedron3BinComplex()
28 | fns = {
29 | 'kay_2': kay2_complex_optimized,
30 | 'cedron_3bin': cedron_3bin_complex_optimized,
31 | }
32 |
33 | #%% Manual testing ###########################################################
34 | np.random.seed(0)
35 | name = ('cedron_3bin', 'kay_2')[1]
36 | N = 100
37 | f = N*0.153123
38 | snr = 100
39 | x = make_x(N, f, snr, real=False)
40 |
41 | fn = fns[name]
42 | f_est = fn(x)
43 |
44 | print(f_est / (f/N), sep='\n')
45 |
46 | #%% Benchmark ################################################################
47 | np.random.seed(0)
48 | times = {'cedron_3bin': {}, 'kay_2': {}}
49 | n_trials = 400
50 | repeats = 20
51 | Npow2 = True
52 |
53 | N_all = 2**np.arange(7, 21)
54 | if not Npow2:
55 | for i, N in enumerate(N_all):
56 | N_all[i] = round(N, -(len(str(N)) - 2))
57 |
58 | for N in N_all:
59 | x = np.random.randn(N) + 1j*np.random.randn(N)
60 | for name in times:
61 | fn = fns[name]
62 | # warmup
63 | for _ in range(10):
64 | fn(x)
65 | # bench
66 | times[name][N] = timeit(fn, x, n_trials, repeats)
67 | print(end='.', flush=True)
68 |
69 | print(times)
70 |
71 | #%% Visualize ################################################################
72 | data = [np.array(list(times['cedron_3bin'].values())),
73 | np.array(list(times['kay_2'].values()))]
74 | N_all = np.array(list(times['cedron_3bin']))
75 |
76 | fig, ax = plt.subplots(1, 1, layout='constrained', figsize=(9.5, 8))
77 |
78 | ax.plot( N_all, data[1] / data[0])
79 | ax.scatter(N_all, data[1] / data[0], s=40)
80 |
81 | ax.set_xscale('log')
82 |
83 | ax.set_xlabel("N", size=20)
84 | ax.set_ylabel("t_ratio", size=20)
85 | title = ("Time ratios (Kay2_complex/Cedron_complex)\n"
86 | "n_trials={}, n_repeats={}\n"
87 | "Npow2={}, FFTW=False"
88 | ).format(n_trials, repeats, Npow2)
89 | ax.set_title(title, weight='bold', fontsize=24, loc='left')
90 | ax.set_ylim(0, 3)
91 | ax.axhline(1, color='tab:red')
92 |
93 | plt.show()
94 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/76644/50076
3 | # All code is written for readability, not performance.
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | # ensure the files can be found
8 | import sys
9 | from pathlib import Path
10 | _dir = Path(__file__).parent
11 | assert _dir.is_file() or _dir.is_dir(), str(_dir)
12 | if not any(str(_dir).lower() == p.lower() for p in sys.path):
13 | sys.path.insert(0, str(_dir))
14 |
15 | from estimators import est_freq, estimator_fns
16 | from utils76644 import (
17 | make_x, run_test, run_test_multitone, snrs_db_practical, snrs_db_wide,
18 | run_viz, run_viz2, run_viz_multitone)
19 |
20 | print("Available estimators:\n " + "\n ".join(list(estimator_fns)))
21 |
22 | #%% Configurations ###########################################################
23 | # USER -----------------------------------------------------------------------
24 | # prints test progress
25 | VERBOSE = 1
26 | # other configs in `utils` file
27 |
28 | # set certain defaults
29 | f_N_all_nonints_large_N = (0.05393, 0.10696, 0.25494, 0.46595)
30 | f_N_all_nonints_small_N = (0.053, 0.106, 0.254, 0.465)
31 | f_N_all_ints_small_N = (0.05, 0.10, 0.25, 0.46)
32 | f_N_all_ints_large_N = f_N_all_nonints_small_N
33 |
34 | #%% Manual testing ###########################################################
35 | np.random.seed(0)
36 | name = ('cedron_3bin', 'kay_2', 'dft_quadratic')[0]
37 | # name = ('cedron_3bin_complex',)[0]
38 | N = 10000
39 | f = N*0.053123
40 | phi = 1
41 | x = np.cos(2*np.pi * f * np.arange(N)/N + phi)
42 | # x = x + 1j*np.sin(2*np.pi * f * np.arange(N)/N)
43 | x += np.random.randn(N) * .1
44 |
45 | print(est_freq(x, name) / (f/N), sep='\n')
46 |
47 | #%% Full testing #############################################################
48 | # configure
49 | seed = 0
50 | N = 100
51 | n_trials = 2000
52 | real = True
53 | sweep_mode = ('practical', 'wide')[0]
54 | name0, name1 = 'cedron_3bin', 'kay_2'
55 | # name0, name1 = 'cedron_3bin', 'dft_quadratic'
56 | # name0, name1 = 'cedron_3bin_complex', 'kay_2'
57 | f_N_all = (f_N_all_nonints_small_N, f_N_all_nonints_large_N,
58 | f_N_all_ints_small_N, f_N_all_ints_large_N)[0]
59 | # f_N_all = np.linspace(1/N, .5-1/N, 50)
60 | # f_N_all = np.linspace(-.5+1/N, .5-1/N, 100)
61 |
62 | errs0_all, errs1_all, snrs, crlbs = run_test(
63 | f_N_all, N, n_trials, name0, name1, real, seed, sweep_mode, verbose=VERBOSE)
64 |
65 | #%% Visualize ################################################################
66 | names = ("Cedron", "Kay_2")
67 | # names = ("Cedron", "DFT_quadratic")
68 |
69 | args = (errs0_all, errs1_all, snrs, crlbs, f_N_all, N, n_trials, names)
70 | run_viz(*args)
71 | # run_viz2(*args)
72 |
73 | #%% Multi-tone Example #######################################################
74 | seed = 0
75 | N = 10000
76 | n_trials = 2000
77 | name0, name1 = 'cedron_3bin', 'dft_quadratic'
78 | # include integer case
79 | f_N_all = (0.05305, 0.10605, 0.254, 0.46505)
80 | # f_N_all = (0.10601, 0.10644, 0.10696, 0.10747)
81 | A_all = (0.5, 0.8, 1.2, 1.5) # mean=1
82 | # each `A` will have different SNR, so extend the range so we can plot all
83 | # under a common snr
84 | snrs_bounds = (snrs_db_practical[0], snrs_db_practical[-1])
85 | snrs = np.linspace(snrs_bounds[0] - 10, snrs_bounds[1] + 15,
86 | int(len(snrs_db_practical)*1.25))
87 |
88 | errs0_all, errs1_all, snrs_f_N, crlbs = run_test_multitone(
89 | f_N_all, A_all, N, n_trials, name0, name1, snrs, seed, verbose=VERBOSE)
90 |
91 | #%% Visualize ################################################################
92 | names = ("Cedron", "DFT_quadratic")
93 | run_viz_multitone(errs0_all, errs1_all, snrs_f_N, crlbs,
94 | f_N_all, A_all, N, n_trials, snrs_bounds, names=names)
95 |
96 | #%% "Extreme" example ########################################################
97 | N = 100000
98 | f_N = .053056
99 | snr = -30
100 | f = f_N * N
101 |
102 | errs = []
103 | for _ in range(2000):
104 | x = make_x(N, f, snr)
105 | f_est = est_freq(x, 'cedron_3bin')
106 | errs.append((f_N - f_est)**2)
107 | mse = np.mean(errs)
108 |
109 | print(mse)
110 |
111 | #%%
112 | x, xo = make_x(N, f, snr, get_xo=True)
113 | fig, axes = plt.subplots(1, 2, figsize=(15.5, 6.5), layout='constrained')
114 | axes[0].plot(xo[:500])
115 | axes[0].set_title("Original signal, zoomed | N=100000",
116 | weight='bold', loc='left', fontsize=24)
117 | axes[1].plot(x[:500])
118 | axes[1].set_title("Noisy, zoomed | SNR=1/1000 (-30dB)",
119 | weight='bold', loc='left', fontsize=24)
120 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/optimized.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Performance-optimized estimators."""
3 | # https://dsp.stackexchange.com/q/76644/50076
4 | import numpy as np
5 | from scipy.fft import fft
6 | from _optimized import kay_weighted_complex, abs_argmax
7 |
8 |
9 | class Kay2Complex():
10 | def __init__(self):
11 | self.weights_N = {}
12 |
13 | def __call__(self, x, N=None):
14 | N = len(x)
15 | if N in self.weights_N:
16 | weights = self.weights_N[N]
17 | else:
18 | idxs = np.arange(N - 1)
19 | weights = 1.5*N / (N**2 - 1) * (1 - ((idxs - (N/2 - 1)) / (N/2))**2
20 | ) / (2*np.pi)
21 | self.weights_N[N] = weights
22 |
23 | f_est = kay_weighted_complex(x.real, x.imag, weights)
24 | return f_est
25 |
26 |
27 | class Cedron3BinComplex():
28 | def __call__(self, x):
29 | xf = fft(x)
30 | kmax = abs_argmax(xf.real, xf.imag)
31 | return 5
32 |
33 | # didn't bother implementing the O(1) finishing step.
34 | # it'd only slightly change the N=100 result, and not at all for
35 | # anything else.
36 | # the handicap of not using FFTW is far greater
37 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # python setup.py build_ext --inplace
3 | from distutils import _msvccompiler
4 | _msvccompiler.PLAT_TO_VCVARS['win-amd64'] = 'amd64'
5 |
6 | from setuptools import setup, Extension
7 | from Cython.Build import cythonize
8 | import numpy as np
9 |
10 | setup(
11 | ext_modules=cythonize(Extension("_optimized", ["_optimized.pyx"]),
12 | language_level=3),
13 | include_dirs=[np.get_include()],
14 | )
15 |
--------------------------------------------------------------------------------
/SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/data.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/data.npy
--------------------------------------------------------------------------------
/SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/76754/50076 ################################
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from numpy.fft import fft, ifft
6 |
7 | def analytic(x):
8 | N = len(x)
9 | xf = fft(x)
10 | xf[1:N//2] *= 2
11 | if N % 2 == 1:
12 | xf[N//2] *= 2
13 | xf[N//2 + 1:] = 0
14 | xa = ifft(xf)
15 | assert np.allclose(xa.real, x)
16 | return xa
17 |
18 | def plot(x, title=None, show=0):
19 | plt.plot(x)
20 | if title is not None:
21 | plt.title(title, loc='left', weight='bold', fontsize=17)
22 | if show:
23 | plt.show()
24 |
25 | #%%###########################################################################
26 | # load
27 | x = np.load('data.npy')
28 | N = len(x)
29 | xa = analytic(x)
30 |
31 | # pad + take hilbert
32 | xp = np.pad(x, N, mode='reflect')
33 | xpa = analytic(xp)
34 | # unpad
35 | xpu = xp[N:-N]
36 | xpau = xpa[N:-N]
37 |
38 | #%% visualize
39 | plot(x, title="original")
40 | plot(np.abs(xa), show=1)
41 |
42 | plot(xpu, title="reflect-padded + unpadded")
43 | plot(np.abs(xpau), show=1)
44 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Continuous Wavelet Transform w/ lowpassing (first-order scattering) GIF."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import matplotlib.animation as animation
8 | from numpy.fft import fft, ifft, fftshift, ifftshift
9 | from kymatio.numpy import Scattering1D
10 | from kymatio.toolkit import echirp
11 |
12 | #%%## Set params & create scattering object ##################################
13 | N = 512
14 | J, Q = 4, 3
15 | T = 2**(J + 1)
16 | # make CWT
17 | average, oversampling = False, 999
18 | ts = Scattering1D(shape=N, J=J, Q=Q, average=average, oversampling=oversampling,
19 | out_type='list', max_order=1, T=T, max_pad_factor=None)
20 | ts.psi1_f.pop(-1) # drop last for niceness
21 |
22 | #%%# Create signal & warp it #################################################
23 | x = echirp(N, fmin=16, fmax=N/2.0)
24 | t = np.linspace(0, 1, N, 1)
25 |
26 | meta = ts.meta()
27 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
28 |
29 | #%%# Animate #################################################################
30 | class CWTAnimation(animation.TimedAnimation):
31 | def __init__(self, ts, x, t, freqs, stride=1, alpha='6f'):
32 | self.ts = ts
33 | self.x = x
34 | self.t = t
35 | self.freqs = freqs
36 | self.stride = stride
37 | self.N = len(x)
38 | assert self.N % 2 == 0
39 |
40 | # unpack filters
41 | self.psi_fs = [p[0] for p in ts.psi1_f]
42 | self.phi_f = ts.phi_f[0]
43 | self.phi_t = ifft(self.phi_f).real
44 | # conjugate since conv == cross-correlation with conjugate
45 | psi_ts = np.array([np.conj(ifft(p)) for p in self.psi_fs])
46 | self.psi_ts = psi_ts / psi_ts.real.max()
47 |
48 | # compute filter params
49 | self.js = [p['j'] for p in ts.psi1_f]
50 | self.lens = [int(1.2*p['support'][0]) for p in ts.psi1_f]
51 | self.n_psis = len(self.psi_fs)
52 | self.Np = len(self.psi_fs[0])
53 | self.trim = ts.ind_start[0]
54 | self.phi_t_trim = fftshift(ifftshift(self.phi_t)[self.trim:-self.trim])
55 | # padded x to do CWT with
56 | self.xpf = fft(np.pad(x, [ts.pad_left, ts.pad_right], mode='reflect'))
57 |
58 | # take CWT without modulus
59 | self.Wx = np.array([ifft(p * self.xpf) for p in self.psi_fs])
60 | self.Ux = np.abs(self.Wx)
61 | self.Sx = ifft(fft(self.Ux) * self.phi_f).real
62 | start, end = ts.ind_start[0], ts.ind_end[0]
63 | self.Sx = self.Sx[:, start:end]
64 | self.Ux = self.Ux[:, start:end]
65 | # suppress boundary effects for visuals
66 | mx = self.Sx.max()
67 | mx_avg = self.Sx[5].max()
68 | self.Sx[0] *= (mx_avg / mx) * 1.3
69 | self.Sx[-1] *= (mx_avg / mx) * 1.3
70 | # spike causes graphical glitch in converting to gif -- flatten
71 | mx = self.Ux.max() * 1.02
72 | self.Ux[0] *= .58 / .61
73 | self.Ux[0, -1] = mx
74 |
75 | self.alpha = '6f'
76 | self.title_kw = dict(weight='bold', fontsize=16, loc='left')
77 | self.label_kw = dict(weight='bold', fontsize=14, labelpad=3)
78 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left")
79 | self._prev_psi_idx = -1
80 |
81 | fig, axes = plt.subplots(1, 2, figsize=(18/1.7, 8/1.7))
82 |
83 | # |CWT|*phi_f ####################################################
84 | ax = axes[0]
85 | self.lines0 = []
86 |
87 | ax.imshow(self.Ux, cmap='turbo', aspect='auto', interpolation='none')
88 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True))
89 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True)
90 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels])
91 | yticks = np.linspace(0, self.n_psis - 1, 6, endpoint=True)
92 | ax.set_yticks(yticks)
93 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N)
94 | for yt in yticks])
95 |
96 | phi_t = 6 - 600*self.phi_t_trim
97 | ax.plot(np.arange(self.N), phi_t, color='white', linewidth=2)
98 | self.lines0.append(ax.lines[-1])
99 |
100 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=15)
101 |
102 | # output #########################################################
103 | ax = axes[1]
104 | im = ax.imshow(self.Sx, cmap='turbo', animated=True, aspect='auto',
105 | interpolation='none')
106 | self.Sx_now = 0 * self.Sx
107 | im.set_array(self.Sx_now)
108 | self.ims1 = [im]
109 |
110 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True))
111 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True)
112 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels])
113 | yticks = np.linspace(0, self.n_psis - 1, 6, endpoint=True)
114 | ax.set_yticks(yticks)
115 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N)
116 | for yt in yticks])
117 | ax.set_title(r"$|\psi \star x| \star \phi$", **self.title_kw)
118 |
119 | # finalize #######################################################
120 | fig.subplots_adjust(left=.05, right=.99, bottom=.06, top=.94,
121 | wspace=.15)
122 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True)
123 |
124 | def _draw_frame(self, frame_idx):
125 | step = self.stride * frame_idx % self.N
126 | total_step = self.stride * frame_idx
127 | psi_idx = int(total_step / self.N)
128 | # at right bound
129 | if self.stride != 1 and (step < self.stride - 1 and frame_idx > 1):
130 | step = self.N
131 | psi_idx -= 1
132 |
133 | # |CWT|*phi_f ####################################################
134 | T = self.ts.phi_f['support']
135 | start, end = max(0, step - T//2), min(self.N, step + T//2)
136 | phi_t = 6 - 600*self.phi_t_trim
137 | phi_t = np.roll(phi_t, step)[start:end]
138 | self.lines0[0].set_data(np.arange(start, end), phi_t)
139 |
140 | tau = "%.2f" % ((step / self.N) * (self.t.max() - self.t.min()))
141 | self.txt0.set_text(r"$\phi(t - {}),\ |\psi \star x|$".format(tau))
142 |
143 | # output #########################################################
144 | self.Sx_now[:, start:step] = self.Sx[:, start:step]
145 | self.ims1[0].set_array(self.Sx_now)
146 |
147 | # finalize ###########################################################
148 | self._prev_psi_idx = psi_idx
149 | self._drawn_artists = [*self.lines0, self.txt0, *self.ims1]
150 |
151 | def new_frame_seq(self):
152 | return iter(range(1 * int(self.N // self.stride + 1)))
153 |
154 | def _init_draw(self):
155 | pass
156 |
157 |
158 | ani = CWTAnimation(ts, x, t, freqs, stride=2)
159 | ani.save('cwt_phi.mp4', fps=60)
160 | plt.show()
161 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/reconstruction.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/reconstruction.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_resc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_resc.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_unscaled_anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_unscaled_anim.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/second_order.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Second-order unaveraged scattering on A.M. cosine and White Gaussian Noise."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | from numpy.fft import fft, ifft
7 | import matplotlib.pyplot as plt
8 | from kymatio.numpy import Scattering1D
9 | from kymatio.visuals import (plot, imshow,
10 | filterbank_heatmap, filterbank_scattering)
11 |
12 | #%%
13 | def cwt_and_viz(x, ts, show_max_rows=False):
14 | # do manually to override `j2 > j1` and see all coeffs
15 | xp = np.pad(x, [ts.pad_left, ts.pad_right], mode='reflect')
16 | U0_f = fft(xp)
17 |
18 | # |CWT|
19 | U1 = []
20 | for p1 in ts.psi1_f:
21 | U1.append(np.abs(ifft(U0_f * p1[0])))
22 | U1 = np.array(U1)
23 | U1_f = fft(U1, axis=-1)
24 |
25 | # U2
26 | U2 = []
27 | for p2 in ts.psi2_f:
28 | U2.append(np.abs(ifft(U1_f * p2[0])))
29 | U2 = np.array(U2)
30 |
31 | # S1 = ifft(U1_f * ts.phi_f[0]).real
32 | # U2_f = fft(U2, axis=-1)
33 | # S2 = ifft(U2_f * ts.phi_f[0][None, None]).real
34 |
35 | # unpad all
36 | s, e = ts.ind_start[0], ts.ind_end[0]
37 | U1, U2 = [g[..., s:e] for g in (U1, U2)]
38 | # S1, S2 = [g[..., s:e] for g in (S1, S2)]
39 |
40 | # viz first-order ########################################################
41 | fs = N
42 | xi1s = np.array([p['xi'] for p in ts.psi1_f]) * fs
43 | imshow(U1, abs=1, title="|CWT(x)|", yticks=xi1s, ylabel="frequency [Hz]",
44 | xlabel="time", xticks=t, w=.8, h=.6)
45 |
46 | if show_max_rows:
47 | mx_idx = np.argmax(np.sum(U1**2, axis=-1))
48 | xi1 = xi1s[mx_idx]
49 | title = "|CWT(x, xi1={:.1f} Hz)| (max row)".format(xi1)
50 | plot(t, U1[mx_idx], title=title, ylims=(0, None), show=1,
51 | ylabel="A.M. rate [Hz]", xlabel="time [sec]", w=.7, h=.9)
52 |
53 | # viz second-oder ########################################################
54 | fig, axes = plt.subplots(3, 3, figsize=(16, 16))
55 | xi2s = np.array([p['xi'] for p in ts.psi2_f]) * fs
56 |
57 | # U2 /= (U1 + U1.max()/10000)[None]
58 | # U2 = np.log(1 + U2)
59 | mx = np.max(U2) * .9
60 | for U2_idx, ax in enumerate(axes.flat):
61 | U2_idx += 1
62 | xi2 = xi2s[U2_idx]
63 | title = "xi2 = {:.1f} Hz".format(xi2)
64 | imshow(U2[U2_idx], abs=1, title=title, norm=(0, mx),
65 | show=0, ax=ax, ticks=0)
66 |
67 | label_kw = dict(weight='bold', fontsize=17)
68 | axes[0, 0].set_ylabel("xi1", **label_kw)
69 | axes[1, 0].set_ylabel("xi1", **label_kw)
70 | axes[2, 0].set_ylabel("xi1", **label_kw)
71 | axes[2, 0].set_xlabel("time", **label_kw)
72 | axes[2, 1].set_xlabel("time", **label_kw)
73 | axes[2, 2].set_xlabel("time", **label_kw)
74 |
75 | plt.subplots_adjust(wspace=.02, hspace=.1)
76 | plt.show()
77 |
78 | if show_max_rows:
79 | mx_idx2 = np.argmax(np.sum(U2**2, axis=(1, 2)))
80 | mx_idx1 = np.argmax(np.sum(U2[mx_idx2]**2, axis=-1))
81 | xi1, xi2 = xi1s[mx_idx1], xi2s[mx_idx2]
82 | title = "|CWT(|CWT(x, xi1={:.1f} Hz)|, xi2={:.1f} Hz)| (max row)".format(
83 | xi1, xi2)
84 | U2_max_row = U2[mx_idx2, mx_idx1]
85 |
86 | plot(t, U2_max_row, title=title, show=1, xlabel="time [sec]",
87 | ylabel="Rate of A.M. rate [Hz^2]",
88 | ylims=(0, U2_max_row.max() * 1.03))
89 |
90 | #%%
91 | N = 2049
92 | kw = dict(shape=N, J=9, Q=8, max_pad_factor=None, oversampling=999, max_order=1)
93 | ts = Scattering1D(**kw)
94 |
95 | #%%
96 | filterbank_scattering(ts, second_order=1, zoom=-1)
97 | filterbank_heatmap(ts, second_order=True)
98 |
99 | #%%# AM cosine ###############################################################
100 | f0, f1 = 64, 3
101 | t = np.linspace(0, 1, N, 1)
102 | c = np.cos(2*np.pi * f0 * t)
103 | a = (1 + (np.cos(2*np.pi * f1 * t))) / 2
104 | x = a * c
105 |
106 | title = "$\cos(2\pi {} t) \cdot (1 + \cos(2\pi {} t))/2$".format(f0, f1)
107 | plot(t, x, show=1, title=title, xlabel="time [sec]")
108 | #%%
109 | cwt_and_viz(x, ts, show_max_rows=1)
110 |
111 | #%%# WGN #####################################################################
112 | np.random.seed(0)
113 | x = np.random.randn(N)
114 | plot(t, x, show=1, title="White Gaussian Noise", xlabel="time [sec]")
115 |
116 | #%%
117 | cwt_and_viz(x, ts, show_max_rows=0)
118 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/sine_warp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/sine_warp.png
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/tshift_invar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Illustrating time-shift invariance."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | import os
7 | import matplotlib.pyplot as plt
8 | from kymatio.numpy import Scattering1D
9 | from kymatio.visuals import plot, make_gif, plotscat
10 | from kymatio.toolkit import l2
11 | from scipy.signal.windows import tukey
12 |
13 | #%%## Set params & create scattering object ##################################
14 | N = 2048
15 | J, Q = 8, 10
16 | T = 2**J
17 | # make CWT
18 | average, oversampling = False, 999
19 | ts = Scattering1D(shape=N, J=J, Q=Q, average=average, oversampling=oversampling,
20 | out_type='list', max_order=1, T=T, max_pad_factor=None)
21 | meta = ts.meta()
22 |
23 | #%%# Create signal & warp it #################################################
24 | f = 128
25 | width = N//4
26 | shift = width // 4
27 |
28 | window = tukey(width)
29 | window = np.pad(window, (N - width)//2)
30 | t = np.linspace(0, 1, N, 0)
31 | x = np.cos(2*np.pi * f * t) * window
32 |
33 | x0 = np.roll(x, -shift//2)
34 | x1 = np.roll(x, shift//2)
35 |
36 | #%%# CWT
37 | _Scx0, _Scx1 = ts(x0), ts(x1)
38 | Scx0 = np.vstack([c['coef'] for c in _Scx0])[meta['order'] == 1]
39 | Scx1 = np.vstack([c['coef'] for c in _Scx1])[meta['order'] == 1]
40 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
41 |
42 | x_all = [x0, x1]
43 | cwt_all = [Scx0, Scx1]
44 |
45 | #%% make scattering objects
46 | T0 = 2**(J - 2) * 2
47 | T1 = 2**J * 2
48 | kw = dict(shape=N, J=J, Q=Q, average=True,
49 | out_type='list', max_order=1, max_pad_factor=3)
50 | ts0 = Scattering1D(**kw, T=T0, oversampling=0)
51 | ts1 = Scattering1D(**kw, T=T1, oversampling=0)
52 | meta = ts0.meta()
53 |
54 | #%% scatter
55 | _Scx00, _Scx01, _Scx10, _Scx11 = ts0(x0), ts0(x1), ts1(x0), ts1(x1)
56 | Scx00 = np.vstack([c['coef'] for c in _Scx00])[meta['order'] == 1]
57 | Scx01 = np.vstack([c['coef'] for c in _Scx01])[meta['order'] == 1]
58 | Scx10 = np.vstack([c['coef'] for c in _Scx10])[meta['order'] == 1]
59 | Scx11 = np.vstack([c['coef'] for c in _Scx11])[meta['order'] == 1]
60 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
61 |
62 | x_all = [x0, x1]
63 | Scx_all = [Scx00, Scx10, Scx01, Scx11]
64 |
65 | #%%# Time-shift GIF ##########################################################
66 | # configure
67 | base_name = "tshift"
68 | images_ext = ".png"
69 | savedir = r"C:\Desktop\School\Deep Learning\DL_Code\signals\viz_gen"
70 | overwrite = True
71 |
72 | # common plot kwargs
73 | title_kw = dict(weight='bold', fontsize=20, loc='left')
74 | label_kw = dict(weight='bold', fontsize=18)
75 | imshow_kw = dict(cmap='turbo', aspect='auto', vmin=0)
76 | vmax00 = np.array(cwt_all).max()*.95
77 | vmax0 = max(Scx00.max(), Scx01.max())
78 | vmax1 = max(Scx10.max(), Scx11.max())
79 |
80 | for i in (0, 2):
81 | fig, axes = plt.subplots(2, 2, figsize=(18, 14))
82 |
83 | # cwt ####################################################################
84 | # plot
85 | ax = axes[0, 0]
86 | ax.plot(t, x_all[i//2])
87 | # style
88 | txt = ("" if i == 0 else
89 | " - %.2f" % (shift/N))
90 | ax.set_title("x(t%s)" % txt, **title_kw)
91 | xticks = np.array([*np.linspace(0, N, 6, 1)[:-1], N - 1])
92 | xticklabels = ['%.2f' % tk for tk in t[np.round(xticks).astype(int)]]
93 | ax.set_xticks(xticks / N)
94 | ax.set_xticklabels(xticklabels, fontsize=14)
95 | ax.set_xlim(-.03, 1.03)
96 | ax.set_yticks([])
97 |
98 | # imshow
99 | ax = axes[0, 1]
100 | ax.imshow(cwt_all[i//2], **imshow_kw, vmax=vmax00)
101 | # style
102 | ax.set_title("|CWT(x)|", **title_kw)
103 | ax.set_yticks([])
104 | ax.set_xticks(xticks)
105 | ax.set_xticklabels(xticklabels, fontsize=14)
106 |
107 | # scattering #############################################################
108 | Scx0, Scx1 = Scx_all[i], Scx_all[i + 1]
109 |
110 | # imshow
111 | ax = axes[1, 0]
112 | ax.imshow(Scx0, **imshow_kw, vmax=vmax0)
113 | # style
114 | ax.set_title("S(x) | T=%.2f sec" % (T0 / N), **title_kw)
115 | ax.set_xticks([]); ax.set_yticks([])
116 |
117 | # # imshow
118 | ax = axes[1, 1]
119 | ax.imshow(Scx1, **imshow_kw, vmax=vmax1)
120 | # style
121 | ax.set_title("S(x) | T=%.2f sec" % (T1 / N), **title_kw)
122 | ax.set_xticks([]); ax.set_yticks([])
123 |
124 | # finalize
125 | plt.subplots_adjust(hspace=.15, wspace=.04)
126 |
127 | # save
128 | path = os.path.join(savedir, f'{base_name}{i}{images_ext}')
129 | if os.path.isfile(path) and overwrite:
130 | os.unlink(path)
131 | if not os.path.isfile(path):
132 | fig.savefig(path, bbox_inches='tight')
133 | plt.close()
134 |
135 | # make into gif
136 | savepath = os.path.join(savedir, f'{base_name}.gif')
137 | make_gif(loaddir=savedir, savepath=savepath, ext=images_ext, duration=750,
138 | overwrite=overwrite, delimiter=base_name, verbose=1, start_end_pause=0,
139 | delete_images=True)
140 |
141 | #%%# plot superimposed #######################################################
142 | def plot_superimposed(g0, g1, T, ax):
143 | t_idxs = np.arange(len(g0))
144 | for t_idx in t_idxs:
145 | plot([t_idx, t_idx], [g0[t_idx], g1[t_idx]],
146 | color='tab:red', linestyle='--', ax=ax)
147 |
148 | title = ("T={:.2f} | reldist={:.2f}".format(T / N, float(l2(g0, g1))),
149 | {'fontsize': 20})
150 | plotscat(g0, title=title, ax=ax)
151 | plotscat(g1, xlims=(-len(g0)/40, 1.03*(len(g0) - 1)),
152 | ylims=(0, 1.03*max(g0.max(), g1.max())), ax=ax)
153 |
154 | mx_idx = np.argmax(Scx00.sum(axis=-1))
155 |
156 | fig, axes = plt.subplots(1, 2, figsize=(18, 7))
157 | plot_superimposed(Scx00[mx_idx], Scx01[mx_idx], T0, axes[0])
158 | plot_superimposed(Scx10[mx_idx], Scx11[mx_idx], T1, axes[1])
159 |
160 | plt.subplots_adjust(wspace=.09)
161 | plt.show()
162 |
163 | #%% Reldist vs T #############################################################
164 | kw = dict(shape=N, J=6, Q=8, average=True,
165 | out_type='array', max_order=1, max_pad_factor=4)
166 | # implementation switches to simple average with unpad which
167 | # detracts from curve slightly in log space, so do N-1 instead of N
168 | T_all = list(range(1, N + 1, 16)) + [N - 1]
169 | ts_all = [Scattering1D(**kw, T=T) for T in T_all]
170 |
171 | #%%
172 | Scx0_all = [ts(x0) for ts in ts_all]
173 | Scx1_all = [ts(x1) for ts in ts_all]
174 | reldist_all = [l2(S0, S1).squeeze() for S0, S1 in zip(Scx0_all, Scx1_all)]
175 |
176 | #%% plot
177 | T_all_sec = np.array(T_all) / N
178 |
179 | fig, axes = plt.subplots(1, 2, figsize=(18, 7))
180 | plot(T_all_sec, reldist_all, ax=axes[0],
181 | hlines=(0, dict(color='tab:red', linestyle='--', linewidth=1)),
182 | title=("Scattering coefficient relative distance", {'fontsize': 20}),
183 | xlabel="T [sec]", ylabel="reldist")
184 |
185 | T_all_sec_log = np.log2(T_all_sec)
186 | reldist_all_log = np.log10(reldist_all)
187 | plot(T_all_sec_log, reldist_all_log, ax=axes[1],
188 | title=("logscaled", {'fontsize': 20}),
189 | xlabel="log2(T) [sec]", ylabel="log10(reldist)")
190 |
191 | plt.subplots_adjust(wspace=.1)
192 | plt.show()
193 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Sinusoid and echirp warping CWT GIF."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import matplotlib.animation as animation
8 | from kymatio.torch import Scattering1D
9 | from kymatio.visuals import imshow
10 | from kymatio.toolkit import _echirp_fn
11 |
12 | def tau(t, K=.08):
13 | return np.cos(2*np.pi * t**2) * K
14 |
15 | def adtau(t, K=.08):
16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K)
17 |
18 | import torch
19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 | #%% Create scattering object #################################################
22 | N, f = 2048, 64
23 | J, Q = 8, 8
24 | T = 512
25 | # increase freq width to improve temporal localization
26 | # to not confuse ill low-freq localization with warp effects, for echirp
27 | ts = Scattering1D(shape=N, J=J, Q=Q, average=False, oversampling=999,
28 | T=T, out_type='list', r_psi=.82)
29 | if USE_GPU:
30 | ts.cuda()
31 |
32 | #%%# configure time-warps ####################################################
33 | K_init = .01
34 | n_pts = 64
35 | t = np.linspace(0, 1, N, 0)
36 |
37 | p = ts.psi1_f[0]
38 | QF = p['xi'] / p['sigma']
39 |
40 | adtau_init_max = adtau(t, K=K_init).max()
41 | K_min = (1/QF) / 10
42 | K_max = (1 / adtau_init_max) * K_init
43 |
44 | K_all = np.logspace(np.log10(K_min), np.log10(K_max), n_pts - 1, 1)
45 | K_all = np.hstack([0, K_all])
46 | tau_all = np.vstack([tau(t, K=k) for k in K_all])
47 | adtau_max_all = np.vstack([adtau(t, K=k).max() for k in K_all])
48 |
49 | assert adtau_max_all.max() <= 1, adtau_max_all.max()
50 |
51 | #%% Animate #################################################################
52 | class PlotImshowAnimation(animation.TimedAnimation):
53 | def __init__(self, plot_frames, imshow_frames, t, freqs, adtau_max_all):
54 | self.plot_frames = plot_frames
55 | self.imshow_frames = imshow_frames
56 | self.t = t
57 | self.freqs = freqs
58 | self.adtau_max_all = adtau_max_all
59 |
60 | self.N = len(t)
61 | self.n_rows = len(imshow_frames[0])
62 | self.n_frames = len(imshow_frames)
63 |
64 | self.title_kw = dict(weight='bold', fontsize=15, loc='left')
65 | self.label_kw = dict(weight='bold', fontsize=14, labelpad=3)
66 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left")
67 |
68 | fig, axes = plt.subplots(1, 2, figsize=(10.6, 4.7))
69 |
70 | # plots ##############################################################
71 | ax = axes[0]
72 | self.lines0 = []
73 |
74 | ax.plot(plot_frames[0])
75 | self.lines0.append(ax.lines[-1])
76 |
77 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True))
78 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True)
79 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels])
80 |
81 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=15)
82 |
83 | # imshows ############################################################
84 | ax = axes[1]
85 | im = ax.imshow(self.imshow_frames[0], cmap='turbo', animated=True,
86 | aspect='auto')
87 | self.ims1 = [im]
88 |
89 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True))
90 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True)
91 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels])
92 | yticks = np.linspace(0, self.n_rows - 1, 6, endpoint=True)
93 | ax.set_yticks(yticks)
94 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N)
95 | for yt in yticks])
96 |
97 | ax.set_title("|CWT(x(t))|", **self.title_kw)
98 |
99 | # finalize #######################################################
100 | fig.subplots_adjust(left=.05, right=.99, bottom=.06, top=.94,
101 | wspace=.15)
102 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True)
103 |
104 | def _draw_frame(self, frame_idx):
105 | # plots ##############################################################
106 | self.lines0[0].set_ydata(self.plot_frames[frame_idx])
107 |
108 | if frame_idx == 0:
109 | txt = r"$x(t) \ ... \ |\tau'(t)| = 0$"
110 | else:
111 | txt = r"$x(t) \ ... \ |\tau'(t)| < %.3f$" % self.adtau_max_all[
112 | frame_idx]
113 | self.txt0.set_text(txt)
114 |
115 | # imshows ############################################################
116 | self.ims1[0].set_array(self.imshow_frames[frame_idx])
117 |
118 | # finalize ###########################################################
119 | self._drawn_artists = [*self.lines0, self.txt0, *self.ims1]
120 |
121 | def new_frame_seq(self):
122 | return iter(range(self.n_frames))
123 |
124 | def _init_draw(self):
125 | pass
126 |
127 | #%%# Run on each signal ######################################################
128 | def extend(x):
129 | return np.array(list(x) + 6*[x[-1]])
130 |
131 | def run(x_all, adtau_max_all, ts, name):
132 | meta = ts.meta()
133 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
134 |
135 | Scx_all0 = [ts(x) for x in x_all]
136 | Scx_all = np.vstack([(np.vstack([c['coef'].cpu().numpy() for c in Scx]
137 | )[meta['order'] == 1])[None]
138 | for Scx in Scx_all0])
139 |
140 | # animate
141 | x_all, Scx_all, adtau_max_all = [extend(h) for h in
142 | (x_all, Scx_all, adtau_max_all)]
143 | plot_frames, imshow_frames = x_all, Scx_all
144 |
145 | ani = PlotImshowAnimation(plot_frames, imshow_frames, t, freqs,
146 | adtau_max_all)
147 | ani.save(f'warp_cwt_{name}.mp4', fps=10)
148 | plt.show()
149 |
150 | #%% echirp
151 | _t = (t - tau_all)
152 | a0 = np.cos(2*np.pi * 3.7 * _t)
153 | x_all = np.cos(_echirp_fn(fmin=40, fmax=N/8)(_t)) * a0
154 | #%%
155 | run(x_all, adtau_max_all, ts, "echirp")
156 | #%% visualize warps for sine case with lower freq
157 | def _tt(title):
158 | return (title, {'fontsize': 20})
159 |
160 | fig, axes = plt.subplots(1, 2, figsize=(18, 7))
161 | x_all_viz = x_all#np.cos(2*np.pi * 64 * (t - tau_all))
162 | imshow(x_all_viz, title=_tt("x(t - tau(t)) for all tau"), show=0, ax=axes[0],
163 | xlabel="time", ylabel="max(|tau'|)", yticks=adtau_max_all)
164 | imshow(tau_all, title=_tt("all tau(t)"), show=0, ax=axes[1],
165 | xlabel="time", yticks=0)
166 |
167 | for ax in axes:
168 | ax.set_xticks(np.linspace(0, len(t), 6))
169 | ax.set_xticklabels([0, .2, .4, .6, .8, 1])
170 |
171 | plt.subplots_adjust(wspace=.03)
172 | fig.savefig("sine_warp.png", bbox_inches='tight')
173 | plt.show()
174 | plt.close(fig)
175 |
176 | #%% sine
177 | x_all = np.cos(2*np.pi * f * (t - tau_all))
178 | run(x_all, adtau_max_all, ts, "sine")
179 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_echirp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_echirp.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_sine.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_sine.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Scattering unwarped vs warped echirp GIF, sweeping `T`."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import matplotlib.animation as animation
8 | from kymatio.torch import Scattering1D
9 | from kymatio.visuals import plot
10 | from kymatio.toolkit import _echirp_fn, l2
11 |
12 | def tau(t, K=.08):
13 | return np.cos(2*np.pi * t**2) * K
14 |
15 | def adtau(t, K=.08):
16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K)
17 |
18 | import torch
19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 | #%%###########################################################################
22 | N = 2048
23 | J, Q = 8, 8
24 | T_all = np.arange(1, N/2 + 1, 8)
25 | T_all = np.logspace(np.log2(1), np.log2(N - 1), N//8, base=2, endpoint=True)
26 |
27 | kw = dict(shape=N, J=J, Q=Q, average=True, oversampling=999, out_type='array',
28 | max_order=1, max_pad_factor=None, r_psi=.82)
29 | ts_all = [Scattering1D(**kw, T=T) for T in T_all]
30 | meta_all = [ts.meta() for ts in ts_all]
31 |
32 | if USE_GPU:
33 | for ts in ts_all:
34 | ts.cuda()
35 |
36 | #%%###########################################################################
37 | meta = meta_all[0]
38 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
39 | #%%
40 | t = np.linspace(0, 1, N, 0)
41 | _tau = tau(t, K=.012)
42 | a0 = np.cos(2*np.pi * 3.7 * (t - _tau))
43 | x0 = np.cos(_echirp_fn(fmin=40, fmax=N/8)(t)) * a0
44 | x1 = np.cos(_echirp_fn(fmin=40, fmax=N/8)(t - _tau)) * a0
45 | #%%
46 | Scx_all0 = np.array([ts(x0)[1:].cpu().numpy() for ts in ts_all])
47 | Scx_all1 = np.array([ts(x1)[1:].cpu().numpy() for ts in ts_all])
48 | #%%
49 | dists = np.array([l2(S0, S1) for S0, S1 in zip(Scx_all0, Scx_all1)]).squeeze()
50 | plot(T_all, dists, ylims=(0, None), show=1)
51 |
52 | #%% extend frames
53 | def extend(x):
54 | return np.array(6*[x[0]] + list(x) + 6*[x[-1]])
55 |
56 | Scx_all0, Scx_all1, dists, T_all = [extend(h) for h in
57 | (Scx_all0, Scx_all1, dists, T_all)]
58 |
59 | #%% Animate #################################################################
60 | class PlotImshowAnimation(animation.TimedAnimation):
61 | def __init__(self, imshow_frames0, imshow_frames1, plot_frame, T_all, vmax):
62 | self.imshow_frames0 = imshow_frames0
63 | self.imshow_frames1 = imshow_frames1
64 | self.plot_frame = plot_frame
65 | self.T_all = T_all
66 | self.vmax = vmax
67 |
68 | self.N = imshow_frames0[0].shape[-1]
69 | self.T_all_sec_log = np.log2(self.T_all / N)
70 | self.n_frames = len(imshow_frames0)
71 |
72 | self.label_kw = dict(weight='bold', fontsize=15, labelpad=3)
73 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left")
74 |
75 | fig = plt.figure(figsize=(18/1.5, 15/1.8))
76 | ax0 = fig.add_subplot(2, 2, 1)
77 | ax1 = fig.add_subplot(2, 2, 2)
78 | ax2 = fig.add_subplot(3, 2, (3, 4))
79 | axes = [ax0, ax1, ax2]
80 |
81 | # imshow1 ############################################################
82 | ax = axes[0]
83 | imshow_kw = dict(cmap='turbo', aspect='auto', animated=True,
84 | vmin=0, vmax=self.vmax)
85 | im = ax.imshow(self.imshow_frames0[0], **imshow_kw)
86 | self.ims0 = [im]
87 |
88 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18)
89 | ax.set_xticks([]); ax.set_yticks([])
90 |
91 | # imshow2 ############################################################
92 | ax = axes[1]
93 | im = ax.imshow(self.imshow_frames1[0], **imshow_kw)
94 | self.ims1 = [im]
95 |
96 | self.txt1 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18)
97 | ax.set_xticks([]); ax.set_yticks([])
98 |
99 | # plot ###############################################################
100 | ax = axes[2]
101 | line = ax.plot(self.T_all_sec_log, self.plot_frame)[0]
102 | line.set_data(self.T_all_sec_log[0], self.plot_frame[0])
103 | self.lines0 = [line]
104 |
105 | ax.set_ylim(0, np.max(self.plot_frame)*1.04)
106 | ax.set_xlabel('log2(T) [sec]', **self.label_kw)
107 | ax.set_ylabel('reldist', **self.label_kw)
108 |
109 | # finalize #######################################################
110 | fig.subplots_adjust(left=.05, right=.99, bottom=-.49, top=.96,
111 | wspace=.02, hspace=.7)
112 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True)
113 |
114 | def _draw_frame(self, frame_idx):
115 | def _txt(txt_idx, T_sec):
116 | fill = (r"x(t)" if txt_idx == 0 else
117 | r"x(t - \tau(t)")
118 | if frame_idx == 0:
119 | txt = r"$S_1(%s)$ | unaveraged" % fill
120 | else:
121 | txt = r"$S_1(%s)\ |\ T=%.3f$ sec " % (fill, T_sec)
122 | return txt
123 |
124 | T_sec = self.T_all[frame_idx] / N
125 | # imshow0 ############################################################
126 | self.ims0[0].set_array(self.imshow_frames0[frame_idx])
127 | self.txt0.set_text(_txt(0, T_sec))
128 |
129 | # imshow1 ############################################################
130 | self.ims1[0].set_array(self.imshow_frames1[frame_idx])
131 | self.txt1.set_text(_txt(1, T_sec))
132 |
133 | # plot ###############################################################
134 | self.lines0[0].set_data(self.T_all_sec_log[:frame_idx + 1],
135 | self.plot_frame[:frame_idx + 1])
136 |
137 | # finalize ###########################################################
138 | self._drawn_artists = [*self.ims0, self.txt0, *self.ims1, self.txt1,
139 | *self.lines0]
140 |
141 | def new_frame_seq(self):
142 | return iter(range(self.n_frames))
143 |
144 | def _init_draw(self):
145 | pass
146 |
147 | imshow_frames0, imshow_frames1 = Scx_all0, Scx_all1
148 | plot_frame = dists
149 | vmax = max(Scx_all0.max(), Scx_all1.max())*.98
150 |
151 | ani = PlotImshowAnimation(imshow_frames0, imshow_frames1, plot_frame, T_all,
152 | vmax)
153 | ani.save('warp_scat_T.mp4', fps=15, savefig_kwargs=dict(pad_inches=0))
154 | plt.show()
155 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Scattering unwarped vs warped echirp GIF."""
3 | # https://dsp.stackexchange.com/q/78512/50076 ################################
4 | # https://dsp.stackexchange.com/q/78514/50076 ################################
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import matplotlib.animation as animation
8 | from kymatio.torch import Scattering1D
9 | from kymatio.visuals import plot
10 | from kymatio.toolkit import _echirp_fn, l2
11 |
12 | def tau(t, K=.08):
13 | return np.cos(2*np.pi * t**2) * K
14 |
15 | def adtau(t, K=.08):
16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K)
17 |
18 | import torch
19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 | #%%###########################################################################
22 | N, f = 2048, 64
23 |
24 | #%%###########################################################################
25 | J, Q = 8, 8
26 | T = 512
27 | T_all = np.arange(1, N/2 + 1, 8)
28 | T_all = np.logspace(np.log2(1), np.log2(N - 1), N//8, base=2, endpoint=True)
29 |
30 | kw = dict(shape=N, J=J, Q=Q, average=True, oversampling=999, out_type='array',
31 | max_order=1, max_pad_factor=None)
32 | ts_all = [Scattering1D(**kw, T=T) for T in T_all]
33 | meta_all = [ts.meta() for ts in ts_all]
34 |
35 | if USE_GPU:
36 | for ts in ts_all:
37 | ts.cuda()
38 |
39 | #%%###########################################################################
40 | meta = meta_all[0]
41 | freqs = N * meta['xi'][meta['order'] == 1][:, 0]
42 | #%%
43 | t = np.linspace(0, 1, N, 0)
44 | _tau = tau(t, K=.012)
45 | x0 = np.cos(_echirp_fn(fmin=32, fmax=N/8)(t))
46 | x1 = np.cos(_echirp_fn(fmin=32, fmax=N/8)(t - _tau))
47 | #%%
48 | Scx_all0 = np.array([ts(x0)[1:].cpu().numpy() for ts in ts_all])
49 | Scx_all1 = np.array([ts(x1)[1:].cpu().numpy() for ts in ts_all])
50 | #%%
51 | dists = np.array([l2(S0, S1) for S0, S1 in zip(Scx_all0, Scx_all1)]).squeeze()
52 | plot(T_all, dists, ylims=(0, None), show=1)
53 |
54 | #%% extend
55 | def extend(x):
56 | return np.array(6*[x[0]] + list(x) + 6*[x[-1]])
57 |
58 | Scx_all0, Scx_all1, dists, T_all = [extend(h) for h in
59 | (Scx_all0, Scx_all1, dists, T_all)]
60 |
61 | #%% Animate #################################################################
62 | class PlotImshowAnimation(animation.TimedAnimation):
63 | def __init__(self, imshow_frames0, imshow_frames1, plot_frame, T_all, vmax):
64 | self.imshow_frames0 = imshow_frames0
65 | self.imshow_frames1 = imshow_frames1
66 | self.plot_frame = plot_frame
67 | self.T_all = T_all
68 | self.vmax = vmax
69 |
70 | self.N = imshow_frames0[0].shape[-1]
71 | self.T_all_sec_log = np.log2(self.T_all / N)
72 | self.n_frames = len(imshow_frames0)
73 |
74 | self.label_kw = dict(weight='bold', fontsize=15, labelpad=3)
75 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left")
76 |
77 | fig = plt.figure(figsize=(18/1.5, 15/1.8))
78 | ax0 = fig.add_subplot(2, 2, 1)
79 | ax1 = fig.add_subplot(2, 2, 2)
80 | ax2 = fig.add_subplot(3, 2, (3, 4))
81 | axes = [ax0, ax1, ax2]
82 |
83 | # imshow1 ############################################################
84 | ax = axes[0]
85 | imshow_kw = dict(cmap='turbo', aspect='auto', animated=True,
86 | vmin=0, vmax=self.vmax)
87 | im = ax.imshow(self.imshow_frames0[0], **imshow_kw)
88 | self.ims0 = [im]
89 |
90 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18)
91 | ax.set_xticks([]); ax.set_yticks([])
92 |
93 | # imshow2 ############################################################
94 | ax = axes[1]
95 | im = ax.imshow(self.imshow_frames1[0], **imshow_kw)
96 | self.ims1 = [im]
97 |
98 | self.txt1 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18)
99 | ax.set_xticks([]); ax.set_yticks([])
100 |
101 | # plot ###############################################################
102 | ax = axes[2]
103 | line = ax.plot(self.T_all_sec_log, self.plot_frame)[0]
104 | line.set_data(self.T_all_sec_log[0], self.plot_frame[0])
105 | self.lines0 = [line]
106 |
107 | ax.set_ylim(0, np.max(self.plot_frame)*1.04)
108 | ax.set_xlabel('log2(T) [sec]', **self.label_kw)
109 | ax.set_ylabel('reldist', **self.label_kw)
110 |
111 | # finalize #######################################################
112 | fig.subplots_adjust(left=.05, right=.99, bottom=-.49, top=.96,
113 | wspace=.02, hspace=.7)
114 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True)
115 |
116 | def _draw_frame(self, frame_idx):
117 | def _txt(txt_idx, T_sec):
118 | fill = (r"x(t)" if txt_idx == 0 else
119 | r"x(t - \tau(t)")
120 | if frame_idx == 0:
121 | txt = r"$S_1(%s)$ | unaveraged" % fill
122 | else:
123 | txt = r"$S_1(%s)\ |\ T=%.3f$ sec " % (fill, T_sec)
124 | return txt
125 |
126 | T_sec = self.T_all[frame_idx] / N
127 | # imshow0 ############################################################
128 | self.ims0[0].set_array(self.imshow_frames0[frame_idx])
129 | self.txt0.set_text(_txt(0, T_sec))
130 |
131 | # imshow1 ############################################################
132 | self.ims1[0].set_array(self.imshow_frames1[frame_idx])
133 | self.txt1.set_text(_txt(1, T_sec))
134 |
135 | # plot ###############################################################
136 | self.lines0[0].set_data(self.T_all_sec_log[:frame_idx],
137 | self.plot_frame[:frame_idx])
138 |
139 | # finalize ###########################################################
140 | self._drawn_artists = [*self.ims0, self.txt0, *self.ims1, self.txt1,
141 | *self.lines0]
142 |
143 | def new_frame_seq(self):
144 | return iter(range(self.n_frames))
145 |
146 | def _init_draw(self):
147 | pass
148 |
149 | imshow_frames0, imshow_frames1 = Scx_all0, Scx_all1
150 | plot_frame = dists
151 | vmax = max(Scx_all0.max(), Scx_all1.max())*.98
152 |
153 | ani = PlotImshowAnimation(imshow_frames0, imshow_frames1, plot_frame, T_all,
154 | vmax)
155 | ani.save('warp_scat.mp4', fps=15, savefig_kwargs=dict(pad_inches=0))
156 | plt.show()
157 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_echirp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_echirp.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_pulse.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_pulse.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/data_3d_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize JTFS of real data: as GIF of 3D slices unfolded over time."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | import numpy as np
5 | import torch
6 | import librosa
7 | from timeit import default_timer as dtime
8 |
9 | from kymatio import TimeFrequencyScattering1D
10 | from kymatio.toolkit import pack_coeffs_jtfs, jtfs_to_numpy, normalize
11 | from kymatio.visuals import gif_jtfs_3d, make_gif
12 |
13 | #%% load data ################################################################
14 | x, sr = librosa.load(librosa.ex('trumpet'))
15 | x = x[:81920]
16 |
17 | #%% create scattering object, move to GPU ####################################
18 | N = len(x)
19 | J = int(np.log2(N) - 3)
20 | Q = (16, 1)
21 | Q_fr = 2
22 | J_fr = 4
23 | T = 2**(J - 4)
24 | F = 4
25 |
26 | jtfs = TimeFrequencyScattering1D(shape=N, J=J, J_fr=J_fr, Q=Q, Q_fr=Q_fr,
27 | T=T, F=F, average_fr=True,
28 | max_pad_factor=None, max_pad_factor_fr=None,
29 | out_type='dict:array', frontend='torch')
30 | jmeta = jtfs.meta()
31 | for a in ('N_frs', 'J_pad_frs'):
32 | print(getattr(jtfs, a), '--', a)
33 | jtfs.cuda()
34 | torch.cuda.empty_cache()
35 |
36 | #%% scatter ##################################################################
37 | t0 = dtime()
38 | Scxt = jtfs(x)
39 | Scx = jtfs_to_numpy(Scxt)
40 |
41 | #%% pack
42 | packed = pack_coeffs_jtfs(Scx, jmeta, structure=2, separate_lowpass=True,
43 | sampling_psi_fr=jtfs.sampling_psi_fr)
44 | packed_spinned = packed[0]
45 | packed_spinned = packed_spinned.transpose(-1, 0, 1, 2)
46 |
47 | #%% normalize
48 | s = packed_spinned.shape
49 | packed_spinned = normalize(packed_spinned.reshape(1, s[0], -1)).reshape(*s)
50 |
51 | #%% make smooth camera transition ############################################
52 | packed_viz = packed_spinned
53 | print(packed_viz.shape)
54 | n_pts = len(packed_viz)
55 | extend_edge = int(.35 * n_pts)
56 |
57 | def gauss(n_pts, mn, mx, width=20):
58 | t = np.linspace(0, 1, n_pts)
59 | g = np.exp(-(t - .5)**2 * width)
60 | g *= (mx - mn)
61 | g += mn
62 | return g
63 |
64 | x = np.logspace(np.log10(2.5), np.log10(8.5), n_pts, endpoint=1)
65 | y = np.logspace(np.log10(0.3), np.log10(6.3), n_pts, endpoint=1)
66 | z = np.logspace(np.log10(2.0), np.log10(2.0), n_pts, endpoint=1)
67 |
68 | x, y, z = [gauss(n_pts, mn, mx) for (mn, mx)
69 | in [(2.5, 8.5), (0.3, 6.3), (2, 2)]]
70 |
71 | eyes = np.vstack([x, y, z]).T
72 | assert len(x) == len(packed_viz), (len(x), len(packed_viz))
73 |
74 | #%% Make gif ################################################################
75 | t0 = dtime()
76 | gif_jtfs_3d(packed_viz, base_name='jtfs3d_trumpet', images_ext='.png',
77 | overwrite=1, save_images=1, angles=eyes, gif_kw=dict(duration=50))
78 | print(dtime() - t0)
79 |
80 | #%% remake if needed
81 | # make_gif('', 'etc.gif', duration=33, delimiter='seiz3d', ext='.jpg',
82 | # overwrite=1, HD=True)
83 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_2d.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Visualize JTFS of exponential chirp: as GIF of coefficients
4 | superimposed with wavelets.
5 | """
6 | # https://dsp.stackexchange.com/q/78622/50076 ################################
7 | import numpy as np
8 | from kymatio.numpy import TimeFrequencyScattering1D
9 | from kymatio.toolkit import echirp, pack_coeffs_jtfs
10 | from kymatio.visuals import make_gif
11 | from numpy.fft import ifft, ifftshift
12 | import matplotlib.pyplot as plt
13 |
14 | #%% Generate echirp and create scattering object #############################
15 | N = 4096
16 | # span low to Nyquist; assume duration of 1 second
17 | x = echirp(N, fmin=64, fmax=N/2)
18 |
19 | #%% Show joint wavelets on a smaller filterbank ##############################
20 | o = (0, 999)[1]
21 | rp = np.sqrt(.5)
22 | jtfs = TimeFrequencyScattering1D(shape=N, J=5, Q=(16, 1), J_fr=3, Q_fr=1,
23 | sampling_filters_fr='resample',
24 | average=0, average_fr=0, F=4,
25 | r_psi=(rp, .9*rp, rp), out_type='dict:list',
26 | oversampling=o, oversampling_fr=o)
27 |
28 | #%% scatter
29 | Scx_orig = jtfs(x)
30 | jmeta = jtfs.meta()
31 |
32 | #%% pack
33 | Scx = pack_coeffs_jtfs(Scx_orig, jmeta, structure=2)
34 | cmx = Scx.max() * .5 # color max
35 |
36 | #%%
37 | n_n2s = sum(p['j'] > 0 for p in jtfs.psi2_f)
38 | n_n1_frs = len(jtfs.psi1_f_fr_up)
39 | # drop spin up
40 | Scx = Scx[:, n_n1_frs:]
41 | # drop spin down
42 | # Scx = Scx[:, :n_n1_frs + 1]
43 |
44 | #%%
45 | psis_down = jtfs.psi1_f_fr_down
46 | psi2s = [p for p in jtfs.psi2_f if p['j'] > 0]
47 | # reverse ordering
48 | psi2s = psi2s[::-1]
49 |
50 | # reverse order of fr wavelets
51 | psis_down = psis_down[::-1]
52 |
53 | # for spin down
54 | c_psi_dn = Scx[1:, 1:]
55 | c_phi_t = Scx[0, 1:]
56 | c_phi_f = Scx[1:, 0]
57 | c_phis = Scx[0, 0]
58 |
59 | # for spin up
60 | # c_psi_dn = Scx[1:, :-1]
61 | # c_phi_t = Scx[0, :-1]
62 | # c_phi_f = Scx[1:, -1]
63 | # c_phis = Scx[0, -1]
64 |
65 | #%% Visualize ################################################################
66 | def no_border(ax):
67 | ax.set_xticks([]); ax.set_yticks([])
68 | for spine in ax.spines:
69 | ax.spines[spine].set_visible(False)
70 |
71 | def to_time(p_f):
72 | if isinstance(p_f, dict):
73 | p_f = p_f[0]
74 | if isinstance(p_f, list):
75 | p_f = p_f[0]
76 | return ifftshift(ifft(p_f.squeeze()))
77 |
78 | spin_up = False
79 |
80 | imshow_kw0 = dict(aspect='auto', cmap='bwr')
81 | imshow_kw1 = dict(aspect='auto', cmap='turbo')
82 |
83 | n_rows = len(psis_down) + 1
84 | n_cols = len(psi2s) + 1
85 | w = 13
86 | h = 13 * n_rows / n_cols
87 |
88 | fig0, axes0 = plt.subplots(n_rows, n_cols, figsize=(w, h))
89 | fig1, axes1 = plt.subplots(n_rows, n_cols, figsize=(w, h))
90 |
91 | # compute common params to zoom on wavelets based on largest wavelet
92 | pf_f = psis_down[0]
93 | pt_f = psi2s[0]
94 | # centers
95 | ct = len(pt_f[0]) // 2
96 | cf = len(pf_f[0]) // 2
97 | # supports
98 | st = int(pt_f['support'][0] / 1.5)
99 | sf = int(pf_f['support'][0] / 1.5)
100 |
101 | # coeff max
102 | cmx = max(c_phi_t.max(), c_phi_f.max(), c_psi_dn.max()) * .8
103 |
104 | # psi_t * psi_f_down
105 | for n2_idx, pt_f in enumerate(psi2s):
106 | for n1_fr_idx, pf_f in enumerate(psis_down):
107 | pt = to_time(pt_f)
108 | pf = to_time(pf_f)
109 | # trim to zoom on wavelet
110 | pt = pt[ct - st:ct + st + 1]
111 | pf = pf[cf - sf:cf + sf + 1]
112 |
113 | Psi = pf[:, None] * pt[None]
114 |
115 | a = n1_fr_idx if spin_up else n1_fr_idx + 1
116 | ax0 = axes0[a, n2_idx + 1]
117 | ax1 = axes1[a, n2_idx + 1]
118 |
119 | mx = np.abs(Psi).max()
120 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx)
121 | no_border(ax0)
122 |
123 | # coeffs
124 | c = c_psi_dn[n2_idx, n1_fr_idx]
125 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx)
126 | no_border(ax1)
127 |
128 |
129 | # psi_t * phi_f
130 | phif = to_time(jtfs.phi_f_fr)
131 | phif = phif[cf - sf:cf + sf + 1]
132 | for n2_idx, pt_f in enumerate(psi2s):
133 | pt = to_time(pt_f)
134 | pt = pt[ct - st:ct + st + 1]
135 | Psi = phif[:, None] * pt[None]
136 |
137 | a = -1 if spin_up else 0
138 | ax0 = axes0[a, n2_idx + 1]
139 | ax1 = axes1[a, n2_idx + 1]
140 |
141 | mx = np.abs(Psi).max()
142 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx)
143 | no_border(ax0)
144 |
145 | # coeffs
146 | c = c_phi_f[n2_idx]
147 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx)
148 | no_border(ax1)
149 |
150 | # phi_t * psi_f
151 | phit = to_time(jtfs.phi_f)
152 | phit = phit[ct - st:ct + st + 1]
153 | for n1_fr_idx, pf_f in enumerate(psis_down):
154 | pf = to_time(pf_f)
155 | pf = pf[cf - sf:cf + sf + 1]
156 |
157 | Psi = pf[:, None] * phit[None]
158 |
159 | a = n1_fr_idx if spin_up else n1_fr_idx + 1
160 | ax0 = axes0[a, 0]
161 | ax1 = axes1[a, 0]
162 |
163 | mx = np.abs(Psi).max()
164 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx)
165 | no_border(ax0)
166 |
167 | # coeffs
168 | c = c_phi_t[n1_fr_idx]
169 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx)
170 | no_border(ax1)
171 |
172 | # phi_t * phi_f
173 | a = -1 if spin_up else 0
174 | ax0 = axes0[a, 0]
175 | ax1 = axes1[a, 0]
176 |
177 | Psi = phif[:, None] * phit[None]
178 | mx = np.abs(Psi).max()
179 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx)
180 | no_border(ax0)
181 |
182 | # coeffs
183 | c = c_phis
184 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx)
185 | no_border(ax1)
186 |
187 | fig0.subplots_adjust(wspace=.02, hspace=.02)
188 | fig1.subplots_adjust(wspace=.02, hspace=.02)
189 |
190 | base_name = 'jtfs_echirp_wavelets'
191 | fig0.savefig(f'{base_name}0.png', bbox_inches='tight')
192 | fig1.savefig(f'{base_name}1.png', bbox_inches='tight')
193 | make_gif('', f'{base_name}.gif', duration=2000, start_end_pause=0,
194 | delimiter=base_name, overwrite=1)
195 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_2d_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize JTFS of exponential chirp: as GIF of joint slices (2D)."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | import numpy as np
5 | from kymatio.numpy import TimeFrequencyScattering1D
6 | from kymatio.toolkit import echirp, energy
7 | from kymatio.visuals import plot, imshow
8 | from kymatio import visuals
9 |
10 | #%% Generate echirp and create scattering object #############################
11 | N = 4097
12 | # span low to Nyquist; assume duration of 1 second
13 | x = echirp(N, fmin=64, fmax=N/2)
14 | x = np.cos(2*np.pi * (N//16) * np.linspace(0, 1, N, 1))
15 |
16 | # 9 temporal octaves
17 | # largest scale is 2**9 [samples] / 4096 [samples / sec] == 125 ms
18 | J = 9
19 | # 8 bandpass wavelets per octave
20 | # J*Q ~= 144 total temporal coefficients in first-order scattering
21 | Q = 16
22 | # scale of temporal invariance, 31.25 ms
23 | T = 2**7
24 | # 4 frequential octaves
25 | J_fr = 4
26 | # 2 bandpass wavelets per octave
27 | Q_fr = 2
28 | # scale of frequential invariance, F/Q == 0.5 cycle per octave
29 | F = 8
30 | # do frequential averaging to enable 4D concatenation
31 | average_fr = True
32 | # frequential padding; 'zero' avoids few discretization artefacts for this example
33 | pad_mode_fr = 'zero'
34 | # return packed as dict keyed by pair names for easy inspection
35 | out_type = 'dict:array'
36 |
37 | params = dict(J=J, Q=Q, T=T, J_fr=J_fr, Q_fr=Q_fr, F=F, average_fr=average_fr,
38 | out_type=out_type, pad_mode_fr=pad_mode_fr, max_pad_factor=4,
39 | max_pad_factor_fr=4, oversampling=999, oversampling_fr=999)
40 | jtfs = TimeFrequencyScattering1D(shape=N, **params)
41 |
42 | #%% Take JTFS, print pair names and shapes ###################################
43 | Scx = jtfs(x)
44 | print("JTFS pairs:")
45 | for pair in Scx:
46 | print("{:<12} -- {}".format(str(Scx[pair].shape), pair))
47 |
48 | E_up = energy(Scx['psi_t * psi_f_up'])
49 | E_dn = energy(Scx['psi_t * psi_f_down'])
50 | print("E_down / E_up = {:.1f}".format(E_dn / E_up))
51 | #%% Show `x` and its (time-averaged) scalogram ###############################
52 | plot(x, show=1, w=.7,
53 | xlabel="time [samples]",
54 | title="Exponential chirp | fmin=64, fmax=2048, 4096 samples")
55 | #%%
56 | freqs = jtfs.meta()['xi']['S1'][:, -1]
57 | imshow(Scx['S1'].squeeze(), abs=1, w=.8, h=.67, yticks=freqs,
58 | xlabel="time [samples]",
59 | ylabel="frequency [frac of fs]",
60 | title="Scalogram, time-averaged (first-order scattering)")
61 |
62 | #%% Create & save GIF ########################################################
63 | # fetch meta (structural info)
64 | jmeta = jtfs.meta()
65 | # specify save folder
66 | savedir = ''
67 | # time between GIF frames (ms)
68 | duration = 150
69 | visuals.gif_jtfs_2d(Scx, jmeta, savedir=savedir, base_name='jtfs_sine',
70 | save_images=0, overwrite=1, gif_kw={'duration': duration})
71 | # Notice how -1 spin coefficients contain nearly all the energy
72 | # and +1 barely any; this is FDTS discriminability.
73 | # For ideal FDTS and JTFS, +1 will be all zeros.
74 | # For down chirp, the case is reversed and +1 hoards the energy.
75 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_3d_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize JTFS of exponential chirp: as GIF of 3D slices unfolded over time."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | from kymatio.numpy import TimeFrequencyScattering1D
5 | from kymatio.toolkit import echirp, pack_coeffs_jtfs
6 | from kymatio.visuals import gif_jtfs_3d
7 |
8 | #%% Make scattering object ###################################################
9 | N = 4096
10 | jtfs = TimeFrequencyScattering1D(shape=N, J=8, Q=16, T=2**8, F=4, J_fr=4, Q_fr=2,
11 | max_pad_factor=None, max_pad_factor_fr=None,
12 | oversampling=1, out_type='dict:array',
13 | pad_mode_fr='zero', pad_mode='zero',
14 | average_fr=True, sampling_filters_fr='resample')
15 | meta = jtfs.meta()
16 |
17 | #%% Make echirp & scatter ####################################################
18 | x = echirp(N, fmin=64, fmax=N/2)
19 | Scx = jtfs(x)
20 |
21 | #%% Make GIF with and without spinned ########################################
22 | for separate_lowpass in (False, True):
23 | packed = pack_coeffs_jtfs(Scx, meta, structure=2, separate_lowpass=True,
24 | sampling_psi_fr=jtfs.sampling_psi_fr)
25 | if separate_lowpass:
26 | packed = packed[0]
27 | packed = packed.transpose(-1, 0, 1, 2) # time first
28 |
29 | nm = 'spinned' if separate_lowpass else 'full'
30 | base_name = f'jtfs3d_echirp_{nm}'
31 | gif_jtfs_3d(packed, base_name=base_name, overwrite=1, cmap_norm=.5)
32 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/fdts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """JTFS of Frequency-Dependent Time Shifts."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | import numpy as np
5 | import torch
6 | from kymatio.torch import Scattering1D, TimeFrequencyScattering1D
7 | from kymatio.toolkit import fdts, l2
8 | from kymatio.visuals import make_gif
9 | import matplotlib.pyplot as plt
10 |
11 | #%% Generate echirp and create scattering object #############################
12 | N = 4096
13 | f0 = N // 24
14 | n_partials = 5
15 | total_shift = N//12
16 | seg_len = N//6
17 |
18 | x, xs = fdts(N, n_partials, total_shift, f0, seg_len, partials_f_sep=1.7)
19 |
20 | #%% Make scattering objects ##################################################
21 | J = int(np.log2(N)) # max scale / global averaging
22 | Q = (8, 2)
23 | kw = dict(J=J, Q=Q, shape=N, max_pad_factor=4, pad_mode='zero')
24 | cwt = Scattering1D(**kw, average=False, out_type='list', oversampling=999,
25 | max_order=1)
26 | ts = Scattering1D(out_type='array', max_order=2, **kw)
27 | jtfs = TimeFrequencyScattering1D(Q_fr=2, J_fr=5, out_type='array', **kw,
28 | out_exclude=('S1', 'phi_t * psi_f',
29 | 'phi_t * phi_f', 'psi_t * phi_f'),
30 | max_pad_factor_fr=4, pad_mode_fr='zero',
31 | sampling_filters_fr=('resample', 'resample'))
32 | cwt, ts, jtfs = [s.cuda() for s in (cwt, ts, jtfs)]
33 |
34 | #%% Scatter ##################################################################
35 | cwt_x = torch.vstack([c['coef'] for c in cwt(x)]).cpu().numpy()[1:]
36 | cwt_xs = torch.vstack([c['coef'] for c in cwt(xs)]).cpu().numpy()[1:]
37 |
38 | ts_x = ts(x).cpu().numpy()
39 | ts_xs = ts(xs).cpu().numpy()
40 |
41 | jtfs_x = jtfs(x).cpu().numpy()[0]
42 | jtfs_xs = jtfs(xs).cpu().numpy()[0]
43 |
44 | l2_ts = float(l2(ts_x, ts_xs))
45 | l2_jtfs = float(l2(jtfs_x, jtfs_xs))
46 |
47 | print(("\nFDTS sensitivity:\n"
48 | "JTFS/TS = {:.1f}\n"
49 | "TS = {:.4f}\n"
50 | "JTFS = {:.4f}\n").format(l2_jtfs / l2_ts, l2_ts, l2_jtfs))
51 |
52 | #%%# Make GIF ################################################################
53 | data = {'x': [x, xs],
54 | 'cwt': [cwt_x, cwt_xs],
55 | 'ts': [ts_x, ts_xs],
56 | 'jtfs': [jtfs_x, jtfs_xs]}
57 | xmx = max(abs(x).max(), abs(xs).max())*1.05
58 | cmx = max(np.abs(cwt_x).max(), np.abs(cwt_xs).max())
59 | tmx = max(np.abs(ts_x).max(), np.abs(ts_xs).max())
60 | jmx = max(np.abs(jtfs_x).max(), np.abs(jtfs_xs).max()) * .95
61 |
62 | for i in (0, 1):
63 | _x, _cwt, _ts, _jtfs = [data[k][i] for k in data]
64 |
65 | fig = plt.figure()
66 | fig = plt.figure(figsize=(14, 14))
67 | ax0 = fig.add_subplot(12, 2, (1, 9))
68 | ax1 = fig.add_subplot(12, 2, (2, 10))
69 | ax2 = fig.add_subplot(14, 2, (13, 14))
70 | ax3 = fig.add_subplot(14, 2, (15, 16))
71 | axes = [ax0, ax1, ax2, ax3]
72 |
73 | imshow_kw = dict(cmap='turbo', aspect='auto', vmin=0)
74 | s, e = 1700, len(x) - 1200
75 | tks = np.arange(s, e)
76 | ax0.plot(tks, _x[s:e])
77 | ax1.imshow(_cwt, **imshow_kw, vmax=cmx)
78 | ax2.imshow(_ts.T, **imshow_kw, vmax=tmx)
79 | ax3.imshow(_jtfs.T, **imshow_kw, vmax=jmx)
80 |
81 | # remove ticks
82 | for ax in (ax2, ax3):
83 | ax.set_xticks([])
84 | ax.set_yticks([])
85 | ax0.set_yticks([])
86 | ax1.set_yticks([])
87 |
88 | # set common limits
89 | ax0.set_ylim(-xmx, xmx)
90 |
91 | fig.subplots_adjust(hspace=.25, wspace=.02)
92 |
93 | # reduce y-spacing between bottom two subplots
94 | yspace = .05
95 | pos_ax3 = ax3.get_position()
96 | pos_ax2 = ax2.get_position()
97 | ywidth = pos_ax3.y1 - pos_ax3.y0
98 | pos_ax3.y0 = pos_ax2.y0 - ywidth * yspace
99 | pos_ax3.y1 = pos_ax2.y0 - ywidth * (1 + yspace)
100 | ax3.set_position(pos_ax3)
101 |
102 | fig.savefig(f'jtfs_ts{i}.png', bbox_inches='tight')
103 | plt.close(fig)
104 |
105 | make_gif('', 'jtfs_ts.gif', duration=1000, start_end_pause=0,
106 | delimiter='jtfs_ts', overwrite=1, verbose=1)
107 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize joint JTFS filterbank (2D)."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | from kymatio.numpy import TimeFrequencyScattering1D
5 | from kymatio.visuals import make_gif
6 | from kymatio import visuals
7 | import matplotlib.pyplot as plt
8 |
9 | #%% Viz frequential filterbank (1D) ##########################################
10 | jtfs = TimeFrequencyScattering1D(shape=512, J=7, Q=(16, 1), J_fr=5, Q_fr=2, F=8,
11 | normalize='l1')
12 | visuals.filterbank_jtfs_1d(jtfs, zoom=-1, lp_sum=1)
13 |
14 | #%% Show joint wavelets on a smaller filterbank ##############################
15 | jtfs = TimeFrequencyScattering1D(shape=512, J=4, Q=(16, 1), J_fr=3, Q_fr=1, F=8)
16 |
17 | #%%
18 | # nearly invisible, omit (also unused in scattering per `j2 > 0`)
19 | jtfs.psi2_f.pop(0)
20 | # not nearly invisible but still makes plot too big
21 | jtfs.psi2_f.pop(0)
22 | jtfs.psi1_f_fr_up.pop(0)
23 | jtfs.psi1_f_fr_down.pop(0)
24 |
25 | #%% Make GIF ################################################################
26 | base_name = 'filterbank'
27 | for i, part in enumerate(('real', 'imag', 'complex')):
28 | fig, axes = visuals.filterbank_jtfs_2d(jtfs, part=part, labels=1,
29 | suptitle_y=1.03)
30 | # fig.show()
31 | # break
32 | fig.savefig(f'{base_name}{i}.png', bbox_inches='tight')
33 | plt.close(fig)
34 |
35 | make_gif('', f'{base_name}.gif', duration=2000, start_end_pause=0,
36 | delimiter=base_name, overwrite=1)
37 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize log-frequency shift."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | import numpy as np
5 | from kymatio.numpy import Scattering1D
6 | from kymatio.visuals import imshow, make_gif
7 | from kymatio.toolkit import fdts
8 | import matplotlib.pyplot as plt
9 |
10 | #%% Make CWT object ##########################################################
11 | N = 2048
12 | cwt = Scattering1D(shape=N, J=7, Q=8, average=False, out_type='list',
13 | r_psi=.85, oversampling=999, max_order=1)
14 |
15 | #%% Make signals & take CWT ##################################################
16 | x0 = fdts(N, n_partials=4, seg_len=N//4, f0=N/12)[0]
17 | x1 = fdts(N, n_partials=4, seg_len=N//4, f0=N/20)[0]
18 |
19 | Wx0 = np.array([c['coef'].squeeze() for c in cwt(x0)[1:]])
20 | Wx1 = np.array([c['coef'].squeeze() for c in cwt(x1)[1:]])
21 |
22 | #%% Make GIF #################################################################
23 | fig0, ax0 = plt.subplots(1, 2, figsize=(12, 5))
24 | fig1, ax1 = plt.subplots(1, 2, figsize=(12, 5))
25 | # imshows
26 | kw = dict(abs=1, ticks=0, show=0)
27 | imshow(Wx0, ax=ax0[1], fig=fig0, **kw)
28 | imshow(Wx1, ax=ax1[1], fig=fig1, **kw)
29 |
30 | # plots
31 | s, e = N//3, -N//3 # zoom
32 | ax0[0].plot(x0[s:e])
33 | ax1[0].plot(x1[s:e])
34 | # ticks & ylims
35 | mx = max(np.abs(x0).max(), np.abs(x1).max()) * 1.03
36 | for ax in (ax0, ax1):
37 | ax[0].set_xticks([])
38 | ax[0].set_yticks([])
39 | ax[0].set_ylim(-mx, mx)
40 | # titles
41 | title_kw = dict(weight='bold', fontsize=20, loc='left')
42 | for ax in (ax0, ax1):
43 | ax[0].set_title("x (zoomed)", **title_kw)
44 | ax[1].set_title("|CWT(x)|", **title_kw)
45 |
46 | # finalize & save
47 | base_name = 'freq_tp'
48 | fig0.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=.02)
49 | fig1.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=.02)
50 | fig0.savefig(f'{base_name}0.png', bbox_inches='tight')
51 | fig1.savefig(f'{base_name}1.png', bbox_inches='tight')
52 | plt.close(fig0)
53 | plt.close(fig1)
54 |
55 | # make GIF
56 | make_gif('', f'{base_name}.gif', duration=1000, start_end_pause=0,
57 | delimiter=base_name, verbose=1, overwrite=1)
58 |
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_full.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_full.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_spinned.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_spinned.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_seiz.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_seiz.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_trumpet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_trumpet.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp_wavelets.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp_wavelets.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_sine.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_sine.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_ts.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_ts.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/sine_2d_anim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Visualize JTFS of pure sine: as GIF of joint slices (2D)."""
3 | # https://dsp.stackexchange.com/q/78622/50076 ################################
4 | import numpy as np
5 | from kymatio.numpy import TimeFrequencyScattering1D
6 | from kymatio.toolkit import energy
7 | from kymatio.visuals import plot, imshow
8 | from kymatio import visuals
9 |
10 | #%% Generate sine and create scattering object ###############################
11 | # `pow2 + 1` with `endpoint==1` to pad to perfect sine
12 | N = 4097
13 | x = np.cos(2*np.pi * (N//16) * np.linspace(0, 1, N, 1))
14 |
15 | # 9 temporal octaves
16 | # largest scale is 2**9 [samples] / 4096 [samples / sec] == 125 ms
17 | J = 9
18 | # 8 bandpass wavelets per octave
19 | # J*Q ~= 144 total temporal coefficients in first-order scattering
20 | Q = 16
21 | # scale of temporal invariance, 31.25 ms
22 | T = 2**7
23 | # 4 frequential octaves
24 | J_fr = 4
25 | # 2 bandpass wavelets per octave
26 | Q_fr = 2
27 | # scale of frequential invariance, F/Q == 0.5 cycle per octave
28 | F = 8
29 | # do frequential averaging to enable 4D concatenation
30 | average_fr = True
31 | # frequential padding; 'zero' avoids few discretization artefacts for this example
32 | pad_mode_fr = 'zero'
33 | # return packed as dict keyed by pair names for easy inspection
34 | out_type = 'dict:array'
35 |
36 | params = dict(J=J, Q=Q, T=T, J_fr=J_fr, Q_fr=Q_fr, F=F, average_fr=average_fr,
37 | out_type=out_type, pad_mode_fr=pad_mode_fr, max_pad_factor=4,
38 | max_pad_factor_fr=4, oversampling=999, oversampling_fr=999)
39 | jtfs = TimeFrequencyScattering1D(shape=N, **params)
40 |
41 | #%% Take JTFS, print pair names and shapes ###################################
42 | Scx = jtfs(x)
43 | print("JTFS pairs:")
44 | for pair in Scx:
45 | print("{:<12} -- {}".format(str(Scx[pair].shape), pair))
46 |
47 | E_up = energy(Scx['psi_t * psi_f_up'])
48 | E_dn = energy(Scx['psi_t * psi_f_down'])
49 | print("E_down / E_up = {:.1f}".format(E_dn / E_up))
50 | #%% Show `x` and its (time-averaged) scalogram ###############################
51 | plot(x, show=1, w=.7,
52 | xlabel="time [samples]",
53 | title="Pure sine | fmin=64, fmax=2048, 4096 samples")
54 | #%%
55 | freqs = jtfs.meta()['xi']['S1'][:, -1]
56 | imshow(Scx['S1'].squeeze(), abs=1, w=.8, h=.67, yticks=freqs,
57 | xlabel="time [samples]",
58 | ylabel="frequency [frac of fs]",
59 | title="Scalogram, time-averaged (first-order scattering)")
60 |
61 | #%% Create & save GIF ########################################################
62 | # fetch meta (structural info)
63 | jmeta = jtfs.meta()
64 | # specify save folder
65 | savedir = ''
66 | # time between GIF frames (ms)
67 | duration = 150
68 | visuals.gif_jtfs_2d(Scx, jmeta, savedir=savedir, base_name='jtfs_sine',
69 | save_images=0, overwrite=1, gif_kw={'duration': duration})
70 |
--------------------------------------------------------------------------------
/SignalProcessing/Q80918 - STFT - Overlap, window length, uncertainty principle/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/a/80920/50076 ################################
3 | import numpy as np
4 | from ssqueezepy import ssq_stft
5 | from ssqueezepy.visuals import imshow
6 | from scipy.signal.windows import dpss
7 |
8 | # signal
9 | N = 4096
10 | f = 32
11 | t = np.linspace(0, 1, N, 0)[::-1]
12 | fm = np.cos(2*np.pi * f * t) / (100*f)
13 | x = np.cos(2*np.pi * f*8 * (t + fm))
14 |
15 | # window
16 | n_fft = 512
17 | window = dpss(n_fft, n_fft//2 - 1)
18 |
19 | # STFT & SSQ
20 | Tx0, Sx0, *_ = ssq_stft(x, window, n_fft=n_fft, flipud=0, hop_len=1)
21 | Tx1, Sx1, *_ = ssq_stft(x, window, n_fft=n_fft, flipud=0, hop_len=64)
22 |
23 | # visualize ##################################################################
24 | # cheat & drop boundary effects for clarity
25 | Sx1, Tx1 = Sx1[:, 1:], Tx1[:, 1:]
26 | kw = dict(abs=1, interpolation='none', w=.8, h=.6)
27 | imshow(Sx0, **kw, title=("|STFT| -- hop_len=1", {'fontsize': 18}))
28 | imshow(Tx0, **kw, title=("|SSQ_STFT| -- hop_len=1", {'fontsize': 18}))
29 | imshow(Sx1, **kw, title=("|STFT| -- hop_len=64", {'fontsize': 18}))
30 | imshow(Tx1, **kw, title=("|SSQ_STFT| -- hop_len=64", {'fontsize': 18}))
31 |
--------------------------------------------------------------------------------
/SignalProcessing/Q81624 - STFT, optimization - Optimize window length for STFT/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """Optimize window length for STFT, plain window demo."""
3 | # https://dsp.stackexchange.com/q/81624/50076 ################################
4 | import torch, torch.nn as nn, torch.nn.functional as F
5 | import matplotlib.pyplot as plt
6 |
7 | def plot(x, title):
8 | plt.plot(x)
9 | plt.title(title, loc='left', weight='bold', fontsize=18)
10 | plt.gcf().set_size_inches(9, 7)
11 | plt.show(x)
12 |
13 |
14 | def hann(N):
15 | Nint = torch.floor(N).int()
16 | Nc = torch.clamp(N, min=Nint, max=Nint + 1)
17 | t = torch.arange(Nint) / Nc
18 | w = .5 * (1 - torch.cos(2*torch.pi *t))
19 | return w
20 |
21 |
22 | N = nn.Parameter(torch.tensor(129.))
23 |
24 | # optimal window length
25 | N_ref = 160
26 | w_ref = hann(torch.tensor(float(N_ref)))
27 | # add room for overshoot
28 | x = F.pad(w_ref, [40, 40])
29 |
30 | LR = 100000
31 | Ns, grads, outs = [], [], []
32 | for i in range(20):
33 | w = hann(N)
34 | w = w / torch.norm(w) # L2 norm to ensure `max(conv)` peaks at `N_ref`
35 | conv = torch.conv1d(x[None, None], w[None, None])
36 | out = 1. / torch.max(conv) # inverse of peak cross-correlation to minimize
37 | out.backward()
38 |
39 | with torch.no_grad():
40 | N -= LR * N.grad
41 |
42 | Ns.append(float(N.detach().numpy()))
43 | grads.append(float(N.grad.detach().numpy()))
44 | outs.append(float(out.detach().numpy()))
45 |
46 |
47 | # manual `optimizer.zero_grad()`
48 | N.grad.requires_grad_(False)
49 | N.grad.zero_()
50 |
51 | plot(Ns, "N vs iteration")
52 | plot(grads, "N.grad vs iteration")
53 | plot(outs, "loss vs iteration")
54 |
--------------------------------------------------------------------------------
/SignalProcessing/Q85745 - Energy, power, relation to sampling rate and duration/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/85745/50076
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | # from ?.toolkit import fft_upsample # to be released soon
6 |
7 | def E(x):
8 | return np.sum(np.abs(x)**2)
9 |
10 | def viz(t1, t2, x1, x2, mode='bar', title=True):
11 | tkw = dict(fontsize=18, weight='bold', loc='left')
12 | fig, axes = plt.subplots(1, 2, sharey=True, figsize=(15, 6))
13 |
14 | if mode == 'bar':
15 | Dt1 = t1[1] - t1[0] # sampling period
16 | Dt2 = t2[1] - t2[0]
17 | axes[0].bar(t1, x1, .9*Dt1)
18 | axes[1].bar(t2, x2, .9*Dt2)
19 | else:
20 | axes[0].plot(t1, x1)
21 | axes[1].plot(t2, x2)
22 |
23 | if title:
24 | if title is True:
25 | title1 = "sum(|x|^2)={:.3g}, N={}".format(E(x1), len(x1))
26 | title2 = "sum(|x|^2)={:.3g}, N={}".format(E(x2), len(x2))
27 | elif isinstance(title, tuple):
28 | title1, title2 = title
29 | axes[0].set_title(title1, **tkw)
30 | axes[1].set_title(title2, **tkw)
31 |
32 | fig.subplots_adjust(wspace=.03)
33 | plt.show()
34 |
35 | def viz_sines(N1, N2, T1, T2, f1=1, f2=1):
36 | t1 = np.linspace(0, T1, N1, endpoint=False)
37 | t2 = np.linspace(0, T2, N2, endpoint=False)
38 | x1 = np.cos(2*np.pi * f1 * t1)
39 | x2 = np.cos(2*np.pi * f2 * t2)
40 |
41 | viz(t1, t2, x1, x2, title=True)
42 |
43 | #%%
44 | N1, N2 = 10, 20
45 | for duration in (1, 1.25):
46 | viz_sines(N1, N2, duration, duration)
47 |
48 | #%%
49 | N = 20
50 | T1, T2 = 1, 2
51 | f1, f2 = 2, 1
52 | viz_sines(N, N, T1, T2, f1, f2)
53 |
54 | #%%
55 | # t = np.linspace(0, 1.25, 20, endpoint=False)
56 | # x = np.cos(2*np.pi * t)
57 | # x_up = fft_upsample(x, factor=128, time_to_time=True, real=True)
58 | # t_up = np.linspace(0, 1.25, len(x_up), endpoint=False)
59 | # x_up4 = np.hstack([x_up]*4)
60 | # t_up4 = np.linspace(0, 1.25*4, len(x_up)*4, endpoint=False)
61 |
62 | # viz(t_up, t_up4, x_up, x_up4, mode='plot',
63 | # title=("x_upsampled", "x_upsampled, 4 periods"))
64 |
--------------------------------------------------------------------------------
/SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_example.m:
--------------------------------------------------------------------------------
1 | % Answer to https://dsp.stackexchange.com/q/86181/50076
2 | % and to https://dsp.stackexchange.com/q/82985/50076
3 | %% Configure ################################################################
4 | fs = 400; % (in Hz) anything works
5 | duration = 5; % (in sec) anything works
6 | padtype = 'reflection'; % anything (supported) works
7 |
8 | % get power between these frequencies
9 | freq_min = 50; % (in Hz)
10 | freq_max = 150; % (in Hz)
11 |
12 | %% Obtain transform #########################################################
13 | % make signal & filterbank
14 | % assume this is Amperes passing through 1 Ohm resistor; P = I^2*R, E = P * T
15 | % check actual physical units for your specific application and adjust accordingly
16 | rng('default')
17 | x = randn(1, fs * duration);
18 | fb = cwtfilterbank('Wavelet', 'amor', 'SignalLength', fs * duration, ...
19 | 'VoicesPerOctave', 11, 'SamplingFrequency', fs, ...
20 | 'Boundary', padtype);
21 |
22 | % transform, get frequencies
23 | [Wx, freqs] = fb.wt(x);
24 |
25 | % fetch coefficients according to `freq_min_max`
26 | Wx_spec = Wx((freq_min < freqs) & (freqs < freq_max), :);
27 |
28 | %% "Adjustments" ############################################################
29 | % See "Practical adjustments" in the answer
30 |
31 | % fetch wavelets in freq domain, compute ET & ES transfer funcs, fethch maxima
32 | psi_fs = fb.PsiDFT; % fetch wavelets in freq domain
33 | ET_tfn = sum(abs(psi_fs).^2, 1);
34 | ES_tfn = abs(sum(psi_fs, 1)).^2;
35 | % real-valued case adjustment:
36 | % - ET since we operate on half of spectrum (half as many coeffs)
37 | % - ES since `.real` halves spectrum, quartering energy on real side;
38 | % note, for this to work right, the filterbank must be halved at Nyquist
39 | % by design (which should also be done for sake of temporal decay)
40 | ET_adj = max(ET_tfn) / 2;
41 | ES_adj = max(ES_tfn) / 4;
42 |
43 | %% Compute energy & power ###################################################
44 | % compute energy & power (discrete)
45 | ET_disc = sum(abs(Wx_spec).^2, 'all') / ET_adj;
46 | ES_disc = sum(abs(real(sum(Wx_spec, 1))).^2, 'all') / ES_adj;
47 | PT_disc = ET_disc / length(x);
48 | PS_disc = ES_disc / length(x);
49 |
50 | % compute energy & power (physical); estimate underlying continuous waveform via
51 | % Riemann integration
52 | sampling_period = 1 / fs;
53 | ET_phys = ET_disc * sampling_period;
54 | ES_phys = ES_disc * sampling_period;
55 | PT_phys = ET_phys / duration;
56 | PS_phys = ES_phys / duration;
57 |
58 | % repeat for original signal
59 | Ex_disc = sum(abs(x).^2, 'all');
60 | Px_disc = Ex_disc / length(x);
61 | Ex_phys = Ex_disc * sampling_period;
62 | Px_phys = Ex_phys / duration;
63 |
64 | %% Report ###################################################################
65 | s = ['Between %d and %d Hz, DISCRETE:\n'...
66 | '%-9.6g -- energy (transform)\n'...
67 | '%-9.6g -- energy (signal)\n'...
68 | '%-9.6g -- mean power (transform)\n'...
69 | '%-9.6g -- mean power (signal)\n\n'];
70 | fprintf(s, freq_min, freq_max, ET_disc, ES_disc, PT_disc, PS_disc)
71 |
72 |
73 | s = ['Between %d and %d Hz, PHYSICAL (via Riemann integration):\n'...
74 | '%-9.6g Joules -- energy (transform)\n'...
75 | '%-9.6g Joules -- energy (signal)\n'...
76 | '%-9.6g Watts -- mean power (transform)\n'...
77 | '%-9.6g Watts -- mean power (signal)\n\n'];
78 | fprintf(s, freq_min, freq_max, ET_phys, ES_phys, PT_phys, PS_phys)
79 |
80 | s = ['Original signal:\n'...
81 | '%-9.6g -- energy (discrete)\n'...
82 | '%-9.6g -- mean power (discrete)\n'...
83 | '%-9.6g Joules -- energy (physical)\n'...
84 | '%-9.6g Watts -- mean power (physical)\n\n'];
85 | fprintf(s, Ex_disc, Px_disc, Ex_phys, Px_phys)
86 |
--------------------------------------------------------------------------------
/SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_example.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Answer to https://dsp.stackexchange.com/q/86181/50076
3 | import numpy as np
4 | from ssqueezepy import cwt, Wavelet
5 | from ssqueezepy.experimental import scale_to_freq
6 |
7 | def E(x):
8 | return np.sum(np.abs(x)**2)
9 |
10 | #%% Configure ################################################################
11 | fs = 400 # (in Hz) anything works
12 | duration = 5 # (in sec) anything works
13 | padtype = 'reflect' # anything (supported) works
14 |
15 | # get power between these frequencies
16 | freq_min = 50 # (in Hz)
17 | freq_max = 150 # (in Hz)
18 |
19 | #%% Obtain transform #########################################################
20 | # make signal & wavelet
21 | # assume this is Amperes passing through 1 Ohm resistor; P = I^2*R, E = P * T
22 | # check actual physical units for your specific application and adjust accordingly
23 | np.random.seed(0)
24 | x = np.random.randn(fs * duration)
25 | wavelet = Wavelet()
26 |
27 | # transform, get frequencies
28 | Wx, scales = cwt(x, wavelet, padtype=padtype)
29 | freqs = scale_to_freq(scales, wavelet, len(x), fs=fs, padtype=padtype)
30 |
31 | # fetch coefficients according to `freq_min` and `freq_max`
32 | Wx_spec = Wx[(freq_min < freqs) * (freqs < freq_max)]
33 |
34 | #%% Normalization ############################################################
35 | # See "Normalization" in the answer
36 |
37 | # fetch wavelets in freq domain, compute ET & ES transfer funcs, fetch maxima
38 | psi_fs = wavelet._Psih # fetch wavelets in freq domain
39 | ET_tfn = np.sum(np.abs(psi_fs)**2, axis=0)
40 | ES_tfn = np.abs(np.sum(psi_fs, axis=0))**2
41 | # real-valued case adjustment:
42 | # - ET since we operate on half of spectrum (half as many coeffs)
43 | # - ES since `.real` halves spectrum, quartering energy on real side;
44 | # note, for this to work right, the filterbank must be halved at Nyquist
45 | # by design (which should also be done for sake of temporal decay)
46 | ET_adj = ET_tfn.max() / 2
47 | ES_adj = ES_tfn.max() / 4
48 |
49 | #%% Compute energy & power ###################################################
50 | # compute energy & power (discrete)
51 | ET_disc = np.sum(np.abs(Wx_spec)**2) / ET_adj
52 | ES_disc = np.sum(np.abs(np.sum(Wx_spec, axis=0).real)**2) / ES_adj
53 | PT_disc = ET_disc / len(x)
54 | PS_disc = ES_disc / len(x)
55 |
56 | # compute energy & power (physical); estimate underlying continuous waveform via
57 | # Riemann integration
58 | sampling_period = 1 / fs
59 | ET_phys = ET_disc * sampling_period
60 | ES_phys = ES_disc * sampling_period
61 | PT_phys = ET_phys / duration
62 | PS_phys = ES_phys / duration
63 |
64 | # repeat for original signal
65 | Ex_disc = E(x)
66 | Px_disc = Ex_disc / len(x)
67 | Ex_phys = Ex_disc * sampling_period
68 | Px_phys = Ex_phys / duration
69 |
70 | #%% Report ###################################################################
71 | print(("Between {:d} and {:d} Hz, DISCRETE:\n"
72 | "{:<9.6g} -- energy (transform)\n"
73 | "{:<9.6g} -- energy (signal)\n"
74 | "{:<9.6g} -- mean power (transform)\n"
75 | "{:<9.6g} -- mean power (signal)\n"
76 | ).format(freq_min, freq_max, ET_disc, ES_disc, PT_disc, PS_disc))
77 |
78 | print(("Between {:d} and {:d} Hz, PHYSICAL (via Riemann integration):\n"
79 | "{:<9.6g} Joules -- energy (transform)\n"
80 | "{:<9.6g} Joules -- energy (signal)\n"
81 | "{:<9.6g} Watts -- mean power (transform)\n"
82 | "{:<9.6g} Watts -- mean power (signal)\n"
83 | ).format(freq_min, freq_max, ET_phys, ES_phys, PT_phys, PS_phys))
84 |
85 | print(("Original signal:\n"
86 | "{:<9.6g} -- energy (discrete)\n"
87 | "{:<9.6g} -- mean power (discrete)\n"
88 | "{:<9.6g} Joules -- energy (physical)\n"
89 | "{:<9.6g} Watts -- mean power (physical)\n").format(
90 | Ex_disc, Px_disc, Ex_phys, Px_phys))
91 |
--------------------------------------------------------------------------------
/SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_validation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Answer to https://dsp.stackexchange.com/q/86181/50076
3 | import numpy as np
4 | from numpy.fft import fft, ifft
5 |
6 | def E(x):
7 | return np.sum(np.abs(x)**2)
8 |
9 | #%% Generate filterbank ######################################################
10 | # show that it works with any filters
11 | # note ES still works despite violating condition 3, but condition 3
12 | # is still needed for valid interpretation (sum won't invert to signal)
13 | psi_fs = np.random.randn(51, 256) + 1j*np.random.randn(51, 256)
14 | # wavelets are zero-mean
15 | psi_fs[:, 0] = 0
16 | # "any" but still analytic here, need more code for "any any"
17 | psi_fs[:, 129:] = 0
18 | # nyquist imaginary part must be zero else inverse can't be real-valued
19 | psi_fs[:, 128].imag = 0
20 |
21 |
22 | #%% Compute "transfer functions" #############################################
23 | ET_tfn = np.sum(np.abs(psi_fs)**2, axis=0)
24 | ES_tfn = np.abs(np.sum(psi_fs, axis=0))**2
25 |
26 | #%% Run tests ################################################################
27 | for case in ('real', 'complex'):
28 | # generate signal
29 | M = 256
30 | x = np.random.randn(M)
31 | if case == 'complex':
32 | x = x + 1j * np.random.randn(M)
33 | xf = fft(x)
34 |
35 | # compute CWT
36 | out = ifft(xf * psi_fs)
37 |
38 | # invert
39 | x_inv = out.sum(axis=0)
40 | if case == 'real':
41 | x_inv = x_inv.real
42 |
43 | # compute energies via transfer fns
44 | xfe = np.abs(xf)**2
45 | ET = xfe * ET_tfn / M
46 | ES = xfe * ES_tfn / M
47 | if case == 'real':
48 | ES[1:M//2] /= 2
49 | ET, ES = ET.sum(), ES.sum()
50 |
51 | # assert agreement
52 | assert np.allclose(ET, E(out))
53 | assert np.allclose(ES, E(x_inv))
54 |
--------------------------------------------------------------------------------
/SignalProcessing/Q86726 - STFT - Amplitude extraction/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/86726/50076
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from scipy.signal.windows import hann, dpss
6 | from numpy.fft import fft, fftshift, ifftshift
7 | from ssqueezepy import stft, ssq_stft
8 | from ssqueezepy.visuals import plotscat, imshow
9 |
10 | #%%############################################################################
11 | # Main example
12 | # ------------
13 |
14 | def viz(Sx, Sxs):
15 | slcs = Sxs[:, 0]
16 | imshow(Sxs, abs=1, w=.7, h=.55, interpolation='none')
17 | plotscat(slcs, complex=1, w=.6, h=.82, show=1)
18 | print("sum(Sx[:, 0]) =", slcs.sum(), flush=True)
19 | print("max(abs(Sx[:, 0])) / sum(window) =", abs(slcs).max() / window.sum(),
20 | flush=True)
21 |
22 | # gen signal
23 | N = 256
24 | t = np.linspace(0, 1, N, 0)
25 | x = np.cos(2*np.pi * 60 * t)
26 |
27 | # gen window
28 | window = dpss(N//2, N//8 - 1, sym=False)
29 | # this is always done under the hood, `len(window) == n_fft`
30 | window = np.pad(window, (N - len(window)) // 2)
31 |
32 | # match input length so `x` is seen as perfect sinusoid by DFT;
33 | # can also be integer fraction of `N` but then `x` can't be of all int freqs
34 | n_fft = N
35 | # in general should be 'reflect' but slightly complicates example here, simplify
36 | padtype = 'wrap'
37 |
38 | Sx = stft(x, window, n_fft=n_fft, hop_len=1, padtype=padtype, dtype='float64')
39 | Sxs = Sx.copy()
40 | Sxs[1:-1] *= 2
41 | viz(Sx, Sxs)
42 |
43 | #%%
44 | Tx = ssq_stft(x, window, n_fft=len(window), padtype=padtype)[0]
45 | # ... *almost* all the work
46 | Txs = Tx.copy()
47 | Txs[1:-1] *= 2
48 | viz(Txs, Txs)
49 |
50 | # ignore ssqueezepy warning (there is as of writing, false positive)
51 |
52 | #%%############################################################################
53 | # Criteria demos
54 | # --------------
55 |
56 | # helper
57 | def viz2(S, w, wf, sig_title, tm_idx=128):
58 | slc = S[:, tm_idx]
59 | fig, axes = plt.subplots(1, 2, figsize=(15, 7))
60 | imshow(S, abs=1, ax=axes[0], fig=fig, show=0,
61 | title="abs(Sx_adj) | " + sig_title, interpolation='none')
62 | plotscat(slc, abs=1, ax=axes[1], fig=fig,
63 | title=f"abs(Sx_adj[:, {tm_idx}])")
64 | fig.subplots_adjust(wspace=.1)
65 | plt.show()
66 |
67 | print(
68 | ("max(abs(x)) = {1:.3g}\n"
69 | "max(abs(Sx_adj[:, {0}])) = {2:.3g}\n"
70 | "sum(abs(Sx_adj[:, {0}])) = {3:.3g}\n"
71 | "max(abs(Sx_adj[:, {0}])) / sum(window_adj) = {4:.3g} -- time norm\n"
72 | "sum(abs(Sx_adj[:, {0}])) / sum(abs(fft(window_adj_f))) = {5:.3g} "
73 | "-- freq norm\n"
74 | ).format(
75 | tm_idx,
76 | abs(x).max(),
77 | sum(abs(S[:, tm_idx])),
78 | max(abs(S[:, tm_idx])),
79 | max(abs(S[:, tm_idx])) / sum(w),
80 | sum(abs(S[:, tm_idx])) / sum(abs(fft(wf))),
81 | )
82 | )
83 |
84 |
85 | def _pad_window(w, padded_len):
86 | pleft = (padded_len - len(w)) // 2
87 | pright = padded_len - pleft - len(w)
88 | return np.pad(w, [pleft, pright])
89 |
90 | def get_adj_windows(window, n_fft, N):
91 | padded_len_conv = N + n_fft - 1
92 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2)
93 | # window_adj_f = np.pad(window, (n_fft - len(window))//2)
94 | window_adj = _pad_window(window, padded_len_conv)
95 | window_adj_f = _pad_window(window, n_fft)
96 |
97 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0
98 | def _center(w):
99 | w = ifftshift(w)
100 | w = fftshift(np.roll(w, np.argmax(w)))
101 | return w
102 |
103 | window_adj = _center(window_adj)
104 | window_adj_f = _center(window_adj_f)
105 | return window_adj, window_adj_f
106 |
107 | N = 256
108 | t = np.linspace(0, 1, N, 0)
109 |
110 | #%%############################################################################
111 | # Everything's right
112 | # ^^^^^^^^^^^^^^^^^^
113 | n_fft = N
114 | window = dpss(N//8, N//8//2 - 1, sym=False)
115 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
116 |
117 | f = 60
118 | x = np.cos(2*np.pi * f * t)
119 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype)
120 | Sx_adj = Sx.copy()
121 | Sx_adj[1:-1] *= 2
122 |
123 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)")
124 |
125 | #%%############################################################################
126 | # Nonstationary case
127 | # ^^^^^^^^^^^^^^^^^^
128 | f = 60
129 | x = np.cos(2*np.pi * f * t**2)
130 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype)
131 | Sx_adj = Sx.copy()
132 | Sx_adj[1:-1] *= 2
133 |
134 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t**2)")
135 |
136 | #%%############################################################################
137 | # Too close to Nyquist
138 | # ^^^^^^^^^^^^^^^^^^^^
139 | f = N//2 - 10
140 | x = np.cos(2*np.pi * f * t)
141 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype)
142 | Sx_adj = Sx.copy()
143 | Sx_adj[1:-1] *= 2
144 |
145 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)", 32)
146 |
147 | #%%############################################################################
148 | # Too close to DC
149 | # ^^^^^^^^^^^^^^^
150 | f = 10
151 | x = np.cos(2*np.pi * f * t)
152 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64')
153 | Sx_adj = Sx.copy()
154 | Sx_adj[1:-1] *= 2
155 |
156 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)", 32)
157 |
158 | #%%############################################################################
159 | # Multi-component intersection
160 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
161 | f, fm = 60, 50
162 | x = np.cos(2*np.pi * f * t) + np.cos(2*np.pi * fm * t**2)
163 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype)
164 | Sxs = Sx.copy()
165 | Sxs[1:-1] *= 2
166 |
167 | viz2(Sxs, window_adj, window_adj_f, f"\ncos(2*pi*{f}*t) + cos(2*pi*{fm}*t**2)")
168 |
169 | #%%############################################################################
170 | # Insufficient `n_fft`
171 | # ^^^^^^^^^^^^^^^^^^^^
172 | f = 60
173 | x = np.cos(2*np.pi * f * t)
174 | n_fft = 36
175 | window = hann(n_fft, sym=False)
176 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
177 |
178 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64')
179 | Sx_adj = Sx.copy()
180 | Sx_adj[1:-1] *= 2
181 |
182 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)")
183 |
184 | #%%############################################################################
185 | # Excessive `hop_size`
186 | # ^^^^^^^^^^^^^^^^^^^^
187 | fc, fa = 60, 2
188 | x = np.sin(2*np.pi * fa * t) * np.cos(2*np.pi * fc * t)
189 | n_fft = len(x)
190 | window = dpss(n_fft, n_fft//2 - 1, sym=False)
191 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
192 |
193 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64',
194 | hop_len=64)
195 | Sx_adj = Sx.copy()
196 | Sx_adj[1:-1] *= 2
197 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{fc}*t) * sin(2*pi*{fa}*t)",
198 | 2)
199 |
200 | #%%############################################################################
201 | # Minimal `hop_size`
202 | # ^^^^^^^^^^^^^^^^^^
203 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64',
204 | hop_len=1)
205 | Sx_adj = Sx.copy()
206 | Sx_adj[1:-1] *= 2
207 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{fc}*t) * sin(2*pi*{fa}*t)",
208 | 96)
209 |
210 | #%%############################################################################
211 | # Non-localized window
212 | # ^^^^^^^^^^^^^^^^^^^^
213 | np.random.seed(0)
214 | n_fft = N
215 | window = np.abs(np.random.randn(n_fft//4) + .01)
216 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
217 |
218 | f = 60
219 | x = np.cos(2*np.pi * f * t)
220 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype)
221 | Sxs = Sx.copy()
222 | Sxs[1:-1] *= 2
223 |
224 | viz2(Sxs, window_adj, window_adj_f, f"cos(2*pi*{f}*t)")
225 |
--------------------------------------------------------------------------------
/SignalProcessing/Q86937 - STFT - Equivalence of Windowed Fourier and Convolutions/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/85745/50076
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from scipy.signal.windows import dpss
6 | from numpy.fft import fft, ifft, fftshift, ifftshift
7 |
8 | from ssqueezepy import stft
9 | from ssqueezepy.visuals import plot, plotscat, imshow
10 |
11 | #%%############################################################################
12 | # Helpers
13 | # -------
14 | def _pad_window(w, padded_len):
15 | pleft = (padded_len - len(w)) // 2
16 | pright = padded_len - pleft - len(w)
17 | return np.pad(w, [pleft, pright])
18 |
19 | def get_adj_windows(window, n_fft, N):
20 | padded_len_conv = N + n_fft - 1
21 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2)
22 | # window_adj_f = np.pad(window, (n_fft - len(window))//2)
23 | window_adj = _pad_window(window, padded_len_conv)
24 | window_adj_f = _pad_window(window, n_fft)
25 |
26 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0
27 | def _center(w):
28 | w = ifftshift(w)
29 | w = fftshift(np.roll(w, np.argmax(w)))
30 | return w
31 |
32 | window_adj = _center(window_adj)
33 | window_adj_f = _center(window_adj_f)
34 | return window_adj, window_adj_f
35 |
36 | #%%############################################################################
37 | # Row-wise STFT implementation
38 | # ----------------------------
39 | def cisoid(N, f):
40 | t = np.linspace(0, 1, N, endpoint=False)
41 | return (np.cos(2*np.pi * f * t) +
42 | np.sin(2*np.pi * f * t) * 1j)
43 |
44 |
45 | def stft_rowwise(x, window, n_fft):
46 | """
47 | - no hop or pad support
48 | - equivalent to `'wrap'` (periodic) pad if `len(window) < len(x)//2`
49 | - real-valued `x` only
50 | - returns `Sx`, `fbank_f`
51 | """
52 | assert len(window) == n_fft and n_fft <= len(x)
53 |
54 | # compute some params
55 | xp = x
56 | N = len(x)
57 | padded_len = N
58 |
59 | # pad such that `armgax(window)` for `Sx[:, 0]` is at `Sx[:, 0]`
60 | # (DFT-center the convolution kernel)
61 | # note, due to how scipy generates windows and standard stft handles padding,
62 | # this still won't yield zero-phase for majority of signals, but it's both
63 | # fixable and can be very close; ideally just pass in `len(window)==len(x)`.
64 | _wpad_right = (padded_len - len(window)) / 2
65 | wpad_right = (int(np.floor(_wpad_right)) if n_fft % 2 == 1 else
66 | int(np.ceil(_wpad_right)))
67 | wpad_left = padded_len - len(window) - wpad_right
68 |
69 | # generate filterbank
70 | cisoids = np.array([cisoid(n_fft, f) for f in range(n_fft//2 + 1)])
71 | fbank = ifftshift(window) * cisoids
72 | fbank = np.pad(fftshift(fbank), [[0, 0], [wpad_left, wpad_right]])
73 | fbank = ifftshift(fbank)
74 | fbank_f = fft(fbank, axis=-1).conj()
75 |
76 | # circular convolution
77 | prod = fbank_f * fft(xp)[None]
78 | Sx = ifft(prod, axis=-1)
79 | return Sx, fbank_f
80 |
81 |
82 | #%% first, validate correctness
83 | for N in (128, 129):
84 | x = np.random.randn(N)
85 | for n_fft in (55, 56):
86 | window = np.abs(np.random.randn(n_fft) + .01)
87 |
88 | Sx0 = stft(x, window, n_fft=n_fft, modulated=True,
89 | padtype='wrap', hop_len=1, dtype='float64')
90 | Sx1, _ = stft_rowwise(x, window, n_fft)
91 |
92 | assert np.allclose(Sx0, Sx1)
93 |
94 | #%% Demo filterbanks
95 | N = 256
96 | x = np.random.randn(N)
97 | pkw = dict(color='tab:blue', w=1, h=.6)
98 | window = dpss(N//6, N//6//4 - 1, sym=False)
99 |
100 | plotscat(window, title="scipy.signal.windows.dpss(42, 9.5, sym=False)", show=1)
101 |
102 | # full spectrum
103 | n_fft = N
104 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
105 | _, fbank_f0 = stft_rowwise(x, window_adj_f, n_fft)
106 | plot(fbank_f0.T.real, title="fbank_f0 | n_fft=N", **pkw, show=1)
107 |
108 | # integer subset
109 | n_fft = N//4
110 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
111 | _, fbank_f1 = stft_rowwise(x, window_adj_f, n_fft)
112 | assert np.allclose(fbank_f1, fbank_f0[::4])
113 | plot(fbank_f1.T.real, title="fbank_f1 = fbank_f0[::4] | n_fft=N/4", **pkw,
114 | show=1)
115 |
116 | # fractional subset
117 | n_fft = int(N/5.5)
118 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N)
119 | _, fbank_f2 = stft_rowwise(x, window_adj_f, n_fft)
120 | plot(fbank_f2.T.real,
121 | title="fbank_f2, fbank_f1 | n_fft=floor(N/5.5), xlims=(100, 150)", **pkw)
122 | plot(fbank_f1.T.real, show=1, color='tab:orange', xlims=(100, 150))
123 |
124 | # time-domain examples
125 | def plot_complex(pt, f):
126 | plot(pt, complex=1)
127 | plot(pt, abs=1, linestyle='--', color='k', w=.5, h=.8,
128 | title=f"STFT filter | f={f}", show=1, ylims=(-1.03, 1.03))
129 |
130 | f0, f1 = 20, 40
131 | pt0 = ifftshift(ifft(fbank_f0[f0]))
132 | pt1 = ifftshift(ifft(fbank_f0[f1]))
133 | plot_complex(pt0, f0)
134 | plot_complex(pt1, f1)
135 |
136 | #%% Demo sines
137 | N = 128
138 | t = np.linspace(0, 1, 128, 1)
139 | x = np.cos(2*np.pi * t * 4)
140 | x += np.cos(2*np.pi * t * 20)
141 | x += np.cos(2*np.pi * t * 40)
142 | x += np.cos(2*np.pi * t * 60)
143 |
144 | window = dpss(128, 128//8, sym=False)
145 |
146 | for n_fft in (128, 256):
147 | Sx0 = stft(x, modulated=0, n_fft=n_fft, window=window)[::-1]
148 | Sx1 = stft(x, modulated=1, n_fft=n_fft, window=window)[::-1]
149 |
150 | fig, axes = plt.subplots(1, 2, sharey=True, figsize=(14, 6),
151 | layout='constrained')
152 | imshow(Sx0.real, fig=fig, ax=axes[0], show=0,
153 | title=f"STFT.real, n_fft={n_fft} | standard")
154 | imshow(Sx1.real, fig=fig, ax=axes[1], show=1,
155 | title=f"STFT.real, n_fft={n_fft} | improved")
156 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004323.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004323.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004324.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004324.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004326.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004326.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004328.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004328.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004330.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004330.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004331.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004331.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004333.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004333.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004335.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004335.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004338.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004338.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004339.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004339.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043244.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043244.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043245.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043245.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043246.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043246.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/example.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/example.wav
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im00.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im01.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im02.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im03.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im04.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im05.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im06.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im07.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im07.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im08.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im08.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im09.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im09.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im10.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im11.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im12.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im13.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/preds.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/preds.gif
--------------------------------------------------------------------------------
/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/test_algo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87355/50076
3 | # USER CONFIGS
4 | TOLERANCE = 0.05
5 | VIZ_PREDS = 1
6 | GPU = 0
7 | FULL_GPU = 0
8 | PRINT_TIMES = 1
9 |
10 | import os
11 | os.environ['SSQ_GPU'] = '1' if GPU else '0'
12 | os.environ['FULL_GPU'] = '1' if FULL_GPU else '0'
13 |
14 | import numpy as np
15 | from utils87355 import (
16 | find_audio_change_timestamps, handle_wavelet, pad_input_make_t,
17 | load_data, data_labels, data_dir,
18 | )
19 | from timeit import default_timer as dtime
20 |
21 | #%%############################################################################
22 | # Configure
23 | # ---------
24 | cfg = dict(
25 | wavelet='freq',
26 | stft_prefilter_scaling=1,
27 | carve_th_scaling=1/3,
28 | fmax_idx_frac=400/471,
29 | silence_th_scaling=50,
30 | final_pred_n_peaks=1,
31 | escaling=(1, 1, 1),
32 | )
33 |
34 | #%%############################################################################
35 | # Make reusables
36 | # --------------
37 | Nmax = max(len(load_data(i)[0]) for i in range(len(data_labels)))
38 | fs = load_data(0)[1] # assumes same for all
39 | Mmax = len(pad_input_make_t(np.arange(Nmax), fs)[0])
40 | N_ref = Nmax
41 | reusables = handle_wavelet(
42 | wavelet=cfg['wavelet'],
43 | fs=fs,
44 | M=Mmax,
45 | ssq_precfg=dict(scales='log', padtype=None),
46 | fmax_idx_frac=cfg['fmax_idx_frac'],
47 | silence_interval_samples=int(.2*fs),
48 | escaling=cfg['escaling'],
49 | )
50 | other_cfg = dict(reusables=reusables, fs=fs, N_ref=N_ref)
51 |
52 | #%%############################################################################
53 | # Run
54 | # ---
55 | results_train = {'preds': []}
56 | results_test = {'preds': []}
57 |
58 | idxs_train = tuple(range(0, 4))
59 | idxs_test = tuple(range(4, len(data_labels)))
60 | idxs_all = idxs_train + idxs_test
61 |
62 | # get predictions
63 | for example_idx in idxs_all:
64 | # load data
65 | x, fs, labels = load_data(example_idx)
66 |
67 | # make predictions
68 | viz_labels = ((example_idx, labels) if VIZ_PREDS else
69 | None)
70 | if PRINT_TIMES:
71 | t0 = dtime()
72 | preds = find_audio_change_timestamps(
73 | x, **other_cfg, **cfg, viz_labels=viz_labels)[0]
74 | if PRINT_TIMES:
75 | print("%.3g" % (dtime() - t0) + " sec", flush=True)
76 | else:
77 | print(end=".", flush=True)
78 |
79 | # append
80 | d = (results_train if example_idx in idxs_train else
81 | results_test)
82 | d['preds'].append(preds)
83 |
84 | #%%############################################################################
85 | # Calculate score
86 | # ---------------
87 | def is_success(preds, label, tolerance=TOLERANCE):
88 | if not isinstance(preds, (list, tuple)):
89 | preds = [preds]
90 | return any(abs(p - label) < tolerance for p in preds)
91 |
92 | for d in (results_train, results_test):
93 | for k in ('scores', 'scores_flattened'):
94 | d[k] = []
95 |
96 | tolerance = TOLERANCE
97 |
98 | printed_line = False
99 | for example_idx in idxs_all:
100 | d = (results_train if example_idx in idxs_train else
101 | results_test)
102 | pred_idx = (example_idx if example_idx in idxs_train else
103 | example_idx - len(idxs_train))
104 | preds = d['preds'][pred_idx]
105 | labels = load_data(example_idx)[-1]
106 |
107 | if len(labels) == 2:
108 | score0 = is_success(preds, labels[0], tolerance)
109 | score1 = is_success(preds, labels[1], tolerance)
110 | scores_packed = (score0, score1)
111 | else:
112 | score0 = is_success(preds[0], labels[0], tolerance)
113 | scores_packed = (score0,)
114 |
115 | # append
116 | d['scores'].append(scores_packed)
117 | d['scores_flattened'].extend(scores_packed)
118 | # print(end=".", flush=True)
119 |
120 | if example_idx not in idxs_train and not printed_line:
121 | print("="*80)
122 | printed_line = True
123 |
124 | print()
125 | preds = sorted([float("%.3g" % p) for p in preds])
126 | print(tuple(np.array(list(scores_packed)).astype(int)), '--', example_idx)
127 | print(tuple(preds))
128 | print(tuple(labels))
129 |
130 | # finalize
131 | accuracy_train = np.mean(results_train['scores_flattened'])
132 | accuracy_test = np.mean(results_test['scores_flattened'])
133 |
134 | print("Accuracy (train, test): {:.3f}, {:.3f}".format(
135 | accuracy_train, accuracy_test))
136 |
137 | #%%
138 | # from ? import make_gif
139 | # from pathlib import Path
140 | # make_gif(data_dir, str(Path(data_dir, "preds.gif")),
141 | # duration=1000, overwrite=True, delimiter="im", HD=True, verbose=True)
142 |
143 | #%%############################################################################
144 | # Last run's output
145 | # -----------------
146 | """
147 | (1, 1) -- 0
148 | (1.56, 3.54)
149 | (1.55, 3.5)
150 |
151 | (1, 1) -- 1
152 | (1.22, 1.74)
153 | (1.21, 1.74)
154 |
155 | (1, 1) -- 2
156 | (0.958, 1.57)
157 | (0.94, 1.57)
158 |
159 | (1, 1) -- 3
160 | (1.42, 1.87)
161 | (1.42, 1.85)
162 | ================================================================================
163 |
164 | (1,) -- 4
165 | (0.783, 1.12)
166 | (0.76,)
167 |
168 | (1, 1) -- 5
169 | (0.818, 1.69)
170 | (0.79, 1.68)
171 |
172 | (1, 1) -- 6
173 | (2.89, 3.3)
174 | (2.87, 3.28)
175 |
176 | (0,) -- 7
177 | (0.488, 0.738)
178 | (0.75,)
179 |
180 | (1,) -- 8
181 | (0.623, 1.77)
182 | (0.63,)
183 |
184 | (1,) -- 9
185 | (0.211, 0.468)
186 | (0.46,)
187 |
188 | (1,) -- 10
189 | (4.69, 4.97)
190 | (4.97,)
191 |
192 | (1,) -- 11
193 | (0.543, 0.956)
194 | (0.94,)
195 |
196 | (1, 1) -- 12
197 | (2.35, 3.0)
198 | (2.37, 2.99)
199 |
200 | (1, 0) -- 13
201 | (2.06, 2.44)
202 | (2.03, 2.73)
203 | Accuracy (train, test): 1.000, 0.857
204 | """
205 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87706 - aliasing - How to measure aliasing/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87705/50076
3 | import numpy as np
4 | from numpy.fft import fft, ifft
5 | import matplotlib.pyplot as plt
6 | import scipy.signal
7 |
8 | try:
9 | from ssqueezepy.visuals import plotscat as _plotscat, plot
10 | except:
11 | print("Plotting requires ssqueezepy; won't plot")
12 | _plotscat = lambda *a, **k: 1
13 | plot = lambda *a, **k: 1
14 |
15 | def plotscat(*a, **k):
16 | x = a[0]
17 | if k.get('abs', 0):
18 | x = abs(x)
19 | mn, mx = x.min(), x.max()
20 | amx = 1.03*max(abs(mn), abs(mx))
21 | if 'ylims' not in k:
22 | k['ylims'] = (-amx, amx) if mn < -.03 else (0, amx)
23 | _plotscat(x, *a[1:], **k)
24 |
25 |
26 | def energy(x):
27 | return np.sum(abs(x)**2)
28 |
29 | def get_xf_sub_worst(xf, M):
30 | return (abs(xf.real) + 1j*abs(xf.imag)).reshape(M, -1).mean(axis=0)
31 |
32 | def measure(xf, M):
33 | N = len(xf)
34 | hbw = N//M//2
35 | xf_ref = np.zeros(N)
36 | xf_ref[:hbw] = 1
37 | xf_ref[-hbw:] = 1
38 |
39 | xf /= abs(xf).max() # in case this step wasn't done already
40 |
41 | xf_sub = get_xf_sub_worst(xf, M)
42 | xf_ref_sub = xf_ref.reshape(M, -1).mean(axis=0)
43 |
44 | E_x_before = energy(xf)
45 | E_ref_before = energy(xf_ref)
46 | E_x_after = energy(xf_sub)
47 | E_ref_after = energy(xf_ref_sub)
48 | ratio_before = E_x_before / E_ref_before
49 | ratio_after = E_x_after / E_ref_after
50 | r = ratio_after / ratio_before
51 |
52 | alias = round(100*(r - 1) / (M - 1), 3)
53 | return alias
54 |
55 | #%% Make signals
56 | N = 256
57 | M = 8
58 |
59 | # populate unaliased signal
60 | hbw = N//M//2
61 | xf_ref = np.zeros(N)
62 | xf_ref[:hbw] = 1
63 | xf_ref[-hbw:] = 1
64 |
65 | # populate referenced signal
66 | xf_full = np.ones(N)
67 |
68 | # make in-between signal
69 | xf0 = xf_ref.copy()
70 | xf0[hbw*2] = 1
71 | xf0[-(hbw*2)] = 1
72 |
73 | #%% Plot
74 | def viz(xf, name, aval=0, do_measure=0, worst=0, cval=0, ylims01=True,
75 | show_ref=False):
76 | fig, axes = plt.subplots(1, 2, layout='constrained', figsize=(12, 5))
77 | pkw = dict(fig=fig, abs=aval, complex=cval)
78 | if ylims01:
79 | pkw['ylims'] = (0, 1.03)
80 | if worst:
81 | sub = get_xf_sub_worst(xf, M)
82 | else:
83 | sub = xf.reshape(M, -1).mean(axis=0)
84 | l, r = ("|", "|") if aval else ("", "")
85 | info = ("abs_max={:.3g}".format(max(abs(sub))) if not do_measure else
86 | "alias={:.3g}%".format(measure(xf, M)))
87 | title1 = "{}fft({}){}".format(l, name, r)
88 | if worst:
89 | title2 = "{}fft({}_sub{}_worst){} -- {}".format(l, name, M, r, info)
90 | else:
91 | title2 = "{}fft({}[::{}]){} -- {}".format(l, name, M, r, info)
92 |
93 | plotscat(xf, **pkw, ax=axes[0], title=title1)
94 | if show_ref:
95 | pkw['hlines'] = (1/M, {'color': 'tab:red', 'linewidth': 1})
96 | plotscat(sub, **pkw, ax=axes[1], title=title2)
97 | plt.show()
98 |
99 | viz(xf_ref, "x_ref")
100 | viz(xf_full, "x_full")
101 | viz(xf0, "x0")
102 |
103 | #%% Applying the metric
104 | def scipy_decimate_filter(N, M):
105 | """Minimally reproduce the filter used by
106 | `scipy.signal.decimate(ftype='fir')`. This is rigorously tested elsewhere.
107 | """
108 | q = M
109 | half_len = 10*q
110 | n = int(2*half_len)
111 | cutoff = 1. / q
112 | numtaps = n + 1
113 |
114 | win = scipy.signal.get_window("hamming", numtaps, fftbins=False)
115 |
116 | # sample, window, & norm sinc
117 | alpha = 0.5 * (numtaps - 1)
118 | m = np.arange(0, numtaps) - alpha
119 | h = win * cutoff * np.sinc(cutoff * m)
120 | h /= h.sum() # L1 norm
121 |
122 | # pad and center
123 | h = np.pad(h, [0, N - len(h)])
124 | h = np.roll(h, -np.argmax(h))
125 | return h
126 |
127 | # moving average
128 | h = np.zeros(N)
129 | h[:M//2 + 1] = 1/M
130 | h[-(M//2 - 1):] = 1/M
131 | hf = fft(h)
132 |
133 | # scipy
134 | h_scipy = scipy_decimate_filter(N, M)
135 | hf_scipy = fft(h_scipy)
136 |
137 | #%% Show results
138 | ckw = dict(aval=1, do_measure=1, worst=1, show_ref=1, ylims01=0)
139 | viz(hf, "h", **ckw)
140 | viz(hf_scipy, "h_scipy", **ckw)
141 |
142 | #%% Effect
143 | def dft_upsample(xf, M):
144 | L = len(xf)
145 | n_to_add = L * M - L
146 | zeros = np.zeros(n_to_add - 1)
147 | nyq = xf[L//2]
148 | return np.hstack([xf[:L//2],
149 | nyq/2, zeros, np.conj(nyq)/2,
150 | xf[-(L//2 - 1):]]) * M
151 |
152 | def rel_l2(x0, x1):
153 | return np.linalg.norm(x0 - x1) / np.linalg.norm(x0)
154 |
155 | def generate_case(M, seed=0):
156 | np.random.seed(seed)
157 | x = np.random.randn(N)
158 | xf = fft(x)
159 | x0_convf = xf * hf
160 | x1_convf = xf * hf_scipy
161 |
162 | _x0f = x0_convf.reshape(M, -1).mean(axis=0)
163 | _x1f = x1_convf.reshape(M, -1).mean(axis=0)
164 | x0f = dft_upsample(_x0f, M)
165 | x1f = dft_upsample(_x1f, M)
166 | x0, x1 = ifft(x0f), ifft(x1f)
167 |
168 | x0_nosub = x0_convf
169 | x0_nosub[hbw:-hbw] = 0
170 | x0_nosub = ifft(x0_nosub)
171 | x1_nosub = ifft(x1_convf)
172 |
173 | return x0, x1, x0_nosub, x1_nosub
174 |
175 | #%% Viz
176 | x0, x1, x0_nosub, x1_nosub = generate_case(M, seed=42)
177 |
178 | fig, axes = plt.subplots(1, 2, layout='constrained', figsize=(12, 5),
179 | sharey=True)
180 | pkw = dict(fig=fig)
181 | plot(x0_nosub.real, ax=axes[0])
182 | plot(x0.real, ax=axes[0], title="hf: recovered")
183 | plot(x1_nosub.real, ax=axes[1])
184 | plot(x1.real, ax=axes[1], title="hf_scipy: recovered")
185 | print(rel_l2(x0, x0_nosub), rel_l2(x1, x1_nosub), sep='\n')
186 |
187 | #%% Survey
188 | dist_hf, dist_hf_scipy = [], []
189 | for seed in range(1000000):
190 | x0, x1, x0_nosub, x1_nosub = generate_case(M, seed)
191 | dist_hf.append(rel_l2(x0, x0_nosub))
192 | dist_hf_scipy.append(rel_l2(x1, x1_nosub))
193 | dist_hf, dist_hf_scipy = [np.array(d) for d in (dist_hf, dist_hf_scipy)]
194 |
195 | #%% Viz survey
196 | plot(dist_hf)
197 | plot(dist_hf_scipy)
198 |
199 | fmt = ("(min, max, mean) = ({:.3g}, {:.3g}, {:.3g})\n"
200 | "(min_idx, max_idx) = ({}, {})\n")
201 | ops = (np.min, np.max, np.mean, np.argmin, np.argmax)
202 | print(fmt.format(*[op(dist_hf) for op in ops]))
203 | print(fmt.format(*[op(dist_hf_scipy) for op in ops]))
204 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target2.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_template.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_template.png
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87774/50076
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import scipy.signal
6 | from numpy.fft import fft2, ifft2, ifftshift
7 | from PIL import Image
8 | from matplotlib.patches import Circle
9 |
10 | def rand(*s, ival=True):
11 | return ((1j*np.random.randn(*s)) if ival else
12 | np.random.randn(*s))
13 |
14 | def cross_correlate_2d(x, h):
15 | h = ifftshift(ifftshift(h, axes=0), axes=1)
16 | return ifft2(fft2(x) * np.conj(fft2(h)))
17 |
18 | def load_image(path):
19 | img = np.array(Image.open(path).convert("L")) / 255.
20 | img[img==1] = 0
21 | return img
22 |
23 | def imshow(x, title=None, show_argmax=False, fig=None, ax=None, show=True):
24 | xa = np.abs(x)
25 | if fig is None:
26 | fig, ax = plt.subplots()
27 | ax.imshow(xa)
28 | ax.set_xticks([])
29 | ax.set_yticks([])
30 | if title is not None:
31 | ax.set_title(title, weight='bold', loc='left', fontsize=20)
32 | if show_argmax:
33 | hmax, wmax = np.where(xa == xa.max())
34 | size = len(x) // 7 // 2
35 | circ = Circle((wmax, hmax), size, fill=False, color='tab:red',
36 | linewidth=2)
37 | ax.add_patch(circ)
38 | if show:
39 | plt.show()
40 |
41 |
42 | def run_example(x, h, show_h=False):
43 | # compute
44 | out0 = scipy.signal.correlate2d(x, h)
45 | out1 = cross_correlate_2d(x, h)
46 |
47 | # plot
48 | if show_h:
49 | fig, axes = plt.subplots(2, 1, figsize=(5.7, 5.7), layout='constrained')
50 | imshow(x, "|x|, |h|", fig=fig, ax=axes[0], show=False,)
51 | imshow(h, fig=fig, ax=axes[1])
52 | else:
53 | imshow(x, "|x|; x = image + iWGN/5")
54 |
55 | nm = "x, x" if not show_h else "x, h"
56 | imshow(out0, f"|scipy.signal.correlate2d({nm})|", show_argmax=True)
57 | imshow(out1, f"|cross_correlate_2d({nm})|", show_argmax=True)
58 |
59 |
60 | #%% Example: self-cc #########################################################
61 | # load image as greyscale
62 | np.random.seed(0)
63 | x = load_image("covid.png")
64 |
65 | # subsample & add noise
66 | x = x[::4, ::4]
67 | x = x + rand(*x.shape) / 5
68 |
69 | run_example(x, x)
70 |
71 | #%% Example: flipped false positive ##########################################
72 | np.random.seed(0)
73 | x = load_image("covid_target.png")
74 | h = load_image("covid_template.png")
75 |
76 | x, h = x[::9, ::9], h[::9, ::9]
77 | x = x + rand(*x.shape) / 10
78 | h = h + rand(*h.shape) / 10
79 |
80 | run_example(x, h, show_h=True)
81 |
82 | #%% Example: conjugated false positive########################################
83 | np.random.seed(0)
84 | x = load_image("covid_target2.png")
85 | h = load_image("covid_template.png")
86 |
87 | x, h = 1j*x[::9, ::9], 1j*h[::9, ::9]
88 | M, N = x.shape
89 | x[:, N//2:] = np.conj(x[:, N//2:])
90 |
91 | x = x + rand(*x.shape) / 10 * 1j
92 | h = h + rand(*h.shape) / 10 * 1j
93 |
94 | run_example(x, h, show_h=True)
95 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/validation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | def cross_correlate_2d(x, h):
3 | h = ifftshift(ifftshift(h, axes=0), axes=1)
4 | return ifft2(fft2(x) * np.conj(fft2(h)))
5 |
6 | ##############################################################################
7 | import numpy as np
8 | from numpy.fft import fft2, ifft2, ifftshift
9 |
10 | def crand(*s):
11 | return np.random.randn(*s) + 1j*np.random.randn(*s)
12 |
13 | for M in (64, 65, 99):
14 | for N in (64, 65, 99):
15 | # case 1 -------------------------------------------------------------
16 | x = crand(M, N)
17 |
18 | o = cross_correlate_2d(x, x)
19 | hmax, wmax = np.where(abs(o) == abs(o).max())
20 | assert hmax == M//2, (hmax, M//2)
21 | assert wmax == N//2, (wmax, N//2)
22 |
23 | # case 2 -------------------------------------------------------------
24 | x = np.zeros((M, N), dtype='complex128')
25 | h = np.zeros((M, N), dtype='complex128')
26 |
27 | hctr, wctr = M//8, N//8
28 | hsize, wsize = hctr*2, wctr*2
29 | target_loc = slice(0, hsize), slice(0, wsize)
30 | false_positive_loc = slice(M//2, M//2 + hsize), slice(-wsize, None)
31 | target = crand(hsize, wsize)
32 |
33 | x[target_loc] = target
34 | x[false_positive_loc] = target[::-1, ::-1]
35 | h[M//2 - hctr:M//2 + hctr, N//2 - wctr:N//2 + wctr] = target
36 |
37 | o = cross_correlate_2d(x, h)
38 | hmax, wmax = np.where(abs(o) == abs(o).max())
39 | assert hmax == hctr, (hmax, hctr)
40 | assert wmax == wctr, (wmax, wctr)
41 |
42 | # case 3 -------------------------------------------------------------
43 | x[false_positive_loc] = np.conj(target)
44 |
45 | o = cross_correlate_2d(x, h)
46 | hmax, wmax = np.where(abs(o) == abs(o).max())
47 | assert hmax == hctr, (hmax, hctr)
48 | assert wmax == wctr, (wmax, wctr)
49 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python/cc2d.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87781/50076
3 | import numpy as np
4 | from scipy.fft import next_fast_len, fft2, ifft2
5 |
6 |
7 | def cross_correlate_2d(x, h, mode='same', real=True, get_reusables=False,
8 | inplace=True, workers=-1):
9 | """2D cross-correlation, replicating `scipy.signal.correlate2d`.
10 |
11 | Parameters
12 | ----------
13 | x : np.ndarray, 2D
14 | Input.
15 |
16 | h : np.ndarray, 2D
17 | Filter/template.
18 |
19 | mode : str
20 | `'full'`, `'same'`, or `'valid'` (see scipy docs).
21 |
22 | real : bool (default True)
23 | Whether to assume `x` and `h` are real-valued, which is faster.
24 |
25 | get_reusables : bool (default False)
26 | Whether to return `out, reusables`.
27 |
28 | If `h` is same and `x` has same shape, pass in `reusables` as `h` for
29 | speedup,
30 |
31 | inplace : bool (default True)
32 | If True, is faster but may alter `x` and/or `h` that are passed in
33 | (unless via `reusables`).
34 |
35 | workers : int
36 | Number of CPU cores to use with FFT. Defaults to `-1`, which is all.
37 | """
38 | # check if `h` is reusables
39 | if not isinstance(h, tuple):
40 | # fetch shapes, check inputs
41 | xs, hs = x.shape, h.shape
42 | h_not_smaller = all(hs[i] >= xs[i] for i in (0, 1))
43 | x_not_smaller = all(xs[i] >= hs[i] for i in (1, 0))
44 | if mode == 'valid' and not (h_not_smaller or x_not_smaller):
45 | raise ValueError(
46 | "For `mode='valid'`, every axis in `x` must be at least "
47 | "as long as in `h`, or vice versa. Got x:{}, h:{}".format(
48 | str(xs), str(hs)))
49 |
50 | # swap if needed
51 | swap = bool(mode == 'valid' and not x_not_smaller)
52 | if swap:
53 | xadj, hadj = h, x
54 | else:
55 | xadj, hadj = x, h
56 | xs, hs = xadj.shape, hadj.shape
57 |
58 | # compute pad quantities
59 | full_len_h = xs[0] + hs[0] - 1
60 | full_len_w = xs[1] + hs[1] - 1
61 | padded_len_h = next_fast_len(full_len_h)
62 | padded_len_w = next_fast_len(full_len_w)
63 | padded_shape = (padded_len_h, padded_len_w)
64 |
65 | # compute unpad indices
66 | if mode == 'full':
67 | offset_h, offset_w = 0, 0
68 | len_h, len_w = full_len_h, full_len_w
69 | elif mode == 'same':
70 | len_h, len_w = xs
71 | offset_h, offset_w = [g//2 for g in hs]
72 | elif mode == 'valid':
73 | ax_pairs = ((xs[0], hs[0]), (xs[1], hs[1]))
74 | len_h, len_w = [max(g) - min(g) + 1 for g in ax_pairs]
75 | offset_h, offset_w = [min(g) - 1 for g in ax_pairs]
76 | unpad_h = slice(offset_h, offset_h + len_h)
77 | unpad_w = slice(offset_w, offset_w + len_w)
78 |
79 | # handle filter / template
80 | if real:
81 | hadj = hadj[::-1, ::-1]
82 | else:
83 | if inplace:
84 | np.conj(hadj[::-1, ::-1], out=hadj)
85 | else:
86 | hadj = np.conj(hadj[::-1, ::-1])
87 | hf = fft2(hadj, padded_shape, workers=workers)
88 | else:
89 | reusables = h
90 | (hf, swap, padded_shape, unpad_h, unpad_w) = reusables
91 | if swap:
92 | xadj, hadj = h, x
93 | else:
94 | xadj, hadj = x, h
95 |
96 | # FFT convolution
97 | xf = fft2(xadj, padded_shape, workers=workers)
98 | if inplace:
99 | np.multiply(xf, hf, out=xf)
100 | else:
101 | xf = xf * hf
102 | out = ifft2(xf, workers=workers)
103 | if real:
104 | out = out.real
105 |
106 | # unpad, unswap
107 | out = out[unpad_h, unpad_w]
108 | if swap:
109 | out = out[::-1, ::-1]
110 |
111 | # pack reusables
112 | if get_reusables:
113 | reusables = (hf, swap, padded_shape, unpad_h, unpad_w)
114 |
115 | # return
116 | return ((out, reusables) if get_reusables else
117 | out)
118 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87781/50076
3 | import numpy as np
4 | import scipy.signal
5 | from cc2d import cross_correlate_2d
6 |
7 | #%% Testing ##################################################################
8 | np.random.seed(0)
9 |
10 | rand = lambda M, N, real: (np.random.randn(M, N) if real else
11 | (np.random.randn(M, N) + 1j*np.random.randn(M, N)))
12 |
13 | lengths = (1, 2, 3, 4, 5, 6, 7, 15, 50)
14 |
15 | for real in (True, False):
16 | for mode in ('full', 'same', 'valid'):
17 | for M0 in lengths:
18 | for N0 in lengths:
19 | x = rand(M0, N0, real)
20 | for M1 in lengths:
21 | for N1 in lengths:
22 | h = rand(M1, N1, real)
23 |
24 | fn0 = lambda: cross_correlate_2d(
25 | x.copy(), h.copy(), mode, real=real)
26 | fn1 = lambda: scipy.signal.correlate2d(
27 | x.copy(), h.copy(), mode=mode)
28 |
29 | # compute
30 | try:
31 | out0 = fn0()
32 | except ValueError:
33 | try:
34 | fn1()
35 | except ValueError:
36 | continue
37 | except:
38 | raise AssertionError
39 | out1 = fn1()
40 |
41 | # assert equality
42 | cfg = (real, mode, M0, N0, M1, N1)
43 | assert out0.shape == out1.shape, cfg
44 | assert np.allclose(out0, out1), cfg
45 |
46 |
47 | # check reusables
48 | out0, reusables = cross_correlate_2d(x, h, get_reusables=True)
49 | out1 = cross_correlate_2d(x, reusables)
50 | assert np.allclose(out0, out1)
51 |
--------------------------------------------------------------------------------
/SignalProcessing/Q87926 - DFT of a sine, closed form solution and insights/solutions.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/87926/50076
3 | import numpy as np
4 |
5 | def _get_UVk(N, f, phi):
6 | assert not float(f).is_integer(), f
7 | k = np.arange(N)
8 | U = np.cos(2*np.pi*f + phi) - np.cos(phi)
9 | V = np.cos(2*np.pi*f + phi - 2*np.pi*f/N) - np.cos(phi - 2*np.pi*f/N)
10 | return U, V, k
11 |
12 |
13 | def sine_dft(N, f, phi):
14 | """Solution by Cedron Dawg https://www.dsprelated.com/showarticle/771.php
15 | """
16 | if float(f).is_integer():
17 | return _sine_dft_int(N, f, phi)
18 | return _sine_dft_frac(N, f, phi)
19 |
20 | def sine_dft_modulus(N, f, phi):
21 | """Modulus analysis & insights by John Muradeli # TODO URL
22 | """
23 | if float(f).is_integer():
24 | return _sine_dft_modulus_int(N, f, phi)
25 | return _sine_dft_modulus_frac(N, f, phi)
26 |
27 | def _sine_dft_frac(N, f, phi):
28 | U, V, k = _get_UVk(N, f, phi)
29 |
30 | n = U*np.exp(1j*2*np.pi*k/N) - V
31 | d = np.cos(2*np.pi*f/N) - np.cos(2*np.pi*k/N)
32 | out = 1/2 * n / d
33 | return out
34 |
35 | def _sine_dft_modulus_frac(N, f, phi, get_params=False):
36 | U, V, k = _get_UVk(N, f, phi)
37 |
38 | n = np.sqrt(U**2 + V**2 - 2*U*V*np.cos(2*np.pi*k/N))
39 | d = abs(np.cos(2*np.pi*f/N) - np.cos(2*np.pi*k/N))
40 | out = 1/2 * n / d
41 |
42 | if get_params:
43 | return out, (U, V)
44 | return out
45 |
46 | def _delta(x):
47 | """Kronecker Delta (discrete unit impulse). Handles numeric precision."""
48 | x = np.asarray(x)
49 | eps = (0 if x.dtype.name.startswith('int') else
50 | np.finfo(x.dtype).eps)
51 | return (abs(x) <= eps).astype(float)
52 |
53 | def _sine_dft_int(N, f, phi):
54 | k = np.arange(N)
55 | a = np.exp(1j*phi ) * _delta(np.mod(k - f, N))
56 | b = np.exp(-1j*phi) * _delta(np.mod(k + f, N))
57 | return (N/2) * (a + b)
58 |
59 | def _sine_dft_modulus_int(N, f, phi):
60 | k = np.arange(N)
61 | dright = _delta(np.mod(k - f, N))
62 | dleft = _delta(np.mod(k + f, N))
63 |
64 | return (N/2) * np.sqrt(dleft + dright + 2*np.cos(2*phi)*dleft*dright)
65 |
66 |
67 | def sine_stft(N, M, H, f, phi):
68 | assert M <= N and 1 <= H <= N, (N, M, H)
69 | n_hops = (N - M)//H + 1
70 | out = np.zeros((M, n_hops), dtype='complex128')
71 |
72 | for tau in range(n_hops):
73 | phi_tau = phi + 2*np.pi*f*tau*H/N
74 | out[:, tau] = sine_dft(M, f*M/N, phi_tau)
75 | return out
76 |
--------------------------------------------------------------------------------
/SignalProcessing/Q88042 - sampling - Is downsampling LTI for bandlimited inputs/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # https://dsp.stackexchange.com/q/88042/50076
3 | import numpy as np
4 | from numpy.fft import fft, ifft
5 |
6 | try:
7 | from .toolkit import fft_upsample
8 | got_fft_upsample = True
9 | except:
10 | got_fft_upsample = False
11 | import warnings
12 | warnings.warn("`fft_upsample` not found! Some tests skipped.")
13 |
14 | # NOTE: linearity isn't tested!
15 |
16 | #%%###########################################################################
17 | # Helpers
18 | # -------
19 | def crandn(N):
20 | return np.random.randn(N) + 1j*np.random.randn(N)
21 |
22 | # note, this `op` is deliberately incomplete relative to the answer's formulation
23 | # (i.e. omits `fft_upsample`, so that we can test non-upsampled cases)
24 | def op(a, b, d, x_only):
25 | if x_only:
26 | return ifft(fft(a[::d]))
27 | else:
28 | return ifft(fft(a[::d]) * fft(b[::d] * d))
29 |
30 | #%%###########################################################################
31 | # Testing
32 | # -------
33 | # np.random.seed(0) # uncomment for reproducible debugging
34 | optional_successes = {i: [] for i in range(5)}
35 | n_cases = 0
36 |
37 | for real in (True, False):
38 | for x_bandlimited in (True, False):
39 | for h_bandlimited in (True, False):
40 | for x_only in (True, False):
41 | if x_only and not h_bandlimited:
42 | # duplicate case
43 | continue
44 | for N in (16, 40, 128):
45 | for d in range(2, N//2 + 1, 2):
46 | for s in range(1, N):
47 | # prepare --------------------------------------------------------
48 | # check if loop is valid
49 | if (N / d) % 2 != 0:
50 | # `fft_upsample` doesn't handle this case, nor does it seem
51 | # that many of such cases can be handled
52 | continue
53 |
54 | # store reusables
55 | both_bandlimited = (x_bandlimited and h_bandlimited)
56 | one_bandlimited = (x_bandlimited or h_bandlimited)
57 | not_bandlimited = (not x_bandlimited and not h_bandlimited)
58 | bandlimited = ((x_only and x_bandlimited) or
59 | (not x_only and both_bandlimited))
60 |
61 | # generate filter, signal
62 | x, h = crandn(N), crandn(N)
63 | if real:
64 | x, h = x.real, h.real
65 | xf, hf = fft(x), fft(h)
66 | zero_idxs = (slice(N//d//2, -(N//d//2 - 1)) if N//d != 2 else
67 | slice(1, None))
68 | if x_bandlimited:
69 | xf[zero_idxs] = 0
70 | if h_bandlimited:
71 | hf[zero_idxs] = 0
72 | x, h = ifft(xf), ifft(hf)
73 | if real:
74 | assert np.allclose(max(abs(x.imag)), 0)
75 | assert np.allclose(max(abs(h.imag)), 0)
76 | x, h = x.real, h.real
77 |
78 | # execute --------------------------------------------------------
79 | # get outputs
80 | o0a = op(np.roll(x, 0), h, 1, x_only)
81 | o0b = op(np.roll(x, s), h, 1, x_only)
82 | o1a = op(np.roll(x, 0), h, d, x_only)
83 | o1b = op(np.roll(x, s), h, d, x_only)
84 |
85 | # get upsampled outputs
86 | if got_fft_upsample:
87 | o1au, o1bu = [fft_upsample(o, d, time_to_time=True)
88 | for o in (o1a, o1b)]
89 |
90 | # validate -------------------------------------------------------
91 | # prepare info in case assert fails
92 | cfg = dict(real=real,
93 | x_bandlimited=x_bandlimited,
94 | h_bandlimited=h_bandlimited,
95 | x_only=x_only, N=N, d=d, s=s)
96 | info = "\n " + "\n ".join(f"{k}={v}" for k, v in cfg.items())
97 | _err = lambda i: f'{i}:{info}'
98 |
99 | # fetch expectations
100 | cond = {
101 | 0: True,
102 | 1: False,
103 | 2: (s / d).is_integer(),
104 | 3: (s / d).is_integer() or bandlimited,
105 | 4: bandlimited,
106 | }
107 | optional = {
108 | 0: False,
109 | 1: True,
110 | # only integer `s/d` must pass
111 | 2: False if cond[2] else True,
112 | # only integer `s/d` or bandlimited must pass
113 | 3: False if cond[3] else True,
114 | # only bandlimited must be roLTI
115 | 4: False if cond[4] else True,
116 | }
117 | results = {}
118 |
119 | # gather successes/failures
120 | # [LTI] (no subsampling)
121 | # shifted input <=> shifted output
122 | results[0] = (cond[0] == np.allclose(o0b, np.roll(o0a, s)))
123 | # [LTI]
124 | # subsampling of shift <=> shift of subsampling
125 | results[1] = (cond[1] == np.allclose(o1b, np.roll(o1a, s//d)))
126 | # [fLTI]
127 | # subsampling of shift <=> shift of subsampling; integer `s/d`
128 | results[2] = (cond[2] == np.allclose(o1b, np.roll(o1a, s//d)))
129 | # [rLTI]
130 | # upsampled subsampling of shift <=> shift of upsampled subsampling
131 | if got_fft_upsample:
132 | results[3] = (cond[3] == np.allclose(o1bu, np.roll(o1au, s)))
133 | # [roLTI] upsampled subsampling of shift <=> shift
134 | if got_fft_upsample:
135 | results[4] = (cond[4] == np.allclose(o1bu, o0b))
136 |
137 | # run assertions on non-optionals, append info about the rest
138 | for i in results:
139 | if i != 0:
140 | if not optional[i]:
141 | assert results[i], _err(i)
142 | else:
143 | if results[i]:
144 | optional_successes[i].append(tuple(cfg.values()))
145 | else:
146 | assert results[0], (
147 | "0:\n" "the operation, unrelated to subsampling, "
148 | "isn't LTI")
149 | n_cases += 1
150 |
151 | #%%###########################################################################
152 | # Trivia
153 | # ------
154 | print("Number of optional (unexpected) successes:\n " +
155 | "\n ".join("Case {}: {:>4} ({:>4.3g}%)".format(
156 | i, len(v), 100*len(v)/n_cases)
157 | for i, v in optional_successes.items()))
158 |
159 | print("\nExample optional successes for case 2:\n " +
160 | ", ".join(list(cfg)) + "\n " +
161 | "\n ".join(str(v) for v in optional_successes[2][:5]))
162 |
--------------------------------------------------------------------------------
/StackOverflow/Q76812752 - MATLAB dictionary comprehension/ifelse.m:
--------------------------------------------------------------------------------
1 | % https://stackoverflow.com/q/76812752/10133797
2 | % Thanks @user16372530 https://stackoverflow.com/a/73751467/10133797
3 | function out = ifelse(a, cond, b)
4 | if cond
5 | out = a();
6 | else
7 | out = b();
8 | end
9 | end
10 |
--------------------------------------------------------------------------------
/StackOverflow/Q76812752 - MATLAB dictionary comprehension/py_dictcomp.m:
--------------------------------------------------------------------------------
1 | % https://stackoverflow.com/q/76812752/10133797
2 | function out = py_dictcomp(fn_k, fn_v, iterable, cond, cell_values)
3 | % py_dictcomp
4 | % Mimics Python's
5 | %
6 | % {f(k): g(v) for k, v in dict.items() if cond(k, v)}
7 | %
8 | % where [f(k), g(v)] = fn(k, v)
9 | %
10 | % or, if `iterable` isn't `dictionary`,
11 | %
12 | % {k: v for i, x in enumerate(iterable) if cond(i, x)}
13 | %
14 | % where `[k, v] = f(i, x)`.
15 | %
16 | % `cell_values=true` makes `out` like `dictionary(a={1})`. By default,
17 | % `false` is attempted first. Has no effect if `iterable` is `dictionary`.
18 | %
19 | % Default `cond = @(a, b) true`.
20 | %
21 | if nargin == 3
22 | cond = @(a, b) true;
23 | end
24 | if nargin <= 4
25 | cell_values = false;
26 | end
27 | user_set_cell_values = (nargin == 5);
28 |
29 | out = dictionary();
30 | if strcmp(class(iterable), "dictionary")
31 | cell_values = strcmp(iterable.values, "cell");
32 | keys = iterable.keys;
33 | values = iterable.values;
34 |
35 | for l=1:iterable.numEntries
36 | k = keys{l};
37 | if cell_values
38 | v = values{l};
39 | else
40 | v = values(l);
41 | end
42 | if cond(k, v)
43 | k = fn_k(k);
44 | v = fn_v(v);
45 | if cell_values
46 | out{k} = v;
47 | else
48 | out(k) = v;
49 | end
50 | end
51 | end
52 | else
53 | is_cell = iscell(iterable);
54 | for i=1:numel(iterable)
55 | if is_cell
56 | x = iterable{i};
57 | else
58 | x = iterable(i);
59 | end
60 |
61 | if cond(i, x)
62 | k = fn_k(i);
63 | v = fn_v(x);
64 |
65 | if cell_values
66 | out{k} = v;
67 | else
68 | try
69 | out(k) = v;
70 | % e.g. first populate with `double` values, then
71 | % `string`, the `string` becomes `NaN` unless cell
72 | if isnan(out(k)) && ~isnan(v)
73 | error("NaN `out(k)` without NAN `v`")
74 | end
75 | catch e
76 | if user_set_cell_values
77 | throw(e)
78 | else
79 | out = py_dictcomp(fn_k, fn_v, iterable, cond, true);
80 | end
81 | end
82 | end
83 | end
84 | end
85 | end
86 | end
87 |
--------------------------------------------------------------------------------
/StackOverflow/Q76812752 - MATLAB dictionary comprehension/test_py_dictcomp.m:
--------------------------------------------------------------------------------
1 | % https://stackoverflow.com/q/76812752/10133797
2 | %% {i: 5 for i, x in enumerate([1, 2, 3])}
3 | out = py_dictcomp(@(i)i, @(v)5, [1 2 3]);
4 |
5 | ref = dictionary();
6 | for i=1:3
7 | ref(i) = 5;
8 | end
9 | assert(isequal(out, ref))
10 |
11 | %% {i: x for i, x in enumerate([1, 2, "a"])}
12 | out = py_dictcomp(@(i)i, @(v)v, [1 2 "a"]);
13 |
14 | ref = dictionary(1, {"1"}, 2, {"2"}, 3, {"a"});
15 | assert(isequal(out, ref))
16 |
17 | %% {2*i: x-1 for i, x in enumerate([1, 2, 3]) if (x != 2 and i > 0)}
18 | out = py_dictcomp(@(i)2*i, @(x)x - 1, [1 2 3], @(i, x)x ~= 2 && i > 1);
19 |
20 | ref = dictionary();
21 | ref(6) = 2;
22 | assert(isequal(out, ref))
23 |
24 | %% {i: ("a", x**2) for i, x in enumerate([1, np.array([2, 3])])}
25 | out = py_dictcomp(@(i)i, @(x){"a", x.^2}, {1, [2 3]});
26 |
27 | ref = dictionary();
28 | cl = {1, [2 3]};
29 | for i=1:2
30 | ref{i} = {"a", cl{i}.^2};
31 | end
32 | assert(isequal(out, ref))
33 |
34 | %% {k: v**2 for k, v in dict(a=1, b=-2).items()}
35 | out = py_dictcomp(@(k)k, @(v)v^2, dictionary(a=1, b=-2));
36 | ref = dictionary(a=1^2, b=(-2)^2);
37 | assert(isequal(out, ref))
38 |
39 | %% {k: v**2 for k, v in dict(a=1, b=-2).items() if v > 0}
40 | out = py_dictcomp(@(k)k, @(v)v^2, dictionary(a=1, b=-2), @(k, v)v > 0);
41 | ref = dictionary(a=1^2);
42 | assert(isequal(out, ref))
43 |
44 | %% {("we have" if i == 0 else i - 1):
45 | % ((x + "ogs") if isinstance(x, str) else (x + 1))
46 | % for i, x in enumerate([2, "d", 3.5])
47 | % if not isinstance(x, float)}
48 | out = py_dictcomp(@(i)ifelse("we have", i == 1, i - 1), ...
49 | @(x)ifelse(x + "ogs", isstring(x), x + 1), ...
50 | {2, "d", 3.5}, ...
51 | @(i, x)~(isnumeric(x) && mod(x, 1) ~= 0));
52 |
53 | %% invalid `cell_values=false`
54 | try
55 | out = py_dictcomp(@(i)i, @(x)x, {1, [2, 3]}, @(a, b)true, false);
56 | assert(false);
57 | catch e
58 | assert(contains(e.identifier, "KeyValueDimsMustMatch"))
59 | end
60 |
--------------------------------------------------------------------------------