├── README.md └── ptcrepe ├── crepe.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # CREPE Pitch Tracker (PyTorch) # 2 | 3 | - Original Tensorflow Implementation : [https://github.com/marl/crepe](https://github.com/marl/crepe) 4 | 5 | --- 6 | CREPE is a monophonic pitch tracker based on a deep convolutional neural network operating directly on the time-domain waveform input. CREPE is originally implemented with tensorflow, which is very inconvenient framework to use. 7 | 8 | 9 | ## Usage 10 | 11 | ```python 12 | import crepe 13 | import torch 14 | device = torch.device(0) 15 | cr = crepe.CREPE("full").to(device) 16 | cr.predict("path/to/audio.file", "path/to/output/directory/", ) 17 | ``` 18 | 19 | ## WIP 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /ptcrepe/crepe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import os, sys 5 | from .utils import * 6 | import numpy as np 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, f, w, s, in_channels): 10 | super().__init__() 11 | p1 = (w - 1) // 2 12 | p2 = (w - 1) - p1 13 | self.pad= nn.ZeroPad2d((0, 0, p1, p2)) 14 | 15 | self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=s) 16 | self.relu = nn.ReLU() 17 | self.bn = nn.BatchNorm2d(f) 18 | self.pool = nn.MaxPool2d(kernel_size=(2, 1)) 19 | self.dropout = nn.Dropout(0.25) 20 | 21 | def forward(self, x): 22 | x = self.pad(x) 23 | x = self.conv2d(x) 24 | x = self.relu(x) 25 | x = self.bn(x) 26 | x = self.pool(x) 27 | x = self.dropout(x) 28 | return x 29 | 30 | class CREPE(nn.Module): 31 | def __init__(self, model_capacity="full"): 32 | super().__init__() 33 | 34 | capacity_multiplier = { 35 | 'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32 36 | }[model_capacity] 37 | 38 | self.layers = [1, 2, 3, 4, 5, 6] 39 | filters = [n * capacity_multiplier for n in [32, 4, 4, 4, 8, 16]] 40 | filters = [1] + filters 41 | widths = [512, 64, 64, 64, 64, 64] 42 | strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)] 43 | 44 | for i in range(len(self.layers)): 45 | f, w, s, in_channel = filters[i+1], widths[i], strides[i], filters[i] 46 | self.add_module("conv%d" % i, ConvBlock(f, w, s, in_channel)) 47 | 48 | self.linear = nn.Linear(64*capacity_multiplier, 360) 49 | self.load_weight(model_capacity) 50 | self.eval() 51 | 52 | def load_weight(self, model_capacity): 53 | download_weights(model_capacity) 54 | package_dir = os.path.dirname(os.path.realpath(__file__)) 55 | filename = "crepe-{}.pth".format(model_capacity) 56 | self.load_state_dict(torch.load(os.path.join(package_dir, filename))) 57 | 58 | def forward(self, x): 59 | # x : shape (batch, sample) 60 | x = x.view(x.shape[0], 1, -1, 1) 61 | for i in range(len(self.layers)): 62 | x = self.__getattr__("conv%d" % i)(x) 63 | 64 | x = x.permute(0, 3, 2, 1) 65 | x = x.reshape(x.shape[0], -1) 66 | x = self.linear(x) 67 | x = torch.sigmoid(x) 68 | return x 69 | 70 | def get_activation(self, audio, sr, center=True, step_size=10, batch_size=128): 71 | """ 72 | audio : (N,) or (C, N) 73 | """ 74 | 75 | if sr != 16000: 76 | rs = torchaudio.transforms.Resample(sr, 16000) 77 | audio = rs(audio) 78 | 79 | if len(audio.shape) == 2: 80 | if audio.shape[0] == 1: 81 | audio = audio[0] 82 | else: 83 | audio = audio.mean(dim=0) # make mono 84 | 85 | def get_frame(audio, step_size, center): 86 | if center: 87 | audio = nn.functional.pad(audio, pad=(512, 512)) 88 | # make 1024-sample frames of the audio with hop length of 10 milliseconds 89 | hop_length = int(16000 * step_size / 1000) 90 | n_frames = 1 + int((len(audio) - 1024) / hop_length) 91 | assert audio.dtype == torch.float32 92 | itemsize = 1 # float32 byte size 93 | frames = torch.as_strided(audio, size=(1024, n_frames), stride=(itemsize, hop_length * itemsize)) 94 | frames = frames.transpose(0, 1).clone() 95 | 96 | frames -= (torch.mean(frames, axis=1).unsqueeze(-1)) 97 | frames /= (torch.std(frames, axis=1).unsqueeze(-1)) 98 | return frames 99 | 100 | frames = get_frame(audio, step_size, center) 101 | activation_stack = [] 102 | device = self.linear.weight.device 103 | 104 | for i in range(0, len(frames), batch_size): 105 | f = frames[i:min(i+batch_size, len(frames))] 106 | f = f.to(device) 107 | act = self.forward(f) 108 | activation_stack.append(act.cpu()) 109 | activation = torch.cat(activation_stack, dim=0) 110 | return activation 111 | 112 | def predict(self, audio, sr, viterbi=False, center=True, step_size=10, batch_size=128): 113 | activation = self.get_activation(audio, sr, batch_size=batch_size, step_size=step_size) 114 | frequency = to_freq(activation, viterbi=viterbi) 115 | confidence = activation.max(dim=1)[0] 116 | time = torch.arange(confidence.shape[0]) * step_size / 1000.0 117 | return time, frequency, confidence, activation 118 | 119 | def process_file(self, file, output=None, viterbi=False, 120 | center=True, step_size=10, save_plot=False, batch_size=128): 121 | try: 122 | audio, sr = torchaudio.load(file) 123 | except ValueError: 124 | print("CREPE-pytorch : Could not read", file, file=sys.stderr) 125 | 126 | with torch.no_grad(): 127 | time, frequency, confidence, activation = self.predict( 128 | audio, sr, 129 | viterbi=viterbi, 130 | center=center, 131 | step_size=step_size, 132 | batch_size=batch_size, 133 | ) 134 | 135 | time, frequency, confidence, activation = time.numpy(), frequency.numpy(), confidence.numpy(), activation.numpy() 136 | 137 | f0_file = os.path.join(output, os.path.basename(os.path.splitext(file)[0])) + ".f0.csv" 138 | f0_data = np.vstack([time, frequency, confidence]).transpose() 139 | np.savetxt(f0_file, f0_data, fmt=['%.3f', '%.3f', '%.6f'], delimiter=',', 140 | header='time,frequency,confidence', comments='') 141 | 142 | # save the salience visualization in a PNG file 143 | if save_plot: 144 | import matplotlib.cm 145 | from imageio import imwrite 146 | 147 | plot_file = os.path.join(output, os.path.basename(os.path.splitext(file)[0])) + ".activation.png" 148 | # to draw the low pitches in the bottom 149 | salience = np.flip(activation, axis=1) 150 | inferno = matplotlib.cm.get_cmap('inferno') 151 | image = inferno(salience.transpose()) 152 | 153 | imwrite(plot_file, (255 * image).astype(np.uint8)) 154 | 155 | 156 | if __name__ == "__main__": 157 | cr = CREPE().cuda() 158 | import glob 159 | files = glob.glob("../../ddsp/data/violin/*.wav") 160 | # files = ["../../ddsp/data/violin/VI.+Double.wav"] 161 | target = "../../ddsp/data/violin/f0_0.004/" 162 | from tqdm import tqdm 163 | for file in tqdm(files): 164 | cr.process_file(file, target, step_size=4, viterbi=True) 165 | -------------------------------------------------------------------------------- /ptcrepe/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | def download_weights(model_capacitiy): 6 | try: 7 | from urllib.request import urlretrieve 8 | except ImportError: 9 | from urllib import urlretrieve 10 | 11 | weight_file = 'crepe-{}.pth'.format(model_capacitiy) 12 | base_url = 'https://github.com/sweetcocoa/crepe-pytorch/raw/models/' 13 | 14 | # in all other cases, decompress the weights file if necessary 15 | package_dir = os.path.dirname(os.path.realpath(__file__)) 16 | weight_path = os.path.join(package_dir, weight_file) 17 | if not os.path.isfile(weight_path): 18 | print('Downloading weight file {} from {} ...'.format(weight_path, base_url + weight_file)) 19 | urlretrieve(base_url + weight_file, weight_path) 20 | 21 | 22 | def to_local_average_cents(salience, center=None): 23 | """ 24 | find the weighted average cents near the argmax bin 25 | """ 26 | 27 | if not hasattr(to_local_average_cents, 'cents_mapping'): 28 | # the bin number-to-cents mapping 29 | to_local_average_cents.mapping = ( 30 | torch.tensor(np.linspace(0, 7180, 360)) + 1997.3794084376191) 31 | 32 | 33 | if isinstance(salience, np.ndarray): 34 | salience = torch.from_numpy(salience) 35 | 36 | if salience.ndim == 1: 37 | if center is None: 38 | center = int(torch.argmax(salience)) 39 | start = max(0, center - 4) 40 | end = min(len(salience), center + 5) 41 | salience = salience[start:end] 42 | product_sum = torch.sum( 43 | salience * to_local_average_cents.mapping[start:end]) 44 | weight_sum = torch.sum(salience) 45 | return product_sum / weight_sum 46 | if salience.ndim == 2: 47 | return torch.tensor([to_local_average_cents(salience[i, :]) for i in 48 | range(salience.shape[0])]) 49 | 50 | raise Exception("label should be either 1d or 2d Tensor") 51 | 52 | 53 | 54 | def to_viterbi_cents(salience): 55 | """ 56 | Find the Viterbi path using a transition prior that induces pitch 57 | continuity. 58 | 59 | * Note : This is NOT implemented with pytorch. 60 | """ 61 | from hmmlearn import hmm 62 | 63 | # uniform prior on the starting pitch 64 | starting = np.ones(360) / 360 65 | 66 | # transition probabilities inducing continuous pitch 67 | xx, yy = np.meshgrid(range(360), range(360)) 68 | transition = np.maximum(12 - abs(xx - yy), 0) 69 | transition = transition / np.sum(transition, axis=1)[:, None] 70 | 71 | # emission probability = fixed probability for self, evenly distribute the 72 | # others 73 | self_emission = 0.1 74 | emission = (np.eye(360) * self_emission + np.ones(shape=(360, 360)) * 75 | ((1 - self_emission) / 360)) 76 | 77 | # fix the model parameters because we are not optimizing the model 78 | model = hmm.MultinomialHMM(360, starting, transition) 79 | model.startprob_, model.transmat_, model.emissionprob_ = \ 80 | starting, transition, emission 81 | 82 | # find the Viterbi path 83 | observations = np.argmax(salience, axis=1) 84 | path = model.predict(observations.reshape(-1, 1), [len(observations)]) 85 | 86 | return np.array([to_local_average_cents(salience[i, :], path[i]) for i in 87 | range(len(observations))]) 88 | 89 | 90 | def to_freq(activation, viterbi=False): 91 | if viterbi: 92 | cents = to_viterbi_cents(activation.detach().numpy()) 93 | cents = torch.tensor(cents) 94 | else: 95 | cents = to_local_average_cents(activation) 96 | 97 | frequency = 10 * 2 ** (cents / 1200) 98 | frequency[torch.isnan(frequency)] = 0 99 | return frequency --------------------------------------------------------------------------------