├── eval └── .gitkeep ├── runs └── .gitkeep ├── config ├── .gitkeep ├── fcnf0++.py ├── fcnf0++-mdb.py ├── fcnf0++-ptdb.py ├── dio.py ├── fcnf0++-ablate-decoder.py ├── fcnf0++-ablate-chunkviterbi-normal.py ├── crepe++.py ├── deepf0++.py ├── fcnf0++-ablate-batchsize.py ├── fcnf0++-ablate-loss.py ├── pyin.py ├── fcnf0++-ablate-layernorm.py ├── fcnf0++-ablate-unvoiced.py ├── fcnf0++-ablate-chunkviterbi.py ├── fcnf0++-ablate-inputnorm.py ├── fcnf0++-ablate-quantization.py ├── fcnf0++-ablate-earlystop.py ├── deepf0.py ├── fcnf0.py ├── crepe.py └── torchcrepe.py ├── results └── .gitkeep ├── data ├── cache │ └── .gitkeep ├── datasets │ └── .gitkeep └── sources │ └── .gitkeep ├── penn ├── config │ ├── __init__.py │ ├── static.py │ └── defaults.py ├── train │ ├── __init__.py │ ├── __main__.py │ └── core.py ├── partition │ ├── __init__.py │ ├── __main__.py │ └── core.py ├── plot │ ├── density │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── logits │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── threshold │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ └── __init__.py ├── data │ ├── download │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── preprocess │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── core.py │ ├── __init__.py │ ├── loader.py │ ├── sampler.py │ └── dataset.py ├── dsp │ ├── __init__.py │ ├── dio.py │ └── pyin.py ├── evaluate │ ├── __init__.py │ ├── __main__.py │ ├── metrics.py │ └── core.py ├── model │ ├── core.py │ ├── __init__.py │ ├── fcnf0.py │ ├── crepe.py │ └── deepf0.py ├── load.py ├── __init__.py ├── voicing.py ├── periodicity.py ├── __main__.py ├── convert.py ├── assets │ └── partitions │ │ └── mdb.json ├── decode.py └── core.py ├── test ├── assets │ ├── gershwin.wav │ ├── beethoven.wav │ └── 500Hz_stereo.wav ├── conftest.py ├── test_core.py └── test_convert.py ├── .gitignore ├── .github └── workflows │ └── run-tests.yml ├── LICENSE ├── setup.py ├── run.sh └── README.md /eval/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /runs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/cache/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/sources/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /penn/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/fcnf0++.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | -------------------------------------------------------------------------------- /penn/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/partition/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/plot/density/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/plot/logits/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/data/download/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/data/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/plot/threshold/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /penn/dsp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dio 2 | from . import pyin 3 | -------------------------------------------------------------------------------- /penn/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from . import density 2 | from . import logits 3 | from . import threshold 4 | -------------------------------------------------------------------------------- /test/assets/gershwin.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interactiveaudiolab/penn/HEAD/test/assets/gershwin.wav -------------------------------------------------------------------------------- /test/assets/beethoven.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interactiveaudiolab/penn/HEAD/test/assets/beethoven.wav -------------------------------------------------------------------------------- /test/assets/500Hz_stereo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interactiveaudiolab/penn/HEAD/test/assets/500Hz_stereo.wav -------------------------------------------------------------------------------- /penn/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from . import metrics 2 | from .core import * 3 | from .metrics import Metrics, PitchMetrics 4 | -------------------------------------------------------------------------------- /penn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import download 2 | from . import preprocess 3 | from .dataset import Dataset 4 | from .loader import loader 5 | from .sampler import sampler 6 | -------------------------------------------------------------------------------- /config/fcnf0++-mdb.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-mdb' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | -------------------------------------------------------------------------------- /config/fcnf0++-ptdb.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ptdb' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | -------------------------------------------------------------------------------- /config/dio.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'dio' 5 | 6 | # Distance between adjacent frames 7 | HOPSIZE = 160 # samples 8 | 9 | # The pitch estimation method to use 10 | METHOD = 'dio' 11 | 12 | # Audio sample rate 13 | SAMPLE_RATE = 16000 # hz 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-decoder.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-decoder' 5 | 6 | # Whether to perform local expected value decoding of pitch 7 | LOCAL_EXPECTED_VALUE = False 8 | 9 | # The decoder to use for postprocessing 10 | DECODER = 'argmax' 11 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-chunkviterbi-normal.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-chunkviterbi-normal' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'viterbi' 8 | 9 | # Maximum chunk size for chunked Viterbi decoding 10 | VITERBI_MIN_CHUNK_SIZE = 64 11 | -------------------------------------------------------------------------------- /penn/model/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | class Flatten(torch.nn.Module): 7 | 8 | def forward(self, x): 9 | return x.reshape(x.shape[0], -1) 10 | 11 | 12 | class Normalize(torch.nn.Module): 13 | 14 | def forward(self, frames): 15 | return penn.normalize(frames) 16 | -------------------------------------------------------------------------------- /config/crepe++.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'crepe++' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # The name of the model to use for training 13 | MODEL = 'crepe' 14 | -------------------------------------------------------------------------------- /config/deepf0++.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'deepf0++' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # The name of the model to use for training 13 | MODEL = 'deepf0' 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-batchsize.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-batchsize' 5 | 6 | # Batch size 7 | BATCH_SIZE = 32 8 | 9 | # The decoder to use for postprocessing 10 | DECODER = 'argmax' 11 | 12 | # Whether to perform local expected value decoding of pitch 13 | LOCAL_EXPECTED_VALUE = False 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-loss.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-loss' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # Loss function 13 | LOSS = 'binary_cross_entropy' 14 | -------------------------------------------------------------------------------- /config/pyin.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'pyin' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'pyin' 8 | 9 | # Distance between adjacent frames 10 | HOPSIZE = 160 # samples 11 | 12 | # The pitch estimation method to use 13 | METHOD = 'pyin' 14 | 15 | # Audio sample rate 16 | SAMPLE_RATE = 16000 # hz 17 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-layernorm.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-layernorm' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # Type of model normalization 13 | NORMALIZATION = 'batch' 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-unvoiced.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-unvoiced' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # Whether to only use voiced start frames 13 | VOICED_ONLY = True 14 | 15 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-chunkviterbi.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-chunkviterbi' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'viterbi' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # Maximum chunk size for chunked Viterbi decoding 13 | VITERBI_MIN_CHUNK_SIZE = 8 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-inputnorm.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-inputnorm' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to perform local expected value decoding of pitch 10 | LOCAL_EXPECTED_VALUE = False 11 | 12 | # Whether to normalize input audio to mean zero and variance one 13 | NORMALIZE_INPUT = True 14 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-quantization.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-quantization' 5 | 6 | # Width of a pitch bin 7 | CENTS_PER_BIN = 12.5 # cents 8 | 9 | # The decoder to use for postprocessing 10 | DECODER = 'argmax' 11 | 12 | # Whether to perform local expected value decoding of pitch 13 | LOCAL_EXPECTED_VALUE = False 14 | 15 | # Number of pitch bins to predict 16 | PITCH_BINS = 486 17 | -------------------------------------------------------------------------------- /penn/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .crepe import Crepe 3 | from .deepf0 import Deepf0 4 | from .fcnf0 import Fcnf0 5 | 6 | import penn 7 | 8 | 9 | def Model(name=penn.MODEL): 10 | """Create a model""" 11 | if name == 'crepe': 12 | return Crepe() 13 | if name == 'deepf0': 14 | return Deepf0() 15 | if name == 'fcnf0': 16 | return Fcnf0() 17 | raise ValueError(f'Model {name} is not defined') 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Directories for large, project-specific files 2 | data/cache/* 3 | !data/cache/.gitkeep 4 | data/datasets/* 5 | !data/datasets/.gitkeep 6 | data/sources/* 7 | !data/sources/.gitkeep 8 | eval/* 9 | !eval/.gitkeep 10 | penn/assets/checkpoints/* 11 | !penn/assets/checkpoints/.gitkeep 12 | results/* 13 | !results/.gitkeep 14 | runs/* 15 | !runs/.gitkeep 16 | 17 | __pycache__/ 18 | .DS_Store 19 | ._.DS_Store 20 | .ipynb_checkpoints/ 21 | .vscode/ 22 | *.egg-info/ 23 | dist/ 24 | -------------------------------------------------------------------------------- /config/fcnf0++-ablate-earlystop.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0++-ablate-earlystop' 5 | 6 | # The decoder to use for postprocessing 7 | DECODER = 'argmax' 8 | 9 | # Whether to stop training when validation loss stops improving 10 | EARLY_STOPPING = True 11 | 12 | # Whether to perform local expected value decoding of pitch 13 | LOCAL_EXPECTED_VALUE = False 14 | 15 | # Number of steps between logging to Tensorboard 16 | LOG_INTERVAL = 500 # steps 17 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | import penn 6 | 7 | 8 | ############################################################################### 9 | # Testing fixtures 10 | ############################################################################### 11 | 12 | 13 | @pytest.fixture(scope='session') 14 | def audio(): 15 | """Retrieve the test audio""" 16 | return penn.load.audio(Path(__file__).parent / 'assets' / 'gershwin.wav') 17 | 18 | 19 | @pytest.fixture(scope='session') 20 | def audio_stereo(): 21 | """Retrieve the test audio""" 22 | return penn.load.audio(Path(__file__).parent / 'assets' / '500Hz_stereo.wav') 23 | -------------------------------------------------------------------------------- /penn/data/download/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Download datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = argparse.ArgumentParser(description='Download datasets') 14 | parser.add_argument( 15 | '--datasets', 16 | nargs='+', 17 | default=penn.DATASETS, 18 | help='The datasets to download') 19 | return parser.parse_args() 20 | 21 | 22 | penn.data.download.datasets(**vars(parse_args())) 23 | -------------------------------------------------------------------------------- /penn/partition/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Partition datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = argparse.ArgumentParser(description='Partition datasets') 14 | parser.add_argument( 15 | '--datasets', 16 | nargs='+', 17 | default=['mdb', 'ptdb'], 18 | help='The datasets to partition') 19 | return parser.parse_args() 20 | 21 | 22 | penn.partition.datasets(**vars(parse_args())) 23 | -------------------------------------------------------------------------------- /penn/load.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | 4 | import torchaudio 5 | 6 | import penn 7 | 8 | 9 | def audio(file): 10 | """Load audio from disk""" 11 | audio, sample_rate = torchaudio.load(file) 12 | 13 | # If audio is stereo, convert to mono 14 | if audio.size(0) == 2: 15 | warnings.warn(f'Converting stereo audio to mono: {file}') 16 | audio = audio.mean(dim=0, keepdim=True) 17 | 18 | # Maybe resample 19 | return penn.resample(audio, sample_rate) 20 | 21 | 22 | def partition(dataset): 23 | """Load partitions for dataset""" 24 | with open(penn.PARTITION_DIR / f'{dataset}.json') as file: 25 | return json.load(file) 26 | -------------------------------------------------------------------------------- /penn/data/preprocess/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Preprocess datasets 8 | ############################################################################### 9 | 10 | 11 | def parse_args(): 12 | """Parse command-line arguments""" 13 | parser = argparse.ArgumentParser(description='Preprocess datasets') 14 | parser.add_argument( 15 | '--datasets', 16 | nargs='+', 17 | default=['mdb', 'ptdb'], 18 | help='The datasets to preprocess') 19 | return parser.parse_known_args()[0] 20 | 21 | 22 | penn.data.preprocess.datasets(**vars(parse_args())) 23 | -------------------------------------------------------------------------------- /penn/data/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | def loader(datasets, partition, hparam_search=False): 7 | """Retrieve a data loader""" 8 | # Create dataset 9 | dataset = penn.data.Dataset(datasets, partition, hparam_search) 10 | 11 | # Create sampler 12 | sampler = penn.data.sampler(dataset, partition) 13 | 14 | # Get batch size 15 | if partition == 'test' or (partition == 'valid' and hparam_search): 16 | batch_size = 1 17 | elif partition in ['train', 'valid']: 18 | batch_size = penn.BATCH_SIZE 19 | else: 20 | raise ValueError(f'Partition {partition} is not defined') 21 | 22 | # Create data loader 23 | return torch.utils.data.DataLoader( 24 | dataset=dataset, 25 | batch_size=batch_size, 26 | num_workers=penn.NUM_WORKERS, 27 | pin_memory=True, 28 | sampler=sampler) 29 | -------------------------------------------------------------------------------- /test/test_core.py: -------------------------------------------------------------------------------- 1 | import penn 2 | 3 | 4 | ############################################################################### 5 | # Test core.py 6 | ############################################################################### 7 | 8 | 9 | def test_infer(audio): 10 | """Test that inference produces the correct shape""" 11 | pitch, periodicity = penn.from_audio( 12 | audio, 13 | penn.SAMPLE_RATE, 14 | center='half-hop') 15 | shape = (1, audio.shape[1] // penn.HOPSIZE) 16 | assert pitch.shape == periodicity.shape == shape 17 | 18 | 19 | def test_infer_stereo(audio_stereo): 20 | """Test that inference on stereo audio produces the correct shape""" 21 | pitch, periodicity = penn.from_audio( 22 | audio_stereo, 23 | penn.SAMPLE_RATE, 24 | center='half-hop') 25 | shape = (1, audio_stereo.shape[1] // penn.HOPSIZE) 26 | assert pitch.shape == periodicity.shape == shape 27 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: run 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.10"] 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install base dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -e . 22 | pip install soundfile pytest 23 | - name: Run core tests 24 | run: | 25 | python -m pytest --ignore test/test_convert.py 26 | - name: Install additional test dependencies 27 | run: | 28 | pip install -e '.[test]' 29 | - name: Run all tests 30 | run: | 31 | python -m pytest 32 | -------------------------------------------------------------------------------- /penn/evaluate/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Evaluate pitch and periodicity estimation 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | '--datasets', 17 | nargs='+', 18 | default=penn.EVALUATION_DATASETS, 19 | help='The datasets to evaluate on') 20 | parser.add_argument( 21 | '--checkpoint', 22 | type=Path, 23 | help='The checkpoint file to evaluate') 24 | parser.add_argument( 25 | '--gpu', 26 | type=int, 27 | help='The index of the GPU to use for evaluation') 28 | 29 | return parser.parse_known_args()[0] 30 | 31 | 32 | penn.evaluate.datasets(**vars(parse_args())) 33 | -------------------------------------------------------------------------------- /config/deepf0.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'deepf0' 5 | 6 | # Batch size 7 | BATCH_SIZE = 32 8 | 9 | # Width of a pitch bin 10 | CENTS_PER_BIN = 20. # cents 11 | 12 | # The decoder to use for postprocessing 13 | DECODER = 'argmax' 14 | 15 | # Whether to stop training when validation loss stops improving 16 | EARLY_STOPPING = True 17 | 18 | # Distance between adjacent frames 19 | HOPSIZE = 160 # samples 20 | 21 | # Number of steps between logging to Tensorboard 22 | LOG_INTERVAL = 500 # steps 23 | 24 | # Loss function 25 | LOSS = 'binary_cross_entropy' 26 | 27 | # The name of the model to use for training 28 | MODEL = 'deepf0' 29 | 30 | # Type of model normalization 31 | NORMALIZATION = 'weight' 32 | 33 | # Whether to peak-normalize CREPE input audio 34 | NORMALIZE_INPUT = True 35 | 36 | # Number of pitch bins to predict 37 | PITCH_BINS = 360 38 | 39 | # Audio sample rate 40 | SAMPLE_RATE = 16000 # hz 41 | 42 | # Whether to only use voiced start frames 43 | VOICED_ONLY = True 44 | -------------------------------------------------------------------------------- /config/fcnf0.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'fcnf0' 5 | 6 | # Batch size 7 | BATCH_SIZE = 32 8 | 9 | # Width of a pitch bin 10 | CENTS_PER_BIN = 12.5 # cents 11 | 12 | # The decoder to use for postprocessing 13 | DECODER = 'argmax' 14 | 15 | # Whether to stop training when validation loss stops improving 16 | EARLY_STOPPING = True 17 | 18 | # Minimum representable frequency 19 | FMIN = 30. # Hz 20 | 21 | # Whether to perform local expected value decoding of pitch 22 | LOCAL_EXPECTED_VALUE = False 23 | 24 | # Number of steps between logging to Tensorboard 25 | LOG_INTERVAL = 500 # steps 26 | 27 | # Loss function 28 | LOSS = 'binary_cross_entropy' 29 | 30 | # The name of the model to use for training 31 | MODEL = 'fcnf0' 32 | 33 | # Whether to peak-normalize CREPE input audio 34 | NORMALIZE_INPUT = True 35 | 36 | # Type of model normalization 37 | NORMALIZATION = 'batch' 38 | 39 | # Number of pitch bins to predict 40 | PITCH_BINS = 486 41 | 42 | # Whether to only use voiced start frames 43 | VOICED_ONLY = True 44 | -------------------------------------------------------------------------------- /config/crepe.py: -------------------------------------------------------------------------------- 1 | MODULE = 'penn' 2 | 3 | # Configuration name 4 | CONFIG = 'crepe' 5 | 6 | # Batch size 7 | BATCH_SIZE = 32 8 | 9 | # Width of a pitch bin 10 | CENTS_PER_BIN = 20. # cents 11 | 12 | # The decoder to use for postprocessing 13 | DECODER = 'argmax' 14 | 15 | # The dropout rate. Set to None to turn off dropout. 16 | DROPOUT = .25 17 | 18 | # Whether to stop training when validation loss stops improving 19 | EARLY_STOPPING = True 20 | 21 | # Distance between adjacent frames 22 | HOPSIZE = 160 # samples 23 | 24 | # Number of steps between logging to Tensorboard 25 | LOG_INTERVAL = 500 # steps 26 | 27 | # Loss function 28 | LOSS = 'binary_cross_entropy' 29 | 30 | # The name of the model to use for training 31 | MODEL = 'crepe' 32 | 33 | # Type of model normalization 34 | NORMALIZATION = 'batch' 35 | 36 | # Whether to peak-normalize CREPE input audio 37 | NORMALIZE_INPUT = True 38 | 39 | # Number of pitch bins to predict 40 | PITCH_BINS = 360 41 | 42 | # Audio sample rate 43 | SAMPLE_RATE = 16000 # hz 44 | 45 | # Whether to only use voiced start frames 46 | VOICED_ONLY = True 47 | -------------------------------------------------------------------------------- /penn/plot/threshold/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Periodicity threshold figure 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = argparse.ArgumentParser( 15 | description='Create periodicity threshold figure') 16 | parser.add_argument( 17 | '--names', 18 | required=True, 19 | nargs='+', 20 | help='Corresponding labels for each evaluation') 21 | parser.add_argument( 22 | '--evaluations', 23 | type=Path, 24 | required=True, 25 | nargs='+', 26 | help='The evaluations to plot') 27 | parser.add_argument( 28 | '--output_file', 29 | type=Path, 30 | required=True, 31 | help='The output jpg file') 32 | return parser.parse_known_args()[0] 33 | 34 | 35 | penn.plot.threshold.from_evaluations(**vars(parse_args())) 36 | -------------------------------------------------------------------------------- /penn/__init__.py: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | # - interpolate unvoiced 3 | 4 | 5 | ############################################################################### 6 | # Configuration 7 | ############################################################################### 8 | 9 | 10 | # Default configuration parameters to be modified 11 | from .config import defaults 12 | 13 | # Modify configuration 14 | import yapecs 15 | yapecs.configure('penn', defaults) 16 | 17 | # Import configuration parameters 18 | from .config.defaults import * 19 | from .config.static import * 20 | 21 | 22 | ############################################################################### 23 | # Module imports 24 | ############################################################################### 25 | 26 | 27 | from .core import * 28 | from .model import Model 29 | from .train import loss, train 30 | from . import convert 31 | from . import data 32 | from . import decode 33 | from . import dsp 34 | from . import evaluate 35 | from . import load 36 | from . import partition 37 | from . import periodicity 38 | from . import plot 39 | from . import train 40 | from . import voicing 41 | -------------------------------------------------------------------------------- /penn/partition/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Dataset-specific 9 | ############################################################################### 10 | 11 | 12 | def datasets(datasets): 13 | """Partition datasets""" 14 | for name in datasets: 15 | dataset(name) 16 | 17 | 18 | def dataset(name): 19 | """Partition dataset""" 20 | # Get dataset stems 21 | stems = sorted([ 22 | file.stem[:-6] for file in 23 | (penn.CACHE_DIR / name).glob('*-audio.npy')]) 24 | random.seed(penn.RANDOM_SEED) 25 | random.shuffle(stems) 26 | 27 | # Get split points 28 | left, right = int(.70 * len(stems)), int(.85 * len(stems)) 29 | 30 | # Perform partition 31 | partition = { 32 | 'train': sorted(stems[:left]), 33 | 'valid': sorted(stems[left:right]), 34 | 'test': sorted(stems[right:])} 35 | 36 | # Write partition file 37 | with open(penn.PARTITION_DIR / f'{name}.json', 'w') as file: 38 | json.dump(partition, file, indent=4) 39 | -------------------------------------------------------------------------------- /penn/plot/density/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Create figure 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = argparse.ArgumentParser(description='Create density figure') 15 | parser.add_argument( 16 | '--datasets', 17 | nargs='+', 18 | required=True, 19 | help='Datasets to use for density figure') 20 | parser.add_argument( 21 | '--output_file', 22 | required=True, 23 | type=Path, 24 | help='The jpg file to save the plot') 25 | parser.add_argument( 26 | '--checkpoint', 27 | type=Path, 28 | help='The checkpoint file to use for inference') 29 | parser.add_argument( 30 | '--gpu', 31 | type=int, 32 | help='The index of the GPU to use for inference') 33 | return parser.parse_known_args()[0] 34 | 35 | penn.plot.density.to_file(**vars(parse_args())) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Interactive Audio Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /penn/plot/logits/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Create figure 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = argparse.ArgumentParser(description='Create logits figure') 15 | parser.add_argument( 16 | '--audio_file', 17 | required=True, 18 | type=Path, 19 | help='The audio file to plot the logits of') 20 | parser.add_argument( 21 | '--output_file', 22 | required=True, 23 | type=Path, 24 | help='The jpg file to save the plot') 25 | parser.add_argument( 26 | '--checkpoint', 27 | type=Path, 28 | help='The checkpoint file to use for inference') 29 | parser.add_argument( 30 | '--gpu', 31 | type=int, 32 | help='The index of the GPU to use for inference') 33 | return parser.parse_known_args()[0] 34 | 35 | 36 | penn.plot.logits.from_file_to_file(**vars(parse_args())) 37 | -------------------------------------------------------------------------------- /penn/config/static.py: -------------------------------------------------------------------------------- 1 | """Config parameters whose values depend on other config parameters""" 2 | import penn 3 | 4 | 5 | ############################################################################### 6 | # Audio parameters 7 | ############################################################################### 8 | 9 | 10 | # Maximum representable frequency 11 | FMAX = \ 12 | penn.FMIN * 2 ** (penn.PITCH_BINS * penn.CENTS_PER_BIN / penn.OCTAVE) 13 | 14 | # Hopsize in seconds 15 | HOPSIZE_SECONDS = penn.HOPSIZE / penn.SAMPLE_RATE 16 | 17 | 18 | ############################################################################### 19 | # Directories 20 | ############################################################################### 21 | 22 | 23 | # Location to save dataset partitions 24 | PARTITION_DIR = penn.ASSETS_DIR / 'partitions' 25 | 26 | 27 | ############################################################################### 28 | # Training parameters 29 | ############################################################################### 30 | 31 | 32 | # Number of samples used during training 33 | NUM_TRAINING_SAMPLES = \ 34 | (penn.NUM_TRAINING_FRAMES - 1) * penn.HOPSIZE + penn.WINDOW_SIZE 35 | -------------------------------------------------------------------------------- /penn/voicing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Voiced/unvoiced 8 | ############################################################################### 9 | 10 | 11 | def interpolate(pitch, periodicity, value): 12 | """Fill unvoiced regions via linear interpolation""" 13 | # Threshold periodicity 14 | voiced = threshold(periodicity, value) 15 | 16 | # Handle no voiced frames 17 | if not voiced.any(): 18 | return pitch 19 | 20 | # Pitch is linear in base-2 log-space 21 | pitch = torch.log2(pitch) 22 | 23 | # Anchor endpoints 24 | pitch[..., 0] = pitch[voiced][..., 0] 25 | pitch[..., -1] = pitch[voiced][..., -1] 26 | voiced[..., 0] = True 27 | voiced[..., -1] = True 28 | 29 | # Interpolate 30 | pitch[~voiced] = penn.interpolate( 31 | torch.where(~voiced[0])[0][None], 32 | torch.where(voiced[0])[0][None], 33 | pitch[voiced][None]) 34 | 35 | return 2 ** pitch 36 | 37 | 38 | def threshold(periodicity, value): 39 | """Threshold periodicity to produce voiced/unvoiced classifications""" 40 | return periodicity > value 41 | -------------------------------------------------------------------------------- /test/test_convert.py: -------------------------------------------------------------------------------- 1 | import penn 2 | 3 | import librosa 4 | import torch 5 | 6 | 7 | ############################################################################### 8 | # Test convert.py 9 | ############################################################################### 10 | 11 | 12 | def test_convert_frequency_to_midi(): 13 | """Test that conversion from Hz to MIDI matches librosa implementation""" 14 | sample_data = torch.tensor([110.0, 220.0, 440.0, 500.0, 880.0]) 15 | 16 | # Convert 17 | penn_midi = penn.convert.frequency_to_midi(sample_data) 18 | librosa_midi = librosa.hz_to_midi(sample_data.numpy()) 19 | 20 | # Compare 21 | assert torch.allclose( 22 | penn_midi, 23 | torch.tensor(librosa_midi, dtype=torch.float32)) 24 | 25 | 26 | def test_convert_midi_to_frequency(): 27 | """Test that conversion from MIDI to Hz matches librosa implementation""" 28 | sample_data = torch.tensor([45.0, 57.0, 69.0, 71.2131, 81.0]) 29 | 30 | # Convert 31 | penn_frequency = penn.convert.midi_to_frequency(sample_data) 32 | librosa_frequency = librosa.midi_to_hz(sample_data.numpy()) 33 | 34 | # Compare 35 | assert torch.allclose( 36 | penn_frequency, 37 | torch.tensor(librosa_frequency, dtype=torch.float32)) 38 | -------------------------------------------------------------------------------- /penn/periodicity.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | import penn 6 | 7 | 8 | ############################################################################### 9 | # Methods for extracting a periodicity estimate from pitch posteriorgram logits 10 | ############################################################################### 11 | 12 | 13 | def entropy(logits): 14 | """Entropy-based periodicity""" 15 | distribution = torch.nn.functional.softmax(logits, dim=1) 16 | return ( 17 | 1 + 1 / math.log(penn.PITCH_BINS) * \ 18 | (distribution * torch.log(distribution + 1e-7)).sum(dim=1)) 19 | 20 | 21 | def max(logits): 22 | """Periodicity as the maximum confidence""" 23 | if penn.LOSS == 'binary_cross_entropy': 24 | return torch.sigmoid(logits).max(dim=1).values 25 | elif penn.LOSS == 'categorical_cross_entropy': 26 | return torch.nn.functional.softmax( 27 | logits, dim=1).max(dim=1).values 28 | raise ValueError(f'Loss function {penn.LOSS} is not implemented') 29 | 30 | 31 | def sum(logits): 32 | """Periodicity as the sum of the distribution 33 | 34 | This is really just for PYIN, which performs a masking of the distribution 35 | probabilities so that it does not always add to one. 36 | """ 37 | return torch.clip(torch.exp(logits).sum(dim=1), 0, 1) 38 | -------------------------------------------------------------------------------- /penn/data/download/core.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import torchutil 4 | 5 | import penn 6 | 7 | 8 | ############################################################################### 9 | # Download datasets 10 | ############################################################################### 11 | 12 | 13 | @torchutil.notify('download') 14 | def datasets(datasets): 15 | """Download datasets""" 16 | if 'mdb' in datasets: 17 | mdb() 18 | 19 | if 'ptdb' in datasets: 20 | ptdb() 21 | 22 | 23 | ############################################################################### 24 | # Individual datasets 25 | ############################################################################### 26 | 27 | 28 | def mdb(): 29 | """Download mdb dataset""" 30 | torchutil.download.targz( 31 | 'https://zenodo.org/record/1481172/files/MDB-stem-synth.tar.gz', 32 | penn.DATA_DIR) 33 | 34 | # Delete previous directory 35 | shutil.rmtree(penn.DATA_DIR / 'mdb', ignore_errors=True) 36 | 37 | # Rename directory 38 | shutil.move(penn.DATA_DIR / 'MDB-stem-synth', penn.DATA_DIR / 'mdb') 39 | 40 | 41 | def ptdb(): 42 | """Download ptdb dataset""" 43 | directory = penn.DATA_DIR / 'ptdb' 44 | directory.mkdir(exist_ok=True, parents=True) 45 | torchutil.download.zip( 46 | 'https://www2.spsc.tugraz.at/databases/PTDB-TUG/SPEECH_DATA_ZIPPED.zip', 47 | directory) 48 | -------------------------------------------------------------------------------- /penn/train/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | from pathlib import Path 4 | 5 | import torchutil 6 | 7 | import penn 8 | 9 | 10 | ############################################################################### 11 | # Entry point 12 | ############################################################################### 13 | 14 | 15 | def main(config, datasets, gpu): 16 | # Create output directory 17 | directory = penn.RUNS_DIR / config.stem 18 | directory.mkdir(parents=True, exist_ok=True) 19 | 20 | # Save configuration 21 | shutil.copyfile(config, directory / config.name) 22 | 23 | # Train 24 | penn.train(datasets, directory, gpu) 25 | 26 | # Get latest checkpoint 27 | checkpoint = torchutil.checkpoint.latest_path(directory) 28 | 29 | # Evaluate 30 | penn.evaluate.datasets(penn.EVALUATION_DATASETS, checkpoint, gpu) 31 | 32 | 33 | def parse_args(): 34 | """Parse command-line arguments""" 35 | parser = argparse.ArgumentParser(description='Train a model') 36 | parser.add_argument( 37 | '--config', 38 | type=Path, 39 | required=True, 40 | help='The configuration file') 41 | parser.add_argument( 42 | '--datasets', 43 | nargs='+', 44 | default=penn.DATASETS, 45 | help='The datasets to train on') 46 | parser.add_argument( 47 | '--gpu', 48 | type=int, 49 | help='The GPU index') 50 | return parser.parse_args() 51 | 52 | 53 | main(**vars(parse_args())) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | with open('README.md') as file: 5 | long_description = file.read() 6 | 7 | 8 | setup( 9 | name='penn', 10 | description='Pitch Estimating Neural Networks (PENN)', 11 | version='1.0.0', 12 | author='Max Morrison, Caedon Hsieh, Nathan Pruyne, and Bryan Pardo', 13 | author_email='interactiveaudiolab@gmail.com', 14 | url='https://github.com/interactiveaudiolab/penn', 15 | extras_require={ 16 | 'train': [ 17 | 'librosa', # 0.9.1 18 | 'matplotlib', # 3.6.1 19 | 'pyworld', # 0.3.2 20 | 'scipy', # 1.9.3 21 | 'torchcrepe' # 0.0.17 22 | ], 23 | 'test': [ 24 | 'librosa', # 0.9.1 25 | 'pytest', # 8.2.2 26 | ] 27 | }, 28 | install_requires=[ 29 | 'huggingface_hub', # 0.11.1 30 | 'numpy', # 1.23.4 31 | 'torbi', # 0.0.1 32 | 'torch', # 1.12.1+cu113 33 | 'torchaudio', # 0.12.1+cu113 34 | 'torchutil', # 0.0.7 35 | 'yapecs' # 0.0.6 36 | ], 37 | packages=find_packages(), 38 | package_data={'penn': ['assets/*', 'assets/*/*']}, 39 | long_description=long_description, 40 | long_description_content_type='text/markdown', 41 | keywords=['audio', 'frequency', 'music', 'periodicity', 'pitch', 'speech'], 42 | classifiers=['License :: OSI Approved :: MIT License'], 43 | license='MIT') 44 | -------------------------------------------------------------------------------- /penn/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Batch sampler 8 | ############################################################################### 9 | 10 | 11 | def sampler(dataset, partition): 12 | """Create batch index sampler""" 13 | # Get sampler indices 14 | indices = ( 15 | dataset.voiced_indices() if penn.VOICED_ONLY and partition == 'train' 16 | else list(range(len(dataset)))) 17 | 18 | # Maybe use distributed sampler for training 19 | if partition == 'train': 20 | return Sampler(indices) 21 | 22 | # Possibly deterministic random sampler for validation 23 | elif partition == 'valid': 24 | return Sampler(indices) 25 | 26 | # Sample test data sequentially 27 | elif partition == 'test': 28 | return torch.utils.data.SequentialSampler(dataset) 29 | 30 | else: 31 | raise ValueError(f'Partition {partition} is not implemented') 32 | 33 | 34 | ############################################################################### 35 | # Custom samplers 36 | ############################################################################### 37 | 38 | 39 | class Sampler: 40 | 41 | def __init__(self, indices): 42 | self.indices = indices 43 | self.epoch = 0 44 | 45 | def __iter__(self): 46 | generator = torch.Generator() 47 | generator.manual_seed(penn.RANDOM_SEED + self.epoch) 48 | for i in torch.randperm(len(self.indices), generator=generator): 49 | yield self.indices[i] 50 | 51 | def __len__(self): 52 | return len(self.indices) 53 | 54 | def set_epoch(self, epoch): 55 | self.epoch = epoch 56 | -------------------------------------------------------------------------------- /config/torchcrepe.py: -------------------------------------------------------------------------------- 1 | import torchcrepe 2 | 3 | MODULE = 'penn' 4 | 5 | # Configuration name 6 | # Note - We're not actually training torchcrepe. We only use this for 7 | # evaluation, and only use the FMIN in order to precisely align 8 | # predictions with ground truth pitch bins. The other arguments are 9 | # for completeness. The public crepe (and torchcrepe) model was trained 10 | # on a set of six datasets, five of which are not considered in the 11 | # current project. 12 | CONFIG = 'torchcrepe' 13 | 14 | # Batch size 15 | BATCH_SIZE = 32 16 | 17 | # Width of a pitch bin 18 | CENTS_PER_BIN = 20. # cents 19 | 20 | # The decoder to use for postprocessing 21 | DECODER = 'argmax' 22 | 23 | # The dropout rate. Set to None to turn off dropout. 24 | DROPOUT = .25 25 | 26 | # Whether to stop training when validation loss stops improving 27 | EARLY_STOPPING = True 28 | 29 | # Exactly align pitch bins 30 | FMIN = torchcrepe.convert.cents_to_frequency(1997.3794084376191) 31 | 32 | # Distance between adjacent frames 33 | HOPSIZE = 160 # samples 34 | 35 | # Whether to perform local expected value decoding of pitch 36 | LOCAL_EXPECTED_VALUE = False 37 | 38 | # Number of steps between logging to Tensorboard 39 | LOG_INTERVAL = 500 # steps 40 | 41 | # Loss function 42 | LOSS = 'binary_cross_entropy' 43 | 44 | # The pitch estimation method to use 45 | METHOD = 'torchcrepe' 46 | 47 | # The name of the model to use for training 48 | MODEL = 'crepe' 49 | 50 | # Type of model normalization 51 | NORMALIZATION = 'batch' 52 | 53 | # Whether to peak-normalize CREPE input audio 54 | NORMALIZE_INPUT = True 55 | 56 | # Number of pitch bins to predict 57 | PITCH_BINS = 360 58 | 59 | # Audio sample rate 60 | SAMPLE_RATE = 16000 # hz 61 | 62 | # Whether to only use voiced start frames 63 | VOICED_ONLY = True 64 | -------------------------------------------------------------------------------- /penn/model/fcnf0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | class Fcnf0(torch.nn.Sequential): 7 | 8 | def __init__(self): 9 | layers = (penn.model.Normalize(),) if penn.NORMALIZE_INPUT else () 10 | layers += ( 11 | Block(1, 256, 481, (2, 2)), 12 | Block(256, 32, 225, (2, 2)), 13 | Block(32, 32, 97, (2, 2)), 14 | Block(32, 128, 66), 15 | Block(128, 256, 35), 16 | Block(256, 512, 4), 17 | torch.nn.Conv1d(512, penn.PITCH_BINS, 4)) 18 | super().__init__(*layers) 19 | 20 | def forward(self, frames): 21 | # shape=(batch, 1, penn.WINDOW_SIZE) => 22 | # shape=(batch, penn.PITCH_BINS, penn.NUM_TRAINING_FRAMES) 23 | return super().forward(frames[:, :, 16:-15]) 24 | 25 | 26 | class Block(torch.nn.Sequential): 27 | 28 | def __init__( 29 | self, 30 | in_channels, 31 | out_channels, 32 | length=1, 33 | pooling=None, 34 | kernel_size=32): 35 | layers = ( 36 | torch.nn.Conv1d(in_channels, out_channels, kernel_size), 37 | torch.nn.ReLU()) 38 | 39 | # Maybe add pooling 40 | if pooling is not None: 41 | layers += (torch.nn.MaxPool1d(*pooling),) 42 | 43 | # Maybe add normalization 44 | if penn.NORMALIZATION == 'batch': 45 | layers += (torch.nn.BatchNorm1d(out_channels, momentum=.01),) 46 | elif penn.NORMALIZATION == 'instance': 47 | layers += (torch.nn.InstanceNorm1d(out_channels),) 48 | elif penn.NORMALIZATION == 'layer': 49 | layers += (torch.nn.LayerNorm((out_channels, length)),) 50 | else: 51 | raise ValueError( 52 | f'Normalization method {penn.NORMALIZATION} is not defined') 53 | 54 | # Maybe add dropout 55 | if penn.DROPOUT is not None: 56 | layers += (torch.nn.Dropout(penn.DROPOUT),) 57 | 58 | super().__init__(*layers) 59 | -------------------------------------------------------------------------------- /penn/plot/threshold/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Create figure 9 | ############################################################################### 10 | 11 | 12 | def from_evaluations(names, evaluations, output_file): 13 | """Plot periodicity thresholds""" 14 | import matplotlib.pyplot as plt 15 | 16 | # Create plot 17 | figure, axis = plt.subplots(figsize=(7, 3)) 18 | 19 | # Make pretty 20 | axis.spines['top'].set_visible(False) 21 | axis.spines['right'].set_visible(False) 22 | axis.spines['bottom'].set_visible(False) 23 | axis.spines['left'].set_visible(False) 24 | ticks = [0., .25, .5, .75, 1.] 25 | axis.set_xlim([0., 1.]) 26 | axis.get_xaxis().set_ticks(ticks) 27 | axis.get_yaxis().set_ticks(ticks) 28 | axis.tick_params(axis=u'both', which=u'both',length=0) 29 | axis.set_xlabel('Unvoiced threshold') 30 | axis.set_ylabel('F1') 31 | for tick in ticks: 32 | axis.axhline(tick, color='gray', linestyle='--', linewidth=.8) 33 | 34 | # Iterate over evaluations to plot 35 | for name, evaluation in zip(names, evaluations): 36 | directory = penn.EVAL_DIR / evaluation 37 | 38 | # Load results 39 | with open(directory / 'overall.json') as file: 40 | results = json.load(file)['aggregate'] 41 | with open(directory / 'periodicity.json') as file: 42 | optimal = json.load(file)['entropy'] 43 | 44 | # Get thresholds and corresponding F1 values 45 | x, y = zip(* 46 | [(key, val) for key, val in results.items() if key.startswith('f1')]) 47 | x = [float(item[3:]) for item in x] + [1] 48 | y = [0 if math.isnan(item) else item for item in y] + [0] 49 | 50 | # Plot 51 | line = axis.plot(x, y, label=name) 52 | color = line[0].get_color() 53 | axis.plot(optimal['threshold'], optimal['f1'], marker='*', color=color) 54 | 55 | # Add legend 56 | axis.legend(frameon=False, loc='upper right') 57 | 58 | # Save 59 | figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=300) 60 | -------------------------------------------------------------------------------- /penn/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import penn 4 | from pathlib import Path 5 | 6 | 7 | ############################################################################### 8 | # Entry point 9 | ############################################################################### 10 | 11 | 12 | def parse_args(): 13 | """Parse command-line arguments""" 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | '--files', 17 | nargs='+', 18 | required=True, 19 | type=Path, 20 | help='The audio files to process') 21 | parser.add_argument( 22 | '--output_prefixes', 23 | nargs='+', 24 | type=Path, 25 | help=( 26 | 'The files to save pitch and periodicity without extension. ' 27 | 'Defaults to audio_files without extensions.')) 28 | parser.add_argument( 29 | '--hopsize', 30 | type=float, 31 | default=penn.HOPSIZE_SECONDS, 32 | help=( 33 | 'The hopsize in seconds. ' 34 | f'Defaults to {penn.HOPSIZE_SECONDS} seconds.')) 35 | parser.add_argument( 36 | '--fmin', 37 | type=float, 38 | default=penn.FMIN, 39 | help=( 40 | 'The minimum frequency allowed in Hz. ' 41 | f'Defaults to {penn.FMIN} Hz.')) 42 | parser.add_argument( 43 | '--fmax', 44 | type=float, 45 | default=penn.FMAX, 46 | help=( 47 | 'The maximum frequency allowed in Hz. ' 48 | f'Defaults to {penn.FMAX} Hz.')) 49 | parser.add_argument( 50 | '--checkpoint', 51 | type=Path, 52 | help=( 53 | 'The model checkpoint file. ' 54 | f'Defaults to pretrained FCNF0++.')) 55 | parser.add_argument( 56 | '--batch_size', 57 | type=int, 58 | default=penn.EVALUATION_BATCH_SIZE, 59 | help=( 60 | 'The number of frames per batch. ' 61 | f'Defaults to {penn.EVALUATION_BATCH_SIZE}.')) 62 | parser.add_argument( 63 | '--center', 64 | choices=['half-window', 'half-hop', 'zero'], 65 | default='half-window', 66 | help='Padding options') 67 | parser.add_argument( 68 | '--decoder', 69 | choices=['argmax', 'pyin', 'viterbi'], 70 | default=penn.DECODER, 71 | help='Posteriorgram decoder') 72 | parser.add_argument( 73 | '--interp_unvoiced_at', 74 | type=float, 75 | help='Specifies voicing threshold for interpolation') 76 | parser.add_argument( 77 | '--num_workers', 78 | type=int, 79 | default=0, 80 | help='Number of CPU threads for async data I/O') 81 | parser.add_argument( 82 | '--gpu', 83 | type=int, 84 | help='The index of the gpu to perform inference on. Defaults to CPU.') 85 | 86 | return parser.parse_known_args()[0] 87 | 88 | 89 | penn.from_files_to_files(**vars(parse_args())) 90 | -------------------------------------------------------------------------------- /penn/model/crepe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Crepe 8 | ############################################################################### 9 | 10 | 11 | class Crepe(torch.nn.Sequential): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | in_channels = [1, 1024, 128, 128, 128, 256] 16 | out_channels = [1024, 128, 128, 128, 256, 512] 17 | kernels = [512] + 5 * [64] 18 | strides = [4] + 5 * [1] 19 | padding = [(254, 254)] + 5 * [(31, 32)] 20 | lengths = [256, 128, 64, 32, 16, 8] 21 | super().__init__(*( 22 | ([penn.model.Normalize()] if penn.NORMALIZE_INPUT else []) + 23 | [ 24 | Block(i, o, k, s, p, l) for i, o, k, s, p, l in 25 | zip( 26 | in_channels, 27 | out_channels, 28 | kernels, 29 | strides, 30 | padding, 31 | lengths 32 | ) 33 | ] + 34 | [ 35 | penn.model.Flatten(), 36 | torch.nn.Linear( 37 | in_features=2048, 38 | out_features=penn.PITCH_BINS) 39 | ] 40 | )) 41 | 42 | def forward(self, frames): 43 | # shape=(batch, 1, penn.WINDOW_SIZE) => 44 | # shape=(batch, penn.PITCH_BINS, penn.NUM_TRAINING_FRAMES) 45 | return super().forward(frames)[:, :, None] 46 | 47 | 48 | ############################################################################### 49 | # Utilities 50 | ############################################################################### 51 | 52 | 53 | class Block(torch.nn.Sequential): 54 | 55 | def __init__( 56 | self, 57 | in_channels, 58 | out_channels, 59 | kernel_size, 60 | stride, 61 | padding, 62 | length): 63 | layers = ( 64 | torch.nn.ConstantPad1d(padding, 0), 65 | torch.nn.Conv1d( 66 | in_channels, 67 | out_channels, 68 | kernel_size, 69 | stride), 70 | torch.nn.ReLU()) 71 | 72 | # Maybe add normalization 73 | if penn.NORMALIZATION == 'batch': 74 | layers += (torch.nn.BatchNorm1d(out_channels, momentum=.01),) 75 | elif penn.NORMALIZATION == 'instance': 76 | layers += (torch.nn.InstanceNorm1d(out_channels),) 77 | elif penn.NORMALIZATION == 'layer': 78 | layers += (torch.nn.LayerNorm((out_channels, length)),) 79 | else: 80 | raise ValueError( 81 | f'Normalization method {penn.NORMALIZATION} is not defined') 82 | 83 | # Add max pooling 84 | layers += (torch.nn.MaxPool1d(2, 2),) 85 | 86 | # Maybe add dropout 87 | if penn.DROPOUT is not None: 88 | layers += (torch.nn.Dropout(penn.DROPOUT),) 89 | 90 | super().__init__(*layers) 91 | -------------------------------------------------------------------------------- /penn/plot/logits/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Create figure 8 | ############################################################################### 9 | 10 | 11 | def from_audio( 12 | audio, 13 | sample_rate, 14 | checkpoint=None, 15 | gpu=None): 16 | """Plot logits with pitch overlay""" 17 | import matplotlib 18 | import matplotlib.pyplot as plt 19 | 20 | logits = [] 21 | 22 | # Change font size 23 | matplotlib.rcParams.update({'font.size': 16}) 24 | 25 | # Preprocess audio 26 | for frames, _ in penn.preprocess( 27 | audio, 28 | sample_rate, 29 | batch_size=penn.EVALUATION_BATCH_SIZE, 30 | center='half-hop' 31 | ): 32 | 33 | # Copy to device 34 | frames = frames.to('cpu' if gpu is None else f'cuda:{gpu}') 35 | 36 | # Infer 37 | logits.append(penn.infer(frames, checkpoint=checkpoint).detach()) 38 | 39 | # Concatenate results 40 | logits = torch.cat(logits) 41 | 42 | # Convert to distribution 43 | # NOTE - We use softmax even if the loss is BCE for more comparable 44 | # visualization. Otherwise, the variance of models trained with 45 | # BCE looks erroneously lower. 46 | distributions = torch.nn.functional.softmax(logits, dim=1) 47 | 48 | # Take the log again for display 49 | distributions = torch.log(distributions) 50 | distributions[torch.isinf(distributions)] = \ 51 | distributions[~torch.isinf(distributions)].min() 52 | 53 | # Prepare for plotting 54 | distributions = distributions.cpu().squeeze(2).T 55 | 56 | # Setup figure 57 | figure, axis = plt.subplots(figsize=(18, 2)) 58 | 59 | # Make pretty 60 | axis.spines['top'].set_visible(False) 61 | axis.spines['right'].set_visible(False) 62 | axis.spines['bottom'].set_visible(False) 63 | axis.spines['left'].set_visible(False) 64 | xticks = torch.arange(0, len(logits), int(penn.SAMPLE_RATE / penn.HOPSIZE)) 65 | xlabels = xticks // 100 66 | axis.get_xaxis().set_ticks(xticks.tolist(), xlabels.tolist()) 67 | yticks = torch.linspace(0, penn.PITCH_BINS - 1, 5) 68 | ylabels = penn.convert.bins_to_frequency(yticks) 69 | ylabels = ylabels.round().int().tolist() 70 | axis.get_yaxis().set_ticks(yticks, ylabels) 71 | axis.set_xlabel('Time (seconds)') 72 | axis.set_ylabel('Frequency (Hz)') 73 | 74 | # Plot pitch posteriorgram 75 | axis.imshow(distributions, aspect='auto', origin='lower') 76 | 77 | return figure 78 | 79 | 80 | def from_file(audio_file, checkpoint=None, gpu=None): 81 | """Plot logits and optional pitch""" 82 | # Load audio 83 | audio = penn.load.audio(audio_file) 84 | 85 | # Plot 86 | return from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu) 87 | 88 | 89 | def from_file_to_file(audio_file, output_file, checkpoint=None, gpu=None): 90 | """Plot pitch and periodicity and save to disk""" 91 | # Plot 92 | figure = from_file(audio_file, checkpoint, gpu) 93 | 94 | # Save to disk 95 | figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=900) 96 | -------------------------------------------------------------------------------- /penn/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Pitch conversions 8 | ############################################################################### 9 | 10 | 11 | def bins_to_cents(bins): 12 | """Converts pitch bins to cents""" 13 | return penn.CENTS_PER_BIN * bins 14 | 15 | 16 | def bins_to_frequency(bins): 17 | """Converts pitch bins to frequency in Hz""" 18 | return cents_to_frequency(bins_to_cents(bins)) 19 | 20 | 21 | def cents_to_bins(cents, quantize_fn=torch.floor): 22 | """Converts cents to pitch bins""" 23 | bins = quantize_fn(cents / penn.CENTS_PER_BIN).long() 24 | bins[bins < 0] = 0 25 | bins[bins >= penn.PITCH_BINS] = penn.PITCH_BINS - 1 26 | return bins 27 | 28 | 29 | def cents_to_frequency(cents): 30 | """Converts cents to frequency in Hz""" 31 | return penn.FMIN * 2 ** (cents / penn.OCTAVE) 32 | 33 | 34 | def frequency_to_bins(frequency, quantize_fn=torch.floor): 35 | """Convert frequency in Hz to pitch bins""" 36 | return cents_to_bins(frequency_to_cents(frequency), quantize_fn) 37 | 38 | 39 | def frequency_to_cents(frequency): 40 | """Convert frequency in Hz to cents""" 41 | return penn.OCTAVE * torch.log2(frequency / penn.FMIN) 42 | 43 | 44 | def frequency_to_samples(frequency, sample_rate=penn.SAMPLE_RATE): 45 | """Convert frequency in Hz to number of samples per period""" 46 | return sample_rate / frequency 47 | 48 | 49 | def frequency_to_midi(frequency): 50 | """ 51 | Convert frequency to MIDI note number(s) 52 | Based on librosa.hz_to_midi(frequencies) implementation 53 | https://librosa.org/doc/main/_modules/librosa/core/convert.html#hz_to_midi 54 | """ 55 | return 12 * (torch.log2(frequency) - torch.log2(torch.tensor(440.0))) + 69 56 | 57 | 58 | def midi_to_frequency(midi): 59 | """ 60 | Convert MIDI note number to frequency 61 | Based on librosa.midi_to_hz(notes) implementation 62 | https://librosa.org/doc/main/_modules/librosa/core/convert.html#midi_to_hz 63 | """ 64 | return 440.0 * (2.0 ** ((midi - 69.0) / 12.0)) 65 | 66 | 67 | ############################################################################### 68 | # Time conversions 69 | ############################################################################### 70 | 71 | 72 | def frames_to_samples(frames): 73 | """Convert number of frames to samples""" 74 | return frames * penn.HOPSIZE 75 | 76 | 77 | def frames_to_seconds(frames): 78 | """Convert number of frames to seconds""" 79 | return frames * penn.HOPSIZE_SECONDS 80 | 81 | 82 | def seconds_to_frames(seconds): 83 | """Convert seconds to number of frames""" 84 | return samples_to_frames(seconds_to_samples(seconds)) 85 | 86 | 87 | def seconds_to_samples(seconds, sample_rate=penn.SAMPLE_RATE): 88 | """Convert seconds to number of samples""" 89 | return seconds * sample_rate 90 | 91 | 92 | def samples_to_frames(samples): 93 | """Convert samples to number of frames""" 94 | return samples // penn.HOPSIZE 95 | 96 | 97 | def samples_to_seconds(samples, sample_rate=penn.SAMPLE_RATE): 98 | """Convert number of samples to seconds""" 99 | return samples / sample_rate 100 | -------------------------------------------------------------------------------- /penn/model/deepf0.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # DeepF0 model implementation 8 | ############################################################################### 9 | 10 | 11 | class Deepf0(torch.nn.Sequential): 12 | 13 | def __init__(self, channels=128, kernel_size=64): 14 | layers = (penn.model.Normalize(),) if penn.NORMALIZE_INPUT else () 15 | layers += ( 16 | CausalConv1d(1, channels, kernel_size), 17 | Block(channels, channels, kernel_size, 1), 18 | Block(channels, channels, kernel_size, 2), 19 | Block(channels, channels, kernel_size, 4), 20 | Block(channels, channels, kernel_size, 8), 21 | torch.nn.AvgPool1d(kernel_size), 22 | penn.model.Flatten(), 23 | torch.nn.Linear(2048, penn.PITCH_BINS)) 24 | super().__init__(*layers) 25 | 26 | def forward(self, frames): 27 | # shape=(batch, 1, penn.WINDOW_SIZE) => 28 | # shape=(batch, penn.PITCH_BINS, penn.NUM_TRAINING_FRAMES) 29 | return super().forward(frames)[:, :, None] 30 | 31 | 32 | ############################################################################### 33 | # Utilities 34 | ############################################################################### 35 | 36 | 37 | class Block(torch.nn.Sequential): 38 | 39 | def __init__( 40 | self, 41 | input_channels, 42 | output_channels, 43 | kernel_size, 44 | dilation): 45 | if penn.NORMALIZATION == 'weight': 46 | norm_conv = [ 47 | torch.nn.utils.weight_norm( 48 | torch.nn.Conv1d(output_channels, output_channels, 1))] 49 | elif penn.NORMALIZATION == 'layer': 50 | norm_conv = [ 51 | torch.nn.LayerNorm( 52 | (output_channels, penn.NUM_TRAINING_SAMPLES)), 53 | torch.nn.Conv1d(output_channels, output_channels, 1)] 54 | else: 55 | raise ValueError( 56 | f'Normalization method {penn.NORMALIZATION} is not defined') 57 | 58 | super().__init__(*( 59 | [ 60 | CausalConv1d( 61 | input_channels, 62 | output_channels, 63 | kernel_size, 64 | dilation=dilation), 65 | torch.nn.ReLU() 66 | ] + norm_conv)) 67 | 68 | def forward(self, x): 69 | return torch.nn.functional.relu(super().forward(x) + x) 70 | 71 | 72 | class CausalConv1d(torch.nn.Conv1d): 73 | 74 | def __init__( 75 | self, 76 | in_channels, 77 | out_channels, 78 | kernel_size, 79 | stride=1, 80 | dilation=1, 81 | groups=1, 82 | bias=True): 83 | self.pad = (kernel_size - 1) * dilation 84 | super().__init__( 85 | in_channels, 86 | out_channels, 87 | kernel_size=kernel_size, 88 | stride=stride, 89 | padding=self.pad, 90 | dilation=dilation, 91 | groups=groups, 92 | bias=bias) 93 | 94 | def forward(self, input): 95 | result = super().forward(input) 96 | if self.pad != 0: 97 | return result[:, :, :-self.pad] 98 | return result 99 | -------------------------------------------------------------------------------- /penn/dsp/dio.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing as mp 3 | 4 | import numpy as np 5 | import torch 6 | import torchutil 7 | 8 | import penn 9 | 10 | 11 | ############################################################################### 12 | # DIO (from pyworld) 13 | ############################################################################### 14 | 15 | 16 | def from_audio( 17 | audio, 18 | sample_rate=penn.SAMPLE_RATE, 19 | hopsize=penn.HOPSIZE_SECONDS, 20 | fmin=penn.FMIN, 21 | fmax=penn.FMAX): 22 | """Estimate pitch and periodicity with dio""" 23 | with torchutil.time.context('infer'): 24 | 25 | import pyworld 26 | 27 | # Convert to numpy 28 | audio = audio.numpy().squeeze().astype(np.float) 29 | 30 | # Get pitch 31 | pitch, times = pyworld.dio( 32 | audio[penn.WINDOW_SIZE // 2:-penn.WINDOW_SIZE // 2], 33 | sample_rate, 34 | fmin, 35 | fmax, 36 | frame_period=1000 * hopsize) 37 | 38 | # Refine pitch 39 | pitch = pyworld.stonemask( 40 | audio, 41 | pitch, 42 | times, 43 | sample_rate) 44 | 45 | # Interpolate unvoiced tokens 46 | pitch, _ = penn.data.preprocess.interpolate_unvoiced(pitch) 47 | 48 | # Convert to torch 49 | return torch.from_numpy(pitch)[None] 50 | 51 | 52 | def from_file( 53 | file, 54 | hopsize=penn.HOPSIZE_SECONDS, 55 | fmin=penn.FMIN, 56 | fmax=penn.FMAX): 57 | """Estimate pitch and periodicity with dio from audio on disk""" 58 | # Load 59 | with torchutil.time.context('load'): 60 | audio = penn.load.audio(file) 61 | 62 | # Infer 63 | return from_audio(audio, penn.SAMPLE_RATE, hopsize, fmin, fmax) 64 | 65 | 66 | def from_file_to_file( 67 | file, 68 | output_prefix=None, 69 | hopsize=penn.HOPSIZE_SECONDS, 70 | fmin=penn.FMIN, 71 | fmax=penn.FMAX): 72 | """Estimate pitch and periodicity with dio and save to disk""" 73 | # Infer 74 | results = from_file(file, hopsize, fmin, fmax) 75 | 76 | # Save to disk 77 | with torchutil.time.context('save'): 78 | 79 | # Maybe use same filename with new extension 80 | if output_prefix is None: 81 | output_prefix = file.parent / file.stem 82 | 83 | # Save pitch 84 | torch.save(results[0], f'{output_prefix}-pitch.pt') 85 | 86 | # Maybe save periodicity 87 | if len(results) > 1: 88 | torch.save(results[1], f'{output_prefix}-periodicity.pt') 89 | 90 | 91 | def from_files_to_files( 92 | files, 93 | output_prefixes=None, 94 | hopsize=penn.HOPSIZE_SECONDS, 95 | fmin=penn.FMIN, 96 | fmax=penn.FMAX): 97 | """Estimate pitch and periodicity with dio and save to disk""" 98 | pitch_fn = functools.partial( 99 | from_file_to_file, 100 | hopsize=hopsize, 101 | fmin=fmin, 102 | fmax=fmax) 103 | iterator = zip(files, output_prefixes) 104 | 105 | # Turn off multiprocessing for benchmarking 106 | if penn.BENCHMARK: 107 | for item in torchutil.iterator( 108 | iterator, 109 | f'{penn.CONFIG}', 110 | total=len(files) 111 | ): 112 | pitch_fn(*item) 113 | else: 114 | with mp.get_context('spawn').Pool() as pool: 115 | pool.starmap(pitch_fn, iterator) 116 | -------------------------------------------------------------------------------- /penn/plot/density/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import penn 4 | 5 | 6 | ############################################################################### 7 | # Constants 8 | ############################################################################### 9 | 10 | 11 | # Amount of bin downsampling 12 | DOWNSAMPLE_RATE = penn.PITCH_BINS // 90 13 | 14 | 15 | ############################################################################### 16 | # Plot dataset density vs true positive density 17 | ############################################################################### 18 | 19 | 20 | def to_file( 21 | datasets, 22 | output_file, 23 | checkpoint=None, 24 | gpu=None): 25 | """Plot ground truth and true positive densities""" 26 | import matplotlib 27 | import matplotlib.pyplot as plt 28 | matplotlib.rcParams.update({'font.size': 20}) 29 | figure, axis = plt.subplots() 30 | axis.set_axis_off() 31 | 32 | # Plot true data density 33 | x = torch.arange(0, penn.PITCH_BINS, DOWNSAMPLE_RATE) 34 | y_true, y_pred = histograms(datasets, checkpoint, gpu) 35 | y_true = y_true.reshape(-1, DOWNSAMPLE_RATE).sum(-1) 36 | axis.bar( 37 | x, 38 | y_true, 39 | width=1.05 * DOWNSAMPLE_RATE, 40 | label=f'Data distribution') 41 | 42 | # Plot our guesses 43 | y_pred = y_pred.reshape(-1, DOWNSAMPLE_RATE).sum(-1) 44 | axis.bar( 45 | x, 46 | y_pred, 47 | width=1.05 * DOWNSAMPLE_RATE, 48 | label='Inferred distribution') 49 | 50 | # Plot overlap 51 | overlap = torch.minimum(y_true, y_pred) 52 | axis.bar( 53 | x, 54 | overlap, 55 | color='gray', 56 | width=1.05 * DOWNSAMPLE_RATE, 57 | label='Overlap') 58 | 59 | # Add legend 60 | axis.legend(frameon=False, prop={'size': 10}) 61 | 62 | # Save plot 63 | figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=300) 64 | 65 | 66 | def histograms(datasets, checkpoint=None, gpu=None): 67 | """Get histogram of true positives from datasets and model checkpoint""" 68 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 69 | 70 | # Initialize counts 71 | true_result = torch.zeros((penn.PITCH_BINS,)) 72 | infer_result = torch.zeros((penn.PITCH_BINS,)) 73 | 74 | # Setup loader 75 | loader = penn.data.loader(datasets, 'test') 76 | 77 | # Update counts 78 | for audio, bins, pitch, voiced, _ in loader: 79 | 80 | # Preprocess audio 81 | batch_size = \ 82 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 83 | for i, (frames, size) in enumerate( 84 | penn.preprocess( 85 | audio[0], 86 | penn.SAMPLE_RATE, 87 | batch_size=batch_size, 88 | center='half-hop' 89 | ) 90 | ): 91 | 92 | # Copy to device 93 | frames = frames.to(device) 94 | 95 | # Slice features and copy to GPU 96 | start = i * penn.EVALUATION_BATCH_SIZE 97 | end = start + size 98 | batch_bins = bins[:, start:end].to(device) 99 | batch_pitch = pitch[:, start:end].to(device) 100 | batch_voiced = voiced[:, start:end].to(device) 101 | 102 | # Infer 103 | batch_logits = penn.infer(frames, checkpoint).detach() 104 | 105 | # Get predicted bins 106 | batch_predicted, _, _ = penn.postprocess(batch_logits) 107 | 108 | # Get true positives 109 | true_all = batch_bins[batch_voiced] 110 | pred_all = batch_predicted[batch_voiced] 111 | 112 | # Update counts 113 | indices = torch.arange( 114 | penn.PITCH_BINS + 1, 115 | dtype=torch.float, 116 | device=device) 117 | true_result += torch.histogram(true_all.cpu().float(), indices.cpu())[0] 118 | infer_result += torch.histogram(pred_all.cpu().float(), indices.cpu())[0] 119 | 120 | return true_result, infer_result 121 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Runs all experiments in the paper 2 | # "Cross-domain Neural Pitch and Periodicity Estimation" 3 | 4 | # Args 5 | # $1 - index of GPU to use 6 | 7 | # Download datasets 8 | python -m penn.data.download 9 | 10 | # Setup and run 16 kHz experiments 11 | python -m penn.data.preprocess --config config/crepe.py 12 | python -m penn.partition 13 | python -m penn.train --config config/crepe.py --gpu $1 14 | python -m penn.train --config config/deepf0.py --gpu $1 15 | 16 | # Evaluate baselines at 16 kHz 17 | python -m penn.evaluate --gpu $1 --config config/dio.py 18 | python -m penn.evaluate --gpu $1 --config config/pyin.py 19 | python -m penn.evaluate \ 20 | --gpu $1 \ 21 | --method torchcrepe \ 22 | --config config/torchcrepe.py 23 | 24 | # Setup 8 kHz data 25 | python -m penn.data.preprocess 26 | 27 | # Run 8 kHz experiments 28 | python -m penn.train --config config/crepe++.py --gpu $1 29 | python -m penn.train --config config/deepf0++.py --gpu $1 30 | python -m penn.train --config config/fcnf0.py --gpu $1 31 | python -m penn.train --config config/fcnf0++.py --gpu $1 32 | 33 | # Train on individual datasets 34 | python -m penn.train \ 35 | --config config/fcnf0++-mdb.py \ 36 | --datasets mdb \ 37 | --gpu $1 38 | python -m penn.train \ 39 | --config config/fcnf0++-ptdb.py \ 40 | --datasets ptdb \ 41 | --gpu $1 42 | 43 | # Train ablations 44 | python -m penn.train --config config/fcnf0++-ablate-batchsize.py --gpu $1 45 | python -m penn.train --config config/fcnf0++-ablate-earlystop.py --gpu $1 46 | python -m penn.train --config config/fcnf0++-ablate-inputnorm.py --gpu $1 47 | python -m penn.train --config config/fcnf0++-ablate-layernorm.py --gpu $1 48 | python -m penn.train --config config/fcnf0++-ablate-loss.py --gpu $1 49 | python -m penn.train --config config/fcnf0++-ablate-quantization.py --gpu $1 50 | python -m penn.train --config config/fcnf0++-ablate-unvoiced.py --gpu $1 51 | 52 | # Evaluate locally normal decoding 53 | python -m penn.evaluate \ 54 | --config config/fcnf0++-ablate-decoder.py \ 55 | --checkpoint runs/fcnf0++/00250000.pt \ 56 | --gpu $1 57 | 58 | # Plot data and inference distributions 59 | python -m penn.plot.density \ 60 | --datasets mdb \ 61 | --output_file results/mdb_on_mdb.pdf \ 62 | --checkpoint runs/fcnf0++-mdb/00250000.pt \ 63 | --gpu $1 64 | python -m penn.plot.density \ 65 | --datasets ptdb \ 66 | --output_file results/mdb_on_ptdb.pdf \ 67 | --checkpoint runs/fcnf0++-mdb/00250000.pt \ 68 | --gpu $1 69 | python -m penn.plot.density \ 70 | --datasets ptdb \ 71 | --output_file results/ptdb_on_ptdb.pdf \ 72 | --checkpoint runs/fcnf0++-ptdb/00250000.pt \ 73 | --gpu $1 74 | python -m penn.plot.density \ 75 | --datasets mdb \ 76 | --output_file results/ptdb_on_mdb.pdf \ 77 | --checkpoint runs/fcnf0++-ptdb/00250000.pt \ 78 | --gpu $1 79 | python -m penn.plot.density \ 80 | --datasets mdb \ 81 | --output_file results/both_on_mdb.pdf \ 82 | --checkpoint runs/fcnf0++/00250000.pt \ 83 | --gpu $1 84 | python -m penn.plot.density \ 85 | --datasets ptdb \ 86 | --output_file results/both_on_ptdb.pdf \ 87 | --checkpoint runs/fcnf0++/00250000.pt \ 88 | --gpu $1 89 | 90 | # Plot voiced/unvoiced threshold landscape 91 | python -m penn.plot.threshold \ 92 | --names FCNF0++ "FCNF0++ (voiced only)" \ 93 | --evaluations fcnf0++-ablate-decoder fcnf0++-ablate-unvoiced \ 94 | --output_file results/threshold.pdf 95 | 96 | # Plot pitch posteriorgram figures 97 | python -m penn.plot.logits \ 98 | --config config/fcnf0++.py \ 99 | --audio_file test/assets/gershwin.wav \ 100 | --output_file results/fcnf0++-gershwin.pdf \ 101 | --checkpoint runs/fcnf0++/00250000.pt \ 102 | --gpu $1 103 | # Note - You will need to replace this checkpoint with the final checkpoint 104 | # that was produced during fcnf0 training 105 | python -m penn.plot.logits \ 106 | --config config/fcnf0.py \ 107 | --audio_file test/assets/gershwin.wav \ 108 | --output_file results/fcnf0-gershwin.pdf \ 109 | --checkpoint runs/fcnf0/00071000.pt \ 110 | --gpu $1 111 | -------------------------------------------------------------------------------- /penn/config/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | ############################################################################### 6 | # Metadata 7 | ############################################################################### 8 | 9 | 10 | # Configuration name 11 | CONFIG = 'fcnf0++' 12 | 13 | 14 | ############################################################################### 15 | # Audio parameters 16 | ############################################################################### 17 | 18 | 19 | # Width of a pitch bin 20 | CENTS_PER_BIN = 5. # cents 21 | 22 | # Whether to trade quantization error for noise during inference 23 | DITHER = False 24 | 25 | # Minimum representable frequency 26 | FMIN = 31. # Hz 27 | 28 | # Distance between adjacent frames 29 | HOPSIZE = 80 # samples 30 | 31 | # Whether to normalize input audio to mean zero and variance one 32 | NORMALIZE_INPUT = False 33 | 34 | # Number of spectrogram frequency bins 35 | NUM_FFT = 1024 36 | 37 | # One octave in cents 38 | OCTAVE = 1200 # cents 39 | 40 | # Number of pitch bins to predict 41 | PITCH_BINS = 1440 42 | 43 | # Audio sample rate 44 | SAMPLE_RATE = 8000 # hz 45 | 46 | # Size of the analysis window 47 | WINDOW_SIZE = 1024 # samples 48 | 49 | 50 | ############################################################################### 51 | # Decoder parameters 52 | ############################################################################### 53 | 54 | 55 | # The decoder to use for postprocessing. One of ['argmax', 'pyin', 'viterbi']. 56 | DECODER = 'viterbi' 57 | 58 | # Whether to perform local expected value decoding of pitch 59 | LOCAL_EXPECTED_VALUE = True 60 | 61 | # The size of the window used for local expected value pitch decoding 62 | LOCAL_PITCH_WINDOW_SIZE = 19 63 | 64 | # Pitch velocity constraint for viterbi decoding 65 | MAX_OCTAVES_PER_SECOND = 32. 66 | 67 | # Maximum chunk size for chunked Viterbi decoding 68 | VITERBI_MIN_CHUNK_SIZE = None 69 | 70 | 71 | ############################################################################### 72 | # Directories 73 | ############################################################################### 74 | 75 | 76 | # Location to save assets to be bundled with pip release 77 | ASSETS_DIR = Path(__file__).parent.parent / 'assets' 78 | 79 | # Location of preprocessed features 80 | CACHE_DIR = Path(__file__).parent.parent.parent / 'data' / 'cache' 81 | 82 | # Location of datasets on disk 83 | DATA_DIR = Path(__file__).parent.parent.parent / 'data' / 'datasets' 84 | 85 | # Location to save evaluation artifacts 86 | EVAL_DIR = Path(__file__).parent.parent.parent / 'eval' 87 | 88 | # Location to save training and adaptation artifacts 89 | RUNS_DIR = Path(__file__).parent.parent.parent / 'runs' 90 | 91 | # Location of compressed datasets on disk 92 | SOURCE_DIR = Path(__file__).parent.parent.parent / 'data' / 'sources' 93 | 94 | 95 | ############################################################################### 96 | # Evaluation parameters 97 | ############################################################################### 98 | 99 | 100 | # Whether to perform benchmarking 101 | BENCHMARK = False 102 | 103 | # Number of steps between saving checkpoints 104 | CHECKPOINT_INTERVAL = 25000 # steps 105 | 106 | # List of all datasets 107 | DATASETS = ['mdb', 'ptdb'] 108 | 109 | # Method to use for evaluation 110 | METHOD = 'penn' 111 | 112 | # Batch size to use for evaluation 113 | EVALUATION_BATCH_SIZE = 2048 114 | 115 | # Datsets to use for evaluation 116 | EVALUATION_DATASETS = DATASETS 117 | 118 | # Number of steps between logging to Tensorboard 119 | LOG_INTERVAL = 2500 # steps 120 | 121 | # Number of batches to use for validation 122 | LOG_STEPS = 64 123 | 124 | # Method to use for periodicity extraction 125 | PERIODICITY = 'entropy' 126 | 127 | 128 | ############################################################################### 129 | # Model parameters 130 | ############################################################################### 131 | 132 | 133 | # The dropout rate. Set to None to turn off dropout. 134 | DROPOUT = None 135 | 136 | # The name of the model to use for training 137 | MODEL = 'fcnf0' 138 | 139 | # Type of model normalization 140 | NORMALIZATION = 'layer' 141 | 142 | 143 | ############################################################################### 144 | # Training parameters 145 | ############################################################################### 146 | 147 | 148 | # Batch size 149 | BATCH_SIZE = 128 150 | 151 | # Whether to stop training when validation loss stops improving 152 | EARLY_STOPPING = False 153 | 154 | # Stop after this number of log intervals without validation improvements 155 | EARLY_STOPPING_STEPS = 32 156 | 157 | # Whether to apply Gaussian blur to binary cross-entropy loss targets 158 | GAUSSIAN_BLUR = True 159 | 160 | # Optimizer learning rate 161 | LEARNING_RATE = 2e-4 162 | 163 | # Loss function 164 | LOSS = 'categorical_cross_entropy' 165 | 166 | # Number of training steps 167 | STEPS = 250000 168 | 169 | # Number of frames used during training 170 | NUM_TRAINING_FRAMES = 1 171 | 172 | # Number of data loading worker threads 173 | NUM_WORKERS = os.cpu_count() // 4 174 | 175 | # Seed for all random number generators 176 | RANDOM_SEED = 1234 177 | 178 | # Whether to only use voiced start frames 179 | VOICED_ONLY = False 180 | -------------------------------------------------------------------------------- /penn/assets/partitions/mdb.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": [ 3 | "000000", 4 | "000002", 5 | "000005", 6 | "000006", 7 | "000009", 8 | "000010", 9 | "000012", 10 | "000014", 11 | "000015", 12 | "000019", 13 | "000020", 14 | "000024", 15 | "000026", 16 | "000027", 17 | "000028", 18 | "000030", 19 | "000031", 20 | "000032", 21 | "000033", 22 | "000034", 23 | "000035", 24 | "000037", 25 | "000039", 26 | "000040", 27 | "000041", 28 | "000043", 29 | "000044", 30 | "000045", 31 | "000047", 32 | "000048", 33 | "000049", 34 | "000050", 35 | "000052", 36 | "000053", 37 | "000054", 38 | "000055", 39 | "000056", 40 | "000057", 41 | "000058", 42 | "000059", 43 | "000061", 44 | "000062", 45 | "000064", 46 | "000065", 47 | "000066", 48 | "000067", 49 | "000070", 50 | "000071", 51 | "000072", 52 | "000073", 53 | "000074", 54 | "000075", 55 | "000076", 56 | "000078", 57 | "000079", 58 | "000080", 59 | "000081", 60 | "000082", 61 | "000083", 62 | "000084", 63 | "000086", 64 | "000087", 65 | "000089", 66 | "000091", 67 | "000092", 68 | "000093", 69 | "000094", 70 | "000095", 71 | "000097", 72 | "000098", 73 | "000099", 74 | "000100", 75 | "000101", 76 | "000102", 77 | "000103", 78 | "000104", 79 | "000105", 80 | "000106", 81 | "000107", 82 | "000108", 83 | "000109", 84 | "000110", 85 | "000111", 86 | "000113", 87 | "000114", 88 | "000115", 89 | "000117", 90 | "000120", 91 | "000122", 92 | "000125", 93 | "000126", 94 | "000127", 95 | "000129", 96 | "000132", 97 | "000133", 98 | "000134", 99 | "000136", 100 | "000137", 101 | "000139", 102 | "000140", 103 | "000141", 104 | "000143", 105 | "000144", 106 | "000145", 107 | "000146", 108 | "000147", 109 | "000150", 110 | "000151", 111 | "000152", 112 | "000154", 113 | "000155", 114 | "000156", 115 | "000158", 116 | "000160", 117 | "000161", 118 | "000162", 119 | "000163", 120 | "000164", 121 | "000166", 122 | "000168", 123 | "000169", 124 | "000173", 125 | "000174", 126 | "000175", 127 | "000176", 128 | "000178", 129 | "000179", 130 | "000180", 131 | "000181", 132 | "000183", 133 | "000184", 134 | "000185", 135 | "000186", 136 | "000187", 137 | "000188", 138 | "000189", 139 | "000190", 140 | "000191", 141 | "000192", 142 | "000193", 143 | "000197", 144 | "000198", 145 | "000200", 146 | "000202", 147 | "000203", 148 | "000204", 149 | "000205", 150 | "000208", 151 | "000210", 152 | "000212", 153 | "000213", 154 | "000215", 155 | "000218", 156 | "000219", 157 | "000220", 158 | "000221", 159 | "000222", 160 | "000223", 161 | "000224", 162 | "000228", 163 | "000229" 164 | ], 165 | "valid": [ 166 | "000011", 167 | "000013", 168 | "000017", 169 | "000018", 170 | "000022", 171 | "000036", 172 | "000042", 173 | "000051", 174 | "000068", 175 | "000069", 176 | "000077", 177 | "000085", 178 | "000096", 179 | "000116", 180 | "000119", 181 | "000121", 182 | "000130", 183 | "000131", 184 | "000135", 185 | "000138", 186 | "000142", 187 | "000148", 188 | "000153", 189 | "000167", 190 | "000170", 191 | "000172", 192 | "000194", 193 | "000195", 194 | "000207", 195 | "000209", 196 | "000211", 197 | "000216", 198 | "000217", 199 | "000226" 200 | ], 201 | "test": [ 202 | "000001", 203 | "000003", 204 | "000004", 205 | "000007", 206 | "000008", 207 | "000016", 208 | "000021", 209 | "000023", 210 | "000025", 211 | "000029", 212 | "000038", 213 | "000046", 214 | "000060", 215 | "000063", 216 | "000088", 217 | "000090", 218 | "000112", 219 | "000118", 220 | "000123", 221 | "000124", 222 | "000128", 223 | "000149", 224 | "000157", 225 | "000159", 226 | "000165", 227 | "000171", 228 | "000177", 229 | "000182", 230 | "000196", 231 | "000199", 232 | "000201", 233 | "000206", 234 | "000214", 235 | "000225", 236 | "000227" 237 | ] 238 | } -------------------------------------------------------------------------------- /penn/evaluate/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchutil 3 | 4 | import penn 5 | 6 | 7 | ############################################################################### 8 | # Constants 9 | ############################################################################### 10 | 11 | 12 | # Evaluation threshold for RPA and RCA 13 | THRESHOLD = 50 # cents 14 | 15 | 16 | ############################################################################### 17 | # Aggregate metric 18 | ############################################################################### 19 | 20 | 21 | class Metrics: 22 | 23 | def __init__(self): 24 | self.accuracy = torchutil.metrics.Accuracy() 25 | self.f1 = F1() 26 | self.loss = Loss() 27 | self.pitch_metrics = PitchMetrics() 28 | 29 | def __call__(self): 30 | return ( 31 | { 32 | 'accuracy': self.accuracy(), 33 | 'loss': self.loss() 34 | } | 35 | self.f1() | 36 | self.pitch_metrics()) 37 | 38 | def update(self, logits, bins, target, voiced): 39 | # Detach from graph 40 | logits = logits.detach() 41 | 42 | # Update loss 43 | self.loss.update(logits[:, :penn.PITCH_BINS], bins.T) 44 | 45 | # Decode bins, pitch, and periodicity 46 | with torchutil.time.context('decode'): 47 | predicted, pitch, periodicity = penn.postprocess(logits) 48 | 49 | # Update bin accuracy 50 | self.accuracy.update(predicted[voiced], bins[voiced]) 51 | 52 | # Update pitch metrics 53 | self.pitch_metrics.update(pitch, target, voiced) 54 | 55 | # Update periodicity metrics 56 | self.f1.update(periodicity, voiced) 57 | 58 | def reset(self): 59 | self.accuracy.reset() 60 | self.f1.reset() 61 | self.loss.reset() 62 | self.pitch_metrics.reset() 63 | 64 | 65 | class PitchMetrics: 66 | 67 | def __init__(self): 68 | self.l1 = L1() 69 | self.rca = RCA() 70 | self.rmse = RMSE() 71 | self.rpa = RPA() 72 | 73 | def __call__(self): 74 | return { 75 | 'l1': self.l1(), 76 | 'rca': self.rca(), 77 | 'rmse': self.rmse(), 78 | 'rpa': self.rpa()} 79 | 80 | def update(self, pitch, target, voiced): 81 | # Mask unvoiced 82 | pitch, target = pitch[voiced], target[voiced] 83 | 84 | # Update metrics 85 | self.l1.update(pitch, target) 86 | self.rca.update(pitch, target) 87 | self.rmse.update(pitch, target) 88 | self.rpa.update(pitch, target) 89 | 90 | def reset(self): 91 | self.l1.reset() 92 | self.rca.reset() 93 | self.rmse.reset() 94 | self.rpa.reset() 95 | 96 | 97 | ############################################################################### 98 | # Individual metrics 99 | ############################################################################### 100 | 101 | 102 | class F1: 103 | 104 | def __init__(self, thresholds=None): 105 | if thresholds is None: 106 | thresholds = sorted(list(set( 107 | [2 ** -i for i in range(1, 11)] + 108 | [.1 * i for i in range(10)]))) 109 | self.thresholds = thresholds 110 | self.precision = [ 111 | torchutil.metrics.Precision() for _ in range(len(thresholds))] 112 | self.recall = [ 113 | torchutil.metrics.Recall() for _ in range(len(thresholds))] 114 | 115 | def __call__(self): 116 | result = {} 117 | for threshold, precision, recall in zip( 118 | self.thresholds, 119 | self.precision, 120 | self.recall 121 | ): 122 | precision = precision() 123 | recall = recall() 124 | try: 125 | f1 = 2 * precision * recall / (precision + recall) 126 | except ZeroDivisionError: 127 | f1 = 0. 128 | result |= { 129 | f'f1-{threshold:.6f}': f1, 130 | f'precision-{threshold:.6f}': precision, 131 | f'recall-{threshold:.6f}': recall} 132 | return result 133 | 134 | def update(self, periodicity, voiced): 135 | for threshold, precision, recall in zip( 136 | self.thresholds, 137 | self.precision, 138 | self.recall 139 | ): 140 | predicted = penn.voicing.threshold(periodicity, threshold) 141 | precision.update(predicted, voiced) 142 | recall.update(predicted, voiced) 143 | 144 | def reset(self): 145 | """Reset the F1 score""" 146 | for precision, recall in zip(self.precision, self.recall): 147 | precision.reset() 148 | recall.reset() 149 | 150 | 151 | class L1(torchutil.metrics.L1): 152 | """L1 pitch distance in cents""" 153 | def update(self, predicted, target): 154 | super().update( 155 | penn.OCTAVE * torch.log2(predicted), 156 | penn.OCTAVE * torch.log2(target)) 157 | 158 | 159 | class Loss(torchutil.metrics.Average): 160 | """Batch-updating loss""" 161 | def update(self, logits, bins): 162 | super().update(penn.loss(logits, bins), bins.shape[0]) 163 | 164 | 165 | class RCA(torchutil.metrics.Average): 166 | """Raw chroma accuracy""" 167 | def update(self, predicted, target): 168 | # Compute pitch difference in cents 169 | difference = penn.cents(predicted, target) 170 | 171 | # Forgive octave errors 172 | difference[difference > (penn.OCTAVE - THRESHOLD)] -= penn.OCTAVE 173 | difference[difference < -(penn.OCTAVE - THRESHOLD)] += penn.OCTAVE 174 | 175 | # Count predictions that are within 50 cents of target 176 | super().update( 177 | (torch.abs(difference) < THRESHOLD).sum(), 178 | predicted.numel()) 179 | 180 | 181 | class RMSE(torchutil.metrics.RMSE): 182 | """Root mean square error of pitch distance in cents""" 183 | def update(self, predicted, target): 184 | super().update( 185 | penn.OCTAVE * torch.log2(predicted), 186 | penn.OCTAVE * torch.log2(target)) 187 | 188 | 189 | class RPA(torchutil.metrics.Average): 190 | """Raw prediction accuracy""" 191 | def update(self, predicted, target): 192 | difference = penn.cents(predicted, target) 193 | super().update( 194 | (torch.abs(difference) < THRESHOLD).sum(), 195 | predicted.numel()) 196 | -------------------------------------------------------------------------------- /penn/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import penn 7 | 8 | 9 | ############################################################################### 10 | # Dataset 11 | ############################################################################### 12 | 13 | 14 | class Dataset(torch.utils.data.Dataset): 15 | """PyTorch dataset 16 | 17 | Arguments 18 | names - list[str] 19 | The names of datasets to load from 20 | partition - string 21 | The name of the data partition 22 | """ 23 | 24 | def __init__(self, names, partition, hparam_search=False): 25 | self.partition = partition 26 | self.hparam_search = hparam_search 27 | self.datasets = [Metadata(name, partition) for name in names] 28 | 29 | def __getitem__(self, index): 30 | if ( 31 | self.partition == 'test' or 32 | (self.partition == 'valid' and self.hparam_search) 33 | ): 34 | return self.load_inference(index) 35 | return self.load_training(index) 36 | 37 | def __len__(self): 38 | """Length of the dataset""" 39 | if ( 40 | self.partition == 'test' or 41 | (self.partition == 'valid' and self.hparam_search) 42 | ): 43 | return sum(len(dataset.files) for dataset in self.datasets) 44 | return sum(dataset.total for dataset in self.datasets) 45 | 46 | def load_inference(self, index): 47 | """Load item for inference""" 48 | # Get dataset to query 49 | i = 0 50 | dataset = self.datasets[i] 51 | upper_bound = len(dataset.files) 52 | while index >= upper_bound: 53 | i += 1 54 | dataset = self.datasets[i] 55 | upper_bound += len(dataset.files) 56 | 57 | # Get index into dataset 58 | index -= (upper_bound - len(dataset.files)) 59 | 60 | # Get stem 61 | stem = dataset.stems[index] 62 | 63 | # Load from cache 64 | directory = penn.CACHE_DIR / dataset.name 65 | audio = np.load(directory / f'{stem}-audio.npy') 66 | pitch = np.load(directory / f'{stem}-pitch.npy') 67 | voiced = np.load(directory / f'{stem}-voiced.npy') 68 | 69 | # Convert to torch 70 | audio = torch.from_numpy(audio)[None] 71 | pitch = torch.from_numpy(pitch) 72 | voiced = torch.from_numpy(voiced) 73 | 74 | # Convert to pitch bin categories 75 | bins = penn.convert.frequency_to_bins(pitch) 76 | 77 | # Set unvoiced bins to random values 78 | bins = torch.where( 79 | ~voiced, 80 | torch.randint(0, penn.PITCH_BINS, bins.shape, dtype=torch.long), 81 | bins) 82 | 83 | return audio, bins, pitch, voiced, stem 84 | 85 | def load_training(self, index): 86 | """Load item for training""" 87 | # Get dataset to query 88 | i = 0 89 | dataset = self.datasets[i] 90 | upper_bound = dataset.total 91 | while index >= upper_bound: 92 | i += 1 93 | dataset = self.datasets[i] 94 | upper_bound += dataset.total 95 | 96 | # Get index into dataset 97 | index -= (upper_bound - dataset.total) 98 | 99 | # Get stem 100 | stem_index = bisect.bisect(dataset.offsets, index) 101 | stem = dataset.stems[stem_index] 102 | 103 | # Get start and end frames 104 | start = \ 105 | index - (0 if stem_index == 0 else dataset.offsets[stem_index - 1]) 106 | end = start + penn.NUM_TRAINING_FRAMES 107 | 108 | # Get start and end samples 109 | start_sample = \ 110 | penn.convert.frames_to_samples(start) - \ 111 | (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 112 | end_sample = start_sample + penn.NUM_TRAINING_SAMPLES 113 | 114 | # Load from cache 115 | directory = penn.CACHE_DIR / dataset.name 116 | waveform = np.load(directory / f'{stem}-audio.npy', mmap_mode='r') 117 | pitch = np.load(directory / f'{stem}-pitch.npy', mmap_mode='r') 118 | voiced = np.load(directory / f'{stem}-voiced.npy', mmap_mode='r') 119 | 120 | # Slice audio 121 | if start_sample < 0: 122 | audio = torch.zeros( 123 | (penn.NUM_TRAINING_SAMPLES,), 124 | dtype=torch.float) 125 | audio[-start_sample:] = torch.from_numpy( 126 | waveform[:end_sample].copy()) 127 | elif end_sample > len(waveform): 128 | audio = torch.zeros( 129 | (penn.NUM_TRAINING_SAMPLES,), 130 | dtype=torch.float) 131 | audio[:len(waveform) - end_sample] = torch.from_numpy( 132 | waveform[start_sample:].copy()) 133 | else: 134 | audio = torch.from_numpy( 135 | waveform[start_sample:end_sample].copy()) 136 | 137 | # Slice pitch and voicing 138 | pitch = torch.from_numpy(pitch[start:end].copy()) 139 | voiced = torch.from_numpy(voiced[start:end].copy()) 140 | 141 | # Convert to pitch bin categories 142 | bins = penn.convert.frequency_to_bins(pitch) 143 | 144 | # Set unvoiced bins to random values 145 | bins = torch.where( 146 | ~voiced, 147 | torch.randint(0, penn.PITCH_BINS, bins.shape, dtype=torch.long), 148 | bins) 149 | 150 | return audio[None], bins, pitch, voiced, stem 151 | 152 | def voiced_indices(self): 153 | """Retrieve the indices with voiced start frames""" 154 | offset = 0 155 | indices = [] 156 | for dataset in self.datasets: 157 | indices += [index + offset for index in dataset.voiced_indices()] 158 | offset += dataset.total 159 | return indices 160 | 161 | 162 | ############################################################################### 163 | # Metadata 164 | ############################################################################### 165 | 166 | 167 | class Metadata: 168 | 169 | def __init__(self, name, partition): 170 | self.name = name 171 | self.stems = penn.load.partition(name)[partition] 172 | self.files = [ 173 | penn.CACHE_DIR / name / f'{stem}-audio.npy' 174 | for stem in self.stems] 175 | 176 | # Get number of frames in each file 177 | self.frames = [ 178 | penn.convert.samples_to_frames(len(np.load(file, mmap_mode='r'))) 179 | for file in self.files] 180 | 181 | # We require all files to be at least as large as the analysis window 182 | assert all(frame >= penn.NUM_TRAINING_FRAMES for frame in self.frames) 183 | 184 | # Remove invalid center points 185 | self.frames = [ 186 | frame - (penn.NUM_TRAINING_FRAMES - 1) for frame in self.frames] 187 | 188 | # Save frame offsets 189 | self.offsets = np.cumsum(self.frames) 190 | 191 | # Total number of valid start points 192 | self.total = self.offsets[-1] 193 | 194 | def voiced_indices(self): 195 | """Retrieve the indices with voiced start frames""" 196 | # Get voicing files 197 | files = [ 198 | penn.CACHE_DIR / self.name / f'{stem}-voiced.npy' 199 | for stem in self.stems] 200 | 201 | offset = 0 202 | indices = [] 203 | for file in files: 204 | 205 | # Load 206 | voiced = np.load(file) 207 | 208 | # Remove invalid center points 209 | if penn.NUM_TRAINING_FRAMES > 1: 210 | voiced = voiced[:-(penn.NUM_TRAINING_FRAMES - 1)] 211 | 212 | # Update 213 | indices.extend(list(voiced.nonzero()[0] + offset)) 214 | offset += len(voiced) 215 | 216 | return indices 217 | -------------------------------------------------------------------------------- /penn/decode.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import functools 3 | 4 | import torbi 5 | import torch 6 | 7 | import penn 8 | 9 | 10 | ############################################################################### 11 | # Base pitch posteriorgram decoder 12 | ############################################################################### 13 | 14 | 15 | class Decoder(abc.ABC): 16 | """Base decoder""" 17 | 18 | def __init__(self, local_expected_value=True): 19 | self.local_expected_value = local_expected_value 20 | 21 | @abc.abstractmethod 22 | def __call__(self, logits): 23 | """Perform decoding""" 24 | pass 25 | 26 | 27 | ############################################################################### 28 | # Derived pitch posteriorgram decoders 29 | ############################################################################### 30 | 31 | 32 | class Argmax(Decoder): 33 | """Decode pitch using argmax""" 34 | 35 | def __init__(self, local_expected_value=penn.LOCAL_EXPECTED_VALUE): 36 | super().__init__(local_expected_value) 37 | 38 | def __call__(self, logits): 39 | # Get pitch bins 40 | bins = logits.argmax(dim=1) 41 | 42 | # Convert to frequency in Hz 43 | if self.local_expected_value: 44 | 45 | # Decode using an assumption of normality around the argmax path 46 | pitch = local_expected_value_from_bins(bins, logits) 47 | 48 | else: 49 | 50 | # Linearly interpolate unvoiced regions 51 | pitch = penn.convert.bins_to_frequency(bins) 52 | 53 | return bins, pitch 54 | 55 | 56 | class PYIN(Decoder): 57 | """Decode pitch via peak picking + Viterbi. Used by PYIN.""" 58 | 59 | def __init__(self, local_expected_value=False): 60 | super().__init__(local_expected_value) 61 | 62 | def __call__(self, logits): 63 | """PYIN decoding""" 64 | periodicity = penn.periodicity.sum(logits).T 65 | unvoiced = ( 66 | (1 - periodicity) / penn.PITCH_BINS).repeat(penn.PITCH_BINS, 1) 67 | distributions = torch.cat( 68 | (torch.exp(logits.permute(2, 1, 0)), unvoiced[None]), 69 | dim=1) 70 | 71 | # Viterbi decoding 72 | gpu = ( 73 | None if distributions.device.type == 'cpu' 74 | else distributions.device.index) 75 | bins = torbi.from_probabilities( 76 | observation=distributions[0].T.unsqueeze(dim=0), 77 | transition=self.transition, 78 | initial=self.initial, 79 | gpu=gpu) 80 | 81 | # Convert to frequency in Hz 82 | if self.local_expected_value: 83 | 84 | # Decode using an assumption of normality around the viterbi path 85 | pitch = local_expected_value_from_bins(bins.T, logits).T 86 | 87 | else: 88 | 89 | # Argmax decoding 90 | pitch = penn.convert.bins_to_frequency(bins) 91 | 92 | # Linearly interpolate unvoiced regions 93 | pitch[bins >= penn.PITCH_BINS] = 0 94 | pitch = penn.data.preprocess.interpolate_unvoiced(pitch.numpy())[0] 95 | pitch = torch.from_numpy(pitch).to(logits.device) 96 | bins = bins.to(logits.device) 97 | 98 | return bins.T, pitch.T 99 | 100 | @functools.cached_property 101 | def initial(self): 102 | """Create initial probability matrix for PYIN""" 103 | initial = torch.zeros(2 * penn.PITCH_BINS) 104 | initial[penn.PITCH_BINS:] = 1 / penn.PITCH_BINS 105 | return initial 106 | 107 | @functools.cached_property 108 | def transition(self): 109 | """Create the Viterbi transition matrix for PYIN""" 110 | transition = triangular_transition_matrix() 111 | 112 | # Add unvoiced probabilities 113 | transition = torch.kron( 114 | torch.tensor([[.99, .01], [.01, .99]]), 115 | transition) 116 | return transition 117 | 118 | 119 | class Viterbi(Decoder): 120 | 121 | def __init__(self, local_expected_value=True): 122 | super().__init__(local_expected_value) 123 | 124 | def __call__(self, logits): 125 | """Decode pitch using viterbi decoding (from librosa)""" 126 | distributions = torch.nn.functional.softmax(logits, dim=1) 127 | distributions = distributions.permute(2, 1, 0) # F x C x 1 -> 1 x C x F 128 | 129 | # Viterbi decoding 130 | gpu = ( 131 | None if distributions.device.type == 'cpu' 132 | else distributions.device.index) 133 | bins = torbi.from_probabilities( 134 | observation=distributions[0].T.unsqueeze(dim=0), 135 | transition=self.transition, 136 | initial=self.initial, 137 | gpu=gpu) 138 | 139 | # Convert to frequency in Hz 140 | if self.local_expected_value: 141 | 142 | # Decode using an assumption of normality around the viterbi path 143 | pitch = local_expected_value_from_bins(bins.T, logits).T 144 | 145 | else: 146 | 147 | # Argmax decoding 148 | pitch = penn.convert.bins_to_frequency(bins) 149 | 150 | return bins.T, pitch.T 151 | 152 | @functools.cached_property 153 | def initial(self): 154 | """Create uniform initial probabilities""" 155 | return torch.full((penn.PITCH_BINS,), 1 / penn.PITCH_BINS) 156 | 157 | @functools.cached_property 158 | def transition(self): 159 | """Create Viterbi transition probability matrix""" 160 | return triangular_transition_matrix() 161 | 162 | 163 | ############################################################################### 164 | # Utilities 165 | ############################################################################### 166 | 167 | 168 | def expected_value(logits, cents): 169 | """Expected value computation from logits""" 170 | # Get local distributions 171 | if penn.LOSS == 'categorical_cross_entropy': 172 | distributions = torch.nn.functional.softmax(logits, dim=1) 173 | elif penn.LOSS == 'binary_cross_entropy': 174 | distributions = torch.sigmoid(logits) 175 | else: 176 | raise ValueError(f'Loss {penn.LOSS} is not defined') 177 | 178 | # Pitch is expected value in cents 179 | pitch = (distributions * cents).sum(dim=1, keepdims=True) 180 | 181 | # BCE requires normalization 182 | if penn.LOSS == 'binary_cross_entropy': 183 | pitch = pitch / distributions.sum(dim=1) 184 | 185 | # Convert to hz 186 | return penn.convert.cents_to_frequency(pitch) 187 | 188 | 189 | def local_expected_value_from_bins( 190 | bins, 191 | logits, 192 | window=penn.LOCAL_PITCH_WINDOW_SIZE): 193 | """Decode pitch using normal assumption around argmax from bin indices""" 194 | # Pad 195 | padded = torch.nn.functional.pad( 196 | logits.squeeze(2), 197 | (window // 2, window // 2), 198 | value=-float('inf')) 199 | 200 | # Get indices 201 | indices = \ 202 | bins.repeat(1, window) + torch.arange(window, device=bins.device)[None] 203 | 204 | # Get values in cents 205 | cents = penn.convert.bins_to_cents(torch.clip(indices - window // 2, 0)) 206 | 207 | # Decode using local expected value 208 | return expected_value(torch.gather(padded, 1, indices), cents) 209 | 210 | 211 | def triangular_transition_matrix(): 212 | """Create a triangular distribution transition matrix""" 213 | xx, yy = torch.meshgrid( 214 | torch.arange(penn.PITCH_BINS), 215 | torch.arange(penn.PITCH_BINS), 216 | indexing='ij') 217 | bins_per_octave = penn.OCTAVE / penn.CENTS_PER_BIN 218 | max_octaves_per_frame = \ 219 | penn.MAX_OCTAVES_PER_SECOND * penn.HOPSIZE / penn.SAMPLE_RATE 220 | max_bins_per_frame = max_octaves_per_frame * bins_per_octave + 1 221 | transition = torch.clip(max_bins_per_frame - (xx - yy).abs(), 0) 222 | return transition / transition.sum(dim=1, keepdims=True) 223 | -------------------------------------------------------------------------------- /penn/data/preprocess/core.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import warnings 3 | 4 | import numpy as np 5 | import torchaudio 6 | import torchutil 7 | 8 | import penn 9 | 10 | 11 | ############################################################################### 12 | # Constants 13 | ############################################################################### 14 | 15 | 16 | # MDB analysis parameters 17 | MDB_HOPSIZE = 128 # samples 18 | MDB_SAMPLE_RATE = 44100 # samples per second 19 | 20 | # PTDB analysis parameters 21 | PTDB_HOPSIZE = 160 # samples 22 | PTDB_SAMPLE_RATE = 16000 # samples per second 23 | PTDB_WINDOW_SIZE = 512 # samples 24 | PTDB_HOPSIZE_SECONDS = PTDB_HOPSIZE / PTDB_SAMPLE_RATE 25 | 26 | 27 | ############################################################################### 28 | # Preprocess datasets 29 | ############################################################################### 30 | 31 | 32 | @torchutil.notify('preprocess') 33 | def datasets(datasets): 34 | """Preprocess datasets""" 35 | if 'mdb' in datasets: 36 | mdb() 37 | 38 | if 'ptdb' in datasets: 39 | ptdb() 40 | 41 | 42 | ############################################################################### 43 | # Individual datasets 44 | ############################################################################### 45 | 46 | 47 | def mdb(): 48 | """Preprocess mdb dataset""" 49 | # Get audio files 50 | audio_files = (penn.DATA_DIR / 'mdb'/ 'audio_stems').glob('*.wav') 51 | audio_files = sorted([ 52 | file for file in audio_files if not file.stem.startswith('._')]) 53 | 54 | # Get pitch files 55 | pitch_files = [ 56 | file.parent.parent / 57 | 'annotation_stems' / 58 | file.with_suffix('.csv').name 59 | for file in audio_files] 60 | 61 | # Create cache 62 | output_directory = penn.CACHE_DIR / 'mdb' 63 | output_directory.mkdir(exist_ok=True, parents=True) 64 | 65 | # Write audio and pitch to cache 66 | for i, (audio_file, pitch_file) in torchutil.iterator( 67 | enumerate(zip(audio_files, pitch_files)), 68 | 'Preprocessing mdb', 69 | total=len(audio_files) 70 | ): 71 | stem = f'{i:06d}' 72 | 73 | # Load and resample audio 74 | audio = penn.load.audio(audio_file) 75 | 76 | # Save as numpy array for fast memory-mapped reads 77 | np.save( 78 | output_directory / f'{stem}-audio.npy', 79 | audio.numpy().squeeze()) 80 | 81 | # Save audio for listening and evaluation 82 | torchaudio.save( 83 | output_directory / f'{stem}.wav', 84 | audio, 85 | penn.SAMPLE_RATE) 86 | 87 | # Load pitch 88 | annotations = np.loadtxt(open(pitch_file), delimiter=',') 89 | times, pitch = annotations[:, 0], annotations[:, 1] 90 | 91 | # Fill unvoiced regions via linear interpolation 92 | pitch, voiced = interpolate_unvoiced(pitch) 93 | 94 | # Get target number of frames 95 | frames = penn.convert.samples_to_frames(audio.shape[-1]) 96 | 97 | # Linearly interpolate to target number of frames 98 | new_times = penn.HOPSIZE_SECONDS * np.arange(0, frames) 99 | new_times += penn.HOPSIZE_SECONDS / 2. 100 | pitch = 2. ** np.interp(new_times, times, np.log2(pitch)) 101 | 102 | # Linearly interpolate voiced/unvoiced tokens 103 | voiced = np.interp(new_times, times, voiced) > .5 104 | 105 | # Check shapes 106 | assert ( 107 | penn.convert.samples_to_frames(audio.shape[-1]) == 108 | pitch.shape[-1] == 109 | voiced.shape[-1]) 110 | 111 | # Save to cache 112 | np.save(output_directory / f'{stem}-pitch.npy', pitch) 113 | np.save(output_directory / f'{stem}-voiced.npy', voiced) 114 | 115 | 116 | def ptdb(): 117 | """Preprocessing ptdb dataset""" 118 | # Get audio files 119 | directory = penn.DATA_DIR / 'ptdb' / 'SPEECH DATA' 120 | male = (directory / 'MALE' / 'MIC').rglob('*.wav') 121 | female = (directory / 'FEMALE' / 'MIC').rglob('*.wav') 122 | audio_files = sorted(itertools.chain(male, female)) 123 | 124 | # Get pitch files 125 | pitch_files = [ 126 | file.parent.parent.parent / 127 | 'REF' / 128 | file.parent.name / 129 | file.with_suffix('.f0').name.replace('mic', 'ref') 130 | for file in audio_files] 131 | 132 | # Create cache 133 | output_directory = penn.CACHE_DIR / 'ptdb' 134 | output_directory.mkdir(exist_ok=True, parents=True) 135 | 136 | # Write audio and pitch to cache 137 | for i, (audio_file, pitch_file) in torchutil.iterator( 138 | enumerate(zip(audio_files, pitch_files)), 139 | 'Preprocessing ptdb', 140 | total=len(audio_files) 141 | ): 142 | stem = f'{i:06d}' 143 | 144 | # Load and resample to PTDB sample rate 145 | audio, sample_rate = torchaudio.load(audio_file) 146 | audio = penn.resample(audio, sample_rate, PTDB_SAMPLE_RATE) 147 | 148 | # Remove padding 149 | offset = PTDB_WINDOW_SIZE - PTDB_HOPSIZE // 2 150 | if (audio.shape[-1] - 2 * offset) % PTDB_HOPSIZE == 0: 151 | offset += PTDB_HOPSIZE // 2 152 | audio = audio[:, offset:-offset] 153 | 154 | # Resample to pitch estimation sample rate 155 | audio = penn.resample(audio, PTDB_SAMPLE_RATE) 156 | 157 | # Save as numpy array for fast memory-mapped read 158 | np.save( 159 | output_directory / f'{stem}-audio.npy', 160 | audio.numpy().squeeze()) 161 | 162 | # Save audio for listening and evaluation 163 | torchaudio.save( 164 | output_directory / f'{stem}.wav', 165 | audio, 166 | penn.SAMPLE_RATE) 167 | 168 | # Load pitch 169 | pitch = np.loadtxt(open(pitch_file), delimiter=' ')[:, 0] 170 | 171 | # Fill unvoiced regions via linear interpolation 172 | pitch, voiced = interpolate_unvoiced(pitch) 173 | 174 | # Get target number of frames 175 | frames = penn.convert.samples_to_frames(audio.shape[-1]) 176 | 177 | # Get original times 178 | times = PTDB_HOPSIZE_SECONDS * np.arange(0, len(pitch)) 179 | times += PTDB_HOPSIZE_SECONDS / 2 180 | 181 | # Linearly interpolate to target number of frames 182 | new_times = penn.HOPSIZE_SECONDS * np.arange(0, frames) 183 | new_times += penn.HOPSIZE_SECONDS / 2. 184 | 185 | pitch = 2. ** np.interp(new_times, times, np.log2(pitch)) 186 | 187 | # Linearly interpolate voiced/unvoiced tokens 188 | voiced = np.interp(new_times, times, voiced) > .5 189 | 190 | # Check shapes 191 | assert ( 192 | penn.convert.samples_to_frames(audio.shape[-1]) == 193 | pitch.shape[-1] == 194 | voiced.shape[-1]) 195 | 196 | # Save to cache 197 | np.save(output_directory / f'{stem}-pitch.npy', pitch) 198 | np.save(output_directory / f'{stem}-voiced.npy', voiced) 199 | 200 | 201 | ############################################################################### 202 | # Utilities 203 | ############################################################################### 204 | 205 | 206 | def interpolate_unvoiced(pitch): 207 | """Fill unvoiced regions via linear interpolation""" 208 | unvoiced = pitch == 0 209 | 210 | # Ignore warning of log setting unvoiced regions (zeros) to nan 211 | with warnings.catch_warnings(): 212 | warnings.simplefilter('ignore') 213 | 214 | # Pitch is linear in base-2 log-space 215 | pitch = np.log2(pitch) 216 | 217 | try: 218 | 219 | # Interpolate 220 | pitch[unvoiced] = np.interp( 221 | np.where(unvoiced)[0], 222 | np.where(~unvoiced)[0], 223 | pitch[~unvoiced]) 224 | 225 | except ValueError: 226 | 227 | # Allow all unvoiced 228 | pass 229 | 230 | return 2 ** pitch, ~unvoiced 231 | -------------------------------------------------------------------------------- /penn/train/core.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torchutil 5 | 6 | import penn 7 | 8 | 9 | ############################################################################### 10 | # Training 11 | ############################################################################### 12 | 13 | 14 | @torchutil.notify('train') 15 | def train(datasets, directory, gpu=None): 16 | """Train a model""" 17 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 18 | model = penn.model.Model().to(device) 19 | 20 | ####################### 21 | # Create data loaders # 22 | ####################### 23 | 24 | torch.manual_seed(penn.RANDOM_SEED) 25 | train_loader = penn.data.loader(datasets, 'train') 26 | valid_loader = penn.data.loader(datasets, 'valid') 27 | 28 | #################### 29 | # Create optimizer # 30 | #################### 31 | 32 | optimizer = torch.optim.Adam(model.parameters(), lr=penn.LEARNING_RATE) 33 | 34 | ############################## 35 | # Maybe load from checkpoint # 36 | ############################## 37 | 38 | path = torchutil.checkpoint.latest_path(directory) 39 | 40 | if path is not None: 41 | 42 | # Load model 43 | model, optimizer, state = torchutil.checkpoint.load( 44 | path, 45 | model, 46 | optimizer) 47 | step, epoch = state['step'], state['epoch'] 48 | 49 | else: 50 | 51 | # Train from scratch 52 | step, epoch = 0, 0 53 | 54 | ############################## 55 | # Maybe setup early stopping # 56 | ############################## 57 | 58 | if penn.EARLY_STOPPING: 59 | counter = penn.EARLY_STOPPING_STEPS 60 | best_accuracy = 0. 61 | stop = False 62 | 63 | ######### 64 | # Train # 65 | ######### 66 | 67 | # Automatic mixed precision (amp) gradient scaler 68 | scaler = torch.cuda.amp.GradScaler() 69 | 70 | # Setup progress bar 71 | progress = torchutil.iterator( 72 | range(step, penn.STEPS), 73 | f'Training {penn.CONFIG}', 74 | step, 75 | penn.STEPS) 76 | while step < penn.STEPS and (not penn.EARLY_STOPPING or not stop): 77 | 78 | # Seed sampler 79 | train_loader.sampler.set_epoch(epoch) 80 | 81 | for batch in train_loader: 82 | 83 | # Unpack batch 84 | audio, bins, *_ = batch 85 | 86 | with torch.autocast(device.type): 87 | 88 | # Forward pass 89 | logits = model(audio.to(device)) 90 | 91 | # Compute losses 92 | losses = loss(logits, bins.to(device)) 93 | 94 | ################## 95 | # Optimize model # 96 | ################## 97 | 98 | optimizer.zero_grad() 99 | 100 | # Backward pass 101 | scaler.scale(losses).backward() 102 | 103 | # Update weights 104 | scaler.step(optimizer) 105 | 106 | # Update gradient scaler 107 | scaler.update() 108 | 109 | ############## 110 | # Evaluation # 111 | ############## 112 | 113 | # Save checkpoint 114 | if step and step % penn.CHECKPOINT_INTERVAL == 0: 115 | torchutil.checkpoint.save( 116 | directory / f'{step:08d}.pt', 117 | model, 118 | optimizer, 119 | step=step, 120 | epoch=epoch) 121 | 122 | # Evaluate 123 | if step % penn.LOG_INTERVAL == 0: 124 | evaluate_fn = functools.partial( 125 | evaluate, 126 | directory, 127 | step, 128 | model, 129 | gpu) 130 | evaluate_fn('train', train_loader) 131 | valid_accuracy = evaluate_fn('valid', valid_loader) 132 | 133 | # Maybe stop training 134 | if penn.EARLY_STOPPING: 135 | counter -= 1 136 | 137 | # Update best validation loss 138 | if valid_accuracy > best_accuracy: 139 | best_accuracy = valid_accuracy 140 | counter = penn.EARLY_STOPPING_STEPS 141 | 142 | # Stop training 143 | elif counter == 0: 144 | stop = True 145 | 146 | # Update training step count 147 | if step >= penn.STEPS or (penn.EARLY_STOPPING and stop): 148 | break 149 | step += 1 150 | 151 | # Update progress bar 152 | progress.update() 153 | 154 | # Update epoch 155 | epoch += 1 156 | 157 | # Close progress bar 158 | progress.close() 159 | 160 | # Save final model 161 | torchutil.checkpoint.save( 162 | directory / f'{step:08d}.pt', 163 | model, 164 | optimizer, 165 | step=step, 166 | epoch=epoch) 167 | 168 | 169 | ############################################################################### 170 | # Evaluation 171 | ############################################################################### 172 | 173 | 174 | def evaluate(directory, step, model, gpu, condition, loader): 175 | """Perform model evaluation""" 176 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 177 | 178 | # Setup evaluation metrics 179 | metrics = penn.evaluate.Metrics() 180 | 181 | # Prepare model for inference 182 | with penn.inference_context(model): 183 | 184 | # Unpack batch 185 | for i, (audio, bins, pitch, voiced, *_) in enumerate(loader): 186 | 187 | # Forward pass 188 | logits = model(audio.to(device)) 189 | 190 | # Update metrics 191 | metrics.update( 192 | logits.to(device), 193 | bins.T.to(device), 194 | pitch.T.to(device), 195 | voiced.T.to(device)) 196 | 197 | # Stop when we exceed some number of batches 198 | if i + 1 == penn.LOG_STEPS: 199 | break 200 | 201 | # Format results 202 | scalars = { 203 | f'{key}/{condition}': value for key, value in metrics().items()} 204 | 205 | # Write to tensorboard 206 | torchutil.tensorboard.update(directory, step, scalars=scalars) 207 | 208 | return scalars[f'accuracy/{condition}'] 209 | 210 | 211 | ############################################################################### 212 | # Loss function 213 | ############################################################################### 214 | 215 | 216 | def loss(logits, bins): 217 | """Compute loss function""" 218 | # Reshape inputs 219 | logits = logits.permute(0, 2, 1).reshape(-1, penn.PITCH_BINS) 220 | bins = bins.flatten() 221 | 222 | # Maybe blur target 223 | if penn.GAUSSIAN_BLUR: 224 | 225 | # Cache cents values to evaluate distributions at 226 | if not hasattr(loss, 'cents'): 227 | loss.cents = penn.convert.bins_to_cents( 228 | torch.arange(penn.PITCH_BINS))[:, None] 229 | 230 | # Ensure values are on correct device (no-op if devices are the same) 231 | loss.cents = loss.cents.to(bins.device) 232 | 233 | # Create normal distributions 234 | distributions = torch.distributions.Normal( 235 | penn.convert.bins_to_cents(bins), 236 | 25) 237 | 238 | # Sample normal distributions 239 | bins = torch.exp(distributions.log_prob(loss.cents)).permute(1, 0) 240 | 241 | # Normalize 242 | bins = bins / (bins.max(dim=1, keepdims=True).values + 1e-8) 243 | 244 | else: 245 | 246 | # One-hot encoding 247 | bins = torch.nn.functional.one_hot(bins, penn.PITCH_BINS).float() 248 | 249 | if penn.LOSS == 'binary_cross_entropy': 250 | 251 | # Compute binary cross-entropy loss 252 | return torch.nn.functional.binary_cross_entropy_with_logits( 253 | logits, 254 | bins) 255 | 256 | elif penn.LOSS == 'categorical_cross_entropy': 257 | 258 | # Compute categorical cross-entropy loss 259 | return torch.nn.functional.cross_entropy(logits, bins) 260 | 261 | else: 262 | 263 | raise ValueError(f'Loss {penn.LOSS} is not implemented') 264 | -------------------------------------------------------------------------------- /penn/dsp/pyin.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import multiprocessing as mp 3 | 4 | import numpy as np 5 | import torch 6 | import torchutil 7 | 8 | import penn 9 | 10 | 11 | ############################################################################### 12 | # PYIN (from librosa) 13 | ############################################################################### 14 | 15 | 16 | def from_audio( 17 | audio, 18 | sample_rate=penn.SAMPLE_RATE, 19 | hopsize=penn.HOPSIZE_SECONDS, 20 | fmin=penn.FMIN, 21 | fmax=penn.FMAX): 22 | """Estimate pitch and periodicity with pyin""" 23 | # Pad 24 | pad = int( 25 | penn.WINDOW_SIZE - penn.convert.seconds_to_samples(hopsize)) // 2 26 | audio = torch.nn.functional.pad(audio, (pad, pad)) 27 | 28 | # Infer pitch bin probabilities 29 | with torchutil.time.context('infer'): 30 | logits = infer(audio, sample_rate, hopsize, fmin, fmax) 31 | 32 | # Decode pitch and periodicity 33 | with torchutil.time.context('postprocess'): 34 | return penn.postprocess(logits)[1:] 35 | 36 | 37 | def from_file( 38 | file, 39 | hopsize=penn.HOPSIZE_SECONDS, 40 | fmin=penn.FMIN, 41 | fmax=penn.FMAX): 42 | """Estimate pitch and periodicity with pyin from audio on disk""" 43 | # Load 44 | with torchutil.time.context('load'): 45 | audio = penn.load.audio(file) 46 | 47 | # Infer 48 | return from_audio(audio, penn.SAMPLE_RATE, hopsize, fmin, fmax) 49 | 50 | 51 | def from_file_to_file( 52 | file, 53 | output_prefix=None, 54 | hopsize=penn.HOPSIZE_SECONDS, 55 | fmin=penn.FMIN, 56 | fmax=penn.FMAX): 57 | """Estimate pitch and periodicity with pyin and save to disk""" 58 | # Infer 59 | pitch, periodicity = from_file(file, hopsize, fmin, fmax) 60 | 61 | # Save to disk 62 | with torchutil.time.context('save'): 63 | 64 | # Maybe use same filename with new extension 65 | if output_prefix is None: 66 | output_prefix = file.parent / file.stem 67 | 68 | # Save 69 | torch.save(pitch, f'{output_prefix}-pitch.pt') 70 | torch.save(periodicity, f'{output_prefix}-periodicity.pt') 71 | 72 | 73 | def from_files_to_files( 74 | files, 75 | output_prefixes=None, 76 | hopsize=penn.HOPSIZE_SECONDS, 77 | fmin=penn.FMIN, 78 | fmax=penn.FMAX): 79 | """Estimate pitch and periodicity with pyin and save to disk""" 80 | pitch_fn = functools.partial( 81 | from_file_to_file, 82 | hopsize=hopsize, 83 | fmin=fmin, 84 | fmax=fmax) 85 | iterator = zip(files, output_prefixes) 86 | 87 | # Turn off multiprocessing for benchmarking 88 | if penn.BENCHMARK: 89 | for item in torchutil.iterator( 90 | iterator, 91 | f'{penn.CONFIG}', 92 | total=len(files) 93 | ): 94 | pitch_fn(*item) 95 | else: 96 | with mp.get_context('spawn').Pool() as pool: 97 | pool.starmap(pitch_fn, iterator) 98 | 99 | 100 | ############################################################################### 101 | # Utilities 102 | ############################################################################### 103 | 104 | 105 | def cumulative_mean_normalized_difference(frames, min_period, max_period): 106 | import librosa 107 | 108 | a = np.fft.rfft(frames, 2 * penn.WINDOW_SIZE, axis=-2) 109 | b = np.fft.rfft( 110 | frames[..., penn.WINDOW_SIZE:0:-1, :], 111 | 2 * penn.WINDOW_SIZE, 112 | axis=-2) 113 | acf_frames = np.fft.irfft( 114 | a * b, 2 * penn.WINDOW_SIZE, axis=-2)[..., penn.WINDOW_SIZE:, :] 115 | acf_frames[np.abs(acf_frames) < 1e-6] = 0 116 | 117 | # Energy terms 118 | energy_frames = np.cumsum(frames ** 2, axis=-2) 119 | energy_frames = ( 120 | energy_frames[..., penn.WINDOW_SIZE:, :] - 121 | energy_frames[..., :-penn.WINDOW_SIZE, :]) 122 | energy_frames[np.abs(energy_frames) < 1e-6] = 0 123 | 124 | # Difference function 125 | yin_frames = energy_frames[..., :1, :] + energy_frames - 2 * acf_frames 126 | 127 | # Cumulative mean normalized difference function 128 | yin_numerator = yin_frames[..., min_period: max_period + 1, :] 129 | 130 | # Broadcast to have leading ones 131 | tau_range = librosa.util.expand_to( 132 | np.arange(1, max_period + 1), ndim=yin_frames.ndim, axes=-2) 133 | 134 | cumulative_mean = ( 135 | np.cumsum(yin_frames[..., 1: max_period + 1, :], axis=-2) / tau_range) 136 | 137 | yin_denominator = cumulative_mean[..., min_period - 1: max_period, :] 138 | yin_frames = yin_numerator / \ 139 | (yin_denominator + librosa.util.tiny(yin_denominator)) 140 | 141 | return yin_frames 142 | 143 | 144 | def infer( 145 | audio, 146 | sample_rate=penn.SAMPLE_RATE, 147 | hopsize=penn.HOPSIZE_SECONDS, 148 | fmin=penn.FMIN, 149 | fmax=penn.FMAX): 150 | hopsize = int(penn.convert.seconds_to_samples(hopsize)) 151 | import scipy 152 | 153 | # Pad audio to center-align frames 154 | pad = penn.WINDOW_SIZE // 2 155 | padded = torch.nn.functional.pad(audio, (0, 2 * pad)) 156 | 157 | # Slice and chunk audio 158 | frames = torch.nn.functional.unfold( 159 | padded[:, None, None], 160 | kernel_size=(1, 2 * penn.WINDOW_SIZE), 161 | stride=(1, penn.HOPSIZE))[0] 162 | 163 | # Calculate minimum and maximum periods 164 | min_period = max(int(np.floor(sample_rate / fmax)), 1) 165 | max_period = min( 166 | int(np.ceil(sample_rate / fmin)), 167 | penn.WINDOW_SIZE - 1) 168 | 169 | # Calculate cumulative mean normalized difference function 170 | yin_frames = cumulative_mean_normalized_difference( 171 | frames.numpy(), 172 | min_period, 173 | max_period) 174 | 175 | # Parabolic interpolation 176 | parabolic_shifts = parabolic_interpolation(yin_frames) 177 | 178 | # Find Yin candidates and probabilities. 179 | # The implementation here follows the official pYIN software which 180 | # differs from the method described in the paper. 181 | # 1. Define the prior over the thresholds. 182 | thresholds = np.linspace(0, 1, 100 + 1) 183 | beta_cdf = scipy.stats.beta.cdf(thresholds, 2, 18) 184 | beta_probs = np.diff(beta_cdf) 185 | 186 | def _helper(a, b): 187 | return pyin_helper( 188 | a, 189 | b, 190 | thresholds, 191 | 2, 192 | beta_probs, 193 | .01, 194 | min_period, 195 | (penn.OCTAVE / 12) / penn.CENTS_PER_BIN) 196 | 197 | helper = np.vectorize(_helper, signature="(f,t),(k,t)->(1,d,t)") 198 | probs = helper(yin_frames, parabolic_shifts) 199 | probs = torch.from_numpy(probs) 200 | return torch.log(probs).permute(2, 1, 0) 201 | 202 | 203 | def parabolic_interpolation(frames): 204 | """Piecewise parabolic interpolation for yin and pyin""" 205 | import librosa 206 | 207 | parabolic_shifts = np.zeros_like(frames) 208 | parabola_a = ( 209 | frames[..., :-2, :] + 210 | frames[..., 2:, :] - 211 | 2 * frames[..., 1:-1, :] 212 | ) / 2 213 | parabola_b = (frames[..., 2:, :] - frames[..., :-2, :]) / 2 214 | parabolic_shifts[..., 1:-1, :] = \ 215 | -parabola_b / (2 * parabola_a + librosa.util.tiny(parabola_a)) 216 | parabolic_shifts[np.abs(parabolic_shifts) > 1] = 0 217 | return parabolic_shifts 218 | 219 | 220 | def pyin_helper( 221 | frames, 222 | parabolic_shifts, 223 | thresholds, 224 | boltzmann_parameter, 225 | beta_probs, 226 | no_trough_prob, 227 | min_period, 228 | n_bins_per_semitone): 229 | import librosa 230 | import scipy 231 | 232 | 233 | yin_probs = np.zeros_like(frames) 234 | 235 | for i, yin_frame in enumerate(frames.T): 236 | # 2. For each frame find the troughs. 237 | is_trough = librosa.util.localmin(yin_frame) 238 | 239 | is_trough[0] = yin_frame[0] < yin_frame[1] 240 | (trough_index,) = np.nonzero(is_trough) 241 | 242 | if len(trough_index) == 0: 243 | continue 244 | 245 | # 3. Find the troughs below each threshold. 246 | # these are the local minima of the frame, could get them directly 247 | # without the trough index 248 | trough_heights = yin_frame[trough_index] 249 | trough_thresholds = np.less.outer(trough_heights, thresholds[1:]) 250 | 251 | # 4. Define the prior over the troughs. 252 | # Smaller periods are weighted more. 253 | trough_positions = np.cumsum(trough_thresholds, axis=0) - 1 254 | n_troughs = np.count_nonzero(trough_thresholds, axis=0) 255 | 256 | trough_prior = scipy.stats.boltzmann.pmf( 257 | trough_positions, 258 | boltzmann_parameter, 259 | n_troughs) 260 | 261 | trough_prior[~trough_thresholds] = 0 262 | 263 | # 5. For each threshold add probability to global minimum if no trough 264 | # is below threshold, else add probability to each trough below 265 | # threshold biased by prior. 266 | probs = trough_prior.dot(beta_probs) 267 | 268 | global_min = np.argmin(trough_heights) 269 | n_thresholds_below_min = np.count_nonzero( 270 | ~trough_thresholds[global_min, :]) 271 | probs[global_min] += no_trough_prob * np.sum( 272 | beta_probs[:n_thresholds_below_min]) 273 | 274 | yin_probs[trough_index, i] = probs 275 | 276 | yin_period, frame_index = np.nonzero(yin_probs) 277 | 278 | # Refine peak by parabolic interpolation. 279 | period_candidates = min_period + yin_period 280 | period_candidates = period_candidates + \ 281 | parabolic_shifts[yin_period, frame_index] 282 | f0_candidates = penn.SAMPLE_RATE / period_candidates 283 | 284 | # Find pitch bin corresponding to each f0 candidate. 285 | bin_index = 12 * n_bins_per_semitone * np.log2(f0_candidates / penn.FMIN) 286 | bin_index = np.clip( 287 | np.round(bin_index), 288 | 0, 289 | penn.PITCH_BINS - 1).astype(int) 290 | 291 | # Observation probabilities. 292 | observation_probs = np.zeros((penn.PITCH_BINS, frames.shape[1])) 293 | observation_probs[bin_index, frame_index] = \ 294 | yin_probs[yin_period, frame_index] 295 | 296 | return observation_probs[np.newaxis] 297 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Pitch-Estimating Neural Networks (PENN)

2 |
3 | 4 | [![PyPI](https://img.shields.io/pypi/v/penn.svg)](https://pypi.python.org/pypi/penn) 5 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 6 | [![Downloads](https://static.pepy.tech/badge/penn)](https://pepy.tech/project/penn) 7 | 8 |
9 | 10 | Training, evaluation, and inference of neural pitch and periodicity estimators in PyTorch. Includes the original code for the paper ["Cross-domain Neural Pitch and Periodicity Estimation"](https://arxiv.org/abs/2301.12258). 11 | 12 | 13 | ## Table of contents 14 | 15 | - [Installation](#installation) 16 | - [Inference](#inference) 17 | * [Application programming interface](#application-programming-interface) 18 | * [`penn.from_audio`](#pennfrom_audio) 19 | * [`penn.from_file`](#pennfrom_file) 20 | * [`penn.from_file_to_file`](#pennfrom_file_to_file) 21 | * [`penn.from_files_to_files`](#pennfrom_files_to_files) 22 | * [Command-line interface](#command-line-interface) 23 | - [Training](#training) 24 | * [Download](#download) 25 | * [Preprocess](#preprocess) 26 | * [Partition](#partition) 27 | * [Train](#train) 28 | * [Monitor](#monitor) 29 | - [Evaluation](#evaluation) 30 | * [Evaluate](#evaluate) 31 | * [Plot](#plot) 32 | - [Citation](#citation) 33 | 34 | 35 | ## Installation 36 | 37 | If you want to perform pitch estimation using a pretrained FCNF0++ model, run 38 | `pip install penn` 39 | 40 | If you want to train or use your own models, run 41 | `pip install penn[train]` 42 | 43 | 44 | ## Inference 45 | 46 | Perform inference using FCNF0++ 47 | 48 | ``` 49 | import penn 50 | 51 | # Load audio 52 | audio, sample_rate = torchaudio.load('test/assets/gershwin.wav') 53 | 54 | # Here we'll use a 10 millisecond hopsize 55 | hopsize = .01 56 | 57 | # Provide a sensible frequency range given your domain and model 58 | fmin = 30. 59 | fmax = 1000. 60 | 61 | # Choose a gpu index to use for inference. Set to None to use cpu. 62 | gpu = 0 63 | 64 | # If you are using a gpu, pick a batch size that doesn't cause memory errors 65 | # on your gpu 66 | batch_size = 2048 67 | 68 | # Select a checkpoint to use for inference. Selecting None will 69 | # download and use FCNF0++ pretrained on MDB-stem-synth and PTDB 70 | checkpoint = None 71 | 72 | # Centers frames at hopsize / 2, 3 * hopsize / 2, 5 * hopsize / 2, ... 73 | center = 'half-hop' 74 | 75 | # (Optional) Linearly interpolate unvoiced regions below periodicity threshold 76 | interp_unvoiced_at = .065 77 | 78 | # (Optional) Select a decoding method. One of ['argmax', 'pyin', 'viterbi']. 79 | decoder = 'viterbi' 80 | 81 | # Infer pitch and periodicity 82 | pitch, periodicity = penn.from_audio( 83 | audio, 84 | sample_rate, 85 | hopsize=hopsize, 86 | fmin=fmin, 87 | fmax=fmax, 88 | checkpoint=checkpoint, 89 | batch_size=batch_size, 90 | center=center, 91 | decoder=decoder, 92 | interp_unvoiced_at=interp_unvoiced_at, 93 | gpu=gpu) 94 | ``` 95 | 96 | Note that pitch estimation is performed independently on each frame of audio. Then, a _decoding_ step occurs, which may or may not be computed independently on each frame. Most often, Viterbi decoding is used (as in, e.g., PYIN and CREPE). However, Viterbi decoding is slow. We made a fast Viterbi decoder called [torbi](https://github.com/maxrmorrison/torbi), which [we are working on adding to PyTorch](https://github.com/pytorch/pytorch/issues/121160). Until `torbi` is integrated into PyTorch (or otherwise made pip-installable), it is recommended to use the `dev` branch of `penn`, which uses `torbi` decoding by default, but is not pip-installable. Our paper [_Fine-Grained and Interpretable Neural Speech Editing_](https://www.maxrmorrison.com/sites/promonet/) introduces and demonstrates the efficacy of `torbi` for pitch decoding. 97 | 98 | 99 | ### Application programming interface 100 | 101 | #### `penn.from_audio` 102 | 103 | ``` 104 | def from_audio( 105 | audio: torch.Tensor, 106 | sample_rate: int = penn.SAMPLE_RATE, 107 | hopsize: float = penn.HOPSIZE_SECONDS, 108 | fmin: float = penn.FMIN, 109 | fmax: float = penn.FMAX, 110 | checkpoint: Optional[Path] = None, 111 | batch_size: Optional[int] = None, 112 | center: str = 'half-window', 113 | decoder: str = penn.DECODER, 114 | interp_unvoiced_at: Optional[float] = None, 115 | gpu: Optional[int] = None 116 | ) -> Tuple[torch.Tensor, torch.Tensor]: 117 | """Perform pitch and periodicity estimation 118 | 119 | Args: 120 | audio: The audio to extract pitch and periodicity from 121 | sample_rate: The audio sample rate 122 | hopsize: The hopsize in seconds 123 | fmin: The minimum allowable frequency in Hz 124 | fmax: The maximum allowable frequency in Hz 125 | checkpoint: The checkpoint file 126 | batch_size: The number of frames per batch 127 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 128 | interp_unvoiced_at: Specifies voicing threshold for interpolation 129 | gpu: The index of the gpu to run inference on 130 | 131 | Returns: 132 | pitch: torch.tensor( 133 | shape=(1, int(samples // penn.seconds_to_sample(hopsize)))) 134 | periodicity: torch.tensor( 135 | shape=(1, int(samples // penn.seconds_to_sample(hopsize)))) 136 | """ 137 | ``` 138 | 139 | 140 | #### `penn.from_file` 141 | 142 | ``` 143 | def from_file( 144 | file: Path, 145 | hopsize: float = penn.HOPSIZE_SECONDS, 146 | fmin: float = penn.FMIN, 147 | fmax: float = penn.FMAX, 148 | checkpoint: Optional[Path] = None, 149 | batch_size: Optional[int] = None, 150 | center: str = 'half-window', 151 | decoder: str = penn.DECODER, 152 | interp_unvoiced_at: Optional[float] = None, 153 | gpu: Optional[int] = None 154 | ) -> Tuple[torch.Tensor, torch.Tensor]: 155 | """Perform pitch and periodicity estimation from audio on disk 156 | 157 | Args: 158 | file: The audio file 159 | hopsize: The hopsize in seconds 160 | fmin: The minimum allowable frequency in Hz 161 | fmax: The maximum allowable frequency in Hz 162 | checkpoint: The checkpoint file 163 | batch_size: The number of frames per batch 164 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 165 | interp_unvoiced_at: Specifies voicing threshold for interpolation 166 | gpu: The index of the gpu to run inference on 167 | 168 | Returns: 169 | pitch: torch.tensor(shape=(1, int(samples // hopsize))) 170 | periodicity: torch.tensor(shape=(1, int(samples // hopsize))) 171 | """ 172 | ``` 173 | 174 | 175 | #### `penn.from_file_to_file` 176 | 177 | ``` 178 | def from_file_to_file( 179 | file: Path, 180 | output_prefix: Optional[Path] = None, 181 | hopsize: float = penn.HOPSIZE_SECONDS, 182 | fmin: float = penn.FMIN, 183 | fmax: float = penn.FMAX, 184 | checkpoint: Optional[Path] = None, 185 | batch_size: Optional[int] = None, 186 | center: str = 'half-window', 187 | decoder: str = penn.DECODER, 188 | interp_unvoiced_at: Optional[float] = None, 189 | gpu: Optional[int] = None 190 | ) -> None: 191 | """Perform pitch and periodicity estimation from audio on disk and save 192 | 193 | Args: 194 | file: The audio file 195 | output_prefix: The file to save pitch and periodicity without extension 196 | hopsize: The hopsize in seconds 197 | fmin: The minimum allowable frequency in Hz 198 | fmax: The maximum allowable frequency in Hz 199 | checkpoint: The checkpoint file 200 | batch_size: The number of frames per batch 201 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 202 | interp_unvoiced_at: Specifies voicing threshold for interpolation 203 | gpu: The index of the gpu to run inference on 204 | """ 205 | ``` 206 | 207 | 208 | #### `penn.from_files_to_files` 209 | 210 | ``` 211 | def from_files_to_files( 212 | files: List[Path], 213 | output_prefixes: Optional[List[Path]] = None, 214 | hopsize: float = penn.HOPSIZE_SECONDS, 215 | fmin: float = penn.FMIN, 216 | fmax: float = penn.FMAX, 217 | checkpoint: Optional[Path] = None, 218 | batch_size: Optional[int] = None, 219 | center: str = 'half-window', 220 | decoder: str = penn.DECODER, 221 | interp_unvoiced_at: Optional[float] = None, 222 | num_workers: int = penn.NUM_WORKERS, 223 | gpu: Optional[int] = None 224 | ) -> None: 225 | """Perform pitch and periodicity estimation from files on disk and save 226 | 227 | Args: 228 | files: The audio files 229 | output_prefixes: Files to save pitch and periodicity without extension 230 | hopsize: The hopsize in seconds 231 | fmin: The minimum allowable frequency in Hz 232 | fmax: The maximum allowable frequency in Hz 233 | checkpoint: The checkpoint file 234 | batch_size: The number of frames per batch 235 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 236 | interp_unvoiced_at: Specifies voicing threshold for interpolation 237 | num_workers: Number of CPU threads for async data I/O 238 | gpu: The index of the gpu to run inference on 239 | """ 240 | ``` 241 | 242 | 243 | ### Command-line interface 244 | 245 | ``` 246 | python -m penn 247 | --files FILES [FILES ...] 248 | [-h] 249 | [--config CONFIG] 250 | [--output_prefixes OUTPUT_PREFIXES [OUTPUT_PREFIXES ...]] 251 | [--hopsize HOPSIZE] 252 | [--fmin FMIN] 253 | [--fmax FMAX] 254 | [--checkpoint CHECKPOINT] 255 | [--batch_size BATCH_SIZE] 256 | [--center {half-window,half-hop,zero}] 257 | [--decoder {argmax,pyin,viterbi}] 258 | [--interp_unvoiced_at INTERP_UNVOICED_AT] 259 | [--num_workers NUM_WORKERS] 260 | [--gpu GPU] 261 | 262 | required arguments: 263 | --files FILES [FILES ...] 264 | The audio files to process 265 | 266 | optional arguments: 267 | -h, --help 268 | show this help message and exit 269 | --config CONFIG 270 | The configuration file. Defaults to using FCNF0++. 271 | --output_prefixes OUTPUT_PREFIXES [OUTPUT_PREFIXES ...] 272 | The files to save pitch and periodicity without extension. 273 | Defaults to files without extensions. 274 | --hopsize HOPSIZE 275 | The hopsize in seconds. Defaults to 0.01 seconds. 276 | --fmin FMIN 277 | The minimum frequency allowed in Hz. Defaults to 31.0 Hz. 278 | --fmax FMAX 279 | The maximum frequency allowed in Hz. Defaults to 1984.0 Hz. 280 | --checkpoint CHECKPOINT 281 | The model checkpoint file. Defaults to ./penn/assets/checkpoints/fcnf0++.pt. 282 | --batch_size BATCH_SIZE 283 | The number of frames per batch. Defaults to 2048. 284 | --center {half-window,half-hop,zero} 285 | Padding options 286 | --decoder {argmax,pyin,viterbi} 287 | Posteriorgram decoder 288 | --interp_unvoiced_at INTERP_UNVOICED_AT 289 | Specifies voicing threshold for interpolation. Defaults to 0.1625. 290 | --num_workers 291 | Number of CPU threads for async data I/O 292 | --gpu GPU 293 | The index of the gpu to perform inference on. Defaults to CPU. 294 | ``` 295 | 296 | 297 | ## Training 298 | 299 | ### Download 300 | 301 | `python -m penn.data.download` 302 | 303 | Downloads and uncompresses the `mdb` and `ptdb` datasets used for training. 304 | 305 | 306 | ### Preprocess 307 | 308 | `python -m penn.data.preprocess --config ` 309 | 310 | Converts each dataset to a common format on disk ready for training. You 311 | can optionally pass a configuration file to override the default configuration. 312 | 313 | 314 | ### Partition 315 | 316 | `python -m penn.partition` 317 | 318 | Generates `train`, `valid`, and `test` partitions for `mdb` and `ptdb`. 319 | Partitioning is deterministic given the same random seed. You do not need to 320 | run this step, as the original partitions are saved in 321 | `penn/assets/partitions`. 322 | 323 | 324 | ### Train 325 | 326 | `python -m penn.train --config --gpu ` 327 | 328 | Trains a model according to a given configuration on the `mdb` and `ptdb` 329 | datasets. 330 | 331 | 332 | ### Monitor 333 | 334 | You can monitor training via `tensorboard`. 335 | 336 | ``` 337 | tensorboard --logdir runs/ --port --load_fast true 338 | ``` 339 | 340 | To use the `torchutil` notification system to receive notifications for long 341 | jobs (download, preprocess, train, and evaluate), set the 342 | `PYTORCH_NOTIFICATION_URL` environment variable to a supported webhook as 343 | explained in [the Apprise documentation](https://pypi.org/project/apprise/). 344 | 345 | 346 | ## Evaluation 347 | 348 | ### Evaluate 349 | 350 | ``` 351 | python -m penn.evaluate \ 352 | --config \ 353 | --checkpoint \ 354 | --gpu 355 | ``` 356 | 357 | Evaluate a model. `` is the checkpoint file to evaluate and `` 358 | is the GPU index. 359 | 360 | 361 | ### Plot 362 | 363 | ``` 364 | python -m penn.plot.density \ 365 | --config \ 366 | --true_datasets \ 367 | --inference_datasets \ 368 | --output_file \ 369 | --checkpoint \ 370 | --gpu 371 | ``` 372 | 373 | Plot the data distribution and inferred distribution for a given dataset and 374 | save to a jpg file. 375 | 376 | ``` 377 | python -m penn.plot.logits \ 378 | --config \ 379 | --audio_file \ 380 | --output_file \ 381 | --checkpoint \ 382 | --gpu 383 | ``` 384 | 385 | Plot the pitch posteriorgram of an audio file and save to a jpg file. 386 | 387 | ``` 388 | python -m penn.plot.threshold \ 389 | --names \ 390 | --evaluations \ 391 | --output_file 392 | ``` 393 | 394 | Plot the periodicity performance (voiced/unvoiced F1) over mdb and ptdb as a 395 | function of the voiced/unvoiced threshold. `names` are the plot labels to give 396 | each evaluation. `evaluations` are the names of the evaluations to plot. 397 | 398 | 399 | ## Citation 400 | 401 | ### IEEE 402 | M. Morrison, C. Hsieh, N. Pruyne, and B. Pardo, "Cross-domain Neural Pitch and Periodicity Estimation," arXiv preprint arXiv:2301.12258, 2023. 403 | 404 | 405 | ### BibTex 406 | 407 | ``` 408 | @inproceedings{morrison2023cross, 409 | title={Cross-domain Neural Pitch and Periodicity Estimation}, 410 | author={Morrison, Max and Hsieh, Caedon and Pruyne, Nathan and Pardo, Bryan}, 411 | booktitle={arXiv preprint arXiv:2301.12258}, 412 | year={2023} 413 | } 414 | -------------------------------------------------------------------------------- /penn/evaluate/core.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import tempfile 4 | import time 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | import torchutil 10 | 11 | import penn 12 | 13 | 14 | ############################################################################### 15 | # Evaluate 16 | ############################################################################### 17 | 18 | 19 | @torchutil.notify('evaluate') 20 | def datasets( 21 | datasets=penn.EVALUATION_DATASETS, 22 | checkpoint=None, 23 | gpu=None): 24 | """Perform evaluation""" 25 | # Make output directory 26 | directory = penn.EVAL_DIR / penn.CONFIG 27 | directory.mkdir(exist_ok=True, parents=True) 28 | 29 | # Evaluate pitch estimation quality and save logits 30 | pitch_quality(directory, datasets, checkpoint, gpu) 31 | 32 | with tempfile.TemporaryDirectory() as directory: 33 | directory = Path(directory) 34 | 35 | # Get periodicity methods 36 | if penn.METHOD == 'dio': 37 | periodicity_fns = {} 38 | elif penn.METHOD == 'pyin': 39 | periodicity_fns = {'sum': penn.periodicity.sum} 40 | else: 41 | periodicity_fns = { 42 | 'entropy': penn.periodicity.entropy, 43 | 'max': penn.periodicity.max} 44 | 45 | # Evaluate periodicity 46 | periodicity_results = {} 47 | for key, val in periodicity_fns.items(): 48 | periodicity_results[key] = periodicity_quality( 49 | directory, 50 | val, 51 | datasets, 52 | checkpoint=checkpoint, 53 | gpu=gpu) 54 | 55 | # Write periodicity results 56 | file = penn.EVAL_DIR / penn.CONFIG / 'periodicity.json' 57 | with open(file, 'w') as file: 58 | json.dump(periodicity_results, file, indent=4) 59 | 60 | # Perform benchmarking on CPU 61 | benchmark_results = {'cpu': benchmark(datasets, checkpoint)} 62 | 63 | # PYIN and DIO do not have GPU support 64 | if penn.METHOD not in ['dio', 'pyin']: 65 | benchmark_results ['gpu'] = benchmark(datasets, checkpoint, gpu) 66 | 67 | # Write benchmarking information 68 | with open(penn.EVAL_DIR / penn.CONFIG / 'time.json', 'w') as file: 69 | json.dump(benchmark_results, file, indent=4) 70 | 71 | 72 | ############################################################################### 73 | # Individual evaluations 74 | ############################################################################### 75 | 76 | 77 | def benchmark( 78 | datasets=penn.EVALUATION_DATASETS, 79 | checkpoint=None, 80 | gpu=None): 81 | """Perform benchmarking""" 82 | # Get audio files 83 | dataset_stems = { 84 | dataset: penn.load.partition(dataset)['test'] for dataset in datasets} 85 | files = [ 86 | penn.CACHE_DIR / dataset / f'{stem}.wav' 87 | for dataset, stems in dataset_stems.items() 88 | for stem in stems] 89 | 90 | # Setup temporary directory 91 | with tempfile.TemporaryDirectory() as directory: 92 | directory = Path(directory) 93 | 94 | # Create output directories 95 | for dataset in datasets: 96 | (directory / dataset).mkdir(exist_ok=True, parents=True) 97 | 98 | # Get output prefixes 99 | output_prefixes = [ 100 | directory / file.parent.name / file.stem for file in files] 101 | 102 | # Start benchmarking 103 | penn.BENCHMARK = True 104 | torchutil.time.reset() 105 | start_time = time.time() 106 | 107 | # Infer to temporary storage 108 | if penn.METHOD == 'penn': 109 | penn.from_files_to_files( 110 | files, 111 | output_prefixes, 112 | checkpoint=checkpoint, 113 | batch_size=penn.EVALUATION_BATCH_SIZE, 114 | center='half-hop', 115 | gpu=gpu) 116 | 117 | elif penn.METHOD == 'torchcrepe': 118 | 119 | import torchcrepe 120 | 121 | # Get output file paths 122 | pitch_files = [ 123 | file.parent / f'{file.stem}-pitch.pt' 124 | for file in output_prefixes] 125 | periodicity_files = [ 126 | file.parent / f'{file.stem}-periodicity.pt' 127 | for file in output_prefixes] 128 | 129 | # Infer 130 | # Note - this does not perform the correct padding, but suffices 131 | # for benchmarking purposes 132 | torchcrepe.predict_from_files_to_files( 133 | files, 134 | pitch_files, 135 | output_periodicity_files=periodicity_files, 136 | hop_length=penn.HOPSIZE, 137 | decoder=torchcrepe.decode.argmax, 138 | batch_size=penn.EVALUATION_BATCH_SIZE, 139 | device='cpu' if gpu is None else f'cuda:{gpu}', 140 | pad=False) 141 | elif penn.METHOD == 'dio': 142 | penn.dsp.dio.from_files_to_files(files, output_prefixes) 143 | elif penn.METHOD == 'pyin': 144 | penn.dsp.pyin.from_files_to_files(files, output_prefixes) 145 | 146 | # Turn off benchmarking 147 | penn.BENCHMARK = False 148 | 149 | # Get benchmarking information 150 | benchmark = torchutil.time.results() 151 | benchmark['total'] = time.time() - start_time 152 | 153 | # Get total number of samples and seconds in test data 154 | samples = sum([ 155 | len(np.load(file.parent / f'{file.stem}-audio.npy', mmap_mode='r')) 156 | for file in files]) 157 | seconds = penn.convert.samples_to_seconds(samples) 158 | 159 | # Format benchmarking results 160 | return { 161 | key: { 162 | 'real-time-factor': value / seconds, 163 | 'samples': samples, 164 | 'samples-per-second': samples / value, 165 | 'seconds': value 166 | } for key, value in benchmark.items()} 167 | 168 | 169 | def periodicity_quality( 170 | directory, 171 | periodicity_fn, 172 | datasets=penn.EVALUATION_DATASETS, 173 | steps=8, 174 | checkpoint=None, 175 | gpu=None): 176 | """Fine-grained periodicity estimation quality evaluation""" 177 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 178 | 179 | # Evaluate each dataset 180 | for dataset in datasets: 181 | 182 | # Create output directory 183 | (directory / dataset).mkdir(exist_ok=True, parents=True) 184 | 185 | # Iterate over validation set 186 | for audio, _, _, voiced, stem in torchutil.iterator( 187 | penn.data.loader([dataset], 'valid', True), 188 | f'Evaluating {penn.CONFIG} periodicity quality on {dataset}' 189 | ): 190 | 191 | if penn.METHOD == 'penn': 192 | 193 | # Accumulate logits 194 | logits = [] 195 | 196 | # Preprocess audio 197 | batch_size = \ 198 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 199 | for frames in penn.preprocess( 200 | audio[0], 201 | penn.SAMPLE_RATE, 202 | batch_size=batch_size, 203 | center='half-hop' 204 | ): 205 | 206 | # Copy to device 207 | frames = frames.to(device) 208 | 209 | # Infer 210 | batch_logits = penn.infer(frames, checkpoint).detach() 211 | 212 | # Accumulate logits 213 | logits.append(batch_logits) 214 | 215 | logits = torch.cat(logits) 216 | 217 | elif penn.METHOD == 'torchcrepe': 218 | 219 | import torchcrepe 220 | 221 | # Accumulate logits 222 | logits = [] 223 | 224 | # Postprocessing breaks gradients, so just don't compute them 225 | with torch.no_grad(): 226 | 227 | # Preprocess audio 228 | batch_size = \ 229 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 230 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 231 | generator = torchcrepe.preprocess( 232 | torch.nn.functional.pad(audio, (pad, pad))[0], 233 | penn.SAMPLE_RATE, 234 | penn.HOPSIZE, 235 | batch_size, 236 | device, 237 | False) 238 | for frames in generator: 239 | 240 | # Infer independent probabilities for each pitch bin 241 | batch_logits = torchcrepe.infer( 242 | frames.to(device))[:, :, None] 243 | 244 | # Accumulate logits 245 | logits.append(batch_logits) 246 | logits = torch.cat(logits) 247 | 248 | elif penn.METHOD == 'pyin': 249 | 250 | # Pad 251 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 252 | audio = torch.nn.functional.pad(audio, (pad, pad)) 253 | 254 | # Infer 255 | logits = penn.dsp.pyin.infer(audio[0]) 256 | 257 | # Save to temporary storage 258 | file = directory / dataset / f'{stem[0]}-logits.pt' 259 | torch.save(logits, file) 260 | 261 | # Default values 262 | best_threshold = .5 263 | best_value = 0. 264 | stepsize = .05 265 | 266 | # Setup metrics 267 | metrics = penn.evaluate.metrics.F1() 268 | 269 | step = 0 270 | while step < steps: 271 | 272 | for dataset in datasets: 273 | 274 | # Setup loader 275 | loader = penn.data.loader([dataset], 'valid', True) 276 | 277 | # Iterate over validation set 278 | for _, _, _, voiced, stem in loader: 279 | 280 | # Load logits 281 | logits = torch.load(directory / dataset / f'{stem[0]}-logits.pt') 282 | 283 | # Decode periodicity 284 | periodicity = periodicity_fn(logits.to(device)).T 285 | 286 | # Update metrics 287 | metrics.update(periodicity, voiced.to(device)) 288 | 289 | # Get best performing threshold 290 | results = { 291 | key: val for key, val in metrics().items() if key.startswith('f1') 292 | and not math.isnan(val)} 293 | key = max(results, key=results.get) 294 | threshold = float(key[3:]) 295 | value = results[key] 296 | if value > best_value: 297 | best_value = value 298 | best_threshold = threshold 299 | 300 | # Reinitialize metrics with new thresholds 301 | metrics = penn.evaluate.metrics.F1( 302 | [best_threshold - stepsize, best_threshold + stepsize]) 303 | 304 | # Binary search for optimal threshold 305 | stepsize /= 2 306 | step += 1 307 | 308 | # Setup metrics with optimal threshold 309 | metrics = penn.evaluate.metrics.F1([best_threshold]) 310 | 311 | # Setup test loader 312 | loader = penn.data.loader(datasets, 'test') 313 | 314 | # Iterate over test set 315 | for audio, _, _, voiced, stem in loader: 316 | 317 | if penn.METHOD == 'penn': 318 | 319 | # Accumulate logits 320 | logits = [] 321 | 322 | # Preprocess audio 323 | batch_size = \ 324 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 325 | for frames in penn.preprocess( 326 | audio[0], 327 | penn.SAMPLE_RATE, 328 | batch_size=batch_size, 329 | center='half-hop' 330 | ): 331 | 332 | # Copy to device 333 | frames = frames.to(device) 334 | 335 | # Infer 336 | batch_logits = penn.infer(frames, checkpoint).detach() 337 | 338 | # Accumulate logits 339 | logits.append(batch_logits) 340 | 341 | logits = torch.cat(logits) 342 | 343 | elif penn.METHOD == 'torchcrepe': 344 | 345 | import torchcrepe 346 | 347 | # Accumulate logits 348 | logits = [] 349 | 350 | # Postprocessing breaks gradients, so just don't compute them 351 | with torch.no_grad(): 352 | 353 | # Preprocess audio 354 | batch_size = \ 355 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 356 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 357 | generator = torchcrepe.preprocess( 358 | torch.nn.functional.pad(audio, (pad, pad))[0], 359 | penn.SAMPLE_RATE, 360 | penn.HOPSIZE, 361 | batch_size, 362 | device, 363 | False) 364 | for frames in generator: 365 | 366 | # Infer independent probabilities for each pitch bin 367 | batch_logits = torchcrepe.infer( 368 | frames.to(device))[:, :, None] 369 | 370 | # Accumulate logits 371 | logits.append(batch_logits) 372 | logits = torch.cat(logits) 373 | 374 | elif penn.METHOD == 'pyin': 375 | 376 | # Pad 377 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 378 | audio = torch.nn.functional.pad(audio, (pad, pad)) 379 | 380 | # Infer 381 | logits = penn.dsp.pyin.infer(audio[0]).to(device) 382 | 383 | # Decode periodicity 384 | periodicity = periodicity_fn(logits).T 385 | 386 | # Update metrics 387 | metrics.update(periodicity, voiced.to(device)) 388 | 389 | # Get F1 score on test set 390 | score = metrics()[f'f1-{best_threshold:.6f}'] 391 | 392 | return {'threshold': best_threshold, 'f1': score} 393 | 394 | 395 | def pitch_quality( 396 | directory, 397 | datasets=penn.EVALUATION_DATASETS, 398 | checkpoint=None, 399 | gpu=None): 400 | """Evaluate pitch estimation quality""" 401 | device = torch.device('cpu' if gpu is None else f'cuda:{gpu}') 402 | 403 | # Containers for results 404 | overall, granular = {}, {} 405 | 406 | # Get metric class 407 | metric_fn = ( 408 | penn.evaluate.PitchMetrics if penn.METHOD == 'dio' else 409 | penn.evaluate.Metrics) 410 | 411 | # Per-file metrics 412 | file_metrics = metric_fn() 413 | 414 | # Per-dataset metrics 415 | dataset_metrics = metric_fn() 416 | 417 | # Aggregate metrics over all datasets 418 | aggregate_metrics = metric_fn() 419 | 420 | # Evaluate each dataset 421 | for dataset in datasets: 422 | 423 | # Reset dataset metrics 424 | dataset_metrics.reset() 425 | 426 | # Iterate over test set 427 | for audio, bins, pitch, voiced, stem in torchutil.iterator( 428 | penn.data.loader([dataset], 'test'), 429 | f'Evaluating {penn.CONFIG} pitch quality on {dataset}' 430 | ): 431 | 432 | # Reset file metrics 433 | file_metrics.reset() 434 | 435 | if penn.METHOD == 'penn': 436 | 437 | # Accumulate logits 438 | logits = [] 439 | 440 | # Preprocess audio 441 | batch_size = \ 442 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 443 | for i, frames in enumerate( 444 | penn.preprocess( 445 | audio[0], 446 | penn.SAMPLE_RATE, 447 | batch_size=batch_size, 448 | center='half-hop' 449 | ) 450 | ): 451 | 452 | # Copy to device 453 | frames = frames.to(device) 454 | 455 | # Slice features and copy to GPU 456 | start = i * penn.EVALUATION_BATCH_SIZE 457 | end = start + len(frames) 458 | batch_bins = bins[:, start:end].to(device) 459 | batch_pitch = pitch[:, start:end].to(device) 460 | batch_voiced = voiced[:, start:end].to(device) 461 | 462 | # Infer 463 | batch_logits = penn.infer(frames, checkpoint).detach() 464 | 465 | # Update metrics 466 | args = ( 467 | batch_logits, 468 | batch_bins, 469 | batch_pitch, 470 | batch_voiced) 471 | file_metrics.update(*args) 472 | dataset_metrics.update(*args) 473 | aggregate_metrics.update(*args) 474 | 475 | # Accumulate logits 476 | logits.append(batch_logits) 477 | logits = torch.cat(logits) 478 | 479 | elif penn.METHOD == 'torchcrepe': 480 | 481 | import torchcrepe 482 | 483 | # Accumulate logits 484 | logits = [] 485 | 486 | # Postprocessing breaks gradients, so just don't compute them 487 | with torch.no_grad(): 488 | 489 | # Preprocess audio 490 | batch_size = \ 491 | None if gpu is None else penn.EVALUATION_BATCH_SIZE 492 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 493 | generator = torchcrepe.preprocess( 494 | torch.nn.functional.pad(audio, (pad, pad))[0], 495 | penn.SAMPLE_RATE, 496 | penn.HOPSIZE, 497 | batch_size, 498 | device, 499 | False) 500 | for i, frames in enumerate(generator): 501 | 502 | # Infer independent probabilities for each pitch bin 503 | batch_logits = torchcrepe.infer(frames.to(device))[:, :, None] 504 | 505 | # Slice features and copy to GPU 506 | start = i * penn.EVALUATION_BATCH_SIZE 507 | end = start + frames.shape[0] 508 | batch_bins = bins[:, start:end].to(device) 509 | batch_pitch = pitch[:, start:end].to(device) 510 | batch_voiced = voiced[:, start:end].to(device) 511 | 512 | # Update metrics 513 | args = ( 514 | batch_logits, 515 | batch_bins, 516 | batch_pitch, 517 | batch_voiced) 518 | file_metrics.update(*args) 519 | dataset_metrics.update(*args) 520 | aggregate_metrics.update(*args) 521 | 522 | # Accumulate logits 523 | logits.append(batch_logits) 524 | logits = torch.cat(logits) 525 | 526 | elif penn.METHOD == 'dio': 527 | 528 | # Pad 529 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 530 | audio = torch.nn.functional.pad(audio, (pad, pad)) 531 | 532 | # Infer 533 | predicted = penn.dsp.dio.from_audio(audio[0]) 534 | 535 | # Update metrics 536 | args = predicted, pitch, voiced 537 | file_metrics.update(*args) 538 | dataset_metrics.update(*args) 539 | aggregate_metrics.update(*args) 540 | 541 | elif penn.METHOD == 'pyin': 542 | 543 | # Pad 544 | pad = (penn.WINDOW_SIZE - penn.HOPSIZE) // 2 545 | audio = torch.nn.functional.pad(audio, (pad, pad)) 546 | 547 | # Infer 548 | logits = penn.dsp.pyin.infer(audio[0]) 549 | 550 | # Update metrics 551 | args = logits, bins, pitch, voiced 552 | file_metrics.update(*args) 553 | dataset_metrics.update(*args) 554 | aggregate_metrics.update(*args) 555 | 556 | # Copy results 557 | granular[f'{dataset}/{stem[0]}'] = file_metrics() 558 | overall[dataset] = dataset_metrics() 559 | overall['aggregate'] = aggregate_metrics() 560 | 561 | # Write to json files 562 | directory = penn.EVAL_DIR / penn.CONFIG 563 | with open(directory / 'overall.json', 'w') as file: 564 | json.dump(overall, file, indent=4) 565 | with open(directory / 'granular.json', 'w') as file: 566 | json.dump(granular, file, indent=4) 567 | -------------------------------------------------------------------------------- /penn/core.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools 3 | import math 4 | import time 5 | from pathlib import Path 6 | from typing import List, Optional, Tuple 7 | 8 | import huggingface_hub 9 | import torch 10 | import torch.multiprocessing as mp 11 | import torchaudio 12 | import torchutil 13 | 14 | import penn 15 | 16 | 17 | ############################################################################### 18 | # Pitch and periodicity estimation 19 | ############################################################################### 20 | 21 | 22 | def from_audio( 23 | audio: torch.Tensor, 24 | sample_rate: int = penn.SAMPLE_RATE, 25 | hopsize: float = penn.HOPSIZE_SECONDS, 26 | fmin: float = penn.FMIN, 27 | fmax: float = penn.FMAX, 28 | checkpoint: Optional[Path] = None, 29 | batch_size: Optional[int] = None, 30 | center: str = 'half-window', 31 | decoder: str = penn.DECODER, 32 | interp_unvoiced_at: Optional[float] = None, 33 | gpu: Optional[int] = None 34 | ) -> Tuple[torch.Tensor, torch.Tensor]: 35 | """Perform pitch and periodicity estimation 36 | 37 | Args: 38 | audio: The audio to extract pitch and periodicity from 39 | sample_rate: The audio sample rate 40 | hopsize: The hopsize in seconds 41 | fmin: The minimum allowable frequency in Hz 42 | fmax: The maximum allowable frequency in Hz 43 | checkpoint: The checkpoint file 44 | batch_size: The number of frames per batch 45 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 46 | decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. 47 | interp_unvoiced_at: Specifies voicing threshold for interpolation 48 | gpu: The index of the gpu to run inference on 49 | 50 | Returns: 51 | pitch: torch.tensor( 52 | shape=(1, int(samples // penn.seconds_to_sample(hopsize)))) 53 | periodicity: torch.tensor( 54 | shape=(1, int(samples // penn.seconds_to_sample(hopsize)))) 55 | """ 56 | device = 'cpu' if gpu is None else f'cuda:{gpu}' 57 | 58 | # Storage for batching 59 | if batch_size is not None: 60 | if decoder == 'argmax': 61 | pitch, periodicity = [], [] 62 | else: 63 | logits = [] 64 | 65 | # Preprocess audio 66 | for frames in preprocess( 67 | audio, 68 | sample_rate, 69 | hopsize, 70 | batch_size, 71 | center 72 | ): 73 | 74 | # Copy to device 75 | with torchutil.time.context('copy-to'): 76 | frames = frames.to(device) 77 | 78 | # Infer 79 | inferred = infer(frames, checkpoint).detach() 80 | 81 | if batch_size is None: 82 | 83 | # Postprocess full file 84 | with torchutil.time.context('postprocess'): 85 | _, pitch, periodicity = postprocess( 86 | inferred, 87 | fmin, 88 | fmax, 89 | decoder) 90 | 91 | elif decoder == 'argmax': 92 | 93 | # Postprocess partial file 94 | with torchutil.time.context('postprocess'): 95 | result = postprocess(inferred, fmin, fmax, decoder) 96 | pitch.append(result[1]) 97 | periodicity.append(result[2]) 98 | 99 | else: 100 | 101 | # Save logits off GPU for later decoding 102 | logits.append(inferred.cpu()) 103 | 104 | if batch_size is not None: 105 | 106 | if decoder == 'argmax': 107 | 108 | # Concatenate results 109 | pitch = torch.cat(pitch, 1) 110 | periodicity = torch.cat(periodicity, 1) 111 | 112 | else: 113 | 114 | # Postprocess full file 115 | _, pitch, periodicity = postprocess( 116 | torch.cat(logits, 0).to(device), 117 | fmin, 118 | fmax, 119 | decoder) 120 | 121 | # Maybe interpolate unvoiced regions 122 | if interp_unvoiced_at is not None: 123 | pitch = penn.voicing.interpolate( 124 | pitch, 125 | periodicity, 126 | interp_unvoiced_at) 127 | 128 | return pitch, periodicity 129 | 130 | 131 | def from_file( 132 | file: Path, 133 | hopsize: float = penn.HOPSIZE_SECONDS, 134 | fmin: float = penn.FMIN, 135 | fmax: float = penn.FMAX, 136 | checkpoint: Optional[Path] = None, 137 | batch_size: Optional[int] = None, 138 | center: str = 'half-window', 139 | decoder: str = penn.DECODER, 140 | interp_unvoiced_at: Optional[float] = None, 141 | gpu: Optional[int] = None 142 | ) -> Tuple[torch.Tensor, torch.Tensor]: 143 | """Perform pitch and periodicity estimation from audio on disk 144 | 145 | Args: 146 | file: The audio file 147 | hopsize: The hopsize in seconds 148 | fmin: The minimum allowable frequency in Hz 149 | fmax: The maximum allowable frequency in Hz 150 | checkpoint: The checkpoint file 151 | batch_size: The number of frames per batch 152 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 153 | decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. 154 | interp_unvoiced_at: Specifies voicing threshold for interpolation 155 | gpu: The index of the gpu to run inference on 156 | 157 | Returns: 158 | pitch: torch.tensor(shape=(1, int(samples // hopsize))) 159 | periodicity: torch.tensor(shape=(1, int(samples // hopsize))) 160 | """ 161 | # Load audio 162 | with torchutil.time.context('load'): 163 | audio, sample_rate = torchaudio.load(file) 164 | 165 | # Inference 166 | return from_audio( 167 | audio, 168 | sample_rate, 169 | hopsize, 170 | fmin, 171 | fmax, 172 | checkpoint, 173 | batch_size, 174 | center, 175 | decoder, 176 | interp_unvoiced_at, 177 | gpu) 178 | 179 | 180 | def from_file_to_file( 181 | file: Path, 182 | output_prefix: Optional[Path] = None, 183 | hopsize: float = penn.HOPSIZE_SECONDS, 184 | fmin: float = penn.FMIN, 185 | fmax: float = penn.FMAX, 186 | checkpoint: Optional[Path] = None, 187 | batch_size: Optional[int] = None, 188 | center: str = 'half-window', 189 | decoder: str = penn.DECODER, 190 | interp_unvoiced_at: Optional[float] = None, 191 | gpu: Optional[int] = None 192 | ) -> None: 193 | """Perform pitch and periodicity estimation from audio on disk and save 194 | 195 | Args: 196 | file: The audio file 197 | output_prefix: The file to save pitch and periodicity without extension 198 | hopsize: The hopsize in seconds 199 | fmin: The minimum allowable frequency in Hz 200 | fmax: The maximum allowable frequency in Hz 201 | checkpoint: The checkpoint file 202 | batch_size: The number of frames per batch 203 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 204 | decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. 205 | interp_unvoiced_at: Specifies voicing threshold for interpolation 206 | gpu: The index of the gpu to run inference on 207 | """ 208 | # Inference 209 | pitch, periodicity = from_file( 210 | file, 211 | hopsize, 212 | fmin, 213 | fmax, 214 | checkpoint, 215 | batch_size, 216 | center, 217 | decoder, 218 | interp_unvoiced_at, 219 | gpu) 220 | 221 | # Move to cpu 222 | with torchutil.time.context('copy-from'): 223 | pitch, periodicity = pitch.cpu(), periodicity.cpu() 224 | 225 | # Save to disk 226 | with torchutil.time.context('save'): 227 | 228 | # Maybe use same filename with new extension 229 | if output_prefix is None: 230 | output_prefix = file.parent / file.stem 231 | 232 | # Save 233 | torch.save(pitch, f'{output_prefix}-pitch.pt') 234 | torch.save(periodicity, f'{output_prefix}-periodicity.pt') 235 | 236 | 237 | def from_files_to_files( 238 | files: List[Path], 239 | output_prefixes: Optional[List[Path]] = None, 240 | hopsize: float = penn.HOPSIZE_SECONDS, 241 | fmin: float = penn.FMIN, 242 | fmax: float = penn.FMAX, 243 | checkpoint: Optional[Path] = None, 244 | batch_size: Optional[int] = None, 245 | center: str = 'half-window', 246 | decoder: str = penn.DECODER, 247 | interp_unvoiced_at: Optional[float] = None, 248 | num_workers: int = 0, 249 | gpu: Optional[int] = None 250 | ) -> None: 251 | """Perform pitch and periodicity estimation from files on disk and save 252 | 253 | Args: 254 | files: The audio files 255 | output_prefixes: Files to save pitch and periodicity without extension 256 | hopsize: The hopsize in seconds 257 | fmin: The minimum allowable frequency in Hz 258 | fmax: The maximum allowable frequency in Hz 259 | checkpoint: The checkpoint file 260 | batch_size: The number of frames per batch 261 | center: Padding options. One of ['half-window', 'half-hop', 'zero']. 262 | decoder: Posteriorgram decoder. One of ['argmax', 'pyin', 'viterbi']. 263 | interp_unvoiced_at: Specifies voicing threshold for interpolation 264 | num_workers: Number of CPU threads for async data I/O 265 | gpu: The index of the gpu to run inference on 266 | """ 267 | # Maybe use default output filenames 268 | if output_prefixes is None: 269 | output_prefixes = [file.parent / file.stem for file in files] 270 | 271 | # Single-threaded 272 | if num_workers == 0: 273 | 274 | # Iterate over files 275 | for i, (file, output_prefix) in torchutil.iterator( 276 | enumerate(zip(files, output_prefixes)), 277 | f'{penn.CONFIG}', 278 | total=len(files) 279 | ): 280 | 281 | # Infer 282 | from_file_to_file( 283 | file, 284 | output_prefix, 285 | hopsize, 286 | fmin, 287 | fmax, 288 | checkpoint, 289 | batch_size, 290 | center, 291 | decoder, 292 | interp_unvoiced_at, 293 | gpu) 294 | 295 | # Multi-threaded 296 | else: 297 | 298 | # Initialize multi-threaded dataloader 299 | loader = inference_loader( 300 | files, 301 | hopsize, 302 | batch_size, 303 | center, 304 | int(math.ceil(num_workers / 2))) 305 | 306 | # Maintain file correspondence 307 | output_prefixes = { 308 | file: output_prefix 309 | for file, output_prefix in zip(files, output_prefixes)} 310 | 311 | # Setup multiprocessing 312 | futures = [] 313 | pool = mp.get_context('spawn').Pool(max(1, num_workers // 2)) 314 | 315 | # Setup progress bar 316 | progress = torchutil.iterator( 317 | range(len(files)), 318 | penn.CONFIG, 319 | total=len(files)) 320 | 321 | try: 322 | 323 | device = 'cpu' if gpu is None else f'cuda:{gpu}' 324 | 325 | # Track residual to fill batch 326 | residual_files = [] 327 | residual_frames = torch.zeros((0, 1, 1024)) 328 | residual_lengths = torch.zeros((0,), dtype=torch.long) 329 | 330 | # Storage for batching within files 331 | if batch_size is not None: 332 | if decoder == 'argmax': 333 | pitch, periodicity = torch.zeros((1, 0)), torch.zeros((1, 0)) 334 | else: 335 | logits = torch.zeros((1, 0, 0)) 336 | 337 | # Iterate over data 338 | num_inferred_unsaved = 0 339 | for frames, lengths, input_files in loader: 340 | 341 | # Prepend residual 342 | if residual_files: 343 | frames = torch.cat((residual_frames, frames), dim=0) 344 | lengths = torch.cat((residual_lengths, lengths)) 345 | input_files = residual_files + input_files 346 | 347 | i = 0 348 | while batch_size is None or i + batch_size <= len(frames): 349 | 350 | # Copy to device 351 | size = len(frames) if batch_size is None else batch_size 352 | batch_frames = frames[i:i + size].to(device) 353 | 354 | # Infer 355 | inferred = infer(batch_frames, checkpoint).detach() 356 | i += len(batch_frames) 357 | num_inferred_unsaved += len(batch_frames) 358 | 359 | if batch_size is None: 360 | 361 | # Postprocess full file 362 | _, pitch, periodicity = postprocess( 363 | inferred, 364 | fmin, 365 | fmax, 366 | decoder) 367 | break 368 | 369 | elif decoder == 'argmax': 370 | 371 | # Postprocess partial file 372 | results = postprocess(inferred, fmin, fmax, decoder) 373 | pitch = torch.cat((pitch, results[1].cpu()), dim=1) 374 | periodicity = torch.cat( 375 | (periodicity, results[2].cpu()), 376 | dim=1) 377 | 378 | else: 379 | 380 | # Save logits for later decoding 381 | # NOTE - This differs from from_audio and does not 382 | # handle large files that do not fit on GPU. 383 | # However, it saves a GPU -> CPU -> GPU copy. 384 | logits = torch.cat((logits, inferred), dim=0) 385 | 386 | # Save to disk 387 | j, k = 0, 0 388 | for length, file in zip(lengths, input_files): 389 | 390 | # Slice and save in another process 391 | if j + length <= num_inferred_unsaved: 392 | 393 | if batch_size is not None: 394 | 395 | if decoder == 'argmax': 396 | 397 | # Slice results 398 | save_pitch = pitch[:, j:j + length] 399 | save_periodicity = periodicity[:, j:j + length] 400 | 401 | else: 402 | 403 | # Postprocess full file 404 | _, save_pitch, save_periodicity = postprocess( 405 | logits[j:j + length], 406 | fmin, 407 | fmax, 408 | decoder) 409 | 410 | # Async save 411 | futures.append( 412 | pool.apply_async( 413 | save_worker, 414 | args=( 415 | output_prefixes[file], 416 | save_pitch, 417 | save_periodicity, 418 | interp_unvoiced_at))) 419 | while len(futures) > 100: 420 | futures = [f for f in futures if not f.ready()] 421 | time.sleep(.1) 422 | 423 | j += length 424 | k += 1 425 | progress.update() 426 | else: 427 | break 428 | 429 | # Setup residual for next iteration 430 | num_inferred_unsaved -= j 431 | pitch = pitch[:, j:] 432 | periodicity = periodicity[:, j:] 433 | logits = logits[j:] 434 | residual_files = input_files[k:] 435 | residual_lengths = lengths[k:] 436 | residual_frames = frames[i:] 437 | 438 | # Handle final files 439 | if residual_frames.numel(): 440 | 441 | # Copy to device 442 | batch_frames = residual_frames.to(device) 443 | 444 | # Infer 445 | inferred = infer(batch_frames, checkpoint).detach() 446 | num_inferred_unsaved += len(batch_frames) 447 | 448 | if decoder == 'argmax': 449 | 450 | # Postprocess partial file 451 | results = postprocess(inferred, fmin, fmax, decoder) 452 | pitch = torch.cat((pitch, results[1].cpu()), dim=1) 453 | periodicity = torch.cat( 454 | (periodicity, results[2].cpu()), 455 | dim=1) 456 | 457 | else: 458 | 459 | # Save logits for later decoding 460 | # NOTE - This differs from from_audio and does not 461 | # handle large files that do not fit on GPU. 462 | # However, it saves a GPU -> CPU -> GPU copy. 463 | logits = torch.cat((logits, inferred), dim=0) 464 | 465 | # Save 466 | i = 0 467 | for length, file in zip(residual_lengths, residual_files): 468 | 469 | if decoder == 'argmax': 470 | 471 | # Slice results 472 | save_pitch = pitch[:, i:i + length] 473 | save_periodicity = periodicity[:, i:i + length] 474 | 475 | else: 476 | 477 | # Postprocess full file 478 | _, save_pitch, save_periodicity = postprocess( 479 | logits[i:i + length], 480 | fmin, 481 | fmax, 482 | decoder) 483 | 484 | # Slice and save in another process 485 | if i + length <= num_inferred_unsaved: 486 | futures.append( 487 | pool.apply_async( 488 | save_worker, 489 | args=( 490 | output_prefixes[file], 491 | save_pitch, 492 | save_periodicity, 493 | interp_unvoiced_at))) 494 | while len(futures) > 100: 495 | futures = [f for f in futures if not f.ready()] 496 | time.sleep(.1) 497 | i += length 498 | progress.update() 499 | 500 | # Wait 501 | for future in futures: 502 | future.wait() 503 | 504 | finally: 505 | 506 | # Shutdown multiprocessing 507 | pool.close() 508 | pool.join() 509 | 510 | # Close progress bar 511 | progress.close() 512 | 513 | 514 | ############################################################################### 515 | # Inference pipeline stages 516 | ############################################################################### 517 | 518 | 519 | def infer(frames, checkpoint=None): 520 | """Forward pass through the model""" 521 | # Time model loading 522 | with torchutil.time.context('model'): 523 | 524 | # Load and cache model 525 | if ( 526 | not hasattr(infer, 'model') or 527 | infer.checkpoint != checkpoint or 528 | infer.device != frames.device 529 | ): 530 | 531 | # Initialize model 532 | model = penn.Model() 533 | 534 | # Maybe download from HuggingFace 535 | if checkpoint is None: 536 | checkpoint = huggingface_hub.hf_hub_download( 537 | 'maxrmorrison/fcnf0-plus-plus', 538 | 'fcnf0++.pt') 539 | infer.checkpoint = None 540 | else: 541 | infer.checkpoint = checkpoint 542 | checkpoint = torch.load(checkpoint, map_location='cpu') 543 | 544 | # Load from disk 545 | model.load_state_dict(checkpoint['model']) 546 | infer.device = frames.device 547 | 548 | # Move model to correct device (no-op if devices are the same) 549 | infer.model = model.to(frames.device) 550 | 551 | # Time inference 552 | with torchutil.time.context('infer'): 553 | 554 | # Prepare model for inference 555 | with inference_context(infer.model): 556 | 557 | # Infer 558 | logits = infer.model(frames) 559 | 560 | # If we're benchmarking, make sure inference finishes within timer 561 | if penn.BENCHMARK and logits.device.type == 'cuda': 562 | torch.cuda.synchronize(logits.device) 563 | 564 | return logits 565 | 566 | 567 | def postprocess(logits, fmin=penn.FMIN, fmax=penn.FMAX, decoder=penn.DECODER): 568 | """Convert model output to pitch and periodicity""" 569 | # Cache decoder 570 | if ( 571 | not hasattr(postprocess, 'decoder') or 572 | postprocess.decoder_name != decoder 573 | ): 574 | if decoder == 'argmax': 575 | postprocess.decoder = penn.decode.Argmax() 576 | elif decoder == 'pyin': 577 | postprocess.decoder = penn.decode.PYIN() 578 | elif decoder == 'viterbi': 579 | postprocess.decoder = penn.decode.Viterbi() 580 | else: 581 | raise ValueError(f'Decoder method {decoder} is not defined') 582 | postprocess.decoder_name = decoder 583 | 584 | # Turn off gradients 585 | with torch.inference_mode(): 586 | 587 | # Convert frequency range to pitch bin range 588 | minidx = penn.convert.frequency_to_bins(torch.tensor(fmin)) 589 | maxidx = penn.convert.frequency_to_bins( 590 | torch.tensor(fmax), 591 | torch.ceil) 592 | 593 | # Remove frequencies outside of allowable range 594 | logits[:, :minidx] = -float('inf') 595 | logits[:, maxidx:] = -float('inf') 596 | 597 | # Decode pitch from logits 598 | bins, pitch = postprocess.decoder(logits) 599 | 600 | # Decode periodicity from logits 601 | if penn.PERIODICITY == 'entropy': 602 | periodicity = penn.periodicity.entropy(logits) 603 | elif penn.PERIODICITY == 'max': 604 | periodicity = penn.periodicity.max(logits) 605 | elif penn.PERIODICITY == 'sum': 606 | periodicity = penn.periodicity.sum(logits) 607 | else: 608 | raise ValueError( 609 | f'Periodicity method {penn.PERIODICITY} is not defined') 610 | 611 | return bins.T, pitch.T, periodicity.T 612 | 613 | 614 | def preprocess( 615 | audio, 616 | sample_rate=penn.SAMPLE_RATE, 617 | hopsize=penn.HOPSIZE_SECONDS, 618 | batch_size=None, 619 | center='half-window'): 620 | """Convert audio to model input""" 621 | # Get number of frames 622 | total_frames = expected_frames( 623 | audio.shape[-1], 624 | sample_rate, 625 | hopsize, 626 | center) 627 | 628 | # Maybe resample 629 | if sample_rate != penn.SAMPLE_RATE: 630 | audio = resample(audio, sample_rate) 631 | 632 | # Maybe pad audio 633 | hopsize = penn.convert.seconds_to_samples(hopsize) 634 | if center in ['half-hop', 'zero']: 635 | if center == 'half-hop': 636 | padding = int((penn.WINDOW_SIZE - hopsize) / 2) 637 | else: 638 | padding = int(penn.WINDOW_SIZE / 2) 639 | audio = torch.nn.functional.pad( 640 | audio, 641 | (padding, padding), 642 | mode='reflect') 643 | 644 | # Integer hopsizes permit a speedup using torch.unfold 645 | if isinstance(hopsize, int) or hopsize.is_integer(): 646 | hopsize = int(round(hopsize)) 647 | start_idxs = None 648 | 649 | else: 650 | 651 | # Find start indices 652 | start_idxs = torch.round( 653 | torch.tensor([hopsize * i for i in range(total_frames + 1)]) 654 | ).int() 655 | 656 | # Default to running all frames in a single batch 657 | batch_size = total_frames if batch_size is None else batch_size 658 | 659 | # Generate batches 660 | for i in range(0, total_frames, batch_size): 661 | 662 | # Size of this batch 663 | batch = min(total_frames - i, batch_size) 664 | 665 | # Fast implementation for integer hopsizes 666 | if start_idxs is None: 667 | 668 | # Batch indices 669 | start = i * hopsize 670 | end = start + int((batch - 1) * hopsize) + penn.WINDOW_SIZE 671 | end = min(end, audio.shape[-1]) 672 | batch_audio = audio[:, start:end] 673 | 674 | # Maybe pad to a single frame 675 | if end - start < penn.WINDOW_SIZE: 676 | padding = penn.WINDOW_SIZE - (end - start) 677 | 678 | # Handle multiple of hopsize 679 | remainder = (end - start) % hopsize 680 | if remainder: 681 | padding += end - start - hopsize 682 | 683 | # Pad 684 | batch_audio = torch.nn.functional.pad( 685 | batch_audio, 686 | (0, padding)) 687 | 688 | # Slice and chunk audio 689 | frames = torch.nn.functional.unfold( 690 | batch_audio[:, None, None], 691 | kernel_size=(1, penn.WINDOW_SIZE), 692 | stride=(1, hopsize)).permute(2, 0, 1) 693 | 694 | # Slow implementation for floating-point hopsizes 695 | else: 696 | 697 | # Allocate frames 698 | frames = torch.zeros(batch, 1, penn.WINDOW_SIZE) 699 | 700 | # Fill each frame with a window starting at the start index 701 | for j in range(batch): 702 | start = start_idxs[i + j] 703 | end = min(start + penn.WINDOW_SIZE, audio.shape[-1]) 704 | frames[j, :, : end - start] = audio[:, start:end] 705 | 706 | yield frames 707 | 708 | 709 | ############################################################################### 710 | # Inference acceleration 711 | ############################################################################### 712 | 713 | 714 | def inference_collate(batch): 715 | frames, lengths, files = zip(*batch) 716 | return ( 717 | torch.cat(frames, dim=0), 718 | torch.tensor(lengths, dtype=torch.long), 719 | files) 720 | 721 | 722 | def inference_loader( 723 | files, 724 | hopsize=penn.HOPSIZE_SECONDS, 725 | batch_size=None, 726 | center='half-window', 727 | num_workers=penn.NUM_WORKERS // 2 728 | ): 729 | dataset = InferenceDataset(files, hopsize, batch_size, center) 730 | return torch.utils.data.DataLoader( 731 | dataset, 732 | batch_sampler=InferenceSampler(dataset), 733 | num_workers=num_workers, 734 | collate_fn=inference_collate) 735 | 736 | 737 | def save_worker(prefix, pitch, periodicity, interp_unvoiced_at=None): 738 | """Save pitch and periodicity to disk""" 739 | # Maybe interpolate unvoiced regions 740 | if interp_unvoiced_at is not None: 741 | pitch = penn.voicing.interpolate( 742 | pitch, 743 | periodicity, 744 | interp_unvoiced_at) 745 | 746 | # Save 747 | Path(prefix).parent.mkdir(exist_ok=True, parents=True) 748 | torch.save(pitch, f'{prefix}-pitch.pt') 749 | torch.save(periodicity, f'{prefix}-periodicity.pt') 750 | 751 | # Clean-up 752 | del pitch 753 | del periodicity 754 | 755 | 756 | class InferenceDataset(torch.utils.data.Dataset): 757 | 758 | def __init__( 759 | self, 760 | files, 761 | hopsize=penn.HOPSIZE_SECONDS, 762 | batch_size=None, 763 | center='half-window'): 764 | self.files = files 765 | self.batch_size = batch_size 766 | self.lengths = [] 767 | for file in files: 768 | info = torchaudio.info(file) 769 | self.lengths.append( 770 | expected_frames( 771 | info.num_frames, 772 | info.sample_rate, 773 | hopsize, 774 | center)) 775 | self.preprocess_fn = functools.partial( 776 | preprocess, 777 | hopsize=hopsize, 778 | batch_size=batch_size, 779 | center=center) 780 | 781 | def __getitem__(self, index): 782 | frames = torch.cat( 783 | [ 784 | frame for frame in 785 | self.preprocess_fn(*torchaudio.load(self.files[index])) 786 | ], 787 | dim=0) 788 | return frames, self.lengths[index], self.files[index] 789 | 790 | def __len__(self): 791 | return len(self.files) 792 | 793 | 794 | class InferenceSampler(torch.utils.data.Sampler): 795 | 796 | def __init__(self, dataset): 797 | self.dataset = dataset 798 | 799 | def __iter__(self): 800 | return iter(self.batch) 801 | 802 | def __len__(self): 803 | return len(self.batch) 804 | 805 | @functools.cached_property 806 | def batch(self): 807 | count = 0 808 | batch, batches = [], [] 809 | for i, length in enumerate(self.dataset.lengths): 810 | batch.append(i) 811 | if self.dataset.batch_size is None: 812 | batches.append(batch) 813 | batch = [] 814 | else: 815 | count += length 816 | if count >= self.dataset.batch_size: 817 | batches.append(batch) 818 | batch = [] 819 | count = 0 820 | if batch: 821 | batches.append(batch) 822 | return batches 823 | 824 | 825 | ############################################################################### 826 | # Utilities 827 | ############################################################################### 828 | 829 | 830 | def cents(a, b): 831 | """Compute pitch difference in cents""" 832 | return penn.OCTAVE * torch.log2(a / b) 833 | 834 | 835 | def expected_frames(samples, sample_rate, hopsize, center): 836 | """Compute expected number of output frames""" 837 | # Calculate expected number of frames 838 | hopsize_resampled = penn.convert.seconds_to_samples( 839 | hopsize, 840 | sample_rate) 841 | if center == 'half-window': 842 | window_size_resampled = \ 843 | penn.WINDOW_SIZE / penn.SAMPLE_RATE * sample_rate 844 | samples = samples - (window_size_resampled - hopsize_resampled) 845 | elif center == 'half-hop': 846 | samples = samples 847 | elif center == 'zero': 848 | samples = samples + hopsize_resampled 849 | else: 850 | raise ValueError(f'Unknown center sample {center}') 851 | return max(1, int(samples / hopsize_resampled)) 852 | 853 | 854 | @contextlib.contextmanager 855 | def inference_context(model): 856 | device_type = next(model.parameters()).device.type 857 | 858 | # Prepare model for evaluation 859 | model.eval() 860 | 861 | # Turn off gradient computation 862 | with torch.inference_mode(): 863 | 864 | if device_type == 'cuda': 865 | with torch.autocast(device_type): 866 | yield 867 | else: 868 | yield 869 | 870 | # Prepare model for training 871 | model.train() 872 | 873 | 874 | def interpolate(x, xp, fp): 875 | """1D linear interpolation for monotonically increasing sample points""" 876 | # Handle edge cases 877 | if xp.shape[-1] == 0: 878 | return x 879 | if xp.shape[-1] == 1: 880 | return torch.full( 881 | x.shape, 882 | fp.squeeze(), 883 | device=fp.device, 884 | dtype=fp.dtype) 885 | 886 | # Get slope and intercept using right-side first-differences 887 | m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1]) 888 | b = fp[:, :-1] - (m.mul(xp[:, :-1])) 889 | 890 | # Get indices to sample slope and intercept 891 | indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1 892 | indicies = torch.clamp(indicies, 0, m.shape[-1] - 1) 893 | line_idx = torch.linspace( 894 | 0, 895 | indicies.shape[0], 896 | 1, 897 | device=indicies.device).to(torch.long).expand(indicies.shape) 898 | 899 | # Interpolate 900 | return m[line_idx, indicies].mul(x) + b[line_idx, indicies] 901 | 902 | 903 | def normalize(frames): 904 | """Normalize audio frames to have mean zero and std dev one""" 905 | # Mean-center 906 | frames -= frames.mean(dim=2, keepdim=True) 907 | 908 | # Scale 909 | frames /= torch.max( 910 | torch.tensor(1e-10, device=frames.device), 911 | frames.std(dim=2, keepdim=True)) 912 | 913 | return frames 914 | 915 | 916 | def resample(audio, sample_rate, target_rate=penn.SAMPLE_RATE): 917 | """Perform audio resampling""" 918 | if sample_rate == target_rate: 919 | return audio 920 | resampler = torchaudio.transforms.Resample(sample_rate, target_rate) 921 | resampler = resampler.to(audio.device) 922 | return resampler(audio) 923 | --------------------------------------------------------------------------------