├── tests ├── __init__.py ├── gpu_test.py └── algorithms_test.py ├── pitch_detectors ├── evaluation │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── mir_1k.py │ │ └── mdb_stem_synth.py │ ├── __main__.py │ └── table.py ├── __init__.py ├── config.py ├── algorithms │ ├── world.py │ ├── rapt.py │ ├── swipe.py │ ├── crepe.py │ ├── torchyin.py │ ├── yin.py │ ├── pyin.py │ ├── praatac.py │ ├── praatcc.py │ ├── praatshs.py │ ├── yaapt.py │ ├── reaper.py │ ├── piptrack.py │ ├── __init__.py │ ├── torchcrepe.py │ ├── penn.py │ ├── spice.py │ ├── base.py │ └── ensemble.py ├── schemas.py └── util.py ├── .gitignore ├── data ├── b1a5da49d564a7341e7e1327aa3f229a.png └── b1a5da49d564a7341e7e1327aa3f229a.wav ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── scripts ├── download_models.py └── run_algorithm.py ├── Dockerfile ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── Taskfile.yml ├── pyproject.toml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pitch_detectors/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.0' 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | dist 3 | *.pyc 4 | *.egg-info 5 | -------------------------------------------------------------------------------- /pitch_detectors/config.py: -------------------------------------------------------------------------------- 1 | HZ_MIN = 75 2 | HZ_MAX = 600 3 | -------------------------------------------------------------------------------- /data/b1a5da49d564a7341e7e1327aa3f229a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tandav/pitch-detectors/HEAD/data/b1a5da49d564a7341e7e1327aa3f229a.png -------------------------------------------------------------------------------- /data/b1a5da49d564a7341e7e1327aa3f229a.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tandav/pitch-detectors/HEAD/data/b1a5da49d564a7341e7e1327aa3f229a.wav -------------------------------------------------------------------------------- /pitch_detectors/evaluation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pitch_detectors.evaluation.datasets.mdb_stem_synth import MDBStemSynth 2 | from pitch_detectors.evaluation.datasets.mir_1k import Mir1K 3 | 4 | all_datasets = MDBStemSynth, Mir1K 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" # See documentation for possible values 4 | directory: "/" # Location of package manifests 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: github-actions 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /scripts/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import s3fs 4 | 5 | s3 = s3fs.S3FileSystem( 6 | endpoint_url=os.environ['AWS_ENDPOINT_URL'], 7 | key=os.environ['AWS_ACCESS_KEY_ID'], 8 | secret=os.environ['AWS_SECRET_ACCESS_KEY'], 9 | ) 10 | 11 | s3.get('pitchtrack/spice_model', os.environ['PITCH_DETECTORS_SPICE_MODEL_PATH'], recursive=True) 12 | s3.get('pitchtrack/fcnf0++.pt', os.environ['PITCH_DETECTORS_PENN_CHECKPOINT_PATH']) 13 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/world.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors.algorithms.base import PitchDetector 4 | 5 | 6 | class World(PitchDetector): 7 | """https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder""" 8 | 9 | def __init__( 10 | self, 11 | a: np.ndarray, 12 | fs: int, 13 | ): 14 | import pyworld 15 | super().__init__(a, fs) 16 | f0, sp, ap = pyworld.wav2world(a.astype(float), fs) 17 | f0[f0 == 0] = np.nan 18 | self.f0 = f0 19 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 20 | -------------------------------------------------------------------------------- /scripts/run_algorithm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from pitch_detectors import algorithms 5 | from pitch_detectors import util 6 | 7 | 8 | def main(audio_path: str, algorithm: str) -> None: 9 | fs, a = util.load_wav(audio_path) 10 | alg = getattr(algorithms, algorithm)(a, fs) 11 | assert alg.f0.shape == alg.t.shape 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--audio-path', type=str, default=os.environ.get('PITCH_DETECTORS_AUDIO_PATH')) 17 | parser.add_argument('--algorithm', type=str, default=os.environ.get('PITCH_DETECTORS_ALGORITHM')) 18 | args = parser.parse_args() 19 | main(**vars(args)) 20 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/rapt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Rapt(PitchDetector): 8 | """https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.rapt.html""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import pysptk 18 | super().__init__(a, fs) 19 | f0 = pysptk.sptk.rapt(self.a, fs=self.fs, min=hz_min, max=hz_max, hopsize=250) 20 | f0[f0 == 0] = np.nan 21 | self.f0 = f0 22 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 23 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/swipe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Swipe(PitchDetector): 8 | """https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.swipe.html""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import pysptk 18 | super().__init__(a, fs) 19 | f0 = pysptk.sptk.swipe(self.a, fs=self.fs, min=hz_min, max=hz_max, hopsize=250) 20 | f0[f0 == 0] = np.nan 21 | self.f0 = f0 22 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 23 | -------------------------------------------------------------------------------- /tests/gpu_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import pytest 5 | 6 | 7 | @pytest.mark.order(0) 8 | @pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') 9 | def test_nvidia_smi(): 10 | subprocess.check_call('/usr/bin/nvidia-smi') 11 | 12 | 13 | @pytest.mark.order(1) 14 | @pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') 15 | def test_tensorflow(): 16 | import tensorflow as tf 17 | assert tf.config.experimental.list_physical_devices('GPU') 18 | 19 | 20 | @pytest.mark.order(2) 21 | @pytest.mark.skipif(os.environ.get('PITCH_DETECTORS_GPU') == 'false', reason='gpu is not used') 22 | def test_pytorch(): 23 | import torch 24 | assert torch.cuda.is_available() 25 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/crepe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors.algorithms.base import PitchDetector 4 | from pitch_detectors.algorithms.base import TensorflowGPU 5 | 6 | 7 | class Crepe(TensorflowGPU, PitchDetector): 8 | """https://github.com/marl/crepe""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | confidence_threshold: float = 0.8, 15 | gpu: bool | None = None, 16 | ): 17 | TensorflowGPU.__init__(self, gpu) 18 | PitchDetector.__init__(self, a, fs) 19 | import crepe 20 | 21 | self.t, self.f0, self.confidence, self.activation = crepe.predict(self.a, sr=self.fs, viterbi=True) 22 | self.f0[self.confidence < confidence_threshold] = np.nan 23 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/torchyin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class TorchYin(PitchDetector): 8 | """https://github.com/brentspell/torch-yin""" 9 | 10 | uses_gpu_framework = True 11 | 12 | def __init__( 13 | self, 14 | a: np.ndarray, 15 | fs: int, 16 | hz_min: float = config.HZ_MIN, 17 | hz_max: float = config.HZ_MAX, 18 | ): 19 | import torch 20 | import torchyin 21 | super().__init__(a, fs) 22 | _a = torch.from_numpy(a) 23 | f0 = torchyin.estimate(_a, sample_rate=self.fs, pitch_min=hz_min, pitch_max=hz_max).numpy() 24 | f0[f0 == 0] = np.nan 25 | self.f0 = f0[:-1] 26 | self.t = np.linspace(0, self.seconds, f0.shape[0])[1:] 27 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/yin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Yin(PitchDetector): 8 | """https://librosa.org/doc/latest/generated/librosa.yin.html#librosa.yin""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | trough_threshold: float = 0.1, 17 | ): 18 | import librosa 19 | super().__init__(a, fs) 20 | f0 = librosa.yin( 21 | self.a, sr=self.fs, fmin=hz_min, fmax=hz_max, 22 | frame_length=2048, 23 | trough_threshold=trough_threshold, 24 | ) 25 | self.f0 = f0 26 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 27 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/pyin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Pyin(PitchDetector): 8 | """https://librosa.org/doc/latest/generated/librosa.pyin.html""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import librosa 18 | super().__init__(a, fs) 19 | f0, voiced_flag, voiced_probs = librosa.pyin( 20 | self.a, sr=self.fs, fmin=hz_min, fmax=hz_max, 21 | resolution=0.1, # Resolution of the pitch bins. 0.01 corresponds to cents. 22 | frame_length=2048, 23 | ) 24 | self.f0 = f0 25 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 26 | -------------------------------------------------------------------------------- /pitch_detectors/schemas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pydantic import BaseModel 3 | from pydantic import ConfigDict 4 | from pydantic import model_validator 5 | from typing_extensions import Self 6 | 7 | 8 | class ArbitraryBaseModel(BaseModel): 9 | model_config = ConfigDict(arbitrary_types_allowed=True) 10 | 11 | 12 | class Wav(ArbitraryBaseModel): 13 | fs: int 14 | a: np.ndarray 15 | 16 | 17 | class F0(ArbitraryBaseModel): 18 | t: np.ndarray 19 | f0: np.ndarray 20 | 21 | @model_validator(mode='after') 22 | def check_shape(self) -> Self: 23 | if self.t.shape != self.f0.shape: 24 | raise ValueError('t and f0 must have the same shape') 25 | return self 26 | 27 | 28 | class Record(ArbitraryBaseModel): 29 | fs: int | None = None 30 | a: np.ndarray | None = None 31 | t: np.ndarray | None = None 32 | f0: np.ndarray | None = None 33 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/praatac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class PraatAC(PitchDetector): 8 | """https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_ac""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import parselmouth 18 | super().__init__(a, fs) 19 | self.signal = parselmouth.Sound(self.a, sampling_frequency=self.fs) 20 | self.pitch_obj = self.signal.to_pitch_ac(pitch_floor=hz_min, pitch_ceiling=hz_max, very_accurate=True) 21 | self.f0 = self.pitch_obj.selected_array['frequency'] 22 | self.f0[self.f0 == 0] = np.nan 23 | self.t = self.pitch_obj.xs() 24 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/praatcc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class PraatCC(PitchDetector): 8 | """https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_cc""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import parselmouth 18 | super().__init__(a, fs) 19 | self.signal = parselmouth.Sound(self.a, sampling_frequency=self.fs) 20 | self.pitch_obj = self.signal.to_pitch_cc(pitch_floor=hz_min, pitch_ceiling=hz_max, very_accurate=True) 21 | self.f0 = self.pitch_obj.selected_array['frequency'] 22 | self.f0[self.f0 == 0] = np.nan 23 | self.t = self.pitch_obj.xs() 24 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/praatshs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class PraatSHS(PitchDetector): 8 | """https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_shs""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import parselmouth 18 | super().__init__(a, fs) 19 | self.signal = parselmouth.Sound(self.a, sampling_frequency=self.fs) 20 | self.pitch_obj = self.signal.to_pitch_shs(minimum_pitch=hz_min, maximum_frequency_component=hz_max) 21 | self.f0 = self.pitch_obj.selected_array['frequency'] 22 | self.f0[self.f0 == 0] = np.nan 23 | self.t = self.pitch_obj.xs() 24 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/yaapt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Yaapt(PitchDetector): 8 | """http://bjbschmitt.github.io/AMFM_decompy/pYAAPT.html#amfm_decompy.pYAAPT.yaapt""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import amfm_decompy.basic_tools as basic 18 | from amfm_decompy import pYAAPT 19 | super().__init__(a, fs) 20 | self.signal = basic.SignalObj(data=self.a, fs=self.fs) 21 | f0 = pYAAPT.yaapt(self.signal, f0_min=hz_min, f0_max=hz_max, frame_length=15) 22 | f0 = f0.samp_values 23 | f0[f0 == 0] = np.nan 24 | self.f0 = f0 25 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 26 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | ARG BASE_IMAGE_CUDA=scratch 3 | ARG PYTHON_VERSION=3.12 4 | ARG UV_IMAGE=ghcr.io/astral-sh/uv:latest 5 | FROM python:${PYTHON_VERSION}-slim AS python 6 | FROM ${UV_IMAGE} AS uv 7 | FROM ${BASE_IMAGE_CUDA} 8 | # disable banner from nvidia/cuda entrypoint 9 | ENTRYPOINT [] 10 | 11 | # Copy Python binaries and libraries from the official Python image 12 | COPY --from=python /usr/local /usr/local 13 | ENV PATH="/usr/local/bin:$PATH" 14 | 15 | COPY --from=uv /uv /uvx /bin/ 16 | 17 | WORKDIR /app 18 | COPY pitch_detectors /app/pitch_detectors 19 | 20 | ENV UV_LINK_MODE=copy \ 21 | UV_FROZEN=1 22 | 23 | RUN --mount=type=cache,target=/root/.cache/uv \ 24 | --mount=type=bind,source=uv.lock,target=uv.lock \ 25 | --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ 26 | uv sync 27 | 28 | COPY tests /app/tests 29 | COPY scripts/ /app/scripts 30 | COPY data /app/data 31 | 32 | ENV PATH="/app/.venv/bin:$PATH" 33 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/reaper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class Reaper(PitchDetector): 8 | """https://github.com/r9y9/pyreaper""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | ): 17 | import pyreaper 18 | from dsplib.scale import minmax_scaler_array 19 | int16_info = np.iinfo(np.int16) 20 | a = minmax_scaler_array(a, np.min(a), np.max(a), int16_info.min, int16_info.max).round().astype(np.int16) 21 | super().__init__(a, fs) 22 | pm_times, pm, f0_times, f0, corr = pyreaper.reaper(self.a, fs=self.fs, minf0=hz_min, maxf0=hz_max, frame_period=0.01) 23 | f0[f0 == -1] = np.nan 24 | self.f0 = f0 25 | self.t = f0_times 26 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/datasets/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections.abc import Iterator 3 | from pathlib import Path 4 | 5 | from pitch_detectors.schemas import F0 6 | from pitch_detectors.schemas import Wav 7 | 8 | 9 | class Dataset(abc.ABC): 10 | 11 | @classmethod 12 | def name(cls) -> str: 13 | return cls.__name__ 14 | 15 | @classmethod 16 | @abc.abstractmethod 17 | def dataset_dir(cls) -> Path: 18 | ... 19 | 20 | @classmethod 21 | @abc.abstractmethod 22 | def wav_dir(cls) -> Path: 23 | ... 24 | 25 | @classmethod 26 | def iter_wav_files(cls) -> Iterator[Path]: 27 | yield from cls.wav_dir().glob('*.wav') 28 | 29 | @classmethod 30 | @abc.abstractmethod 31 | def load_wav(cls, wav_path: Path) -> Wav: 32 | ... 33 | 34 | @classmethod 35 | @abc.abstractmethod 36 | def load_true(cls, wav_path: Path, seconds: float) -> F0: 37 | ... 38 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/piptrack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | 6 | 7 | class PipTrack(PitchDetector): 8 | """https://librosa.org/doc/latest/generated/librosa.piptrack.html""" 9 | 10 | def __init__( 11 | self, 12 | a: np.ndarray, 13 | fs: int, 14 | hz_min: float = config.HZ_MIN, 15 | hz_max: float = config.HZ_MAX, 16 | threshold: float = 0.1, 17 | ): 18 | import librosa 19 | super().__init__(a, fs) 20 | pitches, magnitudes = librosa.piptrack( 21 | y=a, 22 | sr=fs, 23 | fmin=hz_min, 24 | fmax=hz_max, 25 | threshold=threshold, 26 | ) 27 | max_indexes = np.argmax(magnitudes, axis=0) 28 | f0 = pitches[max_indexes, range(magnitudes.shape[1])] 29 | f0[f0 == 0] = np.nan 30 | self.f0 = f0 31 | self.t = np.linspace(0, self.seconds, self.f0.shape[0]) 32 | -------------------------------------------------------------------------------- /pitch_detectors/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import math 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from dsplib.scale import minmax_scaler_array 7 | from scipy.io import wavfile 8 | 9 | 10 | def nan_to_none(x: list[float]) -> list[float | None]: 11 | return [None if math.isnan(v) else v for v in x] 12 | 13 | 14 | def none_to_nan(x: list[float | None]) -> list[float]: 15 | return [float('nan') if v is None else v for v in x] 16 | 17 | 18 | def load_wav(path: Path | str, rescale: float = 100000) -> tuple[int, np.ndarray]: 19 | fs, a = wavfile.read(path) 20 | a = minmax_scaler_array(a, a.min(), a.max(), -rescale, rescale).astype(np.float32) 21 | return fs, a 22 | 23 | 24 | def source_hashes() -> dict[str, str]: 25 | alg_dir = Path(__file__).parent / 'algorithms' 26 | base = alg_dir / 'base.py' 27 | base_bytes = base.read_bytes() 28 | hashes = {} 29 | for p in alg_dir.glob('*.py'): 30 | if p.name in {'__init__.py', 'base.py'}: 31 | continue 32 | h = hashlib.sha256() 33 | h.update(base_bytes) 34 | h.update(p.read_bytes()) 35 | hashes[p.stem] = h.hexdigest() 36 | return hashes 37 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: push 4 | 5 | jobs: 6 | # test: 7 | # runs-on: [self-hosted, gpu] 8 | # steps: 9 | # - uses: actions/checkout@v4 10 | 11 | # - name: test gpu is available 12 | # run: nvidia-smi 13 | 14 | # - name: build image 15 | # run: make build 16 | 17 | # - name: test-no-docker 18 | # run: make test-no-docker 19 | 20 | # - name: test 21 | # run: make test 22 | 23 | # - name: test-no-gpu 24 | # run: make test-no-gpu 25 | 26 | publish-to-pypi-and-github-release: 27 | if: "startsWith(github.ref, 'refs/tags/')" 28 | runs-on: ubuntu-latest 29 | # needs: test 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: '3.12' 36 | 37 | - name: Install pypa/build 38 | run: python -m pip install --upgrade setuptools build twine 39 | 40 | - name: Build a source tarball and wheel 41 | run: python -m build . 42 | 43 | - name: Publish package to PyPI 44 | env: 45 | TWINE_USERNAME: __token__ 46 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 47 | run: python -m twine upload dist/* 48 | 49 | # doesn't work in private repos 50 | # - name: Github Release 51 | # uses: softprops/action-gh-release@v1 52 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from pitch_detectors.algorithms.crepe import Crepe 2 | from pitch_detectors.algorithms.ensemble import Ensemble # noqa: F401 3 | from pitch_detectors.algorithms.penn import Penn 4 | from pitch_detectors.algorithms.piptrack import PipTrack 5 | from pitch_detectors.algorithms.praatac import PraatAC 6 | from pitch_detectors.algorithms.praatcc import PraatCC 7 | from pitch_detectors.algorithms.praatshs import PraatSHS 8 | from pitch_detectors.algorithms.pyin import Pyin 9 | from pitch_detectors.algorithms.rapt import Rapt 10 | from pitch_detectors.algorithms.reaper import Reaper 11 | from pitch_detectors.algorithms.spice import Spice 12 | from pitch_detectors.algorithms.swipe import Swipe 13 | from pitch_detectors.algorithms.torchcrepe import TorchCrepe 14 | from pitch_detectors.algorithms.torchyin import TorchYin 15 | from pitch_detectors.algorithms.world import World 16 | from pitch_detectors.algorithms.yaapt import Yaapt 17 | from pitch_detectors.algorithms.yin import Yin 18 | 19 | ALGORITHMS = ( 20 | PraatAC, 21 | PraatCC, 22 | PraatSHS, 23 | Pyin, 24 | Yin, 25 | Reaper, 26 | Yaapt, 27 | Crepe, 28 | TorchCrepe, 29 | Swipe, 30 | Rapt, 31 | World, 32 | TorchYin, 33 | Spice, 34 | Penn, 35 | PipTrack, 36 | ) 37 | 38 | cpu_algorithms = tuple(a.name() for a in ALGORITHMS if not a.gpu_capable) # type: ignore 39 | gpu_algorithms = tuple(a.name() for a in ALGORITHMS if a.gpu_capable) # type: ignore 40 | algorithms = cpu_algorithms + gpu_algorithms 41 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - test 3 | 4 | test: 5 | stage: test 6 | needs: [] 7 | tags: 8 | - u60 9 | - docker 10 | - gpu 11 | image: tandav/pitch-detectors:12.6.2-cudnn-devel-ubuntu24.04-python3.12@sha256:665b04e9a86d95b1b57175e582d66f0b6cdcda2accd7593845697a4021bd4c46 12 | cache: 13 | key: $CI_PROJECT_NAME 14 | paths: 15 | - .cache/ 16 | variables: 17 | PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" 18 | PRE_COMMIT_HOME: "$CI_PROJECT_DIR/.cache/pre-commit" 19 | RUFF_CACHE_DIR: "$CI_PROJECT_DIR/.cache/ruff_cache" 20 | MYPY_CACHE_DIR: "$CI_PROJECT_DIR/.cache/mypy_cache" 21 | PITCH_DETECTORS_SPICE_MODEL_PATH: /models/spice_model 22 | PITCH_DETECTORS_PENN_CHECKPOINT_PATH: /models/fcnf0++.pt 23 | script: 24 | - export $(grep -v '^#' $S3_ENV | xargs) && python scripts/download_models.py 25 | - pytest --cov pitch_detectors --cov-report term --cov-report xml --junitxml report.xml 26 | coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' 27 | artifacts: 28 | when: always 29 | expire_in: 1 week 30 | reports: 31 | coverage_report: 32 | coverage_format: cobertura 33 | path: coverage.xml 34 | junit: report.xml 35 | 36 | lint: 37 | stage: test 38 | tags: 39 | - u61 40 | - docker 41 | needs: [] 42 | image: python:3.12@sha256:fce9bc7648ef917a5ab67176cf1c7eb41b110452e259736144bc22f32f3aa622 43 | variables: 44 | PIP_INDEX_URL: https://pypi.tandav.me/index/ 45 | script: 46 | - pip install pre-commit 47 | - pre-commit run --all-files 48 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/torchcrepe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pitch_detectors import config 4 | from pitch_detectors.algorithms.base import PitchDetector 5 | from pitch_detectors.algorithms.base import TorchGPU 6 | 7 | 8 | class TorchCrepe(TorchGPU, PitchDetector): 9 | """https://github.com/maxrmorrison/torchcrepe""" 10 | 11 | def __init__( 12 | self, 13 | a: np.ndarray, 14 | fs: int, 15 | hz_min: float = config.HZ_MIN, 16 | hz_max: float = config.HZ_MAX, 17 | confidence_threshold: float = 0.8, 18 | batch_size: int = 2048, 19 | gpu: bool | None = None, 20 | ): 21 | import torch 22 | import torchcrepe 23 | 24 | TorchGPU.__init__(self, gpu) 25 | PitchDetector.__init__(self, a, fs) 26 | 27 | f0, confidence = torchcrepe.predict( 28 | torch.from_numpy(a[np.newaxis, ...]), 29 | fs, 30 | hop_length=int(fs / 100), # 10 ms 31 | fmin=hz_min, 32 | fmax=hz_max, 33 | batch_size=batch_size, 34 | device='cuda:0' if self.gpu else 'cpu', 35 | return_periodicity=True, 36 | ) 37 | win_length = 3 38 | f0 = torchcrepe.filter.mean(f0, win_length) 39 | confidence = torchcrepe.filter.median(confidence, win_length) 40 | 41 | f0 = f0.ravel().numpy() 42 | confidence = confidence.ravel().numpy() 43 | f0[confidence < confidence_threshold] = np.nan 44 | self.f0 = f0 45 | self.t = np.linspace(0, self.seconds, f0.shape[0]) 46 | torch.cuda.empty_cache() 47 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/datasets/mir_1k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from musiclib.pitch import Pitch 6 | from scipy.io import wavfile 7 | 8 | from pitch_detectors.evaluation.datasets.base import Dataset 9 | from pitch_detectors.schemas import F0 10 | from pitch_detectors.schemas import Wav 11 | 12 | 13 | class Mir1K(Dataset): 14 | """https://www.kaggle.com/datasets/datongmuyuyi/mir1k""" 15 | 16 | @classmethod 17 | def dataset_dir(cls) -> Path: 18 | return Path(os.environ.get('DATASET_DIR_MIR_1K', 'f0-datasets/mir-1k/MIR-1K')) 19 | 20 | @classmethod 21 | def wav_dir(cls) -> Path: 22 | return cls.dataset_dir() / 'Wavfile' 23 | 24 | @classmethod 25 | def load_wav(cls, wav_path: Path) -> Wav: 26 | fs, a = wavfile.read(wav_path) 27 | a = a[:, 1].astype(np.float32) 28 | return Wav(fs=fs, a=a) 29 | 30 | @classmethod 31 | def load_true(cls, wav_path: Path, seconds: float) -> F0: 32 | p = Pitch() 33 | pitch_label_dir = wav_path.parent.parent / 'PitchLabel' 34 | f0_path = (pitch_label_dir / wav_path.stem).with_suffix('.pv') 35 | f0 = [] 36 | with open(f0_path) as f: 37 | for _line in f: 38 | line = _line.strip() 39 | if line == '0': 40 | f0.append(float('nan')) 41 | else: 42 | f0.append(p.note_i_to_hz(float(line))) 43 | f0 = np.array(f0) 44 | # t = np.arange(0.02, seconds - 0.02, 0.02) 45 | # assert t.shape == f0.shape 46 | t = np.linspace(0.02, seconds, len(f0)) 47 | return F0(t=t, f0=f0) 48 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-yaml 7 | - id: check-json 8 | - id: check-ast 9 | - id: check-byte-order-marker 10 | - id: check-builtin-literals 11 | - id: check-case-conflict 12 | - id: check-docstring-first 13 | - id: debug-statements 14 | - id: end-of-file-fixer 15 | - id: mixed-line-ending 16 | - id: trailing-whitespace 17 | - id: check-merge-conflict 18 | - id: detect-private-key 19 | - id: double-quote-string-fixer 20 | - id: name-tests-test 21 | 22 | - repo: https://github.com/asottile/add-trailing-comma 23 | rev: v3.1.0 24 | hooks: 25 | - id: add-trailing-comma 26 | 27 | - repo: https://github.com/asottile/pyupgrade 28 | rev: v3.16.0 29 | hooks: 30 | - id: pyupgrade 31 | 32 | - repo: https://github.com/hhatto/autopep8 33 | rev: v2.3.1 34 | hooks: 35 | - id: autopep8 36 | 37 | - repo: https://github.com/PyCQA/autoflake 38 | rev: v2.3.1 39 | hooks: 40 | - id: autoflake 41 | 42 | - repo: https://github.com/astral-sh/ruff-pre-commit 43 | rev: v0.5.0 44 | hooks: 45 | - id: ruff 46 | args: [--fix, --exit-non-zero-on-fix] 47 | 48 | - repo: https://github.com/PyCQA/pylint 49 | rev: v3.2.4 50 | hooks: 51 | - id: pylint 52 | additional_dependencies: ["pylint-per-file-ignores"] 53 | 54 | - repo: https://github.com/pre-commit/mirrors-mypy 55 | rev: v1.10.1 56 | hooks: 57 | - id: mypy 58 | additional_dependencies: [types-redis, types-tabulate, pydantic] 59 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/datasets/mdb_stem_synth.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from collections.abc import Iterator 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from scipy.io import wavfile 8 | 9 | from pitch_detectors.evaluation.datasets.base import Dataset 10 | from pitch_detectors.schemas import F0 11 | from pitch_detectors.schemas import Wav 12 | 13 | 14 | class MDBStemSynth(Dataset): 15 | """https://zenodo.org/record/1481172""" 16 | 17 | @classmethod 18 | def dataset_dir(cls) -> Path: 19 | return Path(os.environ.get('DATASET_DIR_MDB_STEM_SYNTH', 'f0-datasets/mdb-stem-synth/MDB-stem-synth')) 20 | 21 | @classmethod 22 | def wav_dir(cls) -> Path: 23 | return cls.dataset_dir() / 'audio_stems' 24 | 25 | @classmethod 26 | def iter_wav_files(cls) -> Iterator[Path]: 27 | it = cls.wav_dir().glob('*.wav') 28 | it = (f for f in it if not f.name.startswith('.')) 29 | yield from it 30 | 31 | @classmethod 32 | def load_wav(cls, wav_path: Path) -> Wav: 33 | fs, a = wavfile.read(wav_path) 34 | a = a.astype(np.float32) 35 | return Wav(fs=fs, a=a) 36 | 37 | @classmethod 38 | def load_true(cls, wav_path: Path, seconds: float) -> F0: 39 | file = (cls.dataset_dir() / 'annotation_stems' / wav_path.name).with_suffix('.csv') 40 | t_list, f0_list = [], [] 41 | with open(file, newline='') as csvfile: 42 | reader = csv.reader(csvfile, quoting=csv.QUOTE_NONNUMERIC) 43 | for _t, _f0 in reader: 44 | t_list.append(_t) 45 | f0_list.append(_f0) 46 | t = np.array(t_list) 47 | f0 = np.array(f0_list) 48 | f0[f0 == 0] = np.nan 49 | return F0(t=t, f0=f0) 50 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/penn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | from pitch_detectors import config 7 | from pitch_detectors.algorithms.base import PitchDetector 8 | from pitch_detectors.algorithms.base import TorchGPU 9 | 10 | 11 | class Penn(TorchGPU, PitchDetector): 12 | """https://github.com/interactiveaudiolab/penn""" 13 | 14 | def __init__( 15 | self, 16 | a: np.ndarray, 17 | fs: int, 18 | hz_min: float = config.HZ_MIN, 19 | hz_max: float = config.HZ_MAX, 20 | periodicity_threshold: float = 0.1, 21 | checkpoint: str | None = None, 22 | gpu: bool | None = None, 23 | ): 24 | import torch 25 | from penn.core import from_audio 26 | 27 | TorchGPU.__init__(self, gpu) 28 | PitchDetector.__init__(self, a, fs) 29 | 30 | if checkpoint is None: 31 | checkpoint = os.environ.get('PITCH_DETECTORS_PENN_CHECKPOINT_PATH', '/fcnf0++.pt') 32 | checkpoint_path = Path(checkpoint) 33 | 34 | f0, periodicity = from_audio( 35 | audio=torch.tensor(a.reshape(1, -1)), 36 | sample_rate=fs, 37 | fmin=hz_min, 38 | fmax=hz_max, 39 | checkpoint=checkpoint_path, 40 | gpu=0 if self.gpu else None, 41 | ) 42 | if self.gpu: 43 | f0 = f0.cpu() 44 | periodicity = periodicity.cpu() 45 | periodicity = periodicity.numpy().ravel() 46 | f0 = f0.numpy().ravel() 47 | f0[periodicity < periodicity_threshold] = np.nan 48 | self.f0 = f0 49 | self.periodicity = periodicity 50 | self.t = np.linspace(0, self.seconds, self.f0.shape[0]) 51 | torch.cuda.empty_cache() 52 | -------------------------------------------------------------------------------- /Taskfile.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | dotenv: ['.env'] 3 | env: 4 | UBUNTU_VERSION: 24.04 5 | CUDA_VERSION: 12.6.2 6 | PYTHON_VERSION: 3.12 7 | UV_IMAGE: ghcr.io/astral-sh/uv:0.5.4@sha256:5436c72d52c9c0d011010ce68f4c399702b3b0764adcf282fe0e546f20ebaef6 8 | BASE_IMAGE_CUDA_SHA256: 431b2307f69f41ca51503a6103be3b5c52dcfad18b201af7f12349a0cca35a4e 9 | BASE_IMAGE_CUDA: nvidia/cuda:{{.CUDA_VERSION}}-cudnn-devel-ubuntu{{.UBUNTU_VERSION}}@sha256:{{.BASE_IMAGE_CUDA_SHA256}} 10 | IMAGE: tandav/pitch-detectors:{{.CUDA_VERSION}}-cudnn-devel-ubuntu{{.UBUNTU_VERSION}}-python{{.PYTHON_VERSION}} 11 | MODEL_PATH_SPICE: /mnt/sg8tb1/downloads-archive/libmv-data/spice_model 12 | MODEL_PATH_PENN: /mnt/sg8tb1/downloads-archive/libmv-data/fcnf0++.pt 13 | tasks: 14 | build: 15 | cmd: > 16 | docker build 17 | --build-arg BASE_IMAGE_CUDA=$BASE_IMAGE_CUDA 18 | --build-arg PYTHON_VERSION=$PYTHON_VERSION 19 | --build-arg UV_IMAGE=$UV_IMAGE 20 | --tag $IMAGE 21 | . 22 | 23 | push: 24 | cmd: docker push $IMAGE 25 | 26 | test: 27 | deps: [build] 28 | cmd: > 29 | docker run --rm -t --gpus all 30 | -v $MODEL_PATH_SPICE:/spice_model:ro 31 | -v $MODEL_PATH_PENN:/fcnf0++.pt:ro 32 | $IMAGE 33 | uv run 34 | pytest -v 35 | 36 | test-no-docker: 37 | cmd: uv run pytest -v 38 | 39 | bumpver: 40 | desc: 'Bump version. Pass --. Usage example: task bumpver -- --minor' 41 | cmds: 42 | - uv run bumpver update --no-fetch {{.CLI_ARGS}} 43 | 44 | evaluation: 45 | deps: [build] 46 | cmd: > 47 | docker run --rm -t --gpus all 48 | -e PITCH_DETECTORS_GPU=true 49 | -e REDIS_URL={{.REDIS_URL}} 50 | -v /media/tandav/sg8tb1/downloads-archive/f0-datasets:/app/f0-datasets:ro 51 | {{.IMAGE}} 52 | python -m pitch_detectors.evaluation 53 | 54 | table: 55 | cmd: uv run python -m pitch_detectors.evaluation.table 56 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/spice.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from pitch_detectors.algorithms.base import PitchDetector 6 | from pitch_detectors.algorithms.base import TensorflowGPU 7 | 8 | 9 | class Spice(TensorflowGPU, PitchDetector): 10 | """ 11 | https://ai.googleblog.com/2019/11/spice-self-supervised-pitch-estimation.html 12 | https://blog.tensorflow.org/2020/06/estimating-pitch-with-spice-and-tensorflow-hub.html 13 | https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/spice.ipynb 14 | https://www.kaggle.com/models/google/spice 15 | https://www.kaggle.com/models/google/spice/tensorFlow1/spice/2 16 | """ 17 | 18 | def __init__( 19 | self, 20 | a: np.ndarray, 21 | fs: int, 22 | confidence_threshold: float = 0.8, 23 | expected_sample_rate: int = 16000, 24 | spice_model_path: str | None = None, 25 | gpu: bool | None = None, 26 | ): 27 | 28 | import resampy 29 | 30 | a = resampy.resample(a, fs, expected_sample_rate) 31 | TensorflowGPU.__init__(self, gpu) 32 | PitchDetector.__init__(self, a, fs) 33 | 34 | import tensorflow_hub as hub 35 | 36 | if spice_model_path is None: 37 | spice_model_path = os.environ.get('PITCH_DETECTORS_SPICE_MODEL_PATH', '/spice_model') 38 | 39 | model = hub.load(spice_model_path) 40 | model_output = model.signatures['serving_default'](self.tf.constant(a, self.tf.float32)) 41 | confidence = 1.0 - model_output['uncertainty'] 42 | self.f0 = self.output2hz(model_output['pitch'].numpy()) 43 | self.f0[confidence < confidence_threshold] = np.nan 44 | self.t = np.linspace(0, self.seconds, self.f0.shape[0]) 45 | 46 | @staticmethod 47 | def output2hz( 48 | pitch_output: np.ndarray, 49 | pt_offset: float = 25.58, 50 | pt_slope: float = 63.07, 51 | fmin: float = 10.0, 52 | bins_per_octave: float = 12.0, 53 | ) -> np.ndarray: 54 | """convert pitch from the model output [0.0, 1.0] range to absolute values in Hz.""" 55 | cqt_bin = pitch_output * pt_slope + pt_offset 56 | return fmin * 2.0 ** (1.0 * cqt_bin / bins_per_octave) 57 | -------------------------------------------------------------------------------- /tests/algorithms_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | import pytest 6 | 7 | from pitch_detectors import algorithms 8 | 9 | 10 | @pytest.mark.order(3) 11 | @pytest.mark.parametrize('gpu', ['true', 'false'], ids=['gpu', 'cpu']) 12 | @pytest.mark.parametrize('algorithm', (*algorithms.algorithms, 'Ensemble')) 13 | def test_detection(algorithm, gpu): 14 | env = { 15 | 'PITCH_DETECTORS_GPU_MEMORY_LIMIT': 'true', 16 | 'PITCH_DETECTORS_AUDIO_PATH': 'data/b1a5da49d564a7341e7e1327aa3f229a.wav', 17 | 'PATH': '', # for some reason this line prevents SIGSEGV for Spice algorithm 18 | # 'PATH': '/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin', # this is from docker history of base cuda image https://hub.docker.com/layers/nvidia/cuda/12.4.1-cudnn-devel-ubuntu22.04/images/sha256-0a1cb6e7bd047a1067efe14efdf0276352d5ca643dfd77963dab1a4f05a003a4?context=explore 19 | 'PITCH_DETECTORS_ALGORITHM': algorithm, 20 | 'PITCH_DETECTORS_GPU': gpu, 21 | } 22 | if 'PITCH_DETECTORS_PENN_CHECKPOINT_PATH' in os.environ: 23 | env['PITCH_DETECTORS_PENN_CHECKPOINT_PATH'] = os.environ['PITCH_DETECTORS_PENN_CHECKPOINT_PATH'] 24 | if 'PITCH_DETECTORS_SPICE_MODEL_PATH' in os.environ: 25 | env['PITCH_DETECTORS_SPICE_MODEL_PATH'] = os.environ['PITCH_DETECTORS_SPICE_MODEL_PATH'] 26 | subprocess.check_call([sys.executable, 'scripts/run_algorithm.py'], env=env) 27 | 28 | 29 | def test_uses_gpu_framework(): 30 | assert algorithms.Crepe.uses_gpu_framework is True 31 | assert algorithms.Ensemble .uses_gpu_framework is True 32 | assert algorithms.Penn.uses_gpu_framework is True 33 | assert algorithms.PipTrack.uses_gpu_framework is False 34 | assert algorithms.PraatAC.uses_gpu_framework is False 35 | assert algorithms.PraatCC.uses_gpu_framework is False 36 | assert algorithms.PraatSHS.uses_gpu_framework is False 37 | assert algorithms.Pyin.uses_gpu_framework is False 38 | assert algorithms.Rapt.uses_gpu_framework is False 39 | assert algorithms.Reaper.uses_gpu_framework is False 40 | assert algorithms.Spice.uses_gpu_framework is True 41 | assert algorithms.Swipe.uses_gpu_framework is False 42 | assert algorithms.TorchCrepe.uses_gpu_framework is True 43 | assert algorithms.TorchYin.uses_gpu_framework is True 44 | assert algorithms.World.uses_gpu_framework is False 45 | assert algorithms.Yaapt.uses_gpu_framework is False 46 | assert algorithms.Yin.uses_gpu_framework is False 47 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from pitch_detectors import util 6 | 7 | 8 | class PitchDetector: 9 | gpu_capable = False 10 | uses_gpu_framework = False 11 | 12 | def __init__(self, a: np.ndarray, fs: int): 13 | self.a = a 14 | self.fs = fs 15 | self.seconds = len(a) / fs 16 | self.f0: np.ndarray 17 | self.t: np.ndarray 18 | 19 | def dict(self) -> dict[str, list[float | None]]: 20 | return {'f0': util.nan_to_none(self.f0.tolist()), 't': self.t.tolist()} 21 | 22 | @classmethod 23 | def name(cls) -> str: 24 | return cls.__name__ 25 | 26 | 27 | class UsesGPU: 28 | gpu_capable = True 29 | uses_gpu_framework = True 30 | memory_limit_initialized = False 31 | 32 | def __init__(self, gpu: bool | None = None) -> None: 33 | self.gpu = gpu or os.environ.get('PITCH_DETECTORS_GPU') == 'true' 34 | if self.gpu and not self.gpu_available(): 35 | raise ConnectionError('gpu must be available') 36 | if not self.gpu: 37 | self.disable_gpu() 38 | if self.gpu_available(): 39 | raise ConnectionError('gpu must not be available') 40 | 41 | def gpu_available(self) -> bool: 42 | return False 43 | 44 | def disable_gpu(self) -> None: 45 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 46 | 47 | 48 | class TensorflowGPU(UsesGPU): 49 | 50 | def __init__(self, gpu: bool | None = None) -> None: 51 | 52 | import tensorflow as tf 53 | self.tf = tf 54 | super().__init__(gpu) 55 | if self.gpu_available() and os.environ.get('PITCH_DETECTORS_GPU_MEMORY_LIMIT') == 'true': 56 | self.set_memory_limit() 57 | 58 | def set_memory_limit(self) -> None: 59 | if TensorflowGPU.memory_limit_initialized: 60 | return 61 | gpus = self.gpus 62 | for gpu in gpus: 63 | self.tf.config.experimental.set_memory_growth(gpu, True) 64 | TensorflowGPU.memory_limit_initialized = True 65 | 66 | @property 67 | def gpus(self) -> list[str]: 68 | return self.tf.config.experimental.list_physical_devices('GPU') # type: ignore 69 | 70 | def gpu_available(self) -> bool: 71 | return bool(self.gpus) 72 | 73 | 74 | class TorchGPU(UsesGPU): 75 | 76 | def __init__(self, gpu: bool | None = None) -> None: 77 | import torch 78 | self.torch = torch 79 | super().__init__(gpu) 80 | if self.gpu_available() and os.environ.get('PITCH_DETECTORS_GPU_MEMORY_LIMIT') == 'true': 81 | self.set_memory_limit() 82 | 83 | def set_memory_limit(self) -> None: 84 | if TorchGPU.memory_limit_initialized: 85 | return 86 | self.torch.cuda.set_per_process_memory_fraction(fraction=1 / 8, device=0) 87 | TorchGPU.memory_limit_initialized = True 88 | 89 | def gpu_available(self) -> bool: 90 | return self.torch.cuda.is_available() # type: ignore 91 | -------------------------------------------------------------------------------- /pitch_detectors/algorithms/ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from typing import TypeAlias 3 | 4 | import numpy as np 5 | 6 | from pitch_detectors.algorithms.base import PitchDetector 7 | from pitch_detectors.algorithms.base import TensorflowGPU 8 | from pitch_detectors.algorithms.base import TorchGPU 9 | from pitch_detectors.schemas import F0 10 | 11 | PDT: TypeAlias = type[PitchDetector] 12 | 13 | 14 | def vote_and_median( 15 | algorithms: dict[str, F0], 16 | seconds: float, 17 | pitch_fs: int = 1024, 18 | min_duration: float = 1, 19 | min_algorithms: int = 3, 20 | # algorithm_weights: dict[str, float] = {}, 21 | ) -> F0: 22 | if len(algorithms) < min_algorithms: 23 | raise ValueError(f'at least {min_algorithms} algorithms must be provided, because min_algorithms={min_algorithms}') 24 | 25 | single_n = int(seconds * pitch_fs) 26 | t_resampled = np.linspace(0, seconds, single_n) 27 | f0_resampled = {} 28 | F0_arr = np.empty((len(algorithms), single_n)) 29 | for i, (name, data) in enumerate(algorithms.items()): 30 | t = data.t 31 | f0 = data.f0 32 | if len(f0) == 0: 33 | raise ValueError(f'algorithm {name} returned an empty f0 array') 34 | f0_resampled[name] = np.full_like(t_resampled, fill_value=np.nan) 35 | notna_slices = np.ma.clump_unmasked(np.ma.masked_invalid(f0)) 36 | 37 | for sl in notna_slices: 38 | t_slice = t[sl] 39 | f0_slice = f0[sl] 40 | t_start, t_stop = t_slice[0], t_slice[-1] 41 | duration = t_stop - t_start 42 | if duration < min_duration: 43 | continue 44 | mask = (t_start < t_resampled) & (t_resampled < t_stop) 45 | t_interp = t_resampled[mask] 46 | f0_interp = np.interp(t_interp, t_slice, f0_slice) 47 | f0_resampled[name][mask] = f0_interp 48 | F0_arr[i] = f0_resampled[name] 49 | 50 | F0_mask = np.isfinite(F0_arr).astype(int) 51 | F0_mask_sum = F0_mask.sum(axis=0) 52 | min_alg_mask = F0_mask_sum > min_algorithms 53 | f0_mean = np.full_like(t_resampled, fill_value=np.nan) 54 | f0_mean[min_alg_mask] = np.nanmedian(F0_arr[:, min_alg_mask], axis=0) 55 | return F0(t=t_resampled, f0=f0_mean) 56 | 57 | 58 | class Ensemble(TensorflowGPU, TorchGPU, PitchDetector): 59 | """https://github.com/tandav/pitch-detectors/blob/master/pitch_detectors/algorithms/ensemble.py""" 60 | 61 | def __init__( 62 | self, 63 | a: np.ndarray, 64 | fs: int, 65 | algorithms: tuple[PDT, ...] | None = None, 66 | algorithms_kwargs: dict[PDT, dict[str, Any]] | None = None, 67 | gpu: bool | None = None, 68 | vote_and_median_kwargs: dict[str, Any] | None = None, 69 | ): 70 | TensorflowGPU.__init__(self, gpu) 71 | TorchGPU.__init__(self, gpu) 72 | PitchDetector.__init__(self, a, fs) 73 | 74 | if algorithms is None: 75 | from pitch_detectors.algorithms import ALGORITHMS as algorithms_ 76 | else: 77 | algorithms_ = algorithms 78 | 79 | self._algorithms = {} 80 | algorithms_kwargs = algorithms_kwargs or {} 81 | 82 | for cls in algorithms_: 83 | self._algorithms[cls] = cls(a, fs, **algorithms_kwargs.get(cls, {})) 84 | 85 | f0 = vote_and_median( 86 | {k.name(): F0(t=v.t, f0=v.f0) for k, v in self._algorithms.items()}, 87 | self.seconds, 88 | **(vote_and_median_kwargs or {}), 89 | ) 90 | self.t = f0.t 91 | self.f0 = f0.f0 92 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import mir_eval 6 | import numpy as np 7 | import tqdm 8 | from dsplib.scale import minmax_scaler_array 9 | from redis import Redis 10 | 11 | from pitch_detectors import algorithms 12 | from pitch_detectors import util 13 | from pitch_detectors.algorithms.base import PitchDetector 14 | from pitch_detectors.evaluation import datasets 15 | from pitch_detectors.evaluation.datasets.base import Dataset 16 | 17 | 18 | def resample_f0( 19 | pitch: PitchDetector, 20 | t_resampled: np.ndarray, 21 | ) -> np.ndarray: 22 | f0_resampled = np.full_like(t_resampled, fill_value=np.nan) 23 | notna_slices = np.ma.clump_unmasked(np.ma.masked_invalid(pitch.f0)) 24 | for slice_ in notna_slices: 25 | t_slice = pitch.t[slice_] 26 | f0_slice = pitch.f0[slice_] 27 | t_start, t_stop = t_slice[0], t_slice[-1] 28 | mask = (t_start < t_resampled) & (t_resampled < t_stop) 29 | t_interp = t_resampled[mask] 30 | f0_interp = np.interp(t_interp, t_slice, f0_slice) 31 | f0_resampled[mask] = f0_interp 32 | return f0_resampled 33 | 34 | 35 | def raw_pitch_accuracy( 36 | ref_f0: np.ndarray, 37 | est_f0: np.ndarray, 38 | cent_tolerance: float = 50, 39 | ) -> float: 40 | ref_voicing = np.isfinite(ref_f0) 41 | est_voicing = np.isfinite(est_f0) 42 | ref_cent = mir_eval.melody.hz2cents(ref_f0) 43 | est_cent = mir_eval.melody.hz2cents(est_f0) 44 | score: float = mir_eval.melody.raw_pitch_accuracy(ref_voicing, ref_cent, est_voicing, est_cent, cent_tolerance) 45 | return score 46 | 47 | 48 | def evaluate_one( 49 | redis: Redis, # type: ignore 50 | algorithm: type[PitchDetector], 51 | wav_path: Path, 52 | source_hashes: dict[str, str], 53 | dataset: type[Dataset], 54 | ) -> str: 55 | source_hash = source_hashes[algorithm.name().lower()] 56 | key = f'pitch_detectors:evaluation:{dataset.name()}:{wav_path.stem}:{algorithm.name()}:{source_hash}' 57 | if redis.exists(key): 58 | return key 59 | wav = dataset.load_wav(wav_path) 60 | seconds = len(wav.a) / wav.fs 61 | rescale = 100000 62 | a = minmax_scaler_array(wav.a, wav.a.min(), wav.a.max(), -rescale, rescale).astype(np.float32) 63 | true = dataset.load_true(wav_path, seconds) 64 | pitch = algorithm(a, wav.fs) 65 | f0 = resample_f0(pitch, t_resampled=true.t) 66 | score = raw_pitch_accuracy(true.f0, f0) 67 | redis.set(key, score) 68 | return key 69 | 70 | 71 | def evaluate_all(redis: Redis, source_hashes, dataset: type[Dataset]) -> None: # type: ignore 72 | t = tqdm.tqdm(sorted(dataset.iter_wav_files())) 73 | for wav_path in t: 74 | for algorithm in tqdm.tqdm(algorithms.ALGORITHMS, leave=False): 75 | key = evaluate_one(redis, algorithm, wav_path, source_hashes, dataset) 76 | t.set_description(key) 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--algorithm', type=str) 82 | parser.add_argument('--file', type=str) 83 | parser.add_argument('--dataset', type=str) 84 | args = parser.parse_args() 85 | if (args.algorithm is None) ^ (args.file is None): 86 | raise ValueError('you must specify both algorithm and file or neither') 87 | 88 | if args.dataset is None: 89 | _datasets = datasets.all_datasets 90 | else: 91 | _datasets = (getattr(datasets, args.dataset),) # type: ignore 92 | 93 | redis = Redis.from_url(os.environ['REDIS_URL'], decode_responses=True) 94 | source_hashes = util.source_hashes() 95 | redis.hset('pitch_detectors:source_hashes', mapping=source_hashes) # type: ignore 96 | 97 | for _dataset in _datasets: 98 | if args.algorithm is not None and args.file is not None: 99 | evaluate_one( 100 | redis, 101 | algorithm=getattr(algorithms, args.algorithm), 102 | wav_path=_dataset.wav_dir() / args.file, 103 | source_hashes=source_hashes, 104 | dataset=_dataset, # type: ignore 105 | ) 106 | raise SystemExit(0) 107 | evaluate_all(redis, source_hashes, _dataset) # type: ignore 108 | -------------------------------------------------------------------------------- /pitch_detectors/evaluation/table.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import os 3 | import re 4 | from collections.abc import Iterable 5 | from pathlib import Path 6 | from typing import Any 7 | from typing import TypeAlias 8 | 9 | import numpy as np 10 | import pipe21 as P 11 | from redis import Redis 12 | from tabulate import tabulate 13 | 14 | from pitch_detectors import algorithms 15 | from pitch_detectors import util 16 | from pitch_detectors.evaluation import datasets 17 | 18 | DictStr: TypeAlias = dict[str, str] 19 | 20 | 21 | def get_key_fields(s: str) -> DictStr: 22 | _, _, dataset, wav_path, algorithm, algorithm_sha256 = s.split(':') 23 | return { 24 | 'dataset': dataset, 25 | 'wav_path': wav_path, 26 | 'algorithm': algorithm, 27 | 'algorithm_sha256': algorithm_sha256, 28 | } 29 | 30 | 31 | def delete_stale_metrics(redis: Redis) -> int: # type: ignore 32 | keys = redis.keys('pitch_detectors:evaluation:*') 33 | source_hashes = util.source_hashes() 34 | pipeline = redis.pipeline() 35 | for key in keys: 36 | key_dict = get_key_fields(key) 37 | if source_hashes[key_dict['algorithm'].lower()] != key_dict['algorithm_sha256']: 38 | pipeline.delete(key) 39 | return sum(pipeline.execute()) 40 | 41 | 42 | def update_readme( 43 | repl: str, 44 | file: str = 'README.md', 45 | start: str = '', 46 | stop: str = '', 47 | ) -> None: 48 | _file = Path(file) 49 | text = _file.read_text() 50 | new = re.sub(fr'{start}(.*){stop}', f'{start}\n{repl}\n{stop}', text, flags=re.DOTALL) 51 | _file.write_text(new) 52 | 53 | 54 | def main() -> None: 55 | redis = Redis.from_url(os.environ['REDIS_URL'], decode_responses=True) 56 | delete_stale_metrics(redis) 57 | 58 | keys = redis.keys('pitch_detectors:evaluation:*') 59 | pipeline = redis.pipeline() 60 | for key in keys: 61 | pipeline.get(key) 62 | scores = pipeline.execute() 63 | key_score = list(zip(keys, scores, strict=True)) 64 | 65 | def group_key(x: DictStr) -> tuple[str, str]: 66 | return x['algorithm'], x['dataset'] 67 | 68 | algorithm_scores = ( 69 | key_score 70 | | P.MapKeys(get_key_fields) 71 | | P.MapValues(lambda x: {'raw_pitch_accuracy': float(x)}) 72 | | P.Map(lambda kv: kv[0] | kv[1]) 73 | | P.MapApply(lambda x: x.pop('algorithm_sha256')) 74 | | P.Sorted(key=group_key) 75 | | P.GroupBy(group_key) 76 | | P.MapValues(lambda it: it | P.Map(lambda x: x['raw_pitch_accuracy']) | P.Pipe(list)) 77 | | P.Pipe(list) 78 | ) 79 | 80 | def datasets_stats(it: Iterable[dict[str, Any]]) -> DictStr: 81 | out = {} 82 | for x in it: 83 | cls = getattr(datasets, x['dataset']) 84 | dataset = x['dataset'] 85 | out[f'[{dataset}]({cls.__doc__}) accuracy'] = f"{x.pop('mean'):<1.3f} ± {x.pop('std'):<1.3f}" 86 | return out 87 | 88 | def add_cls(kv: DictStr) -> DictStr: 89 | cls = getattr(algorithms, kv['algorithm']) 90 | kv['algorithm'] = f'[{cls.name()}]({cls.__doc__})' 91 | kv['cpu'] = '✓' 92 | kv['gpu'] = '✓' if cls.gpu_capable else '' 93 | return kv 94 | 95 | def sort_keys(kv: DictStr) -> DictStr: 96 | keys = kv.keys() 97 | to_sort = ['algorithm', 'cpu', 'gpu'] 98 | rest_keys = sorted(keys - set(to_sort)) 99 | return {k: kv[k] for k in to_sort + rest_keys} 100 | 101 | table = ( 102 | algorithm_scores 103 | | P.MapValues(lambda x: {'mean': np.mean(x).round(3), 'std': np.std(x).round(3)}) 104 | | P.Map(lambda kv: {'algorithm': kv[0][0], 'dataset': kv[0][1]} | kv[1]) 105 | | P.Sorted(key=operator.itemgetter('algorithm', 'dataset')) 106 | | P.GroupBy(operator.itemgetter('algorithm')) 107 | | P.MapValues(datasets_stats) 108 | | P.Map(lambda kv: {'algorithm': kv[0]} | kv[1]) 109 | | P.Map(add_cls) 110 | | P.Map(sort_keys) 111 | | P.Pipe(list) 112 | ) 113 | table = tabulate(table, headers='keys', tablefmt='github') 114 | print(table) 115 | update_readme(table) 116 | 117 | 118 | if __name__ == '__main__': 119 | print(os.environ['REDIS_URL']) 120 | main() 121 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pitch-detectors" 3 | version = "0.7.0" 4 | authors = [ 5 | {name = "Alexander Rodionov", email = "tandav@tandav.me"}, 6 | ] 7 | description = "collection of pitch detection algorithms with unified interface" 8 | requires-python = ">=3.12,<3.13" 9 | dependencies = [ 10 | "AMFM-decompy", 11 | "tensorflow", 12 | "tensorflow-hub", 13 | "dsplib==0.10.1", 14 | "librosa", 15 | "scipy==1.14.0", 16 | "numpy>=2.0", 17 | "praat-parselmouth>=0.4.3", 18 | "pyreaper>=0.0.9", 19 | "pysptk", 20 | "pyworld>=0.3.2", 21 | "pydantic>=2.0", 22 | "crepe>=0.0.16", 23 | "resampy==0.4.3", 24 | "torch", 25 | "torch-yin>=0.1.3", 26 | "torchcrepe>=0.0.18", 27 | "penn>=0.0.14", 28 | ] 29 | 30 | [dependency-groups] 31 | dev = [ 32 | "bumpver", 33 | "pre-commit", 34 | "pytest==8.3.2", 35 | "pytest-order==1.2.1", 36 | "pytest-cov", 37 | "pytest-env", 38 | "mir_eval", 39 | "tqdm", 40 | "redis", 41 | "musiclib==2.1.0", 42 | "python-dotenv", 43 | "tabulate", 44 | "s3fs", 45 | ] 46 | 47 | [project.urls] 48 | source = "https://github.com/tandav/pitch-detectors" 49 | # docs = "https://tandav.github.io/pitch-detectors/" 50 | issues = "https://github.com/tandav/pitch-detectors/issues" 51 | "release notes" = "https://github.com/tandav/pitch-detectors/releases" 52 | 53 | # ============================================================================== 54 | 55 | [[tool.uv.index]] 56 | name = "local-cache" 57 | url = "https://pypi.tandav.me/index/" 58 | 59 | # ============================================================================== 60 | 61 | [build-system] 62 | requires = ["setuptools"] 63 | build-backend = "setuptools.build_meta" 64 | 65 | # [tool.setuptools] 66 | # packages = ["pitch_detectors"] 67 | 68 | [tool.setuptools.packages.find] 69 | exclude = ["data*"] 70 | 71 | # ============================================================================== 72 | 73 | [tool.bumpver] 74 | current_version = "v0.7.0" 75 | version_pattern = "vMAJOR.MINOR.PATCH" 76 | commit_message = "bump version {old_version} -> {new_version}" 77 | commit = true 78 | tag = true 79 | 80 | [tool.bumpver.file_patterns] 81 | "pyproject.toml" = [ 82 | '^version = "{pep440_version}"', 83 | '^current_version = "{version}"', 84 | ] 85 | "pitch_detectors/__init__.py" = [ 86 | "^__version__ = '{pep440_version}'", 87 | ] 88 | 89 | # ============================================================================== 90 | 91 | [tool.mypy] 92 | # todo: review this 93 | pretty = true 94 | show_traceback = true 95 | color_output = true 96 | allow_redefinition = false 97 | check_untyped_defs = true 98 | disallow_any_generics = true 99 | disallow_incomplete_defs = true 100 | disallow_untyped_defs = true 101 | ignore_missing_imports = true 102 | implicit_reexport = true 103 | no_implicit_optional = true 104 | show_column_numbers = true 105 | show_error_codes = true 106 | show_error_context = true 107 | strict_equality = true 108 | strict_optional = true 109 | warn_no_return = true 110 | warn_redundant_casts = true 111 | warn_return_any = true 112 | warn_unreachable = true 113 | warn_unused_configs = true 114 | warn_unused_ignores = true 115 | plugins = [ 116 | "pydantic.mypy", 117 | ] 118 | 119 | [[tool.mypy.overrides]] 120 | module = ["tests.*"] 121 | disallow_untyped_defs = false 122 | 123 | # ============================================================================== 124 | 125 | [tool.ruff.lint] 126 | extend-select = [ 127 | "W", 128 | "C", 129 | "I", 130 | "SIM", 131 | "TCH", 132 | "C4", 133 | "S", 134 | "BLE", 135 | "B", 136 | "T10", 137 | "INP", 138 | "PIE", 139 | "PL", 140 | "RUF", 141 | ] 142 | ignore = [ 143 | "E501", # line too long 144 | "PLR0913", 145 | "TCH003", 146 | "S603", 147 | ] 148 | 149 | [tool.ruff.lint.per-file-ignores] 150 | "examples/*" = ["INP001"] 151 | "scripts/*" = ["INP001", "S101"] 152 | "tests/*" = ["S101"] 153 | 154 | [tool.ruff.lint.isort] 155 | force-single-line = true 156 | 157 | # ============================================================================== 158 | 159 | [tool.pylint.MASTER] 160 | load-plugins=[ 161 | "pylint_per_file_ignores", 162 | ] 163 | 164 | [tool.pylint.messages-control] 165 | disable = [ 166 | "missing-function-docstring", 167 | "missing-class-docstring", 168 | "missing-module-docstring", 169 | "line-too-long", 170 | "import-outside-toplevel", 171 | "unused-variable", 172 | "too-many-arguments", 173 | "import-error", 174 | "too-few-public-methods", 175 | "unspecified-encoding", 176 | "redefined-outer-name", 177 | "too-many-locals", 178 | "invalid-name", 179 | "protected-access", 180 | "cyclic-import", 181 | ] 182 | 183 | [tool.pylint-per-file-ignores] 184 | "/tests/" = "redefined-outer-name" 185 | 186 | # ============================================================================== 187 | 188 | [tool.autopep8] 189 | ignore="E501,E701" 190 | recursive = true 191 | aggressive = 3 192 | 193 | # ============================================================================== 194 | 195 | [tool.pytest.ini_options] 196 | filterwarnings = [ 197 | "ignore:Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new.:DeprecationWarning", 198 | "ignore:Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new.:DeprecationWarning", 199 | ] 200 | 201 | # ============================================================================== 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![pipeline status](https://gitlab.tandav.me/pitchtrack/pitch-detectors/badges/master/pipeline.svg)](https://gitlab.tandav.me/pitchtrack/pitch-detectors/-/commits/master) 2 | 3 | # pitch-detectors 4 | collection of pitch (f0, fundamental frequency) detection algorithms with unified interface 5 | 6 | ## list of algorithms 7 | 8 | | algorithm | cpu | gpu | [MDBStemSynth](https://zenodo.org/record/1481172) accuracy | [Mir1K](https://www.kaggle.com/datasets/datongmuyuyi/mir1k) accuracy | 9 | |------------------------------------------------------------------------------------------------------------|-------|-------|--------------------------------------------------------------|------------------------------------------------------------------------| 10 | | [Crepe](https://github.com/marl/crepe) | ✓ | ✓ | 0.886 ± 0.059 | 0.759 ± 0.073 | 11 | | [Penn](https://github.com/interactiveaudiolab/penn) | ✓ | ✓ | 0.699 ± 0.263 | 0.660 ± 0.090 | 12 | | [PipTrack](https://librosa.org/doc/latest/generated/librosa.piptrack.html) | ✓ | | 0.524 ± 0.286 | 0.864 ± 0.071 | 13 | | [PraatAC](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_ac) | ✓ | | 0.777 ± 0.302 | 0.859 ± 0.074 | 14 | | [PraatCC](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_cc) | ✓ | | 0.776 ± 0.300 | 0.872 ± 0.068 | 15 | | [PraatSHS](https://parselmouth.readthedocs.io/en/stable/api_reference.html#parselmouth.Sound.to_pitch_shs) | ✓ | | 0.534 ± 0.238 | 0.578 ± 0.169 | 16 | | [Pyin](https://librosa.org/doc/latest/generated/librosa.pyin.html) | ✓ | | 0.722 ± 0.252 | 0.888 ± 0.047 | 17 | | [Rapt](https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.rapt.html) | ✓ | | 0.791 ± 0.282 | 0.827 ± 0.078 | 18 | | [Reaper](https://github.com/r9y9/pyreaper) | ✓ | | 0.792 ± 0.261 | 0.795 ± 0.083 | 19 | | [Spice](https://ai.googleblog.com/2019/11/spice-self-supervised-pitch-estimation.html) | ✓ | ✓ | 0.024 ± 0.034 | 0.889 ± 0.046 | 20 | | [Swipe](https://pysptk.readthedocs.io/en/stable/generated/pysptk.sptk.swipe.html) | ✓ | | 0.784 ± 0.276 | 0.796 ± 0.062 | 21 | | [TorchCrepe](https://github.com/maxrmorrison/torchcrepe) | ✓ | ✓ | 0.764 ± 0.273 | 0.774 ± 0.084 | 22 | | [TorchYin](https://github.com/brentspell/torch-yin) | ✓ | | 0.735 ± 0.268 | 0.866 ± 0.052 | 23 | | [World](https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder) | ✓ | | 0.869 ± 0.135 | 0.842 ± 0.075 | 24 | | [Yaapt](http://bjbschmitt.github.io/AMFM_decompy/pYAAPT.html#amfm_decompy.pYAAPT.yaapt) | ✓ | | 0.740 ± 0.271 | 0.744 ± 0.107 | 25 | | [Yin](https://librosa.org/doc/latest/generated/librosa.yin.html#librosa.yin) | ✓ | | 0.749 ± 0.269 | 0.884 ± 0.043 | 26 | 27 | 28 | accuracy is mean [raw pitch accuracy](http://craffel.github.io/mir_eval/#mir_eval.melody.raw_pitch_accuracy) 29 | 30 | ## install 31 | ```bash 32 | pip install pitch-detectors 33 | ``` 34 | 35 | All agorithms tested on python3.12, this is recommended python version to use 36 | 37 | ## usage 38 | 39 | ```python 40 | from scipy.io import wavfile 41 | from pitch_detectors import algorithms 42 | import matplotlib.pyplot as plt 43 | 44 | fs, a = wavfile.read('data/b1a5da49d564a7341e7e1327aa3f229a.wav') 45 | pitch = algorithms.Crepe(a, fs) 46 | plt.plot(pitch.t, pitch.f0) 47 | plt.show() 48 | ``` 49 | 50 | ![Alt text](data/b1a5da49d564a7341e7e1327aa3f229a.png) 51 | [Colab notebook with plots for all algorithms/models](https://colab.research.google.com/drive/1PVsk4ygDZIhIO3GEIukQJOKkgibqoG1n) 52 | 53 | 54 | ## additional features 55 | - [ ] robust (vote + median) ensemble algorithm using all models 56 | - [ ] json import/export 57 | 58 | ## notes: 59 | Tests are running in subprocess (using `scripts/run_algorithm.py`) to avoid pytorch cuda import caching. 60 | It's difficult to disable gpu after it has been initialized. (https://github.com/pytorch/pytorch/issues/9158) 61 | It is also difficult to set correct PATH and LD_LIBRARY_PATH without a subprocess. 62 | --------------------------------------------------------------------------------