├── onnxcrepe ├── assets │ └── .gitkeep ├── __init__.py ├── load.py ├── session.py ├── configs │ └── default.json ├── convert.py ├── loudness.py ├── decode.py ├── filter.py ├── threshold.py ├── __main__.py └── core.py ├── tests ├── assets │ ├── test.wav │ ├── frames-crepe.npy │ ├── activation-full.npy │ └── activation-tiny.npy ├── test_decode.py ├── test_threshold.py ├── test_core.py └── conftest.py ├── samples ├── assets │ ├── xtgg_mono_16k_denoise.wav │ └── xtgg_mono_16k_original.wav └── demo.py ├── LICENSE ├── README.md └── .gitignore /onnxcrepe/assets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/assets/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/tests/assets/test.wav -------------------------------------------------------------------------------- /tests/assets/frames-crepe.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/tests/assets/frames-crepe.npy -------------------------------------------------------------------------------- /tests/assets/activation-full.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/tests/assets/activation-full.npy -------------------------------------------------------------------------------- /tests/assets/activation-tiny.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/tests/assets/activation-tiny.npy -------------------------------------------------------------------------------- /samples/assets/xtgg_mono_16k_denoise.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/samples/assets/xtgg_mono_16k_denoise.wav -------------------------------------------------------------------------------- /samples/assets/xtgg_mono_16k_original.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yqzhishen/onnxcrepe/HEAD/samples/assets/xtgg_mono_16k_original.wav -------------------------------------------------------------------------------- /onnxcrepe/__init__.py: -------------------------------------------------------------------------------- 1 | from . import decode 2 | from .core import * 3 | from . import convert 4 | from . import filter 5 | from . import load 6 | from . import loudness 7 | from .session import CrepeInferenceSession 8 | from . import threshold 9 | -------------------------------------------------------------------------------- /onnxcrepe/load.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | 5 | def audio(filename): 6 | """Load audio from disk""" 7 | samples, sr = librosa.load(filename, sr=None) 8 | if len(samples.shape) > 1: 9 | # To mono 10 | samples = np.mean(samples, axis=1) 11 | 12 | return samples, sr 13 | -------------------------------------------------------------------------------- /onnxcrepe/session.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import onnxruntime as ort 4 | 5 | 6 | class CrepeInferenceSession(ort.InferenceSession): 7 | def __init__(self, model='full', sess_options=None, providers=None, provider_options=None, **kwargs): 8 | model_path = os.path.join(os.path.dirname(__file__), 'assets', f'{model}.onnx') 9 | super().__init__(model_path, sess_options, providers, provider_options, **kwargs) 10 | -------------------------------------------------------------------------------- /tests/test_decode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnxcrepe 3 | 4 | 5 | ############################################################################### 6 | # Test decode.py 7 | ############################################################################### 8 | 9 | 10 | def test_weighted_argmax_decode(): 11 | """Tests that weighted argmax decode works without CUDA assertion error""" 12 | fake_logits = np.random.rand(8, 360, 128) 13 | decoded = onnxcrepe.decode.weighted_argmax(fake_logits) 14 | -------------------------------------------------------------------------------- /onnxcrepe/configs/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "full", 3 | "precision": 5.0, 4 | "batch_size": 512, 5 | "pad": true, 6 | "fmin": 50.0, 7 | "fmax": 1100.0, 8 | "decoder": "weighted_viterbi", 9 | "threshold": { 10 | "type": "hard", 11 | "arguments": { 12 | "at": 0.21 13 | } 14 | }, 15 | "providers": [ 16 | { 17 | "name": "CUDAExecutionProvider", 18 | "options": { 19 | "device_id": 0, 20 | "cudnn_conv_algo_search": "DEFAULT" 21 | } 22 | }, 23 | { 24 | "name": "DmlExecutionProvider", 25 | "options": { 26 | "device_id": 0 27 | } 28 | }, 29 | { 30 | "name": "CPUExecutionProvider", 31 | "options": {} 32 | } 33 | ] 34 | } 35 | -------------------------------------------------------------------------------- /tests/test_threshold.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxcrepe 4 | 5 | 6 | ############################################################################### 7 | # Test threshold.py 8 | ############################################################################### 9 | 10 | 11 | def test_at(): 12 | """Test onnxcrepe.threshold.At""" 13 | input_pitch = np.array([100., 110., 120., 130., 140.]) 14 | periodicity = np.array([.19, .22, .25, .17, .30]) 15 | 16 | # Perform thresholding 17 | output_pitch = onnxcrepe.threshold.At(.20)(input_pitch, periodicity) 18 | 19 | # Ensure thresholding is not in-place 20 | assert not (input_pitch == output_pitch).all() 21 | 22 | # Ensure certain frames are marked as unvoiced 23 | isnan = np.isnan(output_pitch) 24 | assert isnan[0] and isnan[3] 25 | assert not isnan[1] and not isnan[2] and not isnan[4] 26 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnxcrepe 3 | 4 | 5 | ############################################################################### 6 | # Test core.py 7 | ############################################################################### 8 | 9 | 10 | def test_infer_tiny(frames, activation_tiny): 11 | """Test that inference is the same as the original crepe""" 12 | activation = onnxcrepe.infer( 13 | onnxcrepe.CrepeInferenceSession('tiny', providers=['CPUExecutionProvider']), frames) 14 | diff = np.abs(activation - activation_tiny) 15 | 16 | # ONNX output are not strictly the same as the original CREPE 17 | assert diff.max() < 0.5 and diff.mean() < 5e-3 18 | 19 | 20 | def test_infer_full(frames, activation_full): 21 | """Test that inference is the same as the original crepe""" 22 | activation = onnxcrepe.infer( 23 | onnxcrepe.CrepeInferenceSession('full', providers=['CPUExecutionProvider']), frames) 24 | diff = np.abs(activation - activation_full) 25 | 26 | # ONNX output are not strictly the same as the original CREPE 27 | assert diff.max() < 0.5 and diff.mean() < 5e-3 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 yqzhishen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /samples/demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxcrepe 4 | from onnxcrepe.session import CrepeInferenceSession 5 | 6 | # Load audio 7 | audio, sr = onnxcrepe.load.audio(r'assets/xtgg_mono_16k_denoise.wav') 8 | 9 | # Here we'll use a 5 millisecond hop length 10 | precision = 5.0 11 | 12 | # Provide a sensible frequency range for your domain (upper limit is 2006 Hz) 13 | # This would be a reasonable range for speech 14 | fmin = 50 15 | fmax = 1100 16 | 17 | # Select a model capacity--one of "full", "large", "medium", "small" and "tiny" 18 | model = 'full' 19 | 20 | # Choose execution providers to use for inference 21 | providers = ['CUDAExecutionProvider', 'DmlExecutionProvider', 'CPUExecutionProvider'] 22 | 23 | # Pick a batch size that doesn't cause memory errors on your device 24 | batch_size = 1024 25 | 26 | # Create inference session 27 | session = CrepeInferenceSession( 28 | model='full', 29 | providers=providers) 30 | 31 | # Compute pitch using the default DirectML GPU or CPU 32 | pitch = onnxcrepe.predict(session, audio, sr, precision=precision, fmin=fmin, fmax=fmax, batch_size=batch_size) 33 | print(pitch.shape) 34 | print(np.mean(pitch)) 35 | print(np.var(pitch)) 36 | 37 | # Dispose inference session 38 | del session 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # onnxcrepe 2 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 3 | 4 | ONNX deployment of the CREPE [1] pitch tracker. The provided model weights and most of the codes in this repository were converted and migrated from the original TensorFlow implementation [here](https://github.com/marl/crepe/) and Max Morrison's [torchcrepe](https://github.com/maxrmorrison/torchcrepe), a PyTorch implementation of CREPE. 5 | 6 | 7 | ## Usage 8 | 9 | Download model weights from [releases](https://github.com/yqzhishen/onnxcrepe/releases) and put them into the `onnxcrepe/assets/` directory. See demo [here](samples/demo.py). 10 | 11 | Documentation of this repository is still a work in progress and is comming soon. 12 | 13 | 14 | ## Acknowledgements 15 | Codes and model weights in this repository are based on the following repos: 16 | - [torchcrepe](https://github.com/maxrmorrison/torchcrepe) for 'full' and 'tiny' model weights and most of the code implementation 17 | - [Weights_Keras_2_Pytorch](https://github.com/AgCl-LHY/Weights_Keras_2_Pytorch) for converting 'large', 'medium' and 'small' model weights from the original implementation 18 | - [PyTorch](https://github.com/pytorch/pytorch) for exporting onnx models 19 | - [onnx-optimizer](https://github.com/onnx/optimizer) and [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) for optimizing performance 20 | - [onnxruntime](https://github.com/microsoft/onnxruntime) for execution and configurations 21 | 22 | 23 | ## References 24 | [1] J. W. Kim, J. Salamon, P. Li, and J. P. Bello, “[Crepe: A Convolutional Representation for Pitch Estimation](https://arxiv.org/abs/1802.06182),” in 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import onnxcrepe 7 | 8 | 9 | ############################################################################### 10 | # Testing fixtures 11 | ############################################################################### 12 | 13 | 14 | @pytest.fixture(scope='session') 15 | def activation_full(): 16 | """Retrieve the original crepe activation on the test audio""" 17 | return np.load(path('activation-full.npy')) 18 | 19 | 20 | @pytest.fixture(scope='session') 21 | def activation_tiny(): 22 | """Retrieve the original crepe activation on the test audio""" 23 | return np.load(path('activation-tiny.npy')) 24 | 25 | 26 | @pytest.fixture(scope='session') 27 | def audio(): 28 | """Retrieve the test audio""" 29 | audio, sample_rate = onnxcrepe.load.audio(path('test.wav')) 30 | if sample_rate != onnxcrepe.SAMPLE_RATE: 31 | audio = onnxcrepe.resample(audio, sample_rate) 32 | return audio 33 | 34 | 35 | @pytest.fixture(scope='session') 36 | def frames(): 37 | """Retrieve the preprocessed frames for inference 38 | 39 | Note: the reason we load frames from disk rather than compute ourselves 40 | is that the normalizing process in the preprocessing isn't numerically 41 | stable. Therefore, we use the exact same preprocessed features that were 42 | passed through crepe to retrieve the activations--thus bypassing the 43 | preprocessing step. 44 | """ 45 | return np.load(path('frames-crepe.npy')) 46 | 47 | 48 | ############################################################################### 49 | # Utilities 50 | ############################################################################### 51 | 52 | 53 | def path(file): 54 | """Retrieve the path to the test file""" 55 | return os.path.join(os.path.dirname(__file__), 'assets', file) 56 | -------------------------------------------------------------------------------- /onnxcrepe/convert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | 4 | import onnxcrepe 5 | 6 | 7 | ############################################################################### 8 | # Pitch unit conversions 9 | ############################################################################### 10 | 11 | 12 | def bins_to_cents(bins, apply_dither=False): 13 | """Converts pitch bins to cents""" 14 | cents = onnxcrepe.CENTS_PER_BIN * bins + 1997.3794084376191 15 | 16 | # Trade quantization error for noise (disabled by default) 17 | return dither(cents) if apply_dither else cents 18 | 19 | 20 | def bins_to_frequency(bins, apply_dither=False): 21 | """Converts pitch bins to frequency in Hz""" 22 | return cents_to_frequency(bins_to_cents(bins, apply_dither=apply_dither)) 23 | 24 | 25 | def cents_to_bins(cents, quantize_fn=np.floor): 26 | """Converts cents to pitch bins""" 27 | bins = (cents - 1997.3794084376191) / onnxcrepe.CENTS_PER_BIN 28 | return quantize_fn(bins).astype(np.int64) 29 | 30 | 31 | def cents_to_frequency(cents): 32 | """Converts cents to frequency in Hz""" 33 | return 10 * 2 ** (cents / 1200) 34 | 35 | 36 | def frequency_to_bins(frequency, quantize_fn=np.floor): 37 | """Convert frequency in Hz to pitch bins""" 38 | return cents_to_bins(frequency_to_cents(frequency), quantize_fn) 39 | 40 | 41 | def frequency_to_cents(frequency): 42 | """Convert frequency in Hz to cents""" 43 | return 1200 * np.log2(frequency / 10.) 44 | 45 | 46 | ############################################################################### 47 | # Utilities 48 | ############################################################################### 49 | 50 | 51 | def dither(cents): 52 | """Dither the predicted pitch in cents to remove quantization error""" 53 | noise = scipy.stats.triang.rvs(c=0.5, 54 | loc=-onnxcrepe.CENTS_PER_BIN, 55 | scale=2 * onnxcrepe.CENTS_PER_BIN, 56 | size=cents.shape) 57 | return cents + noise.astype(cents.dtype) 58 | -------------------------------------------------------------------------------- /onnxcrepe/loudness.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | import onnxcrepe 7 | 8 | 9 | ############################################################################### 10 | # Constants 11 | ############################################################################### 12 | 13 | 14 | # Minimum decibel level 15 | MIN_DB = -100. 16 | 17 | # Reference decibel level 18 | REF_DB = 20. 19 | 20 | 21 | ############################################################################### 22 | # A-weighted loudness 23 | ############################################################################### 24 | 25 | 26 | def a_weighted(audio, sample_rate, hop_length=None, pad=True): 27 | """Retrieve the per-frame loudness""" 28 | 29 | # Default hop length of 10 ms 30 | hop_length = sample_rate // 100 if hop_length is None else hop_length 31 | 32 | # Convert to numpy 33 | audio = audio.squeeze(0) 34 | 35 | # Resample 36 | if sample_rate != onnxcrepe.SAMPLE_RATE: 37 | audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=onnxcrepe.SAMPLE_RATE) 38 | hop_length = int(hop_length * onnxcrepe.SAMPLE_RATE / sample_rate) 39 | 40 | # Cache weights 41 | if not hasattr(a_weighted, 'weights'): 42 | a_weighted.weights = perceptual_weights() 43 | 44 | # Take stft 45 | stft = librosa.stft(audio, 46 | n_fft=onnxcrepe.WINDOW_SIZE, 47 | hop_length=hop_length, 48 | win_length=onnxcrepe.WINDOW_SIZE, 49 | center=pad, 50 | pad_mode='constant') 51 | 52 | # Compute magnitude on db scale 53 | db = librosa.amplitude_to_db(np.abs(stft)) 54 | 55 | # Apply A-weighting 56 | weighted = db + a_weighted.weights 57 | 58 | # Threshold 59 | weighted[weighted < MIN_DB] = MIN_DB 60 | 61 | # Average over weighted frequencies 62 | return weighted.mean(axis=0).astype(np.float32)[None] 63 | 64 | 65 | def perceptual_weights(): 66 | """A-weighted frequency-dependent perceptual loudness weights""" 67 | frequencies = librosa.fft_frequencies(sr=onnxcrepe.SAMPLE_RATE, 68 | n_fft=onnxcrepe.WINDOW_SIZE) 69 | 70 | # A warning is raised for nearly inaudible frequencies, but it ends up 71 | # defaulting to -100 db. That default is fine for our purposes. 72 | with warnings.catch_warnings(): 73 | warnings.simplefilter('ignore', RuntimeWarning) 74 | return librosa.A_weighting(frequencies)[:, None] - REF_DB 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # JetBrains PyCharm 132 | .idea/ 133 | 134 | # Model weights 135 | *.onnx 136 | -------------------------------------------------------------------------------- /onnxcrepe/decode.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | import onnxcrepe 5 | 6 | 7 | ############################################################################### 8 | # Probability sequence decoding methods 9 | ############################################################################### 10 | 11 | 12 | def argmax(logits): 13 | """Sample observations by taking the argmax""" 14 | bins = logits.argmax(axis=1) 15 | 16 | # Convert to frequency in Hz 17 | return bins, onnxcrepe.convert.bins_to_frequency(bins) 18 | 19 | 20 | def weighted_argmax(logits: np.ndarray): 21 | """Sample observations using weighted sum near the argmax""" 22 | # Find center of analysis window 23 | bins = logits.argmax(axis=1) 24 | 25 | return bins, _apply_weights(logits, bins) 26 | 27 | 28 | def viterbi(logits): 29 | """Sample observations using viterbi decoding""" 30 | # Create viterbi transition matrix 31 | if not hasattr(viterbi, 'transition'): 32 | xx, yy = np.meshgrid(range(360), range(360)) 33 | transition = np.maximum(12 - abs(xx - yy), 0) 34 | transition = transition / transition.sum(axis=1, keepdims=True) 35 | viterbi.transition = transition 36 | 37 | # Normalize logits (softmax) 38 | logits -= logits.max(axis=1) 39 | exp = np.exp(logits) 40 | probs = exp / np.sum(exp, axis=1) 41 | 42 | # Perform viterbi decoding 43 | bins = np.array([ 44 | librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) 45 | for sequence in probs]) 46 | 47 | # Convert to frequency in Hz 48 | return bins, onnxcrepe.convert.bins_to_frequency(bins) 49 | 50 | 51 | def weighted_viterbi(logits): 52 | """Sample observations combining viterbi decoding and weighted argmax""" 53 | bins, _ = viterbi(logits) 54 | 55 | return bins, _apply_weights(logits, bins) 56 | 57 | 58 | def _apply_weights(logits, bins): 59 | # Find bounds of analysis window 60 | start = np.maximum(0, bins - 4) 61 | end = np.minimum(logits.shape[1], bins + 5) 62 | 63 | # Mask out everything outside of window 64 | for batch in range(logits.shape[0]): 65 | for time in range(logits.shape[2]): 66 | logits[batch, :start[batch, time], time] = float('-inf') 67 | logits[batch, end[batch, time]:, time] = float('-inf') 68 | 69 | # Construct weights 70 | if not hasattr(_apply_weights, 'weights'): 71 | weights = onnxcrepe.convert.bins_to_cents(np.arange(360)) 72 | _apply_weights.weights = weights[None, :, None] 73 | 74 | # Convert to probabilities (ReLU) 75 | probs = np.maximum(0, logits) 76 | 77 | # Apply weights 78 | cents = (_apply_weights.weights * probs).sum(axis=1) / probs.sum(axis=1) 79 | 80 | # Convert to frequency in Hz 81 | return onnxcrepe.convert.cents_to_frequency(cents) 82 | -------------------------------------------------------------------------------- /onnxcrepe/filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | ############################################################################### 5 | # Sequence filters 6 | ############################################################################### 7 | 8 | 9 | def mean(signals, win_length=9): 10 | """Averave filtering for signals containing nan values 11 | 12 | Arguments 13 | signals (numpy.ndarray (shape=(batch, time))) 14 | The signals to filter 15 | win_length 16 | The size of the analysis window 17 | 18 | Returns 19 | filtered (numpy.ndarray (shape=(batch, time))) 20 | """ 21 | return nanfilter(signals, win_length, nanmean) 22 | 23 | 24 | def median(signals, win_length): 25 | """Median filtering for signals containing nan values 26 | 27 | Arguments 28 | signals (numpy.ndarray (shape=(batch, time))) 29 | The signals to filter 30 | win_length 31 | The size of the analysis window 32 | 33 | Returns 34 | filtered (numpy.ndarray (shape=(batch, time))) 35 | """ 36 | return nanfilter(signals, win_length, nanmedian) 37 | 38 | 39 | ############################################################################### 40 | # Utilities 41 | ############################################################################### 42 | 43 | 44 | def nanfilter(signals, win_length, filter_fn): 45 | """Filters a sequence, ignoring nan values 46 | 47 | Arguments 48 | signals (numpy.ndarray (shape=(batch, time))) 49 | The signals to filter 50 | win_length 51 | The size of the analysis window 52 | filter_fn (function) 53 | The function to use for filtering 54 | 55 | Returns 56 | filtered (numpy.ndarray (shape=(batch, time))) 57 | """ 58 | # Output buffer 59 | filtered = np.empty_like(signals) 60 | 61 | # Loop over frames 62 | for i in range(signals.size(1)): 63 | 64 | # Get analysis window bounds 65 | start = max(0, i - win_length // 2) 66 | end = min(signals.size(1), i + win_length // 2 + 1) 67 | 68 | # Apply filter to window 69 | filtered[:, i] = filter_fn(signals[:, start:end]) 70 | 71 | return filtered 72 | 73 | 74 | def nanmean(signals): 75 | """Computes the mean, ignoring nans 76 | 77 | Arguments 78 | signals (numpy.ndarray [shape=(batch, time)]) 79 | The signals to filter 80 | 81 | Returns 82 | filtered (numpy.ndarray [shape=(batch, time)]) 83 | """ 84 | signals = signals.clone() 85 | 86 | # Find nans 87 | nans = np.isnan(signals) 88 | 89 | # Set nans to 0. 90 | signals[nans] = 0. 91 | 92 | # Compute average 93 | return signals.sum(axis=1) / (~nans).astype(np.float32).sum(axis=1) 94 | 95 | 96 | def nanmedian(signals): 97 | """Computes the median, ignoring nans 98 | 99 | Arguments 100 | signals (numpy.ndarray [shape=(batch, time)]) 101 | The signals to filter 102 | 103 | Returns 104 | filtered (numpy.ndarray [shape=(batch, time)]) 105 | """ 106 | # Find nans 107 | nans = np.isnan(signals) 108 | 109 | # Compute median for each slice 110 | medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)] 111 | 112 | # Stack results 113 | return np.array(medians, dtype=signals.dtype) 114 | 115 | 116 | def nanmedian1d(signal): 117 | """Computes the median. If signal is empty, returns torch.nan 118 | 119 | Arguments 120 | signal (numpy.ndarray [shape=(time,)]) 121 | 122 | Returns 123 | median (numpy.ndarray [shape=(1,)]) 124 | """ 125 | return np.median(signal) if signal.size else np.nan 126 | -------------------------------------------------------------------------------- /onnxcrepe/threshold.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import onnxcrepe 4 | 5 | 6 | ############################################################################### 7 | # Pitch thresholding methods 8 | ############################################################################### 9 | 10 | 11 | class At: 12 | """Simple thresholding at a specified probability value""" 13 | 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def __call__(self, pitch, periodicity): 18 | # Make a copy to prevent in-place modification 19 | pitch = pitch.copy() 20 | 21 | # Threshold 22 | pitch[periodicity < self.value] = onnxcrepe.UNVOICED 23 | return pitch 24 | 25 | 26 | class Hysteresis: 27 | """Hysteresis thresholding""" 28 | 29 | def __init__(self, 30 | lower_bound=.19, 31 | upper_bound=.31, 32 | width=.2, 33 | stds=1.7, 34 | return_threshold=False): 35 | self.lower_bound = lower_bound 36 | self.upper_bound = upper_bound 37 | self.width = width 38 | self.stds = stds 39 | self.return_threshold = return_threshold 40 | 41 | def __call__(self, pitch, periodicity): 42 | 43 | # Perform hysteresis in log-2 space 44 | pitch = np.log2(pitch).flatten() 45 | 46 | # Flatten periodicity 47 | periodicity = periodicity.flatten() 48 | 49 | # Ignore confidently unvoiced pitch 50 | pitch[periodicity < self.lower_bound] = onnxcrepe.UNVOICED 51 | 52 | # Whiten pitch 53 | mean, std = np.nanmean(pitch), np.nanstd(pitch) 54 | pitch = (pitch - mean) / std 55 | 56 | # Require high confidence to make predictions far from the mean 57 | parabola = self.width * pitch ** 2 - self.width * self.stds ** 2 58 | threshold = self.lower_bound + np.clip(parabola, 0, 1 - self.lower_bound) 59 | threshold[np.isnan(threshold)] = self.lower_bound 60 | 61 | # Apply hysteresis to prevent short, unconfident voiced regions 62 | i = 0 63 | while i < len(periodicity) - 1: 64 | 65 | # Detect unvoiced to voiced transition 66 | if periodicity[i] < threshold[i] and periodicity[i + 1] > threshold[i + 1]: 67 | 68 | # Grow region until next unvoiced or end of array 69 | start, end, keep = i + 1, i + 1, False 70 | while end < len(periodicity) and periodicity[end] > threshold[end]: 71 | if periodicity[end] > self.upper_bound: 72 | keep = True 73 | end += 1 74 | 75 | # Force unvoiced if we didn't pass the confidence required by 76 | # the hysteresis 77 | if not keep: 78 | threshold[start:end] = 1 79 | 80 | i = end 81 | 82 | else: 83 | i += 1 84 | 85 | # Remove pitch with low periodicity 86 | pitch[periodicity < threshold] = onnxcrepe.UNVOICED 87 | 88 | # Unwhiten 89 | pitch = pitch * std + mean 90 | 91 | # Convert to Hz 92 | pitch = np.array(2 ** pitch)[None, :] 93 | 94 | # Optionally return threshold 95 | if self.return_threshold: 96 | return pitch, np.array(threshold) 97 | 98 | return pitch 99 | 100 | 101 | ############################################################################### 102 | # Periodicity thresholding methods 103 | ############################################################################### 104 | 105 | 106 | class Silence: 107 | """Set periodicity to zero in silent regions""" 108 | 109 | def __init__(self, value=-60): 110 | self.value = value 111 | 112 | def __call__(self, 113 | periodicity, 114 | audio, 115 | sample_rate=onnxcrepe.SAMPLE_RATE, 116 | precision=None, 117 | pad=True): 118 | # Don't modify in-place 119 | periodicity = periodicity.copy() 120 | 121 | # Compute loudness 122 | hop_length = sample_rate * precision // 1000 123 | loudness = onnxcrepe.loudness.a_weighted( 124 | audio, sample_rate, hop_length, pad) 125 | 126 | # Threshold silence 127 | periodicity[loudness < self.value] = 0. 128 | 129 | return periodicity 130 | -------------------------------------------------------------------------------- /onnxcrepe/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import onnxruntime as ort 6 | 7 | import onnxcrepe 8 | from onnxcrepe.session import CrepeInferenceSession 9 | 10 | 11 | ############################################################################### 12 | # Entry point 13 | ############################################################################### 14 | 15 | 16 | def parse_args(): 17 | """Parse command-line arguments""" 18 | parser = argparse.ArgumentParser() 19 | 20 | # Required arguments 21 | parser.add_argument( 22 | 'audio_files', 23 | nargs='+', 24 | help='The audio file to process') 25 | 26 | # Optional arguments 27 | parser.add_argument( 28 | '--output_directory', 29 | required=False, 30 | help='The directory to save output') # Defaults to directory of each audio file 31 | parser.add_argument( 32 | '--save_periodicity', 33 | required=False, 34 | action='store_true', 35 | help='Whether save periodicity') 36 | parser.add_argument( 37 | '--format', 38 | required=False, 39 | default='csv', 40 | help='Saving format of the result (csv or npy)') # Combined .csv or separated .npy 41 | parser.add_argument( 42 | '--config', 43 | required=False, 44 | default='default', 45 | help='Customized configurations') 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def load_config(config: str): 51 | """Load configurations""" 52 | config_file = config if config.endswith('.json') else f'{config}.json' 53 | path = os.path.join(os.path.dirname(__file__), 'configs', config_file) 54 | 55 | with open(path, 'r', encoding='utf-8') as cfg: 56 | config = json.load(cfg) 57 | 58 | return config 59 | 60 | 61 | def main(): 62 | # Parse command-line arguments 63 | args = parse_args() 64 | 65 | # Check saving format 66 | if args.format != 'csv' and args.format != 'npy': 67 | raise NotImplementedError('Saving format must be \'csv\' or \'npy\'.') 68 | 69 | # Ensure output directory exist 70 | if args.output_directory is not None: 71 | os.makedirs(args.output_directory, exist_ok=True) 72 | 73 | # Load configurations 74 | config = load_config(args.config) 75 | 76 | # Check model capacity 77 | if config['model'] not in ['full', 'large', 'medium', 'small', 'tiny']: 78 | raise NotImplementedError( 79 | 'Model capacity must be \'full\', \'large\', \'medium\', \'small\' or \'tiny\'.') 80 | 81 | # Get decoder 82 | if config['decoder'] == 'argmax': 83 | decoder = onnxcrepe.decode.argmax 84 | elif config['decoder'] == 'weighted_argmax': 85 | decoder = onnxcrepe.decode.weighted_argmax 86 | elif config['decoder'] == 'viterbi': 87 | decoder = onnxcrepe.decode.viterbi 88 | elif config['decoder'] == 'weighted_viterbi': 89 | decoder = onnxcrepe.decode.weighted_viterbi 90 | else: 91 | raise NotImplementedError('Decoder must be \'argmax\', \'weighted_argmax\', \'viterbi\' or \'weighted_viterbi\'.') 92 | 93 | # Filter and parse providers 94 | available_providers_selected = [] 95 | for provider in config['providers']: 96 | if provider['name'] in ort.get_available_providers(): 97 | available_providers_selected.append(provider) 98 | else: 99 | print(f'{provider["name"]} is not available on this machine. Skipping.') 100 | 101 | if not available_providers_selected: 102 | raise NotImplementedError('None of the selected execution providers is available on this machine.') 103 | providers = [(provider['name'], provider['options']) for provider in available_providers_selected] 104 | 105 | # Create session options 106 | # DirectML does not support memory pattern optimizations or parallel execution in onnxruntime. See 107 | # https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#configuration-options 108 | options = ort.SessionOptions() 109 | if available_providers_selected[0]['name'] == 'DmlExecutionProvider': 110 | options.enable_mem_pattern = False 111 | options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL 112 | 113 | # Create inference session 114 | session = CrepeInferenceSession(model=config['model'], sess_options=options, providers=providers) 115 | 116 | # Infer pitch and save to disk 117 | onnxcrepe.predict_from_files_to_files(session, 118 | args.audio_files, 119 | args.output_directory, 120 | args.save_periodicity, 121 | args.format, 122 | config['precision'], 123 | config['fmin'], 124 | config['fmax'] if config['fmax'] is not None else onnxcrepe.MAX_FMAX, 125 | decoder, 126 | config['batch_size'], 127 | config['pad']) 128 | 129 | 130 | # Run module entry point 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /onnxcrepe/core.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import librosa 4 | import numpy as np 5 | import tqdm 6 | 7 | import onnxcrepe 8 | 9 | __all__ = ['CENTS_PER_BIN', 10 | 'MAX_FMAX', 11 | 'PITCH_BINS', 12 | 'SAMPLE_RATE', 13 | 'WINDOW_SIZE', 14 | 'UNVOICED', 15 | 'predict', 16 | 'predict_from_file', 17 | 'predict_from_file_to_file', 18 | 'predict_from_files_to_files', 19 | 'preprocess', 20 | 'infer', 21 | 'postprocess', 22 | 'resample'] 23 | 24 | ############################################################################### 25 | # Constants 26 | ############################################################################### 27 | 28 | 29 | CENTS_PER_BIN = 20 # cents 30 | MAX_FMAX = 2006. # hz 31 | PITCH_BINS = 360 32 | SAMPLE_RATE = 16000 # hz 33 | WINDOW_SIZE = 1024 # samples 34 | UNVOICED = np.nan 35 | 36 | 37 | ############################################################################### 38 | # Crepe pitch prediction 39 | ############################################################################### 40 | 41 | 42 | def predict(session, 43 | audio, 44 | sample_rate, 45 | precision=None, 46 | fmin=50., 47 | fmax=MAX_FMAX, 48 | decoder=onnxcrepe.decode.weighted_viterbi, 49 | return_periodicity=False, 50 | batch_size=None, 51 | pad=True): 52 | """Performs pitch estimation 53 | 54 | Arguments 55 | session (onnxcrepe.CrepeInferenceSession) 56 | An onnxruntime.InferenceSession holding the CREPE model 57 | audio (numpy.ndarray [shape=(n_samples,)]) 58 | The audio signal 59 | sample_rate (int) 60 | The sampling rate in Hz 61 | precision (float) 62 | The precision in milliseconds, i.e. the length of each frame 63 | fmin (float) 64 | The minimum allowable frequency in Hz 65 | fmax (float) 66 | The maximum allowable frequency in Hz 67 | decoder (function) 68 | The decoder to use. See decode.py for decoders. 69 | return_periodicity (bool) 70 | Whether to also return the network confidence 71 | batch_size (int) 72 | The number of frames per batch 73 | pad (bool) 74 | Whether to zero-pad the audio 75 | 76 | Returns 77 | pitch (numpy.ndarray [shape=(1, 1 + int(time // precision))]) 78 | (Optional) periodicity (numpy.ndarray 79 | [shape=(1, 1 + int(time // precision))]) 80 | """ 81 | 82 | results = [] 83 | 84 | # Preprocess audio 85 | generator = preprocess(audio, 86 | sample_rate, 87 | precision, 88 | batch_size, 89 | pad) 90 | for frames in generator: 91 | 92 | # Infer independent probabilities for each pitch bin 93 | probabilities = infer(session, frames) # shape=(batch, 360) 94 | 95 | probabilities = probabilities.transpose(1, 0)[None] # shape=(1, 360, batch) 96 | 97 | # Convert probabilities to F0 and periodicity 98 | result = postprocess(probabilities, 99 | fmin, 100 | fmax, 101 | decoder, 102 | return_periodicity) 103 | 104 | # Place on same device as audio to allow very long inputs 105 | if isinstance(result, tuple): 106 | result = (result[0], result[1]) 107 | 108 | results.append(result) 109 | 110 | # Split pitch and periodicity 111 | if return_periodicity: 112 | pitch, periodicity = zip(*results) 113 | return np.concatenate(pitch, axis=1), np.concatenate(periodicity, axis=1) 114 | 115 | # Concatenate 116 | return np.concatenate(results, axis=1) 117 | 118 | 119 | def predict_from_file(session, 120 | audio_file, 121 | precision=None, 122 | fmin=50., 123 | fmax=MAX_FMAX, 124 | decoder=onnxcrepe.decode.weighted_viterbi, 125 | return_periodicity=False, 126 | batch_size=None, 127 | pad=True): 128 | """Performs pitch estimation from file on disk 129 | 130 | Arguments 131 | session (onnxcrepe.CrepeInferenceSession) 132 | An onnxruntime.InferenceSession holding the CREPE model 133 | audio_file (string) 134 | The file to perform pitch tracking on 135 | precision (float) 136 | The precision in milliseconds, i.e. the length of each frame 137 | fmin (float) 138 | The minimum allowable frequency in Hz 139 | fmax (float) 140 | The maximum allowable frequency in Hz 141 | decoder (function) 142 | The decoder to use. See decode.py for decoders. 143 | return_harmonicity (bool) [DEPRECATED] 144 | Whether to also return the network confidence 145 | return_periodicity (bool) 146 | Whether to also return the network confidence 147 | batch_size (int) 148 | The number of frames per batch 149 | device (string) 150 | The device used to run inference 151 | pad (bool) 152 | Whether to zero-pad the audio 153 | 154 | Returns 155 | pitch (numpy.ndarray [shape=(1, 1 + int(time // precision))]) 156 | (Optional) periodicity (torch.tensor 157 | [shape=(1, 1 + int(time // precision))]) 158 | """ 159 | # Load audio 160 | audio, sample_rate = onnxcrepe.load.audio(audio_file) 161 | 162 | # Predict 163 | return predict(session, audio, sample_rate, precision, fmin, fmax, decoder, return_periodicity, batch_size, pad) 164 | 165 | 166 | def predict_from_file_to_file(session, 167 | audio_file, 168 | output_directory=None, 169 | save_periodicity=False, 170 | format='csv', 171 | precision=None, 172 | fmin=50., 173 | fmax=MAX_FMAX, 174 | decoder=onnxcrepe.decode.weighted_viterbi, 175 | batch_size=None, 176 | pad=True): 177 | """Performs pitch estimation from file on disk 178 | 179 | Arguments 180 | session (onnxcrepe.CrepeInferenceSession) 181 | An onnxruntime.InferenceSession holding the CREPE model 182 | audio_file (string) 183 | The file to perform pitch tracking on 184 | output_directory (string or None) 185 | The directory to save results. 186 | None means saving results in the same directory as the audio file. 187 | save_periodicity (bool) 188 | Whether save predicted periodicity 189 | format (string) 190 | The output format. 'csv' means combined csv file and 191 | 'npy' means separated npy files (pitch and periodicity). 192 | precision (float) 193 | The precision in milliseconds, i.e. the length of each frame 194 | fmin (float) 195 | The minimum allowable frequency in Hz 196 | fmax (float) 197 | The maximum allowable frequency in Hz 198 | model (string) 199 | The model capacity. One of 'full' or 'tiny'. 200 | decoder (function) 201 | The decoder to use. See decode.py for decoders. 202 | batch_size (int) 203 | The number of frames per batch 204 | device (string) 205 | The device used to run inference 206 | pad (bool) 207 | Whether to zero-pad the audio 208 | """ 209 | 210 | # Predict from file 211 | prediction = predict_from_file(session, audio_file, precision, fmin, fmax, decoder, 212 | save_periodicity, batch_size, pad) 213 | 214 | # Get audio filename without extension 215 | title = os.path.basename(audio_file).rsplit('.', maxsplit=1)[0] 216 | 217 | # Get output directory 218 | if output_directory is None: 219 | output_directory = os.path.dirname(audio_file) 220 | 221 | # Save to disk 222 | if format == 'csv': 223 | with open(os.path.join(output_directory, f'{title}.pitch.csv'), 'w') as f: 224 | if save_periodicity: 225 | for i in range(prediction[0].shape[1]): 226 | # time, f0, periodicity 227 | print('%f,%f,%f' 228 | % (i * precision / 1000., prediction[0][0][i], prediction[1][0][i]), 229 | file=f) 230 | else: 231 | for i in range(prediction.shape[1]): 232 | # time, f0 233 | print('%f,%f' 234 | % (i * precision / 1000., prediction[0][i]), 235 | file=f) 236 | elif format == 'npy': 237 | np.save(os.path.join(output_directory, f'{title}.f0.npy'), prediction[0]) 238 | if save_periodicity: 239 | np.save(os.path.join(output_directory, f'{title}.periodicity.npy'), prediction[1]) 240 | 241 | 242 | def predict_from_files_to_files(session, 243 | audio_files, 244 | output_directory=None, 245 | save_periodicity=False, 246 | format='csv', 247 | precision=None, 248 | fmin=50., 249 | fmax=MAX_FMAX, 250 | decoder=onnxcrepe.decode.weighted_viterbi, 251 | batch_size=None, 252 | pad=True): 253 | """Performs pitch estimation from files on disk without reloading model 254 | 255 | Arguments 256 | session (onnxcrepe.CrepeInferenceSession) 257 | An onnxruntime.InferenceSession holding the CREPE model 258 | audio_files (list[string]) 259 | The files to perform pitch tracking on 260 | output_directory (string or None) 261 | The directory to save results. 262 | None means saving results in the same directory as each audio file. 263 | save_periodicity (bool) 264 | Whether save predicted periodicity 265 | format (string) 266 | The output format. 'csv' means combined csv file and 267 | 'npy' means separated npy files (pitch and periodicity). 268 | precision (float) 269 | The precision in milliseconds, i.e. the length of each frame 270 | fmin (float) 271 | The minimum allowable frequency in Hz 272 | fmax (float) 273 | The maximum allowable frequency in Hz 274 | model (string) 275 | The model capacity. One of 'full' or 'tiny'. 276 | decoder (function) 277 | The decoder to use. See decode.py for decoders. 278 | batch_size (int) 279 | The number of frames per batch 280 | device (string) 281 | The device used to run inference 282 | pad (bool) 283 | Whether to zero-pad the audio 284 | """ 285 | 286 | # Setup iterator 287 | iterator = tqdm.tqdm(audio_files, desc='onnxcrepe', dynamic_ncols=True) 288 | for audio_file in iterator: 289 | # Predict a file 290 | predict_from_file_to_file(session, audio_file, output_directory, save_periodicity, format, precision, fmin, 291 | fmax, decoder, batch_size, pad) 292 | 293 | 294 | ############################################################################### 295 | # Components for step-by-step prediction 296 | ############################################################################### 297 | 298 | 299 | def preprocess(audio, 300 | sample_rate, 301 | precision=None, 302 | batch_size=None, 303 | pad=True): 304 | """Convert audio to model input 305 | 306 | Arguments 307 | audio (numpy.ndarray [shape=(time,)]) 308 | The audio signals 309 | sample_rate (int) 310 | The sampling rate in Hz 311 | precision (float) 312 | The precision in milliseconds, i.e. the length of each frame 313 | batch_size (int) 314 | The number of frames per batch 315 | pad (bool) 316 | Whether to zero-pad the audio 317 | 318 | Returns 319 | frames (numpy.ndarray [shape=(1 + int(time // precision), 1024)]) 320 | """ 321 | # Resample 322 | if sample_rate != SAMPLE_RATE: 323 | audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=SAMPLE_RATE) 324 | 325 | # Default hop length of 10 ms 326 | hop_length = SAMPLE_RATE / 100 if precision is None else SAMPLE_RATE * precision / 1000 327 | 328 | # Get total number of frames 329 | 330 | # Maybe pad 331 | if pad: 332 | total_frames = 1 + int(audio.shape[0] / hop_length) 333 | audio = np.pad( 334 | audio, 335 | (WINDOW_SIZE // 2, WINDOW_SIZE // 2)) 336 | else: 337 | total_frames = 1 + int((audio.shape[0] - WINDOW_SIZE) / hop_length) 338 | 339 | # Default to running all frames in a single batch 340 | batch_size = total_frames if batch_size is None else batch_size 341 | 342 | # Generate batches 343 | for i in range(0, total_frames, batch_size): 344 | # Batch indices 345 | start = max(0, int(i * hop_length)) 346 | end = min(audio.shape[0], 347 | int((i + batch_size - 1) * hop_length) + WINDOW_SIZE) 348 | 349 | # Chunk 350 | n_bytes = audio.strides[-1] 351 | frames = np.lib.stride_tricks.as_strided( 352 | audio[start:end], 353 | shape=((end - start - WINDOW_SIZE) // int(hop_length) + 1, WINDOW_SIZE), 354 | strides=(int(hop_length) * n_bytes, n_bytes)) # shape=(batch, 1024) 355 | 356 | # Note: 357 | # Z-score standardization operations originally located here 358 | # (https://github.com/maxrmorrison/torchcrepe/blob/master/torchcrepe/core.py#L692) 359 | # are wrapped into the ONNX models for hardware acceleration. 360 | 361 | yield frames 362 | 363 | 364 | def infer(session, frames): 365 | """Forward pass through the model 366 | 367 | Arguments 368 | session (onnxcrepe.CrepeInferenceSession) 369 | An onnxruntime.InferenceSession holding the CREPE model 370 | frames (numpy.ndarray [shape=(time / precision, 1024)]) 371 | The network input 372 | 373 | Returns 374 | logits (numpy.ndarray [shape=(1 + int(time // precision), 360)]) 375 | """ 376 | # Apply model 377 | return session.run(None, {'frames': frames})[0] 378 | 379 | 380 | def postprocess(probabilities, 381 | fmin=0., 382 | fmax=MAX_FMAX, 383 | decoder=onnxcrepe.decode.weighted_viterbi, 384 | return_periodicity=False): 385 | """Convert model output to F0 and periodicity 386 | 387 | Arguments 388 | probabilities (numpy.ndarray [shape=(1, 360, time / precision)]) 389 | The probabilities for each pitch bin inferred by the network 390 | fmin (float) 391 | The minimum allowable frequency in Hz 392 | fmax (float) 393 | The maximum allowable frequency in Hz 394 | decoder (function) 395 | The decoder to use. See decode.py for decoders. 396 | return_periodicity (bool) 397 | Whether to also return the network confidence 398 | 399 | Returns 400 | pitch (numpy.ndarray [shape=(1, 1 + int(time // precision))]) 401 | periodicity (numpy.ndarray [shape=(1, 1 + int(time // precision))]) 402 | """ 403 | # Convert frequency range to pitch bin range 404 | minidx = onnxcrepe.convert.frequency_to_bins(fmin) 405 | maxidx = onnxcrepe.convert.frequency_to_bins(fmax, np.ceil) 406 | 407 | # Remove frequencies outside allowable range 408 | probabilities[:, :minidx] = float('-inf') 409 | probabilities[:, maxidx:] = float('-inf') 410 | 411 | # Perform argmax or viterbi sampling 412 | bins, pitch = decoder(probabilities) 413 | 414 | if not return_periodicity: 415 | return pitch 416 | 417 | # Compute periodicity from probabilities and decoded pitch bins 418 | return pitch, periodicity(probabilities, bins) 419 | 420 | 421 | ############################################################################### 422 | # Utilities 423 | ############################################################################### 424 | 425 | 426 | def periodicity(probabilities, bins): 427 | """Computes the periodicity from the network output and pitch bins""" 428 | # shape=(time / precision, 360) 429 | probs_stacked = probabilities.transpose(0, 2, 1).reshape(-1, PITCH_BINS) 430 | # shape=(time / precision, 1) 431 | bins_stacked = bins.reshape(-1, 1).astype(np.int64) 432 | 433 | # Use maximum logit over pitch bins as periodicity 434 | periodicity = np.take_along_axis(probs_stacked, bins_stacked, axis=1) 435 | 436 | # shape=(batch, time / precision) 437 | return periodicity.reshape(probabilities.shape[0], probabilities.shape[2]) 438 | 439 | 440 | def resample(audio, sample_rate): 441 | """Resample audio""" 442 | return librosa.resample(audio, orig_sr=sample_rate, target_sr=onnxcrepe.SAMPLE_RATE) 443 | --------------------------------------------------------------------------------