├── .formatter.exs ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── guides ├── fft.livemd ├── filtering.livemd └── spectrogram.livemd ├── lib ├── nx_signal.ex └── nx_signal │ ├── convolution.ex │ ├── filters.ex │ ├── internal.ex │ ├── peak_finding.ex │ ├── transforms.ex │ ├── waveforms.ex │ └── windows.ex ├── mix.exs ├── mix.lock └── test ├── nx_signal ├── convolutions_test.exs ├── filters_test.exs ├── internal_test.exs ├── peak_finding_test.exs ├── transforms_test.exs ├── waveforms_test.exs └── windows_test.exs ├── nx_signal_test.exs ├── support └── nx_signal_case.ex └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | import_deps: [:nx], 4 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 5 | ] 6 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | main: 9 | name: Linux 10 | runs-on: ubuntu-latest 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | elixir: ["1.18.0"] 15 | otp: ["27.0"] 16 | env: 17 | MIX_ENV: test 18 | steps: 19 | - uses: actions/checkout@v2 20 | - uses: erlef/setup-beam@v1 21 | with: 22 | otp-version: ${{ matrix.otp }} 23 | elixir-version: ${{ matrix.elixir }} 24 | - name: Retrieve dependencies cache 25 | uses: actions/cache@v4 26 | id: mix-cache # id to use in retrieve action 27 | with: 28 | path: deps 29 | key: v1-${{ runner.os }}-${{ matrix.otp }}-${{ matrix.elixir }}-mix-${{ hashFiles('mix.lock') }} 30 | - name: Install dependencies 31 | if: ${{ steps.mix-cache.outputs.cache-hit != 'true' }} 32 | run: mix deps.get 33 | - name: Compile and check warnings 34 | run: mix compile --warnings-as-errors 35 | - name: Check formatting 36 | run: mix format --check-formatted 37 | - name: Run tests 38 | run: mix test 39 | win: 40 | name: Windows 41 | runs-on: windows-latest 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | elixir: ["1.18.0"] 46 | otp: ["27.0"] 47 | env: 48 | MIX_ENV: test 49 | steps: 50 | - name: Configure Git 51 | run: git config --global core.autocrlf input 52 | working-directory: . 53 | - uses: actions/checkout@v2 54 | - uses: ilammy/msvc-dev-cmd@v1 55 | - uses: erlef/setup-beam@v1 56 | with: 57 | otp-version: ${{ matrix.otp }} 58 | elixir-version: ${{ matrix.elixir }} 59 | - name: Retrieve dependencies cache 60 | uses: actions/cache@v4 61 | id: mix-cache # id to use in retrieve action 62 | with: 63 | path: deps 64 | key: v1-${{ runner.os }}-${{ matrix.otp }}-${{ matrix.elixir }}-mix-${{ hashFiles('mix.lock') }} 65 | - name: Install dependencies 66 | if: ${{ steps.mix-cache.outputs.cache-hit != 'true' }} 67 | run: mix deps.get 68 | - name: Compile and check warnings 69 | run: mix compile --warnings-as-errors 70 | - name: Check formatting 71 | run: mix format --check-formatted 72 | - name: Run tests 73 | run: mix test 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | nx_signal-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Paulo Valente 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NxSignal 2 | 3 | DSP (Digital Signal Processing) with [Nx](https://github.com/elixir-nx/nx) 4 | 5 | ## Why NxSignal? 6 | 7 | This library comes from the author's urge to experiment with audio processing in Elixir through Nx. 8 | However, the scope is not limited to audio signals. This library aims to provide the tooling for 9 | a more classical approach to dealing with time series, through Fourier Transforms, FIR filters, 10 | IIR filters and similar mathematical tools. 11 | 12 | ## Getting Started 13 | 14 | In order to use `NxSignal`, you need Elixir installed. Then, you can add `NxSignal` as a dependency 15 | to your Mix project: 16 | 17 | ```elixir 18 | def deps do 19 | [ 20 | {:nx_signal, "~> 0.1"} 21 | ] 22 | end 23 | ``` 24 | 25 | You can also use `Mix.install` for standalone development: 26 | 27 | ```elixir 28 | Mix.install([ 29 | {:nx_signal, "~> 0.1"} 30 | ]) 31 | ``` 32 | 33 | By default, `NxSignal` only depends directly on `Nx` itself. If you wish to use separate backends 34 | such as `Torchx` or `EXLA`, you need to explicitly depend on them. 35 | 36 | All of `NxSignal`'s functionality is provided through `Nx.Defn`, so things should work out of the 37 | box with different backends and compilers. 38 | 39 | ## Guides (Livebook) 40 | 41 | Check out the "guides" directory in the repository for examples. 42 | 43 | ## Contributing 44 | 45 | Contributions are more than welcome! 46 | 47 | Firstly, please make sure you check the issues tracker and the pull requests list for 48 | a similar feature or bugfix to what you wish to contribute. 49 | If there aren't any mentions to be found, open up an issue so that we can discuss the 50 | feature beforehand. 51 | -------------------------------------------------------------------------------- /guides/fft.livemd: -------------------------------------------------------------------------------- 1 | # The Discrete Fourier Transform (DFT) 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:nx_signal, "~> 0.2"}, 6 | {:vega_lite, "~> 0.1"}, 7 | {:kino_vega_lite, "~> 0.1"} 8 | ]) 9 | ``` 10 | 11 | ## What is the Discrete Fourier Transform (DFT)? 12 | 13 | This livebook will show you how to use the Discrete Fourier Transform (DFT) to analyze a signal. 14 | 15 | Suppose we have a periodic signal which we want to analyze. 16 | 17 | We will run a Fast Fourrier Transform, which is a fast algorithm to compute the DFT. 18 | It transforms a time-domain function into the frequency domain. 19 | 20 | 21 | 22 | ## Building the signal 23 | 24 | Let's build a known signal that we will decompose and analyze later on. 25 | 26 | The signal will be the sum of two sinusoidal signals, one at 5Hz, and one at 20Hz with the corresponding amplitudes (1, 0.5). 27 | 28 | $f(t) = \sin(2\pi*5*t) + \frac{1}{2} \sin(2\pi*20*t)$ 29 | 30 | Suppose we can sample at `fs=50Hz` (meaning 50 samples per second) and our aquisition time is `duration = 1s`. 31 | 32 | We build a time series of `t` equally spaced points with the given `duration` interval with `Nx.linspace`. 33 | 34 | For each value of this serie (the discrete time $t$), we will synthesize the signal $f(t)$ through the module below. 35 | 36 | ```elixir 37 | defmodule Signal do 38 | import Nx.Defn 39 | import Nx.Constants, only: [pi: 0] 40 | 41 | defn source(t) do 42 | f1 = 5 43 | f2 = 20 44 | Nx.sin(2 * pi() * f1 * t) + 1/2 * Nx.sin(2 * pi() * f2 * t) 45 | end 46 | 47 | def sample(opts) do 48 | fs = opts[:fs] 49 | duration = opts[:duration] 50 | t = Nx.linspace(0, duration, n: trunc(duration * fs), endpoint: false, type: {:f, 32}) 51 | source(t) 52 | end 53 | end 54 | ``` 55 | 56 | We sample our signal at fs=50Hz during 1s: 57 | 58 | ```elixir 59 | fs = 50; duration= 1 60 | 61 | sample = Signal.sample(fs: fs, duration: 1) 62 | ``` 63 | 64 | ## Analyzing the signal with the DFT 65 | 66 | Because our signal contains many periods of the underlying function, the DFT results will contain some noise. 67 | This noise can stem both from the fact that we're likely cutting of the signal in the middle of a period 68 | and from the fact that we have a specific frequency resolution which ends up grouping our individual components into frequency bins. 69 | The latter isn't really a problem as we have chosen `fs` to be [fast enough](https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem). 70 | 71 | > The number at the index $i$ of the DFT results gives an approximation of the amplitude and phase of the sampled signal at the frequency $i$. 72 | 73 | In other words, doing `Nx.fft(sample)` returns a list of numbers indexed by the frequency. 74 | 75 | ```elixir 76 | dft = Nx.fft(sample) 77 | ``` 78 | 79 | We will limit our study points to the first half of `dft` because it is symmetrical on the upper half. 80 | The phase doesn't really matter to us because we don't wish to reconstruct the signal, nor find possible discontinuities, 81 | so we'll use `Nx.abs` to obtain the absolute values at each point. 82 | 83 | ```elixir 84 | n = Nx.size(dft) 85 | 86 | max_freq_index = div(n, 2) 87 | 88 | amplitudes = Nx.abs(dft)[0..max_freq_index] 89 | 90 | # the frequency bins, "n" of them spaced by fs/n=1 unit: 91 | frequencies = NxSignal.fft_frequencies(fs, fft_length: n)[0..max_freq_index] 92 | 93 | data1 = %{ 94 | frequencies: Nx.to_list(frequencies), 95 | amplitudes: Nx.to_list(amplitudes) 96 | } 97 | 98 | VegaLite.new(width: 700, height: 300) 99 | |> VegaLite.data_from_values(data1) 100 | |> VegaLite.mark(:bar, tooltip: true) 101 | |> VegaLite.encode_field(:x, "frequencies", 102 | type: :quantitative, 103 | title: "frequency (Hz)", 104 | scale: [domain: [0, 50]] 105 | ) 106 | |> VegaLite.encode_field(:y, "amplitudes", 107 | type: :quantitative, 108 | title: "amplitutde", 109 | scale: [domain: [0, 30]] 110 | ) 111 | ``` 112 | 113 | We see the peaks at 5Hz, 20Hz with their amplitudes (the second is half the first). 114 | 115 | This is indeed our synthesized signal 🎉 116 | 117 | We can confirm this visual inspection with a peek into our data. We use `Nx.top_k` function. 118 | 119 | ```elixir 120 | {values, indices} = Nx.top_k(amplitudes, k: 5) 121 | 122 | { 123 | values, 124 | frequencies[indices] 125 | } 126 | ``` 127 | 128 | ### Visualizing the original signal and the Inverse Discrete Fourier Transform 129 | 130 | Let's visualize our incoming signal over 400ms. This correspond to 2 periods of our 5Hz component and 8 periods of our 20Hz component. 131 | 132 | We compute 200 points to have a smooth curve, thus every (400/200=) 2ms. 133 | 134 | We also add the reconstructed signal via the **Inverse Discrete Fourier Transform** available as `Nx.ifft`. 135 | 136 | This gives us 50 values spaced by 1000ms / 50 = 20ms. 137 | 138 | Below, we display them as a bar chart under the line representing the ideal signal. 139 | 140 | ```elixir 141 | #----------- REAL SIGNAL 142 | # compute 200 points of the "real" signal during 2/5=400ms (twice the main period) 143 | 144 | t = Nx.linspace(0, 0.4, n: trunc(0.4 * 500)) 145 | sample = Signal.source(t) 146 | 147 | #----------- RECONSTRUCTED IFFT 148 | yr = Nx.ifft(dft) |> Nx.real() 149 | fs = 50 150 | tr = Nx.linspace(0, 1, n: 1 * fs, endpoint: false) 151 | 152 | idx = Nx.less_equal(tr, 0.4) 153 | xr = Nx.select(idx, tr, :nan) 154 | yr = Nx.select(idx, yr, :nan) 155 | #---------------- 156 | 157 | 158 | data = %{ 159 | x: Nx.to_list(t), 160 | y: Nx.to_list(sample) 161 | } 162 | 163 | data_r = %{ 164 | yr: Nx.to_list(yr), 165 | xr: Nx.to_list(xr) 166 | } 167 | 168 | VegaLite.new(width: 600, height: 300) 169 | |> VegaLite.layers([ 170 | VegaLite.new() 171 | |> VegaLite.data_from_values(data) 172 | |> VegaLite.mark(:line, tooltip: true) 173 | |> VegaLite.encode_field(:x, "x", type: :quantitative, title: "time (ms)", scale: [domain: [0, 0.4]]) 174 | |> VegaLite.encode_field(:y, "y", type: :quantitative, title: "signal") 175 | |> VegaLite.encode_field(:order, "x"), 176 | VegaLite.new() 177 | |> VegaLite.data_from_values(data_r) 178 | |> VegaLite.mark(:bar, tooltip: true) 179 | |> VegaLite.encode_field(:x, "xr", type: :quantitative, scale: [domain: [0, 0.4]]) 180 | |> VegaLite.encode_field(:y, "yr", type: :quantitative, title: "reconstructed") 181 | |> VegaLite.encode_field(:order, "xr") 182 | ]) 183 | ``` 184 | 185 | We see that during 400ms, we have 2 periods of a longer period signal, and 8 of a shorter and smaller perturbation period signal. 186 | -------------------------------------------------------------------------------- /guides/filtering.livemd: -------------------------------------------------------------------------------- 1 | # Filtering 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:nx_signal, "~> 0.2"}, 6 | {:vega_lite, "~> 0.1"}, 7 | {:kino_vega_lite, "~> 0.1"} 8 | ]) 9 | ``` 10 | 11 | ## Prepare the data 12 | 13 | ```elixir 14 | fs = 16.0e3 15 | window_duration_seconds = 100.0e-3 16 | window_length = 2 ** ceil(:math.log2(fs * window_duration_seconds)) 17 | signal_duration_seconds = 3 18 | signal_length = ceil(signal_duration_seconds * fs) 19 | 20 | sin = fn freq, n -> 21 | Nx.sin(Nx.multiply(2 * :math.pi() * freq / fs, n)) 22 | end 23 | 24 | half_n = Nx.iota({div(signal_length, 2)}) 25 | 26 | sin220 = sin.(220, half_n) 27 | sin440 = sin.(440, half_n) 28 | sin1000 = sin.(440 * 5 / 2, half_n) 29 | sin3000 = sin.(220 * 4 / 3 * 4, half_n) 30 | 31 | n = Nx.iota({signal_length}) 32 | t = Nx.divide(n, fs) 33 | data = Nx.concatenate([Nx.add(sin440, sin1000), Nx.add(sin220, sin3000)]) 34 | 35 | # Data for plotting 36 | 37 | slice = (signal_length - 3000)..(signal_length - 2750) 38 | plot_data = %{y: Nx.to_flat_list(data[[slice]]), x: Nx.to_flat_list(t[[slice]])} 39 | ``` 40 | 41 | 42 | 43 | ```elixir 44 | VegaLite.new(width: 600, height: 400, title: "Signal sample") 45 | |> VegaLite.data_from_values(plot_data, only: ["x", "y"]) 46 | |> VegaLite.mark(:line) 47 | |> VegaLite.encode_field(:x, "x", type: :quantitative) 48 | |> VegaLite.encode_field(:y, "y", type: :quantitative) 49 | ``` 50 | 51 | ## Preparing the Filter 52 | 53 | ```elixir 54 | # The filter used is a simple ideal filter specified through sinc 55 | 56 | # Cutoff frequency in Hz 57 | fc = 600 58 | filter_indices = Nx.iota({window_length}) |> Nx.subtract(div(window_length, 2)) 59 | 60 | h_ideal = 61 | Nx.multiply(2 * fc / fs, NxSignal.Filters.sinc(Nx.multiply(2 * fc / fs, filter_indices))) 62 | 63 | window = NxSignal.Windows.hann(n: window_length) 64 | 65 | h = Nx.multiply(h_ideal, window) 66 | 67 | hfft = 68 | h 69 | |> Nx.fft(length: window_length) 70 | |> Nx.abs() 71 | |> Nx.add(1.0e-10) 72 | 73 | hfft_power = 74 | hfft 75 | |> Nx.log() 76 | |> Nx.divide(Nx.log(10)) 77 | |> Nx.multiply(20) 78 | 79 | f_idx = Nx.iota({window_length}) |> Nx.subtract(div(window_length, 2)) 80 | 81 | f = Nx.multiply(f_idx, fs / window_length) 82 | 83 | plot_data = %{ 84 | n: Enum.to_list(0..(window_length - 1)), 85 | h: Nx.to_flat_list(h), 86 | hfft: 87 | Nx.to_flat_list( 88 | Nx.take_along_axis( 89 | hfft_power, 90 | Nx.select(Nx.less(f_idx, 0), Nx.add(f_idx, window_length), f_idx) 91 | ) 92 | ), 93 | f: Nx.to_flat_list(f) 94 | } 95 | ``` 96 | 97 | 98 | 99 | ```elixir 100 | VegaLite.new( 101 | width: 600, 102 | height: 400, 103 | title: "600 Hz Low-Pass filter (Hann window); L = 2048" 104 | ) 105 | |> VegaLite.data_from_values(plot_data, only: ["n", "h"]) 106 | |> VegaLite.mark(:line) 107 | |> VegaLite.encode_field(:x, "n", type: :quantitative) 108 | |> VegaLite.encode_field(:y, "h", type: :quantitative) 109 | ``` 110 | 111 | 112 | 113 | ```elixir 114 | VegaLite.new( 115 | width: 600, 116 | height: 400, 117 | title: "FFT - 1200 Hz Low-Pass filter (Hann window); L = 1764" 118 | ) 119 | |> VegaLite.data_from_values(plot_data, only: ["f", "hfft"]) 120 | |> VegaLite.mark(:line) 121 | |> VegaLite.encode_field(:x, "f", type: :quantitative) 122 | |> VegaLite.encode_field(:y, "hfft", type: :quantitative) 123 | ``` 124 | 125 | ## Filtering the data 126 | 127 | ```elixir 128 | # Now that we have our filter, instead of convolving the filter 129 | # with the time representation of the signal, we can multiply each 130 | # STFT frame by the DFT of the filter (represented by hfft) 131 | 132 | {z, t, f} = 133 | NxSignal.stft(data, window, fft_length: window_length, sampling_rate: fs, scaling: :spectrum) 134 | ``` 135 | 136 | ```elixir 137 | # Filter 138 | z_filtered = Nx.multiply(z, hfft) 139 | 140 | max_f = 141 | Nx.select(Nx.greater_equal(f, fs / 2), Nx.iota(f.shape), Nx.size(f) + 1) 142 | |> Nx.argmin() 143 | |> Nx.to_number() 144 | 145 | spectrogram = z |> Nx.slice([0, 0], [Nx.size(t), max_f]) |> Nx.abs() |> Nx.pow(2) 146 | 147 | filtered_spectrogram = 148 | z_filtered |> Nx.slice([0, 0], [Nx.size(t), max_f]) |> Nx.abs() |> Nx.pow(2) 149 | 150 | # Reconstruct the time signal 151 | data_out = 152 | z_filtered 153 | |> NxSignal.istft(window, fft_length: window_length, scaling: :spectrum, sampling_rate: fs) 154 | |> Nx.as_type(data.type) 155 | ``` 156 | 157 | ```elixir 158 | plot_data = 159 | for t_idx <- 0..(Nx.size(t) - 1), 160 | f_idx <- 0..max_f, 161 | Nx.to_number(f[[f_idx]]) <= 4000, 162 | reduce: %{"t" => [], "f" => [], "s" => [], "filtered_s" => []} do 163 | %{"t" => t_acc, "f" => f_acc, "s" => s_acc, "filtered_s" => filtered_s_acc} -> 164 | %{ 165 | "t" => [Nx.to_number(t[[t_idx]]) | t_acc], 166 | "f" => [Float.round(Nx.to_number(f[[f_idx]]), 3) | f_acc], 167 | "s" => [Nx.to_number(spectrogram[[t_idx, f_idx]]) | s_acc], 168 | "filtered_s" => [Nx.to_number(filtered_spectrogram[[t_idx, f_idx]]) | filtered_s_acc] 169 | } 170 | end 171 | ``` 172 | 173 | ```elixir 174 | defmodule Spectrogram do 175 | alias VegaLite, as: Vl 176 | 177 | def plot(title, dataset) do 178 | Vl.new(title: title, width: 500, height: 500) 179 | |> Vl.mark(:rect) 180 | |> Vl.data_from_values(dataset) 181 | |> Vl.encode_field(:x, "t", 182 | type: :quantitative, 183 | title: "Time (seconds)", 184 | axis: [tick_min_step: 0.1], 185 | grid: false 186 | ) 187 | |> Vl.encode_field(:y, "f", 188 | type: :quantitative, 189 | sort: "-x", 190 | title: "Frequency (Hz)", 191 | axis: [tick_count: 25], 192 | grid: false 193 | ) 194 | |> Vl.encode_field(:color, "s", 195 | aggregate: :max, 196 | type: :quantitative, 197 | scale: [scheme: "viridis"], 198 | legend: [title: "dBFS"] 199 | ) 200 | |> Vl.config(view: [stroke: nil]) 201 | end 202 | end 203 | ``` 204 | 205 | ```elixir 206 | Spectrogram.plot("Spectrogram", Map.take(plot_data, ["t", "f", "s"])) 207 | ``` 208 | 209 | ```elixir 210 | Spectrogram.plot( 211 | "Filtered Spectrogram", 212 | Map.take(plot_data, ["t", "f"]) |> Map.put("s", plot_data["filtered_s"]) 213 | ) 214 | ``` 215 | 216 | ```elixir 217 | plot_data = %{ 218 | y: Nx.to_flat_list(data_out[[slice]]), 219 | x: Nx.to_flat_list(Nx.divide(n, fs)[[slice]]) 220 | } 221 | ``` 222 | 223 | 224 | 225 | ```elixir 226 | VegaLite.new(width: 600, height: 400, title: "Reconstructed Signal") 227 | |> VegaLite.data_from_values(plot_data, only: ["x", "y"]) 228 | |> VegaLite.mark(:line) 229 | |> VegaLite.encode_field(:x, "x", type: :quantitative) 230 | |> VegaLite.encode_field(:y, "y", type: :quantitative) 231 | ``` 232 | 233 | As we can see, the same original slice is now a pure 220Hz sine wave, because we got rid of the higher frequency harmonic. 234 | -------------------------------------------------------------------------------- /guides/spectrogram.livemd: -------------------------------------------------------------------------------- 1 | # Spectrogram Plotting 2 | 3 | ```elixir 4 | Mix.install([ 5 | {:nx_signal, "~> 0.2"}, 6 | {:vega_lite, "~> 0.1.4"}, 7 | {:kino_vega_lite, "~> 0.1.1"} 8 | ]) 9 | ``` 10 | 11 | ## Generating the audio data 12 | 13 | ```elixir 14 | # You can load an audio file here. For this example, 15 | # we're producing 3 seconds of 220Hz, 440Hz, 1kHz and 3kHz sine waves 16 | 17 | fs = 44.1e3 18 | t_max = 3 19 | 20 | full_n = ceil(fs * t_max) 21 | half_n = div(full_n, 2) 22 | 23 | # samples/sec * sec = samples 24 | n = Nx.iota({half_n}) 25 | 26 | sin = fn freq, n -> 27 | Nx.sin(Nx.multiply(2 * :math.pi() * freq / fs, n)) 28 | end 29 | 30 | sin220 = sin.(220, n) 31 | sin440 = sin.(440, n) 32 | sin1000 = sin.(1000, n) 33 | sin3000 = sin.(3000, n) 34 | 35 | data = Nx.concatenate([Nx.add(sin440, sin1000), Nx.add(sin220, sin3000)]) 36 | n = Nx.iota({full_n}) 37 | 38 | d = %{data: Nx.to_flat_list(data[[1000..1250]]), n: Nx.to_flat_list(n[[1000..1250]])} 39 | ``` 40 | 41 | 42 | 43 | ```elixir 44 | VegaLite.new(width: 600, height: 600, title: "Audio Sample") 45 | |> VegaLite.data_from_values(d, only: ["n", "data"]) 46 | |> VegaLite.mark(:line) 47 | |> VegaLite.encode_field(:x, "n", type: :ordinal) 48 | |> VegaLite.encode_field(:y, "data", type: :quantitative) 49 | ``` 50 | 51 | ```elixir 52 | defmodule Spectrogram do 53 | alias VegaLite, as: Vl 54 | import Nx.Defn 55 | 56 | def calculate_stft_and_plot_spectrogram( 57 | input, 58 | fs, 59 | window_duration_ms, 60 | plot_cutoff_frequency \\ 4000 61 | ) do 62 | n_window = ceil(fs * window_duration_ms) 63 | {spectrogram, f, t, max_f} = stft(input, fs: fs, n_window: n_window) 64 | 65 | max_f = Nx.to_number(max_f) 66 | spectrogram = Nx.slice(spectrogram, [0, 0], [Nx.size(t), max_f]) 67 | f = Nx.slice(f, [0], [max_f]) 68 | 69 | spectrogram 70 | |> to_plot_data(f, t, plot_cutoff_frequency) 71 | |> plot() 72 | end 73 | 74 | defn stft(input, opts) do 75 | fs = opts[:fs] 76 | n_window = opts[:n_window] 77 | 78 | # ms to samples 79 | window = NxSignal.Windows.hann(n: n_window, is_periodic: true) 80 | 81 | # use the default overlap of 50% 82 | {s, t, f} = NxSignal.stft(input, window, sampling_rate: fs, fft_length: 1024) 83 | 84 | max_f = 85 | Nx.select(f >= fs / 2, Nx.iota(f.shape), Nx.size(f) + 1) 86 | |> Nx.argmin() 87 | 88 | spectrogram = Nx.abs(s) 89 | # to dBFS 90 | spectrogram = 20 * Nx.log(spectrogram / Nx.reduce_max(spectrogram)) / Nx.log(10) 91 | 92 | {spectrogram, f, t, max_f} 93 | end 94 | 95 | defp to_plot_data(s, f, t, plot_cutoff_frequency) do 96 | for t_idx <- 0..(Nx.size(t) - 1), 97 | f_idx <- 0..(Nx.size(f) - 1), 98 | Nx.to_number(f[[f_idx]]) <= plot_cutoff_frequency, 99 | reduce: %{"t" => [], "f" => [], "s" => []} do 100 | %{"t" => t_acc, "f" => f_acc, "s" => s_acc} -> 101 | %{ 102 | "t" => [Nx.to_number(t[[t_idx]]) | t_acc], 103 | "f" => [Float.round(Nx.to_number(f[[f_idx]]), 3) | f_acc], 104 | "s" => [Nx.to_number(s[[t_idx, f_idx]]) | s_acc] 105 | } 106 | end 107 | end 108 | 109 | defp plot(dataset) do 110 | Vl.new(title: "Spectrogram", width: 500, height: 500) 111 | |> Vl.mark(:rect) 112 | |> Vl.data_from_values(dataset) 113 | |> Vl.encode_field(:x, "t", 114 | type: :quantitative, 115 | title: "Time (seconds)", 116 | axis: [tick_min_step: 0.1], 117 | grid: false 118 | ) 119 | |> Vl.encode_field(:y, "f", 120 | type: :quantitative, 121 | sort: "-x", 122 | title: "Frequency (Hz)", 123 | axis: [tick_count: 25], 124 | grid: false 125 | ) 126 | |> Vl.encode_field(:color, "s", 127 | aggregate: :max, 128 | type: :quantitative, 129 | scale: [scheme: "viridis"], 130 | legend: [title: "dBFS"] 131 | ) 132 | |> Vl.config(view: [stroke: nil]) 133 | end 134 | end 135 | ``` 136 | 137 | ```elixir 138 | Spectrogram.calculate_stft_and_plot_spectrogram(data, fs, 50.0e-3) 139 | ``` 140 | 141 | Notice how the first half of the spectrogram looks cleaner than the second one. This is due to the window length, that also interferes in how we can observe both our time and frequency resolutions. 142 | 143 | Below we can see what happens if we use different window durations (150ms, 100ms and 25ms respectively). 144 | 145 | ```elixir 146 | Spectrogram.calculate_stft_and_plot_spectrogram(data, fs, 150.0e-3) 147 | ``` 148 | 149 | ```elixir 150 | Spectrogram.calculate_stft_and_plot_spectrogram(data, fs, 100.0e-3) 151 | ``` 152 | 153 | ```elixir 154 | Spectrogram.calculate_stft_and_plot_spectrogram(data, fs, 25.0e-3) 155 | ``` 156 | -------------------------------------------------------------------------------- /lib/nx_signal.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal do 2 | @moduledoc """ 3 | Nx library extension for digital signal processing. 4 | """ 5 | 6 | import Nx.Defn 7 | 8 | @doc ~S""" 9 | Computes the Short-Time Fourier Transform of a tensor. 10 | 11 | Returns the complex spectrum Z, the time in seconds for 12 | each frame and the frequency bins in Hz. 13 | 14 | The STFT is parameterized through: 15 | 16 | * $k$: length of the Discrete Fourier Transform (DFT) 17 | * $N$: length of each frame 18 | * $H$: hop (in samples) between frames (calculated as $H = N - \text{overlap\\_length}$) 19 | * $M$: number of frames 20 | * $x[n]$: the input time-domain signal 21 | * $w[n]$: the window function to be applied to each frame 22 | 23 | $$ 24 | DFT(x, w) := \sum_{n=0}^{N - 1} x[n]w[n]e^\frac{-2 \pi i k n}{N} \\\\ 25 | X[m, k] = DFT(x[mH..(mH + N - 1)], w) 26 | $$ 27 | 28 | where $m$ assumes all values in the interval $[0, M - 1]$ 29 | 30 | See also: `NxSignal.Windows`, `istft/3`, `stft_to_mel/3` 31 | 32 | ## Options 33 | 34 | * `:sampling_rate` - the sampling frequency $F_s$ for the input in Hz. Defaults to `1000`. 35 | * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`. 36 | * `:overlap_length` - the number of samples for the overlap between frames. 37 | Defaults to half the window size. 38 | * `:window_padding` - `:reflect`, `:zeros` or `nil`. See `as_windowed/3` for more details. 39 | * `:scaling` - `nil`, `:spectrum` or `:psd`. 40 | * `:spectrum` - each frame is divided by $\sum_{i} window[i]$. 41 | * `nil` - No scaling is applied. 42 | * `:psd` - each frame is divided by $\sqrt{F\_s\sum_{i} window[i]^2}$. 43 | 44 | ## Examples 45 | 46 | iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(2), overlap_length: 1, fft_length: 2, sampling_rate: 400) 47 | iex> z 48 | #Nx.Tensor< 49 | c64[frames: 3][frequencies: 2] 50 | [ 51 | [1.0+0.0i, -1.0+0.0i], 52 | [3.0+0.0i, -1.0+0.0i], 53 | [5.0+0.0i, -1.0+0.0i] 54 | ] 55 | > 56 | iex> t 57 | #Nx.Tensor< 58 | f32[frames: 3] 59 | [0.0024999999441206455, 0.004999999888241291, 0.007499999832361937] 60 | > 61 | iex> f 62 | #Nx.Tensor< 63 | f32[frequencies: 2] 64 | [0.0, 200.0] 65 | > 66 | """ 67 | @doc type: :time_frequency 68 | deftransform stft(data, window, opts \\ []) do 69 | {frame_length} = Nx.shape(window) 70 | 71 | opts = 72 | Keyword.validate!(opts, [ 73 | :overlap_length, 74 | :window, 75 | :scaling, 76 | window_padding: :valid, 77 | sampling_rate: 100, 78 | fft_length: :power_of_two 79 | ]) 80 | 81 | sampling_rate = opts[:sampling_rate] || raise ArgumentError, "missing sampling_rate option" 82 | 83 | overlap_length = opts[:overlap_length] || div(frame_length, 2) 84 | 85 | stft_n(data, window, sampling_rate, Keyword.put(opts, :overlap_length, overlap_length)) 86 | end 87 | 88 | defnp stft_n(data, window, sampling_rate, opts) do 89 | {frame_length} = Nx.shape(window) 90 | padding = opts[:window_padding] 91 | fft_length = opts[:fft_length] 92 | overlap_length = opts[:overlap_length] 93 | 94 | spectrum = 95 | data 96 | |> as_windowed( 97 | padding: padding, 98 | window_length: frame_length, 99 | stride: frame_length - overlap_length 100 | ) 101 | |> Nx.multiply(window) 102 | |> Nx.fft(length: fft_length) 103 | 104 | {num_frames, fft_length} = Nx.shape(spectrum) 105 | 106 | frequencies = fft_frequencies(sampling_rate, fft_length: fft_length) 107 | 108 | # assign the middle of the equivalent time window as the time for the given frame 109 | time_step = frame_length / (2 * sampling_rate) 110 | last_frame = time_step * num_frames 111 | times = Nx.linspace(time_step, last_frame, n: num_frames, name: :frames) 112 | 113 | output = 114 | case opts[:scaling] do 115 | :spectrum -> 116 | spectrum / Nx.sum(window) 117 | 118 | :psd -> 119 | spectrum / Nx.sqrt(sampling_rate * Nx.sum(window ** 2)) 120 | 121 | nil -> 122 | spectrum 123 | 124 | scaling -> 125 | raise ArgumentError, 126 | "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}" 127 | end 128 | 129 | {Nx.reshape(output, spectrum.shape, names: [:frames, :frequencies]), times, frequencies} 130 | end 131 | 132 | @doc """ 133 | Computes the frequency bins for a FFT with given options. 134 | 135 | ## Arguments 136 | 137 | * `sampling_rate` - Sampling frequency in Hz. 138 | 139 | ## Options 140 | 141 | * `:fft_length` - Number of FFT frequency bins. 142 | * `:type` - Optional output type. Defaults to `{:f, 32}` 143 | * `:name` - Optional axis name for the tensor. Defaults to `:frequencies` 144 | 145 | ## Examples 146 | 147 | iex> NxSignal.fft_frequencies(1.6e4, fft_length: 10) 148 | #Nx.Tensor< 149 | f32[frequencies: 10] 150 | [0.0, 1.6e3, 3.2e3, 4.8e3, 6.4e3, 8.0e3, 9.6e3, 1.12e4, 1.28e4, 1.44e4] 151 | > 152 | """ 153 | @doc type: :time_frequency 154 | defn fft_frequencies(sampling_rate, opts \\ []) do 155 | opts = keyword!(opts, [:fft_length, type: {:f, 32}, name: :frequencies, endpoint: false]) 156 | fft_length = opts[:fft_length] 157 | 158 | step = sampling_rate / fft_length 159 | 160 | Nx.linspace(0, step * fft_length, 161 | n: fft_length, 162 | type: opts[:type], 163 | name: opts[:name], 164 | endpoint: opts[:endpoint] 165 | ) 166 | end 167 | 168 | @doc """ 169 | Returns a tensor of K windows of length N 170 | 171 | ## Options 172 | 173 | * `:window_length` - the number of samples in a window 174 | * `:stride` - The number of samples to skip between windows. Defaults to `1`. 175 | * `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the 176 | input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered 177 | at the start of the signal. The padding is applied for the whole input, rather than individual 178 | windows. For `:zeros`, effectively each incomplete window will be zero-padded. 179 | 180 | ## Examples 181 | 182 | iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 4) 183 | #Nx.Tensor< 184 | s32[5][4] 185 | [ 186 | [0, 1, 2, 3], 187 | [1, 2, 3, 4], 188 | [2, 3, 4, 10], 189 | [3, 4, 10, 11], 190 | [4, 10, 11, 12] 191 | ] 192 | > 193 | 194 | iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11, 12]), window_length: 3) 195 | #Nx.Tensor< 196 | s32[6][3] 197 | [ 198 | [0, 1, 2], 199 | [1, 2, 3], 200 | [2, 3, 4], 201 | [3, 4, 10], 202 | [4, 10, 11], 203 | [10, 11, 12] 204 | ] 205 | > 206 | 207 | iex> NxSignal.as_windowed(Nx.tensor([0, 1, 2, 3, 4, 10, 11]), window_length: 2, stride: 2, padding: [{0, 3}]) 208 | #Nx.Tensor< 209 | s32[5][2] 210 | [ 211 | [0, 1], 212 | [2, 3], 213 | [4, 10], 214 | [11, 0], 215 | [0, 0] 216 | ] 217 | > 218 | 219 | iex> t = Nx.iota({7}); 220 | iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1) 221 | #Nx.Tensor< 222 | s32[8][6] 223 | [ 224 | [3, 2, 1, 0, 1, 2], 225 | [2, 1, 0, 1, 2, 3], 226 | [1, 0, 1, 2, 3, 4], 227 | [0, 1, 2, 3, 4, 5], 228 | [1, 2, 3, 4, 5, 6], 229 | [2, 3, 4, 5, 6, 5], 230 | [3, 4, 5, 6, 5, 4], 231 | [4, 5, 6, 5, 4, 3] 232 | ] 233 | > 234 | 235 | iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2) 236 | #Nx.Tensor< 237 | s32[6][6] 238 | [ 239 | [3, 2, 1, 0, 1, 2], 240 | [1, 0, 1, 2, 3, 4], 241 | [1, 2, 3, 4, 5, 6], 242 | [3, 4, 5, 6, 7, 8], 243 | [5, 6, 7, 8, 9, 8], 244 | [7, 8, 9, 8, 7, 6] 245 | ] 246 | > 247 | """ 248 | @doc type: :windowing 249 | deftransform as_windowed(tensor, opts \\ []) do 250 | if opts[:padding] == :reflect do 251 | as_windowed_reflect_padding(tensor, opts) 252 | else 253 | as_windowed_non_reflect_padding(tensor, opts) 254 | end 255 | end 256 | 257 | deftransformp as_windowed_parse_reflect_opts(shape, opts) do 258 | window_length = opts[:window_length] 259 | 260 | as_windowed_parse_non_reflect_opts( 261 | shape, 262 | Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}]) 263 | ) 264 | end 265 | 266 | deftransformp as_windowed_parse_non_reflect_opts(shape, opts) do 267 | opts = Keyword.validate!(opts, [:window_length, padding: :valid, stride: 1]) 268 | window_length = opts[:window_length] 269 | window_dimensions = {window_length} 270 | 271 | padding = opts[:padding] 272 | 273 | [stride] = 274 | strides = 275 | case opts[:stride] do 276 | stride when is_list(stride) -> 277 | stride 278 | 279 | stride when is_integer(stride) and stride >= 1 -> 280 | [stride] 281 | 282 | stride -> 283 | raise ArgumentError, 284 | "expected an integer >= 1 or a list of integers, got: #{inspect(stride)}" 285 | end 286 | 287 | padding_config = as_windowed_to_padding_config(shape, window_dimensions, padding) 288 | 289 | # trick so that we can get Nx to calculate the pooled shape for us 290 | %{shape: pooled_shape} = 291 | Nx.window_max( 292 | Nx.iota(shape, backend: Nx.Defn.Expr), 293 | window_dimensions, 294 | padding: padding, 295 | strides: strides 296 | ) 297 | 298 | output_shape = {Tuple.product(pooled_shape), window_length} 299 | 300 | {window_length, stride, padding_config, output_shape} 301 | end 302 | 303 | defp as_windowed_to_padding_config(shape, kernel_size, mode) do 304 | case mode do 305 | :valid -> 306 | List.duplicate({0, 0, 0}, tuple_size(shape)) 307 | 308 | :same -> 309 | Enum.zip_with(Tuple.to_list(shape), Tuple.to_list(kernel_size), fn dim, k -> 310 | padding_size = max(dim - 1 + k - dim, 0) 311 | {floor(padding_size / 2), ceil(padding_size / 2), 0} 312 | end) 313 | 314 | config when is_list(config) -> 315 | Enum.map(config, fn 316 | {x, y} when is_integer(x) and is_integer(y) -> 317 | {x, y, 0} 318 | 319 | _other -> 320 | raise ArgumentError, 321 | "padding must be a list of {high, low} tuples, where each element is an integer. " <> 322 | "Got: #{inspect(config)}" 323 | end) 324 | 325 | mode -> 326 | raise ArgumentError, 327 | "invalid padding mode specified, padding must be one" <> 328 | " of :valid, :same, or a padding configuration, got:" <> 329 | " #{inspect(mode)}" 330 | end 331 | end 332 | 333 | defnp as_windowed_non_reflect_padding(tensor, opts \\ []) do 334 | # current implementation only supports windowing 1D tensors 335 | {window_length, stride, padding, output_shape} = 336 | as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts) 337 | 338 | tensor = Nx.pad(tensor, 0, padding) 339 | 340 | as_windowed_apply(tensor, stride, output_shape, window_length) 341 | end 342 | 343 | defnp as_windowed_reflect_padding(tensor, opts \\ []) do 344 | # current implementation only supports windowing 1D tensors 345 | {window_length, stride, _padding, output_shape} = 346 | as_windowed_parse_reflect_opts(Nx.shape(tensor), opts) 347 | 348 | half_window = div(window_length, 2) 349 | tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}]) 350 | 351 | as_windowed_apply(tensor, stride, output_shape, window_length) 352 | end 353 | 354 | defnp as_windowed_apply(tensor, stride, output_shape, window_length) do 355 | {num_windows, _} = output_shape 356 | 357 | window_start = Nx.iota({num_windows}) * stride 358 | window_start = Nx.vectorize(window_start, :window) 359 | output = Nx.slice(tensor, [window_start], [window_length]) 360 | 361 | output 362 | |> Nx.devectorize(keep_names: false) 363 | |> Nx.vectorize(tensor.vectorized_axes) 364 | end 365 | 366 | @doc """ 367 | Generates weights for converting an STFT representation into MEL-scale. 368 | 369 | See also: `stft/3`, `istft/3`, `stft_to_mel/3` 370 | 371 | ## Arguments 372 | 373 | * `fft_length` - Number of FFT bins 374 | * `mel_bins` - Number of target MEL bins 375 | * `sampling_rate` - Sampling frequency in Hz 376 | 377 | ## Options 378 | * `:max_mel` - the pitch for the last MEL bin before log scaling. Defaults to 3016 379 | * `:mel_frequency_spacing` - the distance in Hz between two MEL bins before log scaling. Defaults to 66.6 380 | * `:type` - Target output type. Defaults to `{:f, 32}` 381 | 382 | ## Examples 383 | 384 | iex> NxSignal.mel_filters(10, 5, 8.0e3) 385 | #Nx.Tensor< 386 | f32[mels: 5][frequencies: 10] 387 | [ 388 | [0.0, 8.129207999445498e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 389 | [0.0, 9.972016559913754e-4, 2.1870288765057921e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 390 | [0.0, 0.0, 9.510891977697611e-4, 4.150509194005281e-4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 391 | [0.0, 0.0, 0.0, 4.035891906823963e-4, 5.276656011119485e-4, 2.574124082457274e-4, 0.0, 0.0, 0.0, 0.0], 392 | [0.0, 0.0, 0.0, 0.0, 7.329034269787371e-5, 2.342205698369071e-4, 3.8295105332508683e-4, 2.8712040511891246e-4, 1.9128978601656854e-4, 9.545915963826701e-5] 393 | ] 394 | > 395 | """ 396 | @doc type: :time_frequency 397 | deftransform mel_filters(fft_length, mel_bins, sampling_rate, opts \\ []) do 398 | opts = 399 | Keyword.validate!(opts, 400 | max_mel: 3016, 401 | mel_frequency_spacing: 200 / 3, 402 | type: {:f, 32} 403 | ) 404 | 405 | mel_filters_n(sampling_rate, opts[:max_mel], opts[:mel_frequency_spacing], 406 | type: opts[:type], 407 | fft_length: fft_length, 408 | mel_bins: mel_bins 409 | ) 410 | end 411 | 412 | defnp mel_filters_n(sampling_rate, max_mel, f_sp, opts) do 413 | fft_length = opts[:fft_length] 414 | mel_bins = opts[:mel_bins] 415 | type = opts[:type] 416 | 417 | fftfreqs = fft_frequencies(sampling_rate, type: type, fft_length: fft_length) 418 | 419 | mels = Nx.linspace(0, max_mel / f_sp, type: type, n: mel_bins + 2, name: :mels) 420 | freqs = f_sp * mels 421 | 422 | min_log_hz = 1_000 423 | min_log_mel = min_log_hz / f_sp 424 | 425 | # numpy uses the f64 value by default 426 | logstep = Nx.log(6.4) / 27 427 | 428 | log_t = mels >= min_log_mel 429 | 430 | # This is the same as freqs[log_t] = min_log_hz * Nx.exp(logstep * (mels[log_t] - min_log_mel)) 431 | # notice that since freqs and mels are indexed by the same conditional tensor, we don't 432 | # need to slice either of them 433 | mel_f = Nx.select(log_t, min_log_hz * Nx.exp(logstep * (mels - min_log_mel)), freqs) 434 | 435 | fdiff = Nx.new_axis(mel_f[1..-1//1] - mel_f[0..-2//1], 1) 436 | ramps = Nx.new_axis(mel_f, 1) - fftfreqs 437 | 438 | lower = -ramps[0..(mel_bins - 1)] / fdiff[0..(mel_bins - 1)] 439 | upper = ramps[2..(mel_bins + 1)//1] / fdiff[1..mel_bins] 440 | weights = Nx.max(0, Nx.min(lower, upper)) 441 | 442 | enorm = 2.0 / (mel_f[2..(mel_bins + 1)] - mel_f[0..(mel_bins - 1)]) 443 | 444 | weights * Nx.new_axis(enorm, 1) 445 | end 446 | 447 | @doc """ 448 | Converts a given STFT time-frequency spectrum into a MEL-scale time-frequency spectrum. 449 | 450 | See also: `stft/3`, `istft/3`, `mel_filters/4` 451 | 452 | ## Arguments 453 | 454 | * `z` - STFT spectrum 455 | * `sampling_rate` - Sampling frequency in Hz 456 | 457 | ## Options 458 | 459 | * `:fft_length` - Number of FFT bins 460 | * `:mel_bins` - Number of target MEL bins. Defaults to 128 461 | * `:type` - Target output type. Defaults to `{:f, 32}` 462 | 463 | ## Examples 464 | 465 | iex> fft_length = 16 466 | iex> sampling_rate = 8.0e3 467 | iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect) 468 | iex> Nx.axis_size(z, :frequencies) 469 | 16 470 | iex> Nx.axis_size(z, :frames) 471 | 6 472 | iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4) 473 | #Nx.Tensor< 474 | f32[frames: 6][mel: 4] 475 | [ 476 | [0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825], 477 | [0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537], 478 | [0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198], 479 | [0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989], 480 | [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721], 481 | [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721] 482 | ] 483 | > 484 | """ 485 | @doc type: :time_frequency 486 | defn stft_to_mel(z, sampling_rate, opts \\ []) do 487 | opts = 488 | keyword!(opts, [:fft_length, :mel_bins, :max_mel, :mel_frequency_spacing, type: {:f, 32}]) 489 | 490 | magnitudes = Nx.abs(z) ** 2 491 | 492 | filters = 493 | mel_filters(opts[:fft_length], opts[:mel_bins], sampling_rate, mel_filters_opts(opts)) 494 | 495 | freq_size = div(opts[:fft_length], 2) 496 | 497 | real_freqs_mag = Nx.slice_along_axis(magnitudes, 0, freq_size, axis: :frequencies) 498 | real_freqs_filters = Nx.slice_along_axis(filters, 0, freq_size, axis: :frequencies) 499 | 500 | mel_spec = 501 | Nx.dot( 502 | real_freqs_mag, 503 | [:frequencies], 504 | real_freqs_filters, 505 | [:frequencies] 506 | ) 507 | 508 | mel_spec = Nx.reshape(mel_spec, Nx.shape(mel_spec), names: [:frames, :mel]) 509 | 510 | log_spec = Nx.log(Nx.clip(mel_spec, 1.0e-10, :infinity)) / Nx.log(10) 511 | log_spec = Nx.max(log_spec, Nx.reduce_max(log_spec) - 8) 512 | (log_spec + 4) / 4 513 | end 514 | 515 | deftransformp mel_filters_opts(opts) do 516 | Keyword.take(opts, [:max_mel, :mel_frequency_spacing, :type]) 517 | end 518 | 519 | @doc ~S""" 520 | Computes the Inverse Short-Time Fourier Transform of a tensor. 521 | 522 | Returns a tensor of M time-domain frames of length `fft_length`. 523 | 524 | See also: `NxSignal.Windows`, `stft/3` 525 | 526 | ## Options 527 | 528 | * `:fft_length` - the DFT length that will be passed to `Nx.fft/2`. Defaults to `:power_of_two`. 529 | * `:overlap_length` - the number of samples for the overlap between frames. 530 | Defaults to half the window size. 531 | * `:sampling_rate` - the sampling rate $F_s$ in Hz. Defaults to `1000`. 532 | * `:scaling` - `nil`, `:spectrum` or `:psd`. 533 | * `:spectrum` - each frame is multiplied by $\sum_{i} window[i]$. 534 | * `nil` - No scaling is applied. 535 | * `:psd` - each frame is multiplied by $\sqrt{F\_s\sum_{i} window[i]^2}$. 536 | 537 | ## Examples 538 | 539 | In general, `istft/3` takes in the same parameters and window as the `stft/3` that generated the spectrum. 540 | In the first example, we can notice that the reconstruction is mostly perfect, aside from the first sample. 541 | 542 | This is because the Hann window only ensures perfect reconstruction in overlapping regions, so the edges 543 | of the signal end up being distorted. 544 | 545 | iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) 546 | iex> w = NxSignal.Windows.hann(4) 547 | iex> opts = [sampling_rate: 1, fft_length: 4] 548 | iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) 549 | iex> result = NxSignal.istft(z, w, opts) 550 | iex> Nx.as_type(result, Nx.type(t)) 551 | #Nx.Tensor< 552 | s32[8] 553 | [0, 10, 1, 0, 10, 10, 2, 20] 554 | > 555 | 556 | Different scaling options are available (see `stft/3` for a more detailed explanation). 557 | For perfect reconstruction, you want to use the same scaling as the STFT: 558 | 559 | iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20]) 560 | iex> w = NxSignal.Windows.hann(4) 561 | iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4] 562 | iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) 563 | iex> result = NxSignal.istft(z, w, opts) 564 | iex> Nx.as_type(result, Nx.type(t)) 565 | #Nx.Tensor< 566 | s32[8] 567 | [0, 10, 1, 0, 10, 10, 2, 20] 568 | > 569 | 570 | iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32) 571 | iex> w = NxSignal.Windows.hann(4) 572 | iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4] 573 | iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts) 574 | iex> result = NxSignal.istft(z, w, opts) 575 | iex> Nx.as_type(result, Nx.type(t)) 576 | #Nx.Tensor< 577 | f32[8] 578 | [0.0, 10.0, 0.9999999403953552, -2.1900146407460852e-7, 10.0, 10.0, 2.000000238418579, 20.0] 579 | > 580 | """ 581 | @doc type: :time_frequency 582 | defn istft(data, window, opts) do 583 | opts = keyword!(opts, [:fft_length, :overlap_length, :scaling, sampling_rate: 1000]) 584 | 585 | fft_length = 586 | case opts[:fft_length] do 587 | nil -> 588 | :power_of_two 589 | 590 | fft_length -> 591 | fft_length 592 | end 593 | 594 | overlap_length = 595 | case opts[:overlap_length] do 596 | nil -> 597 | div(Nx.size(window), 2) 598 | 599 | overlap_length -> 600 | overlap_length 601 | end 602 | 603 | sampling_rate = 604 | case {opts[:scaling], opts[:sampling_rate]} do 605 | {:psd, nil} -> raise ArgumentError, ":sampling_rate is mandatory if scaling is :psd" 606 | {_, sampling_rate} -> sampling_rate 607 | end 608 | 609 | frames = Nx.ifft(data, length: fft_length) 610 | 611 | frames_rescaled = 612 | case opts[:scaling] do 613 | :spectrum -> 614 | frames * Nx.sum(window) 615 | 616 | :psd -> 617 | frames * Nx.sqrt(sampling_rate * Nx.sum(window ** 2)) 618 | 619 | nil -> 620 | frames 621 | 622 | scaling -> 623 | raise ArgumentError, 624 | "invalid :scaling, expected one of :spectrum, :psd or nil, got: #{inspect(scaling)}" 625 | end 626 | 627 | result_non_normalized = 628 | overlap_and_add(frames_rescaled * window, overlap_length: overlap_length) 629 | 630 | normalization_factor = 631 | overlap_and_add(Nx.broadcast(Nx.abs(window) ** 2, data.shape), 632 | overlap_length: overlap_length 633 | ) 634 | 635 | normalization_factor = Nx.select(normalization_factor > 1.0e-10, normalization_factor, 1.0) 636 | 637 | result_non_normalized / normalization_factor 638 | end 639 | 640 | @doc """ 641 | Performs the overlap-and-add algorithm over 642 | an {..., M, N}-shaped tensor, where M is the number of 643 | windows and N is the window size. 644 | 645 | The tensor is zero-padded on the right so 646 | the last window fully appears in the result. 647 | 648 | ## Options 649 | 650 | * `:overlap_length` - The number of overlapping samples between windows 651 | * `:type` - output type for casting the accumulated result. 652 | If not given, defaults to `Nx.Type.to_complex/1` called on the input type. 653 | 654 | ## Examples 655 | 656 | iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 0) 657 | #Nx.Tensor< 658 | s32[12] 659 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 660 | > 661 | 662 | iex> NxSignal.overlap_and_add(Nx.iota({3, 4}), overlap_length: 3) 663 | #Nx.Tensor< 664 | s32[6] 665 | [0, 5, 15, 18, 17, 11] 666 | > 667 | 668 | iex> t = Nx.tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]], [[[10, 11, 12, 13], [14, 15, 16, 17]]]]) |> Nx.vectorize(x: 2, y: 1) 669 | iex> NxSignal.overlap_and_add(t, overlap_length: 3) 670 | #Nx.Tensor< 671 | vectorized[x: 2][y: 1] 672 | s32[5] 673 | [ 674 | [ 675 | [0, 5, 7, 9, 7] 676 | ], 677 | [ 678 | [10, 25, 27, 29, 17] 679 | ] 680 | ] 681 | > 682 | """ 683 | @doc type: :windowing 684 | defn overlap_and_add(tensor, opts \\ []) do 685 | opts = keyword!(opts, [:overlap_length, type: Nx.type(tensor)]) 686 | overlap_length = opts[:overlap_length] 687 | 688 | %{vectorized_axes: vectorized_axes, shape: input_shape} = tensor 689 | num_windows = Nx.axis_size(tensor, -2) 690 | window_length = Nx.axis_size(tensor, -1) 691 | 692 | if overlap_length >= window_length do 693 | raise ArgumentError, 694 | "overlap_length must be a number less than the window size #{window_length}, got: #{inspect(window_length)}" 695 | end 696 | 697 | tensor = 698 | Nx.revectorize(tensor, [condensed_vectors: :auto, windows: num_windows], 699 | target_shape: {window_length} 700 | ) 701 | 702 | stride = window_length - overlap_length 703 | output_holder_shape = {num_windows * stride + overlap_length} 704 | 705 | out = 706 | Nx.broadcast( 707 | Nx.tensor(0, type: tensor.type), 708 | output_holder_shape 709 | ) 710 | 711 | idx_template = Nx.iota({window_length, 1}, vectorized_axes: [windows: 1]) 712 | i = Nx.iota({num_windows}) |> Nx.vectorize(:windows) 713 | idx = idx_template + i * stride 714 | 715 | [%{vectorized_axes: [condensed_vectors: n, windows: _]} = tensor, idx] = 716 | Nx.broadcast_vectors([tensor, idx]) 717 | 718 | tensor = Nx.revectorize(tensor, [condensed_vectors: n], target_shape: {:auto}) 719 | idx = Nx.revectorize(idx, [condensed_vectors: n], target_shape: {:auto, 1}) 720 | 721 | out_shape = overlap_and_add_output_shape(out.shape, input_shape) 722 | 723 | out 724 | |> Nx.indexed_add(idx, tensor) 725 | |> Nx.as_type(opts[:type]) 726 | |> Nx.revectorize(vectorized_axes, target_shape: out_shape) 727 | end 728 | 729 | deftransformp overlap_and_add_output_shape({out_len}, in_shape) do 730 | idx = tuple_size(in_shape) - 2 731 | 732 | in_shape 733 | |> Tuple.delete_at(idx) 734 | |> put_elem(idx, out_len) 735 | end 736 | end 737 | -------------------------------------------------------------------------------- /lib/nx_signal/convolution.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Convolution do 2 | @moduledoc """ 3 | Convolution functions through various methods. 4 | 5 | Follows the `scipy.signal` conventions. 6 | """ 7 | 8 | import Nx.Defn 9 | import NxSignal.Transforms 10 | 11 | @doc """ 12 | Computes the convolution of two tensors. 13 | 14 | Given $f[n]$ of length $N$ and $k[n]$ of length ${K}$, we define the convolution $g[n] = (f * k)[n]$ by 15 | 16 | $$ 17 | g[n] = (f * k)[n] = \\sum_{m=0}^{K-1} f[n-m]k[m], 18 | $$ 19 | 20 | where $f[n]$ and $k[n]$ are assumed to be zero outside of their definition boundaries. 21 | 22 | ## Options 23 | 24 | * `:method` - One of `:fft` or `:direct`. Defaults to `:direct`. 25 | * `:mode` - One of `:full`, `:valid`, or `:same`. Defaults to `:full`. 26 | * `:full` returns all $N + K - 1$ samples. 27 | * `:same` returns the center $N$ samples. 28 | * `:valid` returns the center $N - K + 1$ samples. 29 | 30 | ## Examples 31 | 32 | iex> NxSignal.Convolution.convolve(Nx.tensor([1,2,3]), Nx.tensor([3,4,5])) 33 | #Nx.Tensor< 34 | f32[5] 35 | [3.0, 10.0, 22.0, 22.0, 15.0] 36 | > 37 | """ 38 | deftransform convolve(in1, in2, opts \\ []) do 39 | opts = Keyword.validate!(opts, mode: :full, method: :direct) 40 | 41 | if opts[:mode] not in [:full, :same, :valid] do 42 | raise ArgumentError, 43 | "expected mode to be one of [:full, :same, :valid], got: #{inspect(opts[:mode])}" 44 | end 45 | 46 | if opts[:method] not in [:direct, :fft] do 47 | raise ArgumentError, 48 | "expected method to be one of [:direct, :fft], got: #{inspect(opts[:method])}" 49 | end 50 | 51 | case opts[:method] do 52 | :direct -> 53 | direct_convolve(in1, in2, opts) 54 | 55 | :fft -> 56 | fftconvolve(in1, in2, opts) 57 | end 58 | end 59 | 60 | @doc ~S""" 61 | Computes the correlation of two tensors. 62 | 63 | Given $f[n]$ of length $N$ and $k[n]$ of length ${K}$, we define the correlation $g[n] = (f \star k)[n]$ by 64 | 65 | $$ 66 | g[n] = (f \star k)[n] = \\sum_{m = 0}^{K - 1}f[n - m]k^\*[K - 1 - m] 67 | $$ 68 | 69 | where $k^\*[n]$ is the complex conjugate of $k[n]$. 70 | 71 | ## Options 72 | 73 | * `:method` - One of `:fft` or `:direct`. Defaults to `:direct`. 74 | * `:mode` - One of `:full`, `:valid`, or `:same`. Defaults to `:full`. 75 | * `:full` returns all $N + K - 1$ samples. 76 | * `:same` returns the center $N$ samples. 77 | * `:valid` returns the center $N - K + 1$ samples. 78 | 79 | ## Examples 80 | 81 | iex> NxSignal.Convolution.correlate(Nx.tensor([1,2,3]), Nx.tensor([3,4,5])) 82 | #Nx.Tensor< 83 | f32[5] 84 | [5.0, 14.0, 26.0, 18.0, 9.0] 85 | > 86 | """ 87 | defn correlate(in1, in2, opts \\ []) do 88 | if Nx.type(in2) |> Nx.Type.complex?() do 89 | convolve(in1, Nx.conjugate(Nx.reverse(in2)), opts) 90 | else 91 | convolve(in1, Nx.reverse(in2), opts) 92 | end 93 | end 94 | 95 | deftransformp direct_convolve(in1, in2, opts) do 96 | input_rank = 97 | case {Nx.rank(in1), Nx.rank(in2)} do 98 | {0, 0} -> 99 | 0 100 | 101 | {0, r} -> 102 | raise ArgumentError, message: "Incompatible ranks: {0, #{r}}" 103 | 104 | {r, 0} -> 105 | raise ArgumentError, message: "Incompatible ranks: {#{r}, 0}" 106 | 107 | {r, r} -> 108 | r 109 | 110 | {r1, r2} -> 111 | raise ArgumentError, 112 | "NxSignal.convolve/3 requires both inputs to have the same rank or one of them to be a scalar, got #{r1} and #{r2}" 113 | end 114 | 115 | zipped = Enum.zip(Tuple.to_list(Nx.shape(in1)), Tuple.to_list(Nx.shape(in2))) 116 | 117 | ok1 = Enum.all?(for {i, j} <- zipped, do: i >= j) 118 | ok2 = Enum.all?(for {i, j} <- zipped, do: i <= j) 119 | 120 | {in1, in2} = 121 | cond do 122 | opts[:mode] != :valid -> 123 | {in1, in2} 124 | 125 | ok1 -> 126 | {in1, in2} 127 | 128 | ok2 -> 129 | {in2, in1} 130 | 131 | true -> 132 | raise ArgumentError, 133 | message: 134 | "For :valid mode, one must be at least as large as the other in every dimension" 135 | end 136 | 137 | kernel = Nx.reverse(in2) 138 | 139 | kernel_shape = 140 | case Nx.shape(kernel) do 141 | {} -> {1, 1, 1, 1} 142 | {n} -> {1, 1, 1, n} 143 | shape -> List.to_tuple([1, 1 | Tuple.to_list(shape)]) 144 | end 145 | 146 | kernel = Nx.reshape(kernel, kernel_shape) 147 | 148 | volume_shape = 149 | case Nx.shape(in1) do 150 | {} -> {1, 1, 1, 1} 151 | {n} -> {1, 1, 1, n} 152 | shape -> List.to_tuple([1, 1 | Tuple.to_list(shape)]) 153 | end 154 | 155 | volume = Nx.reshape(in1, volume_shape) 156 | 157 | opts = 158 | case opts[:mode] do 159 | :same -> 160 | kernel_spatial_shape = 161 | Nx.shape(kernel) 162 | |> Tuple.to_list() 163 | |> Enum.drop(2) 164 | 165 | padding = 166 | Enum.map(kernel_spatial_shape, fn k -> 167 | pad_total = k - 1 168 | # integer division for right side 169 | pad_right = div(pad_total, 2) 170 | # put the extra padding on the left 171 | pad_left = pad_total - pad_right 172 | {pad_left, pad_right} 173 | end) 174 | 175 | [padding: padding] 176 | 177 | :full -> 178 | kernel_spatial_shape = 179 | Nx.shape(kernel) 180 | |> Tuple.to_list() 181 | |> Enum.drop(2) 182 | 183 | padding = 184 | Enum.map(kernel_spatial_shape, fn k -> 185 | {k - 1, k - 1} 186 | end) 187 | 188 | [padding: padding] 189 | 190 | :valid -> 191 | [padding: :valid] 192 | end 193 | 194 | out = Nx.conv(volume, kernel, opts) 195 | 196 | squeeze_axes = 197 | case input_rank do 198 | 0 -> 199 | [0, 1, 2, 3] 200 | 201 | 1 -> 202 | [0, 1, 2] 203 | 204 | _ -> 205 | [0, 1] 206 | end 207 | 208 | out 209 | |> Nx.squeeze(axes: squeeze_axes) 210 | |> slice_valid(Nx.shape(volume), Nx.shape(kernel), opts[:mode]) 211 | end 212 | 213 | deftransformp slice_valid(out, in1_shape, in2_shape, :valid) do 214 | select = 215 | [in1_shape, in2_shape] 216 | |> Enum.zip_with(fn [i, j] -> 217 | 0..(i - j) 218 | end) 219 | 220 | out[select] 221 | end 222 | 223 | deftransformp slice_valid(out, _, _, _), do: out 224 | 225 | @doc """ 226 | Computes the convolution of two tensors via FFT. 227 | 228 | Given signals $f[n]$, with length $N$, and $k[n]$, with length $K$, we define the convolution $g[n] = (f * k)[n]$ by 229 | 230 | $$ 231 | g[n] = \\text{FFT}^{-1}(\\text{FFT}(f[n]) \\cdot \\text{FFT}(k[n])) 232 | $$ 233 | 234 | where $f[n]$ and $k[n]$ have their DFTs calculated with $N + K - 1$ samples. 235 | The output is sliced in accordance to the `mode` option, as described below. 236 | 237 | ## Options 238 | 239 | * `:mode` - One of `:full`, `:valid`, or `:same`. Defaults to `:full`. 240 | * `:full` returns all $N + K - 1$ samples. 241 | * `:same` returns the center $N$ samples. 242 | * `:valid` returns the center $N - K + 1$ samples. 243 | 244 | ## Examples 245 | 246 | iex> NxSignal.Convolution.fftconvolve(Nx.tensor([1,2,3]), Nx.tensor([3,4,5])) 247 | #Nx.Tensor< 248 | f32[5] 249 | [3.0000007152557373, 10.0, 22.0, 22.0, 15.0] 250 | > 251 | """ 252 | deftransform fftconvolve(in1, in2, opts \\ []) do 253 | opts = Keyword.validate!(opts, mode: :full, method: :direct) 254 | 255 | case {Nx.rank(in1), Nx.rank(in2)} do 256 | {a, b} when a == b -> 257 | s1 = Nx.shape(in1) |> Tuple.to_list() 258 | s2 = Nx.shape(in2) |> Tuple.to_list() 259 | 260 | lengths = 261 | Enum.zip_with(s1, s2, fn ax1, ax2 -> 262 | ax1 + ax2 - 1 263 | end) 264 | 265 | axes = 266 | [s1, s2, Nx.axes(in1)] 267 | |> Enum.zip_with(fn [ax1, ax2, axis] -> 268 | if ax1 != 1 and ax2 != 1 do 269 | axis 270 | end 271 | end) 272 | |> Enum.filter(& &1) 273 | 274 | lengths = Enum.map(axes, &Enum.fetch!(lengths, &1)) 275 | 276 | sp1 = 277 | fft_nd(in1, axes: axes, lengths: lengths) 278 | 279 | sp2 = 280 | fft_nd(in2, axes: axes, lengths: lengths) 281 | 282 | c = Nx.multiply(sp1, sp2) 283 | 284 | out = ifft_nd(c, axes: axes) 285 | 286 | out = 287 | if Nx.Type.merge(Nx.type(in1), Nx.type(in2)) |> Nx.Type.complex?() do 288 | out 289 | else 290 | Nx.real(out) 291 | end 292 | 293 | apply_mode(out, s1, s2, opts[:mode]) 294 | 295 | _ -> 296 | raise ArgumentError, message: "Rank of in1 and in2 must be equal." 297 | end 298 | end 299 | 300 | deftransformp apply_mode(out, _s1, _s2, :full) do 301 | out 302 | end 303 | 304 | deftransformp apply_mode(out, s1, _s2, :same) do 305 | centered(out, s1) 306 | end 307 | 308 | deftransformp apply_mode(out, s1, s2, :valid) do 309 | {s1, s2} = swap_axes(s1, s2) 310 | 311 | shape_valid = 312 | for {a, b} <- Enum.zip(s1, s2) do 313 | a - b + 1 314 | end 315 | 316 | centered(out, shape_valid) 317 | end 318 | 319 | deftransformp centered(out, new_shape) do 320 | start_indices = 321 | out 322 | |> Nx.shape() 323 | |> Tuple.to_list() 324 | |> Enum.zip_with(new_shape, fn current, new -> 325 | div(current - new, 2) 326 | end) 327 | 328 | Nx.slice(out, start_indices, new_shape) 329 | end 330 | 331 | defp swap_axes(s1, s2) do 332 | ok1 = Enum.zip_reduce(s1, s2, true, fn a, b, acc -> acc and a >= b end) 333 | ok2 = Enum.zip_reduce(s2, s1, true, fn a, b, acc -> acc and a >= b end) 334 | 335 | cond do 336 | ok1 -> 337 | {s1, s2} 338 | 339 | ok2 -> 340 | {s2, s1} 341 | 342 | true -> 343 | raise ArgumentError, 344 | message: 345 | "For 'valid' mode, one must be at least as large as the other in every dimension." 346 | end 347 | end 348 | end 349 | -------------------------------------------------------------------------------- /lib/nx_signal/filters.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Filters do 2 | @moduledoc """ 3 | Common filter functions. 4 | """ 5 | import Nx.Defn 6 | import NxSignal.Convolution 7 | 8 | @doc ~S""" 9 | Performs a median filter on a tensor. 10 | 11 | ## Options 12 | 13 | * `:kernel_shape` - the shape of the sliding window. 14 | It must be compatible with the shape of the tensor. 15 | """ 16 | @doc type: :filters 17 | defn median(t, opts) do 18 | validate_median_opts!(t, opts) 19 | 20 | idx = 21 | t 22 | |> idx_tensor() 23 | |> Nx.vectorize(:elements) 24 | 25 | t 26 | |> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape])) 27 | |> Nx.median() 28 | |> Nx.devectorize(keep_names: false) 29 | |> Nx.reshape(t.shape) 30 | |> Nx.as_type({:f, 32}) 31 | end 32 | 33 | deftransformp validate_median_opts!(t, opts) do 34 | Keyword.validate!(opts, [:kernel_shape]) 35 | 36 | if Nx.rank(t) != Nx.rank(opts[:kernel_shape]) do 37 | raise ArgumentError, message: "kernel shape must be of the same rank as the tensor" 38 | end 39 | end 40 | 41 | deftransformp idx_tensor(t) do 42 | t 43 | |> Nx.axes() 44 | |> Enum.map(&Nx.iota(t.shape, axis: &1)) 45 | |> Nx.stack(axis: -1) 46 | |> Nx.reshape({:auto, length(Nx.axes(t))}) 47 | end 48 | 49 | deftransformp start_indices(t, idx_tensor) do 50 | t 51 | |> Nx.axes() 52 | |> Enum.map(&idx_tensor[&1]) 53 | end 54 | 55 | deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape) 56 | 57 | @doc """ 58 | Applies a Wiener filter to the given Nx tensor. 59 | 60 | ## Options 61 | 62 | * `:kernel_size` - filter size given either a number or a tuple. 63 | If a number is given, a kernel with the given size, and same number of axes 64 | as the input tensor will be used. Defaults to `3`. 65 | * `:noise` - noise power, given as a scalar. This will be estimated based on the input tensor if `nil`. Defaults to `nil`. 66 | 67 | ## Examples 68 | 69 | iex> t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) 70 | iex> NxSignal.Filters.wiener(t, kernel_size: {2, 2}, noise: 10) 71 | #Nx.Tensor< 72 | f32[3][3] 73 | [ 74 | [0.25, 0.75, 1.25], 75 | [1.25, 3.0, 4.0], 76 | [2.75, 6.0, 7.0] 77 | ] 78 | > 79 | """ 80 | @doc type: :filters 81 | deftransform wiener(t, opts \\ []) do 82 | # Validate and extract options 83 | opts = Keyword.validate!(opts, noise: nil, kernel_size: 3) 84 | 85 | rank = Nx.rank(t) 86 | kernel_size = Keyword.fetch!(opts, :kernel_size) 87 | noise = Keyword.fetch!(opts, :noise) 88 | 89 | # Ensure `kernel_size` is a tuple 90 | kernel_size = 91 | cond do 92 | is_integer(kernel_size) -> Tuple.duplicate(kernel_size, rank) 93 | is_tuple(kernel_size) -> kernel_size 94 | true -> raise ArgumentError, "kernel_size must be an integer or tuple" 95 | end 96 | 97 | # Convert `nil` noise to `0.0` so it's always a valid tensor 98 | noise_t = if is_nil(noise), do: Nx.tensor(0.0), else: Nx.tensor(noise) 99 | 100 | # Compute filter window size 101 | size = Tuple.to_list(kernel_size) |> Enum.reduce(1, &*/2) 102 | 103 | # Ensure the kernel is the same size as the filter window 104 | kernel = Nx.broadcast(1.0, kernel_size) 105 | 106 | t 107 | |> Nx.as_type(:f64) 108 | |> wiener_n(kernel, noise_t, calculate_noise: is_nil(noise), size: size) 109 | |> Nx.as_type(Nx.type(t)) 110 | end 111 | 112 | defnp wiener_n(t, kernel, noise, opts) do 113 | size = opts[:size] 114 | 115 | # Compute local mean using "same" mode in correlation 116 | l_mean = correlate(t, kernel, mode: :same) / size 117 | 118 | # Compute local variance 119 | l_var = 120 | correlate(t ** 2, kernel, mode: :same) 121 | |> Nx.divide(size) 122 | |> Nx.subtract(l_mean ** 2) 123 | 124 | # Ensure `noise` is a tensor to avoid `nil` issues in `defnp` 125 | noise = 126 | case opts[:calculate_noise] do 127 | true -> Nx.mean(l_var) 128 | false -> noise 129 | end 130 | 131 | # Apply Wiener filter formula 132 | res = (t - l_mean) * (1 - noise / l_var) 133 | Nx.select(l_var < noise, l_mean, res + l_mean) 134 | end 135 | end 136 | -------------------------------------------------------------------------------- /lib/nx_signal/internal.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Internal do 2 | @moduledoc false 3 | import Nx.Defn 4 | 5 | @omega 0.56714329040978387299997 6 | @expn1 0.36787944117144232159553 7 | 8 | deftransform lambert_w(z, k, opts \\ []) do 9 | opts = Keyword.validate!(opts, tol: 1.0e-8) 10 | 11 | z = 12 | if Nx.Type.complex?(Nx.type(z)) do 13 | Nx.as_type(z, :c128) 14 | else 15 | Nx.complex(Nx.as_type(z, :f64), 0) 16 | end 17 | 18 | lambert_w_n(z, k, tol: opts[:tol]) 19 | end 20 | 21 | defnp lambert_w_n(z, k, opts) do 22 | tol = opts[:tol] 23 | 24 | rz = Nx.real(z) 25 | 26 | cond do 27 | Nx.is_infinity(rz) and rz > 0 -> 28 | z + 2.0 * Nx.Constants.pi() * k * Nx.Constants.i() 29 | 30 | Nx.is_infinity(rz) and rz < 0 -> 31 | -z + 2.0 * Nx.Constants.pi() * k * Nx.Constants.i() 32 | 33 | z == 0 and k == 0 -> 34 | z 35 | 36 | z == 0 -> 37 | Nx.Constants.neg_infinity(:f64) 38 | 39 | Nx.equal(z, 1) and k == 0 -> 40 | @omega 41 | 42 | true -> 43 | halleys_method(z, k, tol) 44 | end 45 | end 46 | 47 | defnp halleys_method(z, k, tol) do 48 | absz = Nx.abs(z) 49 | 50 | w = 51 | cond do 52 | k == 0 -> 53 | cond do 54 | Nx.abs(z + @expn1) < 0.3 -> 55 | lambertw_branchpt(z) 56 | 57 | -1.0 < Nx.real(z) and Nx.real(z) < 1.5 and Nx.abs(Nx.imag(z)) < 1.0 and 58 | -2.5 * Nx.abs(Nx.imag(z)) - 0.2 < Nx.real(z) -> 59 | lambertw_pade0(z) 60 | 61 | true -> 62 | lambertw_asy(z, k) 63 | end 64 | 65 | k == -1 and absz <= @expn1 and Nx.imag(z) == 0.0 and Nx.real(z) < 0.0 -> 66 | Nx.log(-Nx.real(z)) 67 | 68 | k == -1 -> 69 | lambertw_asy(z, k) 70 | 71 | true -> 72 | lambertw_asy(z, k) 73 | end 74 | 75 | # Halley's Method 76 | cond do 77 | Nx.real(w) >= 0 -> 78 | {w, _} = 79 | while {w, {z, tol, i = 0}}, i < 100 do 80 | ew = Nx.exp(-w) 81 | wewz = w - z * ew 82 | wn = w - wewz / (w + 1.0 - (w + 2.0) * wewz / (2.0 * w + 2.0)) 83 | 84 | if Nx.abs(wn - w) <= tol * Nx.abs(wn) do 85 | {wn, {z, tol, 100}} 86 | else 87 | {wn, {z, tol, i + 1}} 88 | end 89 | end 90 | 91 | w 92 | 93 | true -> 94 | {w, _} = 95 | while {w, {z, tol, i = 0}}, i < 100 do 96 | ew = Nx.exp(w) 97 | wew = w * ew 98 | wewz = wew - z 99 | wn = w - wewz / (wew + ew - (w + 2.0) * wewz / (2.0 * w + 2.0)) 100 | 101 | if Nx.abs(wn - w) <= tol * Nx.abs(wn) do 102 | {wn, {z, tol, 100}} 103 | else 104 | {wn, {z, tol, i + 1}} 105 | end 106 | end 107 | 108 | w 109 | end 110 | end 111 | 112 | defnp lambertw_branchpt(z) do 113 | m_e = 114 | Nx.Constants.e() 115 | 116 | p = Nx.sqrt(2.0 * (m_e * z + 1.0)) 117 | 118 | cevalpoly_2(p, -1.0 / 3.0, 1.0, -1.0) 119 | end 120 | 121 | defnp lambertw_pade0(z) do 122 | z * cevalpoly_2(z, 12.85106382978723404255, 12.34042553191489361902, 1.0) / 123 | cevalpoly_2(z, 32.53191489361702127660, 14.34042553191489361702, 1.0) 124 | end 125 | 126 | defnp lambertw_asy(z, k) do 127 | w = 128 | Nx.log(z) + 2.0 * Nx.Constants.pi() * k * Nx.Constants.i() 129 | 130 | w - Nx.log(w) 131 | end 132 | 133 | defnp cevalpoly_2(z, c0, c1, c2) do 134 | s = Nx.abs(z) ** 2 135 | r = 2 * Nx.real(z) 136 | b = -s * c0 + c2 137 | a = r * c0 + c1 138 | z * a + b 139 | end 140 | end 141 | -------------------------------------------------------------------------------- /lib/nx_signal/peak_finding.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.PeakFinding do 2 | @moduledoc """ 3 | Peak finding algorithms. 4 | """ 5 | 6 | import Nx.Defn 7 | import Nx, only: [u8: 1, s64: 1] 8 | 9 | @doc """ 10 | Finds a relative minimum along the selected `:axis`. 11 | 12 | A relative minimum is defined by the element being greater 13 | than its neighbors along the axis `:axis`. 14 | 15 | Returns a map in the following format: 16 | 17 | %{ 18 | indices: #Nx.Tensor<...>, 19 | valid_indices: #Nx.Tensor<...> 20 | } 21 | 22 | * `:indices` - the `{n, rank}` tensor of indices. 23 | Contains `-1` as a placeholder for invalid indices. 24 | 25 | * `:valid_indices` - the number of valid indices that lead the tensor. 26 | 27 | ## Options 28 | 29 | * `:axis` - the axis along which to do comparisons. Defaults to 0. 30 | * `:order` - the number of neighbor samples considered for the 31 | comparison in each direction. Defaults to 1. 32 | 33 | ## Examples 34 | 35 | iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) 36 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x) 37 | iex> valid_indices 38 | #Nx.Tensor< 39 | u32 40 | 2 41 | > 42 | iex> indices 43 | #Nx.Tensor< 44 | s32[8][1] 45 | [ 46 | [1], 47 | [5], 48 | [-1], 49 | [-1], 50 | [-1], 51 | [-1], 52 | [-1], 53 | [-1] 54 | ] 55 | > 56 | iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) 57 | #Nx.Tensor< 58 | s32[2][1] 59 | [ 60 | [1], 61 | [5] 62 | ] 63 | > 64 | 65 | For the same tensor in the previous example, we can use `:order` to check if 66 | the relative maxima are extrema in a wider neighborhood. 67 | 68 | iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) 69 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x, order: 3) 70 | iex> valid_indices 71 | #Nx.Tensor< 72 | u32 73 | 1 74 | > 75 | iex> indices 76 | #Nx.Tensor< 77 | s32[8][1] 78 | [ 79 | [1], 80 | [-1], 81 | [-1], 82 | [-1], 83 | [-1], 84 | [-1], 85 | [-1], 86 | [-1] 87 | ] 88 | > 89 | iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) 90 | #Nx.Tensor< 91 | s32[1][1] 92 | [ 93 | [1] 94 | ] 95 | > 96 | 97 | We can also apply this function to tensors with a larger rank: 98 | 99 | iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]]) 100 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmin(x) 101 | iex> valid_indices 102 | #Nx.Tensor< 103 | u32 104 | 2 105 | > 106 | iex> indices[0..1] 107 | #Nx.Tensor< 108 | s32[2][2] 109 | [ 110 | [1, 2], 111 | [1, 3] 112 | ] 113 | > 114 | iex> %{indices: indices} = NxSignal.PeakFinding.argrelmin(x, axis: 1) 115 | iex> valid_indices 116 | #Nx.Tensor< 117 | u32 118 | 2 119 | > 120 | iex> indices[0..1] 121 | #Nx.Tensor< 122 | s32[2][2] 123 | [ 124 | [0, 2], 125 | [2, 1] 126 | ] 127 | > 128 | 129 | """ 130 | @doc type: :peak_finding 131 | defn argrelmin(data, opts \\ []) do 132 | opts = keyword!(opts, axis: 0, order: 1) 133 | argrelextrema(data, &Nx.less/2, opts) 134 | end 135 | 136 | @doc """ 137 | Finds a relative maximum along the selected `:axis`. 138 | 139 | A relative maximum is defined by the element being greater 140 | than its neighbors along the axis `:axis`. 141 | 142 | Returns a map in the following format: 143 | 144 | %{ 145 | indices: #Nx.Tensor<...>, 146 | valid_indices: #Nx.Tensor<...> 147 | } 148 | 149 | * `:indices` - the `{n, rank}` tensor of indices. 150 | Contains `-1` as a placeholder for invalid indices. 151 | 152 | * `:valid_indices` - the number of valid indices that lead the tensor. 153 | 154 | ## Options 155 | 156 | * `:axis` - the axis along which to do comparisons. Defaults to 0. 157 | * `:order` - the number of neighbor samples considered for the 158 | comparison in each direction. Defaults to 1. 159 | 160 | ## Examples 161 | 162 | iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) 163 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x) 164 | iex> valid_indices 165 | #Nx.Tensor< 166 | u32 167 | 2 168 | > 169 | iex> indices 170 | #Nx.Tensor< 171 | s32[8][1] 172 | [ 173 | [3], 174 | [6], 175 | [-1], 176 | [-1], 177 | [-1], 178 | [-1], 179 | [-1], 180 | [-1] 181 | ] 182 | > 183 | iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) 184 | #Nx.Tensor< 185 | s32[2][1] 186 | [ 187 | [3], 188 | [6] 189 | ] 190 | > 191 | 192 | For the same tensor in the previous example, we can use `:order` to check if 193 | the relative maxima are extrema in a wider neighborhood. 194 | 195 | iex> x = Nx.tensor([2, 1, 2, 3, 2, 0, 1, 0]) 196 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x, order: 3) 197 | iex> valid_indices 198 | #Nx.Tensor< 199 | u32 200 | 1 201 | > 202 | iex> indices 203 | #Nx.Tensor< 204 | s32[8][1] 205 | [ 206 | [3], 207 | [-1], 208 | [-1], 209 | [-1], 210 | [-1], 211 | [-1], 212 | [-1], 213 | [-1] 214 | ] 215 | > 216 | iex> Nx.slice_along_axis(indices, 0, Nx.to_number(valid_indices), axis: 0) 217 | #Nx.Tensor< 218 | s32[1][1] 219 | [ 220 | [3] 221 | ] 222 | > 223 | 224 | We can also apply this function to tensors with a larger rank: 225 | 226 | iex> x = Nx.tensor([[1, 2, 1, 2], [6, 2, 0, 0], [5, 3, 4, 4]]) 227 | iex> %{indices: indices, valid_indices: valid_indices} = NxSignal.PeakFinding.argrelmax(x) 228 | iex> valid_indices 229 | #Nx.Tensor< 230 | u32 231 | 1 232 | > 233 | iex> indices[0] 234 | #Nx.Tensor< 235 | s32[2] 236 | [1, 0] 237 | > 238 | iex> %{indices: indices} = NxSignal.PeakFinding.argrelmax(x, axis: 1) 239 | iex> valid_indices 240 | #Nx.Tensor< 241 | u32 242 | 1 243 | > 244 | iex> indices[0] 245 | #Nx.Tensor< 246 | s32[2] 247 | [0, 1] 248 | > 249 | 250 | """ 251 | @doc type: :peak_finding 252 | defn argrelmax(data, opts \\ []) do 253 | opts = keyword!(opts, axis: 0, order: 1) 254 | argrelextrema(data, &Nx.greater/2, opts) 255 | end 256 | 257 | @doc """ 258 | Finds a relative extrema along the selected `:axis`. 259 | 260 | A relative extremum is defined by the given `comparator_fn` 261 | function of arity 2 function that returns a boolean tensor. 262 | 263 | This is the function upon which `&argrelmax/2` and `&argrelmin/2` 264 | are implemented. 265 | 266 | Returns a map in the following format: 267 | 268 | %{ 269 | indices: #Nx.Tensor<...>, 270 | valid_indices: #Nx.Tensor<...> 271 | } 272 | 273 | * `:indices` - the `{n, rank}` tensor of indices. 274 | Contains `-1` as a placeholder for invalid indices. 275 | 276 | * `:valid_indices` - the number of valid indices that lead the tensor. 277 | 278 | ## Options 279 | 280 | * `:axis` - the axis along which to do comparisons. Defaults to 0. 281 | * `:order` - the number of neighbor samples considered for the 282 | comparison in each direction. Defaults to 1. 283 | 284 | ## Examples 285 | 286 | First, do read the examples on `argrelmax/2` keeping in mind that 287 | it is equivalent to `argrelextrema(&1, &Nx.greater/2, &2)`, as well 288 | as `argrelmin/2` which is equivalent to `argrelextrema(&1, &Nx.less/2, &2)`. 289 | 290 | Having that in mind, we will expand on those concepts by using a custom function. 291 | For instance, we can change the definition of a relative maximum to one where 292 | a number is a relative maximum if it is greater than or equal to the double of its 293 | neighbors, as follows: 294 | 295 | iex> comparator = fn x, y -> Nx.greater_equal(x, Nx.multiply(y, 2)) end 296 | iex> x = Nx.tensor([0, 1, 3, 2, 0, 1, 0, 0, 0, 2, 1]) 297 | iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator) 298 | iex> result.valid_indices 299 | #Nx.Tensor< 300 | u32 301 | 3 302 | > 303 | iex> result.indices[0..2] 304 | #Nx.Tensor< 305 | s32[3][1] 306 | [ 307 | [5], 308 | [7], 309 | [9] 310 | ] 311 | > 312 | 313 | Same applies for finding local minima. In the next example, we 314 | find all local minima (i.e. `&Nx.less/2`) that are 315 | different to the global minimum. 316 | 317 | iex> x = Nx.tensor([0, 1, 0, 2, 1, 3, 0, 1]) 318 | iex> global_minimum = Nx.reduce_min(x) 319 | iex> comparator = fn x, y -> 320 | ...> x_not_global = Nx.not_equal(x, global_minimum) 321 | ...> y_not_global = Nx.not_equal(y, global_minimum) 322 | ...> both_not_global = Nx.logical_and(x_not_global, y_not_global) 323 | ...> Nx.logical_and(Nx.less(x, y), both_not_global) 324 | ...> end 325 | iex> result = NxSignal.PeakFinding.argrelextrema(x, comparator) 326 | iex> result.valid_indices 327 | #Nx.Tensor< 328 | u32 329 | 1 330 | > 331 | iex> result.indices[0..0] 332 | #Nx.Tensor< 333 | s32[1][1] 334 | [ 335 | [4] 336 | ] 337 | > 338 | """ 339 | @doc type: :peak_finding 340 | defn argrelextrema(data, comparator_fn, opts \\ []) do 341 | opts = keyword!(opts, axis: 0, order: 1) 342 | 343 | data 344 | |> boolrelextrema(comparator_fn, opts) 345 | |> nonzero() 346 | end 347 | 348 | defnp boolrelextrema(data, comparator_fn, opts \\ []) do 349 | axis = opts[:axis] 350 | order = opts[:order] 351 | locs = Nx.iota({Nx.axis_size(data, axis)}) 352 | 353 | ones = Nx.broadcast(u8(1), data.shape) 354 | [ones, _] = Nx.broadcast_vectors([ones, data]) 355 | 356 | {results, _} = 357 | while {results = ones, {data, locs, halt = u8(0), shift = s64(1)}}, 358 | not halt and shift < order + 1 do 359 | plus = Nx.take(data, Nx.clip(locs + shift, 0, Nx.size(locs) - 1), axis: axis) 360 | minus = Nx.take(data, Nx.clip(locs - shift, 0, Nx.size(locs) - 1), axis: axis) 361 | results = comparator_fn.(data, plus) and results 362 | results = comparator_fn.(data, minus) and results 363 | 364 | {results, {data, locs, not Nx.any(results), shift + 1}} 365 | end 366 | 367 | results 368 | end 369 | 370 | deftransformp nonzero(data) do 371 | flat_data = Nx.reshape(data, {:auto, 1}) 372 | 373 | indices = 374 | for axis <- 0..(Nx.rank(data) - 1), 375 | reduce: Nx.broadcast(0, {Nx.axis_size(flat_data, 0), Nx.rank(data)}) do 376 | %{shape: {n, _}} = indices -> 377 | iota = data.shape |> Nx.iota(axis: axis) |> Nx.reshape({n, 1}) 378 | Nx.put_slice(indices, [0, axis], iota) 379 | end 380 | 381 | indices_with_mask = 382 | Nx.select( 383 | Nx.broadcast(flat_data, indices.shape), 384 | indices, 385 | Nx.broadcast(-1, indices.shape) 386 | ) 387 | 388 | order = Nx.argsort(Nx.squeeze(flat_data, axes: [1]), axis: 0, direction: :desc) 389 | 390 | %{indices: Nx.take(indices_with_mask, order), valid_indices: Nx.sum(flat_data)} 391 | end 392 | end 393 | -------------------------------------------------------------------------------- /lib/nx_signal/transforms.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Transforms do 2 | @moduledoc false 3 | import Nx.Defn 4 | 5 | deftransform fft_nd(tensor, opts \\ []) do 6 | axes = Keyword.get(opts, :axes, [-1]) 7 | lengths = Keyword.get(opts, :lengths) || List.duplicate(nil, length(axes)) 8 | 9 | Enum.zip_reduce(axes, lengths, tensor, fn axis, len, acc -> 10 | Nx.fft(acc, axis: axis, length: len) 11 | end) 12 | end 13 | 14 | deftransform ifft_nd(tensor, opts \\ []) do 15 | axes = Keyword.get(opts, :axes, [-1]) 16 | lengths = Keyword.get(opts, :lengths) || List.duplicate(nil, length(axes)) 17 | 18 | Enum.zip_reduce(axes, lengths, tensor, fn axis, len, acc -> 19 | Nx.ifft(acc, axis: axis, length: len) 20 | end) 21 | end 22 | end 23 | -------------------------------------------------------------------------------- /lib/nx_signal/waveforms.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Waveforms do 2 | @moduledoc """ 3 | Functions that calculate waveforms given a time tensor. 4 | """ 5 | import Nx.Defn 6 | import Nx.Constants, only: [pi: 0] 7 | 8 | @doc ~S""" 9 | Periodic sawtooth or triangular waveform. 10 | 11 | The wave as a period of $2\pi$, rising from -1 to 1 12 | in the interval $[0, 2\pi\cdot\text{width}]$ and dropping from 13 | 1 to -1 in the interval $[2\pi \cdot \text{width}, 2\pi]$. 14 | 15 | ## Options 16 | 17 | * `:width` - the width of the sawtooth. Must be a number 18 | between 0 and 1 (both inclusive). Defaults to 1. 19 | 20 | ## Examples 21 | 22 | A 5Hz waveform sampled at 500Hz for 1 second can be defined as: 23 | 24 | t = Nx.linspace(0, 1, n: 500) 25 | n = Nx.multiply(2 * :math.pi() * 5, t) 26 | wave = NxSignal.Waveforms.sawtooth(n) 27 | """ 28 | @doc type: :waveforms 29 | defn sawtooth(t, opts \\ []) do 30 | opts = keyword!(opts, width: 1) 31 | 32 | width = opts[:width] 33 | 34 | if width < 0 or width > 1 do 35 | raise ArgumentError, "width must be between 0 and 1, inclusive. Got: #{inspect(width)}" 36 | end 37 | 38 | tmod = Nx.remainder(t, 2 * pi()) 39 | 40 | cond do 41 | width == 1 -> 42 | tmod / (pi() * width) - 1 43 | 44 | width == 0 -> 45 | (pi() * (width + 1) - tmod) / (pi() * (1 - width)) 46 | 47 | true -> 48 | Nx.select( 49 | tmod < 2 * pi() * width, 50 | tmod / (pi() * width) - 1, 51 | (pi() * (width + 1) - tmod) / (pi() * (1 - width)) 52 | ) 53 | end 54 | end 55 | 56 | @doc """ 57 | A periodic square wave with period $2\\pi$. 58 | 59 | Evaluates to 1 in the interval $[0, 2\\pi\\text{duty}]$ 60 | and -1 in the interval $[2\\pi\\text{duty}, 2\\pi]$. 61 | 62 | ## Options 63 | 64 | * `:duty` - a number or tensor representing the duty cycle. 65 | If a tensor is given, the waveform changes over time, and it 66 | must have the same length as the `t` input. Defaults to `0.5`. 67 | 68 | ## Examples 69 | 70 | iex> t = Nx.iota({10}) |> Nx.multiply(:math.pi() * 2 / 10) 71 | iex> NxSignal.Waveforms.square(t, duty: 0.1) 72 | #Nx.Tensor< 73 | s32[10] 74 | [1, -1, -1, -1, -1, -1, -1, -1, -1, -1] 75 | > 76 | iex> NxSignal.Waveforms.square(t, duty: 0.5) 77 | #Nx.Tensor< 78 | s32[10] 79 | [1, 1, 1, 1, 1, -1, -1, -1, -1, -1] 80 | > 81 | iex> NxSignal.Waveforms.square(t, duty: 1) 82 | #Nx.Tensor< 83 | s32[10] 84 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 85 | > 86 | 87 | iex> t = Nx.iota({10}) |> Nx.multiply(:math.pi() * 2 / 10) 88 | iex> duty = Nx.tensor([0.1, 0, 0.3, 0, 0.5, 0, 0.7, 0, 0.9, 0]) 89 | iex> NxSignal.Waveforms.square(t, duty: duty) 90 | #Nx.Tensor< 91 | s32[10] 92 | [1, -1, 1, -1, 1, -1, 1, -1, 1, -1] 93 | > 94 | """ 95 | @doc type: :waveforms 96 | deftransform square(t, opts \\ []) do 97 | opts = Keyword.validate!(opts, duty: 0.5) 98 | square_n(t, opts[:duty]) 99 | end 100 | 101 | defnp square_n(t, duty) do 102 | tmod = Nx.remainder(t, 2 * pi()) 103 | Nx.select(tmod < duty * 2 * pi(), 1, -1) 104 | end 105 | 106 | @doc ~S""" 107 | Gaussian modulated sinusoid. 108 | 109 | The returned value follows the formula: 110 | 111 | $$ 112 | f(t) = e^{-at^2}(cos(2 \pi f_c t) + isin(2 \pi f_c t)) 113 | $$ 114 | 115 | Where the exponential envelope is returned as `envelope`, 116 | and the real and imaginary parts of $f(t)$ are returned as 117 | `in_phase` and `quadrature` in the output map. 118 | 119 | Note that `in_phase` and `quadrature` are are equivalent to 120 | $\operatorname{Re} \lbrace f(t) \rbrace$ and $\operatorname{Im} \lbrace f(t) \rbrace$ respectively. 121 | 122 | ## Examples 123 | 124 | iex> t = Nx.linspace(0, 1, n: 4) 125 | iex> pulse = NxSignal.Waveforms.gaussian_pulse(t, center_frequency: 4) 126 | iex> pulse.envelope 127 | #Nx.Tensor< 128 | f32[4] 129 | [1.0, 0.2044311761856079, 0.0017465798882767558, 6.236266472114949e-7] 130 | > 131 | iex> pulse.in_phase 132 | #Nx.Tensor< 133 | f32[4] 134 | [1.0, -0.10221562534570694, -8.73289187438786e-4, 6.236266472114949e-7] 135 | > 136 | iex> pulse.quadrature 137 | #Nx.Tensor< 138 | f32[4] 139 | [0.0, 0.17704255878925323, -0.001512582995928824, 4.3615339204855497e-13] 140 | > 141 | 142 | iex> t = Nx.linspace(0, 1, n: 4) 143 | iex> pulse = NxSignal.Waveforms.gaussian_pulse(t, center_frequency: 4, bandwidth: 0.25) 144 | iex> pulse.envelope 145 | #Nx.Tensor< 146 | f32[4] 147 | [1.0, 0.6724140644073486, 0.2044311761856079, 0.028101608157157898] 148 | > 149 | iex> pulse.in_phase 150 | #Nx.Tensor< 151 | f32[4] 152 | [1.0, -0.33620715141296387, -0.10221550613641739, 0.028101608157157898] 153 | > 154 | iex> pulse.quadrature 155 | #Nx.Tensor< 156 | f32[4] 157 | [0.0, 0.5823275446891785, -0.1770426332950592, 1.965376483781256e-8] 158 | > 159 | """ 160 | @doc type: :waveforms 161 | defn gaussian_pulse(t, opts \\ []) do 162 | opts = 163 | keyword!(opts, 164 | center_frequency: 1000, 165 | bandwidth: 0.5, 166 | bandwidth_reference_level: -6 167 | ) 168 | 169 | fc = opts[:center_frequency] 170 | bw = opts[:bandwidth] 171 | bwr = opts[:bandwidth_reference_level] 172 | 173 | if fc < 0 do 174 | raise ArgumentError, 175 | "Center frequency must be greater than or equal to 0, got: #{inspect(fc)}" 176 | end 177 | 178 | if bw <= 0 do 179 | raise ArgumentError, 180 | "Bandwidth must be greater than 0, got: #{inspect(bw)}" 181 | end 182 | 183 | if bwr >= 0 do 184 | raise ArgumentError, 185 | "Bandwidth reference level must be less than 0, got: #{inspect(bwr)}" 186 | end 187 | 188 | ref = 10 ** (bwr / 20) 189 | 190 | a = -((pi() * fc * bw) ** 2) / (4.0 * Nx.log(ref)) 191 | 192 | yenv = Nx.exp(-a * t * t) 193 | yarg = 2 * pi() * fc * t 194 | yI = yenv * Nx.cos(yarg) 195 | yQ = yenv * Nx.sin(yarg) 196 | 197 | %{envelope: yenv, in_phase: yI, quadrature: yQ} 198 | end 199 | 200 | @doc """ 201 | Chirp function. 202 | 203 | Starts at `t` with frequency `f0` and ends at `t1` with 204 | frequency `f1`. 205 | 206 | ## Options 207 | 208 | * `:phi` - phase shift for the chirp. 209 | * `:vertex_zero` - determines the position of the parabolic vertex 210 | for when `method: :quadratic`. Defaults to `true`. 211 | * `:method` - One of various frequency interpolation methods: 212 | * `:linear` - linear interpolation. 213 | * `:quadratic` - parabolic interpolation with vertex at `t1` or `t0`, 214 | depending if `vertex_zero: false` or `vertex_zero: true` respectively. 215 | * `:hyperbolic` - hyperbolic interpolation. 216 | * `:logarithmic` - logarithmic (also known as geometric or exponential) 217 | interpolation. `f0` and `f1` must be non-zero and have the same sign. 218 | 219 | ## Examples 220 | 221 | iex> t = Nx.linspace(0, 10, n: 5) 222 | iex> NxSignal.Waveforms.chirp(t, 10, 10, 1, method: :linear) 223 | #Nx.Tensor< 224 | f32[5] 225 | [1.0, 0.38268470764160156, 3.795033308051643e-6, -0.382683128118515, 1.0] 226 | > 227 | iex> NxSignal.Waveforms.chirp(t, 10, 10, 1, method: :quadratic) 228 | #Nx.Tensor< 229 | f32[5] 230 | [1.0, -0.9807833433151245, -9.958475288840418e-8, -0.5555803775787354, 1.0] 231 | > 232 | iex> NxSignal.Waveforms.chirp(t, 10, 10, 1, method: :quadratic, vertex_zero: false) 233 | #Nx.Tensor< 234 | f32[5] 235 | [1.0, 0.5555850863456726, -7.490481493732659e-6, 0.98078453540802, 1.0] 236 | > 237 | iex> NxSignal.Waveforms.chirp(t, 10, 10, 1, method: :hyperbolic) 238 | #Nx.Tensor< 239 | f32[5] 240 | [1.0, 0.8229323029518127, 0.9335360527038574, 0.013466471806168556, -0.8630329966545105] 241 | > 242 | iex> NxSignal.Waveforms.chirp(t, 10, 10, 1, method: :logarithmic) 243 | #Nx.Tensor< 244 | f32[5] 245 | [1.0, 0.9989554286003113, -0.33371755480766296, -0.2700612545013428, 0.8558982610702515] 246 | > 247 | """ 248 | @doc type: :waveforms 249 | defn chirp(t, f0, t1, f1, opts \\ []) do 250 | opts = keyword!(opts, phi: 0, vertex_zero: true, method: :linear) 251 | 252 | phase = 253 | case {chirp_validate_method(opts[:method]), opts[:vertex_zero]} do 254 | {:linear, _} -> 255 | beta = (f1 - f0) / t1 256 | 2 * pi() * (f0 * t + 0.5 * beta * t ** 2) 257 | 258 | {:quadratic, true} -> 259 | beta = (f1 - f0) / t1 ** 2 260 | 2 * pi() * (f0 * t + beta * t ** 3 / 3) 261 | 262 | {:quadratic, _} -> 263 | beta = (f1 - f0) / t1 ** 2 264 | 2 * pi() * (f1 * t + beta * ((t1 - t) ** 3 - t1 ** 3) / 3) 265 | 266 | {:logarithmic, _} -> 267 | cond do 268 | f0 * f1 <= 0 -> 269 | Nx.broadcast(:nan, t.shape) 270 | 271 | f0 == f1 -> 272 | 2 * pi() * f0 * t 273 | 274 | true -> 275 | beta = t1 / Nx.log(f1 / f0) 276 | 2 * pi() * beta * f0 * ((f1 / f0) ** (t / t1) - 1.0) 277 | end 278 | 279 | {:hyperbolic, _} -> 280 | if f0 == f1 do 281 | 2 * pi() * f0 * t 282 | else 283 | singular_point = -f1 * t1 / (f0 - f1) 284 | 2 * pi() * (-singular_point * f0) * Nx.log(Nx.abs(1 - t / singular_point)) 285 | end 286 | end 287 | 288 | Nx.cos(phase + opts[:phi]) 289 | end 290 | 291 | deftransformp chirp_validate_method(method) do 292 | valid_methods = [:linear, :quadratic, :logarithmic, :hyperbolic] 293 | 294 | if method not in valid_methods do 295 | raise ArgumentError, 296 | "invalid method, must be one of #{inspect(valid_methods)}, got: #{inspect(method)}" 297 | end 298 | 299 | method 300 | end 301 | 302 | @doc """ 303 | Frequency-swept cosine generator, with a time-dependent frequency. 304 | 305 | This function generates a sinusoidal function whose instantaneous 306 | frequency varies with time. The frequency at time `t` is given by 307 | the polynomial specified by the coefficients contained in `coefs`. 308 | 309 | See also: `chirp/5` 310 | 311 | ## Options 312 | 313 | * `:phi` - phase shift to be applied before calculating the `Nx.cos` 314 | for the output. Defaults to 0. 315 | * `:phi_unit` - determines if `:phi` is given in `:radians` or `:degrees`. 316 | Defaults to `:radians`. 317 | 318 | ## Examples 319 | 320 | iex> t = Nx.linspace(0, 10, n: 5) 321 | iex> NxSignal.Waveforms.polynomial_sweep(t, Nx.tensor([2, 0, 1])) 322 | #Nx.Tensor< 323 | f32[5] 324 | [1.0, 0.866027295589447, -0.500006377696991, 1.7942518752533942e-5, -0.49998921155929565] 325 | > 326 | iex> NxSignal.Waveforms.polynomial_sweep(t, Nx.tensor([2, 0, 1]), phi: :math.pi() / 2) 327 | #Nx.Tensor< 328 | f32[5] 329 | [-4.371138828673793e-8, 0.499999463558197, -0.8660194873809814, 1.0, 0.8660338521003723] 330 | > 331 | iex> NxSignal.Waveforms.polynomial_sweep(t, Nx.tensor([1, 0])) 332 | #Nx.Tensor< 333 | f32[5] 334 | [1.0, 0.7071065306663513, -1.0, 0.7071084976196289, 1.0] 335 | > 336 | iex> NxSignal.Waveforms.polynomial_sweep(t, Nx.tensor([1, 0]), phi: 180, phi_unit: :degrees) 337 | #Nx.Tensor< 338 | f32[5] 339 | [-1.0, -0.7071069478988647, 1.0, -0.7071129679679871, -1.0] 340 | > 341 | """ 342 | @doc type: :waveforms 343 | defn polynomial_sweep(t, coefs, opts \\ []) do 344 | opts = keyword!(opts, phi: 0, phi_unit: :radians) 345 | {n} = Nx.shape(coefs) 346 | # assumes t is of shape {m} 347 | iota = n - Nx.iota({n}) 348 | t_poly = t ** Nx.new_axis(iota, 1) 349 | 350 | int_coefs = coefs / iota 351 | 352 | phase = Nx.dot(int_coefs, t_poly) 353 | 354 | phi = 355 | case {opts[:phi], opts[:phi_unit]} do 356 | {phi, :radians} -> phi 357 | {phi, :degrees} -> phi * pi() / 180 358 | end 359 | 360 | Nx.cos(2 * pi() * phase + phi) 361 | end 362 | 363 | @doc """ 364 | Discrete delta function or unit basis vector. 365 | 366 | ## Options 367 | 368 | * `:index` - one of number, numerical tensor 369 | with length equal to the rank of the given 370 | shape, or `:midpoint`. `index: :midpoint`, 371 | is a shortcut for inserting the impulse 372 | at the index which corresponds to half of 373 | each dimension. Defaults to 0. 374 | 375 | * `:type` - datatype for the output. Defaults to `:f32`. 376 | 377 | ## Examples 378 | 379 | iex> NxSignal.Waveforms.unit_impulse({2}) 380 | #Nx.Tensor< 381 | f32[2] 382 | [1.0, 0.0] 383 | > 384 | 385 | iex> NxSignal.Waveforms.unit_impulse({3, 5}, type: :s32, index: :midpoint) 386 | #Nx.Tensor< 387 | s32[3][5] 388 | [ 389 | [0, 0, 0, 0, 0], 390 | [0, 0, 1, 0, 0], 391 | [0, 0, 0, 0, 0] 392 | ] 393 | > 394 | 395 | iex> NxSignal.Waveforms.unit_impulse({3, 5}, index: Nx.tensor([[2, 3]]), type: :s32) 396 | #Nx.Tensor< 397 | s32[3][5] 398 | [ 399 | [0, 0, 0, 0, 0], 400 | [0, 0, 0, 0, 0], 401 | [0, 0, 0, 1, 0] 402 | ] 403 | > 404 | """ 405 | @doc type: :waveforms 406 | deftransform unit_impulse(shape, opts \\ []) do 407 | opts = Keyword.validate!(opts, index: 0, type: :f32) 408 | index = unit_impulse_index(shape, opts[:index]) 409 | 410 | unit_impulse_n(index, Keyword.put(opts, :shape, shape)) 411 | end 412 | 413 | defnp unit_impulse_n(index, opts \\ []) do 414 | shape = opts[:shape] 415 | type = opts[:type] 416 | 417 | zero = Nx.tensor(0, type: type) 418 | 419 | zeros = Nx.broadcast(zero, shape) 420 | 421 | Nx.indexed_put(zeros, index, 1) 422 | end 423 | 424 | deftransformp unit_impulse_index(shape, index) do 425 | n = Nx.rank(shape) 426 | 427 | case index do 428 | :midpoint -> 429 | shape 430 | |> Tuple.to_list() 431 | |> Enum.map(&div(&1, 2)) 432 | |> then(&Nx.tensor(&1)) 433 | 434 | index -> 435 | Nx.reshape(index, {n}) 436 | end 437 | end 438 | 439 | @doc ~S""" 440 | Calculates the normalized sinc function $sinc(t) = \frac{sin(\pi t)}{\pi t}$ 441 | 442 | ## Examples 443 | 444 | iex> NxSignal.Waveforms.sinc(Nx.tensor([0, 0.25, 1])) 445 | #Nx.Tensor< 446 | f32[3] 447 | [1.0, 0.9003162980079651, -2.7827534054836178e-8] 448 | > 449 | """ 450 | @doc type: :waveforms 451 | defn sinc(t) do 452 | t = t * pi() 453 | zero_idx = Nx.equal(t, 0) 454 | 455 | # Define sinc(0) = 1 456 | Nx.select(zero_idx, 1, Nx.sin(t) / t) 457 | end 458 | end 459 | -------------------------------------------------------------------------------- /lib/nx_signal/windows.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Windows do 2 | @moduledoc """ 3 | Common window functions. 4 | """ 5 | import Nx.Defn 6 | 7 | @pi :math.pi() 8 | 9 | @doc """ 10 | Rectangular window. 11 | 12 | Useful for when no window function should be applied. 13 | 14 | ## Options 15 | 16 | * `:type` - the output type. Defaults to `s64` 17 | 18 | ## Examples 19 | 20 | iex> NxSignal.Windows.rectangular(5) 21 | #Nx.Tensor< 22 | s64[5] 23 | [1, 1, 1, 1, 1] 24 | > 25 | 26 | iex> NxSignal.Windows.rectangular(5, type: :f32) 27 | #Nx.Tensor< 28 | f32[5] 29 | [1.0, 1.0, 1.0, 1.0, 1.0] 30 | > 31 | """ 32 | @doc type: :windowing 33 | deftransform rectangular(n, opts \\ []) when is_integer(n) do 34 | opts = Keyword.validate!(opts, type: :s64) 35 | Nx.broadcast(Nx.tensor(1, type: opts[:type]), {n}) 36 | end 37 | 38 | @doc """ 39 | Bartlett triangular window. 40 | 41 | See also: `triangular/1` 42 | 43 | ## Options 44 | 45 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 46 | * `:name` - the axis name. Defaults to `nil` 47 | 48 | ## Examples 49 | 50 | iex> NxSignal.Windows.bartlett(3) 51 | #Nx.Tensor< 52 | f32[3] 53 | [0.0, 0.6666666865348816, 0.6666666269302368] 54 | > 55 | """ 56 | @doc type: :windowing 57 | deftransform bartlett(n, opts \\ []) when is_integer(n) do 58 | opts = Keyword.validate!(opts, type: {:f, 32}) 59 | bartlett_n(Keyword.put(opts, :n, n)) 60 | end 61 | 62 | defnp bartlett_n(opts) do 63 | n = opts[:n] 64 | name = opts[:name] 65 | type = opts[:type] 66 | 67 | n_on_2 = div(n, 2) 68 | left_size = n_on_2 + rem(n, 2) 69 | left_idx = Nx.iota({left_size}, names: [name], type: type) 70 | right_idx = Nx.iota({n_on_2}, names: [name], type: type) + left_size 71 | 72 | Nx.concatenate([ 73 | left_idx * 2 / n, 74 | 2 - right_idx * 2 / n 75 | ]) 76 | end 77 | 78 | @doc """ 79 | Triangular window. 80 | 81 | See also: `bartlett/1` 82 | 83 | ## Options 84 | 85 | * `:n` - The window length. Mandatory option. 86 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 87 | * `:name` - the axis name. Defaults to `nil` 88 | 89 | ## Examples 90 | 91 | iex> NxSignal.Windows.triangular(3) 92 | #Nx.Tensor< 93 | f32[3] 94 | [0.5, 1.0, 0.5] 95 | > 96 | """ 97 | @doc type: :windowing 98 | deftransform triangular(n, opts \\ []) when is_integer(n) do 99 | opts = Keyword.validate!(opts, [:name, type: {:f, 32}]) 100 | triangular_n(Keyword.put(opts, :n, n)) 101 | end 102 | 103 | defnp triangular_n(opts) do 104 | n = opts[:n] 105 | name = opts[:name] 106 | type = opts[:type] 107 | 108 | case rem(n, 2) do 109 | 1 -> 110 | # odd case 111 | n_on_2 = div(n + 1, 2) 112 | 113 | idx = Nx.iota({n_on_2}, names: [name], type: type) + 1 114 | 115 | left = idx * 2 / (n + 1) 116 | Nx.concatenate([left, left |> Nx.reverse() |> Nx.slice([1], [Nx.size(left) - 1])]) 117 | 118 | 0 -> 119 | # even case 120 | n_on_2 = div(n + 1, 2) 121 | 122 | idx = Nx.iota({n_on_2}, names: [name], type: type) + 1 123 | 124 | left = (2 * idx - 1) / n 125 | Nx.concatenate([left, Nx.reverse(left)]) 126 | end 127 | end 128 | 129 | @doc """ 130 | Blackman window. 131 | 132 | ## Options 133 | 134 | * `:is_periodic` - If `true`, produces a periodic window, 135 | otherwise produces a symmetric window. Defaults to `true` 136 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 137 | * `:name` - the axis name. Defaults to `nil` 138 | 139 | ## Examples 140 | 141 | iex> NxSignal.Windows.blackman(5, is_periodic: false) 142 | #Nx.Tensor< 143 | f32[5] 144 | [-1.4901161193847656e-8, 0.3400000333786011, 0.9999999403953552, 0.3400000333786011, -1.4901161193847656e-8] 145 | > 146 | 147 | iex> NxSignal.Windows.blackman(5, is_periodic: true) 148 | #Nx.Tensor< 149 | f32[5] 150 | [-1.4901161193847656e-8, 0.20077012479305267, 0.8492299318313599, 0.8492299318313599, 0.20077012479305267] 151 | > 152 | 153 | iex> NxSignal.Windows.blackman(6, is_periodic: true, type: {:f, 32}) 154 | #Nx.Tensor< 155 | f32[6] 156 | [-1.4901161193847656e-8, 0.12999999523162842, 0.6299999952316284, 0.9999999403953552, 0.6299999952316284, 0.12999999523162842] 157 | > 158 | """ 159 | @doc type: :windowing 160 | deftransform blackman(n, opts \\ []) when is_integer(n) do 161 | opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) 162 | blackman_n(Keyword.put(opts, :n, n)) 163 | end 164 | 165 | defnp blackman_n(opts) do 166 | n = opts[:n] 167 | name = opts[:name] 168 | type = opts[:type] 169 | is_periodic = opts[:is_periodic] 170 | 171 | l = 172 | if is_periodic do 173 | n + 1 174 | else 175 | n 176 | end 177 | 178 | m = 179 | integer_div_ceil(l, 2) 180 | 181 | n = Nx.iota({m}, names: [name], type: type) 182 | 183 | left = 184 | 0.42 - 0.5 * Nx.cos(2 * @pi * n / (l - 1)) + 185 | 0.08 * Nx.cos(4 * @pi * n / (l - 1)) 186 | 187 | window = 188 | if rem(l, 2) == 0 do 189 | Nx.concatenate([left, Nx.reverse(left)]) 190 | else 191 | Nx.concatenate([left, left |> Nx.reverse() |> Nx.slice([1], [Nx.size(left) - 1])]) 192 | end 193 | 194 | if is_periodic do 195 | Nx.slice(window, [0], [Nx.size(window) - 1]) 196 | else 197 | window 198 | end 199 | end 200 | 201 | @doc """ 202 | Hamming window. 203 | 204 | ## Options 205 | 206 | * `:is_periodic` - If `true`, produces a periodic window, 207 | otherwise produces a symmetric window. Defaults to `true` 208 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 209 | * `:name` - the axis name. Defaults to `nil` 210 | 211 | ## Examples 212 | 213 | iex> NxSignal.Windows.hamming(5, is_periodic: true) 214 | #Nx.Tensor< 215 | f32[5] 216 | [0.08000001311302185, 0.39785221219062805, 0.9121478796005249, 0.9121478199958801, 0.3978521227836609] 217 | > 218 | iex> NxSignal.Windows.hamming(5, is_periodic: false) 219 | #Nx.Tensor< 220 | f32[5] 221 | [0.08000001311302185, 0.5400000214576721, 1.0, 0.5400000214576721, 0.08000001311302185] 222 | > 223 | """ 224 | @doc type: :windowing 225 | deftransform hamming(n, opts \\ []) when is_integer(n) do 226 | opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) 227 | hamming_n(Keyword.put(opts, :n, n)) 228 | end 229 | 230 | defnp hamming_n(opts) do 231 | n = opts[:n] 232 | name = opts[:name] 233 | type = opts[:type] 234 | is_periodic = opts[:is_periodic] 235 | 236 | l = 237 | if is_periodic do 238 | n + 1 239 | else 240 | n 241 | end 242 | 243 | n = Nx.iota({l}, names: [name], type: type) 244 | 245 | window = 0.54 - 0.46 * Nx.cos(2 * @pi * n / (l - 1)) 246 | 247 | if is_periodic do 248 | Nx.slice(window, [0], [l - 1]) 249 | else 250 | window 251 | end 252 | end 253 | 254 | @doc """ 255 | Hann window. 256 | 257 | ## Options 258 | 259 | * `:is_periodic` - If `true`, produces a periodic window, 260 | otherwise produces a symmetric window. Defaults to `true` 261 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 262 | * `:name` - the axis name. Defaults to `nil` 263 | 264 | ## Examples 265 | 266 | iex> NxSignal.Windows.hann(5, is_periodic: false) 267 | #Nx.Tensor< 268 | f32[5] 269 | [0.0, 0.5, 1.0, 0.5, 0.0] 270 | > 271 | iex> NxSignal.Windows.hann(5, is_periodic: true) 272 | #Nx.Tensor< 273 | f32[5] 274 | [0.0, 0.34549152851104736, 0.9045085310935974, 0.9045084714889526, 0.3454914391040802] 275 | > 276 | """ 277 | @doc type: :windowing 278 | deftransform hann(n, opts \\ []) when is_integer(n) do 279 | opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}]) 280 | hann_n(Keyword.put(opts, :n, n)) 281 | end 282 | 283 | defnp hann_n(opts) do 284 | n = opts[:n] 285 | name = opts[:name] 286 | type = opts[:type] 287 | is_periodic = opts[:is_periodic] 288 | 289 | l = 290 | if is_periodic do 291 | n + 1 292 | else 293 | n 294 | end 295 | 296 | n = Nx.iota({l}, names: [name], type: type) 297 | 298 | window = 0.5 * (1 - Nx.cos(2 * @pi * n / (l - 1))) 299 | 300 | if is_periodic do 301 | Nx.slice(window, [0], [l - 1]) 302 | else 303 | window 304 | end 305 | end 306 | 307 | @doc """ 308 | Creates a Kaiser window of size `window_length`. 309 | 310 | The Kaiser window is a taper formed by using a Bessel function. 311 | 312 | ## Options 313 | 314 | * `:is_periodic` - If `true`, produces a periodic window, 315 | otherwise produces a symmetric window. Defaults to `true` 316 | * `:type` - the output type for the window. Defaults to `{:f, 32}` 317 | * `:beta` - Shape parameter for the window. As beta increases, the window becomes more focused in frequency domain. Defaults to 12.0. 318 | * `:eps` - Epsilon value to avoid division by zero. Defaults to 1.0e-7. 319 | * `:axis_name` - the axis name. Defaults to `nil` 320 | 321 | ## Examples 322 | iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: true) 323 | #Nx.Tensor< 324 | f32[4] 325 | [5.2776191296288744e-5, 0.21566666662693024, 1.0, 0.21566666662693024] 326 | > 327 | 328 | iex> NxSignal.Windows.kaiser(5, beta: 12.0, is_periodic: true) 329 | #Nx.Tensor< 330 | f32[5] 331 | [5.2776191296288744e-5, 0.10171464085578918, 0.7929369807243347, 0.7929369807243347, 0.10171464085578918] 332 | > 333 | 334 | iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: false) 335 | #Nx.Tensor< 336 | f32[4] 337 | [5.2776191296288744e-5, 0.5188394784927368, 0.5188390612602234, 5.2776191296288744e-5] 338 | > 339 | """ 340 | @doc type: :windowing 341 | deftransform kaiser(n, opts \\ []) when is_integer(n) do 342 | opts = 343 | Keyword.validate!(opts, [:name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}]) 344 | 345 | kaiser_n(Keyword.put(opts, :n, n)) 346 | end 347 | 348 | defnp kaiser_n(opts) do 349 | n = opts[:n] 350 | name = opts[:name] 351 | type = opts[:type] 352 | beta = opts[:beta] 353 | eps = opts[:eps] 354 | is_periodic = opts[:is_periodic] 355 | 356 | window_length = if is_periodic, do: n + 1, else: n 357 | 358 | ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type, name: name) 359 | sqrt_arg = Nx.max(1 - ratio ** 2, eps) 360 | r = beta * Nx.sqrt(sqrt_arg) 361 | 362 | window = kaiser_bessel_i0(r) / kaiser_bessel_i0(beta) 363 | 364 | if is_periodic do 365 | Nx.slice(window, [0], [n]) 366 | else 367 | window 368 | end 369 | end 370 | 371 | defnp kaiser_bessel_i0(x) do 372 | abs_x = Nx.abs(x) 373 | 374 | small_x_result = 375 | 1 + 376 | abs_x ** 2 / 4 + 377 | abs_x ** 4 / 64 + 378 | abs_x ** 6 / 2304 + 379 | abs_x ** 8 / 147_456 380 | 381 | large_x_result = 382 | Nx.exp(abs_x) / Nx.sqrt(2 * Nx.Constants.pi() * abs_x) * 383 | (1 + 1 / (8 * abs_x) + 9 / (128 * Nx.pow(abs_x, 2))) 384 | 385 | Nx.select(abs_x < 3.75, small_x_result, large_x_result) 386 | end 387 | 388 | deftransformp integer_div_ceil(num, den) when is_integer(num) and is_integer(den) do 389 | rem = rem(num, den) 390 | 391 | if rem == 0 do 392 | div(num, den) 393 | else 394 | div(num, den) + 1 395 | end 396 | end 397 | end 398 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.MixProject do 2 | use Mix.Project 3 | 4 | @source_url "https://github.com/elixir-nx/nx_signal" 5 | @version "0.2.0" 6 | 7 | def project do 8 | [ 9 | app: :nx_signal, 10 | version: @version, 11 | elixir: "~> 1.14", 12 | start_permanent: Mix.env() == :prod, 13 | elixirc_paths: elixirc_paths(Mix.env()), 14 | deps: deps(), 15 | docs: docs(), 16 | name: "NxSignal", 17 | description: "Digital Signal Processing extension for Nx", 18 | package: package(), 19 | preferred_cli_env: [ 20 | docs: :docs, 21 | "hex.publish": :docs 22 | ] 23 | ] 24 | end 25 | 26 | defp elixirc_paths(:test), do: ["test/support", "lib"] 27 | defp elixirc_paths(_), do: ["lib"] 28 | 29 | # Run "mix help compile.app" to learn about applications. 30 | def application do 31 | [ 32 | extra_applications: [:logger] 33 | ] 34 | end 35 | 36 | defp docs do 37 | [ 38 | main: "NxSignal", 39 | source_url_pattern: "#{@source_url}/blob/v#{@version}/nx-signal/%{path}#L%{line}", 40 | before_closing_body_tag: &before_closing_body_tag/1, 41 | extras: [ 42 | "guides/filtering.livemd", 43 | "guides/spectrogram.livemd" 44 | ], 45 | groups_for_extras: [ 46 | Guides: Path.wildcard("guides/*.livemd") 47 | ], 48 | groups_for_functions: [ 49 | "Functions: Time-Frequency": &(&1[:type] == :time_frequency), 50 | "Functions: Windowing": &(&1[:type] == :windowing), 51 | "Functions: Filters": &(&1[:type] == :filters), 52 | "Functions: Waveforms": &(&1[:type] == :waveforms) 53 | ] 54 | ] 55 | end 56 | 57 | # Run "mix help deps" to learn about dependencies. 58 | defp deps do 59 | [ 60 | {:nx, github: "elixir-nx/nx", sparse: "nx"}, 61 | {:ex_doc, "~> 0.29", only: :docs} 62 | ] 63 | end 64 | 65 | defp package do 66 | [ 67 | maintainers: ["Paulo Valente"], 68 | licenses: ["Apache-2.0"], 69 | links: %{"GitHub" => @source_url} 70 | ] 71 | end 72 | 73 | defp before_closing_body_tag(:html) do 74 | """ 75 | 76 | 77 | 78 | 88 | """ 89 | end 90 | 91 | defp before_closing_body_tag(_), do: "" 92 | end 93 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, 3 | "earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"}, 4 | "elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"}, 5 | "ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"}, 6 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 7 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"}, 8 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, 9 | "nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"}, 10 | "nx": {:git, "https://github.com/elixir-nx/nx.git", "c6fc98df9ff36b27727c800524b36e92f56dd7ba", [sparse: "nx"]}, 11 | "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, 12 | "xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"}, 13 | } 14 | -------------------------------------------------------------------------------- /test/nx_signal/convolutions_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.ConvolutionTest do 2 | use NxSignal.Case, async: true, validate_doc_metadata: false 3 | doctest NxSignal.Convolution 4 | 5 | describe "convolve/3" do 6 | # These tests were adapted from https://github.com/numpy/numpy/blob/v2.1.0/numpy/_core/tests/test_numeric.py#L3573 7 | test "numpy object" do 8 | d = Nx.tensor(List.duplicate(1.0, 100)) 9 | k = Nx.tensor(List.duplicate(1.0, 3)) 10 | c = NxSignal.Convolution.convolve(d, k)[2..-3//1] 11 | o = Nx.tensor(List.duplicate(3, 98)) 12 | assert_all_close(c, o) 13 | end 14 | 15 | # These tests were adapted from https://github.com/scipy/scipy/blob/v1.14.1/scipy/signal/tests/test_signaltools.py 16 | test "basic" do 17 | a = Nx.tensor([3, 4, 5, 6, 5, 4]) 18 | b = Nx.tensor([1, 2, 3]) 19 | c = NxSignal.Convolution.convolve(a, b, mode: :full) 20 | assert c == Nx.as_type(Nx.tensor([3, 10, 22, 28, 32, 32, 23, 12]), {:f, 32}) 21 | end 22 | 23 | test "same" do 24 | a = Nx.tensor([3, 4, 5]) 25 | b = Nx.tensor([1, 2, 3, 4]) 26 | c = NxSignal.Convolution.convolve(a, b, mode: :same) 27 | assert c == Nx.as_type(Nx.tensor([10, 22, 34]), {:f, 32}) 28 | end 29 | 30 | test "same eq" do 31 | a = Nx.tensor([3, 4, 5]) 32 | b = Nx.tensor([1, 2, 3]) 33 | c = NxSignal.Convolution.convolve(a, b, mode: :same) 34 | assert c == Nx.as_type(Nx.tensor([10, 22, 22]), {:f, 32}) 35 | end 36 | 37 | test "complex" do 38 | a = Nx.tensor([Complex.new(1, 1), Complex.new(2, 1), Complex.new(3, 1)]) 39 | b = Nx.tensor([Complex.new(1, 1), Complex.new(2, 1)]) 40 | c = NxSignal.Convolution.convolve(a, b) 41 | 42 | assert c == 43 | Nx.tensor([ 44 | Complex.new(0, 2), 45 | Complex.new(2, 6), 46 | Complex.new(5, 8), 47 | Complex.new(5, 5) 48 | ]) 49 | end 50 | 51 | test "zero rank" do 52 | a = Nx.tensor(1289) 53 | b = Nx.tensor(4567) 54 | c = NxSignal.Convolution.convolve(a, b) 55 | assert c == Nx.as_type(Nx.multiply(a, b), {:f, 32}) 56 | end 57 | 58 | test "complex simple" do 59 | a = Nx.tensor([Complex.new(1, 1)]) 60 | b = Nx.tensor([Complex.new(3, 4)]) 61 | c = NxSignal.Convolution.convolve(a, b) 62 | assert c == Nx.tensor([Complex.new(-1, 7)]) 63 | end 64 | 65 | test "fft_nd" do 66 | a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 67 | c = NxSignal.Transforms.fft_nd(a, axes: [0, 1], lengths: [2, 3]) 68 | 69 | z = 70 | Nx.tensor([[21, Complex.new(-3, 1.732), Complex.new(-3, -1.732)], [-9, 0, 0]]) 71 | 72 | assert_all_close( 73 | c, 74 | z 75 | ) 76 | end 77 | 78 | test "fft_nd with padding" do 79 | a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 80 | c = NxSignal.Transforms.fft_nd(a, axes: [0, 1], lengths: [3, 3]) 81 | 82 | z = 83 | Nx.tensor([ 84 | [2.1e1, Complex.new(-3, 1.732), Complex.new(-3, -1.732)], 85 | [Complex.new(-1.5, -12.99), Complex.new(-1.11e-16, 1.732), Complex.new(-1.5, 0.866)], 86 | [Complex.new(-1.5, 12.99), Complex.new(-1.5, -0.866), Complex.new(-1.11e-16, -1.732)] 87 | ]) 88 | 89 | assert_all_close( 90 | c, 91 | z 92 | ) 93 | end 94 | 95 | test "broadcastable" do 96 | a = Nx.iota({3, 3, 3}) 97 | b = Nx.iota({1, 1, 3}) 98 | 99 | x = NxSignal.Convolution.convolve(a, b, method: :direct) 100 | y = NxSignal.Convolution.convolve(a, b, method: :fft) 101 | 102 | expected = 103 | Nx.tensor([ 104 | [[0, 0, 1, 4, 4], [0, 3, 10, 13, 10], [0, 6, 19, 22, 16]], 105 | [[0, 9, 28, 31, 22], [0, 12, 37, 40, 28], [0, 15, 46, 49, 34]], 106 | [[0, 18, 55, 58, 40], [0, 21, 64, 67, 46], [0, 24, 73, 76, 52]] 107 | ]) 108 | 109 | assert_all_close(x, expected) 110 | assert_all_close(y, expected) 111 | 112 | b = Nx.reshape(b, {1, 3, 1}) 113 | 114 | x = NxSignal.Convolution.convolve(a, b, method: :direct) 115 | y = NxSignal.Convolution.convolve(a, b, method: :fft) 116 | 117 | expected = 118 | Nx.tensor([ 119 | [[0, 0, 0], [0, 1, 2], [3, 6, 9], [12, 15, 18], [12, 14, 16]], 120 | [[0, 0, 0], [9, 10, 11], [30, 33, 36], [39, 42, 45], [30, 32, 34]], 121 | [[0, 0, 0], [18, 19, 20], [57, 60, 63], [66, 69, 72], [48, 50, 52]] 122 | ]) 123 | 124 | assert_all_close(x, expected) 125 | assert_all_close(y, expected) 126 | 127 | b = Nx.reshape(b, {3, 1, 1}) 128 | 129 | x = NxSignal.Convolution.convolve(a, b, method: :direct) 130 | y = NxSignal.Convolution.convolve(a, b, method: :fft) 131 | 132 | expected = 133 | Nx.tensor([ 134 | [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 135 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]], 136 | [[9, 12, 15], [18, 21, 24], [27, 30, 33]], 137 | [[36, 39, 42], [45, 48, 51], [54, 57, 60]], 138 | [[36, 38, 40], [42, 44, 46], [48, 50, 52]] 139 | ]) 140 | 141 | assert_all_close(x, expected) 142 | assert_all_close(y, expected) 143 | end 144 | 145 | test "single element" do 146 | a = Nx.tensor([4967]) 147 | b = Nx.tensor([3920]) 148 | c = NxSignal.Convolution.convolve(a, b) 149 | assert c == Nx.as_type(Nx.multiply(a, b), {:f, 32}) 150 | end 151 | 152 | test "2d arrays" do 153 | a = Nx.tensor([[1, 2, 3], [3, 4, 5]]) 154 | b = Nx.tensor([[2, 3, 4], [4, 5, 6]]) 155 | c = NxSignal.Convolution.convolve(a, b) 156 | 157 | d = 158 | Nx.tensor([[2, 7, 16, 17, 12], [10, 30, 62, 58, 38], [12, 31, 58, 49, 30]]) 159 | |> Nx.as_type({:f, 32}) 160 | 161 | assert c == d 162 | end 163 | 164 | test "input swapping" do 165 | small = 166 | 0..(8 - 1) 167 | |> Enum.to_list() 168 | |> Nx.tensor() 169 | |> Nx.reshape({2, 2, 2}) 170 | 171 | big = 172 | 0..(27 - 1) 173 | |> Enum.to_list() 174 | |> Enum.map(&Complex.new(0, &1)) 175 | |> Nx.tensor() 176 | |> Nx.reshape({3, 3, 3}) 177 | 178 | big_add = 179 | 0..(27 - 1) 180 | |> Enum.to_list() 181 | |> Enum.reverse() 182 | |> Nx.tensor() 183 | |> Nx.reshape({3, 3, 3}) 184 | 185 | big = Nx.add(big, big_add) 186 | 187 | out_array = 188 | [ 189 | [ 190 | [Complex.new(0, 0), Complex.new(26, 0), Complex.new(25, 1), Complex.new(24, 2)], 191 | [Complex.new(52, 0), Complex.new(151, 5), Complex.new(145, 11), Complex.new(93, 11)], 192 | [Complex.new(46, 6), Complex.new(133, 23), Complex.new(127, 29), Complex.new(81, 23)], 193 | [Complex.new(40, 12), Complex.new(98, 32), Complex.new(93, 37), Complex.new(54, 24)] 194 | ], 195 | [ 196 | [ 197 | Complex.new(104, 0), 198 | Complex.new(247, 13), 199 | Complex.new(237, 23), 200 | Complex.new(135, 21) 201 | ], 202 | [ 203 | Complex.new(282, 30), 204 | Complex.new(632, 96), 205 | Complex.new(604, 124), 206 | Complex.new(330, 86) 207 | ], 208 | [ 209 | Complex.new(246, 66), 210 | Complex.new(548, 180), 211 | Complex.new(520, 208), 212 | Complex.new(282, 134) 213 | ], 214 | [ 215 | Complex.new(142, 66), 216 | Complex.new(307, 161), 217 | Complex.new(289, 179), 218 | Complex.new(153, 107) 219 | ] 220 | ], 221 | [ 222 | [ 223 | Complex.new(68, 36), 224 | Complex.new(157, 103), 225 | Complex.new(147, 113), 226 | Complex.new(81, 75) 227 | ], 228 | [ 229 | Complex.new(174, 138), 230 | Complex.new(380, 348), 231 | Complex.new(352, 376), 232 | Complex.new(186, 230) 233 | ], 234 | [ 235 | Complex.new(138, 174), 236 | Complex.new(296, 432), 237 | Complex.new(268, 460), 238 | Complex.new(138, 278) 239 | ], 240 | [ 241 | Complex.new(70, 138), 242 | Complex.new(145, 323), 243 | Complex.new(127, 341), 244 | Complex.new(63, 197) 245 | ] 246 | ], 247 | [ 248 | [ 249 | Complex.new(32, 72), 250 | Complex.new(68, 166), 251 | Complex.new(59, 175), 252 | Complex.new(30, 100) 253 | ], 254 | [ 255 | Complex.new(68, 192), 256 | Complex.new(139, 433), 257 | Complex.new(117, 455), 258 | Complex.new(57, 255) 259 | ], 260 | [ 261 | Complex.new(38, 222), 262 | Complex.new(73, 499), 263 | Complex.new(51, 521), 264 | Complex.new(21, 291) 265 | ], 266 | [ 267 | Complex.new(12, 144), 268 | Complex.new(20, 318), 269 | Complex.new(7, 331), 270 | Complex.new(0, 182) 271 | ] 272 | ] 273 | ] 274 | |> Nx.tensor() 275 | 276 | assert NxSignal.Convolution.convolve(small, big, mode: :full) == out_array 277 | assert NxSignal.Convolution.convolve(big, small, mode: :full) == out_array 278 | 279 | assert NxSignal.Convolution.convolve(small, big, mode: :same) == 280 | out_array[[1..2, 1..2, 1..2]] 281 | 282 | assert NxSignal.Convolution.convolve(big, small, mode: :same) == 283 | out_array[[0..2, 0..2, 0..2]] 284 | 285 | assert NxSignal.Convolution.convolve(small, big, mode: :valid) == 286 | out_array[[1..2, 1..2, 1..2]] 287 | 288 | assert NxSignal.Convolution.convolve(big, small, mode: :valid) == 289 | out_array[[1..2, 1..2, 1..2]] 290 | end 291 | 292 | test "invalid params" do 293 | a = Nx.tensor([3, 4, 5]) 294 | b = Nx.tensor([1, 2, 3]) 295 | 296 | assert_raise( 297 | ArgumentError, 298 | "expected mode to be one of [:full, :same, :valid], got: :spam", 299 | fn -> 300 | NxSignal.Convolution.convolve(a, b, mode: :spam) 301 | end 302 | ) 303 | 304 | assert_raise( 305 | ArgumentError, 306 | "expected mode to be one of [:full, :same, :valid], got: :eggs", 307 | fn -> 308 | NxSignal.Convolution.convolve(a, b, mode: :eggs, method: :fft) 309 | end 310 | ) 311 | 312 | assert_raise( 313 | ArgumentError, 314 | "expected mode to be one of [:full, :same, :valid], got: :ham", 315 | fn -> 316 | NxSignal.Convolution.convolve(a, b, mode: :ham, method: :direct) 317 | end 318 | ) 319 | 320 | assert_raise( 321 | ArgumentError, 322 | "expected method to be one of [:direct, :fft], got: :bacon", 323 | fn -> 324 | NxSignal.Convolution.convolve(a, b, mode: :full, method: :bacon) 325 | end 326 | ) 327 | 328 | assert_raise( 329 | ArgumentError, 330 | "expected method to be one of [:direct, :fft], got: :bacon", 331 | fn -> 332 | NxSignal.Convolution.convolve(a, b, mode: :same, method: :bacon) 333 | end 334 | ) 335 | end 336 | 337 | test "valid mode 2.1" do 338 | a = Nx.tensor([1, 2, 3, 6, 5, 3]) 339 | b = Nx.tensor([2, 3, 4, 5, 3, 4, 2, 2, 1]) 340 | expected = Nx.tensor([70, 78, 73, 65]) |> Nx.as_type({:f, 32}) 341 | 342 | out = NxSignal.Convolution.convolve(a, b, mode: :valid) 343 | assert out == expected 344 | 345 | out = NxSignal.Convolution.convolve(b, a, mode: :valid) 346 | assert out == expected 347 | end 348 | 349 | test "valid mode 2.2" do 350 | a = Nx.tensor([Complex.new(1, 5), Complex.new(2, -1), Complex.new(3, 0)]) 351 | b = Nx.tensor([Complex.new(2, -3), Complex.new(1, 0)]) 352 | expected = Nx.tensor([Complex.new(2, -3), Complex.new(8, -10)]) 353 | 354 | out = NxSignal.Convolution.convolve(a, b, mode: :valid) 355 | assert out == expected 356 | 357 | out = NxSignal.Convolution.convolve(b, a, mode: :valid) 358 | assert out == expected 359 | end 360 | 361 | test "same mode" do 362 | a = Nx.tensor([1, 2, 3, 3, 1, 2]) 363 | b = Nx.tensor([1, 4, 3, 4, 5, 6, 7, 4, 3, 2, 1, 1, 3]) 364 | 365 | c = NxSignal.Convolution.convolve(a, b, mode: :same) 366 | d = Nx.tensor([57, 61, 63, 57, 45, 36]) |> Nx.as_type({:f, 32}) 367 | assert c == d 368 | end 369 | 370 | test "invalid shapes" do 371 | a = 372 | 1..6 373 | |> Enum.to_list() 374 | |> Nx.tensor() 375 | |> Nx.reshape({2, 3}) 376 | 377 | b = 378 | -6..-1 379 | |> Enum.to_list() 380 | |> Nx.tensor() 381 | |> Nx.reshape({3, 2}) 382 | 383 | assert_raise(ArgumentError, fn -> 384 | NxSignal.Convolution.convolve(a, b, mode: :valid) 385 | end) 386 | 387 | assert_raise(ArgumentError, fn -> 388 | NxSignal.Convolution.convolve(b, a, mode: :valid) 389 | end) 390 | end 391 | 392 | test "don't complexify" do 393 | a = Nx.tensor([1, 2, 3]) 394 | b = Nx.tensor([4, 5, 6]) 395 | types = [{:f, 32}, {:c, 64}] 396 | 397 | for t1 <- types, t2 <- types do 398 | aT = Nx.as_type(a, t1) 399 | bT = Nx.as_type(b, t2) 400 | 401 | outD = NxSignal.Convolution.convolve(aT, bT, method: :direct) 402 | outF = NxSignal.Convolution.convolve(aT, bT, method: :fft) 403 | 404 | assert_all_close(outD, outF) 405 | 406 | case {t1, t2} do 407 | {{ts1, _}, {ts2, _}} when ts1 == :c or ts2 == :c -> 408 | assert {:c, 64} == Nx.type(outF) 409 | assert {:c, 64} == Nx.type(outD) 410 | 411 | _el -> 412 | assert {:f, 32} == Nx.type(outF) 413 | assert {:f, 32} == Nx.type(outD) 414 | end 415 | end 416 | end 417 | 418 | test "mismatched dims" do 419 | assert_raise(ArgumentError, fn -> 420 | NxSignal.Convolution.convolve(Nx.tensor([1]), Nx.tensor(2), method: :direct) 421 | end) 422 | 423 | assert_raise(ArgumentError, fn -> 424 | NxSignal.Convolution.convolve(Nx.tensor(1), Nx.tensor([2]), method: :direct) 425 | end) 426 | 427 | assert_raise(ArgumentError, fn -> 428 | NxSignal.Convolution.convolve(Nx.tensor([1]), Nx.tensor(2), method: :fft) 429 | end) 430 | 431 | assert_raise(ArgumentError, fn -> 432 | NxSignal.Convolution.convolve(Nx.tensor(1), Nx.tensor([2]), method: :fft) 433 | end) 434 | 435 | assert_raise(ArgumentError, fn -> 436 | NxSignal.Convolution.convolve(Nx.tensor([1]), Nx.tensor([[2]])) 437 | end) 438 | 439 | assert_raise(ArgumentError, fn -> 440 | NxSignal.Convolution.convolve(Nx.tensor([3]), Nx.tensor(2)) 441 | end) 442 | end 443 | 444 | test "2d valid mode" do 445 | e = Nx.tensor([[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]) 446 | f = Nx.tensor([[1, 2, 3], [3, 4, 5]]) 447 | h = Nx.tensor([[62, 80, 98, 116, 134]]) |> Nx.as_type({:f, 32}) 448 | g = NxSignal.Convolution.convolve(e, f, mode: :valid) 449 | assert g == h 450 | 451 | g = NxSignal.Convolution.convolve(f, e, mode: :valid) 452 | assert g == h 453 | end 454 | 455 | test "FFT real" do 456 | a = Nx.tensor([1, 2, 3]) 457 | expected = Nx.tensor([1, 4, 10, 12, 9.0]) 458 | out = NxSignal.Convolution.convolve(a, a, method: :fft) 459 | assert_all_close(out, expected) 460 | end 461 | 462 | # test "FFT real axes" do 463 | # # This test relies on specifying axes to convolve which we don't support. 464 | # a = Nx.tensor([1, 2, 3]) 465 | # expected = Nx.tensor([1, 4, 10, 12, 9.0]) 466 | 467 | # a = Nx.tile(a, [2, 1]) 468 | # expected = Nx.tile(expected, [2, 1]) 469 | # out = NxSignal.Convolution.convolve(a, a, method: :fft) 470 | # assert_all_close(out, expected) 471 | # end 472 | 473 | test "FFT complex" do 474 | a = Nx.tensor([Complex.new(1, 1), Complex.new(2, 2), Complex.new(3, 3)]) 475 | 476 | expected = 477 | Nx.tensor([ 478 | Complex.new(0, 2), 479 | Complex.new(0, 8), 480 | Complex.new(0, 20), 481 | Complex.new(0, 24), 482 | Complex.new(0, 18) 483 | ]) 484 | 485 | out = NxSignal.Convolution.convolve(a, a, method: :fft) 486 | assert_all_close(out, expected) 487 | end 488 | 489 | test "FFT 2d real same" do 490 | a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 491 | expected = Nx.tensor([[1, 4, 10, 12, 9], [8, 26, 56, 54, 36], [16, 40, 73, 60, 36]]) 492 | out = NxSignal.Convolution.convolve(a, a, method: :fft) 493 | assert_all_close(out, expected) 494 | end 495 | 496 | test "FFT 2d complex same" do 497 | a = 498 | Nx.tensor([ 499 | [Complex.new(1, 2), Complex.new(3, 4), Complex.new(5, 6)], 500 | [Complex.new(2, 1), Complex.new(4, 3), Complex.new(6, 5)] 501 | ]) 502 | 503 | expected = 504 | Nx.tensor([ 505 | [ 506 | Complex.new(-3, 4), 507 | Complex.new(-10, 20), 508 | Complex.new(-21, 56), 509 | Complex.new(-18, 76), 510 | Complex.new(-11, 60) 511 | ], 512 | [ 513 | Complex.new(0, 10), 514 | Complex.new(0, 44), 515 | Complex.new(0, 118), 516 | Complex.new(0, 156), 517 | Complex.new(0, 122) 518 | ], 519 | [ 520 | Complex.new(3, 4), 521 | Complex.new(10, 20), 522 | Complex.new(21, 56), 523 | Complex.new(18, 76), 524 | Complex.new(11, 60) 525 | ] 526 | ]) 527 | 528 | out = NxSignal.Convolution.convolve(a, a, method: :fft) 529 | assert_all_close(out, expected) 530 | end 531 | 532 | test "FFT real same mode" do 533 | a = Nx.tensor([1, 2, 3]) 534 | b = Nx.tensor([3, 3, 5, 6, 8, 7, 9, 0, 1]) 535 | expected_1 = Nx.tensor([35.0, 41.0, 47.0]) 536 | expected_2 = Nx.tensor([9.0, 20.0, 25.0, 35.0, 41.0, 47.0, 39.0, 28.0, 2.0]) 537 | 538 | out = NxSignal.Convolution.convolve(a, b, method: :fft, mode: :same) 539 | 540 | assert_all_close(out, expected_1) 541 | 542 | out = NxSignal.Convolution.convolve(b, a, method: :fft, mode: :same) 543 | 544 | assert_all_close(out, expected_2) 545 | end 546 | 547 | test "FFT valid mode real" do 548 | a = Nx.tensor([3, 2, 1]) 549 | b = Nx.tensor([3, 3, 5, 6, 8, 7, 9, 0, 1]) 550 | 551 | expected = Nx.tensor([24.0, 31.0, 41.0, 43.0, 49.0, 25.0, 12.0]) 552 | 553 | out = NxSignal.Convolution.convolve(a, b, method: :fft, mode: :valid) 554 | 555 | assert_all_close(out, expected) 556 | 557 | out = NxSignal.Convolution.convolve(b, a, method: :fft, mode: :valid) 558 | 559 | assert_all_close(out, expected) 560 | end 561 | end 562 | 563 | describe "correlate/3" do 564 | def setup_rank1() do 565 | a = Nx.linspace(0, 3, n: 4) 566 | b = Nx.linspace(1, 2, n: 2) 567 | 568 | y = Nx.tensor([0, 2, 5, 8, 3]) 569 | 570 | {a, b, y} 571 | end 572 | 573 | test "rank 1 valid" do 574 | {a, b, y_r} = setup_rank1() 575 | y = NxSignal.Convolution.correlate(a, b, mode: :valid) 576 | assert_all_close(y, y_r[1..3]) 577 | 578 | y = NxSignal.Convolution.correlate(b, a, mode: :valid) 579 | assert_all_close(y, Nx.reverse(y_r[1..3], axes: [0])) 580 | end 581 | 582 | test "rank 1 same" do 583 | {a, b, y_r} = setup_rank1() 584 | y = NxSignal.Convolution.correlate(a, b, mode: :same) 585 | assert_all_close(y, y_r[0..-2//1]) 586 | end 587 | 588 | test "rank 1 full" do 589 | {a, b, y_r} = setup_rank1() 590 | y = NxSignal.Convolution.correlate(a, b, mode: :full) 591 | assert_all_close(y, y_r) 592 | end 593 | 594 | defp setup_rank1_complex(mode) do 595 | key = Nx.Random.key(9) 596 | 597 | {a, key} = Nx.Random.normal(key, shape: {10}, type: :c64) 598 | {a2, key} = Nx.Random.normal(key, shape: {10}, type: :c64) 599 | a = Nx.add(a, Nx.multiply(Complex.new(0, 1), a2)) 600 | 601 | {b, key} = Nx.Random.normal(key, shape: {8}, type: :c64) 602 | {b2, key} = Nx.Random.normal(key, shape: {8}, type: :c64) 603 | b = Nx.add(b, Nx.multiply(Complex.new(0, 1), b2)) 604 | 605 | y_r = 606 | Nx.add( 607 | NxSignal.Convolution.correlate(Nx.real(a), Nx.real(b), mode: mode), 608 | NxSignal.Convolution.correlate(Nx.imag(a), Nx.imag(b), mode: mode) 609 | ) 610 | 611 | y_r = 612 | Nx.add( 613 | y_r, 614 | Nx.multiply( 615 | Complex.new(0, 1), 616 | Nx.add( 617 | Nx.multiply(-1, NxSignal.Convolution.correlate(Nx.real(a), Nx.imag(b), mode: mode)), 618 | NxSignal.Convolution.correlate(Nx.imag(a), Nx.real(b), mode: mode) 619 | ) 620 | ) 621 | ) 622 | 623 | {key, a, b, y_r} 624 | end 625 | 626 | test "complex rank 1 valid" do 627 | {_key, a, b, y_r} = setup_rank1_complex(:valid) 628 | y = NxSignal.Convolution.correlate(a, b, mode: :valid) 629 | assert_all_close(y, y_r) 630 | end 631 | end 632 | 633 | # describe "oaconvolve/3" do 634 | # def gen_oa_shapes_eq(list) do 635 | # for a <- list, b <- list, a >= b do 636 | # {a, b} 637 | # end 638 | # end 639 | 640 | # test "real many lens" do 641 | # inputs = gen_oa_shapes_eq(Enum.to_list(0..99) ++ Enum.to_list(100..999//23)) 642 | # key = Nx.Random.key(123) 643 | 644 | # for {shape_a, shape_b} <- inputs do 645 | # {a, key} = Nx.Random.uniform(key, 0, 1, shape: {shape_a}) 646 | # {b, key} = Nx.Random.uniform(key, 0, 1, shape: {shape_b}) 647 | 648 | # expected = convolve(a, b, method: "fft") 649 | # out = oaconvolve(a, b) 650 | 651 | # assert_all_close(expected, out) 652 | # end 653 | # end 654 | # end 655 | end 656 | -------------------------------------------------------------------------------- /test/nx_signal/filters_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.FiltersTest do 2 | use NxSignal.Case 3 | doctest NxSignal.Filters 4 | 5 | describe "median/2" do 6 | test "performs 1D median filter" do 7 | t = Nx.tensor([10, 9, 8, 7, 1, 4, 5, 3, 2, 6]) 8 | opts = [kernel_shape: {3}] 9 | expected = Nx.tensor([9.0, 8.0, 7.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0]) 10 | 11 | assert NxSignal.Filters.median(t, opts) == expected 12 | end 13 | 14 | test "performs 2D median filter" do 15 | t = 16 | Nx.tensor([ 17 | [31, 11, 17, 13, 1], 18 | [1, 3, 19, 23, 29], 19 | [19, 5, 7, 37, 2] 20 | ]) 21 | 22 | opts = [kernel_shape: {3, 3}] 23 | 24 | expected = 25 | Nx.tensor([ 26 | [11.0, 13.0, 17.0, 17.0, 17.0], 27 | [11.0, 13.0, 17.0, 17.0, 17.0], 28 | [11.0, 13.0, 17.0, 17.0, 17.0] 29 | ]) 30 | 31 | assert NxSignal.Filters.median(t, opts) == expected 32 | end 33 | 34 | test "performs n-dim median filter" do 35 | t = 36 | Nx.tensor([ 37 | [ 38 | [31, 11, 17, 13, 1], 39 | [1, 3, 19, 23, 29], 40 | [19, 5, 7, 37, 2] 41 | ], 42 | [ 43 | [19, 5, 7, 37, 2], 44 | [1, 3, 19, 23, 29], 45 | [31, 11, 17, 13, 1] 46 | ], 47 | [ 48 | [1, 3, 19, 23, 29], 49 | [31, 11, 17, 13, 1], 50 | [19, 5, 7, 37, 2] 51 | ] 52 | ]) 53 | 54 | k1 = {3, 3, 1} 55 | k2 = {3, 3, 3} 56 | 57 | expected1 = 58 | Nx.tensor([ 59 | [ 60 | [19.0, 5.0, 17.0, 23.0, 2.0], 61 | [19.0, 5.0, 17.0, 23.0, 2.0], 62 | [19.0, 5.0, 17.0, 23.0, 2.0] 63 | ], 64 | [ 65 | [19.0, 5.0, 17.0, 23.0, 2.0], 66 | [19.0, 5.0, 17.0, 23.0, 2.0], 67 | [19.0, 5.0, 17.0, 23.0, 2.0] 68 | ], 69 | [ 70 | [19.0, 5.0, 17.0, 23.0, 2.0], 71 | [19.0, 5.0, 17.0, 23.0, 2.0], 72 | [19.0, 5.0, 17.0, 23.0, 2.0] 73 | ] 74 | ]) 75 | 76 | expected2 = 77 | Nx.tensor([ 78 | [ 79 | [11.0, 13.0, 17.0, 17.0, 17.0], 80 | [11.0, 13.0, 17.0, 17.0, 17.0], 81 | [11.0, 13.0, 17.0, 17.0, 17.0] 82 | ], 83 | [ 84 | [11.0, 13.0, 17.0, 17.0, 17.0], 85 | [11.0, 13.0, 17.0, 17.0, 17.0], 86 | [11.0, 13.0, 17.0, 17.0, 17.0] 87 | ], 88 | [ 89 | [11.0, 13.0, 17.0, 17.0, 17.0], 90 | [11.0, 13.0, 17.0, 17.0, 17.0], 91 | [11.0, 13.0, 17.0, 17.0, 17.0] 92 | ] 93 | ]) 94 | 95 | assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1 96 | assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2 97 | end 98 | 99 | test "raises if kernel_shape is not compatible" do 100 | t1 = Nx.iota({10}) 101 | opts1 = [kernel_shape: {5, 5}] 102 | 103 | assert_raise( 104 | ArgumentError, 105 | "kernel shape must be of the same rank as the tensor", 106 | fn -> NxSignal.Filters.median(t1, opts1) end 107 | ) 108 | 109 | t2 = Nx.iota({5, 5}) 110 | opts2 = [kernel_shape: {5, 5, 5}] 111 | 112 | assert_raise( 113 | ArgumentError, 114 | "kernel shape must be of the same rank as the tensor", 115 | fn -> NxSignal.Filters.median(t2, opts2) end 116 | ) 117 | end 118 | end 119 | 120 | describe "wiener/2" do 121 | test "performs n-dim wiener filter with calculated noise" do 122 | im = 123 | Nx.tensor( 124 | [ 125 | [1.0, 2.0, 3.0, 4.0, 5.0], 126 | [6.0, 7.0, 8.0, 9.0, 10.0], 127 | [11.0, 12.0, 13.0, 14.0, 15.0] 128 | ], 129 | type: :f64 130 | ) 131 | 132 | kernel_size = {3, 3} 133 | 134 | expected = 135 | Nx.tensor( 136 | [ 137 | [ 138 | 1.7777777777777777, 139 | 3.0, 140 | 3.6666666666666665, 141 | 4.333333333333333, 142 | 3.111111111111111 143 | ], 144 | [4.3366520642506305, 7.0, 8.0, 9.0, 7.58637597408283], 145 | [ 146 | 4.692197051420351, 147 | 7.261706150595039, 148 | 8.748939779474131, 149 | 10.157992415073023, 150 | 9.813815742524799 151 | ] 152 | ], 153 | type: :f64 154 | ) 155 | 156 | assert NxSignal.Filters.wiener(im, kernel_size: kernel_size) == expected 157 | assert NxSignal.Filters.wiener(im, kernel_size: 3) == expected 158 | 159 | assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size) == 160 | Nx.tensor([ 161 | [ 162 | 1.7777777910232544, 163 | 3.0, 164 | 3.6666667461395264, 165 | 4.333333492279053, 166 | 3.1111111640930176 167 | ], 168 | [4.3366522789001465, 7.0, 8.0, 9.0, 7.586376190185547], 169 | [ 170 | 4.692196846008301, 171 | 7.261706352233887, 172 | 8.748939514160156, 173 | 10.157992362976074, 174 | 9.81381607055664 175 | ] 176 | ]) 177 | end 178 | 179 | test "performs n-dim wiener filter with parameterized noise" do 180 | im = 181 | Nx.tensor( 182 | [ 183 | [1.0, 2.0, 3.0, 4.0, 5.0], 184 | [6.0, 7.0, 8.0, 9.0, 10.0], 185 | [11.0, 12.0, 13.0, 14.0, 15.0] 186 | ], 187 | type: :f64 188 | ) 189 | 190 | kernel_size = {3, 3} 191 | 192 | assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 10) == 193 | Nx.tensor( 194 | [ 195 | [ 196 | 1.7777777777777777, 197 | 3.0, 198 | 3.5882352941176467, 199 | 4.238095238095238, 200 | 3.7397034596375622 201 | ], 202 | [5.193548387096774, 7.0, 8.0, 9.0, 8.829787234042554], 203 | [ 204 | 7.941747572815534, 205 | 9.702702702702702, 206 | 10.938931297709924, 207 | 12.137254901960784, 208 | 12.485549132947977 209 | ] 210 | ], 211 | type: :f64 212 | ) 213 | 214 | assert NxSignal.Filters.wiener(Nx.as_type(im, :f32), kernel_size: kernel_size, noise: 10) == 215 | Nx.tensor([ 216 | [ 217 | 1.7777777910232544, 218 | 3.0, 219 | 3.588235378265381, 220 | 4.238095283508301, 221 | 3.739703416824341 222 | ], 223 | [5.193548202514648, 7.0, 8.0, 9.0, 8.829787254333496], 224 | [ 225 | 7.941747665405273, 226 | 9.702702522277832, 227 | 10.938931465148926, 228 | 12.13725471496582, 229 | 12.485548973083496 230 | ] 231 | ]) 232 | 233 | assert NxSignal.Filters.wiener(im, kernel_size: kernel_size, noise: 0) == 234 | Nx.tensor( 235 | [ 236 | [1.0, 2.0, 3.0, 4.0, 5.0], 237 | [6.0, 7.0, 8.0, 9.0, 10.0], 238 | [11.0, 12.0, 13.0, 14.0, 15.0] 239 | ], 240 | type: :f64 241 | ) 242 | end 243 | end 244 | end 245 | -------------------------------------------------------------------------------- /test/nx_signal/internal_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.InternalTest do 2 | use NxSignal.Case, async: true, validate_doc_metadata: false 3 | 4 | describe "lambert_w/3" do 5 | test "scipy values" do 6 | for {a, b, y} <- [ 7 | {0, 0, 0}, 8 | {Complex.new(0, 0), 0, 0}, 9 | {Nx.Constants.infinity(:f64), 0, Complex.new(:infinity, 0)}, 10 | {0, -1, Complex.new(:neg_infinity, 0)}, 11 | {0, 1, Complex.new(:neg_infinity, 0)}, 12 | {0, 3, Complex.new(:neg_infinity, 0)}, 13 | {Nx.to_number(Nx.Constants.e({:f, 64})), 0, 1}, 14 | {1, 0, 0.567143290409783873}, 15 | {-Nx.to_number(Nx.Constants.pi(:f64)) / 2, 0, 16 | Complex.new(0, Nx.to_number(Nx.Constants.pi(:f64)) / 2)}, 17 | {-:math.log(2.0) / 2, 0, -:math.log(2)}, 18 | {0.25, 0, 0.203888354702240164}, 19 | {-0.25, 0, -0.357402956181388903}, 20 | {-1.0 / 10000, 0, -0.000100010001500266719}, 21 | {-0.25, -1, -2.15329236411034965}, 22 | {0.25, -1, Complex.new(-3.00899800997004620, -4.07652978899159763)}, 23 | {-0.25, -1, -2.15329236411034965}, 24 | {0.25, 1, Complex.new(-3.00899800997004620, 4.07652978899159763)}, 25 | {-0.25, 1, Complex.new(-3.48973228422959210, 7.41405453009603664)}, 26 | {-4, 0, Complex.new(0.67881197132094523, 1.91195078174339937)}, 27 | {-4, 1, Complex.new(-0.6674310712980098, 7.76827456802783084)}, 28 | {-4, -1, Complex.new(0.67881197132094523, -1.91195078174339937)}, 29 | {1000, 0, 5.24960285240159623}, 30 | {1000, 1, Complex.new(4.91492239981054535, 5.44652615979447070)}, 31 | {1000, -1, Complex.new(4.91492239981054535, -5.44652615979447070)}, 32 | {1000, 5, Complex.new(3.5010625305312892, 29.9614548941181328)}, 33 | {Complex.new(3, 4), 0, Complex.new(1.281561806123775878, 0.533095222020971071)}, 34 | {Complex.new(-0.4, 0.4), 0, Complex.new(-0.10396515323290657, 0.61899273315171632)}, 35 | {Complex.new(3, 4), 1, Complex.new(-0.11691092896595324, 5.61888039871282334)}, 36 | {Complex.new(3, 4), -1, Complex.new(0.25856740686699742, -3.85211668616143559)}, 37 | {-0.5, -1, Complex.new(-0.794023632344689368, -0.770111750510379110)}, 38 | {-1.0 / 10000, 1, Complex.new(-11.82350837248724344, 6.80546081842002101)}, 39 | {-1.0 / 10000, -1, -11.6671145325663544}, 40 | {-1.0 / 10000, -2, Complex.new(-11.82350837248724344, -6.80546081842002101)}, 41 | {-1.0 / 100_000, 4, Complex.new(-14.9186890769540539, 26.1856750178782046)}, 42 | {-1.0 / 100_000, 5, Complex.new(-15.0931437726379218666, 32.5525721210262290086)}, 43 | {Complex.divide(Complex.new(2, 1), 10), 0, 44 | Complex.new(0.173704503762911669, 0.071781336752835511)}, 45 | {Complex.divide(Complex.new(2, 1), 10), 1, 46 | Complex.new(-3.21746028349820063, 4.56175438896292539)}, 47 | {Complex.divide(Complex.new(2, 1), 10), -1, 48 | Complex.new(-3.03781405002993088, -3.53946629633505737)}, 49 | {Complex.divide(Complex.new(2, 1), 10), 4, 50 | Complex.new(-4.6878509692773249, 23.8313630697683291)}, 51 | {Complex.divide(Complex.new(-2, -1), 10), 0, 52 | Complex.new(-0.226933772515757933, -0.164986470020154580)}, 53 | {Complex.divide(Complex.new(-2, -1), 10), 1, 54 | Complex.new(-2.43569517046110001, 0.76974067544756289)}, 55 | {Complex.divide(Complex.new(-2, -1), 10), -1, 56 | Complex.new(-3.54858738151989450, -6.91627921869943589)}, 57 | {Complex.divide(Complex.new(-2, -1), 10), 4, 58 | Complex.new(-4.5500846928118151, 20.6672982215434637)}, 59 | {Nx.Constants.pi(:f64), 0, 1.073658194796149172092178407024821347547745350410314531}, 60 | {Complex.new(-0.5, 0.002), 0, 61 | Complex.new(-0.78917138132659918344, 0.76743539379990327749)}, 62 | {Complex.new(-0.5, -0.002), 0, 63 | Complex.new(-0.78917138132659918344, -0.76743539379990327749)}, 64 | {Complex.new(-0.448, 0.4), 0, 65 | Complex.new(-0.11855133765652382241, 0.66570534313583423116)}, 66 | {Complex.new(-0.448, -0.4), 0, 67 | Complex.new(-0.11855133765652382241, -0.66570534313583423116)} 68 | ] do 69 | x = NxSignal.Internal.lambert_w(a, b) 70 | assert_all_close(x, as_tensor(y), atol: 1.0e-13, rtol: 1.0e-10) 71 | end 72 | end 73 | end 74 | 75 | defp as_tensor(a) when is_struct(a, Complex) do 76 | Nx.tensor(a, type: :c128) 77 | end 78 | 79 | defp as_tensor(a) when is_struct(a, Nx.Tensor) do 80 | a 81 | end 82 | 83 | defp as_tensor(a) do 84 | Nx.f64(a) 85 | end 86 | end 87 | -------------------------------------------------------------------------------- /test/nx_signal/peak_finding_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.PeakFindingTest do 2 | use NxSignal.Case, async: true 3 | doctest NxSignal.PeakFinding 4 | end 5 | -------------------------------------------------------------------------------- /test/nx_signal/transforms_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.TransformsTests do 2 | use NxSignal.Case, async: true, validate_doc_metadata: false 3 | import NxSignal.Transforms 4 | 5 | describe "fftnd/2" do 6 | test "equivlanet to fft" do 7 | a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 8 | assert Nx.fft(a) == fft_nd(a) 9 | end 10 | 11 | test "all axes" do 12 | a = Nx.tensor([[1, 0], [0, 1]]) 13 | 14 | out = 15 | Nx.tensor([ 16 | [2, Complex.new(0, 0)], 17 | [Complex.new(0, 0), 2] 18 | ]) 19 | 20 | assert fft_nd(a, axes: Nx.axes(a)) == out 21 | end 22 | end 23 | 24 | describe "ifftnd_/2" do 25 | test "equivlanet to ifft" do 26 | a = Nx.tensor([[1, 2, 3], [4, 5, 6]]) 27 | assert Nx.ifft(a) == ifft_nd(a) 28 | end 29 | 30 | test "all axes" do 31 | a = 32 | Nx.tensor([ 33 | [2, Complex.new(0, 0)], 34 | [Complex.new(0, 0), 2] 35 | ]) 36 | 37 | out = 38 | Nx.tensor([[Complex.new(1, 0), 0], [0, 1]]) 39 | 40 | assert ifft_nd(a, axes: Nx.axes(a)) == out 41 | end 42 | end 43 | end 44 | -------------------------------------------------------------------------------- /test/nx_signal/waveforms_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.WaveformsTest do 2 | use NxSignal.Case, async: true 3 | doctest NxSignal.Waveforms 4 | 5 | test "sawtooth scipy regression" do 6 | t = Nx.linspace(0, 6, n: 20) 7 | 8 | width_1 = 9 | Nx.tensor([ 10 | -1.0, 11 | -0.8994811177253723, 12 | -0.7989621758460999, 13 | -0.6984432935714722, 14 | -0.5979243516921997, 15 | -0.497405469417572, 16 | -0.39688658714294434, 17 | -0.2963676452636719, 18 | -0.1958487629890442, 19 | -0.0953298807144165, 20 | 0.005189061164855957, 21 | 0.10570800304412842, 22 | 0.20622682571411133, 23 | 0.30674564838409424, 24 | 0.40726470947265625, 25 | 0.5077836513519287, 26 | 0.6083024740219116, 27 | 0.7088212966918945, 28 | 0.809340238571167, 29 | 0.909859299659729 30 | ]) 31 | 32 | assert NxSignal.Waveforms.sawtooth(t) == width_1 33 | assert NxSignal.Waveforms.sawtooth(t, width: 1) == width_1 34 | 35 | assert NxSignal.Waveforms.sawtooth(t, width: 0.2) == 36 | Nx.tensor([ 37 | -1.0, 38 | -0.497405469417572, 39 | 0.005189061164855957, 40 | 0.5077835321426392, 41 | 0.997405469417572, 42 | 0.8717568516731262, 43 | 0.7461082339286804, 44 | 0.6204595565795898, 45 | 0.49481093883514404, 46 | 0.36916232109069824, 47 | 0.24351368844509125, 48 | 0.11786506325006485, 49 | -0.007783569395542145, 50 | -0.13343210518360138, 51 | -0.25908082723617554, 52 | -0.3847295641899109, 53 | -0.5103781223297119, 54 | -0.6360266208648682, 55 | -0.7616753578186035, 56 | -0.8873240947723389 57 | ]) 58 | 59 | assert NxSignal.Waveforms.sawtooth(t, width: 0) == 60 | Nx.tensor([ 61 | 1.0, 62 | 0.8994811177253723, 63 | 0.7989621758460999, 64 | 0.6984432935714722, 65 | 0.5979243516921997, 66 | 0.497405469417572, 67 | 0.39688655734062195, 68 | 0.29636767506599426, 69 | 0.1958487629890442, 70 | 0.09532985836267471, 71 | -0.005189046263694763, 72 | -0.10570795089006424, 73 | -0.20622685551643372, 74 | -0.3067456781864166, 75 | -0.40726467967033386, 76 | -0.5077836513519287, 77 | -0.6083024740219116, 78 | -0.7088212966918945, 79 | -0.8093402981758118, 80 | -0.9098592400550842 81 | ]) 82 | end 83 | end 84 | -------------------------------------------------------------------------------- /test/nx_signal/windows_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.WindowsTest do 2 | use NxSignal.Case 3 | doctest NxSignal.Windows 4 | end 5 | -------------------------------------------------------------------------------- /test/nx_signal_test.exs: -------------------------------------------------------------------------------- 1 | defmodule NxSignalTest do 2 | use NxSignal.Case 3 | doctest NxSignal 4 | end 5 | -------------------------------------------------------------------------------- /test/support/nx_signal_case.ex: -------------------------------------------------------------------------------- 1 | defmodule NxSignal.Case do 2 | use ExUnit.CaseTemplate 3 | import ExUnit.Assertions 4 | 5 | using opts do 6 | validate_doc_metadata = Keyword.get(opts, :validate_doc_metadata, true) 7 | 8 | quote do 9 | import NxSignal.Case 10 | 11 | if unquote(validate_doc_metadata) do 12 | test "defines doc :type" do 13 | validate_doc_metadata(__MODULE__) 14 | end 15 | end 16 | end 17 | end 18 | 19 | @doctypes [ 20 | :time_frequency, 21 | :windowing, 22 | :filters, 23 | :waveforms, 24 | :peak_finding 25 | ] 26 | 27 | def validate_doc_metadata(module) do 28 | [h | t] = module |> Module.split() |> Enum.reverse() 29 | h = String.trim_trailing(h, "Test") 30 | mod = [h | t] |> Enum.reverse() |> Module.concat() 31 | 32 | {:docs_v1, _, :elixir, "text/markdown", _docs, _metadata, entries} = Code.fetch_docs(mod) 33 | 34 | for {{:function, name, arity}, _ann, _signature, docs, metadata} <- entries, 35 | is_map(docs) and map_size(docs) > 0, 36 | metadata[:type] not in @doctypes do 37 | flunk("invalid @doc type: #{inspect(metadata[:type])} for #{name}/#{arity}") 38 | end 39 | end 40 | 41 | @doc """ 42 | Asserts `lhs` is close to `rhs`. 43 | """ 44 | def assert_all_close(lhs, rhs, opts \\ []) do 45 | atol = opts[:atol] || 1.0e-4 46 | rtol = opts[:rtol] || 1.0e-4 47 | 48 | if Nx.all_close(lhs, rhs, atol: atol, rtol: rtol, equal_nan: opts[:equal_nan]) != 49 | Nx.tensor(1, type: {:u, 8}) do 50 | flunk(""" 51 | expected 52 | 53 | #{inspect(lhs)} 54 | 55 | to be within tolerance of 56 | 57 | #{inspect(rhs)} 58 | """) 59 | end 60 | end 61 | end 62 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------