├── .gitignore ├── src ├── utils │ ├── mod.rs │ ├── peak.rs │ └── buffer.rs ├── float │ └── mod.rs ├── detector │ ├── mod.rs │ ├── autocorrelation.rs │ ├── mcleod.rs │ ├── yin.rs │ └── internals.rs └── lib.rs ├── tests ├── samples │ ├── LICENSE.txt │ ├── violin-D4.wav │ ├── violin-F4.wav │ ├── violin-G4.wav │ ├── tenor-trombone-B3.wav │ ├── tenor-trombone-C3.wav │ ├── tenor-trombone-Ab3.wav │ ├── tenor-trombone-Db3.wav │ └── README.md └── main.rs ├── .github └── workflows │ └── main.yaml ├── Cargo.toml ├── LICENSE ├── README.md └── benches └── utils_benchmark.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | -------------------------------------------------------------------------------- /src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod buffer; 2 | pub mod peak; 3 | -------------------------------------------------------------------------------- /tests/samples/LICENSE.txt: -------------------------------------------------------------------------------- 1 | All audio samples are in the public domain, licensed under CC-0 2 | -------------------------------------------------------------------------------- /tests/samples/violin-D4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/violin-D4.wav -------------------------------------------------------------------------------- /tests/samples/violin-F4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/violin-F4.wav -------------------------------------------------------------------------------- /tests/samples/violin-G4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/violin-G4.wav -------------------------------------------------------------------------------- /tests/samples/tenor-trombone-B3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/tenor-trombone-B3.wav -------------------------------------------------------------------------------- /tests/samples/tenor-trombone-C3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/tenor-trombone-C3.wav -------------------------------------------------------------------------------- /tests/samples/tenor-trombone-Ab3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/tenor-trombone-Ab3.wav -------------------------------------------------------------------------------- /tests/samples/tenor-trombone-Db3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alesgenova/pitch-detection/HEAD/tests/samples/tenor-trombone-Db3.wav -------------------------------------------------------------------------------- /tests/samples/README.md: -------------------------------------------------------------------------------- 1 | # Audio Samples 2 | 3 | Audio samples are released public domain under the CC-0 license unless otherwise specified. 4 | 5 | Tenor Trombone samples from the University of Iowa Electronic Music Studio: http://theremin.music.uiowa.edu/MIStenortrombone.html 6 | -------------------------------------------------------------------------------- /src/float/mod.rs: -------------------------------------------------------------------------------- 1 | //! Generic [Float] type which acts as a stand-in for `f32` or `f64`. 2 | use rustfft::num_traits::float::FloatCore as NumFloatCore; 3 | use rustfft::FftNum; 4 | use std::fmt::{Debug, Display}; 5 | 6 | /// Signals are processed as arrays of [Float]s. A [Float] is normally `f32` or `f64`. 7 | pub trait Float: Display + Debug + NumFloatCore + FftNum {} 8 | 9 | impl Float for f64 {} 10 | impl Float for f32 {} 11 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: main 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "pitch-detection" 3 | description = "A collection of algorithms to determine the pitch of a sound sample." 4 | version = "0.3.0" 5 | authors = ["Alessandro Genova "] 6 | edition = "2021" 7 | license = "MIT/Apache-2.0" 8 | homepage = "https://github.com/alesgenova/pitch-detection" 9 | repository = "https://github.com/alesgenova/pitch-detection" 10 | keywords = ["pitch", "frequency", "detection", "sound"] 11 | categories = ["algorithms", "multimedia::audio", "no-std"] 12 | readme = "README.md" 13 | 14 | [dependencies] 15 | rustfft = { version = "6.0.1", default-features = false } 16 | 17 | [dev-dependencies] 18 | criterion = "0.3" 19 | hound = { version = "3.4.0" } 20 | 21 | [[bench]] 22 | name = "utils_benchmark" 23 | harness = false 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-2022 Alessandro Genova 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /src/detector/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Pitch Detectors 2 | //! Each detector implements a different pitch-detection algorithm. 3 | //! Every detector implements the standard [PitchDetector] trait. 4 | 5 | use crate::detector::internals::Pitch; 6 | use crate::float::Float; 7 | 8 | pub mod autocorrelation; 9 | #[doc(hidden)] 10 | pub mod internals; 11 | pub mod mcleod; 12 | pub mod yin; 13 | 14 | /// A uniform interface to all pitch-detection algorithms. 15 | pub trait PitchDetector 16 | where 17 | T: Float, 18 | { 19 | /// Get an estimate of the [Pitch] of the sound sample stored in `signal`. 20 | /// 21 | /// Arguments: 22 | /// 23 | /// * `signal`: The signal to be analyzed 24 | /// * `sample_rate`: The number of samples per second contained in the signal. 25 | /// * `power_threshold`: If the signal has a power below this threshold, no 26 | /// attempt is made to find its pitch and `None` is returned. 27 | /// * `clarity_threshold`: A number between 0 and 1 reflecting the confidence 28 | /// the algorithm has in its estimate of the frequency. Higher `clarity_threshold`s 29 | /// correspond to higher confidence. 30 | fn get_pitch( 31 | &mut self, 32 | signal: &[T], 33 | sample_rate: usize, 34 | power_threshold: T, 35 | clarity_threshold: T, 36 | ) -> Option>; 37 | } 38 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Pitch Detection 2 | //! *pitch_detection* implements several algorithms for estimating the 3 | //! fundamental frequency of a sound wave stored in a buffer. It is designed 4 | //! to be usable in a WASM environment. 5 | //! 6 | //! # Detectors 7 | //! A *detector* is an implementation of a pitch detection algorithm. Each detector's tolerance 8 | //! for noise and polyphonic sounds varies. 9 | //! 10 | //! * [AutocorrelationDetector][detector::autocorrelation] 11 | //! * [McLeodDetector][detector::mcleod] 12 | //! * [YINDetector][detector::yin] 13 | //! 14 | //! # Examples 15 | //! ``` 16 | //! use pitch_detection::detector::mcleod::McLeodDetector; 17 | //! use pitch_detection::detector::PitchDetector; 18 | //! 19 | //! fn main() { 20 | //! const SAMPLE_RATE: usize = 44100; 21 | //! const SIZE: usize = 1024; 22 | //! const PADDING: usize = SIZE / 2; 23 | //! const POWER_THRESHOLD: f64 = 5.0; 24 | //! const CLARITY_THRESHOLD: f64 = 0.7; 25 | //! 26 | //! // Signal coming from some source (microphone, generated, etc...) 27 | //! let dt = 1.0 / SAMPLE_RATE as f64; 28 | //! let freq = 300.0; 29 | //! let signal: Vec = (0..SIZE) 30 | //! .map(|x| (2.0 * std::f64::consts::PI * x as f64 * dt * freq).sin()) 31 | //! .collect(); 32 | //! 33 | //! let mut detector = McLeodDetector::new(SIZE, PADDING); 34 | //! 35 | //! let pitch = detector 36 | //! .get_pitch(&signal, SAMPLE_RATE, POWER_THRESHOLD, CLARITY_THRESHOLD) 37 | //! .unwrap(); 38 | //! 39 | //! println!("Frequency: {}, Clarity: {}", pitch.frequency, pitch.clarity); 40 | //! } 41 | //! ``` 42 | 43 | pub use detector::internals::Pitch; 44 | 45 | pub mod detector; 46 | pub mod float; 47 | pub mod utils; 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![workflow status](https://github.com/alesgenova/pitch-detection/workflows/main/badge.svg?branch=master)](https://github.com/alesgenova/pitch-detection/actions?query=workflow%3Amain+branch%3Amaster) 2 | [![crates.io](https://img.shields.io/crates/v/pitch-detection.svg)](https://crates.io/crates/pitch-detection) 3 | 4 | # pitch_detection 5 | 6 | ## Usage 7 | ```rust 8 | use pitch_detection::detector::mcleod::McLeodDetector; 9 | use pitch_detection::detector::PitchDetector; 10 | 11 | fn main() { 12 | const SAMPLE_RATE: usize = 44100; 13 | const SIZE: usize = 1024; 14 | const PADDING: usize = SIZE / 2; 15 | const POWER_THRESHOLD: f64 = 5.0; 16 | const CLARITY_THRESHOLD: f64 = 0.7; 17 | 18 | // Signal coming from some source (microphone, generated, etc...) 19 | let dt = 1.0 / SAMPLE_RATE as f64; 20 | let freq = 300.0; 21 | let signal: Vec = (0..SIZE) 22 | .map(|x| (2.0 * std::f64::consts::PI * x as f64 * dt * freq).sin()) 23 | .collect(); 24 | 25 | let mut detector = McLeodDetector::new(SIZE, PADDING); 26 | 27 | let pitch = detector 28 | .get_pitch(&signal, SAMPLE_RATE, POWER_THRESHOLD, CLARITY_THRESHOLD) 29 | .unwrap(); 30 | 31 | println!("Frequency: {}, Clarity: {}", pitch.frequency, pitch.clarity); 32 | } 33 | ``` 34 | ## Live Demo 35 | [![Demo Page](https://raw.githubusercontent.com/alesgenova/pitch-detection-app/master/demo.png)](https://alesgenova.github.io/pitch-detection-app/) 36 | [Source](https://github.com/alesgenova/pitch-detection-app) 37 | 38 | ## Documentation 39 | LaTeX formulas can be used in documentation. This is enabled by a method outlined in [rust-latex-doc-minimal-example](https://github.com/victe/rust-latex-doc-minimal-example). To build the docs, use 40 | ``` 41 | cargo doc --no-deps 42 | ``` 43 | The `--no-deps` flag is needed because special headers are included to auto-process the math in the documentation. This 44 | header is specified using a relative path and so an error is produced if `cargo` tries generate documentation for 45 | dependencies. -------------------------------------------------------------------------------- /benches/utils_benchmark.rs: -------------------------------------------------------------------------------- 1 | use std::f64::consts::PI; 2 | 3 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 4 | use pitch_detection::{ 5 | detector::{ 6 | autocorrelation::AutocorrelationDetector, mcleod::McLeodDetector, yin::YINDetector, 7 | PitchDetector, 8 | }, 9 | utils::peak::detect_peaks, 10 | }; 11 | 12 | pub fn utils_benchmark(c: &mut Criterion) { 13 | let v = (0..1024) 14 | .into_iter() 15 | .map(|v| ((v as f64) / PI / 30.).sin()) 16 | .collect::>(); 17 | let vv = v.as_slice(); 18 | 19 | c.bench_function("detect_peaks", |b| { 20 | b.iter(|| detect_peaks(black_box(vv)).collect::>()) 21 | }); 22 | } 23 | 24 | pub fn pitch_detect_benchmark(c: &mut Criterion) { 25 | const SAMPLE_RATE: usize = 44100; 26 | const SIZE: usize = 1024; 27 | const PADDING: usize = SIZE / 2; 28 | const POWER_THRESHOLD: f64 = 5.0; 29 | const CLARITY_THRESHOLD: f64 = 0.7; 30 | 31 | // Signal coming from some source (microphone, generated, etc...) 32 | let dt = 1.0 / SAMPLE_RATE as f64; 33 | let freq = 300.0; 34 | let signal: Vec = (0..SIZE) 35 | .map(|x| (2.0 * std::f64::consts::PI * x as f64 * dt * freq).sin()) 36 | .collect(); 37 | 38 | c.bench_function("McLeod get_pitch", |b| { 39 | let mut mcleod_detector = McLeodDetector::new(SIZE, PADDING); 40 | b.iter(|| { 41 | mcleod_detector 42 | .get_pitch( 43 | black_box(&signal), 44 | SAMPLE_RATE, 45 | POWER_THRESHOLD, 46 | CLARITY_THRESHOLD, 47 | ) 48 | .unwrap() 49 | }); 50 | }); 51 | 52 | c.bench_function("Autocorrelation get_pitch", |b| { 53 | let mut autocorrelation_detector = AutocorrelationDetector::new(SIZE, PADDING); 54 | b.iter(|| { 55 | autocorrelation_detector 56 | .get_pitch( 57 | black_box(&signal), 58 | SAMPLE_RATE, 59 | POWER_THRESHOLD, 60 | CLARITY_THRESHOLD, 61 | ) 62 | .unwrap() 63 | }); 64 | }); 65 | c.bench_function("YIN get_pitch", |b| { 66 | let mut yin_detector = YINDetector::new(SIZE, PADDING); 67 | b.iter(|| { 68 | yin_detector 69 | .get_pitch( 70 | black_box(&signal), 71 | SAMPLE_RATE, 72 | POWER_THRESHOLD, 73 | CLARITY_THRESHOLD, 74 | ) 75 | .unwrap() 76 | }); 77 | }); 78 | } 79 | 80 | criterion_group!(benches, pitch_detect_benchmark, utils_benchmark); 81 | criterion_main!(benches); 82 | -------------------------------------------------------------------------------- /src/detector/autocorrelation.rs: -------------------------------------------------------------------------------- 1 | //! Autocorrelation is one of the most basic forms of pitch detection. Let $S=(s_0,s_1,\ldots,s_N)$ 2 | //! be a discrete signal. Then, the autocorrelation function of $S$ at time $t$ is 3 | //! $$ A_t(S) = \sum_{i=0}^{N-t} s_i s_{i+t}. $$ 4 | //! The autocorrelation function is largest when $t=0$. Subsequent peaks indicate when the signal 5 | //! is particularly well aligned with itself. Thus, peaks of $A_t(S)$ when $t>0$ are good candidates 6 | //! for the fundamental frequency of $S$. 7 | //! 8 | //! Unfortunately, autocorrelation-based pitch detection is prone to octave errors, since a signal 9 | //! may "line up" with itself better when shifted by amounts larger than by the fundamental frequency. 10 | //! Further, autocorrelation is a bad choice for situations where the fundamental frequency may not 11 | //! be the loudest frequency (which is common in telephone speech and for certain types of instruments). 12 | //! 13 | //! ## Implementation 14 | //! Rather than compute the autocorrelation function directly, an [FFT](https://en.wikipedia.org/wiki/Fast_Fourier_transform) 15 | //! is used, providing a dramatic speed increase for large buffers. 16 | 17 | use crate::detector::internals::pitch_from_peaks; 18 | use crate::detector::internals::DetectorInternals; 19 | use crate::detector::internals::Pitch; 20 | use crate::detector::PitchDetector; 21 | use crate::float::Float; 22 | use crate::utils::peak::PeakCorrection; 23 | use crate::{detector::internals::autocorrelation, utils::buffer::square_sum}; 24 | 25 | pub struct AutocorrelationDetector 26 | where 27 | T: Float, 28 | { 29 | internals: DetectorInternals, 30 | } 31 | 32 | impl AutocorrelationDetector 33 | where 34 | T: Float, 35 | { 36 | pub fn new(size: usize, padding: usize) -> Self { 37 | let internals = DetectorInternals::new(size, padding); 38 | AutocorrelationDetector { internals } 39 | } 40 | } 41 | 42 | impl PitchDetector for AutocorrelationDetector 43 | where 44 | T: Float + std::iter::Sum, 45 | { 46 | fn get_pitch( 47 | &mut self, 48 | signal: &[T], 49 | sample_rate: usize, 50 | power_threshold: T, 51 | clarity_threshold: T, 52 | ) -> Option> { 53 | assert_eq!(signal.len(), self.internals.size); 54 | 55 | if square_sum(signal) < power_threshold { 56 | return None; 57 | } 58 | 59 | let result_ref = self.internals.buffers.get_real_buffer(); 60 | let result = &mut result_ref.borrow_mut()[..]; 61 | 62 | autocorrelation(signal, &mut self.internals.buffers, result); 63 | let clarity_threshold = clarity_threshold * result[0]; 64 | 65 | pitch_from_peaks(result, sample_rate, clarity_threshold, PeakCorrection::None) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/detector/mcleod.rs: -------------------------------------------------------------------------------- 1 | //! The McLeod pitch detection algorithm is based on the algorithm from the paper 2 | //! *[A Smarter Way To Find Pitch](https://www.researchgate.net/publication/230554927_A_smarter_way_to_find_pitch)*. 3 | //! It is efficient and offers an improvement over basic autocorrelation. 4 | //! 5 | //! The algorithm is based on finding peaks of the *normalized square difference* function. Let $S=(s_0,s_1,\ldots,s_N)$ 6 | //! be a discrete signal. The *square difference function* at time $t$ is defined by 7 | //! $$ d\'(t) = \sum_{i=0}^{N-t} (s_i-s_{i+t})^2. $$ 8 | //! This function is close to zero when the signal "lines up" with itself. However, *close* is a relative term, 9 | //! and the value of $d\'(t)$ depends on volume, which should not affect the pitch of the signal. For this 10 | //! reason, the *normalized square difference function*, $n\'(t)$, is computed. 11 | //! $$ n\'(t) = \frac{d\'(t)}{\sum_{i=0}^{N-t} (x_i^2+x_{i+t}^2) } $$ 12 | //! The algorithm then searches for the first local minimum of $n\'(t)$ below a given threshold, called the 13 | //! *clarity threshold*. 14 | //! 15 | //! ## Implementation 16 | //! As outlined in *A Smarter Way To Find Pitch*, 17 | //! an [FFT](https://en.wikipedia.org/wiki/Fast_Fourier_transform) is used to greatly speed up the computation of 18 | //! the normalized square difference function. Further, the algorithm applies some algebraic tricks and actually 19 | //! searches for the *peaks* of $1-n\'(t)$, rather than minimums of $n\'(t)$. 20 | //! 21 | //! After a peak is found, quadratic interpolation is applied to further refine the estimate. 22 | use crate::detector::internals::normalized_square_difference; 23 | use crate::detector::internals::pitch_from_peaks; 24 | use crate::detector::internals::DetectorInternals; 25 | use crate::detector::internals::Pitch; 26 | use crate::detector::PitchDetector; 27 | use crate::float::Float; 28 | use crate::utils::buffer::square_sum; 29 | use crate::utils::peak::PeakCorrection; 30 | 31 | pub struct McLeodDetector 32 | where 33 | T: Float + std::iter::Sum, 34 | { 35 | internals: DetectorInternals, 36 | } 37 | 38 | impl McLeodDetector 39 | where 40 | T: Float + std::iter::Sum, 41 | { 42 | pub fn new(size: usize, padding: usize) -> Self { 43 | let internals = DetectorInternals::new(size, padding); 44 | McLeodDetector { internals } 45 | } 46 | } 47 | 48 | impl PitchDetector for McLeodDetector 49 | where 50 | T: Float + std::iter::Sum, 51 | { 52 | fn get_pitch( 53 | &mut self, 54 | signal: &[T], 55 | sample_rate: usize, 56 | power_threshold: T, 57 | clarity_threshold: T, 58 | ) -> Option> { 59 | assert_eq!(signal.len(), self.internals.size); 60 | 61 | if square_sum(signal) < power_threshold { 62 | return None; 63 | } 64 | let result_ref = self.internals.buffers.get_real_buffer(); 65 | let result = &mut result_ref.borrow_mut()[..]; 66 | 67 | normalized_square_difference(signal, &mut self.internals.buffers, result); 68 | pitch_from_peaks( 69 | result, 70 | sample_rate, 71 | clarity_threshold, 72 | PeakCorrection::Quadratic, 73 | ) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/detector/yin.rs: -------------------------------------------------------------------------------- 1 | //! The YIN pitch detection algorithm is based on the algorithm from the paper 2 | //! *[YIN, a fundamental frequency estimator for speech and music](http://recherche.ircam.fr/equipes/pcm/cheveign/ps/2002_JASA_YIN_proof.pdf)*. 3 | //! It is efficient and offers an improvement over basic autocorrelation. 4 | //! 5 | //! The YIN pitch detection algorithm is similar to the [McLeod][crate::detector::mcleod], but it is based on 6 | //! a different normalization of the *mean square difference function*. 7 | //! 8 | //! Let $S=(s_0,s_1,\ldots,s_N)$ be a discrete signal. The *mean square difference function* at time $t$ 9 | //! is defined by 10 | //! $$ d(t) = \sum_{i=0}^{N-t} (s_i-s_{i+t})^2. $$ 11 | //! This function is close to zero when the signal "lines up" with itself. However, *close* is a relative term, 12 | //! and the value of $d\'(t)$ depends on volume, which should not affect the pitch of the signal. For this 13 | //! reason, the signal is normalized. The YIN algorithm computes the *cumulative mean normalized difference function*, 14 | //! $$ d\'(t) = \begin{cases}1&\text{if }t=0\\\\ d(t) / \left[ \tfrac{1}{t}\sum_{i=0}^t d(i) \right] & \text{otherwise}\end{cases}. $$ 15 | //! Then, it searches for the first local minimum of $d\'(t)$ below a given threshold. 16 | //! 17 | //! ## Implementation 18 | //! Rather than compute the cumulative mean normalized difference function directly, 19 | //! an [FFT](https://en.wikipedia.org/wiki/Fast_Fourier_transform) is used, providing a dramatic speed increase for large buffers. 20 | //! 21 | //! After a candidate frequency is found, quadratic interpolation is applied to further refine the estimate. 22 | //! 23 | //! The current implementation does not perform *Step 6* of the algorithm specified in the YIN paper. 24 | 25 | use crate::detector::internals::pitch_from_peaks; 26 | use crate::detector::internals::Pitch; 27 | use crate::detector::PitchDetector; 28 | use crate::float::Float; 29 | use crate::utils::buffer::square_sum; 30 | use crate::utils::peak::PeakCorrection; 31 | 32 | use super::internals::{windowed_square_error, yin_normalize_square_error, DetectorInternals}; 33 | 34 | pub struct YINDetector 35 | where 36 | T: Float + std::iter::Sum, 37 | { 38 | internals: DetectorInternals, 39 | } 40 | 41 | impl YINDetector 42 | where 43 | T: Float + std::iter::Sum, 44 | { 45 | pub fn new(size: usize, padding: usize) -> Self { 46 | let internals = DetectorInternals::::new(size, padding); 47 | YINDetector { internals } 48 | } 49 | } 50 | 51 | /// Pitch detection based on the YIN algorithm. See 52 | impl PitchDetector for YINDetector 53 | where 54 | T: Float + std::iter::Sum, 55 | { 56 | fn get_pitch( 57 | &mut self, 58 | signal: &[T], 59 | sample_rate: usize, 60 | power_threshold: T, 61 | clarity_threshold: T, 62 | ) -> Option> { 63 | // The YIN paper uses 0.1 as a threshold; TarsosDSP uses 0.2. `threshold` is not quite 64 | // the same thing as 1 - clarity, but it should be close enough. 65 | let threshold = T::one() - clarity_threshold; 66 | let window_size = signal.len() / 2; 67 | 68 | assert_eq!(signal.len(), self.internals.size); 69 | 70 | if square_sum(signal) < power_threshold { 71 | return None; 72 | } 73 | 74 | let result_ref = self.internals.buffers.get_real_buffer(); 75 | let result = &mut result_ref.borrow_mut()[..window_size]; 76 | 77 | // STEP 2: Calculate the difference function, d_t. 78 | windowed_square_error(signal, window_size, &mut self.internals.buffers, result); 79 | 80 | // STEP 3: Calculate the cumulative mean normalized difference function, d_t'. 81 | yin_normalize_square_error(result); 82 | 83 | // STEP 4: The absolute threshold. We want the first dip below `threshold`. 84 | // The YIN paper looks for minimum peaks. Since `pitch_from_peaks` looks 85 | // for maximums, we take this opportunity to invert the signal. 86 | result.iter_mut().for_each(|val| *val = threshold - *val); 87 | 88 | // STEP 5: Find the peak and use quadratic interpolation to fine-tune the result 89 | pitch_from_peaks(result, sample_rate, T::zero(), PeakCorrection::Quadratic).map(|pitch| { 90 | Pitch { 91 | frequency: pitch.frequency, 92 | // A `clarity` is not given by the YIN algorithm. However, we can 93 | // say a pitch has higher clarity if it's YIN normalized square error is closer to zero. 94 | // We can then take 1 - YIN error and report that as `clarity`. 95 | clarity: T::one() - threshold + pitch.clarity / threshold, 96 | } 97 | }) 98 | 99 | // STEP 6: TODO. Step 6 of the YIN paper can eek out a little more accuracy/consistency, but 100 | // it also involves computing over a much larger window. 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/utils/peak.rs: -------------------------------------------------------------------------------- 1 | use crate::float::Float; 2 | 3 | pub enum PeakCorrection { 4 | Quadratic, 5 | None, 6 | } 7 | 8 | /// An iterator that returns the positive peaks of `self.data`, 9 | /// skipping over any initial positive values (i.e., every peak 10 | /// is preceded by negative values). 11 | struct PeaksIter<'a, T> { 12 | index: usize, 13 | data: &'a [T], 14 | } 15 | 16 | impl<'a, T: Float> PeaksIter<'a, T> { 17 | fn new(arr: &'a [T]) -> Self { 18 | Self { 19 | data: arr, 20 | index: 0, 21 | } 22 | } 23 | } 24 | 25 | impl<'a, T: Float> Iterator for PeaksIter<'a, T> { 26 | type Item = (usize, T); 27 | 28 | fn next(&mut self) -> Option<(usize, T)> { 29 | let mut idx = self.index; 30 | let mut max = -T::infinity(); 31 | let mut max_index = idx; 32 | 33 | if idx >= self.data.len() { 34 | return None; 35 | } 36 | 37 | if idx == 0 { 38 | // If we're first starting iteration, we want to skip over 39 | // any positive values at the start of `self.data`. These are not 40 | // relevant for auto-correlation algorithms (since self.data[0] will always 41 | // be a global maximum for an autocorrelation). 42 | idx += self 43 | .data 44 | .iter() 45 | // `!val.is_sign_negative()` is used instead of `val.is_sign_positive()` 46 | // to make sure that any spurious NaN at the start are also skipped (NaN 47 | // is not sign positive and is not sign negative). 48 | .take_while(|val| !val.is_sign_negative()) 49 | .count(); 50 | } 51 | 52 | // Skip over the negative parts because we're looking for a positive peak! 53 | idx += self.data[idx..] 54 | .iter() 55 | .take_while(|val| val.is_sign_negative()) 56 | .count(); 57 | 58 | // Record the local max and max_index for the next stretch of positive values. 59 | for val in &self.data[idx..] { 60 | if val.is_sign_negative() { 61 | break; 62 | } 63 | if *val > max { 64 | max = *val; 65 | max_index = idx; 66 | } 67 | idx += 1; 68 | } 69 | 70 | self.index = idx; 71 | 72 | // We may not have found a maximum; the only time when this happens is when we've 73 | // reached the end of `self.data`. Alternatively, if `self.data` ends in a positive 74 | // segment we don't want to count `max` as a real maximum (since the data 75 | // was probably truncated in some way). In this case, we have `idx == self.data.len()`, 76 | // and so we terminate the iterator. 77 | if max == -T::infinity() || idx == self.data.len() { 78 | return None; 79 | } 80 | 81 | Some((max_index, max)) 82 | } 83 | } 84 | 85 | // Find `(index, value)` of positive peaks in `arr`. Every positive peak is preceded and succeeded 86 | // by negative values, so any initial positive segment of `arr` does not produce a peak. 87 | pub fn detect_peaks<'a, T: Float>(arr: &'a [T]) -> impl Iterator + 'a { 88 | PeaksIter::new(arr) 89 | } 90 | 91 | pub fn choose_peak, T: Float>( 92 | mut peaks: I, 93 | threshold: T, 94 | ) -> Option<(usize, T)> { 95 | peaks.find(|p| p.1 > threshold) 96 | } 97 | 98 | pub fn correct_peak(peak: (usize, T), data: &[T], correction: PeakCorrection) -> (T, T) { 99 | match correction { 100 | PeakCorrection::Quadratic => { 101 | let idx = peak.0; 102 | let (x, y) = find_quadratic_peak(data[idx - 1], data[idx], data[idx + 1]); 103 | (x + T::from_usize(idx).unwrap(), y) 104 | } 105 | PeakCorrection::None => (T::from_usize(peak.0).unwrap(), peak.1), 106 | } 107 | } 108 | 109 | /// Use a quadratic interpolation to find the maximum of 110 | /// a parabola passing through `(-1, y0)`, `(0, y1)`, `(1, y2)`. 111 | /// 112 | /// The output is of the form `(x-offset, peak value)`. 113 | fn find_quadratic_peak(y0: T, y1: T, y2: T) -> (T, T) { 114 | // The quadratic ax^2+bx+c passing through 115 | // (-1, y0), (0, y1), (1, y2), the 116 | // has coefficients 117 | // 118 | // a = y0/2 - y1 + y2/2 119 | // b = (y2 - y0)/2 120 | // c = y1 121 | // 122 | // and a maximum at x=-b/(2a) and y=-b^2/(4a) + c 123 | 124 | // Some constants 125 | let two = T::from_f64(2.).unwrap(); 126 | let four = T::from_f64(4.).unwrap(); 127 | 128 | let a = (y0 + y2) / two - y1; 129 | let b = (y2 - y0) / two; 130 | let c = y1; 131 | 132 | // If we're concave up, the maximum is at one of the end points 133 | if a > T::zero() { 134 | if y0 > y2 { 135 | return (-T::one(), y0); 136 | } 137 | return (T::one(), y2); 138 | } 139 | 140 | (-b / (two * a), -b * b / (four * a) + c) 141 | } 142 | 143 | #[cfg(test)] 144 | mod tests { 145 | use super::*; 146 | 147 | #[test] 148 | fn peak_correction() { 149 | fn quad1(x: f64) -> f64 { 150 | -x * x + 4.0 151 | } 152 | let (x, y) = find_quadratic_peak(quad1(-1.5), quad1(-0.5), quad1(0.5)); 153 | assert_eq!(x - 0.5, 0.0); 154 | assert_eq!(y, 4.0); 155 | let (x, y) = find_quadratic_peak(quad1(-3.5), quad1(-2.5), quad1(-1.5)); 156 | assert_eq!(x - 2.5, 0.0); 157 | assert_eq!(y, 4.0); 158 | 159 | fn quad2(x: f64) -> f64 { 160 | -2. * x * x + 3. * x - 2.5 161 | } 162 | 163 | let (x, y) = find_quadratic_peak(quad2(-1.), quad2(0.), quad2(1.)); 164 | assert_eq!(x, 0.75); 165 | assert_eq!(y, -1.375); 166 | 167 | // Test of concave-up parabolas 168 | fn quad3(x: f64) -> f64 { 169 | x * x + 2.0 170 | } 171 | 172 | let (x, y) = find_quadratic_peak(quad3(0.), quad3(1.), quad3(2.)); 173 | assert_eq!(x + 1., 2.); 174 | assert_eq!(y, 6.); 175 | let (x, y) = find_quadratic_peak(quad3(-2.), quad3(-1.), quad3(0.)); 176 | assert_eq!(x - 1., -2.); 177 | assert_eq!(y, 6.); 178 | } 179 | 180 | #[test] 181 | fn detect_peaks_test() { 182 | let v = vec![-2., -1., 1., 2., 1., -1., 4., -3., -2., 1., 1., -1.]; 183 | let peaks: Vec<(usize, f64)> = detect_peaks(v.as_slice()).collect(); 184 | assert_eq!(peaks, [(3, 2.), (6, 4.), (9, 1.)]); 185 | 186 | let v = vec![1., 2., 1., -1., 2., -3., -2., 1., 1., -1.]; 187 | let peaks: Vec<(usize, f64)> = detect_peaks(v.as_slice()).collect(); 188 | assert_eq!(peaks, [(4, 2.), (7, 1.)]); 189 | 190 | let v = vec![1., 2., 1., -1., 2., -3., -2., 1., 1.]; 191 | let peaks: Vec<(usize, f64)> = detect_peaks(v.as_slice()).collect(); 192 | assert_eq!(peaks, [(4, 2.)]); 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/utils/buffer.rs: -------------------------------------------------------------------------------- 1 | use rustfft::num_complex::Complex; 2 | use rustfft::num_traits::Zero; 3 | use std::{cell::RefCell, rc::Rc}; 4 | 5 | use crate::float::Float; 6 | 7 | pub enum ComplexComponent { 8 | Re, 9 | Im, 10 | } 11 | 12 | pub fn new_real_buffer(size: usize) -> Vec { 13 | vec![T::zero(); size] 14 | } 15 | 16 | pub fn new_complex_buffer(size: usize) -> Vec> { 17 | vec![Complex::zero(); size] 18 | } 19 | 20 | pub fn copy_real_to_complex( 21 | input: &[T], 22 | output: &mut [Complex], 23 | component: ComplexComponent, 24 | ) { 25 | assert!(input.len() <= output.len()); 26 | match component { 27 | ComplexComponent::Re => input.iter().zip(output.iter_mut()).for_each(|(i, o)| { 28 | o.re = *i; 29 | o.im = T::zero(); 30 | }), 31 | ComplexComponent::Im => input.iter().zip(output.iter_mut()).for_each(|(i, o)| { 32 | o.im = *i; 33 | o.re = T::zero(); 34 | }), 35 | } 36 | output[input.len()..] 37 | .iter_mut() 38 | .for_each(|o| *o = Complex::zero()) 39 | } 40 | 41 | pub fn copy_complex_to_real( 42 | input: &[Complex], 43 | output: &mut [T], 44 | component: ComplexComponent, 45 | ) { 46 | assert!(input.len() <= output.len()); 47 | match component { 48 | ComplexComponent::Re => input 49 | .iter() 50 | .map(|c| c.re) 51 | .zip(output.iter_mut()) 52 | .for_each(|(i, o)| *o = i), 53 | ComplexComponent::Im => input 54 | .iter() 55 | .map(|c| c.im) 56 | .zip(output.iter_mut()) 57 | .for_each(|(i, o)| *o = i), 58 | } 59 | 60 | output[input.len()..] 61 | .iter_mut() 62 | .for_each(|o| *o = T::zero()); 63 | } 64 | 65 | /// Computes |x|^2 for each complex value x in `arr`. This function 66 | /// modifies `arr` in place and leaves the complex component zero. 67 | pub fn modulus_squared<'a, T: Float>(arr: &'a mut [Complex]) { 68 | for mut s in arr { 69 | s.re = s.re * s.re + s.im * s.im; 70 | s.im = T::zero(); 71 | } 72 | } 73 | 74 | /// Compute the sum of the square of each element of `arr`. 75 | pub fn square_sum(arr: &[T]) -> T 76 | where 77 | T: Float + std::iter::Sum, 78 | { 79 | arr.iter().map(|&s| s * s).sum::() 80 | } 81 | 82 | #[derive(Debug)] 83 | /// A pool of real/complex buffer objects. Buffers are dynamically created as needed 84 | /// and reused if previously `Drop`ed. Buffers are never freed. Instead buffers are kept 85 | /// in reserve and reused when a new buffer is requested. 86 | /// 87 | /// ```rust 88 | /// use pitch_detection::utils::buffer::BufferPool; 89 | /// 90 | /// let mut buffers = BufferPool::new(3); 91 | /// let buf_cell1 = buffers.get_real_buffer(); 92 | /// { 93 | /// // This buffer won't be dropped until the end of the function 94 | /// let mut buf1 = buf_cell1.borrow_mut(); 95 | /// buf1[0] = 5.5; 96 | /// } 97 | /// { 98 | /// // This buffer will be dropped when the scope ends 99 | /// let buf_cell2 = buffers.get_real_buffer(); 100 | /// let mut buf2 = buf_cell2.borrow_mut(); 101 | /// buf2[1] = 6.6; 102 | /// } 103 | /// { 104 | /// // This buffer will be dropped when the scope ends 105 | /// // It is the same buffer that was just used (i.e., it's a reused buffer) 106 | /// let buf_cell3 = buffers.get_real_buffer(); 107 | /// let mut buf3 = buf_cell3.borrow_mut(); 108 | /// buf3[2] = 7.7; 109 | /// } 110 | /// // The first buffer we asked for should not have been reused. 111 | /// assert_eq!(&buf_cell1.borrow()[..], &[5.5, 0., 0.]); 112 | /// let buf_cell2 = buffers.get_real_buffer(); 113 | /// // The second buffer was reused because it was dropped and then another buffer was requested. 114 | /// assert_eq!(&buf_cell2.borrow()[..], &[0.0, 6.6, 7.7]); 115 | /// ``` 116 | pub struct BufferPool { 117 | real_buffers: Vec>>>, 118 | complex_buffers: Vec>>>>, 119 | pub buffer_size: usize, 120 | } 121 | 122 | impl BufferPool { 123 | pub fn new(buffer_size: usize) -> Self { 124 | BufferPool { 125 | real_buffers: vec![], 126 | complex_buffers: vec![], 127 | buffer_size, 128 | } 129 | } 130 | fn add_real_buffer(&mut self) -> Rc>> { 131 | self.real_buffers 132 | .push(Rc::new(RefCell::new(new_real_buffer::( 133 | self.buffer_size, 134 | )))); 135 | Rc::clone(&self.real_buffers.last().unwrap()) 136 | } 137 | fn add_complex_buffer(&mut self) -> Rc>>> { 138 | self.complex_buffers 139 | .push(Rc::new(RefCell::new(new_complex_buffer::( 140 | self.buffer_size, 141 | )))); 142 | Rc::clone(&self.complex_buffers.last().unwrap()) 143 | } 144 | /// Get a reference to a buffer that can e used until it is `Drop`ed. Call 145 | /// `.borrow_mut()` to get a reference to a mutable version of the buffer. 146 | pub fn get_real_buffer(&mut self) -> Rc>> { 147 | self.real_buffers 148 | .iter() 149 | // If the Rc count is 1, we haven't loaned the buffer out yet. 150 | .find(|&buf| Rc::strong_count(buf) == 1) 151 | .map(|buf| Rc::clone(buf)) 152 | // If we haven't found a buffer we can reuse, create one. 153 | .unwrap_or_else(|| self.add_real_buffer()) 154 | } 155 | /// Get a reference to a buffer that can e used until it is `Drop`ed. Call 156 | /// `.borrow_mut()` to get a reference to a mutable version of the buffer. 157 | pub fn get_complex_buffer(&mut self) -> Rc>>> { 158 | self.complex_buffers 159 | .iter() 160 | // If the Rc count is 1, we haven't loaned the buffer out yet. 161 | .find(|&buf| Rc::strong_count(buf) == 1) 162 | .map(|buf| Rc::clone(buf)) 163 | // If we haven't found a buffer we can reuse, create one. 164 | .unwrap_or_else(|| self.add_complex_buffer()) 165 | } 166 | } 167 | 168 | #[test] 169 | fn test_buffers() { 170 | let mut buffers = BufferPool::new(3); 171 | let buf_cell1 = buffers.get_real_buffer(); 172 | { 173 | // This buffer won't be dropped until the end of the function 174 | let mut buf1 = buf_cell1.borrow_mut(); 175 | buf1[0] = 5.5; 176 | } 177 | { 178 | // This buffer will be dropped when the scope ends 179 | let buf_cell2 = buffers.get_real_buffer(); 180 | let mut buf2 = buf_cell2.borrow_mut(); 181 | buf2[1] = 6.6; 182 | } 183 | { 184 | // This buffer will be dropped when the scope ends 185 | // It is the same buffer that was just used (i.e., it's a reused buffer) 186 | let buf_cell3 = buffers.get_real_buffer(); 187 | let mut buf3 = buf_cell3.borrow_mut(); 188 | buf3[2] = 7.7; 189 | } 190 | // We're peering into the internals of `BufferPool`. This shouldn't normally be done. 191 | assert_eq!(&buffers.real_buffers[0].borrow()[..], &[5.5, 0., 0.]); 192 | assert_eq!(&buffers.real_buffers[1].borrow()[..], &[0.0, 6.6, 7.7]); 193 | } 194 | -------------------------------------------------------------------------------- /src/detector/internals.rs: -------------------------------------------------------------------------------- 1 | use rustfft::FftPlanner; 2 | 3 | use crate::utils::buffer::ComplexComponent; 4 | use crate::utils::buffer::{copy_complex_to_real, square_sum}; 5 | use crate::utils::buffer::{copy_real_to_complex, BufferPool}; 6 | use crate::utils::peak::choose_peak; 7 | use crate::utils::peak::correct_peak; 8 | use crate::utils::peak::detect_peaks; 9 | use crate::utils::peak::PeakCorrection; 10 | use crate::{float::Float, utils::buffer::modulus_squared}; 11 | 12 | /// A pitch's `frequency` as well as `clarity`, which is a measure 13 | /// of confidence in the pitch detection. 14 | pub struct Pitch 15 | where 16 | T: Float, 17 | { 18 | pub frequency: T, 19 | pub clarity: T, 20 | } 21 | 22 | /// Data structure to hold any buffers needed for pitch computation. 23 | /// For WASM it's best to allocate buffers once rather than allocate and 24 | /// free buffers repeatedly, so we use a `BufferPool` object to manage the buffers. 25 | pub struct DetectorInternals 26 | where 27 | T: Float, 28 | { 29 | pub size: usize, 30 | pub padding: usize, 31 | pub buffers: BufferPool, 32 | } 33 | 34 | impl DetectorInternals 35 | where 36 | T: Float, 37 | { 38 | pub fn new(size: usize, padding: usize) -> Self { 39 | let buffers = BufferPool::new(size + padding); 40 | 41 | DetectorInternals { 42 | size, 43 | padding, 44 | buffers, 45 | } 46 | } 47 | } 48 | 49 | /// Compute the autocorrelation of `signal` to `result`. All buffers but `signal` 50 | /// may be used as scratch. 51 | pub fn autocorrelation(signal: &[T], buffers: &mut BufferPool, result: &mut [T]) 52 | where 53 | T: Float, 54 | { 55 | let (ref1, ref2) = (buffers.get_complex_buffer(), buffers.get_complex_buffer()); 56 | let signal_complex = &mut ref1.borrow_mut()[..]; 57 | let scratch = &mut ref2.borrow_mut()[..]; 58 | 59 | let mut planner = FftPlanner::new(); 60 | let fft = planner.plan_fft_forward(signal_complex.len()); 61 | let inv_fft = planner.plan_fft_inverse(signal_complex.len()); 62 | 63 | // Compute the autocorrelation 64 | copy_real_to_complex(signal, signal_complex, ComplexComponent::Re); 65 | fft.process_with_scratch(signal_complex, scratch); 66 | modulus_squared(signal_complex); 67 | inv_fft.process_with_scratch(signal_complex, scratch); 68 | copy_complex_to_real(signal_complex, result, ComplexComponent::Re); 69 | } 70 | 71 | pub fn pitch_from_peaks( 72 | input: &[T], 73 | sample_rate: usize, 74 | clarity_threshold: T, 75 | correction: PeakCorrection, 76 | ) -> Option> 77 | where 78 | T: Float, 79 | { 80 | let sample_rate = T::from_usize(sample_rate).unwrap(); 81 | let peaks = detect_peaks(input); 82 | 83 | choose_peak(peaks, clarity_threshold) 84 | .map(|peak| correct_peak(peak, input, correction)) 85 | .map(|peak| Pitch { 86 | frequency: sample_rate / peak.0, 87 | clarity: peak.1 / input[0], 88 | }) 89 | } 90 | 91 | fn m_of_tau(signal: &[T], signal_square_sum: Option, result: &mut [T]) 92 | where 93 | T: Float + std::iter::Sum, 94 | { 95 | assert!(result.len() >= signal.len()); 96 | 97 | let signal_square_sum = signal_square_sum.unwrap_or_else(|| square_sum(signal)); 98 | 99 | let start = T::from_usize(2).unwrap() * signal_square_sum; 100 | result[0] = start; 101 | let last = result[1..] 102 | .iter_mut() 103 | .zip(signal) 104 | .fold(start, |old, (r, &s)| { 105 | *r = old - s * s; 106 | *r 107 | }); 108 | // Pad the end of `result` with the last value 109 | result[signal.len()..].iter_mut().for_each(|r| *r = last); 110 | } 111 | 112 | pub fn normalized_square_difference(signal: &[T], buffers: &mut BufferPool, result: &mut [T]) 113 | where 114 | T: Float + std::iter::Sum, 115 | { 116 | let two = T::from_usize(2).unwrap(); 117 | 118 | let scratch_ref = buffers.get_real_buffer(); 119 | let scratch = &mut scratch_ref.borrow_mut()[..]; 120 | 121 | autocorrelation(signal, buffers, result); 122 | m_of_tau(signal, Some(result[0]), scratch); 123 | result 124 | .iter_mut() 125 | .zip(scratch) 126 | .for_each(|(r, s)| *r = two * *r / *s) 127 | } 128 | 129 | /// Compute the windowed autocorrelation of `signal` and put the result in `result`. 130 | /// For a signal _x=(x_0,x_1,...)_, the windowed autocorrelation with window size _w_ is 131 | /// the function 132 | /// 133 | /// > r(t) = sum_{i=0}^{w-1} x_i*x_{i+t} 134 | /// 135 | /// This function assumes `window_size` is at most half of the length of `signal`. 136 | pub fn windowed_autocorrelation( 137 | signal: &[T], 138 | window_size: usize, 139 | buffers: &mut BufferPool, 140 | result: &mut [T], 141 | ) where 142 | T: Float + std::iter::Sum, 143 | { 144 | assert!( 145 | buffers.buffer_size >= signal.len(), 146 | "Buffers must have a length at least equal to `signal`." 147 | ); 148 | 149 | let mut planner = FftPlanner::new(); 150 | let fft = planner.plan_fft_forward(signal.len()); 151 | let inv_fft = planner.plan_fft_inverse(signal.len()); 152 | 153 | let (scratch_ref1, scratch_ref2, scratch_ref3) = ( 154 | buffers.get_complex_buffer(), 155 | buffers.get_complex_buffer(), 156 | buffers.get_complex_buffer(), 157 | ); 158 | 159 | let signal_complex = &mut scratch_ref1.borrow_mut()[..signal.len()]; 160 | let truncated_signal_complex = &mut scratch_ref2.borrow_mut()[..signal.len()]; 161 | let scratch = &mut scratch_ref3.borrow_mut()[..signal.len()]; 162 | 163 | // To achieve the windowed autocorrelation, we compute the cross correlation between 164 | // the original signal and the signal truncated to lie in `0..window_size` 165 | copy_real_to_complex(signal, signal_complex, ComplexComponent::Re); 166 | copy_real_to_complex( 167 | &signal[..window_size], 168 | truncated_signal_complex, 169 | ComplexComponent::Re, 170 | ); 171 | fft.process_with_scratch(signal_complex, scratch); 172 | fft.process_with_scratch(truncated_signal_complex, scratch); 173 | // rustfft doesn't normalize when it computes the fft, so we need to normalize ourselves by 174 | // dividing by `sqrt(signal.len())` each time we take an fft or inverse fft. 175 | // Since the fft is linear and we are doing fft -> inverse fft, we can just divide by 176 | // `signal.len()` once. 177 | let normalization_const = T::one() / T::from_usize(signal.len()).unwrap(); 178 | signal_complex 179 | .iter_mut() 180 | .zip(truncated_signal_complex.iter()) 181 | .for_each(|(a, b)| { 182 | *a = *a * normalization_const * b.conj(); 183 | }); 184 | inv_fft.process_with_scratch(signal_complex, scratch); 185 | 186 | // The result is valid only for `0..window_size` 187 | copy_complex_to_real(&signal_complex[..window_size], result, ComplexComponent::Re); 188 | } 189 | 190 | /// Compute the windowed square error, _d(t)_, of `signal`. For a window size of _w_ and a signal 191 | /// _x=(x_0,x_1,...)_, this is defined by 192 | /// 193 | /// > d(t) = sum_{i=0}^{w-1} (x_i - x_{i+t})^2 194 | /// 195 | /// This function is computed efficiently using an FFT. It is assumed that `window_size` is at most half 196 | /// the length of `signal`. 197 | pub fn windowed_square_error( 198 | signal: &[T], 199 | window_size: usize, 200 | buffers: &mut BufferPool, 201 | result: &mut [T], 202 | ) where 203 | T: Float + std::iter::Sum, 204 | { 205 | assert!( 206 | 2 * window_size <= signal.len(), 207 | "The window size cannot be more than half the signal length" 208 | ); 209 | 210 | let two = T::from_f64(2.).unwrap(); 211 | 212 | // The windowed square error function, d(t), can be computed 213 | // as d(t) = pow_0^w + pow_t^{t+w} - 2*windowed_autocorrelation(t) 214 | // where pow_a^b is the sum of the square of `signal` on the window `a..b` 215 | // We proceed accordingly. 216 | windowed_autocorrelation(signal, window_size, buffers, result); 217 | let mut windowed_power = square_sum(&signal[..window_size]); 218 | let power = windowed_power; 219 | 220 | result.iter_mut().enumerate().for_each(|(i, a)| { 221 | // use the formula pow_0^w + pow_t^{t+w} - 2*windowed_autocorrelation(t) 222 | *a = power + windowed_power - two * *a; 223 | // Since we're processing everything in order, we can computed pow_{t+1}^{t+1+w} 224 | // directly from pow_t^{t+w} by adding and subtracting the boundary terms. 225 | windowed_power = windowed_power - signal[i] * signal[i] 226 | + signal[i + window_size] * signal[i + window_size]; 227 | }) 228 | } 229 | 230 | /// Calculate the "cumulative mean normalized difference function" as 231 | /// specified in the YIN paper. If _d(t)_ is the square error function, 232 | /// compute _d'(0) = 1_ and for _t > 0_ 233 | /// 234 | /// > d'(t) = d(t) / [ (1/t) * sum_{i=0}^t d(i) ] 235 | pub fn yin_normalize_square_error(square_error: &mut [T]) { 236 | let mut sum = T::zero(); 237 | square_error[0] = T::one(); 238 | // square_error[0] should always be zero, so we don't need to worry about 239 | // adding this to our sum. 240 | square_error 241 | .iter_mut() 242 | .enumerate() 243 | .skip(1) 244 | .for_each(|(i, a)| { 245 | sum = sum + *a; 246 | *a = *a * T::from_usize(i + 1).unwrap() / sum; 247 | }); 248 | } 249 | 250 | #[cfg(test)] 251 | mod tests { 252 | use super::*; 253 | 254 | #[test] 255 | fn windowed_autocorrelation_test() { 256 | let signal: Vec = vec![0., 1., 2., 0., -1., -2.]; 257 | let window_size: usize = 3; 258 | 259 | let buffers = &mut BufferPool::new(signal.len()); 260 | 261 | let result: Vec = (0..window_size) 262 | .map(|i| { 263 | signal[..window_size] 264 | .iter() 265 | .zip(signal[i..(i + window_size)].iter()) 266 | .map(|(a, b)| *a * *b) 267 | .sum() 268 | }) 269 | .collect(); 270 | 271 | let mut computed_result = vec![0.; window_size]; 272 | windowed_autocorrelation(&signal, window_size, buffers, &mut computed_result); 273 | // Using an FFT loses precision; we don't care that much, so round generously. 274 | computed_result 275 | .iter_mut() 276 | .for_each(|x| *x = (*x * 100.).round() / 100.); 277 | 278 | assert_eq!(result, computed_result); 279 | } 280 | 281 | #[test] 282 | fn windowed_square_error_test() { 283 | let signal: Vec = vec![0., 1., 2., 0., -1., -2.]; 284 | let window_size: usize = 3; 285 | 286 | let buffers = &mut BufferPool::new(signal.len()); 287 | 288 | let result: Vec = (0..window_size) 289 | .map(|i| { 290 | signal[..window_size] 291 | .iter() 292 | .zip(signal[i..(i + window_size)].iter()) 293 | .map(|(x_j, x_j_tau)| (*x_j - *x_j_tau) * (*x_j - *x_j_tau)) 294 | .sum() 295 | }) 296 | .collect(); 297 | 298 | let mut computed_result = vec![0.; window_size]; 299 | windowed_square_error(&signal, window_size, buffers, &mut computed_result); 300 | // Using an FFT loses precision; we don't care that much, so round generously. 301 | computed_result 302 | .iter_mut() 303 | .for_each(|x| *x = (*x * 100.).round() / 100.); 304 | 305 | assert_eq!(result, computed_result); 306 | } 307 | #[test] 308 | fn yin_normalized_square_error_test() { 309 | let signal: &mut Vec = &mut vec![0., 6., 14.]; 310 | let result = vec![1., 2., 3. * 14. / (6. + 14.)]; 311 | 312 | yin_normalize_square_error(signal); 313 | 314 | assert_eq!(result, *signal); 315 | } 316 | } 317 | -------------------------------------------------------------------------------- /tests/main.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use pitch_detection::detector::mcleod::McLeodDetector; 4 | use pitch_detection::detector::PitchDetector; 5 | use pitch_detection::detector::{autocorrelation::AutocorrelationDetector, yin::YINDetector}; 6 | use pitch_detection::float::Float; 7 | use pitch_detection::utils::buffer::new_real_buffer; 8 | 9 | // For reading in `.wav` files 10 | use hound; 11 | 12 | #[derive(Debug)] 13 | struct Signal { 14 | sample_rate: usize, 15 | data: Vec, 16 | } 17 | 18 | #[test] 19 | fn autocorrelation_sin_signal() { 20 | pure_frequency(String::from("Autocorrelation"), String::from("sin"), 440.0); 21 | } 22 | 23 | #[test] 24 | fn mcleod_sin_signal() { 25 | pure_frequency(String::from("McLeod"), String::from("sin"), 440.0); 26 | } 27 | 28 | #[test] 29 | fn yin_sin_signal() { 30 | pure_frequency(String::from("YIN"), String::from("sin"), 440.0); 31 | } 32 | 33 | #[test] 34 | fn autocorrelation_square_signal() { 35 | pure_frequency( 36 | String::from("Autocorrelation"), 37 | String::from("square"), 38 | 440.0, 39 | ); 40 | } 41 | 42 | #[test] 43 | fn mcleod_square_signal() { 44 | pure_frequency(String::from("McLeod"), String::from("square"), 440.0); 45 | } 46 | 47 | #[test] 48 | fn yin_square_signal() { 49 | pure_frequency(String::from("YIN"), String::from("square"), 440.0); 50 | } 51 | 52 | #[test] 53 | fn autocorrelation_triangle_signal() { 54 | pure_frequency( 55 | String::from("Autocorrelation"), 56 | String::from("triangle"), 57 | 440.0, 58 | ); 59 | } 60 | 61 | #[test] 62 | fn mcleod_triangle_signal() { 63 | pure_frequency(String::from("McLeod"), String::from("triangle"), 440.0); 64 | } 65 | 66 | #[test] 67 | fn yin_triangle_signal() { 68 | pure_frequency(String::from("YIN"), String::from("triangle"), 440.0); 69 | } 70 | 71 | #[test] 72 | fn autocorrelation_violin_d4() { 73 | let signal: Signal = wav_file_to_signal(samples_path("violin-D4.wav"), 0, 10 * 1024); 74 | 75 | raw_frequency("Autocorrelation".into(), signal, 293.); 76 | } 77 | 78 | #[test] 79 | fn mcleod_violin_d4() { 80 | let signal: Signal = wav_file_to_signal(samples_path("violin-D4.wav"), 0, 10 * 1024); 81 | 82 | raw_frequency("McLeod".into(), signal, 293.); 83 | } 84 | 85 | #[test] 86 | fn autocorrelation_violin_f4() { 87 | let signal: Signal = wav_file_to_signal(samples_path("violin-F4.wav"), 0, 10 * 1024); 88 | 89 | raw_frequency("Autocorrelation".into(), signal, 349.); 90 | } 91 | 92 | #[test] 93 | fn mcleod_violin_f4() { 94 | let signal: Signal = wav_file_to_signal(samples_path("violin-F4.wav"), 0, 10 * 1024); 95 | 96 | raw_frequency("McLeod".into(), signal, 349.); 97 | } 98 | 99 | #[test] 100 | fn autocorrelation_violin_g4() { 101 | let signal: Signal = wav_file_to_signal(samples_path("violin-G4.wav"), 0, 10 * 1024); 102 | 103 | raw_frequency("Autocorrelation".into(), signal, 392.); 104 | } 105 | 106 | #[test] 107 | fn mcleod_violin_g4() { 108 | let signal: Signal = wav_file_to_signal(samples_path("violin-G4.wav"), 0, 10 * 1024); 109 | 110 | raw_frequency("McLeod".into(), signal, 392.); 111 | } 112 | 113 | #[test] 114 | fn mcleod_tenor_trombone_c3() { 115 | let signal: Signal = 116 | wav_file_to_signal(samples_path("tenor-trombone-C3.wav"), 0, 10 * 1024); 117 | 118 | raw_frequency("McLeod".into(), signal, 130.); 119 | } 120 | 121 | #[test] 122 | fn mcleod_tenor_trombone_db3() { 123 | let signal: Signal = 124 | wav_file_to_signal(samples_path("tenor-trombone-Db3.wav"), 0, 10 * 1024); 125 | 126 | raw_frequency("McLeod".into(), signal, 138.); 127 | } 128 | 129 | #[test] 130 | fn mcleod_tenor_trombone_ab3() { 131 | let signal: Signal = 132 | wav_file_to_signal(samples_path("tenor-trombone-Ab3.wav"), 0, 10 * 1024); 133 | 134 | raw_frequency("McLeod".into(), signal, 207.); 135 | } 136 | 137 | #[test] 138 | fn mcleod_tenor_trombone_b3() { 139 | let signal: Signal = 140 | wav_file_to_signal(samples_path("tenor-trombone-B3.wav"), 0, 10 * 1024); 141 | 142 | raw_frequency("McLeod".into(), signal, 246.); 143 | } 144 | 145 | fn get_chunk(signal: &[T], start: usize, window: usize, output: &mut [T]) { 146 | let start = match signal.len() > start { 147 | true => start, 148 | false => signal.len(), 149 | }; 150 | 151 | let stop = match signal.len() >= start + window { 152 | true => start + window, 153 | false => signal.len(), 154 | }; 155 | 156 | for i in 0..stop - start { 157 | output[i] = signal[start + i]; 158 | } 159 | 160 | for i in stop - start..output.len() { 161 | output[i] = T::zero(); 162 | } 163 | } 164 | 165 | fn wav_file_to_signal( 166 | file_name: String, 167 | seek_start: usize, 168 | num_samples: usize, 169 | ) -> Signal { 170 | println!("Opening \"{}\"", file_name); 171 | let mut reader = hound::WavReader::open(file_name).unwrap(); 172 | let sample_rate = reader.spec().sample_rate as usize; 173 | let data: Vec = reader 174 | .samples::() 175 | .skip(seek_start) 176 | .map(|s| T::from_i32(s.unwrap()).unwrap()) 177 | .take(num_samples) 178 | .collect(); 179 | 180 | Signal { sample_rate, data } 181 | } 182 | 183 | /// Get the full path of `wav` file specified by `file_name`. 184 | fn samples_path(file_name: &str) -> String { 185 | // `d` is an absolute path to the source directory of the project 186 | let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); 187 | // all audio samples are in this subfolder 188 | d.push("tests/samples"); 189 | d.push(file_name); 190 | 191 | d.to_str().unwrap().into() 192 | } 193 | 194 | fn sin_wave(freq: f64, size: usize, sample_rate: usize) -> Vec { 195 | let mut signal = new_real_buffer(size); 196 | let two_pi = 2.0 * std::f64::consts::PI; 197 | let dx = two_pi * freq / sample_rate as f64; 198 | for i in 0..size { 199 | let x = i as f64 * dx; 200 | let y = x.sin(); 201 | signal[i] = T::from(y).unwrap(); 202 | } 203 | signal 204 | } 205 | 206 | fn square_wave(freq: f64, size: usize, sample_rate: usize) -> Vec { 207 | let mut signal = new_real_buffer(size); 208 | let period = sample_rate as f64 / freq; 209 | 210 | for i in 0..size { 211 | let x = i as f64 / period; 212 | let frac = x - x.floor(); 213 | let y = match frac >= 0.5 { 214 | true => -1.0, 215 | false => 1.0, 216 | }; 217 | signal[i] = T::from(y).unwrap(); 218 | } 219 | signal 220 | } 221 | 222 | fn triangle_wave(freq: f64, size: usize, sample_rate: usize) -> Vec { 223 | let mut signal = new_real_buffer(size); 224 | let period = sample_rate as f64 / freq; 225 | 226 | for i in 0..size { 227 | let x = i as f64 / period; 228 | let frac = x - x.floor(); 229 | let y = match frac { 230 | f if f >= 0. && f < 0.25 => 4. * f, 231 | f if f >= 0.25 && f < 0.75 => 1. - 4. * (f - 0.25), 232 | f if f >= 0.75 && f < 1. => -1. + 4. * (f - 0.75), 233 | _ => panic!("Should be between 0 and 1"), 234 | }; 235 | signal[i] = T::from(y).unwrap(); 236 | } 237 | signal 238 | } 239 | 240 | fn saw_wave(freq: f64, size: usize, sample_rate: usize) -> Vec { 241 | let mut signal = new_real_buffer(size); 242 | let period = sample_rate as f64 / freq; 243 | 244 | for i in 0..size { 245 | let x = i as f64 / period; 246 | let frac = x - x.floor(); 247 | let y = match frac { 248 | f if f >= 0. && f < 0.25 => 4. * f, 249 | f if f >= 0.25 && f < 0.75 => -1. + 4. * (f - 0.25), 250 | f if f >= 0.75 && f < 1. => -1. + 4. * (f - 0.75), 251 | _ => panic!("Should be between 0 and 1"), 252 | }; 253 | signal[i] = T::from(y).unwrap(); 254 | } 255 | signal 256 | } 257 | 258 | fn detector_factory(name: String, window: usize, padding: usize) -> Box> { 259 | match name.as_ref() { 260 | "McLeod" => { 261 | return Box::new(McLeodDetector::::new(window, padding)); 262 | } 263 | "Autocorrelation" => { 264 | return Box::new(AutocorrelationDetector::::new(window, padding)); 265 | } 266 | "YIN" => { 267 | return Box::new(YINDetector::::new(window, padding)); 268 | } 269 | _ => { 270 | panic!("Unknown detector {}", name); 271 | } 272 | } 273 | } 274 | 275 | fn signal_factory(name: String, freq: f64, size: usize, sample_rate: usize) -> Vec { 276 | match name.as_ref() { 277 | "sin" => { 278 | return sin_wave(freq, size, sample_rate); 279 | } 280 | "square" => { 281 | return square_wave(freq, size, sample_rate); 282 | } 283 | "triangle" => { 284 | return triangle_wave(freq, size, sample_rate); 285 | } 286 | "saw" => { 287 | return saw_wave(freq, size, sample_rate); 288 | } 289 | _ => { 290 | panic!("Unknown wave function {}", name); 291 | } 292 | } 293 | } 294 | 295 | fn pure_frequency(detector_name: String, wave_name: String, freq_in: f64) { 296 | const SAMPLE_RATE: usize = 48000; 297 | const DURATION: f64 = 4.0; 298 | const SAMPLE_SIZE: usize = (SAMPLE_RATE as f64 * DURATION) as usize; 299 | const WINDOW: usize = 1024; 300 | const PADDING: usize = WINDOW / 2; 301 | const DELTA_T: usize = WINDOW / 4; 302 | const N_WINDOWS: usize = (SAMPLE_SIZE - WINDOW) / DELTA_T; 303 | const POWER_THRESHOLD: f64 = 300.0; 304 | const CLARITY_THRESHOLD: f64 = 0.6; 305 | 306 | let signal = signal_factory::(wave_name, freq_in, SAMPLE_SIZE, SAMPLE_RATE); 307 | 308 | let mut chunk = new_real_buffer(WINDOW); 309 | 310 | let mut detector = detector_factory(detector_name, WINDOW, PADDING); 311 | 312 | for i in 0..N_WINDOWS { 313 | let t: usize = i * DELTA_T; 314 | get_chunk(&signal, t, WINDOW, &mut chunk); 315 | 316 | let pitch = detector.get_pitch(&chunk, SAMPLE_RATE, POWER_THRESHOLD, CLARITY_THRESHOLD); 317 | 318 | match pitch { 319 | Some(pitch) => { 320 | let frequency = pitch.frequency; 321 | let clarity = pitch.clarity; 322 | let idx = SAMPLE_RATE as f64 / frequency; 323 | let epsilon = (SAMPLE_RATE as f64 / (idx - 1.0)) - frequency; 324 | println!( 325 | "Chosen Peak idx: {}; clarity: {}; freq: {} +/- {}", 326 | idx, clarity, frequency, epsilon 327 | ); 328 | assert!((frequency - freq_in).abs() < 2. * epsilon); 329 | } 330 | None => { 331 | println!("No peaks accepted."); 332 | assert!(false); 333 | } 334 | } 335 | } 336 | } 337 | 338 | /// Test if the signal in `signal` is reasonably close to `freq_in`. 339 | fn raw_frequency(detector_name: String, signal: Signal, freq_in: f64) { 340 | const ERROR_TOLERANCE: f64 = 2.; 341 | let sample_rate = signal.sample_rate; 342 | let duration: f64 = signal.data.len() as f64 / sample_rate as f64; 343 | let sample_size: usize = (signal.sample_rate as f64 * duration) as usize; 344 | const WINDOW: usize = 1024; 345 | const PADDING: usize = WINDOW / 2; 346 | const DELTA_T: usize = WINDOW / 4; 347 | let n_windows: usize = (sample_size - WINDOW) / DELTA_T; 348 | const POWER_THRESHOLD: f64 = 300.0; 349 | const CLARITY_THRESHOLD: f64 = 0.6; 350 | 351 | let mut chunk = new_real_buffer(WINDOW); 352 | 353 | let mut detector = detector_factory(detector_name, WINDOW, PADDING); 354 | 355 | for i in 0..n_windows { 356 | let t: usize = i * DELTA_T; 357 | get_chunk(&signal.data, t, WINDOW, &mut chunk); 358 | 359 | let pitch = detector.get_pitch(&chunk, sample_rate, POWER_THRESHOLD, CLARITY_THRESHOLD); 360 | 361 | match pitch { 362 | Some(pitch) => { 363 | let frequency = pitch.frequency; 364 | let clarity = pitch.clarity; 365 | let idx = sample_rate as f64 / frequency; 366 | let epsilon = (sample_rate as f64 / (idx - 1.0)) - frequency; 367 | println!( 368 | "Chosen Peak idx: {}; clarity: {}; freq: {} +/- {}", 369 | idx, clarity, frequency, epsilon 370 | ); 371 | assert!((frequency - freq_in).abs() < ERROR_TOLERANCE * epsilon); 372 | } 373 | None => { 374 | println!("No peaks accepted."); 375 | assert!(false); 376 | } 377 | } 378 | } 379 | } 380 | --------------------------------------------------------------------------------