├── .gitignore ├── LICENSE ├── README.md ├── SignalProcessing ├── Q11333 - STFT - Beats in spectrogram of pure sines │ └── main.py ├── Q29509 - DFT - Extract sine phase and amplitude │ ├── estimators.py │ └── main.py ├── Q76329 - CWT, Wavelets - Scale vs frequency │ ├── im0.png │ ├── im1.png │ ├── im10.png │ ├── im11.png │ ├── im12.png │ ├── im13.png │ ├── im14.png │ ├── im15.png │ ├── im16.png │ ├── im3.png │ ├── im4.png │ ├── im5.png │ ├── im6.png │ ├── im7.png │ ├── im8.png │ ├── im9.png │ ├── main.py │ └── wavgif.gif ├── Q76463 - FM - Frequency Modulating a Signal │ ├── im0.png │ ├── im1.png │ ├── im2.png │ └── main.py ├── Q76560 - Analyticity - Symmetries │ ├── main.py │ ├── morlets0.png │ ├── morlets1.png │ └── rands0.png ├── Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass │ ├── WGN.png │ ├── im2.png │ ├── ims.gif │ └── main.py ├── Q76644 - estimation - Estimate sine frequency under white noise │ ├── _optimized.pyx │ ├── benchmarks.py │ ├── estimators.py │ ├── main.py │ ├── optimized.py │ ├── setup.py │ └── utils76644.py ├── Q76754 - Hilbert - Alleviate boundary effects │ ├── data.npy │ └── main.py ├── Q78512 - Wavelet Scattering explanation │ ├── cwt.gif │ ├── cwt.py │ ├── cwt_phi.gif │ ├── cwt_phi.py │ ├── reconstruction.gif │ ├── reconstruction.py │ ├── scat_impulse.gif │ ├── scat_impulse_anim.py │ ├── scat_impulse_resc.gif │ ├── scat_impulse_unscaled_anim.gif │ ├── second_order.py │ ├── sine_warp.png │ ├── tshift_invar.py │ ├── warp_cwt_anim.py │ ├── warp_cwt_echirp.gif │ ├── warp_cwt_sine.gif │ ├── warp_scat_T.gif │ ├── warp_scat_T_anim.py │ ├── warp_scat_anim.py │ ├── warp_scat_tau.gif │ ├── warp_scat_tau_anim.py │ ├── warp_scat_tau_echirp.gif │ └── warp_scat_tau_pulse.gif ├── Q78644 - Joint Time-Frequency Scattering explanation │ ├── data_3d_anim.py │ ├── echirp_2d.py │ ├── echirp_2d_anim.py │ ├── echirp_3d_anim.py │ ├── fdts.py │ ├── filterbank.gif │ ├── filterbank.py │ ├── freq_tp.gif │ ├── freq_tp.py │ ├── jtfs3d_echirp_full.gif │ ├── jtfs3d_echirp_spinned.gif │ ├── jtfs3d_seiz.gif │ ├── jtfs3d_trumpet.gif │ ├── jtfs_echirp.gif │ ├── jtfs_echirp_wavelets.gif │ ├── jtfs_sine.gif │ ├── jtfs_ts.gif │ └── sine_2d_anim.py ├── Q80918 - STFT - Overlap, window length, uncertainty principle │ └── main.py ├── Q81624 - STFT, optimization - Optimize window length for STFT │ └── main.py ├── Q85745 - Energy, power, relation to sampling rate and duration │ └── main.py ├── Q86181 - CWT Power and Energy │ ├── cwt_power_example.m │ ├── cwt_power_example.py │ └── cwt_power_validation.py ├── Q86726 - STFT - Amplitude extraction │ └── main.py ├── Q86937 - STFT - Equivalence of Windowed Fourier and Convolutions │ └── main.py ├── Q87355 - audio, algorithms - Detecting abrupt changes │ ├── T2_D_00004323.wav │ ├── T2_D_00004324.wav │ ├── T2_D_00004326.wav │ ├── T2_D_00004328.wav │ ├── T2_D_00004330.wav │ ├── T2_D_00004331.wav │ ├── T2_D_00004333.wav │ ├── T2_D_00004335.wav │ ├── T2_D_00004338.wav │ ├── T2_D_00004339.wav │ ├── T2_T_00043244.wav │ ├── T2_T_00043245.wav │ ├── T2_T_00043246.wav │ ├── example.wav │ ├── im00.png │ ├── im01.png │ ├── im02.png │ ├── im03.png │ ├── im04.png │ ├── im05.png │ ├── im06.png │ ├── im07.png │ ├── im08.png │ ├── im09.png │ ├── im10.png │ ├── im11.png │ ├── im12.png │ ├── im13.png │ ├── main.py │ ├── preds.gif │ ├── test_algo.py │ └── utils87355.py ├── Q87706 - aliasing - How to measure aliasing │ └── main.py ├── Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation │ ├── covid.png │ ├── covid_target.png │ ├── covid_target2.png │ ├── covid_template.png │ ├── main.py │ └── validation.py ├── Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python │ ├── cc2d.py │ └── main.py ├── Q87926 - DFT of a sine, closed form solution and insights │ ├── main.py │ └── solutions.py └── Q88042 - sampling - Is downsampling LTI for bandlimited inputs │ └── main.py └── StackOverflow └── Q76812752 - MATLAB dictionary comprehension ├── ifelse.m ├── py_dictcomp.m └── test_py_dictcomp.m /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OverLordGoldDragon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StackExchange Answers 2 | 3 | [![DOI](https://zenodo.org/badge/387701666.svg)](https://zenodo.org/badge/latestdoi/387701666) 4 | 5 | "Watch" to be notified of new posts. 6 | 7 | Question titles may be paraphrased, but IDs are exact. 8 | 9 | ### Citing 10 | 11 | If using significant portions of code, please cite this repository. Short form (year = latest, or when the cited content was published): 12 | 13 | > John Muradeli, StackExchange Answers, 2023. GitHub repository, https://github.com/OverLordGoldDragon/StackExchangeAnswers/. DOI: 10.5281/zenodo.5115534 14 | 15 | BibTeX 16 | 17 | ```bibtex 18 | @article{OverLordGoldDragon2020StackExchangeAnswers, 19 | title={StackExchange Answers}, 20 | author={John Muradeli}, 21 | journal={GitHub. Note: https://github.com/OverLordGoldDragon/StackExchangeAnswers/}, 22 | year={2023}, 23 | doi={10.5281/zenodo.5115534}, 24 | } 25 | ``` 26 | 27 |
28 | 29 | ### Signal Processing / Feature Engineering 30 | 31 | Code-accompanied answers say "code Q...", found in respective folders, or "code included". 32 | 33 | ✯ = exceptional quality; 🡡 = great quality (brag!) 34 | 35 | #### Feature Engineering (applied) / Problem Solving 36 | 37 | 1. ✯ Identifying abrupt changes in audio - [Q&A](https://dsp.stackexchange.com/a/87512/50076), code Q87355 38 | 2. 🡡 Locating non-homogeneous areas in an image - [Q&A](https://dsp.stackexchange.com/a/75636/50076), code included 39 | 3. 🡡 Estimate sine frequency under white noise - [Q&A](https://dsp.stackexchange.com/a/88711/50076), code Q76644 40 | 4. Extract sine phase and amplitude, accurate and noise-robust method - [Q&A](https://dsp.stackexchange.com/a/88833/50076), code Q29509 41 | 5. Nonlinear filtering and inspection of noisy vibration signal - [Q&A](https://dsp.stackexchange.com/a/84095/50076), code included 42 | 43 | #### Intuition / "Deep" Explanations 44 | 45 | 1. ✯ Physical significance of negative frequencies - [Q&A](https://dsp.stackexchange.com/a/87994/50076) 46 | 2. ✯ Wavelet "center frequency" explanation, relation to CWT scales - [Q&A](https://dsp.stackexchange.com/a/76371/50076), code Q76329 47 | 3. 🡡 DFT coefficients meaning? - [Q&A](https://dsp.stackexchange.com/a/70395/50076) 48 | 4. 🡡 Why is `|x|^2` power? - [Q&A](https://dsp.stackexchange.com/a/86919/50076) 49 | 5. 🡡 Does zero-padding distort the spectrum? - [Q&A](https://dsp.stackexchange.com/a/70498/50076) 50 | 6. How is a 1D signal considered a "vector"? - [Q&A](https://dsp.stackexchange.com/a/87057/50076) 51 | 7. Are colored images 2D or 3D? - [Q&A](https://dsp.stackexchange.com/a/81831/50076) 52 | 8. Are real-world signals bandlimited? - [Q&A](https://dsp.stackexchange.com/a/75147/50076) 53 | 9. What can be meant by "bandlimited"? - [Q&A](https://dsp.stackexchange.com/a/70949/50076) 54 | 55 | #### Measures / Heuristics 56 | 57 | 1. ✯ How to validate a wavelet filterbank? - [Q&A](https://dsp.stackexchange.com/a/86069/50076), code to-be-released 58 | 2. 🡡 How to measure aliasing? - [Q&A](https://dsp.stackexchange.com/a/87706/50076), code Q87705 59 | 60 | #### Time-Frequency Analysis / Feature Engineering 61 | 62 | 1. ✯ Wavelet Scattering explanation - [Q&A](https://dsp.stackexchange.com/a/78513/50076), code Q78512 63 | 2. ✯ Joint Time-Frequency Scattering explanation - [Q&A](https://dsp.stackexchange.com/a/78623/50076), code Q78644 64 | 3. ✯ Synchrosqueezing explanation - [Q&A](https://dsp.stackexchange.com/a/71399/50076) 65 | 4. ✯ Wavelet Scattering properties & implementation - [Q&A](https://dsp.stackexchange.com/a/78515/50076), code Q78512 66 | 5. 🡡 Equivalence between "windowed Fourier transform" and STFT as convolutions/filtering - [Q&A](https://dsp.stackexchange.com/a/86938/50076), code Q86937 67 | 6. 🡡 What is "spin" for the 2D (separable) Morlet? - [Q&A](https://dsp.stackexchange.com/a/86013/50076), code to-be-released 68 | 7. 🡡 Amplitude extraction using STFT - [Q&A](https://dsp.stackexchange.com/a/86731/50076), code Q86726 69 | 8. 🡡 Why are there beats in STFT of sines? - [Q&A](https://dsp.stackexchange.com/a/87086/50076), code Q11333 70 | 9. 🡡 Power/Energy from CWT - [Q&A](https://dsp.stackexchange.com/a/86182/50076), code Q86181 71 | 10. 🡡 CWT vs DWT - [Q&A](https://dsp.stackexchange.com/a/76639/50076) 72 | 11. 🡡 One-integral inverse CWT - [Q&A](https://dsp.stackexchange.com/a/76239/50076), code included 73 | 12. 🡡 Joint Time-Frequency Scattering structure & implementation - [Q&A](https://dsp.stackexchange.com/a/78625/50076), code Q78644 74 | 13. What do CWT values correspond to? - [Q&A](https://dsp.stackexchange.com/a/86939/50076), code included 75 | 14. Role of window length and overlap in uncertainty principle - [Q&A](https://dsp.stackexchange.com/a/80920/50076), Q80918 76 | 15. Inverting a scalogram - [Q&A](https://dsp.stackexchange.com/a/78531/50076), code Q78530 77 | 16. Importance of overlap in STFT - [Q&A](https://dsp.stackexchange.com/a/76507/50076) 78 | 17. Advantages of complex wavelets over real-valued - [Q&A](https://dsp.stackexchange.com/a/76093/50076) 79 | 18. Bandpass filtering for amplitude - [Q&A](https://dsp.stackexchange.com/a/75819/50076), code included 80 | 19. How is wavelet time & frequency resolution computed? - [Q&A](https://dsp.stackexchange.com/a/72043/50076) 81 | 20. How is wavelet center frequency computed? - [Q&A](https://dsp.stackexchange.com/a/70837/50076) 82 | 83 | #### Signal Theory 84 | 85 | 1. ✯ DFT of a sine, closed form solution and insights - [Q&A](https://dsp.stackexchange.com/a/88365/50076), code Q87926 86 | 2. 🡡 Subsampling in frequency domain, Effect of sampling rate on spectrum - [Q&A](https://dsp.stackexchange.com/a/87121/50076), code included 87 | 3. 🡡 Why is SNR periodic/sinusoidal for streamed windowing of a sine? - [Q&A](https://dsp.stackexchange.com/a/87933/50076) 88 | 4. 🡡 DFT modulus of a sine, closed form solution and insights - [Q&A](https://dsp.stackexchange.com/a/88400/50076), code Q87926 89 | 5. Relationship between energy, power, and sampling rate - [Q&A](https://dsp.stackexchange.com/a/86132/50076), code Q85745 90 | 6. Hilbert transform for amplitude extraction/demodulation of broadband waveforms - [Q&A](https://dsp.stackexchange.com/a/83257/50076) 91 | 7. Other end of Nyquist limit - [Q&A](https://dsp.stackexchange.com/a/84746/50076) 92 | 8. Instantaneous frequency of a signal - [Q&A](https://dsp.stackexchange.com/a/84144/50076) 93 | 9. DFT filter bank interpretation - [Q&A](https://dsp.stackexchange.com/a/83878/50076) 94 | 10. `ifft` first half of time domain ("DFT unpad property") - [Q&A](https://dsp.stackexchange.com/a/81994/50076), code included 95 | 11. Role of padding in convolution / Alleviating edge effects in Hilbert transform - [Q&A](https://dsp.stackexchange.com/a/76792/50076), code Q76754 96 | 12. How to frequency modulate? - [Q&A](https://dsp.stackexchange.com/a/76482/50076) 97 | 13. Symmetries of analyticity - [Q&A](https://dsp.stackexchange.com/q/76560/50076) 98 | 14. Effect of sampling rate and duration on discrete parameters of sine (spectrum) - [Q&A](https://dsp.stackexchange.com/a/88389/50076) 99 | 100 | #### Algorithms 101 | 102 | 1. 🡡 2D FFT cross-correlation in Python - [Q&A](https://dsp.stackexchange.com/a/87782/50076), code Q87781 103 | 2. Efficient 2D convolution/cross-correlation along only one axis (1D output) - [Q&A](https://dsp.stackexchange.com/a/87563/50076), code included 104 | 3. Optimizing window length in STFT via gradient descent - [Q&A](https://dsp.stackexchange.com/a/81627/50076), code Q81624 105 | 106 | #### Proofs / Mathy 107 | 108 | 1. 🡡 Amplitude extraction / demodulation criteria for Hilbert transform - [Q&A](https://dsp.stackexchange.com/a/83299/50076) 109 | 2. 🡡 Is downsampling LTI for bandlimited inputs? - [Q&A](https://dsp.stackexchange.com/a/88055/50076), code Q88042 110 | 3. 🡡 Sine DFT: N/4-symmetry, bin sinusoidal modulation vs time shifts - [Q&A](https://dsp.stackexchange.com/a/88421/50076), code Q87926 111 | 4. Inverse CWT derivation - [Q&A](https://dsp.stackexchange.com/a/71148/50076) 112 | 113 | 114 |
115 | 116 | ### License 117 | 118 | MIT-licensed. 119 | -------------------------------------------------------------------------------- /SignalProcessing/Q11333 - STFT - Beats in spectrogram of pure sines/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/11333/50076 3 | # Some code copied from my answer to https://dsp.stackexchange.com/q/85745/50076 4 | import numpy as np 5 | from scipy.signal.windows import dpss 6 | from numpy.fft import fft, ifft, fftshift, ifftshift 7 | from ssqueezepy import stft 8 | from ssqueezepy.visuals import plot, plotscat 9 | 10 | #%%############################################################################ 11 | # Helpers 12 | # ------- 13 | def _pad_window(w, padded_len): 14 | pleft = (padded_len - len(w)) // 2 15 | pright = padded_len - pleft - len(w) 16 | return np.pad(w, [pleft, pright]) 17 | 18 | def get_adj_windows(window, n_fft, N): 19 | padded_len_conv = N + n_fft - 1 20 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2) 21 | # window_adj_f = np.pad(window, (n_fft - len(window))//2) 22 | window_adj = _pad_window(window, padded_len_conv) 23 | window_adj_f = _pad_window(window, n_fft) 24 | 25 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0 26 | def _center(w): 27 | w = ifftshift(w) 28 | w = fftshift(np.roll(w, np.argmax(w))) 29 | return w 30 | 31 | window_adj = _center(window_adj) 32 | window_adj_f = _center(window_adj_f) 33 | return window_adj, window_adj_f 34 | 35 | #%%############################################################################ 36 | # Row-wise STFT implementation 37 | # ---------------------------- 38 | def cisoid(N, f): 39 | t = np.linspace(0, 1, N, endpoint=False) 40 | return (np.cos(2*np.pi * f * t) + 41 | np.sin(2*np.pi * f * t) * 1j) 42 | 43 | 44 | def stft_rowwise(x, window, n_fft): 45 | assert len(window) == n_fft and n_fft <= len(x) 46 | 47 | # compute some params 48 | xp = x 49 | N = len(x) 50 | padded_len = N 51 | 52 | # pad such that `armgax(window)` for `Sx[:, 0]` is at `Sx[:, 0]` 53 | # (DFT-center the convolution kernel) 54 | # note, due to how scipy generates windows and standard stft handles padding, 55 | # this still won't yield zero-phase for majority of signals, but it's both 56 | # fixable and can be very close; ideally just pass in `len(window)==len(x)`. 57 | _wpad_right = (padded_len - len(window)) / 2 58 | wpad_right = (int(np.floor(_wpad_right)) if n_fft % 2 == 1 else 59 | int(np.ceil(_wpad_right))) 60 | wpad_left = padded_len - len(window) - wpad_right 61 | 62 | # generate filterbank 63 | cisoids = np.array([cisoid(n_fft, f) for f in range(n_fft//2 + 1)]) 64 | fbank = ifftshift(window) * cisoids 65 | fbank = np.pad(fftshift(fbank), [[0, 0], [wpad_left, wpad_right]]) 66 | fbank = ifftshift(fbank) 67 | fbank_f = fft(fbank, axis=-1).conj() 68 | 69 | # circular convolution 70 | prod = fbank_f * fft(xp)[None] 71 | Sx = ifft(prod, axis=-1) 72 | return Sx, fbank_f 73 | 74 | #%%############################################################################ 75 | # Current answer -- messy code 76 | import matplotlib.pyplot as plt 77 | from scipy.signal.windows import gaussian 78 | from ssqueezepy import ssq_cwt, TestSignals, Wavelet 79 | from ssqueezepy.visuals import imshow 80 | 81 | N = 256 82 | n_fft = N 83 | t = np.linspace(0, 1, N, 0) 84 | x = np.cos(2*np.pi * 125 * t) 85 | window = gaussian(N, 6, 0) 86 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 87 | 88 | Sx, fbank_f = stft_rowwise(x, window_adj_f, n_fft) 89 | 90 | plot(fbank_f[::4].T, abs=1, color='tab:blue', h=.7, w=1.1, 91 | vlines=(128, {'linewidth': 3, 'color': 'tab:red'}), 92 | title="STFT filterbank, frequency-domain | Gaussian", show=1) 93 | #%% 94 | for case in (0, 1): 95 | if case == 0: 96 | a = fbank_f[-8].copy() 97 | else: 98 | a = fbank_f[-1].copy() 99 | title = "Filter peak freq = {}".format(np.argmax(np.abs(a))) 100 | 101 | fig, axes = plt.subplots(4, 1, figsize=(8, 24)) 102 | 103 | b = fft(x) 104 | a /= max(abs(a)) 105 | b /= max(abs(b)) 106 | pkw = dict(ylims=(-.01, 1.03), abs=1, fig=fig) 107 | plotscat(b, abs=1, color='k', ax=axes[0]) 108 | plot(a, **pkw, title=title, ax=axes[0]) 109 | plotscat(a * b, **pkw, ax=axes[1], title="fft(filter) * fft(x)") 110 | 111 | x_filt = ifft(a * b) 112 | plot(x_filt, complex=1, ax=axes[2], fig=fig,) 113 | plot(x_filt, abs=1, color='k', linestyle='--', ax=axes[2], fig=fig, 114 | title="x_filt = ifft(fft(filter) * fft(x))") 115 | 116 | plotscat(fft(abs(x_filt)), abs=1, ax=axes[3], fig=fig, 117 | title="fft(envelope) | envelope = abs(x_filt)") 118 | fig.subplots_adjust(hspace=.13) 119 | 120 | #%% 121 | ikw = dict(w=1.2, h=.35, yticks=np.arange(len(Sx))[::-1]) 122 | imshow(Sx[::-1], abs=1, **ikw, title="|STFT|") 123 | imshow(Sx.real[::-1], **ikw, title="STFT.real") 124 | imshow(Sx.imag[::-1], **ikw, title="STFT.imag") 125 | 126 | #%% 127 | ts = TestSignals(2048) 128 | x = ts.hchirp(fmin=.11)[0] 129 | t = np.linspace(0, 1, len(x), 0) 130 | x += x[::-1] 131 | 132 | for analytic in (0, 1): 133 | wavelet = Wavelet(('gmw', {'gamma': 1, 'beta': 1})) 134 | Tx, Wx, _, scales, *_ = ssq_cwt(x, wavelet, scales='log') 135 | 136 | if not analytic: 137 | psiup_t = ifft(wavelet(N=len(x)*4, scale=scales), axis=-1) 138 | wavelet._Psih = fft(psiup_t[:, ::2]) 139 | Tx, Wx, *_ = ssq_cwt(x, wavelet, scales=scales) 140 | 141 | title = "analytic" if analytic else "non-analytic" 142 | for i, g in enumerate([Tx, Wx]): 143 | scl = .5 if i == 0 else .5 144 | scl = 240 if i == 0 else 3 145 | imshow(g[:], abs=1, w=.6, h=.45, norm=(0, np.abs(g).mean()*scl), 146 | yticks=0, 147 | title=title + (" |SSQ_CWT|" if i == 0 else " |CWT|")) 148 | 149 | #%% 150 | for case in (0, 1): 151 | beta = 10 if case == 1 else 60 152 | wavelet = Wavelet(('gmw', {'gamma': 3, 'beta': beta})) 153 | _ = ssq_cwt(np.arange(8192), wavelet) 154 | 155 | pf = wavelet._Psih[20].copy() 156 | pf = np.roll(pf, -np.argmax(pf) + len(pf)//2) 157 | if case == 1: 158 | pf[len(pf)//2 + 1:] = 0 159 | pf[len(pf)//2] /= 2 160 | pt = ifftshift(ifft(pf)) 161 | 162 | fig, axes = plt.subplots(3, 1, figsize=(8, 18)) 163 | plot(pf, abs=1, ticks=0, fig=fig, ax=axes[0], title="freq-domain") 164 | plot(pt, complex=1, fig=fig, ax=axes[1]) 165 | plot(pt, abs=1, linestyle='--', color='k', ticks=0, 166 | fig=fig, ax=axes[1], title="time-domain") 167 | 168 | ctr = len(pt)//2 169 | zm = 40 170 | slc = pt[ctr-zm:ctr+zm+1] 171 | plot(slc, complex=1, fig=fig, ax=axes[2]) 172 | plot(slc, abs=1, linestyle='--', color='k', show=0, ticks=0, 173 | fig=fig, ax=axes[2], title="time-domain, zoomed") 174 | fig.subplots_adjust(hspace=.1) 175 | plt.show() 176 | -------------------------------------------------------------------------------- /SignalProcessing/Q29509 - DFT - Extract sine phase and amplitude/estimators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/29509/50076 3 | import numpy as np 4 | from numpy.fft import rfft 5 | 6 | 7 | def est_amp_phase_cedron_2bin(x): 8 | """Amplitude & phase. 9 | "A Two Bin Solution", Cedron Dawg, Eqs 25 & 26 10 | https://www.dsprelated.com/showarticle/1284.php 11 | """ 12 | # not performance-optimized 13 | 14 | # get DFT & compute freq 15 | N = len(x) 16 | xf = rfft(x) 17 | 18 | if N == 3: 19 | kmax = 1 20 | Z = xf 21 | f_N = _est_f_cedron_2bin(Z, kmax, N) 22 | else: 23 | kmax = np.argmax(abs(xf[1:-1])) + 1 # see other referenced answer 24 | Z = xf[kmax-1:kmax+2] 25 | f_N = _est_f_cedron_3bin(Z, kmax, N) 26 | 27 | # run two bin calculations 28 | alpha = f_N * (2*np.pi) 29 | alphaN = alpha * N 30 | 31 | kj = kmax - 1 32 | kk = kmax 33 | cosa = np.cos(alpha) 34 | sina = np.sin(alpha) 35 | cosbj, cosbk = np.cos(2*np.pi/N * np.array([kj, kk])) 36 | sinbj, sinbk = np.sin(2*np.pi/N * np.array([kj, kk])) 37 | 38 | UA = np.cos(alphaN) - 1 39 | VA = np.cos(alphaN - alpha) - cosa 40 | 41 | UB = np.sin(alphaN) 42 | VB = np.sin(alphaN - alpha) + sina 43 | 44 | fj = 1 / (2 * N * (cosa - cosbj)) 45 | fk = 1 / (2 * N * (cosa - cosbk)) 46 | 47 | A = np.array([ 48 | fj * (UA * cosbj - VA), 49 | fj * (UA * sinbj), 50 | fk * (UA * cosbk - VA), 51 | fk * (UA * sinbk), 52 | ]) 53 | B = np.array([ 54 | fj * (UB * cosbj - VB), 55 | fj * (UB * sinbj), 56 | fk * (UB * cosbk - VB), 57 | fk * (UB * sinbk), 58 | ]) 59 | 60 | # "unfurl" 61 | Z = np.array([Z[0].real, Z[0].imag, Z[1].real, Z[1].imag]) / N 62 | 63 | AB = A@B 64 | AZ = A@Z 65 | BZ = B@Z 66 | AA = A@A 67 | BB = B@B 68 | 69 | d = AA*BB - AB*AB 70 | ca = (BB*AZ - AB*BZ) / d 71 | cb = (AA*BZ - AB*AZ) / d 72 | 73 | amplitude = np.sqrt(ca**2 + cb**2) 74 | phase = np.arctan2(-cb, ca) 75 | return amplitude, phase 76 | 77 | 78 | ### Frequency case ########################################################### 79 | def _get_cedron_basics(Z, k, N): 80 | re = Z.real 81 | im = Z.imag 82 | 83 | rt2 = np.sqrt(2) 84 | betas = np.array([k - 1, k, k + 1]) * 2*np.pi/N 85 | cosb = np.cos(betas) 86 | sinb = np.sin(betas) 87 | return re, im, cosb, sinb, rt2 88 | 89 | def _cedron_bin_finish(A, B, C): 90 | P = C / np.linalg.norm(C) 91 | D = A + B 92 | K = D - (D @ P)*P 93 | 94 | num = K @ B # dot product 95 | den = K @ A 96 | ratio = max(min(num/den, 1), -1) # handles float issues 97 | f = np.arccos(ratio) / (2*np.pi) 98 | return f 99 | 100 | 101 | def est_f_cedron_3bin_complex(Z, k, N): 102 | """ 103 | "Three Bin Exact Frequency Formulas for a Pure Complex Tone in a DFT", 104 | Cedron Dawg, Eq 19 (via Eqs 35, 31, 20, 16) 105 | https://www.dsprelated.com/showarticle/1043.php 106 | """ 107 | R1 = np.exp(-1j*1*2*np.pi/N) 108 | DZ = Z * np.array([1/R1, 1, R1]) 109 | G = np.conj(Z + DZ) 110 | K = G - np.mean(G) 111 | 112 | num = K @ Z 113 | den = K @ DZ 114 | ratio = num / den 115 | if k > N//2: 116 | k = -(N - k) 117 | f = k/N + np.arctan2(ratio.imag, ratio.real) / (2*np.pi) 118 | return f 119 | 120 | 121 | def est_f_cedron_3bin(x): 122 | N = len(x) 123 | xf = rfft(x) 124 | kmax = np.argmax(abs(xf[1:-1])) + 1 125 | Z = xf[kmax-1:kmax+2] 126 | 127 | f = _est_f_cedron_3bin(Z, kmax, N) 128 | return f 129 | 130 | 131 | def _est_f_cedron_3bin(Z, k, N): 132 | """ 133 | "Improved Three Bin Exact Frequency Formula for a Pure Real Tone in a DFT", 134 | Cedron Dawg, Eq 9 135 | https://www.dsprelated.com/showarticle/1108.php 136 | """ 137 | re, im, cosb, sinb, rt2 = _get_cedron_basics(Z, k, N) 138 | 139 | A = np.array([(re[1] - re[0])/rt2, 140 | (re[1] - re[2])/rt2, 141 | im[0], 142 | im[1], 143 | im[2]]) 144 | B = np.array([(cosb[1]*re[1] - cosb[0]*re[0])/rt2, 145 | (cosb[1]*re[1] - cosb[2]*re[2])/rt2, 146 | cosb[0]*im[0], 147 | cosb[1]*im[1], 148 | cosb[2]*im[2]]) 149 | C = np.array([(cosb[1] - cosb[0])/rt2, 150 | (cosb[1] - cosb[2])/rt2, 151 | sinb[0], 152 | sinb[1], 153 | sinb[2]]) 154 | 155 | f = _cedron_bin_finish(A, B, C) 156 | return f 157 | 158 | 159 | def est_f_cedron_2bin(x): 160 | N = len(x) 161 | xf = rfft(x) 162 | 163 | if N == 3: 164 | # not necessarily optimal scheme 165 | kmax = 1 166 | Z = xf 167 | else: 168 | kmax = np.argmax(np.abs(xf)) 169 | Z = xf[kmax-1:kmax+1] 170 | 171 | f = _est_f_cedron_2bin(Z, kmax, N) 172 | return f 173 | 174 | 175 | def _est_f_cedron_2bin(Z, k, N): 176 | """ 177 | "A Two Bin Solution", Cedron Dawg, Eq 14 178 | https://www.dsprelated.com/showarticle/1284.php 179 | """ 180 | re, im, cosb, sinb, rt2 = _get_cedron_basics(Z, k, N) 181 | 182 | A = np.array([(re[1] - re[0])/rt2, 183 | im[1], 184 | im[0]]) 185 | B = np.array([(cosb[1]*re[1] - cosb[0]*re[0])/rt2, 186 | cosb[1]*im[1], 187 | cosb[0]*im[0]]) 188 | C = np.array([(cosb[1] - cosb[0])/rt2, 189 | sinb[1], 190 | sinb[0]]) 191 | 192 | f = _cedron_bin_finish(A, B, C) 193 | return f 194 | -------------------------------------------------------------------------------- /SignalProcessing/Q29509 - DFT - Extract sine phase and amplitude/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/29509/50076 3 | import warnings 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | # ensure the files can be found 8 | import sys 9 | from pathlib import Path 10 | _dir = Path(__file__).parent 11 | assert _dir.is_file() or _dir.is_dir(), str(_dir) 12 | if not any(str(_dir).lower() == p.lower() for p in sys.path): 13 | sys.path.insert(0, str(_dir)) 14 | 15 | from estimators import est_amp_phase_cedron_2bin, est_f_cedron_2bin 16 | 17 | #%% Helpers ################################################################## 18 | def cisoid(N, f, phi=0): 19 | return (np.cos(2*np.pi*f*np.arange(N)/N + phi) + 20 | np.sin(2*np.pi*f*np.arange(N)/N + phi)*1j) 21 | 22 | def est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi, 23 | errs_A=None, errs_f=None): 24 | if errs_f is not None: 25 | f_est = est_f_cedron_2bin(x) 26 | 27 | A_est, phi_est = est_amp_phase_cedron_2bin(x) 28 | err_A, err_A_alt, err_phi = get_errs_A_phi(A_est, phi_est, A, phi) 29 | 30 | if errs_A is not None: 31 | errs_A.append(err_A) 32 | if errs_f is not None: 33 | errs_f.append((f_est - f_N)**2) 34 | 35 | # normalized 36 | errs_A_alt.append(err_A_alt) 37 | 38 | errs_phi.append(err_phi) 39 | 40 | 41 | def get_errs_A_phi(A_est, phi_est, A, phi): 42 | err_A = (A_est - A)**2 43 | err_A_alt = (1 - A_est/A)**2 44 | 45 | # note inherent ambiguity; discard sign in this case 46 | if abs(abs(phi_est) - np.pi) < .001: 47 | err_phi = (abs(phi_est) - abs(phi))**2 48 | else: 49 | err_phi = (phi_est - phi)**2 50 | return err_A, err_A_alt, err_phi 51 | 52 | 53 | #%% Manual testing ########################################################### 54 | N = 3 55 | f = 0.0035035*N 56 | phi = -1 57 | A = 1 58 | 59 | x = A*cisoid(N, f, phi).real 60 | f_N = f/N 61 | 62 | f_est = est_f_cedron_2bin(x) 63 | A_est, phi_est = est_amp_phase_cedron_2bin(x) 64 | 65 | print(f_est / f_N) 66 | print(A_est / A) 67 | print(phi_est / phi) 68 | 69 | #%% Test full (noiseless) #################################################### 70 | N = 3 71 | 72 | errs_f_fmax, errs_A_fmax, errs_phi_fmax = [], [], [] 73 | errs_A_alt_fmax = [] 74 | # 0.5 takes more coding but can be handled 75 | 76 | # make freqs 77 | f_N_all = [] 78 | for f_N in np.linspace(0, 0.5, 203, endpoint=False): 79 | # skip near-integer frequency per numeric instability 80 | # (robust version not implemented) 81 | if (f_N * N) % 1 < .01: 82 | continue 83 | f_N_all.append(f_N) 84 | f_N_all = np.array(f_N_all) 85 | 86 | # make amplitudes 87 | A_all = np.logspace(np.log10(1e-3), np.log10(1000), 100) 88 | # make phases 89 | phi_all = np.linspace(-np.pi, np.pi, 100) 90 | 91 | for f_N in f_N_all: 92 | errs_f, errs_A, errs_phi = [], [], [] 93 | errs_A_alt = [] 94 | f = f_N * N 95 | 96 | for A in A_all: 97 | for phi in phi_all: 98 | x = A*cisoid(N, f, phi).real 99 | est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi, 100 | errs_A, errs_f) 101 | 102 | errs_f_fmax.append(np.max(errs_f)) 103 | errs_A_fmax.append(np.max(errs_A)) 104 | errs_phi_fmax.append(np.max(errs_phi)) 105 | errs_A_alt_fmax.append(np.max(errs_A_alt)) 106 | 107 | #%% Visualize ################################################################ 108 | fig, ax = plt.subplots(layout='constrained', figsize=(12, 8)) 109 | 110 | ax.plot(f_N_all, np.log10(errs_f_fmax)) 111 | ax.plot(f_N_all, np.log10(errs_A_alt_fmax)) 112 | ax.plot(f_N_all, np.log10(errs_phi_fmax)) 113 | 114 | title = ("Max errors: frequency, amplitude, phase | N={}\n" 115 | "n_freqs, n_amps, n_phases = {}, {}, {}" 116 | ).format(N, len(f_N_all), len(A_all), len(phi_all)) 117 | 118 | ax.set_title(title, weight='bold', fontsize=24, loc='left') 119 | ax.set_ylabel("Squared Error (log10)", fontsize=22) 120 | ax.set_xlabel("f/N", fontsize=22) 121 | ax.legend(["f", "A", "phi"], fontsize=22) 122 | 123 | plt.show() 124 | 125 | #%% Test full (noisy) ######################################################## 126 | np.random.seed(0) 127 | N = 300 128 | n_trials = 50 129 | 130 | snrs = np.linspace(-10, 50, 25) 131 | 132 | # make freqs 133 | f_N_all = [] 134 | for f_N in np.linspace(1/N, 0.5-1/N, 42): 135 | # skip near-integer frequency per numeric instability 136 | # (robust version not implemented) 137 | if (f_N * N) % 1 < .01: 138 | continue 139 | f_N_all.append(f_N) 140 | f_N_all = np.array(f_N_all) 141 | 142 | # make amplitudes 143 | A_all = np.logspace(np.log10(1e-3), np.log10(1000), 10) 144 | 145 | errs_A_alt_all, errs_phi_all = [], [] 146 | for snr in snrs: 147 | errs_A_alt, errs_phi = [], [] 148 | for f_N in f_N_all: 149 | f = f_N * N 150 | for A in A_all: 151 | for _ in range(n_trials): 152 | phi = np.random.uniform(-np.pi, np.pi) 153 | xo = A * cisoid(N, f, phi).real 154 | noise_var = xo.var() / 10**(snr/10) 155 | noise = np.random.randn(N) * np.sqrt(noise_var) 156 | x = xo + noise 157 | 158 | with warnings.catch_warnings(record=True) as ws: 159 | est_and_append_errs(x, f_N, A, phi, errs_A_alt, errs_phi) 160 | # if len(ws) > 0: 161 | # print(snr, f_N, f_N*N, A, phi, sep='\n') 162 | # 1/0 163 | errs_A_alt_all.append(np.mean(errs_A_alt)) 164 | errs_phi_all.append(np.mean(errs_phi)) 165 | print(end='.', flush=True) 166 | 167 | #%% Visualize ################################################################ 168 | fig, ax = plt.subplots(layout='constrained', figsize=(12, 8)) 169 | 170 | ax.plot(snrs, np.log10(errs_A_alt_all)) 171 | ax.plot(snrs, np.log10(errs_phi_all)) 172 | 173 | title = ("N={}, f/N=lin sweep, phase=randu(-pi, pi)\n" 174 | "n_freqs={}, n_amps={}, n_trials_per_f_and_A={}" 175 | ).format(N, len(f_N_all), len(A_all), n_trials) 176 | 177 | ax.set_title(title, weight='bold', fontsize=24, loc='left') 178 | ax.set_ylabel("MSE (log10)", fontsize=22) 179 | ax.set_xlabel("SNR [dB]", fontsize=22) 180 | ax.legend(["A", "phi"], fontsize=22) 181 | ax.set_ylim(-10, 1) 182 | 183 | plt.show() 184 | 185 | #%% OP's case ################################################################ 186 | np.random.seed(0) 187 | A = 1 188 | f = 30.1 189 | N = 300 190 | t = np.linspace(0, 1, N, endpoint=False) 191 | n_trials = 10000 192 | 193 | errs_A_alt, errs_phi = [], [] 194 | for _ in range(n_trials): 195 | phi = np.random.uniform(-np.pi, np.pi) 196 | xo = A*np.cos(2*np.pi*f*t + phi) 197 | noise = np.random.normal(0, 0.05, len(t)) 198 | x = xo + noise 199 | 200 | A_est, phi_est = est_amp_phase_cedron_2bin(x) 201 | 202 | _, err_A_alt, err_phi = get_errs_A_phi(A_est, phi_est, A, phi) 203 | errs_A_alt.append(err_A_alt) 204 | errs_phi.append(err_phi) 205 | 206 | SNR = 10*np.log10((A/2)/.05**2) 207 | print("MSE: A={:.3g}, phi={:.3g} -- SNR={:.3g}, N={}, n_trials={}".format( 208 | np.mean(errs_A_alt), np.mean(errs_phi), SNR, N, n_trials)) 209 | -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im0.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im1.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im10.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im11.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im12.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im13.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im14.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im15.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im16.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im3.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im4.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im5.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im6.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im7.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im8.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/im9.png -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/76329/50076 ################################ 3 | import numpy as np 4 | from numpy.fft import ifft, ifftshift 5 | from ssqueezepy.visuals import plot, scat, imshow 6 | from ssqueezepy.utils import make_scales, cwt_scalebounds 7 | from ssqueezepy import ssq_cwt, cwt, Wavelet 8 | from kymatio.scattering1d.filter_bank import morlet_1d 9 | 10 | #%%# Helper methods ########################################################## 11 | def viz0(xf, xref=None, center=True): 12 | x = ifft(xf) 13 | if center: 14 | x = ifftshift(x) 15 | scat(xf[:32], show=1) 16 | if xref is not None: 17 | xref *= np.abs(x).max() / np.abs(xref).max() 18 | plot(xref, color=grey, linestyle='--') 19 | plot(x.real, show=1) 20 | 21 | def morlet(N, xi, sigma=.1/2, viz=1, trim=True, nonhalved=False): 22 | xf = morlet_1d(N, xi=xi, sigma=sigma) 23 | 24 | trim_idx = int(xi * N) + 1 25 | if trim: 26 | xf[trim_idx:] = 0 27 | if not nonhalved: 28 | xf[trim_idx - 1] /= 2 29 | x = ifftshift(ifft(xf)) 30 | 31 | if viz: 32 | scat(xf, show=1) 33 | plot(x, complex=1) 34 | plot(x, abs=1, color='k', linestyle='--', show=1) 35 | return x, xf 36 | 37 | def ref_sine(xi, x=None, N=None, endpoint=False, zoom=None): 38 | N = N if N is not None else len(x) 39 | t = np.linspace(0, 1, N, endpoint=endpoint) 40 | xref = np.cos(2*np.pi * (xi * N) * t) 41 | if x is not None: 42 | xref *= x.real.max() / xref.max() 43 | 44 | if zoom is not None: 45 | ctr = N // 2 46 | zoom = (zoom, zoom) if not isinstance(zoom, tuple) else zoom 47 | a, b = ctr - zoom[0], ctr + zoom[1] + 1 48 | _t = np.arange(a, b) 49 | 50 | plot(_t, xref[a:b], color=grey, linestyle='--') 51 | if x is not None: 52 | plot(_t, x.real[a:b], show=1, title="zoomed, real part") 53 | return xref 54 | 55 | grey = (.1, .1, .1, .4) 56 | 57 | #%%# Simple sinusoid ######################################################### 58 | N = 128 59 | t = np.linspace(0, 1, N, 0) 60 | xf = np.zeros(N) 61 | 62 | xf[16] = 1 63 | xref = ifft(xf).real 64 | viz0(xf, center=0) 65 | 66 | #%%# Add lateral peaks 67 | xf = np.zeros(N) 68 | xf[15] = .5 69 | xf[17] = .5 70 | viz0(xf, center=0) 71 | #%%# Both at once 72 | xf[16] = 1 73 | viz0(xf, xref) 74 | #%%# Positive + negative example 75 | xf = np.zeros(N) 76 | xf[15] = .5 77 | xf[17] = -.5 78 | viz0(xf, center=0) 79 | 80 | #%%# GIF ##################################################################### 81 | xf_full = morlet_1d(N, xi=16/128, sigma=.1/8) 82 | xf = np.zeros(N) 83 | xf[16] = xf_full[16] 84 | viz0(xf) 85 | 86 | for idx in (17, 18, 19, 20, 21): 87 | xf[idx] = xf_full[idx] 88 | xf[16 + (16 - idx)] = xf_full[16 + (16 - idx)] 89 | viz0(xf, xref) 90 | 91 | #%%# CWT near Nyquist ######################################################## 92 | N = 129 93 | t = np.linspace(0, 1, N, 1) 94 | x = np.cos(2*np.pi * 64 * t) 95 | 96 | wavelet = Wavelet('morlet') 97 | min_scale, max_scale = cwt_scalebounds(wavelet, N=N, preset='minimal') 98 | # adjust such that highest freq wavelet's peak is at Nyquist 99 | scales0 = make_scales(N, min_scale, max_scale, wavelet=wavelet) * 1.12 100 | scales = scales0 101 | Wx, _ = cwt(x, wavelet, scales=scales) 102 | 103 | imshow(Wx, abs=1, title="abs(CWT)", xlabel="time", ylabel="scale", 104 | yticks=scales) 105 | plot(Wx[:, 0], abs=1, title="abs(CWT) at a time slice (same for all slices)", 106 | xlabel="scale", show=1, xticks=scales) 107 | imshow(Wx.real) 108 | 109 | #%%# Morlet near nyquist ##################################################### 110 | N = 128 111 | xf = morlet(N, xi=.5, viz=1) 112 | 113 | #%%# Trimmed peak but higher sampling rate 114 | N = 512 115 | xi = .5 / 4 116 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=1, nonhalved=True) 117 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=1) 118 | 119 | #%%# Zoomed sinusoid at center frequency 120 | xref = ref_sine(xi, x=x, zoom=20) 121 | xref = ref_sine(xi, x=x, zoom=(80, -40)) 122 | 123 | #%%# SSQ_CWT of wavelet ###################################################### 124 | # extreme time-localization 125 | wavelet1 = Wavelet(('gmw', {'beta': 1, 'gamma': 1})) 126 | 127 | Tx0, Wx0, *_ = ssq_cwt(x.real, wavelet1, scales='log') 128 | mx = np.abs(Tx0).max() * .6 129 | # zoomed 130 | imshow(Tx0, abs=1, title="abs(SSQ_CWT) of wavelet.real", norm=(0, mx)) 131 | imshow(Tx0[20:160], abs=1, norm=(0, mx)) 132 | 133 | #%%# CWT 134 | # to take perfect CWT of wavelet, tweak slightly to account for padding 135 | xref = np.cos(2*np.pi * (N//8) * np.linspace(0, 1, N + 1, 1)) 136 | Tx1, Wx1, *_ = ssq_cwt(xref, wavelet, scales='log') 137 | imshow(Tx1, abs=1, title="abs(SSQ_CWT) of sinusoid at peak center freq") 138 | 139 | #%%# Peak center frequency ################################################### 140 | N = 512 141 | xi = .5 / 4 142 | x, xf = morlet(N, xi=xi, sigma=.1/4, viz=0, nonhalved=True) 143 | scat(xf) 144 | f = int(N * xi) 145 | scat(np.array([f]), xf[f], color='red', s=30, show=1) 146 | 147 | #%% 148 | xref = ref_sine(xi, x=x) 149 | plot(xref, color=grey) 150 | plot(x.real, show=1, title="peak center frequency") 151 | 152 | #%% 153 | t = np.linspace(0, 1, N, 1) 154 | xref = np.cos(2*np.pi * 64 * np.linspace(0, 1, 513, 1)) 155 | 156 | wavelet = Wavelet('morlet') 157 | min_scale, max_scale = cwt_scalebounds(wavelet, N=N, preset='minimal') 158 | scales = make_scales(N, min_scale, max_scale, wavelet=wavelet) * 1.12 159 | Wx, scales = cwt(xref, wavelet, scales=scales) 160 | 161 | #%%# Energy center frequency ################################################# 162 | axf = np.abs(xf) 163 | axfs = axf**2 164 | fmean_l1 = np.sum(np.arange(len(x)) * axf / axf.sum()) 165 | fmean = np.sum(np.arange(len(x)) * axfs / axfs.sum()) 166 | xref = np.cos(2*np.pi * fmean * np.linspace(0, 1, N, 0)) 167 | xref *= x.real.max() / xref.max() 168 | 169 | plot(xref, color=grey) 170 | plot(x.real, show=1, title="energy center frequency") 171 | 172 | #%% For spectrum use one near Nyquist 173 | xref = ref_sine(xi=.5, N=129, endpoint=1) 174 | Wx, _ = cwt(xref, 'morlet', scales=scales0) 175 | imshow(Wx, abs=1) 176 | 177 | slc = np.abs(Wx[:, 0]) 178 | fmean_nyq = np.sum(np.arange(len(slc)) * (slc**2) / (slc**2).sum()) 179 | scat(slc, vlines=(fmean_nyq, {'color': 'tab:red'}), show=1) 180 | imshow(Wx, abs=1) 181 | 182 | #%%# Central instantaneous center frequency ################################## 183 | # determine central instantaneous frequency from CWT since SSQ 184 | # reassigns nonlinearly 185 | N = len(Tx0[0]) 186 | finst_idx = np.argmax(np.abs(Wx0[:, 256])) 187 | psihs = wavelet1.Psih() 188 | finst = np.argmax(np.abs(psihs[finst_idx])) // 2 # x2 pad doubles index 189 | 190 | #%% 191 | xref = np.cos(2*np.pi * finst * np.linspace(0, 1, N, 0)) 192 | xref *= x.real.max() / xref.max() 193 | ctr = N // 2 194 | d = 9 195 | a, b = ctr - d, ctr + d + 1 196 | xref = xref[a:b] * N / (a - b) 197 | _t = np.arange(a, b) 198 | 199 | plot(_t, xref, color=grey, auto_xlims=0) 200 | plot(x.real, show=1, title="central instantaneous center frequency") 201 | -------------------------------------------------------------------------------- /SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/wavgif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76329 - CWT, Wavelets - Scale vs frequency/wavgif.gif -------------------------------------------------------------------------------- /SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im0.png -------------------------------------------------------------------------------- /SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im1.png -------------------------------------------------------------------------------- /SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/im2.png -------------------------------------------------------------------------------- /SignalProcessing/Q76463 - FM - Frequency Modulating a Signal/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/76463/50076 ################################ 3 | import numpy as np 4 | from numpy.fft import fft, ifft 5 | from scipy.io import wavfile 6 | from ssqueezepy import ssq_cwt, Wavelet 7 | from ssqueezepy.visuals import imshow, plot 8 | 9 | #%%# Helper methods ########################################################## 10 | def frequency_modulate(slc, fc=None, b=.3): 11 | N = len(slc) 12 | if fc is None: 13 | fc = N / 18 # arbitrary 14 | # track actual `b` for demodulation purposes 15 | b_effective = b 16 | 17 | t_min, t_max = start / fs, end / fs 18 | t = np.linspace(t_min, t_max, N, endpoint=False) 19 | assert np.allclose(fs, 1 / np.diff(t)) 20 | 21 | x0 = slc[:N] 22 | # ensure it's [-.5, .5] so diff(phi) is b*[-pi, pi] 23 | x0max = np.abs(x0).max() 24 | x0 /= (2*x0max) 25 | b_effective /= (2*x0max) 26 | 27 | # generate phase 28 | phi0 = 2*np.pi * fc * t 29 | phi1 = 2*np.pi * b * np.cumsum(x0) 30 | phi = phi0 + phi1 31 | diffmax = np.abs(np.diff(phi)).max() 32 | # `b` correction 33 | if diffmax > np.pi or np.allclose(phi, np.pi): 34 | diffmax0 = np.abs(np.diff(phi0)).max() 35 | diffmax1 = np.abs(np.diff(phi1)).max() 36 | # epsilon term for stable inversion / pi-unambiguity 37 | eps = 1e-7 38 | factor = ((np.pi - diffmax0 - eps) / diffmax1) 39 | phi1 *= factor 40 | b_effective *= factor 41 | phi = phi0 + phi1 42 | assert np.abs(np.diff(phi)).max() <= np.pi 43 | 44 | # modulate 45 | x = np.cos(phi) 46 | return x, t, phi0, phi1, b_effective 47 | 48 | def analytic(x): 49 | N = len(x) 50 | xf = fft(x) 51 | 52 | xaf = np.zeros(N, dtype='complex128') 53 | xaf[:N//2 + 1] = 2 * xf[:N//2 + 1] 54 | xaf[0] /= 2 55 | xaf[N//2] /= 2 56 | 57 | xa = ifft(xaf) 58 | assert np.allclose(xa.real, x) 59 | return xa 60 | 61 | #%%# Load data, select slice ################################################# 62 | fs, data = wavfile.read(r"C:\Desktop\recording.wav") 63 | data = data.astype('float64') 64 | data /= (2*np.abs(data).max()) 65 | 66 | start, end = 0, fs 67 | slc = data[start:end] 68 | 69 | #%%# Modulate & validate ##################################################### 70 | x, t, phi0, phi10, b_effective = frequency_modulate(slc) 71 | 72 | # extreme time localization 73 | wavelet = Wavelet(('gmw', {'gamma': 1, 'beta': 1})) 74 | # synchrosqueezed CWT 75 | Tx, Wx, ssq_freqs, *_ = ssq_cwt(x, wavelet) 76 | # ints for better plotting 77 | ssq_freqs = (ssq_freqs * fs).astype(int) 78 | 79 | #%%# Visualize ############################################################## 80 | def viz(t_min=None, t_max=None, f_min=None, f_max=None, show_original=False, 81 | show_modulated=False): 82 | freqs = ssq_freqs[::-1] 83 | a = int(t_min * fs) if t_min is not None else 0 84 | b = int(t_max * fs) if t_max is not None else None 85 | d = np.argmin(np.abs(freqs - f_min)) if f_min is not None else None 86 | c = np.argmin(np.abs(freqs - f_max)) if f_max is not None else 0 87 | 88 | imshow(Tx[c:d, a:b], xticks=t[a:b], yticks=freqs[c:d], **kw) 89 | 90 | if show_original: 91 | plot(t[a:b], slc[a:b], xlabel="time [sec]", title="original data", 92 | show=1) 93 | if show_modulated: 94 | plot(t[a:b], x[a:b], xlabel="time [sec]", title="modulated data", 95 | show=1) 96 | 97 | mx = np.abs(Tx).max() * .4 98 | kw = dict(abs=1, norm=(0, mx), title="abs(SSQ_CWT)", 99 | xlabel="time [sec]", ylabel="frequency [Hz]") 100 | viz() 101 | viz(0, .2, 300, 8000, show_original=1) 102 | viz(0, .02, 300, 8000, show_original=1, show_modulated=1) 103 | 104 | #%% 105 | # plot(x, show=1) 106 | # plot(slc, show=1) 107 | phie = np.arccos(x)[:2048] 108 | 109 | phi = np.unwrap(np.angle(analytic(x)))[:2048] 110 | phi_exact = (phi0 + phi10)[:2048] 111 | phi[:10] = phi_exact[:10] 112 | ae = np.abs(phi - phi_exact) 113 | print(ae.mean(), ae.max(), ae.min()) 114 | 115 | phi1 = (phi - phi0[:2048]) 116 | x_inv_cs = phi1 / (2*np.pi * b_effective) 117 | x_inv = np.diff(x_inv_cs) 118 | plot(x_inv) 119 | -------------------------------------------------------------------------------- /SignalProcessing/Q76560 - Analyticity - Symmetries/main.py: -------------------------------------------------------------------------------- 1 | # https://dsp.stackexchange.com/q/76560/50076 ################################ 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from numpy.fft import ifft, ifftshift 5 | from kymatio.scattering1d.filter_bank import morlet_1d 6 | from ssqueezepy.visuals import plot as _plot, plotscat 7 | 8 | def plot(*args, **kw): 9 | if 'title' in kw: 10 | kw['title'] = (kw['title'], {'fontsize': 18}) 11 | mx = re_im_max(args[0]) * 1.05 12 | kw['ylims'] = (-mx, mx) 13 | _plot(*args, **kw) 14 | 15 | def re_im_max(x): 16 | return max(np.abs(x.real).max(), np.abs(x.imag).max()) 17 | 18 | def re_im_parts(x, y): 19 | A = x.real * y.real 20 | B = -x.imag * y.imag 21 | C = x.real * y.imag 22 | D = x.imag * y.real 23 | return A, B, C, D 24 | 25 | def l2(x): 26 | return np.sqrt(np.sum(np.abs(x)**2)) 27 | 28 | #%%# Morlet visuals ########################################################## 29 | N = 256 30 | pf = morlet_1d(N, xi=.025, sigma=.1/20) 31 | pt = ifftshift(ifft(pf)) 32 | #%% 33 | kw = dict(complex=1, show=1, ticks=(1, 0)) 34 | plot(pt, **kw, title="psi: analytic Morlet") 35 | 36 | #%% 37 | A, B, C, D = re_im_parts(pt, pt) 38 | plot(A + 1j*B, **kw, title="(psi * psi).real") 39 | plot(C + 1j*D, **kw, title="(psi * psi).imag") 40 | plot(pt * pt, **kw, 41 | title="(psi * psi).real + (psi * psi).imag = psi * psi") 42 | 43 | #%% 44 | pta = np.conj(pt) 45 | plot(pta, **kw, title="psia: anti-analytic Morlet") 46 | 47 | #%% 48 | plot(pt.imag + 1j*pta.imag, **kw, title="psia = psi.real - psi.imag") 49 | 50 | #%% 51 | A, B, C, D = re_im_parts(pt, pta) 52 | plot(A + 1j*B, **kw, title="(psi * psia).real") 53 | plot(C + 1j*D, **kw, title="(psi * psia).imag") 54 | plot(pt * pta, **kw, 55 | title="(psi * psia).real + (psi * psia).imag = psi * psia") 56 | 57 | #%%# Random sequence visuals ################################################# 58 | def viz_seq(x, show=1, complex=1, **kw): 59 | mx = re_im_max(x) * 1.1 60 | plotscat(x, complex=complex, ylims=(-mx, mx), show=show, **kw, 61 | hlines=(0, {'color': 'tab:red', 'linewidth': 1}), 62 | ticks=(1, 0), auto_xlims=0) 63 | 64 | def zero_sum_seq(N): 65 | x = np.random.randn(N) + 1j*np.random.randn(N) 66 | x[N//2:] = 0 67 | 68 | slc = x[:N//2][::-1] 69 | x[N//2:] = slc.real - 1j * slc.imag 70 | x -= x.mean() 71 | 72 | x.real *= (l2(x.imag) / l2(x.real)) 73 | return x 74 | 75 | #%% 76 | np.random.seed(0) 77 | N = 12 78 | x = zero_sum_seq(N) 79 | y = zero_sum_seq(N) 80 | 81 | viz_seq(x, title="x | sum = {:.2f}".format(x.sum())) 82 | viz_seq(y, title="y | sum = {:.2f}".format(y.sum())) 83 | viz_seq(x * x, title="x * x | sum = {:.2f}".format((x*x).sum())) 84 | viz_seq(y * y, title="y * y | sum = {:.2f}".format((y*y).sum())) 85 | viz_seq(x * y, title="x * y | sum = {:.2f}".format((x * y).sum())) 86 | -------------------------------------------------------------------------------- /SignalProcessing/Q76560 - Analyticity - Symmetries/morlets0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets0.png -------------------------------------------------------------------------------- /SignalProcessing/Q76560 - Analyticity - Symmetries/morlets1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/morlets1.png -------------------------------------------------------------------------------- /SignalProcessing/Q76560 - Analyticity - Symmetries/rands0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76560 - Analyticity - Symmetries/rands0.png -------------------------------------------------------------------------------- /SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/WGN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/WGN.png -------------------------------------------------------------------------------- /SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/im2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/im2.png -------------------------------------------------------------------------------- /SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/ims.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/ims.gif -------------------------------------------------------------------------------- /SignalProcessing/Q76636 - filters - Why is x(n) - x(n - 1) + x(n + 2) lowpass/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot(x, title): 6 | fig = plt.figure() 7 | plt.plot(np.abs(x)) 8 | plt.scatter(np.arange(len(x)), np.abs(x), s=10) 9 | plt.title(title, weight='bold', fontsize=18, loc='left') 10 | return fig 11 | 12 | def plot_T(x, Tmax): 13 | if Tmax == 0: 14 | title = "|H(w)|: x(n)" 15 | elif Tmax == 1: 16 | title = "|H(w)|: x(n) - x(n - 1)" 17 | elif Tmax == 2: 18 | title = "|H(w)|: x(n) - x(n - 1) + x(n - 2)" 19 | else: 20 | title = "|H(w)|: x(n) - x(n - 1) + x(n - 2) - ... x(n - %s)" % Tmax 21 | 22 | fig = plot(x, title, scatter=1) 23 | plt.ylim(-.05, 1.05) 24 | 25 | plt.savefig(f'im{Tmax}.png', bbox_inches='tight') 26 | plt.close(fig) 27 | 28 | def csoid(f): 29 | return (np.cos(2*np.pi* f * t) - 30 | np.sin(2*np.pi* f * t) * 1j) 31 | 32 | #%%# Direct frequency response ############################################### 33 | N = 32 34 | t = np.linspace(0, 1, N, 0) 35 | 36 | for Tmax in range(N): 37 | x = np.sum([(-1)**T * csoid(T) for T in range(Tmax + 1)], axis=0) 38 | x /= np.abs(x).max() 39 | plot_T(x, Tmax) 40 | 41 | #%%# WGN example ############################################################# 42 | def plot_and_save(x, title, savepath): 43 | fig = plot(x, title) 44 | plt.savefig(savepath, bbox_inches='tight') 45 | plt.close(fig) 46 | 47 | np.random.seed(69) 48 | x = np.random.randn(32) 49 | xf0 = np.fft.fft(x) 50 | x = x - np.roll(x, 1) + np.roll(x, 2) 51 | xf1 = np.fft.fft(x) 52 | 53 | plot_and_save(xf0, "|X(w)|: x(n)", "WGN0.png") 54 | plot_and_save(xf1, "|X(w)|: x(n) - x(n - 1) + x(n - 2)", "WGN1.png") 55 | -------------------------------------------------------------------------------- /SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/_optimized.pyx: -------------------------------------------------------------------------------- 1 | import cython 2 | from libc.math cimport atan2 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cpdef double kay_weighted_complex(double[:] R, double[:] I, double[:] W): 8 | # initialize variables 9 | cdef Py_ssize_t N = R.shape[0] 10 | cdef Py_ssize_t i = 0 11 | cdef double f_est = 0 12 | 13 | # main loop 14 | for i in range(N - 1): 15 | f_est += W[i] * atan2(R[i]*I[i + 1] - I[i]*R[i + 1], 16 | R[i]*R[i + 1] + I[i]*I[i + 1]) 17 | 18 | return f_est 19 | 20 | 21 | @cython.boundscheck(False) 22 | @cython.wraparound(False) 23 | cpdef int abs_argmax(double[:] R, double[:] I): 24 | # initialize variables 25 | cdef Py_ssize_t N = R.shape[0] 26 | cdef Py_ssize_t i = 0 27 | cdef int max_idx = 0 28 | cdef double current_max = 0 29 | cdef double current_abs2 = 0 30 | 31 | # main loop 32 | for i in range(N): 33 | current_abs2 = R[i]*R[i] + I[i]*I[i] 34 | if current_abs2 > current_max: 35 | max_idx = i 36 | current_max = current_abs2 37 | 38 | return max_idx 39 | -------------------------------------------------------------------------------- /SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/benchmarks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/76644/50076 3 | import os 4 | os.environ['NUMEXPR_NUM_THREADS'] = '1' 5 | os.environ['OMP_NUM_THREADS'] = '1' 6 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 7 | os.environ['MKL_NUM_THREADS'] = '1' 8 | os.environ["VECLIB_MAXIMUM_THREADS"] = '1' 9 | os.environ["NUMEXPR_NUM_THREADS"] = '1' 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | # ensure the files can be found 15 | import sys 16 | from pathlib import Path 17 | _dir = Path(__file__).parent 18 | assert _dir.is_file() or _dir.is_dir(), str(_dir) 19 | if not any(str(_dir).lower() == p.lower() for p in sys.path): 20 | sys.path.insert(0, str(_dir)) 21 | 22 | from optimized import Kay2Complex, Cedron3BinComplex 23 | from utils76644 import make_x, timeit 24 | 25 | #%% 26 | kay2_complex_optimized = Kay2Complex() 27 | cedron_3bin_complex_optimized = Cedron3BinComplex() 28 | fns = { 29 | 'kay_2': kay2_complex_optimized, 30 | 'cedron_3bin': cedron_3bin_complex_optimized, 31 | } 32 | 33 | #%% Manual testing ########################################################### 34 | np.random.seed(0) 35 | name = ('cedron_3bin', 'kay_2')[1] 36 | N = 100 37 | f = N*0.153123 38 | snr = 100 39 | x = make_x(N, f, snr, real=False) 40 | 41 | fn = fns[name] 42 | f_est = fn(x) 43 | 44 | print(f_est / (f/N), sep='\n') 45 | 46 | #%% Benchmark ################################################################ 47 | np.random.seed(0) 48 | times = {'cedron_3bin': {}, 'kay_2': {}} 49 | n_trials = 400 50 | repeats = 20 51 | Npow2 = True 52 | 53 | N_all = 2**np.arange(7, 21) 54 | if not Npow2: 55 | for i, N in enumerate(N_all): 56 | N_all[i] = round(N, -(len(str(N)) - 2)) 57 | 58 | for N in N_all: 59 | x = np.random.randn(N) + 1j*np.random.randn(N) 60 | for name in times: 61 | fn = fns[name] 62 | # warmup 63 | for _ in range(10): 64 | fn(x) 65 | # bench 66 | times[name][N] = timeit(fn, x, n_trials, repeats) 67 | print(end='.', flush=True) 68 | 69 | print(times) 70 | 71 | #%% Visualize ################################################################ 72 | data = [np.array(list(times['cedron_3bin'].values())), 73 | np.array(list(times['kay_2'].values()))] 74 | N_all = np.array(list(times['cedron_3bin'])) 75 | 76 | fig, ax = plt.subplots(1, 1, layout='constrained', figsize=(9.5, 8)) 77 | 78 | ax.plot( N_all, data[1] / data[0]) 79 | ax.scatter(N_all, data[1] / data[0], s=40) 80 | 81 | ax.set_xscale('log') 82 | 83 | ax.set_xlabel("N", size=20) 84 | ax.set_ylabel("t_ratio", size=20) 85 | title = ("Time ratios (Kay2_complex/Cedron_complex)\n" 86 | "n_trials={}, n_repeats={}\n" 87 | "Npow2={}, FFTW=False" 88 | ).format(n_trials, repeats, Npow2) 89 | ax.set_title(title, weight='bold', fontsize=24, loc='left') 90 | ax.set_ylim(0, 3) 91 | ax.axhline(1, color='tab:red') 92 | 93 | plt.show() 94 | -------------------------------------------------------------------------------- /SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/76644/50076 3 | # All code is written for readability, not performance. 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | # ensure the files can be found 8 | import sys 9 | from pathlib import Path 10 | _dir = Path(__file__).parent 11 | assert _dir.is_file() or _dir.is_dir(), str(_dir) 12 | if not any(str(_dir).lower() == p.lower() for p in sys.path): 13 | sys.path.insert(0, str(_dir)) 14 | 15 | from estimators import est_freq, estimator_fns 16 | from utils76644 import ( 17 | make_x, run_test, run_test_multitone, snrs_db_practical, snrs_db_wide, 18 | run_viz, run_viz2, run_viz_multitone) 19 | 20 | print("Available estimators:\n " + "\n ".join(list(estimator_fns))) 21 | 22 | #%% Configurations ########################################################### 23 | # USER ----------------------------------------------------------------------- 24 | # prints test progress 25 | VERBOSE = 1 26 | # other configs in `utils` file 27 | 28 | # set certain defaults 29 | f_N_all_nonints_large_N = (0.05393, 0.10696, 0.25494, 0.46595) 30 | f_N_all_nonints_small_N = (0.053, 0.106, 0.254, 0.465) 31 | f_N_all_ints_small_N = (0.05, 0.10, 0.25, 0.46) 32 | f_N_all_ints_large_N = f_N_all_nonints_small_N 33 | 34 | #%% Manual testing ########################################################### 35 | np.random.seed(0) 36 | name = ('cedron_3bin', 'kay_2', 'dft_quadratic')[0] 37 | # name = ('cedron_3bin_complex',)[0] 38 | N = 10000 39 | f = N*0.053123 40 | phi = 1 41 | x = np.cos(2*np.pi * f * np.arange(N)/N + phi) 42 | # x = x + 1j*np.sin(2*np.pi * f * np.arange(N)/N) 43 | x += np.random.randn(N) * .1 44 | 45 | print(est_freq(x, name) / (f/N), sep='\n') 46 | 47 | #%% Full testing ############################################################# 48 | # configure 49 | seed = 0 50 | N = 100 51 | n_trials = 2000 52 | real = True 53 | sweep_mode = ('practical', 'wide')[0] 54 | name0, name1 = 'cedron_3bin', 'kay_2' 55 | # name0, name1 = 'cedron_3bin', 'dft_quadratic' 56 | # name0, name1 = 'cedron_3bin_complex', 'kay_2' 57 | f_N_all = (f_N_all_nonints_small_N, f_N_all_nonints_large_N, 58 | f_N_all_ints_small_N, f_N_all_ints_large_N)[0] 59 | # f_N_all = np.linspace(1/N, .5-1/N, 50) 60 | # f_N_all = np.linspace(-.5+1/N, .5-1/N, 100) 61 | 62 | errs0_all, errs1_all, snrs, crlbs = run_test( 63 | f_N_all, N, n_trials, name0, name1, real, seed, sweep_mode, verbose=VERBOSE) 64 | 65 | #%% Visualize ################################################################ 66 | names = ("Cedron", "Kay_2") 67 | # names = ("Cedron", "DFT_quadratic") 68 | 69 | args = (errs0_all, errs1_all, snrs, crlbs, f_N_all, N, n_trials, names) 70 | run_viz(*args) 71 | # run_viz2(*args) 72 | 73 | #%% Multi-tone Example ####################################################### 74 | seed = 0 75 | N = 10000 76 | n_trials = 2000 77 | name0, name1 = 'cedron_3bin', 'dft_quadratic' 78 | # include integer case 79 | f_N_all = (0.05305, 0.10605, 0.254, 0.46505) 80 | # f_N_all = (0.10601, 0.10644, 0.10696, 0.10747) 81 | A_all = (0.5, 0.8, 1.2, 1.5) # mean=1 82 | # each `A` will have different SNR, so extend the range so we can plot all 83 | # under a common snr 84 | snrs_bounds = (snrs_db_practical[0], snrs_db_practical[-1]) 85 | snrs = np.linspace(snrs_bounds[0] - 10, snrs_bounds[1] + 15, 86 | int(len(snrs_db_practical)*1.25)) 87 | 88 | errs0_all, errs1_all, snrs_f_N, crlbs = run_test_multitone( 89 | f_N_all, A_all, N, n_trials, name0, name1, snrs, seed, verbose=VERBOSE) 90 | 91 | #%% Visualize ################################################################ 92 | names = ("Cedron", "DFT_quadratic") 93 | run_viz_multitone(errs0_all, errs1_all, snrs_f_N, crlbs, 94 | f_N_all, A_all, N, n_trials, snrs_bounds, names=names) 95 | 96 | #%% "Extreme" example ######################################################## 97 | N = 100000 98 | f_N = .053056 99 | snr = -30 100 | f = f_N * N 101 | 102 | errs = [] 103 | for _ in range(2000): 104 | x = make_x(N, f, snr) 105 | f_est = est_freq(x, 'cedron_3bin') 106 | errs.append((f_N - f_est)**2) 107 | mse = np.mean(errs) 108 | 109 | print(mse) 110 | 111 | #%% 112 | x, xo = make_x(N, f, snr, get_xo=True) 113 | fig, axes = plt.subplots(1, 2, figsize=(15.5, 6.5), layout='constrained') 114 | axes[0].plot(xo[:500]) 115 | axes[0].set_title("Original signal, zoomed | N=100000", 116 | weight='bold', loc='left', fontsize=24) 117 | axes[1].plot(x[:500]) 118 | axes[1].set_title("Noisy, zoomed | SNR=1/1000 (-30dB)", 119 | weight='bold', loc='left', fontsize=24) 120 | -------------------------------------------------------------------------------- /SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/optimized.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Performance-optimized estimators.""" 3 | # https://dsp.stackexchange.com/q/76644/50076 4 | import numpy as np 5 | from scipy.fft import fft 6 | from _optimized import kay_weighted_complex, abs_argmax 7 | 8 | 9 | class Kay2Complex(): 10 | def __init__(self): 11 | self.weights_N = {} 12 | 13 | def __call__(self, x, N=None): 14 | N = len(x) 15 | if N in self.weights_N: 16 | weights = self.weights_N[N] 17 | else: 18 | idxs = np.arange(N - 1) 19 | weights = 1.5*N / (N**2 - 1) * (1 - ((idxs - (N/2 - 1)) / (N/2))**2 20 | ) / (2*np.pi) 21 | self.weights_N[N] = weights 22 | 23 | f_est = kay_weighted_complex(x.real, x.imag, weights) 24 | return f_est 25 | 26 | 27 | class Cedron3BinComplex(): 28 | def __call__(self, x): 29 | xf = fft(x) 30 | kmax = abs_argmax(xf.real, xf.imag) 31 | return 5 32 | 33 | # didn't bother implementing the O(1) finishing step. 34 | # it'd only slightly change the N=100 result, and not at all for 35 | # anything else. 36 | # the handicap of not using FFTW is far greater 37 | -------------------------------------------------------------------------------- /SignalProcessing/Q76644 - estimation - Estimate sine frequency under white noise/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # python setup.py build_ext --inplace 3 | from distutils import _msvccompiler 4 | _msvccompiler.PLAT_TO_VCVARS['win-amd64'] = 'amd64' 5 | 6 | from setuptools import setup, Extension 7 | from Cython.Build import cythonize 8 | import numpy as np 9 | 10 | setup( 11 | ext_modules=cythonize(Extension("_optimized", ["_optimized.pyx"]), 12 | language_level=3), 13 | include_dirs=[np.get_include()], 14 | ) 15 | -------------------------------------------------------------------------------- /SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/data.npy -------------------------------------------------------------------------------- /SignalProcessing/Q76754 - Hilbert - Alleviate boundary effects/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/76754/50076 ################################ 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from numpy.fft import fft, ifft 6 | 7 | def analytic(x): 8 | N = len(x) 9 | xf = fft(x) 10 | xf[1:N//2] *= 2 11 | if N % 2 == 1: 12 | xf[N//2] *= 2 13 | xf[N//2 + 1:] = 0 14 | xa = ifft(xf) 15 | assert np.allclose(xa.real, x) 16 | return xa 17 | 18 | def plot(x, title=None, show=0): 19 | plt.plot(x) 20 | if title is not None: 21 | plt.title(title, loc='left', weight='bold', fontsize=17) 22 | if show: 23 | plt.show() 24 | 25 | #%%########################################################################### 26 | # load 27 | x = np.load('data.npy') 28 | N = len(x) 29 | xa = analytic(x) 30 | 31 | # pad + take hilbert 32 | xp = np.pad(x, N, mode='reflect') 33 | xpa = analytic(xp) 34 | # unpad 35 | xpu = xp[N:-N] 36 | xpau = xpa[N:-N] 37 | 38 | #%% visualize 39 | plot(x, title="original") 40 | plot(np.abs(xa), show=1) 41 | 42 | plot(xpu, title="reflect-padded + unpadded") 43 | plot(np.abs(xpau), show=1) 44 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/cwt_phi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Continuous Wavelet Transform w/ lowpassing (first-order scattering) GIF.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.animation as animation 8 | from numpy.fft import fft, ifft, fftshift, ifftshift 9 | from kymatio.numpy import Scattering1D 10 | from kymatio.toolkit import echirp 11 | 12 | #%%## Set params & create scattering object ################################## 13 | N = 512 14 | J, Q = 4, 3 15 | T = 2**(J + 1) 16 | # make CWT 17 | average, oversampling = False, 999 18 | ts = Scattering1D(shape=N, J=J, Q=Q, average=average, oversampling=oversampling, 19 | out_type='list', max_order=1, T=T, max_pad_factor=None) 20 | ts.psi1_f.pop(-1) # drop last for niceness 21 | 22 | #%%# Create signal & warp it ################################################# 23 | x = echirp(N, fmin=16, fmax=N/2.0) 24 | t = np.linspace(0, 1, N, 1) 25 | 26 | meta = ts.meta() 27 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 28 | 29 | #%%# Animate ################################################################# 30 | class CWTAnimation(animation.TimedAnimation): 31 | def __init__(self, ts, x, t, freqs, stride=1, alpha='6f'): 32 | self.ts = ts 33 | self.x = x 34 | self.t = t 35 | self.freqs = freqs 36 | self.stride = stride 37 | self.N = len(x) 38 | assert self.N % 2 == 0 39 | 40 | # unpack filters 41 | self.psi_fs = [p[0] for p in ts.psi1_f] 42 | self.phi_f = ts.phi_f[0] 43 | self.phi_t = ifft(self.phi_f).real 44 | # conjugate since conv == cross-correlation with conjugate 45 | psi_ts = np.array([np.conj(ifft(p)) for p in self.psi_fs]) 46 | self.psi_ts = psi_ts / psi_ts.real.max() 47 | 48 | # compute filter params 49 | self.js = [p['j'] for p in ts.psi1_f] 50 | self.lens = [int(1.2*p['support'][0]) for p in ts.psi1_f] 51 | self.n_psis = len(self.psi_fs) 52 | self.Np = len(self.psi_fs[0]) 53 | self.trim = ts.ind_start[0] 54 | self.phi_t_trim = fftshift(ifftshift(self.phi_t)[self.trim:-self.trim]) 55 | # padded x to do CWT with 56 | self.xpf = fft(np.pad(x, [ts.pad_left, ts.pad_right], mode='reflect')) 57 | 58 | # take CWT without modulus 59 | self.Wx = np.array([ifft(p * self.xpf) for p in self.psi_fs]) 60 | self.Ux = np.abs(self.Wx) 61 | self.Sx = ifft(fft(self.Ux) * self.phi_f).real 62 | start, end = ts.ind_start[0], ts.ind_end[0] 63 | self.Sx = self.Sx[:, start:end] 64 | self.Ux = self.Ux[:, start:end] 65 | # suppress boundary effects for visuals 66 | mx = self.Sx.max() 67 | mx_avg = self.Sx[5].max() 68 | self.Sx[0] *= (mx_avg / mx) * 1.3 69 | self.Sx[-1] *= (mx_avg / mx) * 1.3 70 | # spike causes graphical glitch in converting to gif -- flatten 71 | mx = self.Ux.max() * 1.02 72 | self.Ux[0] *= .58 / .61 73 | self.Ux[0, -1] = mx 74 | 75 | self.alpha = '6f' 76 | self.title_kw = dict(weight='bold', fontsize=16, loc='left') 77 | self.label_kw = dict(weight='bold', fontsize=14, labelpad=3) 78 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left") 79 | self._prev_psi_idx = -1 80 | 81 | fig, axes = plt.subplots(1, 2, figsize=(18/1.7, 8/1.7)) 82 | 83 | # |CWT|*phi_f #################################################### 84 | ax = axes[0] 85 | self.lines0 = [] 86 | 87 | ax.imshow(self.Ux, cmap='turbo', aspect='auto', interpolation='none') 88 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True)) 89 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True) 90 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels]) 91 | yticks = np.linspace(0, self.n_psis - 1, 6, endpoint=True) 92 | ax.set_yticks(yticks) 93 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N) 94 | for yt in yticks]) 95 | 96 | phi_t = 6 - 600*self.phi_t_trim 97 | ax.plot(np.arange(self.N), phi_t, color='white', linewidth=2) 98 | self.lines0.append(ax.lines[-1]) 99 | 100 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=15) 101 | 102 | # output ######################################################### 103 | ax = axes[1] 104 | im = ax.imshow(self.Sx, cmap='turbo', animated=True, aspect='auto', 105 | interpolation='none') 106 | self.Sx_now = 0 * self.Sx 107 | im.set_array(self.Sx_now) 108 | self.ims1 = [im] 109 | 110 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True)) 111 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True) 112 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels]) 113 | yticks = np.linspace(0, self.n_psis - 1, 6, endpoint=True) 114 | ax.set_yticks(yticks) 115 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N) 116 | for yt in yticks]) 117 | ax.set_title(r"$|\psi \star x| \star \phi$", **self.title_kw) 118 | 119 | # finalize ####################################################### 120 | fig.subplots_adjust(left=.05, right=.99, bottom=.06, top=.94, 121 | wspace=.15) 122 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True) 123 | 124 | def _draw_frame(self, frame_idx): 125 | step = self.stride * frame_idx % self.N 126 | total_step = self.stride * frame_idx 127 | psi_idx = int(total_step / self.N) 128 | # at right bound 129 | if self.stride != 1 and (step < self.stride - 1 and frame_idx > 1): 130 | step = self.N 131 | psi_idx -= 1 132 | 133 | # |CWT|*phi_f #################################################### 134 | T = self.ts.phi_f['support'] 135 | start, end = max(0, step - T//2), min(self.N, step + T//2) 136 | phi_t = 6 - 600*self.phi_t_trim 137 | phi_t = np.roll(phi_t, step)[start:end] 138 | self.lines0[0].set_data(np.arange(start, end), phi_t) 139 | 140 | tau = "%.2f" % ((step / self.N) * (self.t.max() - self.t.min())) 141 | self.txt0.set_text(r"$\phi(t - {}),\ |\psi \star x|$".format(tau)) 142 | 143 | # output ######################################################### 144 | self.Sx_now[:, start:step] = self.Sx[:, start:step] 145 | self.ims1[0].set_array(self.Sx_now) 146 | 147 | # finalize ########################################################### 148 | self._prev_psi_idx = psi_idx 149 | self._drawn_artists = [*self.lines0, self.txt0, *self.ims1] 150 | 151 | def new_frame_seq(self): 152 | return iter(range(1 * int(self.N // self.stride + 1))) 153 | 154 | def _init_draw(self): 155 | pass 156 | 157 | 158 | ani = CWTAnimation(ts, x, t, freqs, stride=2) 159 | ani.save('cwt_phi.mp4', fps=60) 160 | plt.show() 161 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/reconstruction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/reconstruction.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_resc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_resc.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_unscaled_anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/scat_impulse_unscaled_anim.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/second_order.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Second-order unaveraged scattering on A.M. cosine and White Gaussian Noise.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | from numpy.fft import fft, ifft 7 | import matplotlib.pyplot as plt 8 | from kymatio.numpy import Scattering1D 9 | from kymatio.visuals import (plot, imshow, 10 | filterbank_heatmap, filterbank_scattering) 11 | 12 | #%% 13 | def cwt_and_viz(x, ts, show_max_rows=False): 14 | # do manually to override `j2 > j1` and see all coeffs 15 | xp = np.pad(x, [ts.pad_left, ts.pad_right], mode='reflect') 16 | U0_f = fft(xp) 17 | 18 | # |CWT| 19 | U1 = [] 20 | for p1 in ts.psi1_f: 21 | U1.append(np.abs(ifft(U0_f * p1[0]))) 22 | U1 = np.array(U1) 23 | U1_f = fft(U1, axis=-1) 24 | 25 | # U2 26 | U2 = [] 27 | for p2 in ts.psi2_f: 28 | U2.append(np.abs(ifft(U1_f * p2[0]))) 29 | U2 = np.array(U2) 30 | 31 | # S1 = ifft(U1_f * ts.phi_f[0]).real 32 | # U2_f = fft(U2, axis=-1) 33 | # S2 = ifft(U2_f * ts.phi_f[0][None, None]).real 34 | 35 | # unpad all 36 | s, e = ts.ind_start[0], ts.ind_end[0] 37 | U1, U2 = [g[..., s:e] for g in (U1, U2)] 38 | # S1, S2 = [g[..., s:e] for g in (S1, S2)] 39 | 40 | # viz first-order ######################################################## 41 | fs = N 42 | xi1s = np.array([p['xi'] for p in ts.psi1_f]) * fs 43 | imshow(U1, abs=1, title="|CWT(x)|", yticks=xi1s, ylabel="frequency [Hz]", 44 | xlabel="time", xticks=t, w=.8, h=.6) 45 | 46 | if show_max_rows: 47 | mx_idx = np.argmax(np.sum(U1**2, axis=-1)) 48 | xi1 = xi1s[mx_idx] 49 | title = "|CWT(x, xi1={:.1f} Hz)| (max row)".format(xi1) 50 | plot(t, U1[mx_idx], title=title, ylims=(0, None), show=1, 51 | ylabel="A.M. rate [Hz]", xlabel="time [sec]", w=.7, h=.9) 52 | 53 | # viz second-oder ######################################################## 54 | fig, axes = plt.subplots(3, 3, figsize=(16, 16)) 55 | xi2s = np.array([p['xi'] for p in ts.psi2_f]) * fs 56 | 57 | # U2 /= (U1 + U1.max()/10000)[None] 58 | # U2 = np.log(1 + U2) 59 | mx = np.max(U2) * .9 60 | for U2_idx, ax in enumerate(axes.flat): 61 | U2_idx += 1 62 | xi2 = xi2s[U2_idx] 63 | title = "xi2 = {:.1f} Hz".format(xi2) 64 | imshow(U2[U2_idx], abs=1, title=title, norm=(0, mx), 65 | show=0, ax=ax, ticks=0) 66 | 67 | label_kw = dict(weight='bold', fontsize=17) 68 | axes[0, 0].set_ylabel("xi1", **label_kw) 69 | axes[1, 0].set_ylabel("xi1", **label_kw) 70 | axes[2, 0].set_ylabel("xi1", **label_kw) 71 | axes[2, 0].set_xlabel("time", **label_kw) 72 | axes[2, 1].set_xlabel("time", **label_kw) 73 | axes[2, 2].set_xlabel("time", **label_kw) 74 | 75 | plt.subplots_adjust(wspace=.02, hspace=.1) 76 | plt.show() 77 | 78 | if show_max_rows: 79 | mx_idx2 = np.argmax(np.sum(U2**2, axis=(1, 2))) 80 | mx_idx1 = np.argmax(np.sum(U2[mx_idx2]**2, axis=-1)) 81 | xi1, xi2 = xi1s[mx_idx1], xi2s[mx_idx2] 82 | title = "|CWT(|CWT(x, xi1={:.1f} Hz)|, xi2={:.1f} Hz)| (max row)".format( 83 | xi1, xi2) 84 | U2_max_row = U2[mx_idx2, mx_idx1] 85 | 86 | plot(t, U2_max_row, title=title, show=1, xlabel="time [sec]", 87 | ylabel="Rate of A.M. rate [Hz^2]", 88 | ylims=(0, U2_max_row.max() * 1.03)) 89 | 90 | #%% 91 | N = 2049 92 | kw = dict(shape=N, J=9, Q=8, max_pad_factor=None, oversampling=999, max_order=1) 93 | ts = Scattering1D(**kw) 94 | 95 | #%% 96 | filterbank_scattering(ts, second_order=1, zoom=-1) 97 | filterbank_heatmap(ts, second_order=True) 98 | 99 | #%%# AM cosine ############################################################### 100 | f0, f1 = 64, 3 101 | t = np.linspace(0, 1, N, 1) 102 | c = np.cos(2*np.pi * f0 * t) 103 | a = (1 + (np.cos(2*np.pi * f1 * t))) / 2 104 | x = a * c 105 | 106 | title = "$\cos(2\pi {} t) \cdot (1 + \cos(2\pi {} t))/2$".format(f0, f1) 107 | plot(t, x, show=1, title=title, xlabel="time [sec]") 108 | #%% 109 | cwt_and_viz(x, ts, show_max_rows=1) 110 | 111 | #%%# WGN ##################################################################### 112 | np.random.seed(0) 113 | x = np.random.randn(N) 114 | plot(t, x, show=1, title="White Gaussian Noise", xlabel="time [sec]") 115 | 116 | #%% 117 | cwt_and_viz(x, ts, show_max_rows=0) 118 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/sine_warp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/sine_warp.png -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/tshift_invar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Illustrating time-shift invariance.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | from kymatio.numpy import Scattering1D 9 | from kymatio.visuals import plot, make_gif, plotscat 10 | from kymatio.toolkit import l2 11 | from scipy.signal.windows import tukey 12 | 13 | #%%## Set params & create scattering object ################################## 14 | N = 2048 15 | J, Q = 8, 10 16 | T = 2**J 17 | # make CWT 18 | average, oversampling = False, 999 19 | ts = Scattering1D(shape=N, J=J, Q=Q, average=average, oversampling=oversampling, 20 | out_type='list', max_order=1, T=T, max_pad_factor=None) 21 | meta = ts.meta() 22 | 23 | #%%# Create signal & warp it ################################################# 24 | f = 128 25 | width = N//4 26 | shift = width // 4 27 | 28 | window = tukey(width) 29 | window = np.pad(window, (N - width)//2) 30 | t = np.linspace(0, 1, N, 0) 31 | x = np.cos(2*np.pi * f * t) * window 32 | 33 | x0 = np.roll(x, -shift//2) 34 | x1 = np.roll(x, shift//2) 35 | 36 | #%%# CWT 37 | _Scx0, _Scx1 = ts(x0), ts(x1) 38 | Scx0 = np.vstack([c['coef'] for c in _Scx0])[meta['order'] == 1] 39 | Scx1 = np.vstack([c['coef'] for c in _Scx1])[meta['order'] == 1] 40 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 41 | 42 | x_all = [x0, x1] 43 | cwt_all = [Scx0, Scx1] 44 | 45 | #%% make scattering objects 46 | T0 = 2**(J - 2) * 2 47 | T1 = 2**J * 2 48 | kw = dict(shape=N, J=J, Q=Q, average=True, 49 | out_type='list', max_order=1, max_pad_factor=3) 50 | ts0 = Scattering1D(**kw, T=T0, oversampling=0) 51 | ts1 = Scattering1D(**kw, T=T1, oversampling=0) 52 | meta = ts0.meta() 53 | 54 | #%% scatter 55 | _Scx00, _Scx01, _Scx10, _Scx11 = ts0(x0), ts0(x1), ts1(x0), ts1(x1) 56 | Scx00 = np.vstack([c['coef'] for c in _Scx00])[meta['order'] == 1] 57 | Scx01 = np.vstack([c['coef'] for c in _Scx01])[meta['order'] == 1] 58 | Scx10 = np.vstack([c['coef'] for c in _Scx10])[meta['order'] == 1] 59 | Scx11 = np.vstack([c['coef'] for c in _Scx11])[meta['order'] == 1] 60 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 61 | 62 | x_all = [x0, x1] 63 | Scx_all = [Scx00, Scx10, Scx01, Scx11] 64 | 65 | #%%# Time-shift GIF ########################################################## 66 | # configure 67 | base_name = "tshift" 68 | images_ext = ".png" 69 | savedir = r"C:\Desktop\School\Deep Learning\DL_Code\signals\viz_gen" 70 | overwrite = True 71 | 72 | # common plot kwargs 73 | title_kw = dict(weight='bold', fontsize=20, loc='left') 74 | label_kw = dict(weight='bold', fontsize=18) 75 | imshow_kw = dict(cmap='turbo', aspect='auto', vmin=0) 76 | vmax00 = np.array(cwt_all).max()*.95 77 | vmax0 = max(Scx00.max(), Scx01.max()) 78 | vmax1 = max(Scx10.max(), Scx11.max()) 79 | 80 | for i in (0, 2): 81 | fig, axes = plt.subplots(2, 2, figsize=(18, 14)) 82 | 83 | # cwt #################################################################### 84 | # plot 85 | ax = axes[0, 0] 86 | ax.plot(t, x_all[i//2]) 87 | # style 88 | txt = ("" if i == 0 else 89 | " - %.2f" % (shift/N)) 90 | ax.set_title("x(t%s)" % txt, **title_kw) 91 | xticks = np.array([*np.linspace(0, N, 6, 1)[:-1], N - 1]) 92 | xticklabels = ['%.2f' % tk for tk in t[np.round(xticks).astype(int)]] 93 | ax.set_xticks(xticks / N) 94 | ax.set_xticklabels(xticklabels, fontsize=14) 95 | ax.set_xlim(-.03, 1.03) 96 | ax.set_yticks([]) 97 | 98 | # imshow 99 | ax = axes[0, 1] 100 | ax.imshow(cwt_all[i//2], **imshow_kw, vmax=vmax00) 101 | # style 102 | ax.set_title("|CWT(x)|", **title_kw) 103 | ax.set_yticks([]) 104 | ax.set_xticks(xticks) 105 | ax.set_xticklabels(xticklabels, fontsize=14) 106 | 107 | # scattering ############################################################# 108 | Scx0, Scx1 = Scx_all[i], Scx_all[i + 1] 109 | 110 | # imshow 111 | ax = axes[1, 0] 112 | ax.imshow(Scx0, **imshow_kw, vmax=vmax0) 113 | # style 114 | ax.set_title("S(x) | T=%.2f sec" % (T0 / N), **title_kw) 115 | ax.set_xticks([]); ax.set_yticks([]) 116 | 117 | # # imshow 118 | ax = axes[1, 1] 119 | ax.imshow(Scx1, **imshow_kw, vmax=vmax1) 120 | # style 121 | ax.set_title("S(x) | T=%.2f sec" % (T1 / N), **title_kw) 122 | ax.set_xticks([]); ax.set_yticks([]) 123 | 124 | # finalize 125 | plt.subplots_adjust(hspace=.15, wspace=.04) 126 | 127 | # save 128 | path = os.path.join(savedir, f'{base_name}{i}{images_ext}') 129 | if os.path.isfile(path) and overwrite: 130 | os.unlink(path) 131 | if not os.path.isfile(path): 132 | fig.savefig(path, bbox_inches='tight') 133 | plt.close() 134 | 135 | # make into gif 136 | savepath = os.path.join(savedir, f'{base_name}.gif') 137 | make_gif(loaddir=savedir, savepath=savepath, ext=images_ext, duration=750, 138 | overwrite=overwrite, delimiter=base_name, verbose=1, start_end_pause=0, 139 | delete_images=True) 140 | 141 | #%%# plot superimposed ####################################################### 142 | def plot_superimposed(g0, g1, T, ax): 143 | t_idxs = np.arange(len(g0)) 144 | for t_idx in t_idxs: 145 | plot([t_idx, t_idx], [g0[t_idx], g1[t_idx]], 146 | color='tab:red', linestyle='--', ax=ax) 147 | 148 | title = ("T={:.2f} | reldist={:.2f}".format(T / N, float(l2(g0, g1))), 149 | {'fontsize': 20}) 150 | plotscat(g0, title=title, ax=ax) 151 | plotscat(g1, xlims=(-len(g0)/40, 1.03*(len(g0) - 1)), 152 | ylims=(0, 1.03*max(g0.max(), g1.max())), ax=ax) 153 | 154 | mx_idx = np.argmax(Scx00.sum(axis=-1)) 155 | 156 | fig, axes = plt.subplots(1, 2, figsize=(18, 7)) 157 | plot_superimposed(Scx00[mx_idx], Scx01[mx_idx], T0, axes[0]) 158 | plot_superimposed(Scx10[mx_idx], Scx11[mx_idx], T1, axes[1]) 159 | 160 | plt.subplots_adjust(wspace=.09) 161 | plt.show() 162 | 163 | #%% Reldist vs T ############################################################# 164 | kw = dict(shape=N, J=6, Q=8, average=True, 165 | out_type='array', max_order=1, max_pad_factor=4) 166 | # implementation switches to simple average with unpad which 167 | # detracts from curve slightly in log space, so do N-1 instead of N 168 | T_all = list(range(1, N + 1, 16)) + [N - 1] 169 | ts_all = [Scattering1D(**kw, T=T) for T in T_all] 170 | 171 | #%% 172 | Scx0_all = [ts(x0) for ts in ts_all] 173 | Scx1_all = [ts(x1) for ts in ts_all] 174 | reldist_all = [l2(S0, S1).squeeze() for S0, S1 in zip(Scx0_all, Scx1_all)] 175 | 176 | #%% plot 177 | T_all_sec = np.array(T_all) / N 178 | 179 | fig, axes = plt.subplots(1, 2, figsize=(18, 7)) 180 | plot(T_all_sec, reldist_all, ax=axes[0], 181 | hlines=(0, dict(color='tab:red', linestyle='--', linewidth=1)), 182 | title=("Scattering coefficient relative distance", {'fontsize': 20}), 183 | xlabel="T [sec]", ylabel="reldist") 184 | 185 | T_all_sec_log = np.log2(T_all_sec) 186 | reldist_all_log = np.log10(reldist_all) 187 | plot(T_all_sec_log, reldist_all_log, ax=axes[1], 188 | title=("logscaled", {'fontsize': 20}), 189 | xlabel="log2(T) [sec]", ylabel="log10(reldist)") 190 | 191 | plt.subplots_adjust(wspace=.1) 192 | plt.show() 193 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Sinusoid and echirp warping CWT GIF.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.animation as animation 8 | from kymatio.torch import Scattering1D 9 | from kymatio.visuals import imshow 10 | from kymatio.toolkit import _echirp_fn 11 | 12 | def tau(t, K=.08): 13 | return np.cos(2*np.pi * t**2) * K 14 | 15 | def adtau(t, K=.08): 16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K) 17 | 18 | import torch 19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | #%% Create scattering object ################################################# 22 | N, f = 2048, 64 23 | J, Q = 8, 8 24 | T = 512 25 | # increase freq width to improve temporal localization 26 | # to not confuse ill low-freq localization with warp effects, for echirp 27 | ts = Scattering1D(shape=N, J=J, Q=Q, average=False, oversampling=999, 28 | T=T, out_type='list', r_psi=.82) 29 | if USE_GPU: 30 | ts.cuda() 31 | 32 | #%%# configure time-warps #################################################### 33 | K_init = .01 34 | n_pts = 64 35 | t = np.linspace(0, 1, N, 0) 36 | 37 | p = ts.psi1_f[0] 38 | QF = p['xi'] / p['sigma'] 39 | 40 | adtau_init_max = adtau(t, K=K_init).max() 41 | K_min = (1/QF) / 10 42 | K_max = (1 / adtau_init_max) * K_init 43 | 44 | K_all = np.logspace(np.log10(K_min), np.log10(K_max), n_pts - 1, 1) 45 | K_all = np.hstack([0, K_all]) 46 | tau_all = np.vstack([tau(t, K=k) for k in K_all]) 47 | adtau_max_all = np.vstack([adtau(t, K=k).max() for k in K_all]) 48 | 49 | assert adtau_max_all.max() <= 1, adtau_max_all.max() 50 | 51 | #%% Animate ################################################################# 52 | class PlotImshowAnimation(animation.TimedAnimation): 53 | def __init__(self, plot_frames, imshow_frames, t, freqs, adtau_max_all): 54 | self.plot_frames = plot_frames 55 | self.imshow_frames = imshow_frames 56 | self.t = t 57 | self.freqs = freqs 58 | self.adtau_max_all = adtau_max_all 59 | 60 | self.N = len(t) 61 | self.n_rows = len(imshow_frames[0]) 62 | self.n_frames = len(imshow_frames) 63 | 64 | self.title_kw = dict(weight='bold', fontsize=15, loc='left') 65 | self.label_kw = dict(weight='bold', fontsize=14, labelpad=3) 66 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left") 67 | 68 | fig, axes = plt.subplots(1, 2, figsize=(10.6, 4.7)) 69 | 70 | # plots ############################################################## 71 | ax = axes[0] 72 | self.lines0 = [] 73 | 74 | ax.plot(plot_frames[0]) 75 | self.lines0.append(ax.lines[-1]) 76 | 77 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True)) 78 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True) 79 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels]) 80 | 81 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=15) 82 | 83 | # imshows ############################################################ 84 | ax = axes[1] 85 | im = ax.imshow(self.imshow_frames[0], cmap='turbo', animated=True, 86 | aspect='auto') 87 | self.ims1 = [im] 88 | 89 | ax.set_xticks(np.linspace(0, self.N, 6, endpoint=True)) 90 | xticklabels = np.linspace(t.min(), t.max(), 6, endpoint=True) 91 | ax.set_xticklabels(["%.1f" % xt for xt in xticklabels]) 92 | yticks = np.linspace(0, self.n_rows - 1, 6, endpoint=True) 93 | ax.set_yticks(yticks) 94 | ax.set_yticklabels(["%.2f" % (self.freqs[int(yt)] / self.N) 95 | for yt in yticks]) 96 | 97 | ax.set_title("|CWT(x(t))|", **self.title_kw) 98 | 99 | # finalize ####################################################### 100 | fig.subplots_adjust(left=.05, right=.99, bottom=.06, top=.94, 101 | wspace=.15) 102 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True) 103 | 104 | def _draw_frame(self, frame_idx): 105 | # plots ############################################################## 106 | self.lines0[0].set_ydata(self.plot_frames[frame_idx]) 107 | 108 | if frame_idx == 0: 109 | txt = r"$x(t) \ ... \ |\tau'(t)| = 0$" 110 | else: 111 | txt = r"$x(t) \ ... \ |\tau'(t)| < %.3f$" % self.adtau_max_all[ 112 | frame_idx] 113 | self.txt0.set_text(txt) 114 | 115 | # imshows ############################################################ 116 | self.ims1[0].set_array(self.imshow_frames[frame_idx]) 117 | 118 | # finalize ########################################################### 119 | self._drawn_artists = [*self.lines0, self.txt0, *self.ims1] 120 | 121 | def new_frame_seq(self): 122 | return iter(range(self.n_frames)) 123 | 124 | def _init_draw(self): 125 | pass 126 | 127 | #%%# Run on each signal ###################################################### 128 | def extend(x): 129 | return np.array(list(x) + 6*[x[-1]]) 130 | 131 | def run(x_all, adtau_max_all, ts, name): 132 | meta = ts.meta() 133 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 134 | 135 | Scx_all0 = [ts(x) for x in x_all] 136 | Scx_all = np.vstack([(np.vstack([c['coef'].cpu().numpy() for c in Scx] 137 | )[meta['order'] == 1])[None] 138 | for Scx in Scx_all0]) 139 | 140 | # animate 141 | x_all, Scx_all, adtau_max_all = [extend(h) for h in 142 | (x_all, Scx_all, adtau_max_all)] 143 | plot_frames, imshow_frames = x_all, Scx_all 144 | 145 | ani = PlotImshowAnimation(plot_frames, imshow_frames, t, freqs, 146 | adtau_max_all) 147 | ani.save(f'warp_cwt_{name}.mp4', fps=10) 148 | plt.show() 149 | 150 | #%% echirp 151 | _t = (t - tau_all) 152 | a0 = np.cos(2*np.pi * 3.7 * _t) 153 | x_all = np.cos(_echirp_fn(fmin=40, fmax=N/8)(_t)) * a0 154 | #%% 155 | run(x_all, adtau_max_all, ts, "echirp") 156 | #%% visualize warps for sine case with lower freq 157 | def _tt(title): 158 | return (title, {'fontsize': 20}) 159 | 160 | fig, axes = plt.subplots(1, 2, figsize=(18, 7)) 161 | x_all_viz = x_all#np.cos(2*np.pi * 64 * (t - tau_all)) 162 | imshow(x_all_viz, title=_tt("x(t - tau(t)) for all tau"), show=0, ax=axes[0], 163 | xlabel="time", ylabel="max(|tau'|)", yticks=adtau_max_all) 164 | imshow(tau_all, title=_tt("all tau(t)"), show=0, ax=axes[1], 165 | xlabel="time", yticks=0) 166 | 167 | for ax in axes: 168 | ax.set_xticks(np.linspace(0, len(t), 6)) 169 | ax.set_xticklabels([0, .2, .4, .6, .8, 1]) 170 | 171 | plt.subplots_adjust(wspace=.03) 172 | fig.savefig("sine_warp.png", bbox_inches='tight') 173 | plt.show() 174 | plt.close(fig) 175 | 176 | #%% sine 177 | x_all = np.cos(2*np.pi * f * (t - tau_all)) 178 | run(x_all, adtau_max_all, ts, "sine") 179 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_echirp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_echirp.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_sine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_cwt_sine.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_T_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Scattering unwarped vs warped echirp GIF, sweeping `T`.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.animation as animation 8 | from kymatio.torch import Scattering1D 9 | from kymatio.visuals import plot 10 | from kymatio.toolkit import _echirp_fn, l2 11 | 12 | def tau(t, K=.08): 13 | return np.cos(2*np.pi * t**2) * K 14 | 15 | def adtau(t, K=.08): 16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K) 17 | 18 | import torch 19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | #%%########################################################################### 22 | N = 2048 23 | J, Q = 8, 8 24 | T_all = np.arange(1, N/2 + 1, 8) 25 | T_all = np.logspace(np.log2(1), np.log2(N - 1), N//8, base=2, endpoint=True) 26 | 27 | kw = dict(shape=N, J=J, Q=Q, average=True, oversampling=999, out_type='array', 28 | max_order=1, max_pad_factor=None, r_psi=.82) 29 | ts_all = [Scattering1D(**kw, T=T) for T in T_all] 30 | meta_all = [ts.meta() for ts in ts_all] 31 | 32 | if USE_GPU: 33 | for ts in ts_all: 34 | ts.cuda() 35 | 36 | #%%########################################################################### 37 | meta = meta_all[0] 38 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 39 | #%% 40 | t = np.linspace(0, 1, N, 0) 41 | _tau = tau(t, K=.012) 42 | a0 = np.cos(2*np.pi * 3.7 * (t - _tau)) 43 | x0 = np.cos(_echirp_fn(fmin=40, fmax=N/8)(t)) * a0 44 | x1 = np.cos(_echirp_fn(fmin=40, fmax=N/8)(t - _tau)) * a0 45 | #%% 46 | Scx_all0 = np.array([ts(x0)[1:].cpu().numpy() for ts in ts_all]) 47 | Scx_all1 = np.array([ts(x1)[1:].cpu().numpy() for ts in ts_all]) 48 | #%% 49 | dists = np.array([l2(S0, S1) for S0, S1 in zip(Scx_all0, Scx_all1)]).squeeze() 50 | plot(T_all, dists, ylims=(0, None), show=1) 51 | 52 | #%% extend frames 53 | def extend(x): 54 | return np.array(6*[x[0]] + list(x) + 6*[x[-1]]) 55 | 56 | Scx_all0, Scx_all1, dists, T_all = [extend(h) for h in 57 | (Scx_all0, Scx_all1, dists, T_all)] 58 | 59 | #%% Animate ################################################################# 60 | class PlotImshowAnimation(animation.TimedAnimation): 61 | def __init__(self, imshow_frames0, imshow_frames1, plot_frame, T_all, vmax): 62 | self.imshow_frames0 = imshow_frames0 63 | self.imshow_frames1 = imshow_frames1 64 | self.plot_frame = plot_frame 65 | self.T_all = T_all 66 | self.vmax = vmax 67 | 68 | self.N = imshow_frames0[0].shape[-1] 69 | self.T_all_sec_log = np.log2(self.T_all / N) 70 | self.n_frames = len(imshow_frames0) 71 | 72 | self.label_kw = dict(weight='bold', fontsize=15, labelpad=3) 73 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left") 74 | 75 | fig = plt.figure(figsize=(18/1.5, 15/1.8)) 76 | ax0 = fig.add_subplot(2, 2, 1) 77 | ax1 = fig.add_subplot(2, 2, 2) 78 | ax2 = fig.add_subplot(3, 2, (3, 4)) 79 | axes = [ax0, ax1, ax2] 80 | 81 | # imshow1 ############################################################ 82 | ax = axes[0] 83 | imshow_kw = dict(cmap='turbo', aspect='auto', animated=True, 84 | vmin=0, vmax=self.vmax) 85 | im = ax.imshow(self.imshow_frames0[0], **imshow_kw) 86 | self.ims0 = [im] 87 | 88 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18) 89 | ax.set_xticks([]); ax.set_yticks([]) 90 | 91 | # imshow2 ############################################################ 92 | ax = axes[1] 93 | im = ax.imshow(self.imshow_frames1[0], **imshow_kw) 94 | self.ims1 = [im] 95 | 96 | self.txt1 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18) 97 | ax.set_xticks([]); ax.set_yticks([]) 98 | 99 | # plot ############################################################### 100 | ax = axes[2] 101 | line = ax.plot(self.T_all_sec_log, self.plot_frame)[0] 102 | line.set_data(self.T_all_sec_log[0], self.plot_frame[0]) 103 | self.lines0 = [line] 104 | 105 | ax.set_ylim(0, np.max(self.plot_frame)*1.04) 106 | ax.set_xlabel('log2(T) [sec]', **self.label_kw) 107 | ax.set_ylabel('reldist', **self.label_kw) 108 | 109 | # finalize ####################################################### 110 | fig.subplots_adjust(left=.05, right=.99, bottom=-.49, top=.96, 111 | wspace=.02, hspace=.7) 112 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True) 113 | 114 | def _draw_frame(self, frame_idx): 115 | def _txt(txt_idx, T_sec): 116 | fill = (r"x(t)" if txt_idx == 0 else 117 | r"x(t - \tau(t)") 118 | if frame_idx == 0: 119 | txt = r"$S_1(%s)$ | unaveraged" % fill 120 | else: 121 | txt = r"$S_1(%s)\ |\ T=%.3f$ sec " % (fill, T_sec) 122 | return txt 123 | 124 | T_sec = self.T_all[frame_idx] / N 125 | # imshow0 ############################################################ 126 | self.ims0[0].set_array(self.imshow_frames0[frame_idx]) 127 | self.txt0.set_text(_txt(0, T_sec)) 128 | 129 | # imshow1 ############################################################ 130 | self.ims1[0].set_array(self.imshow_frames1[frame_idx]) 131 | self.txt1.set_text(_txt(1, T_sec)) 132 | 133 | # plot ############################################################### 134 | self.lines0[0].set_data(self.T_all_sec_log[:frame_idx + 1], 135 | self.plot_frame[:frame_idx + 1]) 136 | 137 | # finalize ########################################################### 138 | self._drawn_artists = [*self.ims0, self.txt0, *self.ims1, self.txt1, 139 | *self.lines0] 140 | 141 | def new_frame_seq(self): 142 | return iter(range(self.n_frames)) 143 | 144 | def _init_draw(self): 145 | pass 146 | 147 | imshow_frames0, imshow_frames1 = Scx_all0, Scx_all1 148 | plot_frame = dists 149 | vmax = max(Scx_all0.max(), Scx_all1.max())*.98 150 | 151 | ani = PlotImshowAnimation(imshow_frames0, imshow_frames1, plot_frame, T_all, 152 | vmax) 153 | ani.save('warp_scat_T.mp4', fps=15, savefig_kwargs=dict(pad_inches=0)) 154 | plt.show() 155 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Scattering unwarped vs warped echirp GIF.""" 3 | # https://dsp.stackexchange.com/q/78512/50076 ################################ 4 | # https://dsp.stackexchange.com/q/78514/50076 ################################ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.animation as animation 8 | from kymatio.torch import Scattering1D 9 | from kymatio.visuals import plot 10 | from kymatio.toolkit import _echirp_fn, l2 11 | 12 | def tau(t, K=.08): 13 | return np.cos(2*np.pi * t**2) * K 14 | 15 | def adtau(t, K=.08): 16 | return np.abs(np.sin(2*np.pi * t**2) * t * 4*np.pi*K) 17 | 18 | import torch 19 | USE_GPU = bool('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | #%%########################################################################### 22 | N, f = 2048, 64 23 | 24 | #%%########################################################################### 25 | J, Q = 8, 8 26 | T = 512 27 | T_all = np.arange(1, N/2 + 1, 8) 28 | T_all = np.logspace(np.log2(1), np.log2(N - 1), N//8, base=2, endpoint=True) 29 | 30 | kw = dict(shape=N, J=J, Q=Q, average=True, oversampling=999, out_type='array', 31 | max_order=1, max_pad_factor=None) 32 | ts_all = [Scattering1D(**kw, T=T) for T in T_all] 33 | meta_all = [ts.meta() for ts in ts_all] 34 | 35 | if USE_GPU: 36 | for ts in ts_all: 37 | ts.cuda() 38 | 39 | #%%########################################################################### 40 | meta = meta_all[0] 41 | freqs = N * meta['xi'][meta['order'] == 1][:, 0] 42 | #%% 43 | t = np.linspace(0, 1, N, 0) 44 | _tau = tau(t, K=.012) 45 | x0 = np.cos(_echirp_fn(fmin=32, fmax=N/8)(t)) 46 | x1 = np.cos(_echirp_fn(fmin=32, fmax=N/8)(t - _tau)) 47 | #%% 48 | Scx_all0 = np.array([ts(x0)[1:].cpu().numpy() for ts in ts_all]) 49 | Scx_all1 = np.array([ts(x1)[1:].cpu().numpy() for ts in ts_all]) 50 | #%% 51 | dists = np.array([l2(S0, S1) for S0, S1 in zip(Scx_all0, Scx_all1)]).squeeze() 52 | plot(T_all, dists, ylims=(0, None), show=1) 53 | 54 | #%% extend 55 | def extend(x): 56 | return np.array(6*[x[0]] + list(x) + 6*[x[-1]]) 57 | 58 | Scx_all0, Scx_all1, dists, T_all = [extend(h) for h in 59 | (Scx_all0, Scx_all1, dists, T_all)] 60 | 61 | #%% Animate ################################################################# 62 | class PlotImshowAnimation(animation.TimedAnimation): 63 | def __init__(self, imshow_frames0, imshow_frames1, plot_frame, T_all, vmax): 64 | self.imshow_frames0 = imshow_frames0 65 | self.imshow_frames1 = imshow_frames1 66 | self.plot_frame = plot_frame 67 | self.T_all = T_all 68 | self.vmax = vmax 69 | 70 | self.N = imshow_frames0[0].shape[-1] 71 | self.T_all_sec_log = np.log2(self.T_all / N) 72 | self.n_frames = len(imshow_frames0) 73 | 74 | self.label_kw = dict(weight='bold', fontsize=15, labelpad=3) 75 | self.txt_kw = dict(x=0, y=1.015, s="", ha="left") 76 | 77 | fig = plt.figure(figsize=(18/1.5, 15/1.8)) 78 | ax0 = fig.add_subplot(2, 2, 1) 79 | ax1 = fig.add_subplot(2, 2, 2) 80 | ax2 = fig.add_subplot(3, 2, (3, 4)) 81 | axes = [ax0, ax1, ax2] 82 | 83 | # imshow1 ############################################################ 84 | ax = axes[0] 85 | imshow_kw = dict(cmap='turbo', aspect='auto', animated=True, 86 | vmin=0, vmax=self.vmax) 87 | im = ax.imshow(self.imshow_frames0[0], **imshow_kw) 88 | self.ims0 = [im] 89 | 90 | self.txt0 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18) 91 | ax.set_xticks([]); ax.set_yticks([]) 92 | 93 | # imshow2 ############################################################ 94 | ax = axes[1] 95 | im = ax.imshow(self.imshow_frames1[0], **imshow_kw) 96 | self.ims1 = [im] 97 | 98 | self.txt1 = ax.text(transform=ax.transAxes, **self.txt_kw, fontsize=18) 99 | ax.set_xticks([]); ax.set_yticks([]) 100 | 101 | # plot ############################################################### 102 | ax = axes[2] 103 | line = ax.plot(self.T_all_sec_log, self.plot_frame)[0] 104 | line.set_data(self.T_all_sec_log[0], self.plot_frame[0]) 105 | self.lines0 = [line] 106 | 107 | ax.set_ylim(0, np.max(self.plot_frame)*1.04) 108 | ax.set_xlabel('log2(T) [sec]', **self.label_kw) 109 | ax.set_ylabel('reldist', **self.label_kw) 110 | 111 | # finalize ####################################################### 112 | fig.subplots_adjust(left=.05, right=.99, bottom=-.49, top=.96, 113 | wspace=.02, hspace=.7) 114 | animation.TimedAnimation.__init__(self, fig, interval=50, blit=True) 115 | 116 | def _draw_frame(self, frame_idx): 117 | def _txt(txt_idx, T_sec): 118 | fill = (r"x(t)" if txt_idx == 0 else 119 | r"x(t - \tau(t)") 120 | if frame_idx == 0: 121 | txt = r"$S_1(%s)$ | unaveraged" % fill 122 | else: 123 | txt = r"$S_1(%s)\ |\ T=%.3f$ sec " % (fill, T_sec) 124 | return txt 125 | 126 | T_sec = self.T_all[frame_idx] / N 127 | # imshow0 ############################################################ 128 | self.ims0[0].set_array(self.imshow_frames0[frame_idx]) 129 | self.txt0.set_text(_txt(0, T_sec)) 130 | 131 | # imshow1 ############################################################ 132 | self.ims1[0].set_array(self.imshow_frames1[frame_idx]) 133 | self.txt1.set_text(_txt(1, T_sec)) 134 | 135 | # plot ############################################################### 136 | self.lines0[0].set_data(self.T_all_sec_log[:frame_idx], 137 | self.plot_frame[:frame_idx]) 138 | 139 | # finalize ########################################################### 140 | self._drawn_artists = [*self.ims0, self.txt0, *self.ims1, self.txt1, 141 | *self.lines0] 142 | 143 | def new_frame_seq(self): 144 | return iter(range(self.n_frames)) 145 | 146 | def _init_draw(self): 147 | pass 148 | 149 | imshow_frames0, imshow_frames1 = Scx_all0, Scx_all1 150 | plot_frame = dists 151 | vmax = max(Scx_all0.max(), Scx_all1.max())*.98 152 | 153 | ani = PlotImshowAnimation(imshow_frames0, imshow_frames1, plot_frame, T_all, 154 | vmax) 155 | ani.save('warp_scat.mp4', fps=15, savefig_kwargs=dict(pad_inches=0)) 156 | plt.show() 157 | -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_echirp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_echirp.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_pulse.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78512 - Wavelet Scattering explanation/warp_scat_tau_pulse.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/data_3d_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize JTFS of real data: as GIF of 3D slices unfolded over time.""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | import numpy as np 5 | import torch 6 | import librosa 7 | from timeit import default_timer as dtime 8 | 9 | from kymatio import TimeFrequencyScattering1D 10 | from kymatio.toolkit import pack_coeffs_jtfs, jtfs_to_numpy, normalize 11 | from kymatio.visuals import gif_jtfs_3d, make_gif 12 | 13 | #%% load data ################################################################ 14 | x, sr = librosa.load(librosa.ex('trumpet')) 15 | x = x[:81920] 16 | 17 | #%% create scattering object, move to GPU #################################### 18 | N = len(x) 19 | J = int(np.log2(N) - 3) 20 | Q = (16, 1) 21 | Q_fr = 2 22 | J_fr = 4 23 | T = 2**(J - 4) 24 | F = 4 25 | 26 | jtfs = TimeFrequencyScattering1D(shape=N, J=J, J_fr=J_fr, Q=Q, Q_fr=Q_fr, 27 | T=T, F=F, average_fr=True, 28 | max_pad_factor=None, max_pad_factor_fr=None, 29 | out_type='dict:array', frontend='torch') 30 | jmeta = jtfs.meta() 31 | for a in ('N_frs', 'J_pad_frs'): 32 | print(getattr(jtfs, a), '--', a) 33 | jtfs.cuda() 34 | torch.cuda.empty_cache() 35 | 36 | #%% scatter ################################################################## 37 | t0 = dtime() 38 | Scxt = jtfs(x) 39 | Scx = jtfs_to_numpy(Scxt) 40 | 41 | #%% pack 42 | packed = pack_coeffs_jtfs(Scx, jmeta, structure=2, separate_lowpass=True, 43 | sampling_psi_fr=jtfs.sampling_psi_fr) 44 | packed_spinned = packed[0] 45 | packed_spinned = packed_spinned.transpose(-1, 0, 1, 2) 46 | 47 | #%% normalize 48 | s = packed_spinned.shape 49 | packed_spinned = normalize(packed_spinned.reshape(1, s[0], -1)).reshape(*s) 50 | 51 | #%% make smooth camera transition ############################################ 52 | packed_viz = packed_spinned 53 | print(packed_viz.shape) 54 | n_pts = len(packed_viz) 55 | extend_edge = int(.35 * n_pts) 56 | 57 | def gauss(n_pts, mn, mx, width=20): 58 | t = np.linspace(0, 1, n_pts) 59 | g = np.exp(-(t - .5)**2 * width) 60 | g *= (mx - mn) 61 | g += mn 62 | return g 63 | 64 | x = np.logspace(np.log10(2.5), np.log10(8.5), n_pts, endpoint=1) 65 | y = np.logspace(np.log10(0.3), np.log10(6.3), n_pts, endpoint=1) 66 | z = np.logspace(np.log10(2.0), np.log10(2.0), n_pts, endpoint=1) 67 | 68 | x, y, z = [gauss(n_pts, mn, mx) for (mn, mx) 69 | in [(2.5, 8.5), (0.3, 6.3), (2, 2)]] 70 | 71 | eyes = np.vstack([x, y, z]).T 72 | assert len(x) == len(packed_viz), (len(x), len(packed_viz)) 73 | 74 | #%% Make gif ################################################################ 75 | t0 = dtime() 76 | gif_jtfs_3d(packed_viz, base_name='jtfs3d_trumpet', images_ext='.png', 77 | overwrite=1, save_images=1, angles=eyes, gif_kw=dict(duration=50)) 78 | print(dtime() - t0) 79 | 80 | #%% remake if needed 81 | # make_gif('', 'etc.gif', duration=33, delimiter='seiz3d', ext='.jpg', 82 | # overwrite=1, HD=True) 83 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Visualize JTFS of exponential chirp: as GIF of coefficients 4 | superimposed with wavelets. 5 | """ 6 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 7 | import numpy as np 8 | from kymatio.numpy import TimeFrequencyScattering1D 9 | from kymatio.toolkit import echirp, pack_coeffs_jtfs 10 | from kymatio.visuals import make_gif 11 | from numpy.fft import ifft, ifftshift 12 | import matplotlib.pyplot as plt 13 | 14 | #%% Generate echirp and create scattering object ############################# 15 | N = 4096 16 | # span low to Nyquist; assume duration of 1 second 17 | x = echirp(N, fmin=64, fmax=N/2) 18 | 19 | #%% Show joint wavelets on a smaller filterbank ############################## 20 | o = (0, 999)[1] 21 | rp = np.sqrt(.5) 22 | jtfs = TimeFrequencyScattering1D(shape=N, J=5, Q=(16, 1), J_fr=3, Q_fr=1, 23 | sampling_filters_fr='resample', 24 | average=0, average_fr=0, F=4, 25 | r_psi=(rp, .9*rp, rp), out_type='dict:list', 26 | oversampling=o, oversampling_fr=o) 27 | 28 | #%% scatter 29 | Scx_orig = jtfs(x) 30 | jmeta = jtfs.meta() 31 | 32 | #%% pack 33 | Scx = pack_coeffs_jtfs(Scx_orig, jmeta, structure=2) 34 | cmx = Scx.max() * .5 # color max 35 | 36 | #%% 37 | n_n2s = sum(p['j'] > 0 for p in jtfs.psi2_f) 38 | n_n1_frs = len(jtfs.psi1_f_fr_up) 39 | # drop spin up 40 | Scx = Scx[:, n_n1_frs:] 41 | # drop spin down 42 | # Scx = Scx[:, :n_n1_frs + 1] 43 | 44 | #%% 45 | psis_down = jtfs.psi1_f_fr_down 46 | psi2s = [p for p in jtfs.psi2_f if p['j'] > 0] 47 | # reverse ordering 48 | psi2s = psi2s[::-1] 49 | 50 | # reverse order of fr wavelets 51 | psis_down = psis_down[::-1] 52 | 53 | # for spin down 54 | c_psi_dn = Scx[1:, 1:] 55 | c_phi_t = Scx[0, 1:] 56 | c_phi_f = Scx[1:, 0] 57 | c_phis = Scx[0, 0] 58 | 59 | # for spin up 60 | # c_psi_dn = Scx[1:, :-1] 61 | # c_phi_t = Scx[0, :-1] 62 | # c_phi_f = Scx[1:, -1] 63 | # c_phis = Scx[0, -1] 64 | 65 | #%% Visualize ################################################################ 66 | def no_border(ax): 67 | ax.set_xticks([]); ax.set_yticks([]) 68 | for spine in ax.spines: 69 | ax.spines[spine].set_visible(False) 70 | 71 | def to_time(p_f): 72 | if isinstance(p_f, dict): 73 | p_f = p_f[0] 74 | if isinstance(p_f, list): 75 | p_f = p_f[0] 76 | return ifftshift(ifft(p_f.squeeze())) 77 | 78 | spin_up = False 79 | 80 | imshow_kw0 = dict(aspect='auto', cmap='bwr') 81 | imshow_kw1 = dict(aspect='auto', cmap='turbo') 82 | 83 | n_rows = len(psis_down) + 1 84 | n_cols = len(psi2s) + 1 85 | w = 13 86 | h = 13 * n_rows / n_cols 87 | 88 | fig0, axes0 = plt.subplots(n_rows, n_cols, figsize=(w, h)) 89 | fig1, axes1 = plt.subplots(n_rows, n_cols, figsize=(w, h)) 90 | 91 | # compute common params to zoom on wavelets based on largest wavelet 92 | pf_f = psis_down[0] 93 | pt_f = psi2s[0] 94 | # centers 95 | ct = len(pt_f[0]) // 2 96 | cf = len(pf_f[0]) // 2 97 | # supports 98 | st = int(pt_f['support'][0] / 1.5) 99 | sf = int(pf_f['support'][0] / 1.5) 100 | 101 | # coeff max 102 | cmx = max(c_phi_t.max(), c_phi_f.max(), c_psi_dn.max()) * .8 103 | 104 | # psi_t * psi_f_down 105 | for n2_idx, pt_f in enumerate(psi2s): 106 | for n1_fr_idx, pf_f in enumerate(psis_down): 107 | pt = to_time(pt_f) 108 | pf = to_time(pf_f) 109 | # trim to zoom on wavelet 110 | pt = pt[ct - st:ct + st + 1] 111 | pf = pf[cf - sf:cf + sf + 1] 112 | 113 | Psi = pf[:, None] * pt[None] 114 | 115 | a = n1_fr_idx if spin_up else n1_fr_idx + 1 116 | ax0 = axes0[a, n2_idx + 1] 117 | ax1 = axes1[a, n2_idx + 1] 118 | 119 | mx = np.abs(Psi).max() 120 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx) 121 | no_border(ax0) 122 | 123 | # coeffs 124 | c = c_psi_dn[n2_idx, n1_fr_idx] 125 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx) 126 | no_border(ax1) 127 | 128 | 129 | # psi_t * phi_f 130 | phif = to_time(jtfs.phi_f_fr) 131 | phif = phif[cf - sf:cf + sf + 1] 132 | for n2_idx, pt_f in enumerate(psi2s): 133 | pt = to_time(pt_f) 134 | pt = pt[ct - st:ct + st + 1] 135 | Psi = phif[:, None] * pt[None] 136 | 137 | a = -1 if spin_up else 0 138 | ax0 = axes0[a, n2_idx + 1] 139 | ax1 = axes1[a, n2_idx + 1] 140 | 141 | mx = np.abs(Psi).max() 142 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx) 143 | no_border(ax0) 144 | 145 | # coeffs 146 | c = c_phi_f[n2_idx] 147 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx) 148 | no_border(ax1) 149 | 150 | # phi_t * psi_f 151 | phit = to_time(jtfs.phi_f) 152 | phit = phit[ct - st:ct + st + 1] 153 | for n1_fr_idx, pf_f in enumerate(psis_down): 154 | pf = to_time(pf_f) 155 | pf = pf[cf - sf:cf + sf + 1] 156 | 157 | Psi = pf[:, None] * phit[None] 158 | 159 | a = n1_fr_idx if spin_up else n1_fr_idx + 1 160 | ax0 = axes0[a, 0] 161 | ax1 = axes1[a, 0] 162 | 163 | mx = np.abs(Psi).max() 164 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx) 165 | no_border(ax0) 166 | 167 | # coeffs 168 | c = c_phi_t[n1_fr_idx] 169 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx) 170 | no_border(ax1) 171 | 172 | # phi_t * phi_f 173 | a = -1 if spin_up else 0 174 | ax0 = axes0[a, 0] 175 | ax1 = axes1[a, 0] 176 | 177 | Psi = phif[:, None] * phit[None] 178 | mx = np.abs(Psi).max() 179 | ax0.imshow(Psi.real, **imshow_kw0, vmin=-mx, vmax=mx) 180 | no_border(ax0) 181 | 182 | # coeffs 183 | c = c_phis 184 | ax1.imshow(c, **imshow_kw1, vmin=0, vmax=cmx) 185 | no_border(ax1) 186 | 187 | fig0.subplots_adjust(wspace=.02, hspace=.02) 188 | fig1.subplots_adjust(wspace=.02, hspace=.02) 189 | 190 | base_name = 'jtfs_echirp_wavelets' 191 | fig0.savefig(f'{base_name}0.png', bbox_inches='tight') 192 | fig1.savefig(f'{base_name}1.png', bbox_inches='tight') 193 | make_gif('', f'{base_name}.gif', duration=2000, start_end_pause=0, 194 | delimiter=base_name, overwrite=1) 195 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_2d_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize JTFS of exponential chirp: as GIF of joint slices (2D).""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | import numpy as np 5 | from kymatio.numpy import TimeFrequencyScattering1D 6 | from kymatio.toolkit import echirp, energy 7 | from kymatio.visuals import plot, imshow 8 | from kymatio import visuals 9 | 10 | #%% Generate echirp and create scattering object ############################# 11 | N = 4097 12 | # span low to Nyquist; assume duration of 1 second 13 | x = echirp(N, fmin=64, fmax=N/2) 14 | x = np.cos(2*np.pi * (N//16) * np.linspace(0, 1, N, 1)) 15 | 16 | # 9 temporal octaves 17 | # largest scale is 2**9 [samples] / 4096 [samples / sec] == 125 ms 18 | J = 9 19 | # 8 bandpass wavelets per octave 20 | # J*Q ~= 144 total temporal coefficients in first-order scattering 21 | Q = 16 22 | # scale of temporal invariance, 31.25 ms 23 | T = 2**7 24 | # 4 frequential octaves 25 | J_fr = 4 26 | # 2 bandpass wavelets per octave 27 | Q_fr = 2 28 | # scale of frequential invariance, F/Q == 0.5 cycle per octave 29 | F = 8 30 | # do frequential averaging to enable 4D concatenation 31 | average_fr = True 32 | # frequential padding; 'zero' avoids few discretization artefacts for this example 33 | pad_mode_fr = 'zero' 34 | # return packed as dict keyed by pair names for easy inspection 35 | out_type = 'dict:array' 36 | 37 | params = dict(J=J, Q=Q, T=T, J_fr=J_fr, Q_fr=Q_fr, F=F, average_fr=average_fr, 38 | out_type=out_type, pad_mode_fr=pad_mode_fr, max_pad_factor=4, 39 | max_pad_factor_fr=4, oversampling=999, oversampling_fr=999) 40 | jtfs = TimeFrequencyScattering1D(shape=N, **params) 41 | 42 | #%% Take JTFS, print pair names and shapes ################################### 43 | Scx = jtfs(x) 44 | print("JTFS pairs:") 45 | for pair in Scx: 46 | print("{:<12} -- {}".format(str(Scx[pair].shape), pair)) 47 | 48 | E_up = energy(Scx['psi_t * psi_f_up']) 49 | E_dn = energy(Scx['psi_t * psi_f_down']) 50 | print("E_down / E_up = {:.1f}".format(E_dn / E_up)) 51 | #%% Show `x` and its (time-averaged) scalogram ############################### 52 | plot(x, show=1, w=.7, 53 | xlabel="time [samples]", 54 | title="Exponential chirp | fmin=64, fmax=2048, 4096 samples") 55 | #%% 56 | freqs = jtfs.meta()['xi']['S1'][:, -1] 57 | imshow(Scx['S1'].squeeze(), abs=1, w=.8, h=.67, yticks=freqs, 58 | xlabel="time [samples]", 59 | ylabel="frequency [frac of fs]", 60 | title="Scalogram, time-averaged (first-order scattering)") 61 | 62 | #%% Create & save GIF ######################################################## 63 | # fetch meta (structural info) 64 | jmeta = jtfs.meta() 65 | # specify save folder 66 | savedir = '' 67 | # time between GIF frames (ms) 68 | duration = 150 69 | visuals.gif_jtfs_2d(Scx, jmeta, savedir=savedir, base_name='jtfs_sine', 70 | save_images=0, overwrite=1, gif_kw={'duration': duration}) 71 | # Notice how -1 spin coefficients contain nearly all the energy 72 | # and +1 barely any; this is FDTS discriminability. 73 | # For ideal FDTS and JTFS, +1 will be all zeros. 74 | # For down chirp, the case is reversed and +1 hoards the energy. 75 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/echirp_3d_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize JTFS of exponential chirp: as GIF of 3D slices unfolded over time.""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | from kymatio.numpy import TimeFrequencyScattering1D 5 | from kymatio.toolkit import echirp, pack_coeffs_jtfs 6 | from kymatio.visuals import gif_jtfs_3d 7 | 8 | #%% Make scattering object ################################################### 9 | N = 4096 10 | jtfs = TimeFrequencyScattering1D(shape=N, J=8, Q=16, T=2**8, F=4, J_fr=4, Q_fr=2, 11 | max_pad_factor=None, max_pad_factor_fr=None, 12 | oversampling=1, out_type='dict:array', 13 | pad_mode_fr='zero', pad_mode='zero', 14 | average_fr=True, sampling_filters_fr='resample') 15 | meta = jtfs.meta() 16 | 17 | #%% Make echirp & scatter #################################################### 18 | x = echirp(N, fmin=64, fmax=N/2) 19 | Scx = jtfs(x) 20 | 21 | #%% Make GIF with and without spinned ######################################## 22 | for separate_lowpass in (False, True): 23 | packed = pack_coeffs_jtfs(Scx, meta, structure=2, separate_lowpass=True, 24 | sampling_psi_fr=jtfs.sampling_psi_fr) 25 | if separate_lowpass: 26 | packed = packed[0] 27 | packed = packed.transpose(-1, 0, 1, 2) # time first 28 | 29 | nm = 'spinned' if separate_lowpass else 'full' 30 | base_name = f'jtfs3d_echirp_{nm}' 31 | gif_jtfs_3d(packed, base_name=base_name, overwrite=1, cmap_norm=.5) 32 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/fdts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """JTFS of Frequency-Dependent Time Shifts.""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | import numpy as np 5 | import torch 6 | from kymatio.torch import Scattering1D, TimeFrequencyScattering1D 7 | from kymatio.toolkit import fdts, l2 8 | from kymatio.visuals import make_gif 9 | import matplotlib.pyplot as plt 10 | 11 | #%% Generate echirp and create scattering object ############################# 12 | N = 4096 13 | f0 = N // 24 14 | n_partials = 5 15 | total_shift = N//12 16 | seg_len = N//6 17 | 18 | x, xs = fdts(N, n_partials, total_shift, f0, seg_len, partials_f_sep=1.7) 19 | 20 | #%% Make scattering objects ################################################## 21 | J = int(np.log2(N)) # max scale / global averaging 22 | Q = (8, 2) 23 | kw = dict(J=J, Q=Q, shape=N, max_pad_factor=4, pad_mode='zero') 24 | cwt = Scattering1D(**kw, average=False, out_type='list', oversampling=999, 25 | max_order=1) 26 | ts = Scattering1D(out_type='array', max_order=2, **kw) 27 | jtfs = TimeFrequencyScattering1D(Q_fr=2, J_fr=5, out_type='array', **kw, 28 | out_exclude=('S1', 'phi_t * psi_f', 29 | 'phi_t * phi_f', 'psi_t * phi_f'), 30 | max_pad_factor_fr=4, pad_mode_fr='zero', 31 | sampling_filters_fr=('resample', 'resample')) 32 | cwt, ts, jtfs = [s.cuda() for s in (cwt, ts, jtfs)] 33 | 34 | #%% Scatter ################################################################## 35 | cwt_x = torch.vstack([c['coef'] for c in cwt(x)]).cpu().numpy()[1:] 36 | cwt_xs = torch.vstack([c['coef'] for c in cwt(xs)]).cpu().numpy()[1:] 37 | 38 | ts_x = ts(x).cpu().numpy() 39 | ts_xs = ts(xs).cpu().numpy() 40 | 41 | jtfs_x = jtfs(x).cpu().numpy()[0] 42 | jtfs_xs = jtfs(xs).cpu().numpy()[0] 43 | 44 | l2_ts = float(l2(ts_x, ts_xs)) 45 | l2_jtfs = float(l2(jtfs_x, jtfs_xs)) 46 | 47 | print(("\nFDTS sensitivity:\n" 48 | "JTFS/TS = {:.1f}\n" 49 | "TS = {:.4f}\n" 50 | "JTFS = {:.4f}\n").format(l2_jtfs / l2_ts, l2_ts, l2_jtfs)) 51 | 52 | #%%# Make GIF ################################################################ 53 | data = {'x': [x, xs], 54 | 'cwt': [cwt_x, cwt_xs], 55 | 'ts': [ts_x, ts_xs], 56 | 'jtfs': [jtfs_x, jtfs_xs]} 57 | xmx = max(abs(x).max(), abs(xs).max())*1.05 58 | cmx = max(np.abs(cwt_x).max(), np.abs(cwt_xs).max()) 59 | tmx = max(np.abs(ts_x).max(), np.abs(ts_xs).max()) 60 | jmx = max(np.abs(jtfs_x).max(), np.abs(jtfs_xs).max()) * .95 61 | 62 | for i in (0, 1): 63 | _x, _cwt, _ts, _jtfs = [data[k][i] for k in data] 64 | 65 | fig = plt.figure() 66 | fig = plt.figure(figsize=(14, 14)) 67 | ax0 = fig.add_subplot(12, 2, (1, 9)) 68 | ax1 = fig.add_subplot(12, 2, (2, 10)) 69 | ax2 = fig.add_subplot(14, 2, (13, 14)) 70 | ax3 = fig.add_subplot(14, 2, (15, 16)) 71 | axes = [ax0, ax1, ax2, ax3] 72 | 73 | imshow_kw = dict(cmap='turbo', aspect='auto', vmin=0) 74 | s, e = 1700, len(x) - 1200 75 | tks = np.arange(s, e) 76 | ax0.plot(tks, _x[s:e]) 77 | ax1.imshow(_cwt, **imshow_kw, vmax=cmx) 78 | ax2.imshow(_ts.T, **imshow_kw, vmax=tmx) 79 | ax3.imshow(_jtfs.T, **imshow_kw, vmax=jmx) 80 | 81 | # remove ticks 82 | for ax in (ax2, ax3): 83 | ax.set_xticks([]) 84 | ax.set_yticks([]) 85 | ax0.set_yticks([]) 86 | ax1.set_yticks([]) 87 | 88 | # set common limits 89 | ax0.set_ylim(-xmx, xmx) 90 | 91 | fig.subplots_adjust(hspace=.25, wspace=.02) 92 | 93 | # reduce y-spacing between bottom two subplots 94 | yspace = .05 95 | pos_ax3 = ax3.get_position() 96 | pos_ax2 = ax2.get_position() 97 | ywidth = pos_ax3.y1 - pos_ax3.y0 98 | pos_ax3.y0 = pos_ax2.y0 - ywidth * yspace 99 | pos_ax3.y1 = pos_ax2.y0 - ywidth * (1 + yspace) 100 | ax3.set_position(pos_ax3) 101 | 102 | fig.savefig(f'jtfs_ts{i}.png', bbox_inches='tight') 103 | plt.close(fig) 104 | 105 | make_gif('', 'jtfs_ts.gif', duration=1000, start_end_pause=0, 106 | delimiter='jtfs_ts', overwrite=1, verbose=1) 107 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/filterbank.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize joint JTFS filterbank (2D).""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | from kymatio.numpy import TimeFrequencyScattering1D 5 | from kymatio.visuals import make_gif 6 | from kymatio import visuals 7 | import matplotlib.pyplot as plt 8 | 9 | #%% Viz frequential filterbank (1D) ########################################## 10 | jtfs = TimeFrequencyScattering1D(shape=512, J=7, Q=(16, 1), J_fr=5, Q_fr=2, F=8, 11 | normalize='l1') 12 | visuals.filterbank_jtfs_1d(jtfs, zoom=-1, lp_sum=1) 13 | 14 | #%% Show joint wavelets on a smaller filterbank ############################## 15 | jtfs = TimeFrequencyScattering1D(shape=512, J=4, Q=(16, 1), J_fr=3, Q_fr=1, F=8) 16 | 17 | #%% 18 | # nearly invisible, omit (also unused in scattering per `j2 > 0`) 19 | jtfs.psi2_f.pop(0) 20 | # not nearly invisible but still makes plot too big 21 | jtfs.psi2_f.pop(0) 22 | jtfs.psi1_f_fr_up.pop(0) 23 | jtfs.psi1_f_fr_down.pop(0) 24 | 25 | #%% Make GIF ################################################################ 26 | base_name = 'filterbank' 27 | for i, part in enumerate(('real', 'imag', 'complex')): 28 | fig, axes = visuals.filterbank_jtfs_2d(jtfs, part=part, labels=1, 29 | suptitle_y=1.03) 30 | # fig.show() 31 | # break 32 | fig.savefig(f'{base_name}{i}.png', bbox_inches='tight') 33 | plt.close(fig) 34 | 35 | make_gif('', f'{base_name}.gif', duration=2000, start_end_pause=0, 36 | delimiter=base_name, overwrite=1) 37 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/freq_tp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize log-frequency shift.""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | import numpy as np 5 | from kymatio.numpy import Scattering1D 6 | from kymatio.visuals import imshow, make_gif 7 | from kymatio.toolkit import fdts 8 | import matplotlib.pyplot as plt 9 | 10 | #%% Make CWT object ########################################################## 11 | N = 2048 12 | cwt = Scattering1D(shape=N, J=7, Q=8, average=False, out_type='list', 13 | r_psi=.85, oversampling=999, max_order=1) 14 | 15 | #%% Make signals & take CWT ################################################## 16 | x0 = fdts(N, n_partials=4, seg_len=N//4, f0=N/12)[0] 17 | x1 = fdts(N, n_partials=4, seg_len=N//4, f0=N/20)[0] 18 | 19 | Wx0 = np.array([c['coef'].squeeze() for c in cwt(x0)[1:]]) 20 | Wx1 = np.array([c['coef'].squeeze() for c in cwt(x1)[1:]]) 21 | 22 | #%% Make GIF ################################################################# 23 | fig0, ax0 = plt.subplots(1, 2, figsize=(12, 5)) 24 | fig1, ax1 = plt.subplots(1, 2, figsize=(12, 5)) 25 | # imshows 26 | kw = dict(abs=1, ticks=0, show=0) 27 | imshow(Wx0, ax=ax0[1], fig=fig0, **kw) 28 | imshow(Wx1, ax=ax1[1], fig=fig1, **kw) 29 | 30 | # plots 31 | s, e = N//3, -N//3 # zoom 32 | ax0[0].plot(x0[s:e]) 33 | ax1[0].plot(x1[s:e]) 34 | # ticks & ylims 35 | mx = max(np.abs(x0).max(), np.abs(x1).max()) * 1.03 36 | for ax in (ax0, ax1): 37 | ax[0].set_xticks([]) 38 | ax[0].set_yticks([]) 39 | ax[0].set_ylim(-mx, mx) 40 | # titles 41 | title_kw = dict(weight='bold', fontsize=20, loc='left') 42 | for ax in (ax0, ax1): 43 | ax[0].set_title("x (zoomed)", **title_kw) 44 | ax[1].set_title("|CWT(x)|", **title_kw) 45 | 46 | # finalize & save 47 | base_name = 'freq_tp' 48 | fig0.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=.02) 49 | fig1.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=.02) 50 | fig0.savefig(f'{base_name}0.png', bbox_inches='tight') 51 | fig1.savefig(f'{base_name}1.png', bbox_inches='tight') 52 | plt.close(fig0) 53 | plt.close(fig1) 54 | 55 | # make GIF 56 | make_gif('', f'{base_name}.gif', duration=1000, start_end_pause=0, 57 | delimiter=base_name, verbose=1, overwrite=1) 58 | -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_full.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_full.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_spinned.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_echirp_spinned.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_seiz.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_seiz.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_trumpet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs3d_trumpet.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp_wavelets.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_echirp_wavelets.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_sine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_sine.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_ts.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/jtfs_ts.gif -------------------------------------------------------------------------------- /SignalProcessing/Q78644 - Joint Time-Frequency Scattering explanation/sine_2d_anim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Visualize JTFS of pure sine: as GIF of joint slices (2D).""" 3 | # https://dsp.stackexchange.com/q/78622/50076 ################################ 4 | import numpy as np 5 | from kymatio.numpy import TimeFrequencyScattering1D 6 | from kymatio.toolkit import energy 7 | from kymatio.visuals import plot, imshow 8 | from kymatio import visuals 9 | 10 | #%% Generate sine and create scattering object ############################### 11 | # `pow2 + 1` with `endpoint==1` to pad to perfect sine 12 | N = 4097 13 | x = np.cos(2*np.pi * (N//16) * np.linspace(0, 1, N, 1)) 14 | 15 | # 9 temporal octaves 16 | # largest scale is 2**9 [samples] / 4096 [samples / sec] == 125 ms 17 | J = 9 18 | # 8 bandpass wavelets per octave 19 | # J*Q ~= 144 total temporal coefficients in first-order scattering 20 | Q = 16 21 | # scale of temporal invariance, 31.25 ms 22 | T = 2**7 23 | # 4 frequential octaves 24 | J_fr = 4 25 | # 2 bandpass wavelets per octave 26 | Q_fr = 2 27 | # scale of frequential invariance, F/Q == 0.5 cycle per octave 28 | F = 8 29 | # do frequential averaging to enable 4D concatenation 30 | average_fr = True 31 | # frequential padding; 'zero' avoids few discretization artefacts for this example 32 | pad_mode_fr = 'zero' 33 | # return packed as dict keyed by pair names for easy inspection 34 | out_type = 'dict:array' 35 | 36 | params = dict(J=J, Q=Q, T=T, J_fr=J_fr, Q_fr=Q_fr, F=F, average_fr=average_fr, 37 | out_type=out_type, pad_mode_fr=pad_mode_fr, max_pad_factor=4, 38 | max_pad_factor_fr=4, oversampling=999, oversampling_fr=999) 39 | jtfs = TimeFrequencyScattering1D(shape=N, **params) 40 | 41 | #%% Take JTFS, print pair names and shapes ################################### 42 | Scx = jtfs(x) 43 | print("JTFS pairs:") 44 | for pair in Scx: 45 | print("{:<12} -- {}".format(str(Scx[pair].shape), pair)) 46 | 47 | E_up = energy(Scx['psi_t * psi_f_up']) 48 | E_dn = energy(Scx['psi_t * psi_f_down']) 49 | print("E_down / E_up = {:.1f}".format(E_dn / E_up)) 50 | #%% Show `x` and its (time-averaged) scalogram ############################### 51 | plot(x, show=1, w=.7, 52 | xlabel="time [samples]", 53 | title="Pure sine | fmin=64, fmax=2048, 4096 samples") 54 | #%% 55 | freqs = jtfs.meta()['xi']['S1'][:, -1] 56 | imshow(Scx['S1'].squeeze(), abs=1, w=.8, h=.67, yticks=freqs, 57 | xlabel="time [samples]", 58 | ylabel="frequency [frac of fs]", 59 | title="Scalogram, time-averaged (first-order scattering)") 60 | 61 | #%% Create & save GIF ######################################################## 62 | # fetch meta (structural info) 63 | jmeta = jtfs.meta() 64 | # specify save folder 65 | savedir = '' 66 | # time between GIF frames (ms) 67 | duration = 150 68 | visuals.gif_jtfs_2d(Scx, jmeta, savedir=savedir, base_name='jtfs_sine', 69 | save_images=0, overwrite=1, gif_kw={'duration': duration}) 70 | -------------------------------------------------------------------------------- /SignalProcessing/Q80918 - STFT - Overlap, window length, uncertainty principle/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/a/80920/50076 ################################ 3 | import numpy as np 4 | from ssqueezepy import ssq_stft 5 | from ssqueezepy.visuals import imshow 6 | from scipy.signal.windows import dpss 7 | 8 | # signal 9 | N = 4096 10 | f = 32 11 | t = np.linspace(0, 1, N, 0)[::-1] 12 | fm = np.cos(2*np.pi * f * t) / (100*f) 13 | x = np.cos(2*np.pi * f*8 * (t + fm)) 14 | 15 | # window 16 | n_fft = 512 17 | window = dpss(n_fft, n_fft//2 - 1) 18 | 19 | # STFT & SSQ 20 | Tx0, Sx0, *_ = ssq_stft(x, window, n_fft=n_fft, flipud=0, hop_len=1) 21 | Tx1, Sx1, *_ = ssq_stft(x, window, n_fft=n_fft, flipud=0, hop_len=64) 22 | 23 | # visualize ################################################################## 24 | # cheat & drop boundary effects for clarity 25 | Sx1, Tx1 = Sx1[:, 1:], Tx1[:, 1:] 26 | kw = dict(abs=1, interpolation='none', w=.8, h=.6) 27 | imshow(Sx0, **kw, title=("|STFT| -- hop_len=1", {'fontsize': 18})) 28 | imshow(Tx0, **kw, title=("|SSQ_STFT| -- hop_len=1", {'fontsize': 18})) 29 | imshow(Sx1, **kw, title=("|STFT| -- hop_len=64", {'fontsize': 18})) 30 | imshow(Tx1, **kw, title=("|SSQ_STFT| -- hop_len=64", {'fontsize': 18})) 31 | -------------------------------------------------------------------------------- /SignalProcessing/Q81624 - STFT, optimization - Optimize window length for STFT/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Optimize window length for STFT, plain window demo.""" 3 | # https://dsp.stackexchange.com/q/81624/50076 ################################ 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | 7 | def plot(x, title): 8 | plt.plot(x) 9 | plt.title(title, loc='left', weight='bold', fontsize=18) 10 | plt.gcf().set_size_inches(9, 7) 11 | plt.show(x) 12 | 13 | 14 | def hann(N): 15 | Nint = torch.floor(N).int() 16 | Nc = torch.clamp(N, min=Nint, max=Nint + 1) 17 | t = torch.arange(Nint) / Nc 18 | w = .5 * (1 - torch.cos(2*torch.pi *t)) 19 | return w 20 | 21 | 22 | N = nn.Parameter(torch.tensor(129.)) 23 | 24 | # optimal window length 25 | N_ref = 160 26 | w_ref = hann(torch.tensor(float(N_ref))) 27 | # add room for overshoot 28 | x = F.pad(w_ref, [40, 40]) 29 | 30 | LR = 100000 31 | Ns, grads, outs = [], [], [] 32 | for i in range(20): 33 | w = hann(N) 34 | w = w / torch.norm(w) # L2 norm to ensure `max(conv)` peaks at `N_ref` 35 | conv = torch.conv1d(x[None, None], w[None, None]) 36 | out = 1. / torch.max(conv) # inverse of peak cross-correlation to minimize 37 | out.backward() 38 | 39 | with torch.no_grad(): 40 | N -= LR * N.grad 41 | 42 | Ns.append(float(N.detach().numpy())) 43 | grads.append(float(N.grad.detach().numpy())) 44 | outs.append(float(out.detach().numpy())) 45 | 46 | 47 | # manual `optimizer.zero_grad()` 48 | N.grad.requires_grad_(False) 49 | N.grad.zero_() 50 | 51 | plot(Ns, "N vs iteration") 52 | plot(grads, "N.grad vs iteration") 53 | plot(outs, "loss vs iteration") 54 | -------------------------------------------------------------------------------- /SignalProcessing/Q85745 - Energy, power, relation to sampling rate and duration/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/85745/50076 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | # from ?.toolkit import fft_upsample # to be released soon 6 | 7 | def E(x): 8 | return np.sum(np.abs(x)**2) 9 | 10 | def viz(t1, t2, x1, x2, mode='bar', title=True): 11 | tkw = dict(fontsize=18, weight='bold', loc='left') 12 | fig, axes = plt.subplots(1, 2, sharey=True, figsize=(15, 6)) 13 | 14 | if mode == 'bar': 15 | Dt1 = t1[1] - t1[0] # sampling period 16 | Dt2 = t2[1] - t2[0] 17 | axes[0].bar(t1, x1, .9*Dt1) 18 | axes[1].bar(t2, x2, .9*Dt2) 19 | else: 20 | axes[0].plot(t1, x1) 21 | axes[1].plot(t2, x2) 22 | 23 | if title: 24 | if title is True: 25 | title1 = "sum(|x|^2)={:.3g}, N={}".format(E(x1), len(x1)) 26 | title2 = "sum(|x|^2)={:.3g}, N={}".format(E(x2), len(x2)) 27 | elif isinstance(title, tuple): 28 | title1, title2 = title 29 | axes[0].set_title(title1, **tkw) 30 | axes[1].set_title(title2, **tkw) 31 | 32 | fig.subplots_adjust(wspace=.03) 33 | plt.show() 34 | 35 | def viz_sines(N1, N2, T1, T2, f1=1, f2=1): 36 | t1 = np.linspace(0, T1, N1, endpoint=False) 37 | t2 = np.linspace(0, T2, N2, endpoint=False) 38 | x1 = np.cos(2*np.pi * f1 * t1) 39 | x2 = np.cos(2*np.pi * f2 * t2) 40 | 41 | viz(t1, t2, x1, x2, title=True) 42 | 43 | #%% 44 | N1, N2 = 10, 20 45 | for duration in (1, 1.25): 46 | viz_sines(N1, N2, duration, duration) 47 | 48 | #%% 49 | N = 20 50 | T1, T2 = 1, 2 51 | f1, f2 = 2, 1 52 | viz_sines(N, N, T1, T2, f1, f2) 53 | 54 | #%% 55 | # t = np.linspace(0, 1.25, 20, endpoint=False) 56 | # x = np.cos(2*np.pi * t) 57 | # x_up = fft_upsample(x, factor=128, time_to_time=True, real=True) 58 | # t_up = np.linspace(0, 1.25, len(x_up), endpoint=False) 59 | # x_up4 = np.hstack([x_up]*4) 60 | # t_up4 = np.linspace(0, 1.25*4, len(x_up)*4, endpoint=False) 61 | 62 | # viz(t_up, t_up4, x_up, x_up4, mode='plot', 63 | # title=("x_upsampled", "x_upsampled, 4 periods")) 64 | -------------------------------------------------------------------------------- /SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_example.m: -------------------------------------------------------------------------------- 1 | % Answer to https://dsp.stackexchange.com/q/86181/50076 2 | % and to https://dsp.stackexchange.com/q/82985/50076 3 | %% Configure ################################################################ 4 | fs = 400; % (in Hz) anything works 5 | duration = 5; % (in sec) anything works 6 | padtype = 'reflection'; % anything (supported) works 7 | 8 | % get power between these frequencies 9 | freq_min = 50; % (in Hz) 10 | freq_max = 150; % (in Hz) 11 | 12 | %% Obtain transform ######################################################### 13 | % make signal & filterbank 14 | % assume this is Amperes passing through 1 Ohm resistor; P = I^2*R, E = P * T 15 | % check actual physical units for your specific application and adjust accordingly 16 | rng('default') 17 | x = randn(1, fs * duration); 18 | fb = cwtfilterbank('Wavelet', 'amor', 'SignalLength', fs * duration, ... 19 | 'VoicesPerOctave', 11, 'SamplingFrequency', fs, ... 20 | 'Boundary', padtype); 21 | 22 | % transform, get frequencies 23 | [Wx, freqs] = fb.wt(x); 24 | 25 | % fetch coefficients according to `freq_min_max` 26 | Wx_spec = Wx((freq_min < freqs) & (freqs < freq_max), :); 27 | 28 | %% "Adjustments" ############################################################ 29 | % See "Practical adjustments" in the answer 30 | 31 | % fetch wavelets in freq domain, compute ET & ES transfer funcs, fethch maxima 32 | psi_fs = fb.PsiDFT; % fetch wavelets in freq domain 33 | ET_tfn = sum(abs(psi_fs).^2, 1); 34 | ES_tfn = abs(sum(psi_fs, 1)).^2; 35 | % real-valued case adjustment: 36 | % - ET since we operate on half of spectrum (half as many coeffs) 37 | % - ES since `.real` halves spectrum, quartering energy on real side; 38 | % note, for this to work right, the filterbank must be halved at Nyquist 39 | % by design (which should also be done for sake of temporal decay) 40 | ET_adj = max(ET_tfn) / 2; 41 | ES_adj = max(ES_tfn) / 4; 42 | 43 | %% Compute energy & power ################################################### 44 | % compute energy & power (discrete) 45 | ET_disc = sum(abs(Wx_spec).^2, 'all') / ET_adj; 46 | ES_disc = sum(abs(real(sum(Wx_spec, 1))).^2, 'all') / ES_adj; 47 | PT_disc = ET_disc / length(x); 48 | PS_disc = ES_disc / length(x); 49 | 50 | % compute energy & power (physical); estimate underlying continuous waveform via 51 | % Riemann integration 52 | sampling_period = 1 / fs; 53 | ET_phys = ET_disc * sampling_period; 54 | ES_phys = ES_disc * sampling_period; 55 | PT_phys = ET_phys / duration; 56 | PS_phys = ES_phys / duration; 57 | 58 | % repeat for original signal 59 | Ex_disc = sum(abs(x).^2, 'all'); 60 | Px_disc = Ex_disc / length(x); 61 | Ex_phys = Ex_disc * sampling_period; 62 | Px_phys = Ex_phys / duration; 63 | 64 | %% Report ################################################################### 65 | s = ['Between %d and %d Hz, DISCRETE:\n'... 66 | '%-9.6g -- energy (transform)\n'... 67 | '%-9.6g -- energy (signal)\n'... 68 | '%-9.6g -- mean power (transform)\n'... 69 | '%-9.6g -- mean power (signal)\n\n']; 70 | fprintf(s, freq_min, freq_max, ET_disc, ES_disc, PT_disc, PS_disc) 71 | 72 | 73 | s = ['Between %d and %d Hz, PHYSICAL (via Riemann integration):\n'... 74 | '%-9.6g Joules -- energy (transform)\n'... 75 | '%-9.6g Joules -- energy (signal)\n'... 76 | '%-9.6g Watts -- mean power (transform)\n'... 77 | '%-9.6g Watts -- mean power (signal)\n\n']; 78 | fprintf(s, freq_min, freq_max, ET_phys, ES_phys, PT_phys, PS_phys) 79 | 80 | s = ['Original signal:\n'... 81 | '%-9.6g -- energy (discrete)\n'... 82 | '%-9.6g -- mean power (discrete)\n'... 83 | '%-9.6g Joules -- energy (physical)\n'... 84 | '%-9.6g Watts -- mean power (physical)\n\n']; 85 | fprintf(s, Ex_disc, Px_disc, Ex_phys, Px_phys) 86 | -------------------------------------------------------------------------------- /SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Answer to https://dsp.stackexchange.com/q/86181/50076 3 | import numpy as np 4 | from ssqueezepy import cwt, Wavelet 5 | from ssqueezepy.experimental import scale_to_freq 6 | 7 | def E(x): 8 | return np.sum(np.abs(x)**2) 9 | 10 | #%% Configure ################################################################ 11 | fs = 400 # (in Hz) anything works 12 | duration = 5 # (in sec) anything works 13 | padtype = 'reflect' # anything (supported) works 14 | 15 | # get power between these frequencies 16 | freq_min = 50 # (in Hz) 17 | freq_max = 150 # (in Hz) 18 | 19 | #%% Obtain transform ######################################################### 20 | # make signal & wavelet 21 | # assume this is Amperes passing through 1 Ohm resistor; P = I^2*R, E = P * T 22 | # check actual physical units for your specific application and adjust accordingly 23 | np.random.seed(0) 24 | x = np.random.randn(fs * duration) 25 | wavelet = Wavelet() 26 | 27 | # transform, get frequencies 28 | Wx, scales = cwt(x, wavelet, padtype=padtype) 29 | freqs = scale_to_freq(scales, wavelet, len(x), fs=fs, padtype=padtype) 30 | 31 | # fetch coefficients according to `freq_min` and `freq_max` 32 | Wx_spec = Wx[(freq_min < freqs) * (freqs < freq_max)] 33 | 34 | #%% Normalization ############################################################ 35 | # See "Normalization" in the answer 36 | 37 | # fetch wavelets in freq domain, compute ET & ES transfer funcs, fetch maxima 38 | psi_fs = wavelet._Psih # fetch wavelets in freq domain 39 | ET_tfn = np.sum(np.abs(psi_fs)**2, axis=0) 40 | ES_tfn = np.abs(np.sum(psi_fs, axis=0))**2 41 | # real-valued case adjustment: 42 | # - ET since we operate on half of spectrum (half as many coeffs) 43 | # - ES since `.real` halves spectrum, quartering energy on real side; 44 | # note, for this to work right, the filterbank must be halved at Nyquist 45 | # by design (which should also be done for sake of temporal decay) 46 | ET_adj = ET_tfn.max() / 2 47 | ES_adj = ES_tfn.max() / 4 48 | 49 | #%% Compute energy & power ################################################### 50 | # compute energy & power (discrete) 51 | ET_disc = np.sum(np.abs(Wx_spec)**2) / ET_adj 52 | ES_disc = np.sum(np.abs(np.sum(Wx_spec, axis=0).real)**2) / ES_adj 53 | PT_disc = ET_disc / len(x) 54 | PS_disc = ES_disc / len(x) 55 | 56 | # compute energy & power (physical); estimate underlying continuous waveform via 57 | # Riemann integration 58 | sampling_period = 1 / fs 59 | ET_phys = ET_disc * sampling_period 60 | ES_phys = ES_disc * sampling_period 61 | PT_phys = ET_phys / duration 62 | PS_phys = ES_phys / duration 63 | 64 | # repeat for original signal 65 | Ex_disc = E(x) 66 | Px_disc = Ex_disc / len(x) 67 | Ex_phys = Ex_disc * sampling_period 68 | Px_phys = Ex_phys / duration 69 | 70 | #%% Report ################################################################### 71 | print(("Between {:d} and {:d} Hz, DISCRETE:\n" 72 | "{:<9.6g} -- energy (transform)\n" 73 | "{:<9.6g} -- energy (signal)\n" 74 | "{:<9.6g} -- mean power (transform)\n" 75 | "{:<9.6g} -- mean power (signal)\n" 76 | ).format(freq_min, freq_max, ET_disc, ES_disc, PT_disc, PS_disc)) 77 | 78 | print(("Between {:d} and {:d} Hz, PHYSICAL (via Riemann integration):\n" 79 | "{:<9.6g} Joules -- energy (transform)\n" 80 | "{:<9.6g} Joules -- energy (signal)\n" 81 | "{:<9.6g} Watts -- mean power (transform)\n" 82 | "{:<9.6g} Watts -- mean power (signal)\n" 83 | ).format(freq_min, freq_max, ET_phys, ES_phys, PT_phys, PS_phys)) 84 | 85 | print(("Original signal:\n" 86 | "{:<9.6g} -- energy (discrete)\n" 87 | "{:<9.6g} -- mean power (discrete)\n" 88 | "{:<9.6g} Joules -- energy (physical)\n" 89 | "{:<9.6g} Watts -- mean power (physical)\n").format( 90 | Ex_disc, Px_disc, Ex_phys, Px_phys)) 91 | -------------------------------------------------------------------------------- /SignalProcessing/Q86181 - CWT Power and Energy/cwt_power_validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Answer to https://dsp.stackexchange.com/q/86181/50076 3 | import numpy as np 4 | from numpy.fft import fft, ifft 5 | 6 | def E(x): 7 | return np.sum(np.abs(x)**2) 8 | 9 | #%% Generate filterbank ###################################################### 10 | # show that it works with any filters 11 | # note ES still works despite violating condition 3, but condition 3 12 | # is still needed for valid interpretation (sum won't invert to signal) 13 | psi_fs = np.random.randn(51, 256) + 1j*np.random.randn(51, 256) 14 | # wavelets are zero-mean 15 | psi_fs[:, 0] = 0 16 | # "any" but still analytic here, need more code for "any any" 17 | psi_fs[:, 129:] = 0 18 | # nyquist imaginary part must be zero else inverse can't be real-valued 19 | psi_fs[:, 128].imag = 0 20 | 21 | 22 | #%% Compute "transfer functions" ############################################# 23 | ET_tfn = np.sum(np.abs(psi_fs)**2, axis=0) 24 | ES_tfn = np.abs(np.sum(psi_fs, axis=0))**2 25 | 26 | #%% Run tests ################################################################ 27 | for case in ('real', 'complex'): 28 | # generate signal 29 | M = 256 30 | x = np.random.randn(M) 31 | if case == 'complex': 32 | x = x + 1j * np.random.randn(M) 33 | xf = fft(x) 34 | 35 | # compute CWT 36 | out = ifft(xf * psi_fs) 37 | 38 | # invert 39 | x_inv = out.sum(axis=0) 40 | if case == 'real': 41 | x_inv = x_inv.real 42 | 43 | # compute energies via transfer fns 44 | xfe = np.abs(xf)**2 45 | ET = xfe * ET_tfn / M 46 | ES = xfe * ES_tfn / M 47 | if case == 'real': 48 | ES[1:M//2] /= 2 49 | ET, ES = ET.sum(), ES.sum() 50 | 51 | # assert agreement 52 | assert np.allclose(ET, E(out)) 53 | assert np.allclose(ES, E(x_inv)) 54 | -------------------------------------------------------------------------------- /SignalProcessing/Q86726 - STFT - Amplitude extraction/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/86726/50076 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.signal.windows import hann, dpss 6 | from numpy.fft import fft, fftshift, ifftshift 7 | from ssqueezepy import stft, ssq_stft 8 | from ssqueezepy.visuals import plotscat, imshow 9 | 10 | #%%############################################################################ 11 | # Main example 12 | # ------------ 13 | 14 | def viz(Sx, Sxs): 15 | slcs = Sxs[:, 0] 16 | imshow(Sxs, abs=1, w=.7, h=.55, interpolation='none') 17 | plotscat(slcs, complex=1, w=.6, h=.82, show=1) 18 | print("sum(Sx[:, 0]) =", slcs.sum(), flush=True) 19 | print("max(abs(Sx[:, 0])) / sum(window) =", abs(slcs).max() / window.sum(), 20 | flush=True) 21 | 22 | # gen signal 23 | N = 256 24 | t = np.linspace(0, 1, N, 0) 25 | x = np.cos(2*np.pi * 60 * t) 26 | 27 | # gen window 28 | window = dpss(N//2, N//8 - 1, sym=False) 29 | # this is always done under the hood, `len(window) == n_fft` 30 | window = np.pad(window, (N - len(window)) // 2) 31 | 32 | # match input length so `x` is seen as perfect sinusoid by DFT; 33 | # can also be integer fraction of `N` but then `x` can't be of all int freqs 34 | n_fft = N 35 | # in general should be 'reflect' but slightly complicates example here, simplify 36 | padtype = 'wrap' 37 | 38 | Sx = stft(x, window, n_fft=n_fft, hop_len=1, padtype=padtype, dtype='float64') 39 | Sxs = Sx.copy() 40 | Sxs[1:-1] *= 2 41 | viz(Sx, Sxs) 42 | 43 | #%% 44 | Tx = ssq_stft(x, window, n_fft=len(window), padtype=padtype)[0] 45 | # ... *almost* all the work 46 | Txs = Tx.copy() 47 | Txs[1:-1] *= 2 48 | viz(Txs, Txs) 49 | 50 | # ignore ssqueezepy warning (there is as of writing, false positive) 51 | 52 | #%%############################################################################ 53 | # Criteria demos 54 | # -------------- 55 | 56 | # helper 57 | def viz2(S, w, wf, sig_title, tm_idx=128): 58 | slc = S[:, tm_idx] 59 | fig, axes = plt.subplots(1, 2, figsize=(15, 7)) 60 | imshow(S, abs=1, ax=axes[0], fig=fig, show=0, 61 | title="abs(Sx_adj) | " + sig_title, interpolation='none') 62 | plotscat(slc, abs=1, ax=axes[1], fig=fig, 63 | title=f"abs(Sx_adj[:, {tm_idx}])") 64 | fig.subplots_adjust(wspace=.1) 65 | plt.show() 66 | 67 | print( 68 | ("max(abs(x)) = {1:.3g}\n" 69 | "max(abs(Sx_adj[:, {0}])) = {2:.3g}\n" 70 | "sum(abs(Sx_adj[:, {0}])) = {3:.3g}\n" 71 | "max(abs(Sx_adj[:, {0}])) / sum(window_adj) = {4:.3g} -- time norm\n" 72 | "sum(abs(Sx_adj[:, {0}])) / sum(abs(fft(window_adj_f))) = {5:.3g} " 73 | "-- freq norm\n" 74 | ).format( 75 | tm_idx, 76 | abs(x).max(), 77 | sum(abs(S[:, tm_idx])), 78 | max(abs(S[:, tm_idx])), 79 | max(abs(S[:, tm_idx])) / sum(w), 80 | sum(abs(S[:, tm_idx])) / sum(abs(fft(wf))), 81 | ) 82 | ) 83 | 84 | 85 | def _pad_window(w, padded_len): 86 | pleft = (padded_len - len(w)) // 2 87 | pright = padded_len - pleft - len(w) 88 | return np.pad(w, [pleft, pright]) 89 | 90 | def get_adj_windows(window, n_fft, N): 91 | padded_len_conv = N + n_fft - 1 92 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2) 93 | # window_adj_f = np.pad(window, (n_fft - len(window))//2) 94 | window_adj = _pad_window(window, padded_len_conv) 95 | window_adj_f = _pad_window(window, n_fft) 96 | 97 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0 98 | def _center(w): 99 | w = ifftshift(w) 100 | w = fftshift(np.roll(w, np.argmax(w))) 101 | return w 102 | 103 | window_adj = _center(window_adj) 104 | window_adj_f = _center(window_adj_f) 105 | return window_adj, window_adj_f 106 | 107 | N = 256 108 | t = np.linspace(0, 1, N, 0) 109 | 110 | #%%############################################################################ 111 | # Everything's right 112 | # ^^^^^^^^^^^^^^^^^^ 113 | n_fft = N 114 | window = dpss(N//8, N//8//2 - 1, sym=False) 115 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 116 | 117 | f = 60 118 | x = np.cos(2*np.pi * f * t) 119 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype) 120 | Sx_adj = Sx.copy() 121 | Sx_adj[1:-1] *= 2 122 | 123 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)") 124 | 125 | #%%############################################################################ 126 | # Nonstationary case 127 | # ^^^^^^^^^^^^^^^^^^ 128 | f = 60 129 | x = np.cos(2*np.pi * f * t**2) 130 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype) 131 | Sx_adj = Sx.copy() 132 | Sx_adj[1:-1] *= 2 133 | 134 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t**2)") 135 | 136 | #%%############################################################################ 137 | # Too close to Nyquist 138 | # ^^^^^^^^^^^^^^^^^^^^ 139 | f = N//2 - 10 140 | x = np.cos(2*np.pi * f * t) 141 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype) 142 | Sx_adj = Sx.copy() 143 | Sx_adj[1:-1] *= 2 144 | 145 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)", 32) 146 | 147 | #%%############################################################################ 148 | # Too close to DC 149 | # ^^^^^^^^^^^^^^^ 150 | f = 10 151 | x = np.cos(2*np.pi * f * t) 152 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64') 153 | Sx_adj = Sx.copy() 154 | Sx_adj[1:-1] *= 2 155 | 156 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)", 32) 157 | 158 | #%%############################################################################ 159 | # Multi-component intersection 160 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 161 | f, fm = 60, 50 162 | x = np.cos(2*np.pi * f * t) + np.cos(2*np.pi * fm * t**2) 163 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype) 164 | Sxs = Sx.copy() 165 | Sxs[1:-1] *= 2 166 | 167 | viz2(Sxs, window_adj, window_adj_f, f"\ncos(2*pi*{f}*t) + cos(2*pi*{fm}*t**2)") 168 | 169 | #%%############################################################################ 170 | # Insufficient `n_fft` 171 | # ^^^^^^^^^^^^^^^^^^^^ 172 | f = 60 173 | x = np.cos(2*np.pi * f * t) 174 | n_fft = 36 175 | window = hann(n_fft, sym=False) 176 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 177 | 178 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64') 179 | Sx_adj = Sx.copy() 180 | Sx_adj[1:-1] *= 2 181 | 182 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{f}*t)") 183 | 184 | #%%############################################################################ 185 | # Excessive `hop_size` 186 | # ^^^^^^^^^^^^^^^^^^^^ 187 | fc, fa = 60, 2 188 | x = np.sin(2*np.pi * fa * t) * np.cos(2*np.pi * fc * t) 189 | n_fft = len(x) 190 | window = dpss(n_fft, n_fft//2 - 1, sym=False) 191 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 192 | 193 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64', 194 | hop_len=64) 195 | Sx_adj = Sx.copy() 196 | Sx_adj[1:-1] *= 2 197 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{fc}*t) * sin(2*pi*{fa}*t)", 198 | 2) 199 | 200 | #%%############################################################################ 201 | # Minimal `hop_size` 202 | # ^^^^^^^^^^^^^^^^^^ 203 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype, dtype='float64', 204 | hop_len=1) 205 | Sx_adj = Sx.copy() 206 | Sx_adj[1:-1] *= 2 207 | viz2(Sx_adj, window_adj, window_adj_f, f"cos(2*pi*{fc}*t) * sin(2*pi*{fa}*t)", 208 | 96) 209 | 210 | #%%############################################################################ 211 | # Non-localized window 212 | # ^^^^^^^^^^^^^^^^^^^^ 213 | np.random.seed(0) 214 | n_fft = N 215 | window = np.abs(np.random.randn(n_fft//4) + .01) 216 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 217 | 218 | f = 60 219 | x = np.cos(2*np.pi * f * t) 220 | Sx = stft(x, window, n_fft=n_fft, padtype=padtype) 221 | Sxs = Sx.copy() 222 | Sxs[1:-1] *= 2 223 | 224 | viz2(Sxs, window_adj, window_adj_f, f"cos(2*pi*{f}*t)") 225 | -------------------------------------------------------------------------------- /SignalProcessing/Q86937 - STFT - Equivalence of Windowed Fourier and Convolutions/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/85745/50076 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.signal.windows import dpss 6 | from numpy.fft import fft, ifft, fftshift, ifftshift 7 | 8 | from ssqueezepy import stft 9 | from ssqueezepy.visuals import plot, plotscat, imshow 10 | 11 | #%%############################################################################ 12 | # Helpers 13 | # ------- 14 | def _pad_window(w, padded_len): 15 | pleft = (padded_len - len(w)) // 2 16 | pright = padded_len - pleft - len(w) 17 | return np.pad(w, [pleft, pright]) 18 | 19 | def get_adj_windows(window, n_fft, N): 20 | padded_len_conv = N + n_fft - 1 21 | # window_adj = np.pad(window, (padded_len_conv - len(window))//2) 22 | # window_adj_f = np.pad(window, (n_fft - len(window))//2) 23 | window_adj = _pad_window(window, padded_len_conv) 24 | window_adj_f = _pad_window(window, n_fft) 25 | 26 | # shortcut for later examples to spare code; ensure ifftshift center at idx 0 27 | def _center(w): 28 | w = ifftshift(w) 29 | w = fftshift(np.roll(w, np.argmax(w))) 30 | return w 31 | 32 | window_adj = _center(window_adj) 33 | window_adj_f = _center(window_adj_f) 34 | return window_adj, window_adj_f 35 | 36 | #%%############################################################################ 37 | # Row-wise STFT implementation 38 | # ---------------------------- 39 | def cisoid(N, f): 40 | t = np.linspace(0, 1, N, endpoint=False) 41 | return (np.cos(2*np.pi * f * t) + 42 | np.sin(2*np.pi * f * t) * 1j) 43 | 44 | 45 | def stft_rowwise(x, window, n_fft): 46 | """ 47 | - no hop or pad support 48 | - equivalent to `'wrap'` (periodic) pad if `len(window) < len(x)//2` 49 | - real-valued `x` only 50 | - returns `Sx`, `fbank_f` 51 | """ 52 | assert len(window) == n_fft and n_fft <= len(x) 53 | 54 | # compute some params 55 | xp = x 56 | N = len(x) 57 | padded_len = N 58 | 59 | # pad such that `armgax(window)` for `Sx[:, 0]` is at `Sx[:, 0]` 60 | # (DFT-center the convolution kernel) 61 | # note, due to how scipy generates windows and standard stft handles padding, 62 | # this still won't yield zero-phase for majority of signals, but it's both 63 | # fixable and can be very close; ideally just pass in `len(window)==len(x)`. 64 | _wpad_right = (padded_len - len(window)) / 2 65 | wpad_right = (int(np.floor(_wpad_right)) if n_fft % 2 == 1 else 66 | int(np.ceil(_wpad_right))) 67 | wpad_left = padded_len - len(window) - wpad_right 68 | 69 | # generate filterbank 70 | cisoids = np.array([cisoid(n_fft, f) for f in range(n_fft//2 + 1)]) 71 | fbank = ifftshift(window) * cisoids 72 | fbank = np.pad(fftshift(fbank), [[0, 0], [wpad_left, wpad_right]]) 73 | fbank = ifftshift(fbank) 74 | fbank_f = fft(fbank, axis=-1).conj() 75 | 76 | # circular convolution 77 | prod = fbank_f * fft(xp)[None] 78 | Sx = ifft(prod, axis=-1) 79 | return Sx, fbank_f 80 | 81 | 82 | #%% first, validate correctness 83 | for N in (128, 129): 84 | x = np.random.randn(N) 85 | for n_fft in (55, 56): 86 | window = np.abs(np.random.randn(n_fft) + .01) 87 | 88 | Sx0 = stft(x, window, n_fft=n_fft, modulated=True, 89 | padtype='wrap', hop_len=1, dtype='float64') 90 | Sx1, _ = stft_rowwise(x, window, n_fft) 91 | 92 | assert np.allclose(Sx0, Sx1) 93 | 94 | #%% Demo filterbanks 95 | N = 256 96 | x = np.random.randn(N) 97 | pkw = dict(color='tab:blue', w=1, h=.6) 98 | window = dpss(N//6, N//6//4 - 1, sym=False) 99 | 100 | plotscat(window, title="scipy.signal.windows.dpss(42, 9.5, sym=False)", show=1) 101 | 102 | # full spectrum 103 | n_fft = N 104 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 105 | _, fbank_f0 = stft_rowwise(x, window_adj_f, n_fft) 106 | plot(fbank_f0.T.real, title="fbank_f0 | n_fft=N", **pkw, show=1) 107 | 108 | # integer subset 109 | n_fft = N//4 110 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 111 | _, fbank_f1 = stft_rowwise(x, window_adj_f, n_fft) 112 | assert np.allclose(fbank_f1, fbank_f0[::4]) 113 | plot(fbank_f1.T.real, title="fbank_f1 = fbank_f0[::4] | n_fft=N/4", **pkw, 114 | show=1) 115 | 116 | # fractional subset 117 | n_fft = int(N/5.5) 118 | window_adj, window_adj_f = get_adj_windows(window, n_fft, N) 119 | _, fbank_f2 = stft_rowwise(x, window_adj_f, n_fft) 120 | plot(fbank_f2.T.real, 121 | title="fbank_f2, fbank_f1 | n_fft=floor(N/5.5), xlims=(100, 150)", **pkw) 122 | plot(fbank_f1.T.real, show=1, color='tab:orange', xlims=(100, 150)) 123 | 124 | # time-domain examples 125 | def plot_complex(pt, f): 126 | plot(pt, complex=1) 127 | plot(pt, abs=1, linestyle='--', color='k', w=.5, h=.8, 128 | title=f"STFT filter | f={f}", show=1, ylims=(-1.03, 1.03)) 129 | 130 | f0, f1 = 20, 40 131 | pt0 = ifftshift(ifft(fbank_f0[f0])) 132 | pt1 = ifftshift(ifft(fbank_f0[f1])) 133 | plot_complex(pt0, f0) 134 | plot_complex(pt1, f1) 135 | 136 | #%% Demo sines 137 | N = 128 138 | t = np.linspace(0, 1, 128, 1) 139 | x = np.cos(2*np.pi * t * 4) 140 | x += np.cos(2*np.pi * t * 20) 141 | x += np.cos(2*np.pi * t * 40) 142 | x += np.cos(2*np.pi * t * 60) 143 | 144 | window = dpss(128, 128//8, sym=False) 145 | 146 | for n_fft in (128, 256): 147 | Sx0 = stft(x, modulated=0, n_fft=n_fft, window=window)[::-1] 148 | Sx1 = stft(x, modulated=1, n_fft=n_fft, window=window)[::-1] 149 | 150 | fig, axes = plt.subplots(1, 2, sharey=True, figsize=(14, 6), 151 | layout='constrained') 152 | imshow(Sx0.real, fig=fig, ax=axes[0], show=0, 153 | title=f"STFT.real, n_fft={n_fft} | standard") 154 | imshow(Sx1.real, fig=fig, ax=axes[1], show=1, 155 | title=f"STFT.real, n_fft={n_fft} | improved") 156 | -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004323.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004323.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004324.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004324.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004326.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004326.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004328.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004328.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004330.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004330.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004331.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004331.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004333.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004333.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004335.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004335.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004338.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004338.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004339.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_D_00004339.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043244.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043244.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043245.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043245.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043246.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/T2_T_00043246.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/example.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/example.wav -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im00.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im01.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im02.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im03.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im04.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im05.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im06.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im07.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im08.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im09.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im10.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im11.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im12.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/im13.png -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/preds.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/preds.gif -------------------------------------------------------------------------------- /SignalProcessing/Q87355 - audio, algorithms - Detecting abrupt changes/test_algo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87355/50076 3 | # USER CONFIGS 4 | TOLERANCE = 0.05 5 | VIZ_PREDS = 1 6 | GPU = 0 7 | FULL_GPU = 0 8 | PRINT_TIMES = 1 9 | 10 | import os 11 | os.environ['SSQ_GPU'] = '1' if GPU else '0' 12 | os.environ['FULL_GPU'] = '1' if FULL_GPU else '0' 13 | 14 | import numpy as np 15 | from utils87355 import ( 16 | find_audio_change_timestamps, handle_wavelet, pad_input_make_t, 17 | load_data, data_labels, data_dir, 18 | ) 19 | from timeit import default_timer as dtime 20 | 21 | #%%############################################################################ 22 | # Configure 23 | # --------- 24 | cfg = dict( 25 | wavelet='freq', 26 | stft_prefilter_scaling=1, 27 | carve_th_scaling=1/3, 28 | fmax_idx_frac=400/471, 29 | silence_th_scaling=50, 30 | final_pred_n_peaks=1, 31 | escaling=(1, 1, 1), 32 | ) 33 | 34 | #%%############################################################################ 35 | # Make reusables 36 | # -------------- 37 | Nmax = max(len(load_data(i)[0]) for i in range(len(data_labels))) 38 | fs = load_data(0)[1] # assumes same for all 39 | Mmax = len(pad_input_make_t(np.arange(Nmax), fs)[0]) 40 | N_ref = Nmax 41 | reusables = handle_wavelet( 42 | wavelet=cfg['wavelet'], 43 | fs=fs, 44 | M=Mmax, 45 | ssq_precfg=dict(scales='log', padtype=None), 46 | fmax_idx_frac=cfg['fmax_idx_frac'], 47 | silence_interval_samples=int(.2*fs), 48 | escaling=cfg['escaling'], 49 | ) 50 | other_cfg = dict(reusables=reusables, fs=fs, N_ref=N_ref) 51 | 52 | #%%############################################################################ 53 | # Run 54 | # --- 55 | results_train = {'preds': []} 56 | results_test = {'preds': []} 57 | 58 | idxs_train = tuple(range(0, 4)) 59 | idxs_test = tuple(range(4, len(data_labels))) 60 | idxs_all = idxs_train + idxs_test 61 | 62 | # get predictions 63 | for example_idx in idxs_all: 64 | # load data 65 | x, fs, labels = load_data(example_idx) 66 | 67 | # make predictions 68 | viz_labels = ((example_idx, labels) if VIZ_PREDS else 69 | None) 70 | if PRINT_TIMES: 71 | t0 = dtime() 72 | preds = find_audio_change_timestamps( 73 | x, **other_cfg, **cfg, viz_labels=viz_labels)[0] 74 | if PRINT_TIMES: 75 | print("%.3g" % (dtime() - t0) + " sec", flush=True) 76 | else: 77 | print(end=".", flush=True) 78 | 79 | # append 80 | d = (results_train if example_idx in idxs_train else 81 | results_test) 82 | d['preds'].append(preds) 83 | 84 | #%%############################################################################ 85 | # Calculate score 86 | # --------------- 87 | def is_success(preds, label, tolerance=TOLERANCE): 88 | if not isinstance(preds, (list, tuple)): 89 | preds = [preds] 90 | return any(abs(p - label) < tolerance for p in preds) 91 | 92 | for d in (results_train, results_test): 93 | for k in ('scores', 'scores_flattened'): 94 | d[k] = [] 95 | 96 | tolerance = TOLERANCE 97 | 98 | printed_line = False 99 | for example_idx in idxs_all: 100 | d = (results_train if example_idx in idxs_train else 101 | results_test) 102 | pred_idx = (example_idx if example_idx in idxs_train else 103 | example_idx - len(idxs_train)) 104 | preds = d['preds'][pred_idx] 105 | labels = load_data(example_idx)[-1] 106 | 107 | if len(labels) == 2: 108 | score0 = is_success(preds, labels[0], tolerance) 109 | score1 = is_success(preds, labels[1], tolerance) 110 | scores_packed = (score0, score1) 111 | else: 112 | score0 = is_success(preds[0], labels[0], tolerance) 113 | scores_packed = (score0,) 114 | 115 | # append 116 | d['scores'].append(scores_packed) 117 | d['scores_flattened'].extend(scores_packed) 118 | # print(end=".", flush=True) 119 | 120 | if example_idx not in idxs_train and not printed_line: 121 | print("="*80) 122 | printed_line = True 123 | 124 | print() 125 | preds = sorted([float("%.3g" % p) for p in preds]) 126 | print(tuple(np.array(list(scores_packed)).astype(int)), '--', example_idx) 127 | print(tuple(preds)) 128 | print(tuple(labels)) 129 | 130 | # finalize 131 | accuracy_train = np.mean(results_train['scores_flattened']) 132 | accuracy_test = np.mean(results_test['scores_flattened']) 133 | 134 | print("Accuracy (train, test): {:.3f}, {:.3f}".format( 135 | accuracy_train, accuracy_test)) 136 | 137 | #%% 138 | # from ? import make_gif 139 | # from pathlib import Path 140 | # make_gif(data_dir, str(Path(data_dir, "preds.gif")), 141 | # duration=1000, overwrite=True, delimiter="im", HD=True, verbose=True) 142 | 143 | #%%############################################################################ 144 | # Last run's output 145 | # ----------------- 146 | """ 147 | (1, 1) -- 0 148 | (1.56, 3.54) 149 | (1.55, 3.5) 150 | 151 | (1, 1) -- 1 152 | (1.22, 1.74) 153 | (1.21, 1.74) 154 | 155 | (1, 1) -- 2 156 | (0.958, 1.57) 157 | (0.94, 1.57) 158 | 159 | (1, 1) -- 3 160 | (1.42, 1.87) 161 | (1.42, 1.85) 162 | ================================================================================ 163 | 164 | (1,) -- 4 165 | (0.783, 1.12) 166 | (0.76,) 167 | 168 | (1, 1) -- 5 169 | (0.818, 1.69) 170 | (0.79, 1.68) 171 | 172 | (1, 1) -- 6 173 | (2.89, 3.3) 174 | (2.87, 3.28) 175 | 176 | (0,) -- 7 177 | (0.488, 0.738) 178 | (0.75,) 179 | 180 | (1,) -- 8 181 | (0.623, 1.77) 182 | (0.63,) 183 | 184 | (1,) -- 9 185 | (0.211, 0.468) 186 | (0.46,) 187 | 188 | (1,) -- 10 189 | (4.69, 4.97) 190 | (4.97,) 191 | 192 | (1,) -- 11 193 | (0.543, 0.956) 194 | (0.94,) 195 | 196 | (1, 1) -- 12 197 | (2.35, 3.0) 198 | (2.37, 2.99) 199 | 200 | (1, 0) -- 13 201 | (2.06, 2.44) 202 | (2.03, 2.73) 203 | Accuracy (train, test): 1.000, 0.857 204 | """ 205 | -------------------------------------------------------------------------------- /SignalProcessing/Q87706 - aliasing - How to measure aliasing/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87705/50076 3 | import numpy as np 4 | from numpy.fft import fft, ifft 5 | import matplotlib.pyplot as plt 6 | import scipy.signal 7 | 8 | try: 9 | from ssqueezepy.visuals import plotscat as _plotscat, plot 10 | except: 11 | print("Plotting requires ssqueezepy; won't plot") 12 | _plotscat = lambda *a, **k: 1 13 | plot = lambda *a, **k: 1 14 | 15 | def plotscat(*a, **k): 16 | x = a[0] 17 | if k.get('abs', 0): 18 | x = abs(x) 19 | mn, mx = x.min(), x.max() 20 | amx = 1.03*max(abs(mn), abs(mx)) 21 | if 'ylims' not in k: 22 | k['ylims'] = (-amx, amx) if mn < -.03 else (0, amx) 23 | _plotscat(x, *a[1:], **k) 24 | 25 | 26 | def energy(x): 27 | return np.sum(abs(x)**2) 28 | 29 | def get_xf_sub_worst(xf, M): 30 | return (abs(xf.real) + 1j*abs(xf.imag)).reshape(M, -1).mean(axis=0) 31 | 32 | def measure(xf, M): 33 | N = len(xf) 34 | hbw = N//M//2 35 | xf_ref = np.zeros(N) 36 | xf_ref[:hbw] = 1 37 | xf_ref[-hbw:] = 1 38 | 39 | xf /= abs(xf).max() # in case this step wasn't done already 40 | 41 | xf_sub = get_xf_sub_worst(xf, M) 42 | xf_ref_sub = xf_ref.reshape(M, -1).mean(axis=0) 43 | 44 | E_x_before = energy(xf) 45 | E_ref_before = energy(xf_ref) 46 | E_x_after = energy(xf_sub) 47 | E_ref_after = energy(xf_ref_sub) 48 | ratio_before = E_x_before / E_ref_before 49 | ratio_after = E_x_after / E_ref_after 50 | r = ratio_after / ratio_before 51 | 52 | alias = round(100*(r - 1) / (M - 1), 3) 53 | return alias 54 | 55 | #%% Make signals 56 | N = 256 57 | M = 8 58 | 59 | # populate unaliased signal 60 | hbw = N//M//2 61 | xf_ref = np.zeros(N) 62 | xf_ref[:hbw] = 1 63 | xf_ref[-hbw:] = 1 64 | 65 | # populate referenced signal 66 | xf_full = np.ones(N) 67 | 68 | # make in-between signal 69 | xf0 = xf_ref.copy() 70 | xf0[hbw*2] = 1 71 | xf0[-(hbw*2)] = 1 72 | 73 | #%% Plot 74 | def viz(xf, name, aval=0, do_measure=0, worst=0, cval=0, ylims01=True, 75 | show_ref=False): 76 | fig, axes = plt.subplots(1, 2, layout='constrained', figsize=(12, 5)) 77 | pkw = dict(fig=fig, abs=aval, complex=cval) 78 | if ylims01: 79 | pkw['ylims'] = (0, 1.03) 80 | if worst: 81 | sub = get_xf_sub_worst(xf, M) 82 | else: 83 | sub = xf.reshape(M, -1).mean(axis=0) 84 | l, r = ("|", "|") if aval else ("", "") 85 | info = ("abs_max={:.3g}".format(max(abs(sub))) if not do_measure else 86 | "alias={:.3g}%".format(measure(xf, M))) 87 | title1 = "{}fft({}){}".format(l, name, r) 88 | if worst: 89 | title2 = "{}fft({}_sub{}_worst){} -- {}".format(l, name, M, r, info) 90 | else: 91 | title2 = "{}fft({}[::{}]){} -- {}".format(l, name, M, r, info) 92 | 93 | plotscat(xf, **pkw, ax=axes[0], title=title1) 94 | if show_ref: 95 | pkw['hlines'] = (1/M, {'color': 'tab:red', 'linewidth': 1}) 96 | plotscat(sub, **pkw, ax=axes[1], title=title2) 97 | plt.show() 98 | 99 | viz(xf_ref, "x_ref") 100 | viz(xf_full, "x_full") 101 | viz(xf0, "x0") 102 | 103 | #%% Applying the metric 104 | def scipy_decimate_filter(N, M): 105 | """Minimally reproduce the filter used by 106 | `scipy.signal.decimate(ftype='fir')`. This is rigorously tested elsewhere. 107 | """ 108 | q = M 109 | half_len = 10*q 110 | n = int(2*half_len) 111 | cutoff = 1. / q 112 | numtaps = n + 1 113 | 114 | win = scipy.signal.get_window("hamming", numtaps, fftbins=False) 115 | 116 | # sample, window, & norm sinc 117 | alpha = 0.5 * (numtaps - 1) 118 | m = np.arange(0, numtaps) - alpha 119 | h = win * cutoff * np.sinc(cutoff * m) 120 | h /= h.sum() # L1 norm 121 | 122 | # pad and center 123 | h = np.pad(h, [0, N - len(h)]) 124 | h = np.roll(h, -np.argmax(h)) 125 | return h 126 | 127 | # moving average 128 | h = np.zeros(N) 129 | h[:M//2 + 1] = 1/M 130 | h[-(M//2 - 1):] = 1/M 131 | hf = fft(h) 132 | 133 | # scipy 134 | h_scipy = scipy_decimate_filter(N, M) 135 | hf_scipy = fft(h_scipy) 136 | 137 | #%% Show results 138 | ckw = dict(aval=1, do_measure=1, worst=1, show_ref=1, ylims01=0) 139 | viz(hf, "h", **ckw) 140 | viz(hf_scipy, "h_scipy", **ckw) 141 | 142 | #%% Effect 143 | def dft_upsample(xf, M): 144 | L = len(xf) 145 | n_to_add = L * M - L 146 | zeros = np.zeros(n_to_add - 1) 147 | nyq = xf[L//2] 148 | return np.hstack([xf[:L//2], 149 | nyq/2, zeros, np.conj(nyq)/2, 150 | xf[-(L//2 - 1):]]) * M 151 | 152 | def rel_l2(x0, x1): 153 | return np.linalg.norm(x0 - x1) / np.linalg.norm(x0) 154 | 155 | def generate_case(M, seed=0): 156 | np.random.seed(seed) 157 | x = np.random.randn(N) 158 | xf = fft(x) 159 | x0_convf = xf * hf 160 | x1_convf = xf * hf_scipy 161 | 162 | _x0f = x0_convf.reshape(M, -1).mean(axis=0) 163 | _x1f = x1_convf.reshape(M, -1).mean(axis=0) 164 | x0f = dft_upsample(_x0f, M) 165 | x1f = dft_upsample(_x1f, M) 166 | x0, x1 = ifft(x0f), ifft(x1f) 167 | 168 | x0_nosub = x0_convf 169 | x0_nosub[hbw:-hbw] = 0 170 | x0_nosub = ifft(x0_nosub) 171 | x1_nosub = ifft(x1_convf) 172 | 173 | return x0, x1, x0_nosub, x1_nosub 174 | 175 | #%% Viz 176 | x0, x1, x0_nosub, x1_nosub = generate_case(M, seed=42) 177 | 178 | fig, axes = plt.subplots(1, 2, layout='constrained', figsize=(12, 5), 179 | sharey=True) 180 | pkw = dict(fig=fig) 181 | plot(x0_nosub.real, ax=axes[0]) 182 | plot(x0.real, ax=axes[0], title="hf: recovered") 183 | plot(x1_nosub.real, ax=axes[1]) 184 | plot(x1.real, ax=axes[1], title="hf_scipy: recovered") 185 | print(rel_l2(x0, x0_nosub), rel_l2(x1, x1_nosub), sep='\n') 186 | 187 | #%% Survey 188 | dist_hf, dist_hf_scipy = [], [] 189 | for seed in range(1000000): 190 | x0, x1, x0_nosub, x1_nosub = generate_case(M, seed) 191 | dist_hf.append(rel_l2(x0, x0_nosub)) 192 | dist_hf_scipy.append(rel_l2(x1, x1_nosub)) 193 | dist_hf, dist_hf_scipy = [np.array(d) for d in (dist_hf, dist_hf_scipy)] 194 | 195 | #%% Viz survey 196 | plot(dist_hf) 197 | plot(dist_hf_scipy) 198 | 199 | fmt = ("(min, max, mean) = ({:.3g}, {:.3g}, {:.3g})\n" 200 | "(min_idx, max_idx) = ({}, {})\n") 201 | ops = (np.min, np.max, np.mean, np.argmin, np.argmax) 202 | print(fmt.format(*[op(dist_hf) for op in ops])) 203 | print(fmt.format(*[op(dist_hf_scipy) for op in ops])) 204 | -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid.png -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target.png -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_target2.png -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OverLordGoldDragon/StackExchangeAnswers/47a5fd462e506cd417c7112a9fff3300b489f0de/SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/covid_template.png -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87774/50076 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import scipy.signal 6 | from numpy.fft import fft2, ifft2, ifftshift 7 | from PIL import Image 8 | from matplotlib.patches import Circle 9 | 10 | def rand(*s, ival=True): 11 | return ((1j*np.random.randn(*s)) if ival else 12 | np.random.randn(*s)) 13 | 14 | def cross_correlate_2d(x, h): 15 | h = ifftshift(ifftshift(h, axes=0), axes=1) 16 | return ifft2(fft2(x) * np.conj(fft2(h))) 17 | 18 | def load_image(path): 19 | img = np.array(Image.open(path).convert("L")) / 255. 20 | img[img==1] = 0 21 | return img 22 | 23 | def imshow(x, title=None, show_argmax=False, fig=None, ax=None, show=True): 24 | xa = np.abs(x) 25 | if fig is None: 26 | fig, ax = plt.subplots() 27 | ax.imshow(xa) 28 | ax.set_xticks([]) 29 | ax.set_yticks([]) 30 | if title is not None: 31 | ax.set_title(title, weight='bold', loc='left', fontsize=20) 32 | if show_argmax: 33 | hmax, wmax = np.where(xa == xa.max()) 34 | size = len(x) // 7 // 2 35 | circ = Circle((wmax, hmax), size, fill=False, color='tab:red', 36 | linewidth=2) 37 | ax.add_patch(circ) 38 | if show: 39 | plt.show() 40 | 41 | 42 | def run_example(x, h, show_h=False): 43 | # compute 44 | out0 = scipy.signal.correlate2d(x, h) 45 | out1 = cross_correlate_2d(x, h) 46 | 47 | # plot 48 | if show_h: 49 | fig, axes = plt.subplots(2, 1, figsize=(5.7, 5.7), layout='constrained') 50 | imshow(x, "|x|, |h|", fig=fig, ax=axes[0], show=False,) 51 | imshow(h, fig=fig, ax=axes[1]) 52 | else: 53 | imshow(x, "|x|; x = image + iWGN/5") 54 | 55 | nm = "x, x" if not show_h else "x, h" 56 | imshow(out0, f"|scipy.signal.correlate2d({nm})|", show_argmax=True) 57 | imshow(out1, f"|cross_correlate_2d({nm})|", show_argmax=True) 58 | 59 | 60 | #%% Example: self-cc ######################################################### 61 | # load image as greyscale 62 | np.random.seed(0) 63 | x = load_image("covid.png") 64 | 65 | # subsample & add noise 66 | x = x[::4, ::4] 67 | x = x + rand(*x.shape) / 5 68 | 69 | run_example(x, x) 70 | 71 | #%% Example: flipped false positive ########################################## 72 | np.random.seed(0) 73 | x = load_image("covid_target.png") 74 | h = load_image("covid_template.png") 75 | 76 | x, h = x[::9, ::9], h[::9, ::9] 77 | x = x + rand(*x.shape) / 10 78 | h = h + rand(*h.shape) / 10 79 | 80 | run_example(x, h, show_h=True) 81 | 82 | #%% Example: conjugated false positive######################################## 83 | np.random.seed(0) 84 | x = load_image("covid_target2.png") 85 | h = load_image("covid_template.png") 86 | 87 | x, h = 1j*x[::9, ::9], 1j*h[::9, ::9] 88 | M, N = x.shape 89 | x[:, N//2:] = np.conj(x[:, N//2:]) 90 | 91 | x = x + rand(*x.shape) / 10 * 1j 92 | h = h + rand(*h.shape) / 10 * 1j 93 | 94 | run_example(x, h, show_h=True) 95 | -------------------------------------------------------------------------------- /SignalProcessing/Q87774 - cross-correlation, FFT - Verifying a source on 2D FFT Cross-correlation/validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | def cross_correlate_2d(x, h): 3 | h = ifftshift(ifftshift(h, axes=0), axes=1) 4 | return ifft2(fft2(x) * np.conj(fft2(h))) 5 | 6 | ############################################################################## 7 | import numpy as np 8 | from numpy.fft import fft2, ifft2, ifftshift 9 | 10 | def crand(*s): 11 | return np.random.randn(*s) + 1j*np.random.randn(*s) 12 | 13 | for M in (64, 65, 99): 14 | for N in (64, 65, 99): 15 | # case 1 ------------------------------------------------------------- 16 | x = crand(M, N) 17 | 18 | o = cross_correlate_2d(x, x) 19 | hmax, wmax = np.where(abs(o) == abs(o).max()) 20 | assert hmax == M//2, (hmax, M//2) 21 | assert wmax == N//2, (wmax, N//2) 22 | 23 | # case 2 ------------------------------------------------------------- 24 | x = np.zeros((M, N), dtype='complex128') 25 | h = np.zeros((M, N), dtype='complex128') 26 | 27 | hctr, wctr = M//8, N//8 28 | hsize, wsize = hctr*2, wctr*2 29 | target_loc = slice(0, hsize), slice(0, wsize) 30 | false_positive_loc = slice(M//2, M//2 + hsize), slice(-wsize, None) 31 | target = crand(hsize, wsize) 32 | 33 | x[target_loc] = target 34 | x[false_positive_loc] = target[::-1, ::-1] 35 | h[M//2 - hctr:M//2 + hctr, N//2 - wctr:N//2 + wctr] = target 36 | 37 | o = cross_correlate_2d(x, h) 38 | hmax, wmax = np.where(abs(o) == abs(o).max()) 39 | assert hmax == hctr, (hmax, hctr) 40 | assert wmax == wctr, (wmax, wctr) 41 | 42 | # case 3 ------------------------------------------------------------- 43 | x[false_positive_loc] = np.conj(target) 44 | 45 | o = cross_correlate_2d(x, h) 46 | hmax, wmax = np.where(abs(o) == abs(o).max()) 47 | assert hmax == hctr, (hmax, hctr) 48 | assert wmax == wctr, (wmax, wctr) 49 | -------------------------------------------------------------------------------- /SignalProcessing/Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python/cc2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87781/50076 3 | import numpy as np 4 | from scipy.fft import next_fast_len, fft2, ifft2 5 | 6 | 7 | def cross_correlate_2d(x, h, mode='same', real=True, get_reusables=False, 8 | inplace=True, workers=-1): 9 | """2D cross-correlation, replicating `scipy.signal.correlate2d`. 10 | 11 | Parameters 12 | ---------- 13 | x : np.ndarray, 2D 14 | Input. 15 | 16 | h : np.ndarray, 2D 17 | Filter/template. 18 | 19 | mode : str 20 | `'full'`, `'same'`, or `'valid'` (see scipy docs). 21 | 22 | real : bool (default True) 23 | Whether to assume `x` and `h` are real-valued, which is faster. 24 | 25 | get_reusables : bool (default False) 26 | Whether to return `out, reusables`. 27 | 28 | If `h` is same and `x` has same shape, pass in `reusables` as `h` for 29 | speedup, 30 | 31 | inplace : bool (default True) 32 | If True, is faster but may alter `x` and/or `h` that are passed in 33 | (unless via `reusables`). 34 | 35 | workers : int 36 | Number of CPU cores to use with FFT. Defaults to `-1`, which is all. 37 | """ 38 | # check if `h` is reusables 39 | if not isinstance(h, tuple): 40 | # fetch shapes, check inputs 41 | xs, hs = x.shape, h.shape 42 | h_not_smaller = all(hs[i] >= xs[i] for i in (0, 1)) 43 | x_not_smaller = all(xs[i] >= hs[i] for i in (1, 0)) 44 | if mode == 'valid' and not (h_not_smaller or x_not_smaller): 45 | raise ValueError( 46 | "For `mode='valid'`, every axis in `x` must be at least " 47 | "as long as in `h`, or vice versa. Got x:{}, h:{}".format( 48 | str(xs), str(hs))) 49 | 50 | # swap if needed 51 | swap = bool(mode == 'valid' and not x_not_smaller) 52 | if swap: 53 | xadj, hadj = h, x 54 | else: 55 | xadj, hadj = x, h 56 | xs, hs = xadj.shape, hadj.shape 57 | 58 | # compute pad quantities 59 | full_len_h = xs[0] + hs[0] - 1 60 | full_len_w = xs[1] + hs[1] - 1 61 | padded_len_h = next_fast_len(full_len_h) 62 | padded_len_w = next_fast_len(full_len_w) 63 | padded_shape = (padded_len_h, padded_len_w) 64 | 65 | # compute unpad indices 66 | if mode == 'full': 67 | offset_h, offset_w = 0, 0 68 | len_h, len_w = full_len_h, full_len_w 69 | elif mode == 'same': 70 | len_h, len_w = xs 71 | offset_h, offset_w = [g//2 for g in hs] 72 | elif mode == 'valid': 73 | ax_pairs = ((xs[0], hs[0]), (xs[1], hs[1])) 74 | len_h, len_w = [max(g) - min(g) + 1 for g in ax_pairs] 75 | offset_h, offset_w = [min(g) - 1 for g in ax_pairs] 76 | unpad_h = slice(offset_h, offset_h + len_h) 77 | unpad_w = slice(offset_w, offset_w + len_w) 78 | 79 | # handle filter / template 80 | if real: 81 | hadj = hadj[::-1, ::-1] 82 | else: 83 | if inplace: 84 | np.conj(hadj[::-1, ::-1], out=hadj) 85 | else: 86 | hadj = np.conj(hadj[::-1, ::-1]) 87 | hf = fft2(hadj, padded_shape, workers=workers) 88 | else: 89 | reusables = h 90 | (hf, swap, padded_shape, unpad_h, unpad_w) = reusables 91 | if swap: 92 | xadj, hadj = h, x 93 | else: 94 | xadj, hadj = x, h 95 | 96 | # FFT convolution 97 | xf = fft2(xadj, padded_shape, workers=workers) 98 | if inplace: 99 | np.multiply(xf, hf, out=xf) 100 | else: 101 | xf = xf * hf 102 | out = ifft2(xf, workers=workers) 103 | if real: 104 | out = out.real 105 | 106 | # unpad, unswap 107 | out = out[unpad_h, unpad_w] 108 | if swap: 109 | out = out[::-1, ::-1] 110 | 111 | # pack reusables 112 | if get_reusables: 113 | reusables = (hf, swap, padded_shape, unpad_h, unpad_w) 114 | 115 | # return 116 | return ((out, reusables) if get_reusables else 117 | out) 118 | -------------------------------------------------------------------------------- /SignalProcessing/Q87781 - cross-correlation, FFT - 2D Cross-Correlation using FFT in Python/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87781/50076 3 | import numpy as np 4 | import scipy.signal 5 | from cc2d import cross_correlate_2d 6 | 7 | #%% Testing ################################################################## 8 | np.random.seed(0) 9 | 10 | rand = lambda M, N, real: (np.random.randn(M, N) if real else 11 | (np.random.randn(M, N) + 1j*np.random.randn(M, N))) 12 | 13 | lengths = (1, 2, 3, 4, 5, 6, 7, 15, 50) 14 | 15 | for real in (True, False): 16 | for mode in ('full', 'same', 'valid'): 17 | for M0 in lengths: 18 | for N0 in lengths: 19 | x = rand(M0, N0, real) 20 | for M1 in lengths: 21 | for N1 in lengths: 22 | h = rand(M1, N1, real) 23 | 24 | fn0 = lambda: cross_correlate_2d( 25 | x.copy(), h.copy(), mode, real=real) 26 | fn1 = lambda: scipy.signal.correlate2d( 27 | x.copy(), h.copy(), mode=mode) 28 | 29 | # compute 30 | try: 31 | out0 = fn0() 32 | except ValueError: 33 | try: 34 | fn1() 35 | except ValueError: 36 | continue 37 | except: 38 | raise AssertionError 39 | out1 = fn1() 40 | 41 | # assert equality 42 | cfg = (real, mode, M0, N0, M1, N1) 43 | assert out0.shape == out1.shape, cfg 44 | assert np.allclose(out0, out1), cfg 45 | 46 | 47 | # check reusables 48 | out0, reusables = cross_correlate_2d(x, h, get_reusables=True) 49 | out1 = cross_correlate_2d(x, reusables) 50 | assert np.allclose(out0, out1) 51 | -------------------------------------------------------------------------------- /SignalProcessing/Q87926 - DFT of a sine, closed form solution and insights/solutions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/87926/50076 3 | import numpy as np 4 | 5 | def _get_UVk(N, f, phi): 6 | assert not float(f).is_integer(), f 7 | k = np.arange(N) 8 | U = np.cos(2*np.pi*f + phi) - np.cos(phi) 9 | V = np.cos(2*np.pi*f + phi - 2*np.pi*f/N) - np.cos(phi - 2*np.pi*f/N) 10 | return U, V, k 11 | 12 | 13 | def sine_dft(N, f, phi): 14 | """Solution by Cedron Dawg https://www.dsprelated.com/showarticle/771.php 15 | """ 16 | if float(f).is_integer(): 17 | return _sine_dft_int(N, f, phi) 18 | return _sine_dft_frac(N, f, phi) 19 | 20 | def sine_dft_modulus(N, f, phi): 21 | """Modulus analysis & insights by John Muradeli # TODO URL 22 | """ 23 | if float(f).is_integer(): 24 | return _sine_dft_modulus_int(N, f, phi) 25 | return _sine_dft_modulus_frac(N, f, phi) 26 | 27 | def _sine_dft_frac(N, f, phi): 28 | U, V, k = _get_UVk(N, f, phi) 29 | 30 | n = U*np.exp(1j*2*np.pi*k/N) - V 31 | d = np.cos(2*np.pi*f/N) - np.cos(2*np.pi*k/N) 32 | out = 1/2 * n / d 33 | return out 34 | 35 | def _sine_dft_modulus_frac(N, f, phi, get_params=False): 36 | U, V, k = _get_UVk(N, f, phi) 37 | 38 | n = np.sqrt(U**2 + V**2 - 2*U*V*np.cos(2*np.pi*k/N)) 39 | d = abs(np.cos(2*np.pi*f/N) - np.cos(2*np.pi*k/N)) 40 | out = 1/2 * n / d 41 | 42 | if get_params: 43 | return out, (U, V) 44 | return out 45 | 46 | def _delta(x): 47 | """Kronecker Delta (discrete unit impulse). Handles numeric precision.""" 48 | x = np.asarray(x) 49 | eps = (0 if x.dtype.name.startswith('int') else 50 | np.finfo(x.dtype).eps) 51 | return (abs(x) <= eps).astype(float) 52 | 53 | def _sine_dft_int(N, f, phi): 54 | k = np.arange(N) 55 | a = np.exp(1j*phi ) * _delta(np.mod(k - f, N)) 56 | b = np.exp(-1j*phi) * _delta(np.mod(k + f, N)) 57 | return (N/2) * (a + b) 58 | 59 | def _sine_dft_modulus_int(N, f, phi): 60 | k = np.arange(N) 61 | dright = _delta(np.mod(k - f, N)) 62 | dleft = _delta(np.mod(k + f, N)) 63 | 64 | return (N/2) * np.sqrt(dleft + dright + 2*np.cos(2*phi)*dleft*dright) 65 | 66 | 67 | def sine_stft(N, M, H, f, phi): 68 | assert M <= N and 1 <= H <= N, (N, M, H) 69 | n_hops = (N - M)//H + 1 70 | out = np.zeros((M, n_hops), dtype='complex128') 71 | 72 | for tau in range(n_hops): 73 | phi_tau = phi + 2*np.pi*f*tau*H/N 74 | out[:, tau] = sine_dft(M, f*M/N, phi_tau) 75 | return out 76 | -------------------------------------------------------------------------------- /SignalProcessing/Q88042 - sampling - Is downsampling LTI for bandlimited inputs/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # https://dsp.stackexchange.com/q/88042/50076 3 | import numpy as np 4 | from numpy.fft import fft, ifft 5 | 6 | try: 7 | from .toolkit import fft_upsample 8 | got_fft_upsample = True 9 | except: 10 | got_fft_upsample = False 11 | import warnings 12 | warnings.warn("`fft_upsample` not found! Some tests skipped.") 13 | 14 | # NOTE: linearity isn't tested! 15 | 16 | #%%########################################################################### 17 | # Helpers 18 | # ------- 19 | def crandn(N): 20 | return np.random.randn(N) + 1j*np.random.randn(N) 21 | 22 | # note, this `op` is deliberately incomplete relative to the answer's formulation 23 | # (i.e. omits `fft_upsample`, so that we can test non-upsampled cases) 24 | def op(a, b, d, x_only): 25 | if x_only: 26 | return ifft(fft(a[::d])) 27 | else: 28 | return ifft(fft(a[::d]) * fft(b[::d] * d)) 29 | 30 | #%%########################################################################### 31 | # Testing 32 | # ------- 33 | # np.random.seed(0) # uncomment for reproducible debugging 34 | optional_successes = {i: [] for i in range(5)} 35 | n_cases = 0 36 | 37 | for real in (True, False): 38 | for x_bandlimited in (True, False): 39 | for h_bandlimited in (True, False): 40 | for x_only in (True, False): 41 | if x_only and not h_bandlimited: 42 | # duplicate case 43 | continue 44 | for N in (16, 40, 128): 45 | for d in range(2, N//2 + 1, 2): 46 | for s in range(1, N): 47 | # prepare -------------------------------------------------------- 48 | # check if loop is valid 49 | if (N / d) % 2 != 0: 50 | # `fft_upsample` doesn't handle this case, nor does it seem 51 | # that many of such cases can be handled 52 | continue 53 | 54 | # store reusables 55 | both_bandlimited = (x_bandlimited and h_bandlimited) 56 | one_bandlimited = (x_bandlimited or h_bandlimited) 57 | not_bandlimited = (not x_bandlimited and not h_bandlimited) 58 | bandlimited = ((x_only and x_bandlimited) or 59 | (not x_only and both_bandlimited)) 60 | 61 | # generate filter, signal 62 | x, h = crandn(N), crandn(N) 63 | if real: 64 | x, h = x.real, h.real 65 | xf, hf = fft(x), fft(h) 66 | zero_idxs = (slice(N//d//2, -(N//d//2 - 1)) if N//d != 2 else 67 | slice(1, None)) 68 | if x_bandlimited: 69 | xf[zero_idxs] = 0 70 | if h_bandlimited: 71 | hf[zero_idxs] = 0 72 | x, h = ifft(xf), ifft(hf) 73 | if real: 74 | assert np.allclose(max(abs(x.imag)), 0) 75 | assert np.allclose(max(abs(h.imag)), 0) 76 | x, h = x.real, h.real 77 | 78 | # execute -------------------------------------------------------- 79 | # get outputs 80 | o0a = op(np.roll(x, 0), h, 1, x_only) 81 | o0b = op(np.roll(x, s), h, 1, x_only) 82 | o1a = op(np.roll(x, 0), h, d, x_only) 83 | o1b = op(np.roll(x, s), h, d, x_only) 84 | 85 | # get upsampled outputs 86 | if got_fft_upsample: 87 | o1au, o1bu = [fft_upsample(o, d, time_to_time=True) 88 | for o in (o1a, o1b)] 89 | 90 | # validate ------------------------------------------------------- 91 | # prepare info in case assert fails 92 | cfg = dict(real=real, 93 | x_bandlimited=x_bandlimited, 94 | h_bandlimited=h_bandlimited, 95 | x_only=x_only, N=N, d=d, s=s) 96 | info = "\n " + "\n ".join(f"{k}={v}" for k, v in cfg.items()) 97 | _err = lambda i: f'{i}:{info}' 98 | 99 | # fetch expectations 100 | cond = { 101 | 0: True, 102 | 1: False, 103 | 2: (s / d).is_integer(), 104 | 3: (s / d).is_integer() or bandlimited, 105 | 4: bandlimited, 106 | } 107 | optional = { 108 | 0: False, 109 | 1: True, 110 | # only integer `s/d` must pass 111 | 2: False if cond[2] else True, 112 | # only integer `s/d` or bandlimited must pass 113 | 3: False if cond[3] else True, 114 | # only bandlimited must be roLTI 115 | 4: False if cond[4] else True, 116 | } 117 | results = {} 118 | 119 | # gather successes/failures 120 | # [LTI] (no subsampling) 121 | # shifted input <=> shifted output 122 | results[0] = (cond[0] == np.allclose(o0b, np.roll(o0a, s))) 123 | # [LTI] 124 | # subsampling of shift <=> shift of subsampling 125 | results[1] = (cond[1] == np.allclose(o1b, np.roll(o1a, s//d))) 126 | # [fLTI] 127 | # subsampling of shift <=> shift of subsampling; integer `s/d` 128 | results[2] = (cond[2] == np.allclose(o1b, np.roll(o1a, s//d))) 129 | # [rLTI] 130 | # upsampled subsampling of shift <=> shift of upsampled subsampling 131 | if got_fft_upsample: 132 | results[3] = (cond[3] == np.allclose(o1bu, np.roll(o1au, s))) 133 | # [roLTI] upsampled subsampling of shift <=> shift 134 | if got_fft_upsample: 135 | results[4] = (cond[4] == np.allclose(o1bu, o0b)) 136 | 137 | # run assertions on non-optionals, append info about the rest 138 | for i in results: 139 | if i != 0: 140 | if not optional[i]: 141 | assert results[i], _err(i) 142 | else: 143 | if results[i]: 144 | optional_successes[i].append(tuple(cfg.values())) 145 | else: 146 | assert results[0], ( 147 | "0:\n" "the operation, unrelated to subsampling, " 148 | "isn't LTI") 149 | n_cases += 1 150 | 151 | #%%########################################################################### 152 | # Trivia 153 | # ------ 154 | print("Number of optional (unexpected) successes:\n " + 155 | "\n ".join("Case {}: {:>4} ({:>4.3g}%)".format( 156 | i, len(v), 100*len(v)/n_cases) 157 | for i, v in optional_successes.items())) 158 | 159 | print("\nExample optional successes for case 2:\n " + 160 | ", ".join(list(cfg)) + "\n " + 161 | "\n ".join(str(v) for v in optional_successes[2][:5])) 162 | -------------------------------------------------------------------------------- /StackOverflow/Q76812752 - MATLAB dictionary comprehension/ifelse.m: -------------------------------------------------------------------------------- 1 | % https://stackoverflow.com/q/76812752/10133797 2 | % Thanks @user16372530 https://stackoverflow.com/a/73751467/10133797 3 | function out = ifelse(a, cond, b) 4 | if cond 5 | out = a(); 6 | else 7 | out = b(); 8 | end 9 | end 10 | -------------------------------------------------------------------------------- /StackOverflow/Q76812752 - MATLAB dictionary comprehension/py_dictcomp.m: -------------------------------------------------------------------------------- 1 | % https://stackoverflow.com/q/76812752/10133797 2 | function out = py_dictcomp(fn_k, fn_v, iterable, cond, cell_values) 3 | % py_dictcomp 4 | % Mimics Python's 5 | % 6 | % {f(k): g(v) for k, v in dict.items() if cond(k, v)} 7 | % 8 | % where [f(k), g(v)] = fn(k, v) 9 | % 10 | % or, if `iterable` isn't `dictionary`, 11 | % 12 | % {k: v for i, x in enumerate(iterable) if cond(i, x)} 13 | % 14 | % where `[k, v] = f(i, x)`. 15 | % 16 | % `cell_values=true` makes `out` like `dictionary(a={1})`. By default, 17 | % `false` is attempted first. Has no effect if `iterable` is `dictionary`. 18 | % 19 | % Default `cond = @(a, b) true`. 20 | % 21 | if nargin == 3 22 | cond = @(a, b) true; 23 | end 24 | if nargin <= 4 25 | cell_values = false; 26 | end 27 | user_set_cell_values = (nargin == 5); 28 | 29 | out = dictionary(); 30 | if strcmp(class(iterable), "dictionary") 31 | cell_values = strcmp(iterable.values, "cell"); 32 | keys = iterable.keys; 33 | values = iterable.values; 34 | 35 | for l=1:iterable.numEntries 36 | k = keys{l}; 37 | if cell_values 38 | v = values{l}; 39 | else 40 | v = values(l); 41 | end 42 | if cond(k, v) 43 | k = fn_k(k); 44 | v = fn_v(v); 45 | if cell_values 46 | out{k} = v; 47 | else 48 | out(k) = v; 49 | end 50 | end 51 | end 52 | else 53 | is_cell = iscell(iterable); 54 | for i=1:numel(iterable) 55 | if is_cell 56 | x = iterable{i}; 57 | else 58 | x = iterable(i); 59 | end 60 | 61 | if cond(i, x) 62 | k = fn_k(i); 63 | v = fn_v(x); 64 | 65 | if cell_values 66 | out{k} = v; 67 | else 68 | try 69 | out(k) = v; 70 | % e.g. first populate with `double` values, then 71 | % `string`, the `string` becomes `NaN` unless cell 72 | if isnan(out(k)) && ~isnan(v) 73 | error("NaN `out(k)` without NAN `v`") 74 | end 75 | catch e 76 | if user_set_cell_values 77 | throw(e) 78 | else 79 | out = py_dictcomp(fn_k, fn_v, iterable, cond, true); 80 | end 81 | end 82 | end 83 | end 84 | end 85 | end 86 | end 87 | -------------------------------------------------------------------------------- /StackOverflow/Q76812752 - MATLAB dictionary comprehension/test_py_dictcomp.m: -------------------------------------------------------------------------------- 1 | % https://stackoverflow.com/q/76812752/10133797 2 | %% {i: 5 for i, x in enumerate([1, 2, 3])} 3 | out = py_dictcomp(@(i)i, @(v)5, [1 2 3]); 4 | 5 | ref = dictionary(); 6 | for i=1:3 7 | ref(i) = 5; 8 | end 9 | assert(isequal(out, ref)) 10 | 11 | %% {i: x for i, x in enumerate([1, 2, "a"])} 12 | out = py_dictcomp(@(i)i, @(v)v, [1 2 "a"]); 13 | 14 | ref = dictionary(1, {"1"}, 2, {"2"}, 3, {"a"}); 15 | assert(isequal(out, ref)) 16 | 17 | %% {2*i: x-1 for i, x in enumerate([1, 2, 3]) if (x != 2 and i > 0)} 18 | out = py_dictcomp(@(i)2*i, @(x)x - 1, [1 2 3], @(i, x)x ~= 2 && i > 1); 19 | 20 | ref = dictionary(); 21 | ref(6) = 2; 22 | assert(isequal(out, ref)) 23 | 24 | %% {i: ("a", x**2) for i, x in enumerate([1, np.array([2, 3])])} 25 | out = py_dictcomp(@(i)i, @(x){"a", x.^2}, {1, [2 3]}); 26 | 27 | ref = dictionary(); 28 | cl = {1, [2 3]}; 29 | for i=1:2 30 | ref{i} = {"a", cl{i}.^2}; 31 | end 32 | assert(isequal(out, ref)) 33 | 34 | %% {k: v**2 for k, v in dict(a=1, b=-2).items()} 35 | out = py_dictcomp(@(k)k, @(v)v^2, dictionary(a=1, b=-2)); 36 | ref = dictionary(a=1^2, b=(-2)^2); 37 | assert(isequal(out, ref)) 38 | 39 | %% {k: v**2 for k, v in dict(a=1, b=-2).items() if v > 0} 40 | out = py_dictcomp(@(k)k, @(v)v^2, dictionary(a=1, b=-2), @(k, v)v > 0); 41 | ref = dictionary(a=1^2); 42 | assert(isequal(out, ref)) 43 | 44 | %% {("we have" if i == 0 else i - 1): 45 | % ((x + "ogs") if isinstance(x, str) else (x + 1)) 46 | % for i, x in enumerate([2, "d", 3.5]) 47 | % if not isinstance(x, float)} 48 | out = py_dictcomp(@(i)ifelse("we have", i == 1, i - 1), ... 49 | @(x)ifelse(x + "ogs", isstring(x), x + 1), ... 50 | {2, "d", 3.5}, ... 51 | @(i, x)~(isnumeric(x) && mod(x, 1) ~= 0)); 52 | 53 | %% invalid `cell_values=false` 54 | try 55 | out = py_dictcomp(@(i)i, @(x)x, {1, [2, 3]}, @(a, b)true, false); 56 | assert(false); 57 | catch e 58 | assert(contains(e.identifier, "KeyValueDimsMustMatch")) 59 | end 60 | --------------------------------------------------------------------------------