├── constants ├── midi.py └── datasets.py ├── feature ├── wrapper_func.py ├── hcfp.py └── cfp.py ├── generate_align.py ├── transcribe.py ├── generate_contour.py ├── requirements.txt ├── models ├── losses.py ├── u_net.py ├── t2t.py └── pyramid_net.py ├── vocal_contour ├── labels.py ├── inference.py ├── callbacks.py └── app.py ├── vocal ├── prediction.py ├── labels.py ├── inference.py └── app.py ├── train_contour.py ├── train_align.py ├── defaults ├── vocal_contour.yaml └── vocal.yaml ├── README.md ├── setting_loaders.py ├── train.py ├── utils.py └── base.py /constants/midi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as jpath 3 | 4 | from librosa import note_to_midi 5 | 6 | MODULE_PATH = os.path.abspath(jpath(os.path.split(__file__)[0], '..')) 7 | 8 | LOWEST_MIDI_NOTE = note_to_midi("A0") 9 | HIGHEST_MIDI_NOTE = note_to_midi("C8") -------------------------------------------------------------------------------- /feature/wrapper_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import os 5 | import sys 6 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 7 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 8 | 9 | from feature import cfp 10 | from feature import hcfp 11 | 12 | 13 | def extract_cfp_feature(audio_path, harmonic=False, harmonic_num=6, **kwargs): 14 | """Wrapper of CFP/HCFP feature extraction""" 15 | 16 | if harmonic: 17 | spec, gcos, ceps, _ = hcfp.extract_hcfp(audio_path, harmonic_num=harmonic_num, **kwargs) 18 | return np.dstack([spec, gcos, ceps]) 19 | 20 | z, spec, gcos, ceps, _ = cfp.extract_cfp(audio_path, **kwargs) 21 | return np.dstack([z.T, spec.T, gcos.T, ceps.T]) -------------------------------------------------------------------------------- /generate_align.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | # Add arguments 4 | ap = argparse.ArgumentParser() 5 | ap.add_argument('-i', '--input', required=True, help='Path to vocal alignment dataset') 6 | ap.add_argument('-o', '--output', required=False, help='Path to generated feature') 7 | args = ap.parse_args() 8 | 9 | from setting_loaders import VocalSettings 10 | from vocal import app 11 | 12 | # Change settings to match arguments 13 | config_path = 'defaults/vocal.yaml' 14 | config = VocalSettings(config_path) 15 | if args.output is not None: 16 | config.dataset.feature_save_path= args.output 17 | 18 | # Generate features 19 | va_transcription = app.VocalTranscription(config_path) 20 | va_transcription.generate_feature(dataset_path=args.input, vocal_settings=config) -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | # Add arguments 2 | import argparse 3 | ap = argparse.ArgumentParser() 4 | ap.add_argument('-i', '--input', required=True, help='Path to input audio (wav)') 5 | ap.add_argument('-o', '--output', required=False, help='Path to the folder will contain predictions') 6 | ap.add_argument('-m', '--model', required=False, help='Path to the transcribe model') 7 | args = ap.parse_args() 8 | 9 | from vocal import app 10 | 11 | # Change settings to match arguments 12 | config_path = 'defaults/vocal.yaml' 13 | model = args.model if args.model is not None else None 14 | output = args.output if args.output is not None else "./" 15 | 16 | # Transcribe 17 | va_transcription = app.VocalTranscription(config_path) 18 | va_transcription.transcribe(input_audio=args.input, model_path=model, output=output) 19 | -------------------------------------------------------------------------------- /generate_contour.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | # Add arguments 4 | ap = argparse.ArgumentParser() 5 | ap.add_argument('-i', '--input', required=True, help='Path to vocal contour dataset') 6 | ap.add_argument('-o', '--output', required=False, help='Path to generated feature') 7 | args = ap.parse_args() 8 | 9 | from setting_loaders import VocalContourSettings 10 | from vocal_contour import app as vcapp 11 | 12 | # Change settings to match arguments 13 | config_path = 'defaults/vocal_contour.yaml' 14 | config = VocalContourSettings(config_path) 15 | if args.output is not None: 16 | config.dataset.feature_save_path = args.output 17 | 18 | # Generate features 19 | vc_transcription = vcapp.VocalContourTranscription(config_path) 20 | vc_transcription.generate_feature(dataset_path=args.input, vocalcontour_settings=config) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | async-generator==1.10 2 | attrs==21.2.0 3 | audioread==2.1.9 4 | cached-property==1.5.2 5 | cachetools==4.2.4 6 | certifi==2021.10.8 7 | cffi==1.15.0 8 | charset-normalizer==2.0.7 9 | click==7.1.2 10 | colorama==0.4.4 11 | contextvars==2.4 12 | decorator==5.1.0 13 | google-auth==2.3.3 14 | httpcore==0.13.3 15 | idna==3.3 16 | immutables==0.16 17 | importlib-metadata==3.10.1 18 | importlib-resources==4.1.1 19 | joblib==1.1.0 20 | jsonschema==3.2.0 21 | madmom==0.16.1 22 | mido==1.2.10 23 | mir-eval==0.6 24 | oauthlib==3.1.1 25 | packaging==21.2 26 | pandas==1.1.5 27 | pillow==8.4.0 28 | pooch==1.5.2 29 | pretty-midi==0.2.9 30 | protobuf==3.19.1 31 | pyfluidsynth==1.3.0 32 | pyparsing==2.4.7 33 | pyrsistent==0.18.0 34 | pytz==2021.3 35 | pyyaml==5.4.1 36 | ruamel.yaml==0.17.21 37 | requests-oauthlib==1.3.0 38 | requests==2.26.0 39 | resampy==0.2.2 40 | rsa==4.7.2 41 | scikit-learn==0.24.2 42 | scipy==1.5.4 43 | sniffio==1.2.0 44 | soundfile==0.10.3.post1 45 | tensorboard-plugin-wit==1.8.0 46 | tensorboard==2.7.0 47 | threadpoolctl==3.0.0 48 | tqdm==4.62.3 49 | urllib3==1.26.4 50 | vamp==1.1.0 51 | werkzeug==2.0.2 52 | wrapt==1.12.1 53 | zipp==3.6.0 54 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import array_ops 3 | 4 | 5 | def focal_loss(target_tensor, prediction_tensor, weights=None, alpha=0.25, gamma=2): 6 | """Compute focal loss for predictions 7 | 8 | Multi-labels Focal loss formula: FL = -\alpha * (z-p)^\gamma * \log{(p)} -(1-\alpha) * p^\gamma * \log{(1-p)} 9 | Which :`\alpha` = 0.25, `\gamma` = 2, p = sigmoid(x), z = target_tensor. 10 | """ 11 | 12 | sigmoid_p = tf.nn.sigmoid(prediction_tensor) 13 | zeros = array_ops.zeros_like(sigmoid_p) 14 | pos_p_sub = array_ops.where(target_tensor >= sigmoid_p, target_tensor - sigmoid_p, zeros) 15 | neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p) 16 | per_entry_cross_ent = -alpha * (pos_p_sub**gamma) * tf.math.log( # noqa: E226 17 | tf.clip_by_value(sigmoid_p, 1e-8, 1.0) 18 | ) - (1-alpha) * (neg_p_sub**gamma) * tf.math.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0)) # noqa: E226 19 | 20 | if weights is not None: 21 | weights = tf.constant(weights, dtype=per_entry_cross_ent.dtype) 22 | per_entry_cross_ent *= weights 23 | 24 | return tf.reduce_mean(per_entry_cross_ent) 25 | -------------------------------------------------------------------------------- /vocal_contour/labels.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | import os 5 | import sys 6 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 7 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 8 | 9 | from constants import datasets as dset 10 | from constants.midi import LOWEST_MIDI_NOTE 11 | 12 | 13 | class BaseLabelExtraction(metaclass=abc.ABCMeta): 14 | """Base class for extract label information""" 15 | 16 | @classmethod 17 | @abc.abstractmethod 18 | def load_label(cls, label_path): 19 | """Load the label file and parse information into ``Label`` class""" 20 | raise NotImplementedError 21 | 22 | @classmethod 23 | def extract_label(cls, label_path, t_unit=0.02): 24 | labels = cls.load_label(label_path) 25 | fs = round(1 / t_unit) 26 | 27 | max_time = max(label.end_time for label in labels) 28 | output = np.zeros((round(max_time * fs), 352)) 29 | for label in labels: 30 | start_idx = round(label.start_time * fs) 31 | end_idx = round(label.end_time * fs) 32 | pitch = round((label.note - LOWEST_MIDI_NOTE) * 4) 33 | output[start_idx:end_idx, pitch] = 1 34 | return output 35 | 36 | 37 | class VocalContourlabelExtraction(BaseLabelExtraction): 38 | """vocal contour datasets label extraction class""" 39 | @classmethod 40 | def load_label(cls, label_path): 41 | return dset.VocalContourStructure.load_label(label_path) 42 | -------------------------------------------------------------------------------- /vocal/prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import sys 5 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 6 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 7 | 8 | from utils import get_logger 9 | 10 | logger = get_logger("Vocal Predict") 11 | 12 | 13 | def create_batches(feature, ctx_len=9, batch_size=64): 14 | feat_pad = np.pad(feature, ((ctx_len, ctx_len), (0, 0), (0, 0))) 15 | 16 | slices = [feat_pad[idx - ctx_len:idx + ctx_len + 1] for idx in range(ctx_len, len(feat_pad) - ctx_len)] 17 | pad_size = batch_size - len(slices) % batch_size 18 | payload = np.zeros_like(slices[0]) 19 | for _ in range(pad_size): 20 | slices.append(payload) 21 | slices = np.array(slices) 22 | assert len(slices) % batch_size == 0 23 | 24 | batches = [slices[idx:idx + batch_size] for idx in range(0, len(slices), batch_size)] 25 | return np.array(batches, dtype=np.float32), pad_size 26 | 27 | 28 | def merge_batches(batch_pred): 29 | assert len(batch_pred.shape) == 4 30 | 31 | batches, batch_size, frm_len, out_classes = batch_pred.shape 32 | total_len = batches * batch_size + frm_len - 1 33 | output = np.zeros((total_len, out_classes)) 34 | for bidx, batch in enumerate(batch_pred): 35 | for fidx, frame in enumerate(batch): 36 | start_idx = bidx * batch_size + fidx 37 | output[start_idx:start_idx + frm_len] += frame 38 | 39 | max_len = min(frm_len - 1, len(output) - frm_len) 40 | output[max_len:-max_len] /= max_len + 1 41 | for idx in range(max_len): 42 | output[idx] /= idx + 1 43 | output[-1 - idx] /= idx + 1 44 | return output 45 | 46 | 47 | def predict(feature, model, ctx_len=9, batch_size=16): 48 | assert feature.shape[1:] == (174, 9) 49 | batches, pad_size = create_batches(feature, ctx_len=ctx_len, batch_size=batch_size) 50 | batch_pred = [] 51 | for idx, batch in enumerate(batches): 52 | print(f"Progress: {idx+1}/{len(batches)}", end="\r") 53 | batch_pred.append(model.predict(batch)) 54 | pred = merge_batches(np.array(batch_pred)) 55 | return pred[ctx_len:-pad_size - ctx_len] 56 | -------------------------------------------------------------------------------- /vocal_contour/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import expit 3 | from librosa.core import midi_to_hz 4 | 5 | import os 6 | import sys 7 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 8 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 9 | 10 | from constants.midi import LOWEST_MIDI_NOTE 11 | 12 | def inference(feature, model, timestep=128, batch_size=10, feature_num=384): 13 | assert len(feature.shape) == 2 14 | # Padding 15 | total_samples = len(feature) 16 | pad_bottom = (feature_num - feature.shape[1]) // 2 17 | pad_top = feature_num - feature.shape[1] - pad_bottom 18 | pad_len = timestep - 1 19 | feature = np.pad(feature, ((pad_len, pad_len), (pad_bottom, pad_top))) 20 | 21 | # Prepare for prediction 22 | output = np.zeros(feature.shape + (2,)) 23 | total_batches = int(np.ceil(total_samples / batch_size)) 24 | last_batch_idx = len(feature) - pad_len 25 | for bidx in range(total_batches): 26 | print(f"batch: {bidx+1}/{total_batches}", end="\r") 27 | 28 | # Collect batch feature 29 | start_idx = bidx * batch_size 30 | end_idx = min(start_idx + batch_size, last_batch_idx) 31 | batch = np.array([feature[idx:idx+timestep] for idx in range(start_idx, end_idx)]) # noqa: E226 32 | batch = np.expand_dims(batch, axis=3) 33 | 34 | # Predict contour 35 | batch_pred = model.predict(batch) 36 | batch_pred = 1 / (1 + np.exp(-expit(batch_pred))) 37 | 38 | # Add the batch results to the output container. 39 | for idx, pred in enumerate(batch_pred): 40 | slice_start = start_idx + idx 41 | slice_end = slice_start + timestep 42 | output[slice_start:slice_end] += pred 43 | output = output[pad_len:-pad_len, pad_bottom:-pad_top, 1] # Remove padding 44 | 45 | # Filter values 46 | avg_max_val = np.mean(np.max(output, axis=1)) 47 | output = np.where(output > avg_max_val, output, 0) 48 | 49 | # Generate final output F0 50 | f0 = [] # pylint: disable=invalid-name 51 | for pitches in output: 52 | if np.sum(pitches) > 0: 53 | pidx = np.argmax(pitches) 54 | f0.append(midi_to_hz(pidx / 4 + LOWEST_MIDI_NOTE)) 55 | else: 56 | f0.append(0) 57 | 58 | return np.array(f0) 59 | -------------------------------------------------------------------------------- /feature/hcfp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | 5 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 6 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 7 | 8 | from feature.cfp import extract_cfp 9 | from utils import get_logger 10 | 11 | 12 | logger = get_logger("HCFP Feature") 13 | 14 | 15 | def fetch_harmonic(data, cenf, ith_har, start_freq=27.5, num_per_octave=48, is_reverse=False): 16 | ith_har += 1 17 | if ith_har != 0 and is_reverse: 18 | ith_har = 1 / ith_har 19 | 20 | # harmonic_series = [12, 19, 24, 28, 31] 21 | bins_per_note = int(num_per_octave / 12) 22 | total_bins = int(bins_per_note * 88) 23 | 24 | hid = min(range(len(cenf)), key=lambda i: abs(cenf[i] - ith_har*start_freq)) # noqa: E226 25 | 26 | harmonic = np.zeros((total_bins, data.shape[1])) 27 | upper_bound = min(len(cenf) - 1, hid + total_bins) 28 | harmonic[:(upper_bound - hid)] = data[hid:upper_bound] 29 | 30 | return harmonic 31 | 32 | 33 | def extract_hcfp( 34 | filename, 35 | hop=0.02, # in seconds 36 | win_size=7939, 37 | fr=2.0, 38 | g=[0.24, 0.6, 1], 39 | bin_per_octave=48, 40 | down_fs=44100, 41 | max_sample=2000, 42 | harmonic_num=6, 43 | ): 44 | _, spec, gcos, ceps, cenf = extract_cfp( 45 | filename, 46 | hop=hop, 47 | win_size=win_size, 48 | fr=fr, 49 | fc=1.0, 50 | tc=1 / 22050, 51 | g=g, 52 | bin_per_octave=bin_per_octave, 53 | down_fs=down_fs, 54 | max_sample=max_sample, 55 | ) 56 | 57 | har = [] 58 | logger.debug("Fetching harmonics of spectrum") 59 | for i in range(harmonic_num + 1): 60 | har.append(fetch_harmonic(spec, cenf, i)) 61 | har_s = np.transpose(np.array(har), axes=(2, 1, 0)) 62 | 63 | # Harmonic GCoS 64 | har = [] 65 | logger.debug("Fetching harmonics of GCoS") 66 | for i in range(harmonic_num + 1): 67 | har.append(fetch_harmonic(gcos, cenf, i)) 68 | har_g = np.transpose(np.array(har), axes=(2, 1, 0)) 69 | 70 | # Harmonic cepstrum 71 | har = [] 72 | logger.debug("Fetching harmonics of cepstrum") 73 | for i in range(harmonic_num + 1): 74 | har.append(fetch_harmonic(ceps, cenf, i, is_reverse=True)) 75 | har_c = np.transpose(np.array(har), axes=(2, 1, 0)) 76 | 77 | return har_s, har_g, har_c, cenf 78 | -------------------------------------------------------------------------------- /train_contour.py: -------------------------------------------------------------------------------- 1 | # Add arguments 2 | import argparse 3 | ap = argparse.ArgumentParser() 4 | ap.add_argument('-f', '--feature', required=True, help='Path to the folder of extracted feature') 5 | ap.add_argument('-i', '--input_model', required=False, help='If given, the training will continue to fine-tune the pre-trained model') 6 | ap.add_argument('-o', '--output_model', required=False, help='Name for the output model') 7 | ap.add_argument('-e', '--epochs', required=False, help='Number of training epochs') 8 | ap.add_argument('-b', '--batch_size', required=False, help='Batch size of each training step') 9 | ap.add_argument('-s', '--steps', required=False, help='Number of step each training epochs (virtual epochs)') 10 | ap.add_argument('-vb', '--val_batch_size', required=False, help='Batch size of each validation step') 11 | ap.add_argument('-vs', '--val_steps', required=False, help='Number of step each validation epochs (virtual epochs)') 12 | ap.add_argument('-lr', '--learning_rate', required=False, help='Initial learning rate') 13 | ap.add_argument('--early_stop', required=False, help='Stop the training if validation accuracy does not improve over the given number of epochs') 14 | args = ap.parse_args() 15 | 16 | from setting_loaders import VocalContourSettings 17 | from vocal_contour import app as vcapp 18 | 19 | # Change settings to match arguments 20 | config_path = 'defaults/vocal_contour.yaml' 21 | config = VocalContourSettings(config_path) 22 | output_model = args.output_model if args.output_model is not None else None 23 | input_model = args.input_model if args.input_model is not None else None 24 | if args.epochs is not None: 25 | config.training.epoch = int(args.epochs) 26 | if args.batch_size is not None: 27 | config.training.batch_size = int(args.batch_size) 28 | if args.steps is not None: 29 | config.training.steps = int(args.steps) 30 | if args.val_batch_size is not None: 31 | config.training.val_batch_size = int(args.val_batch_size) 32 | if args.val_steps is not None: 33 | config.training.val_steps = int(args.val_steps) 34 | if args.learning_rate is not None: 35 | config.training.init_learning_rate = float(args.learning_rate) 36 | if args.early_stop is not None: 37 | config.training.early_stop = int(args.early_stop) 38 | 39 | # Training 40 | vc_transcription = vcapp.VocalContourTranscription(config_path) 41 | vc_transcription.train(feature_folder=args.feature, model_name=output_model, input_model_path=input_model, vocalcontour_settings=config) -------------------------------------------------------------------------------- /vocal/labels.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | import os 5 | import sys 6 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 7 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 8 | 9 | from constants import datasets as dset 10 | 11 | 12 | class BaseLabelExtraction(metaclass=abc.ABCMeta): 13 | """Base class for extract label information""" 14 | 15 | @classmethod 16 | @abc.abstractmethod 17 | def load_label(cls, label_path): # -> list[Label] 18 | """Load the label file and parse information into ``Label`` class""" 19 | raise NotImplementedError 20 | 21 | @classmethod 22 | def extract_label(cls, label_path, t_unit=0.02): 23 | """Extract SDT label""" 24 | 25 | label_list = cls.load_label(label_path) 26 | 27 | max_sec = max([ll.end_time for ll in label_list]) 28 | num_frm = int(max_sec / t_unit) + 10 # Reserve additional 10 frames 29 | 30 | sdt_label = np.zeros((num_frm, 6)) 31 | frm_per_sec = round(1 / t_unit) 32 | clip = lambda v: np.clip(v, 0, num_frm - 1) 33 | for label in label_list: 34 | act_range = range( 35 | round(label.start_time*frm_per_sec), round(label.end_time*frm_per_sec) # noqa: E226 36 | ) 37 | on_range = range( 38 | round(label.start_time*frm_per_sec - 2), round(label.start_time*frm_per_sec + 4) # noqa: E226 39 | ) 40 | off_range = range( 41 | round(label.end_time*frm_per_sec - 2), round(label.end_time*frm_per_sec + 4) # noqa: E226 42 | ) 43 | if len(act_range) == 0: 44 | continue 45 | 46 | sdt_label[clip(act_range), 0] = 1 # activation 47 | sdt_label[clip(on_range), 2] = 1 # onset 48 | sdt_label[clip(off_range), 4] = 1 # offset 49 | 50 | sdt_label[:, 1] = 1 - sdt_label[:, 0] 51 | sdt_label[:, 3] = 1 - sdt_label[:, 2] 52 | sdt_label[:, 5] = 1 - sdt_label[:, 4] 53 | return sdt_label 54 | 55 | 56 | class VocalAlignLabelExtraction(BaseLabelExtraction): 57 | """Label extraction for vocal-semi datasets""" 58 | @classmethod 59 | def load_label(cls, label_path): 60 | return dset.VocalAlignStructure.load_label(label_path) 61 | 62 | 63 | class UnlabeledLabelExtraction(BaseLabelExtraction): 64 | """Label extraction for unlabeled datasets""" 65 | @classmethod 66 | def load_label(cls, label_path): 67 | return dset.UnlabeledStructure.load_label(label_path) -------------------------------------------------------------------------------- /train_align.py: -------------------------------------------------------------------------------- 1 | # Add arguments 2 | import argparse 3 | ap = argparse.ArgumentParser() 4 | ap.add_argument('-f', '--feature', required=True, help='Path to the folder of extracted feature') 5 | ap.add_argument('-fs', '--feature_semi', required=False, help='Path to the folder of extracted semi feature') 6 | ap.add_argument('-i', '--input_model', required=False, help='If given, the training will continue to fine-tune the pre-trained model') 7 | ap.add_argument('-o', '--output_model', required=False, help='Name for the output model') 8 | ap.add_argument('-e', '--epochs', required=False, help='Number of training epochs') 9 | ap.add_argument('-b', '--batch_size', required=False, help='Batch size of each training step') 10 | ap.add_argument('-s', '--steps', required=False, help='Number of step each training epochs (virtual epochs)') 11 | ap.add_argument('-vb', '--val_batch_size', required=False, help='Batch size of each validation step') 12 | ap.add_argument('-vs', '--val_steps', required=False, help='Number of step each validation epochs (virtual epochs)') 13 | ap.add_argument('-lr', '--learning_rate', required=False, help='Initial learning rate') 14 | ap.add_argument('--early_stop', required=False, help='Stop the training if validation accuracy does not improve over the given number of epochs') 15 | args = ap.parse_args() 16 | 17 | from setting_loaders import VocalSettings 18 | from vocal import app 19 | 20 | # Change settings to match arguments 21 | config_path = 'defaults/vocal.yaml' 22 | config = VocalSettings(config_path) 23 | feature_semi = args.feature_semi if args.feature_semi is not None else None 24 | output_model = args.output_model if args.output_model is not None else None 25 | input_model = args.input_model if args.input_model is not None else None 26 | if args.epochs is not None: 27 | config.training.epoch = int(args.epochs) 28 | if args.batch_size is not None: 29 | config.training.batch_size = int(args.batch_size) 30 | if args.steps is not None: 31 | config.training.steps = int(args.steps) 32 | if args.val_batch_size is not None: 33 | config.training.val_batch_size = int(args.val_batch_size) 34 | if args.val_steps is not None: 35 | config.training.val_steps = int(args.val_steps) 36 | if args.learning_rate is not None: 37 | config.training.init_learning_rate = float(args.learning_rate) 38 | if args.early_stop is not None: 39 | config.training.early_stop = int(args.early_stop) 40 | 41 | # Training 42 | va_transcription = app.VocalTranscription(config_path) 43 | va_transcription.train(feature_folder=args.feature, semi_feature_folder=feature_semi, model_name=output_model, input_model_path=input_model, vocal_settings=config) -------------------------------------------------------------------------------- /defaults/vocal_contour.yaml: -------------------------------------------------------------------------------- 1 | General: 2 | TranscriptionMode: 3 | Description: Mode of transcription by executing the `omnizart vocal-contour transribe` command. 4 | Type: String 5 | Value: VocalContour 6 | CheckpointPath: 7 | Description: Path to the pre-trained models. 8 | Type: Map 9 | SubType: [String, String] 10 | Value: 11 | VocalContour: pretrained_models/vocal_contour 12 | Feature: 13 | Description: Default settings of feature extraction 14 | Settings: 15 | HopSize: 16 | Description: Hop size in seconds with respect to sampling rate. 17 | Type: Float 18 | Value: 0.02 19 | SamplingRate: 20 | Description: Adjust input sampling rate to this value. 21 | Type: Integer 22 | Value: 16000 23 | WindowSize: 24 | Type: Integer 25 | Value: 2049 26 | Dataset: 27 | Description: Settings of datasets. 28 | Settings: 29 | SavePath: 30 | Description: Path for storing the downloaded datasets. 31 | Type: String 32 | Value: ./ 33 | FeatureSavePath: 34 | Description: Path for storing the extracted feature. Default to the path under the dataset folder. 35 | Type: String 36 | Value: + 37 | Model: 38 | Description: Default settings of training / testing the model. 39 | Settings: 40 | SavePrefix: 41 | Description: Prefix of the trained model's name to be saved. 42 | Type: String 43 | Value: vocal_contour 44 | SavePath: 45 | Description: Path to save the trained model. 46 | Type: String 47 | Value: ./checkpoints/vocal_contour 48 | Training: 49 | Description: Parameters for training 50 | Settings: 51 | Epoch: 52 | Description: Maximum number of epochs for training. 53 | Type: Integer 54 | Value: 20 55 | EarlyStop: 56 | Description: Terminate the training if the validation performance doesn't imrove after n epochs. 57 | Type: Integer 58 | Value: 10 59 | Steps: 60 | Description: Number of training steps for each epoch. 61 | Type: Integer 62 | Value: 1000 63 | ValSteps: 64 | Description: Number of validation steps after each training epoch. 65 | Type: Integer 66 | Value: 50 67 | BatchSize: 68 | Description: Batch size of each training step. 69 | Type: Integer 70 | Value: 32 71 | ValBatchSize: 72 | Description: Batch size of each validation step. 73 | Type: Integer 74 | Value: 32 75 | Timesteps: 76 | Description: Length of time axis of the input feature. 77 | Type: Integer 78 | Value: 128 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Singing Voice Transcription 2 | Singing Voice Transcription is a module that aim to automatic transcript singing voice into music note. Given a polyphonic music, it is able to find and extract main melody at note-level from this music singing voice. You can find some demo samples [here](https://drive.google.com/drive/folders/1o-FqYGEZao_5H8FRuiHVoU4RqJQnFOBo?usp=sharing). 3 | 4 | This module is referenced and improved from a part of [Omnizart](https://github.com/Music-and-Culture-Technology-Lab/omnizart) by [Music and Culture Technology (MCT) Lab](https://github.com/Music-and-Culture-Technology-Lab). 5 | 6 | ## Installation 7 | Before starting, make sure you have conda or [miniconda](https://docs.conda.io/en/latest/miniconda.html) installed, then create new environment: 8 | ```bash 9 | conda create -n voice_transcription python=3.7 10 | conda activate voice_transcription 11 | ``` 12 | First, install [Spleeter by deezer](https://github.com/deezer/spleeter) and its required system packages: 13 | ```bash 14 | conda install -c conda-forge ffmpeg libsndfile fluidsynth 15 | conda install numpy Cython 16 | pip install spleeter==2.3.0 17 | ``` 18 | Then install this module and dependencies: 19 | ```bash 20 | git clone https://github.com/pthang23/Singing_Voice_Transcription 21 | cd Singing_Voice_Transcription 22 | pip install -r requirements.txt 23 | ``` 24 | Finally, download pretrained models from [here](https://drive.google.com/file/d/1y3M_rutkUW5xvp88z8eoFyzRbKwRhfa3/view?usp=sharing), put in `Singing_Voice_Transcription` and unzip it. Then you can access all the module features. 25 | 26 | ## Transcribe Music 27 | Transcribe a single audio by running the command, output will be saved in MIDI format with the same basename as the given audio: 28 | ```bash 29 | python transcribe.py -i -o 30 | ``` 31 | You can also view more transcribe options with `--help` command 32 | 33 | ## Training 34 | You can train a our module using your own custom dataset. This module contain 2 main part: 35 | - **Vocal Alignment**: align onset, offset of each note 36 | - **Vocal Contour Estimation**: segment pitch line 37 | 38 | Before training, you need to make sure the data structure look like following: 39 |
40 |  |  dataset
41 |  |  ├── audios
42 |  |  │   └── audio1.wav ...
43 |  |  └── labels
44 |  |      └── label1.csv ...
45 | 
46 | **Vocal Alignment** label file contain 3 columns: onset, offset, midi_pitch
47 | **Vocal Contour Estimation** label file contain 2 columns: onset, pitch (hz) 48 | 49 | ### Training Vocal Alignment 50 | First of all, generate the features that are necessary for training and testing. You can use simultaneously semi-supervised learning with unlabeled dataset, just eliminate **'labels'** part form the above data structure: 51 | ```bash 52 | python generate_align.py -i -o 53 | ``` 54 | The processed labeled features will be stored in `/train_feature` and `/test_feature`. The semi-supervised feature will be stored in `/semi_feature`. 55 | 56 | Then training a new model or continue to train on a pretrained model: 57 | ```bash 58 | python train_align.py -f -fs -i 59 | ``` 60 | You can view more training options with `--help` command or access `defaults/vocal.yaml` 61 | 62 | ### Training Vocal Contour Estimation 63 | You also need to generate feature first, the processed features will be stored in `/train_feature` and `/test_feature`: 64 | ```bash 65 | python generate_contour.py -i -o 66 | ``` 67 | Then training from scratch or finetuning contour model: 68 | ```bash 69 | python train_contour.py -f -i 70 | ``` 71 | Once more time check `--help` command or access `defaults/vocal_contour.yaml` if you want to view more training options 72 | 73 | ## Reference 74 | - [Omnizart](https://github.com/Music-and-Culture-Technology-Lab/omnizart) by Music and Culture Technology (MCT) Lab
75 | - [Spleeter](https://github.com/deezer/spleeter) by deezer
76 | - [TONAS](https://zenodo.org/record/1290722#.Y6q2RadBxH4), [VOCADITO](https://zenodo.org/record/5578807#.Y6q2hKdBxH4), [CSD](https://zenodo.org/record/4785016#.Y6q22adBxH4), [MEDLEYDB](https://medleydb.weebly.com/) and [MIR1K](https://sites.google.com/site/unvoicedsoundseparation/mir-1k) dataset 77 | -------------------------------------------------------------------------------- /vocal_contour/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import abc 3 | import numpy as np 4 | 5 | from utils import write_yaml, get_logger, ensure_path_exists 6 | 7 | logger = get_logger("Callbacks") 8 | 9 | 10 | class Callback(metaclass=abc.ABCMeta): 11 | """Base class of all callback classes""" 12 | 13 | def __init__(self, monitor=None): 14 | if monitor is not None: 15 | self.monitor = monitor 16 | if "acc" in monitor: 17 | self.monitor_op = np.greater 18 | else: 19 | self.monitor_op = np.less 20 | 21 | def on_train_begin(self, history=None): 22 | pass 23 | 24 | def on_train_end(self, history=None): 25 | pass 26 | 27 | def on_epoch_begin(self, epoch, history=None): 28 | pass 29 | 30 | def on_epoch_end(self, epoch, history=None): 31 | pass 32 | 33 | def on_train_batch_begin(self, history=None): 34 | pass 35 | 36 | def on_train_batch_end(self, history=None): 37 | pass 38 | 39 | def on_test_batch_begin(self, history=None): 40 | pass 41 | 42 | def on_test_batch_end(self, history=None): 43 | pass 44 | 45 | def _set_model(self, model): 46 | self.model = model 47 | 48 | def _get_monitor_value(self, history, callback_name="Callback"): 49 | history = history or {"train": [], "validate": []} 50 | 51 | if self.monitor.startswith("val"): 52 | hist = history["validate"] 53 | else: 54 | hist = history["train"] 55 | 56 | if len(hist) > 0: 57 | current = hist[-1] 58 | 59 | metric = self.monitor.split("_")[-1] 60 | if metric == "acc": 61 | metric = "accuracy" 62 | score = current.get(metric) 63 | if score is None: 64 | logger.warning( 65 | "%s conditioned on metric %s " 66 | "which is not available. Available metrics are %s", 67 | callback_name, self.monitor, list(current.keys()) 68 | ) 69 | return score 70 | 71 | 72 | class EarlyStopping(Callback): 73 | """Early stop the training after no improvement on the monitor for a certain period""" 74 | 75 | def __init__(self, patience=5, monitor="val_acc"): 76 | super().__init__(monitor=monitor) 77 | self.patience = patience 78 | self.stopped_epoch = 0 79 | 80 | def on_train_begin(self, history=None): 81 | self.wait = 0 82 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 83 | 84 | def on_epoch_end(self, epoch, history=None): 85 | assert hasattr(self, "model") 86 | score = self._get_monitor_value(history, callback_name="Early stopping") 87 | if score is None: 88 | return 89 | 90 | if self.monitor_op(score, self.best): 91 | self.best = score 92 | self.wait = 0 93 | else: 94 | self.wait += 1 95 | 96 | if self.wait >= self.patience: 97 | self.model.stop_training = True 98 | self.stopped_epoch = epoch 99 | 100 | def on_train_end(self, history=None): 101 | if self.stopped_epoch > 0: 102 | print("Early stopped training") 103 | 104 | 105 | class ModelCheckpoint(Callback): 106 | """Saving the model during training, override the previous checkpoint""" 107 | 108 | def __init__(self, filepath, monitor='val_acc', save_best_only=False, save_weights_only=False): 109 | super().__init__(monitor=monitor) 110 | self.filepath = filepath 111 | self.save_best_only = save_best_only 112 | self.save_weights_only = save_weights_only 113 | 114 | def on_train_begin(self, history=None): 115 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 116 | 117 | def on_epoch_end(self, epoch, history=None): 118 | if self.save_best_only: 119 | score = self._get_monitor_value(history, callback_name="Model checkpoint") 120 | if score is None: 121 | return 122 | 123 | if self.monitor_op(score, self.best): 124 | self.best = score 125 | self._save_model() 126 | else: 127 | self._save_model() 128 | 129 | def _ensure_path_exists(self): 130 | if hasattr(self, "_path_checked") and self._path_checked: # pylint: disable=E0203 131 | return 132 | ensure_path_exists(self.filepath) 133 | self._path_checked = True 134 | 135 | def _save_model(self): 136 | self._ensure_path_exists() 137 | if not self.save_weights_only: 138 | write_yaml(self.model.to_yaml(), os.path.join(self.filepath, "arch.yaml"), dump=False) 139 | self.model.save_weights(os.path.join(self.filepath, "weights.h5")) -------------------------------------------------------------------------------- /setting_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from utils import load_yaml, json_serializable 4 | 5 | SETTING_DIR = f'{os.path.split(__file__)[0]}/defaults' 6 | 7 | 8 | class Settings: 9 | default_setting_file = None 10 | 11 | def __init__(self, conf_path=None): 12 | # Load default settings 13 | if conf_path is not None: 14 | self.from_json(load_yaml(conf_path)) # pylint: disable=E1101 15 | else: 16 | conf_path = os.path.join(SETTING_DIR, self.default_setting_file) 17 | self.from_json(load_yaml(conf_path)) # pylint: disable=E1101 18 | 19 | 20 | @json_serializable(key_path="./General", value_path="./Value") 21 | class VocalContourSettings(Settings): 22 | default_setting_file: str = "vocal_contour.yaml" 23 | 24 | def __init__(self, conf_path=None): 25 | self.transcription_mode: str = None 26 | self.checkpoint_path: str = None 27 | self.feature = self.VocalContourFeature() 28 | self.dataset = self.VocalContourDataset() 29 | self.model = self.VocalContourModel() 30 | self.training = self.VocalContourTraining() 31 | 32 | super().__init__(conf_path=conf_path) 33 | 34 | @json_serializable(key_path="./Settings", value_path="./Value") 35 | class VocalContourFeature(): 36 | def __init__(self): 37 | self.hop_size: float = None 38 | self.sampling_rate: int = None 39 | self.window_size: int = None 40 | 41 | @json_serializable(key_path="./Settings", value_path="./Value") 42 | class VocalContourDataset(): 43 | def __init__(self): 44 | self.save_path: str = None 45 | self.feature_save_path: str = None 46 | 47 | @json_serializable(key_path="./Settings", value_path="./Value") 48 | class VocalContourModel(): 49 | def __init__(self): 50 | self.save_prefix: str = None 51 | self.save_path: str = None 52 | 53 | @json_serializable(key_path="./Settings", value_path="./Value") 54 | class VocalContourTraining(): 55 | def __init__(self): 56 | self.epoch: int = None 57 | self.early_stop: int = None 58 | self.steps: int = None 59 | self.val_steps: int = None 60 | self.batch_size: int = None 61 | self.val_batch_size: int = None 62 | self.timesteps: int = None 63 | 64 | 65 | @json_serializable(key_path="./General", value_path="./Value") 66 | class VocalSettings(Settings): 67 | default_setting_file: str = "vocal.yaml" 68 | 69 | def __init__(self, conf_path=None): 70 | self.transcription_mode: str = None 71 | self.checkpoint_path: dict = None 72 | self.feature = self.VocalFeature() 73 | self.dataset = self.VocalDataset() 74 | self.model = self.VocalModel() 75 | self.inference = self.VocalInference() 76 | self.training = self.VocalTraining() 77 | 78 | super().__init__(conf_path=conf_path) 79 | 80 | @json_serializable(key_path="./Settings", value_path="./Value") 81 | class VocalFeature: 82 | def __init__(self): 83 | self.hop_size: float = None 84 | self.sampling_rate: int = None 85 | self.frequency_resolution: float = None 86 | self.frequency_center: float = None 87 | self.time_center: float = None 88 | self.gamma: list = None 89 | self.bins_per_octave: int = None 90 | 91 | @json_serializable(key_path="./Settings", value_path="./Value") 92 | class VocalDataset: 93 | def __init__(self): 94 | self.save_path: str = None 95 | self.feature_save_path: str = None 96 | 97 | @json_serializable(key_path="./Settings", value_path="./Value") 98 | class VocalModel: 99 | def __init__(self): 100 | self.save_prefix: str = None 101 | self.save_path: str = None 102 | self.min_kernel_size: int = None 103 | self.depth: int = None 104 | self.shake_drop: bool = True 105 | self.alpha: int = None 106 | self.semi_loss_weight: float = None 107 | self.semi_xi: float = None 108 | self.semi_epsilon: float = None 109 | self.semi_iterations: int = None 110 | 111 | @json_serializable(key_path="./Settings", value_path="./Value") 112 | class VocalInference: 113 | def __init__(self): 114 | self.context_length: int = None 115 | self.threshold: float = None 116 | self.min_duration: float = None 117 | self.pitch_model: str = None 118 | 119 | @json_serializable(key_path="./Settings", value_path="./Value") 120 | class VocalTraining: 121 | def __init__(self): 122 | self.epoch: int = None 123 | self.steps: int = None 124 | self.val_steps: int = None 125 | self.batch_size: int = None 126 | self.val_batch_size: int = None 127 | self.early_stop: int = None 128 | self.init_learning_rate: float = None 129 | self.context_length: int = None -------------------------------------------------------------------------------- /models/u_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Input, Model 3 | from tensorflow.keras.layers import ( 4 | BatchNormalization, 5 | Activation, 6 | Dropout, 7 | Conv2D, 8 | Conv2DTranspose, 9 | Add, 10 | Concatenate 11 | ) 12 | 13 | import os 14 | import sys 15 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 16 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 17 | 18 | from models.t2t import local_attention_2d, split_heads_2d, combine_heads_2d 19 | 20 | 21 | def conv_block(input_tensor, channel, kernel_size, strides=(2, 2), dilation_rate=1, dropout_rate=0.4): 22 | """Convolutional encoder block of U-net, encoder block does not downsample the input feature""" 23 | 24 | skip = input_tensor 25 | 26 | input_tensor = BatchNormalization()(Activation("relu")(input_tensor)) 27 | input_tensor = Dropout(dropout_rate)(input_tensor) 28 | input_tensor = Conv2D( 29 | channel, kernel_size, strides=strides, dilation_rate=dilation_rate, padding="same" 30 | )(input_tensor) 31 | 32 | input_tensor = BatchNormalization()(Activation("relu")(input_tensor)) 33 | input_tensor = Dropout(dropout_rate)(input_tensor) 34 | input_tensor = Conv2D( 35 | channel, kernel_size, strides=(1, 1), dilation_rate=dilation_rate, padding="same" 36 | )(input_tensor) 37 | 38 | if strides != (1, 1): 39 | skip = Conv2D(channel, (1, 1), strides=strides, padding="same")(skip) 40 | input_tensor = Add()([input_tensor, skip]) 41 | 42 | return input_tensor 43 | 44 | 45 | def transpose_conv_block(input_tensor, channel, kernel_size, strides=(2, 2), dropout_rate=0.4): 46 | skip = input_tensor 47 | 48 | input_tensor = BatchNormalization()(Activation("relu")(input_tensor)) 49 | input_tensor = Dropout(dropout_rate)(input_tensor) 50 | input_tensor = Conv2D(channel, kernel_size, strides=(1, 1), padding="same")(input_tensor) 51 | 52 | input_tensor = BatchNormalization()(Activation("relu")(input_tensor)) 53 | input_tensor = Dropout(dropout_rate)(input_tensor) 54 | input_tensor = Conv2DTranspose(channel, kernel_size, strides=strides, padding="same")(input_tensor) 55 | 56 | if strides != (1, 1): 57 | skip = Conv2DTranspose(channel, (1, 1), strides=strides, padding="same")(skip) 58 | input_tensor = Add()([input_tensor, skip]) 59 | 60 | return input_tensor 61 | 62 | 63 | def semantic_segmentation(feature_num=352, timesteps=256, multi_grid_layer_n=1, multi_grid_n=5, ch_num=1, out_class=2, dropout=0.4): 64 | """Improved U-net model with Atrous Spatial Pyramid Pooling (ASPP) block""" 65 | 66 | input_score = Input(shape=(timesteps, feature_num, ch_num), name="input_score_48") 67 | en = Conv2D(2**7, (7, 7), strides=(1, 1), padding="same")(input_score) 68 | 69 | en_l1 = conv_block(en, 2**7, (3, 3), strides=(2, 2)) 70 | en_l1 = conv_block(en_l1, 2**7, (3, 3), strides=(1, 1)) 71 | 72 | en_l2 = conv_block(en_l1, 2**7, (3, 3), strides=(2, 2)) 73 | en_l2 = conv_block(en_l2, 2**7, (3, 3), strides=(1, 1)) 74 | en_l2 = conv_block(en_l2, 2**7, (3, 3), strides=(1, 1)) 75 | 76 | en_l3 = conv_block(en_l2, 2**7, (3, 3), strides=(2, 2)) 77 | en_l3 = conv_block(en_l3, 2**7, (3, 3), strides=(1, 1)) 78 | en_l3 = conv_block(en_l3, 2**7, (3, 3), strides=(1, 1)) 79 | en_l3 = conv_block(en_l3, 2**7, (3, 3), strides=(1, 1)) 80 | 81 | en_l4 = conv_block(en_l3, 2**8, (3, 3), strides=(2, 2)) 82 | en_l4 = conv_block(en_l4, 2**8, (3, 3), strides=(1, 1)) 83 | en_l4 = conv_block(en_l4, 2**8, (3, 3), strides=(1, 1)) 84 | en_l4 = conv_block(en_l4, 2**8, (3, 3), strides=(1, 1)) 85 | en_l4 = conv_block(en_l4, 2**8, (3, 3), strides=(1, 1)) 86 | 87 | feature = en_l4 88 | for _ in range(multi_grid_layer_n): 89 | feature = BatchNormalization()(Activation("relu")(feature)) 90 | feature = Dropout(dropout)(feature) 91 | m = BatchNormalization()(Conv2D(2**9, (1, 1), strides=(1, 1), padding="same", activation="relu")(feature)) 92 | multi_grid = m 93 | for ii in range(multi_grid_n): 94 | m = BatchNormalization()( 95 | Conv2D(2**9, (3, 3), strides=(1, 1), dilation_rate=2**ii, padding="same", activation="relu")(feature) 96 | ) 97 | multi_grid = Concatenate()([multi_grid, m]) 98 | multi_grid = Dropout(dropout)(multi_grid) 99 | feature = Conv2D(2**9, (1, 1), strides=(1, 1), padding="same")(multi_grid) 100 | 101 | feature = BatchNormalization()(Activation("relu")(feature)) 102 | 103 | feature = Conv2D(2**8, (1, 1), strides=(1, 1), padding="same")(feature) 104 | feature = Add()([feature, en_l4]) 105 | de_l1 = transpose_conv_block(feature, 2**7, (3, 3), strides=(2, 2)) 106 | 107 | skip = de_l1 108 | de_l1 = BatchNormalization()(Activation("relu")(de_l1)) 109 | de_l1 = Concatenate()([de_l1, BatchNormalization()(Activation("relu")(en_l3))]) 110 | de_l1 = Dropout(dropout)(de_l1) 111 | de_l1 = Conv2D(2**7, (1, 1), strides=(1, 1), padding="same")(de_l1) 112 | de_l1 = Add()([de_l1, skip]) 113 | de_l2 = transpose_conv_block(de_l1, 2**7, (3, 3), strides=(2, 2)) 114 | 115 | skip = de_l2 116 | de_l2 = BatchNormalization()(Activation("relu")(de_l2)) 117 | de_l2 = Concatenate()([de_l2, BatchNormalization()(Activation("relu")(en_l2))]) 118 | de_l2 = Dropout(dropout)(de_l2) 119 | de_l2 = Conv2D(2**7, (1, 1), strides=(1, 1), padding="same")(de_l2) 120 | de_l2 = Add()([de_l2, skip]) 121 | de_l3 = transpose_conv_block(de_l2, 2**7, (3, 3), strides=(2, 2)) 122 | 123 | skip = de_l3 124 | de_l3 = BatchNormalization()(Activation("relu")(de_l3)) 125 | de_l3 = Concatenate()([de_l3, BatchNormalization()(Activation("relu")(en_l1))]) 126 | de_l3 = Dropout(dropout)(de_l3) 127 | de_l3 = Conv2D(2**7, (1, 1), strides=(1, 1), padding="same")(de_l3) 128 | de_l3 = Add()([de_l3, skip]) 129 | de_l4 = transpose_conv_block(de_l3, 2**7, (3, 3), strides=(2, 2)) 130 | 131 | de_l4 = BatchNormalization()(Activation("relu")(de_l4)) 132 | de_l4 = Dropout(dropout)(de_l4) 133 | out = Conv2D(out_class, (1, 1), strides=(1, 1), padding="same", name="prediction")(de_l4) 134 | 135 | return Model(inputs=input_score, outputs=out) -------------------------------------------------------------------------------- /vocal/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pretty_midi 3 | from scipy.stats import norm 4 | 5 | import os 6 | import sys 7 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 8 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 9 | 10 | from utils import get_logger 11 | 12 | logger = get_logger("Vocal Inference") 13 | 14 | 15 | def _find_peaks(seq, ctx_len=2, threshold=0.5): 16 | # Discard the first and the last frames. 17 | peaks = [] 18 | for idx in range(ctx_len, len(seq) - ctx_len - 1): 19 | cur_val = seq[idx] 20 | if cur_val < threshold: 21 | continue 22 | if not all(cur_val > seq[idx - ctx_len:idx]): 23 | continue 24 | if not all(cur_val >= seq[idx + 1:idx + ctx_len + 1]): 25 | continue 26 | peaks.append(idx) 27 | return peaks 28 | 29 | def _find_first_bellow_th(seq, threshold=0.5): 30 | activate = False 31 | for idx, val in enumerate(seq): 32 | if val > threshold: 33 | activate = True 34 | if activate and val < threshold: 35 | return idx 36 | return 0 37 | 38 | def infer_interval(pred, ctx_len=2, threshold=0.5, min_dura=0.1, t_unit=0.02): 39 | """Infer the onset and offset time of notes from the raw prediction values""" 40 | 41 | on_peaks = _find_peaks(pred[:, 2], ctx_len=ctx_len, threshold=threshold) 42 | off_peaks = _find_peaks(pred[:, 4], ctx_len=ctx_len, threshold=threshold) 43 | if len(on_peaks) == 0 or len(off_peaks) == 0: 44 | return None 45 | 46 | # Clearing out offsets before first onset (since onset is more accurate) 47 | off_peaks = [idx for idx in off_peaks if idx > on_peaks[0]] 48 | 49 | on_peak_id = 0 50 | est_interval = [] 51 | min_len = min_dura / t_unit 52 | while on_peak_id < len(on_peaks) - 1: 53 | on_id = on_peaks[on_peak_id] 54 | next_on_id = on_peaks[on_peak_id + 1] 55 | 56 | off_peak_id = np.where(np.array(off_peaks) >= on_id + min_len)[0] 57 | if len(off_peak_id) == 0: 58 | off_id = _find_first_bellow_th(pred[on_id:, 0], threshold=threshold) 59 | else: 60 | off_id = off_peaks[off_peak_id[0]] 61 | 62 | if on_id < next_on_id < off_id \ 63 | and np.mean(pred[on_id:next_on_id, 1]) > np.mean(pred[on_id:next_on_id, 0]): 64 | # Discard current onset, since the duration between current and 65 | # next onset shows an inactive status. 66 | on_peak_id += 1 67 | continue 68 | 69 | if off_id > next_on_id: 70 | # Missing offset between current and next onset. 71 | if (off_id - next_on_id) < min_len: 72 | # Assign the offset after the next onset to the current onset. 73 | est_interval.append((on_id * t_unit, off_id * t_unit)) 74 | on_peak_id += 1 75 | else: 76 | # Insert an additional offset. 77 | est_interval.append((on_id * t_unit, next_on_id * t_unit)) 78 | on_peak_id += 1 79 | elif (off_id - on_id) >= min_len: 80 | # Normal case that one onset has a corressponding offset. 81 | est_interval.append((on_id * t_unit, off_id * t_unit)) 82 | on_peak_id += 1 83 | else: 84 | # Do nothing 85 | on_peak_id += 1 86 | 87 | # Deal with the border case, the last onset peak. 88 | on_id = on_peaks[-1] 89 | off_id = _find_first_bellow_th(pred[on_id:, 0], threshold=threshold) + on_id 90 | if off_id - on_id >= min_len: 91 | est_interval.append((on_id * t_unit, off_id * t_unit)) 92 | 93 | return np.array(est_interval) 94 | 95 | 96 | def _conclude_freq(freqs, std=2, min_count=3): 97 | """Conclude the average frequency with gaussian distribution weighting""" 98 | 99 | # Expect freqs contains zero 100 | half_len = len(freqs) // 2 101 | prob_func = lambda x: norm(0, std).pdf(x - half_len) 102 | weights = [prob_func(idx) for idx in range(len(freqs))] 103 | avg_freq = 0 104 | count = 0 105 | total_weight = 1e-8 106 | for weight, freq in zip(weights, freqs): 107 | if freq < 1e-6: 108 | continue 109 | 110 | avg_freq += weight * freq 111 | total_weight += weight 112 | count += 1 113 | 114 | return avg_freq / total_weight if count >= min_count else 0 115 | 116 | 117 | def infer_midi(interval, agg_f0, t_unit=0.02): 118 | """Inference the given interval and aggregated F0 to MIDI file""" 119 | 120 | fs = round(1 / t_unit) 121 | max_secs = max(record["end_time"] for record in agg_f0) 122 | total_frames = round(max_secs) * fs + 10 123 | flat_f0 = np.zeros(total_frames) 124 | for record in agg_f0: 125 | start_idx = int(round(record["start_time"] * fs)) 126 | end_idx = int(round(record["end_time"] * fs)) 127 | flat_f0[start_idx:end_idx] = record["frequency"] 128 | 129 | notes = [] 130 | drum_notes = [] 131 | skip_num = 0 132 | for onset, offset in interval: 133 | start_idx = int(round(onset * fs)) 134 | end_idx = int(round(offset * fs)) 135 | freqs = flat_f0[start_idx:end_idx] 136 | avg_hz = _conclude_freq(freqs) 137 | if avg_hz < 1e-6: 138 | skip_num += 1 139 | note = pretty_midi.Note(velocity=80, pitch=77, start=onset, end=offset) 140 | drum_notes.append(note) 141 | continue 142 | 143 | note_num = int(round(pretty_midi.hz_to_note_number(avg_hz))) 144 | if not (0 <= note_num <= 127): 145 | logger.warning("Caught invalid note number: %d (should be in range 0~127). Skipping.", note_num) 146 | skip_num += 1 147 | continue 148 | note = pretty_midi.Note(velocity=80, pitch=note_num, start=onset, end=offset) 149 | notes.append(note) 150 | 151 | if skip_num > 0: 152 | logger.warning("A total of %d notes are skipped due to lack of corressponding pitch information.", skip_num) 153 | 154 | inst = pretty_midi.Instrument(program=0) 155 | inst.notes += notes 156 | drum_inst = pretty_midi.Instrument(program=1, is_drum=True, name="Missing Notes") 157 | drum_inst.notes += drum_notes 158 | midi = pretty_midi.PrettyMIDI() 159 | midi.instruments.append(inst) 160 | midi.instruments.append(drum_inst) 161 | return midi 162 | -------------------------------------------------------------------------------- /defaults/vocal.yaml: -------------------------------------------------------------------------------- 1 | General: 2 | TranscriptionMode: 3 | Description: Mode of transcription by executing the `omnizart vocal transcribe` command. 4 | Type: String 5 | Value: Semi 6 | CheckpointPath: 7 | Description: Path to the pre-trained models. 8 | Type: Map 9 | SubType: [String, String] 10 | Value: 11 | Super: pretrained_models/vocal_super 12 | Semi: pretrained_models/vocal_semi 13 | Feature: 14 | Description: Default settings of feature extraction for drum transcription. 15 | Settings: 16 | HopSize: 17 | Description: Hop size in seconds with respect to sampling rate. 18 | Type: Float 19 | Value: 0.02 20 | SamplingRate: 21 | Description: Adjust input sampling rate to this value. 22 | Type: Integer 23 | Value: 16000 24 | FrequencyResolution: 25 | Type: Float 26 | Value: 2.0 27 | FrequencyCenter: 28 | Description: Lowest frequency to extract. 29 | Type: Float 30 | Value: 80 31 | TimeCenter: 32 | Description: Highest frequency to extract (1/time_center). 33 | Type: Float 34 | Value: 0.001 35 | Gamma: 36 | Type: List 37 | SubType: Float 38 | Value: [0.24, 0.6, 1.0] 39 | BinsPerOctave: 40 | Description: Number of bins for each octave. 41 | Type: Integer 42 | Value: 48 43 | Dataset: 44 | Description: Settings of datasets. 45 | Settings: 46 | SavePath: 47 | Description: Path for storing the downloaded datasets. 48 | Type: String 49 | Value: ./ 50 | FeatureSavePath: 51 | Description: Path for storing the extracted feature. Default to the path under the dataset folder. 52 | Type: String 53 | Value: + 54 | Model: 55 | Description: Default settings of training / testing the model. 56 | Settings: 57 | SavePrefix: 58 | Description: Prefix of the trained model's name to be saved. 59 | Type: String 60 | Value: vocal 61 | SavePath: 62 | Description: Path to save the trained model. 63 | Type: String 64 | Value: ./checkpoints/vocal 65 | MinKernelSize: 66 | Description: Minimum kernel size of convolution layers in each pyramid block. 67 | Type: Integer 68 | Value: 16 69 | Depth: 70 | Description: Total number of pyramid blocks will be -> (Depth - 2) / 2 . 71 | Type: Integer 72 | Value: 110 73 | Alpha: 74 | Type: Integer 75 | Value: 270 76 | ShakeDrop: 77 | Description: Whether to leverage Shake Drop normalization when back propagation. 78 | Type: Bool 79 | Value: True 80 | SemiLossWeight: 81 | Description: Weighting factor of the semi-supervise loss. Supervised loss will not be affected by this parameter. 82 | Type: Float 83 | Value: 1.0 84 | SemiXi: 85 | Description: A small constant value for weighting the adverarial perturbation. 86 | Type: Float 87 | Value: 0.000001 88 | SemiEpsilon: 89 | Description: Weighting factor of the output adversarial perturbation. 90 | Type: Float 91 | Value: 8.0 92 | SemiIterations: 93 | Description: Number of iterations when generating the adversarial perturbation. 94 | Type: Integer 95 | Value: 2 96 | Inference: 97 | Description: Default settings when infering notes. 98 | Settings: 99 | ContextLength: 100 | Description: Length of context that will be used to find the peaks. 101 | Type: Integer 102 | Value: 2 103 | Threshold: 104 | Description: Threshold that will be applied to clip the predicted values to either 0 or 1. 105 | Type: Float 106 | Value: 0.5 107 | MinDuration: 108 | Description: Minimum required length of each note, in seconds. 109 | Type: Float 110 | Value: 0.1 111 | PitchModel: 112 | Description: The model for predicting the pitch contour. Default to use vocal-contour modeul. Could be path or mode name. 113 | Type: String 114 | Value: VocalContour 115 | Training: 116 | Description: Hyper parameters for training 117 | Settings: 118 | Epoch: 119 | Description: Maximum number of epochs for training. 120 | Type: Integer 121 | Value: 20 122 | Steps: 123 | Description: Number of training steps for each epoch. 124 | Type: Integer 125 | Value: 1000 126 | ValSteps: 127 | Description: Number of validation steps after each training epoch. 128 | Type: Integer 129 | Value: 50 130 | BatchSize: 131 | Description: Batch size of each training step. 132 | Type: Integer 133 | Value: 64 134 | ValBatchSize: 135 | Description: Batch size of each validation step. 136 | Type: Integer 137 | Value: 64 138 | EarlyStop: 139 | Description: Terminate the training if the validation performance doesn't imrove after n epochs. 140 | Type: Integer 141 | Value: 10 142 | InitLearningRate: 143 | Descriptoin: Initial learning rate. 144 | Type: Float 145 | Value: 0.0001 146 | ContextLength: 147 | Description: Context to be considered before and after current timestamp. 148 | Type: Integer 149 | Value: 9 -------------------------------------------------------------------------------- /constants/datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import csv 4 | import glob 5 | from os.path import join as jpath 6 | from shutil import copy 7 | 8 | import pretty_midi 9 | import numpy as np 10 | import re 11 | 12 | MODULE_PATH = os.path.abspath(jpath(os.path.split(__file__)[0], '..')) 13 | if sys.path[0] != MODULE_PATH: sys.path.append(MODULE_PATH) 14 | 15 | from base import Label 16 | from utils import get_logger 17 | 18 | 19 | logger = get_logger("Constant Datasets") 20 | 21 | 22 | def _get_file_list(dataset_path, dirs, ext): 23 | files = [] 24 | for _dir in dirs: 25 | files += glob.glob(os.path.join(dataset_path, _dir, "*" + ext)) 26 | return files 27 | 28 | 29 | class BaseStructure: 30 | """Defines the necessary attributes and common functions for each sub-dataset structure class""" 31 | 32 | # Is labeled or unlabeled dataset 33 | is_labeled = True 34 | 35 | # The URL for downloading the dataset. 36 | url = None 37 | 38 | # The extension of ground-truth files (e.g. .mid, .csv). 39 | label_ext = None 40 | 41 | # Record folders that contain trainig wav files 42 | train_wavs = None 43 | 44 | # Record folders that contain testing wav files 45 | test_wavs = None 46 | 47 | # Record folders that contains training labels 48 | train_labels = None 49 | 50 | # Records folders that contains testing labels 51 | test_labels = None 52 | 53 | @classmethod 54 | def _get_data_pair(cls, wavs, labels): 55 | label_path_mapping = {os.path.basename(label): label for label in labels} 56 | 57 | pair = [] 58 | for wav in wavs: 59 | basename = os.path.basename(wav) 60 | label_name = cls._name_transform(basename).replace(".wav", cls.label_ext) 61 | label_path = label_path_mapping[label_name] 62 | assert os.path.exists(label_path) 63 | pair.append((wav, label_path)) 64 | 65 | return pair 66 | 67 | # FIX 68 | @classmethod 69 | def _get_unlabeled_data_pair(cls, wavs): 70 | pair = [] 71 | for wav in wavs: 72 | pair.append((wav)) 73 | return pair 74 | 75 | @classmethod 76 | def get_train_data_pair(cls, dataset_path): 77 | """Get pair of training file and the coressponding label file path.""" 78 | # print(cls.get_train_wavs(dataset_path)) 79 | # print(cls.get_train_labels(dataset_path)) 80 | return cls._get_data_pair(cls.get_train_wavs(dataset_path), cls.get_train_labels(dataset_path)) if cls.is_labeled else cls._get_unlabeled_data_pair(cls.get_train_wavs(dataset_path)) 81 | 82 | @classmethod 83 | def get_test_data_pair(cls, dataset_path): 84 | """Get pair of testing file and the coressponding label file path.""" 85 | return cls._get_data_pair(cls.get_test_wavs(dataset_path), cls.get_test_labels(dataset_path)) if cls.is_labeled else cls._get_unlabeled_data_pair(cls.get_test_wavs(dataset_path)) 86 | 87 | @classmethod 88 | def _name_transform(cls, basename): 89 | # Transform the basename of wav file to the corressponding label file name. 90 | return basename 91 | 92 | @classmethod 93 | def get_train_wavs(cls, dataset_path): 94 | """Get list of complete train wav paths""" 95 | return _get_file_list(dataset_path, cls.train_wavs, ".wav") 96 | 97 | @classmethod 98 | def get_test_wavs(cls, dataset_path): 99 | """Get list of complete test wav paths""" 100 | return _get_file_list(dataset_path, cls.test_wavs, ".wav") 101 | 102 | @classmethod 103 | def get_train_labels(cls, dataset_path): 104 | """Get list of complete train label paths""" 105 | return _get_file_list(dataset_path, cls.train_labels, cls.label_ext) 106 | 107 | @classmethod 108 | def get_test_labels(cls, dataset_path): 109 | """Get list of complete test label paths""" 110 | return _get_file_list(dataset_path, cls.test_labels, cls.label_ext) 111 | 112 | @classmethod 113 | def load_label(cls, label_path): 114 | """Load and parse labels for the given label file path""" 115 | 116 | midi = pretty_midi.PrettyMIDI(label_path) 117 | labels = [] 118 | for inst in midi.instruments: 119 | if inst.is_drum: 120 | continue 121 | for note in inst.notes: 122 | label = Label( 123 | start_time=note.start, 124 | end_time=note.end, 125 | note=note.pitch, 126 | velocity=note.velocity, 127 | instrument=inst.program 128 | ) 129 | if label.note == -1: 130 | continue 131 | labels.append(label) 132 | return labels 133 | 134 | 135 | class VocalAlignStructure(BaseStructure): 136 | """Constant settings of vocal-semi datasets""" 137 | 138 | url = None 139 | 140 | # Label extension for note-level transcription. 141 | label_ext = ".csv" 142 | 143 | # Folder to train wavs 144 | train_wavs = ["audios"] 145 | 146 | # Folder to train labels 147 | train_labels = ["labels"] 148 | 149 | # Folder to test wavs 150 | test_wavs = [] 151 | 152 | # Folder to test labels 153 | test_labels = [] 154 | 155 | @classmethod 156 | def load_label(cls, label_path): 157 | with open(label_path, "r") as lin: 158 | lines = lin.readlines() 159 | 160 | labels = [] 161 | for line in lines: 162 | if not re.fullmatch("([\d\s\.]+,){2}[\d\s\.]+(,.+)*", line.strip()): 163 | continue 164 | 165 | onset, offset, note = [element.strip() for element in line.split(",")[:3]] 166 | 167 | labels.append(Label( 168 | start_time=float(onset), 169 | end_time=float(offset), 170 | note=round(float(note)) 171 | )) 172 | return labels 173 | 174 | 175 | class VocalContourStructure(BaseStructure): 176 | """Constant settings of vocal-semi datasets""" 177 | 178 | url = None 179 | 180 | # Label extension for note-level transcription. 181 | label_ext = ".csv" 182 | 183 | # Folder to train wavs 184 | train_wavs = ["audios"] 185 | 186 | # Folder to train labels 187 | train_labels = ["labels"] 188 | 189 | # Folder to test wavs 190 | test_wavs = [] 191 | 192 | # Folder to test labels 193 | test_labels = [] 194 | 195 | @classmethod 196 | def load_label(cls, label_path): 197 | with open(label_path, "r") as fin: 198 | lines = fin.readlines() 199 | 200 | labels = [] 201 | t_unit = 256 / 44100 # ~= 0.0058 secs 202 | for line in lines: 203 | elems = line.strip().split(",") 204 | sec, hz = float(elems[0]), float(elems[1]) 205 | if hz < 1e-10: 206 | continue 207 | note = float(pretty_midi.hz_to_note_number(hz)) # Convert return type of np.float64 to float 208 | end_t = sec + t_unit 209 | labels.append(Label(start_time=sec, end_time=end_t, note=note)) 210 | 211 | return labels 212 | 213 | class UnlabeledStructure(BaseStructure): 214 | """Constant settings of unlabeled datasets""" 215 | 216 | is_labeled = False 217 | 218 | url = None 219 | 220 | # Folder to train wavs 221 | train_wavs = ["audios"] 222 | 223 | # Folder to test wavs 224 | test_wavs = [] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | 4 | import tqdm 5 | 6 | def get_train_val_feat_file_list(feature_folder, split=0.9): 7 | feat_files = glob.glob(f"{feature_folder}/*.hdf") 8 | sidx = round(len(feat_files) * split) 9 | random.shuffle(feat_files) 10 | train_files = feat_files[:sidx] 11 | val_files = feat_files[sidx:] 12 | return train_files, val_files 13 | 14 | PROGRESS_BAR_FORMAT = "{desc} - {percentage:3.0f}% |{bar:40}| {n_fmt}/{total_fmt} \ 15 | [{elapsed}<{remaining},{rate_fmt}{postfix}]" 16 | 17 | def format_num(num, digit=4): 18 | """Formatting the float values as string.""" 19 | rounding = f".{digit}g" 20 | num_str = f"{num:{rounding}}".replace("+0", "+").replace("-0", "-") 21 | num = str(num) 22 | return num_str if len(num_str) < len(num) else num 23 | 24 | def gen_bar_postfix(history, targets=["loss", "accuracy"], name_transform=["loss", "acc"]): # pylint: disable=W0102 25 | """Generate string of metrics status to be appended to the end of the progress bar. 26 | Parameters 27 | ---------- 28 | history: dict 29 | History records generated by ``train_steps``. 30 | targets: list[str] 31 | List of metric's names to be extracted as the postfix. 32 | name_transform: list[str] 33 | The alias metric name that will be showed on the bar. 34 | Should be the same length, same order as ``targets``. 35 | Returns 36 | ------- 37 | postfix: str 38 | The extracted metrics information. 39 | """ 40 | info = [] 41 | for target, name in zip(targets, name_transform): 42 | if target not in history: 43 | continue 44 | val = history[target] 45 | val_str = format_num(val) 46 | info_str = f"{name}: {val_str}" 47 | info.append(info_str) 48 | return ", ".join(info) 49 | 50 | def train_steps(model, dataset, steps=None, bar_title=None, validate=False): 51 | """A single training epoch with multiple steps. 52 | Customized training epoch compared to the built-in ``.fit(...)`` function 53 | of tensorflow keras model. The major difference is that the ``.fit()`` 54 | requires the dataset to yield either (feature, target) or 55 | (feature, target, weight) pairs, which losses the flexibility of yielding 56 | different numbers of elements for each iteration. And thus we'd decide to 57 | implement our own training logic and relevant utilities same as provided 58 | in tensorflow like `callbacks`. 59 | Parameters 60 | ---------- 61 | model: 62 | Compiled tf.keras model. 63 | dataset: 64 | The loaded tf.data.Dataset object that yields (feature, target) pairs 65 | at the first two elements, indicating that you can yields more than 66 | two elements for each iteration, but only the first two will be used 67 | for training. 68 | steps: int 69 | Total number of steps that the dataset object will yield. This is used 70 | for visualizing the training progress. 71 | bar_title: str 72 | Additional title to be printed at the start of the progress bar. 73 | validate: bool 74 | Indicating whether it is now in validation stage or it is within 75 | training loop that should update the weights of the model. 76 | Returns 77 | ------- 78 | history: dict 79 | The history of scores for each metric during each epoch. 80 | """ 81 | iter_bar = tqdm.tqdm(dataset, total=steps, desc=bar_title, bar_format=PROGRESS_BAR_FORMAT) 82 | 83 | for iters, data in enumerate(iter_bar): 84 | feat, label = data[:2] # Assumed the first two elements are feature and label, respectively. 85 | if validate: 86 | step_result = model.test_on_batch(feat, label, return_dict=True) 87 | else: 88 | step_result = model.train_on_batch(feat, label, return_dict=True) 89 | 90 | if iters == 0: 91 | # model.metrics_names is only available after the first train_on_batch 92 | metrics = model.metrics_names 93 | history = {metric: 0 for metric in metrics} 94 | history.update({f"{metric}_sum": 0 for metric in metrics}) 95 | 96 | for metric in metrics: 97 | history[f"{metric}_sum"] += step_result[metric] 98 | history[metric] = history[f"{metric}_sum"] / (iters + 1) 99 | iter_bar.set_postfix_str(gen_bar_postfix(history)) 100 | 101 | # Remove metric_sum columns in the history 102 | history = {metric: history[metric] for metric in metrics} 103 | return history 104 | 105 | def execute_callbacks(callbacks, func_name, **kwargs): 106 | """Execute callbacks at different training stage.""" 107 | if callbacks is not None: 108 | for callback in callbacks: 109 | getattr(callback, func_name)(**kwargs) 110 | 111 | def train_epochs( 112 | model, 113 | train_dataset, 114 | validate_dataset=None, 115 | epochs=10, 116 | steps=100, 117 | val_steps=100, 118 | callbacks=None, 119 | **kwargs 120 | ): 121 | """Logic of training loop. 122 | The main loop of the training, with events-based life-cycle management 123 | that triggers different events for all callbacks. Event types are the 124 | same as the original tensorflow implementation. 125 | Event types and their order: 126 | .. code-block:: none 127 | 128 | | 129 | |-on_train_begin 130 | T| |-on_epoch_begin 131 | R| | 132 | A| L|-on_train_batch_begin 133 | I| O|-on_train_batch_end 134 | N| O| 135 | I| P|-on_test_batch_begin 136 | N| |-on_test_batch_end 137 | G| | 138 | | |-on_epoch_end 139 | |-on_train_end 140 | | 141 | 142 | Parameters 143 | ---------- 144 | model: 145 | Compiled tensorflow keras model. 146 | train_dataset: 147 | The tf.data.Dataset instance for training. 148 | validate_dataset: 149 | The tf.data.Dataset instance for validation. If not given, validation 150 | stage will be skipped. 151 | epochs: int 152 | Number of maximum training epochs. 153 | steps: int 154 | Number of training steps for each epoch. Should be the same as 155 | when initiating the dataset instance. 156 | val_steps: int 157 | Number of validation steps for each epoch.Should be the same as 158 | when initiating the dataset instance. 159 | callbacks: 160 | List of callback instances. 161 | Returns 162 | ------- 163 | history: dict 164 | Score history of each metrics during each epoch of both training 165 | and validation. 166 | See Also 167 | -------- 168 | omn.callbacks: 169 | Implementation and available callbacks for training. 170 | """ 171 | history = {"train": [], "validate": []} 172 | execute_callbacks(callbacks, "_set_model", model=model) 173 | execute_callbacks(callbacks, "on_train_begin") 174 | for epoch_idx in range(epochs): 175 | # Epoch begin 176 | execute_callbacks(callbacks, "on_epoch_begin", epoch=epoch_idx+1) # noqa: E226 177 | if model.stop_training: 178 | break 179 | 180 | print(f"Epoch: {epoch_idx+1}/{epochs}") 181 | 182 | # Train batch begin 183 | execute_callbacks(callbacks, "on_train_batch_begin") 184 | results = train_steps(model, dataset=train_dataset, steps=steps, bar_title="Train ", **kwargs) 185 | 186 | # Train batch end 187 | execute_callbacks(callbacks, "on_train_batch_end") 188 | history["train"].append(results) 189 | 190 | # Test batch begin 191 | execute_callbacks(callbacks, "on_test_batch_begin") 192 | val_results = {} 193 | if validate_dataset is not None: 194 | val_results = train_steps( 195 | model, dataset=validate_dataset, steps=val_steps, validate=True, bar_title="Validate", **kwargs 196 | ) 197 | 198 | # Test batch end 199 | execute_callbacks(callbacks, "on_test_batch_end") 200 | history["validate"].append(val_results) 201 | 202 | # Epoch end 203 | execute_callbacks(callbacks, "on_epoch_end", epoch=epoch_idx+1, history=history) # noqa: E226 204 | 205 | execute_callbacks(callbacks, "on_train_end") 206 | return history -------------------------------------------------------------------------------- /models/t2t.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def shape_list(input_tensor): 9 | """Return list of dims, statically where possible""" 10 | 11 | tensor = tf.convert_to_tensor(input_tensor) 12 | 13 | # If unknown rank, return dynamic shape 14 | if tensor.get_shape().dims is None: 15 | return tf.shape(tensor) 16 | 17 | static = tensor.get_shape().as_list() 18 | shape = tf.shape(tensor) 19 | 20 | ret = [] 21 | for i, dim in enumerate(static): 22 | if dim is None: 23 | dim = shape[i] 24 | ret.append(dim) 25 | return ret 26 | 27 | 28 | def reshape_range(tensor, i, j, shape): 29 | """Reshapes a tensor between dimensions i and j""" 30 | 31 | t_shape = shape_list(tensor) 32 | target_shape = t_shape[:i] + shape + t_shape[j:] 33 | return tf.reshape(tensor, target_shape) 34 | 35 | 36 | def cast_like(x, y): 37 | """Cast x to y's dtype, if necessary.""" 38 | x = tf.convert_to_tensor(x) 39 | y = tf.convert_to_tensor(y) 40 | 41 | if x.dtype.base_dtype == y.dtype.base_dtype: 42 | return x 43 | 44 | cast_x = tf.cast(x, y.dtype) 45 | if cast_x.device != x.device: 46 | x_name = "(eager Tensor)" 47 | try: 48 | x_name = x.name 49 | except AttributeError: 50 | pass 51 | tf.compat.v1.logging.warning("Cast for %s may induce copy from '%s' to '%s'", x_name, x.device, cast_x.device) 52 | return cast_x 53 | 54 | 55 | def split_last_dimension(x, n): 56 | """Reshape x so that the last dimension becomes two dimensions""" 57 | 58 | x_shape = shape_list(x) 59 | m = x_shape[-1] 60 | if isinstance(m, int) and isinstance(n, int): 61 | assert m % n == 0 62 | return tf.reshape(x, x_shape[:-1] + [n, m // n]) 63 | 64 | 65 | def split_heads_2d(x, num_heads): 66 | """Split channels (dimension 3) into multiple heads (becomes dimension 1)""" 67 | return tf.transpose(split_last_dimension(x, num_heads), [0, 3, 1, 2, 4]) 68 | 69 | 70 | def pad_to_multiple_2d(x, block_shape): 71 | """Making sure x is a multiple of shape""" 72 | 73 | old_shape = x.get_shape().dims 74 | last = old_shape[-1] 75 | if len(old_shape) == 4: 76 | height_padding = -shape_list(x)[1] % block_shape[0] 77 | width_padding = -shape_list(x)[2] % block_shape[1] 78 | paddings = [[0, 0], [0, height_padding], [0, width_padding], [0, 0]] 79 | elif len(old_shape) == 5: 80 | height_padding = -shape_list(x)[2] % block_shape[0] 81 | width_padding = -shape_list(x)[3] % block_shape[1] 82 | paddings = [[0, 0], [0, 0], [0, height_padding], [0, width_padding], [0, 0]] 83 | 84 | padded_x = tf.pad(x, paddings) 85 | padded_shape = padded_x.get_shape().as_list() 86 | padded_shape = padded_shape[:-1] + [last] 87 | padded_x.set_shape(padded_shape) 88 | return padded_x 89 | 90 | 91 | def gather_indices_2d(x, block_shape, block_stride): 92 | """Getting gather indices.""" 93 | 94 | # making an identity matrix kernel 95 | kernel = tf.eye(block_shape[0] * block_shape[1]) 96 | kernel = reshape_range(kernel, 0, 1, [block_shape[0], block_shape[1], 1]) 97 | # making indices [1, h, w, 1] to appy convs 98 | x_shape = shape_list(x) 99 | indices = tf.range(x_shape[2] * x_shape[3]) 100 | indices = tf.reshape(indices, [1, x_shape[2], x_shape[3], 1]) 101 | indices = tf.nn.conv2d( 102 | tf.cast(indices, tf.float32), kernel, strides=[1, block_stride[0], block_stride[1], 1], padding="VALID" 103 | ) 104 | # making indices [num_blocks, dim] to gather 105 | dims = shape_list(indices)[:3] 106 | if all([isinstance(dim, int) for dim in dims]): 107 | num_blocks = functools.reduce(operator.mul, dims, 1) 108 | else: 109 | num_blocks = tf.reduce_prod(dims) 110 | indices = tf.reshape(indices, [num_blocks, -1]) 111 | return tf.cast(indices, tf.int32) 112 | 113 | 114 | def gather_blocks_2d(x, indices): 115 | """Gathers flattened blocks from x""" 116 | 117 | x_shape = shape_list(x) 118 | x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])]) 119 | # [length, batch, heads, dim] 120 | x_t = tf.transpose(x, [2, 0, 1, 3]) 121 | x_new = tf.gather(x_t, indices) 122 | # returns [batch, heads, num_blocks, block_length ** 2, dim] 123 | return tf.transpose(x_new, [2, 3, 0, 1, 4]) 124 | 125 | 126 | def combine_last_two_dimensions(x): 127 | """Reshape x so that the last two dimension become one""" 128 | 129 | x_shape = shape_list(x) 130 | a, b = x_shape[-2:] 131 | return tf.reshape(x, x_shape[:-2] + [a*b]) # noqa: E226 132 | 133 | 134 | def combine_heads_2d(x): 135 | """Inverse of split_heads_2d""" 136 | return combine_last_two_dimensions(tf.transpose(x, [0, 2, 3, 1, 4])) 137 | 138 | 139 | def embedding_to_padding(emb): 140 | """Calculates the padding mask based on which embeddings are all zero""" 141 | 142 | emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) 143 | return tf.compat.v1.to_float(tf.equal(emb_sum, 0.0)) 144 | 145 | 146 | def scatter_blocks_2d(x, indices, shape): 147 | """scatters blocks from x into shape with indices""" 148 | 149 | x_shape = shape_list(x) 150 | # [length, batch, heads, dim] 151 | x_t = tf.transpose(tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), [2, 0, 1, 3]) 152 | x_t_shape = shape_list(x_t) 153 | indices = tf.reshape(indices, [-1, 1]) 154 | scattered_x = tf.scatter_nd(indices, x_t, x_t_shape) 155 | scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3]) 156 | return tf.reshape(scattered_x, shape) 157 | 158 | 159 | def mixed_precision_is_enabled(activation_dtype=None, weight_dtype=None, hparams=None): 160 | assert not ( 161 | hparams and (activation_dtype or weight_dtype) 162 | ), "Provide only hparams or activation_dtype and weight_dtype" 163 | if hparams and hasattr(hparams, "activation_dtype") and hasattr(hparams, "weight_dtype"): 164 | activation_dtype = hparams.activation_dtype 165 | weight_dtype = hparams.weight_dtype 166 | return activation_dtype == tf.float16 and weight_dtype == tf.float32 167 | 168 | 169 | def maybe_upcast(logits, activation_dtype=None, weight_dtype=None, hparams=None): 170 | if mixed_precision_is_enabled(activation_dtype, weight_dtype, hparams): 171 | return tf.cast(logits, tf.float32) 172 | return logits 173 | 174 | 175 | def dropout_with_broadcast_dims(x, keep_prob, broadcast_dims=None, **kwargs): 176 | """Like tf.nn.dropout but takes broadcast_dims instead of noise_shape""" 177 | assert "noise_shape" not in kwargs 178 | if broadcast_dims: 179 | shape = tf.shape(x) 180 | ndims = len(x.get_shape()) 181 | # Allow dimensions like "-1" as well. 182 | broadcast_dims = [dim + ndims if dim < 0 else dim for dim in broadcast_dims] 183 | kwargs["noise_shape"] = [1 if i in broadcast_dims else shape[i] for i in range(ndims)] 184 | return tf.compat.v1.nn.dropout(x, keep_prob, **kwargs) 185 | 186 | 187 | def dot_product_attention( 188 | q, 189 | k, 190 | v, 191 | bias, 192 | dropout_rate=0.0, 193 | name=None, 194 | save_weights_to=None, 195 | dropout_broadcast_dims=None, 196 | activation_dtype=None, 197 | weight_dtype=None, 198 | ): 199 | with tf.compat.v1.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: 200 | logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] 201 | if bias is not None: 202 | bias = cast_like(bias, logits) 203 | logits += bias 204 | # If logits are fp16, upcast before softmax 205 | logits = maybe_upcast(logits, activation_dtype, weight_dtype) 206 | weights = tf.nn.softmax(logits, name="attention_weights") 207 | weights = cast_like(weights, q) 208 | if save_weights_to is not None: 209 | save_weights_to[scope.name] = weights 210 | save_weights_to[scope.name + "/logits"] = logits 211 | # Drop out attention links for each head. 212 | weights = dropout_with_broadcast_dims(weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) 213 | return tf.matmul(weights, v) 214 | 215 | 216 | def local_attention_2d(q, k, v, query_shape=(8, 16), memory_flange=(8, 16), name=None): 217 | """Strided block local self-attention""" 218 | 219 | with tf.compat.v1.variable_scope(name, default_name="local_self_attention_2d", values=[q, k, v]): 220 | v_shape = shape_list(v) 221 | 222 | # Pad query, key, value to ensure multiple of corresponding lengths. 223 | q = pad_to_multiple_2d(q, query_shape) 224 | k = pad_to_multiple_2d(k, query_shape) 225 | v = pad_to_multiple_2d(v, query_shape) 226 | paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]], [memory_flange[0], memory_flange[1]], [0, 0]] 227 | k = tf.pad(k, paddings) 228 | v = tf.pad(v, paddings) 229 | 230 | # Set up query blocks. 231 | q_indices = gather_indices_2d(q, query_shape, query_shape) 232 | q_new = gather_blocks_2d(q, q_indices) 233 | 234 | # Set up key and value blocks. 235 | memory_shape = (query_shape[0] + 2 * memory_flange[0], query_shape[1] + 2 * memory_flange[1]) 236 | k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape) 237 | k_new = gather_blocks_2d(k, k_and_v_indices) 238 | v_new = gather_blocks_2d(v, k_and_v_indices) 239 | 240 | attention_bias = tf.expand_dims(tf.compat.v1.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) 241 | output = dot_product_attention(q_new, k_new, v_new, attention_bias, dropout_rate=0.0, name="local_2d") 242 | # Put representations back into original shapes. 243 | padded_q_shape = shape_list(q) 244 | output = scatter_blocks_2d(output, q_indices, padded_q_shape) 245 | 246 | # Remove the padding if introduced. 247 | output = tf.slice(output, [0, 0, 0, 0, 0], [-1, -1, v_shape[2], v_shape[3], -1]) 248 | return output -------------------------------------------------------------------------------- /feature/cfp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from os.path import join as jpath 4 | import numpy as np 5 | import scipy 6 | 7 | MODULE_PATH = os.path.abspath(jpath(os.path.split(__file__)[0], '..')) 8 | if sys.path[0] != MODULE_PATH: sys.path.append(MODULE_PATH) 9 | 10 | from utils import load_audio, get_logger, parallel_generator 11 | 12 | logger = get_logger("CFP Feature") 13 | 14 | 15 | def STFT(x, fr, fs, Hop, h): 16 | t = np.arange(Hop, np.ceil(len(x) / float(Hop)) * Hop, Hop) 17 | N = int(fs / float(fr)) 18 | window_size = len(h) 19 | f = fs * np.linspace(0, 0.5, np.round(N / 2).astype("int"), endpoint=True) 20 | Lh = int(np.floor(float(window_size - 1) / 2)) 21 | tfr = np.zeros((int(N), len(t)), dtype=np.float) 22 | 23 | for icol, ti in enumerate(t): 24 | ti = int(ti) 25 | tau = np.arange(int(-min([round(N / 2.0) - 1, Lh, ti - 1])), int(min([round(N / 2.0) - 1, Lh, len(x) - ti]))) 26 | indices = np.mod(N + tau, N) + 1 27 | tfr[indices - 1, icol] = x[ti + tau - 1] * h[Lh + tau - 1] / np.linalg.norm(h[Lh + tau - 1]) 28 | 29 | tfr = abs(scipy.fftpack.fft(tfr, n=N, axis=0)) 30 | return tfr, f, t, N 31 | 32 | 33 | def nonlinear_func(X, g, cutoff): 34 | cutoff = int(cutoff) 35 | if g != 0: 36 | X[X < 0] = 0 37 | X[:cutoff, :] = 0 38 | X[-cutoff:, :] = 0 39 | X = np.power(X, g) 40 | else: 41 | X = np.log(X) 42 | X[:cutoff, :] = 0 43 | X[-cutoff:, :] = 0 44 | return X 45 | 46 | 47 | def freq_to_log_freq_mapping(tfr, f, fr, fc, tc, NumPerOct): 48 | StartFreq = fc 49 | StopFreq = 1 / tc 50 | Nest = int(np.ceil(np.log2(StopFreq / StartFreq)) * NumPerOct) 51 | central_freq = [] 52 | 53 | for i in range(0, Nest): 54 | cen_freq = StartFreq * pow(2, float(i) / NumPerOct) 55 | if cen_freq < StopFreq: 56 | central_freq.append(cen_freq) 57 | else: 58 | break 59 | 60 | Nest = len(central_freq) 61 | freq_band_transformation = np.zeros((Nest - 1, len(f)), dtype=np.float) 62 | for i in range(1, Nest - 1): 63 | left = int(round(central_freq[i - 1] / fr)) 64 | right = int(round(central_freq[i + 1] / fr) + 1) 65 | 66 | # rounding1 67 | if left >= right - 1: 68 | freq_band_transformation[i, left] = 1 69 | else: 70 | for j in range(left, right): 71 | if f[j] > central_freq[i - 1] and f[j] < central_freq[i]: 72 | freq_band_transformation[i, j] = (f[j] - central_freq[i-1]) / (central_freq[i] - central_freq[i-1]) 73 | elif f[j] > central_freq[i] and f[j] < central_freq[i+1]: 74 | freq_band_transformation[i, j] = (central_freq[i+1] - f[j]) / (central_freq[i+1] - central_freq[i]) 75 | tfrL = np.dot(freq_band_transformation, tfr) 76 | return tfrL, central_freq 77 | 78 | 79 | def quef_to_log_freq_mapping(ceps, q, fs, fc, tc, NumPerOct): 80 | StartFreq = fc 81 | StopFreq = 1 / tc 82 | Nest = int(np.ceil(np.log2(StopFreq / StartFreq)) * NumPerOct) 83 | central_freq = [] 84 | 85 | for i in range(0, Nest): 86 | cen_freq = StartFreq * pow(2, float(i) / NumPerOct) 87 | if cen_freq < StopFreq: 88 | central_freq.append(cen_freq) 89 | else: 90 | break 91 | f = 1 / (q+1e-9) 92 | Nest = len(central_freq) 93 | freq_band_transformation = np.zeros((Nest - 1, len(f)), dtype=np.float) 94 | for i in range(1, Nest - 1): 95 | for j in range(int(round(fs / central_freq[i + 1])), int(round(fs / central_freq[i - 1]) + 1)): 96 | if f[j] > central_freq[i - 1] and f[j] < central_freq[i]: 97 | freq_band_transformation[i, j] = (f[j] - central_freq[i - 1]) / (central_freq[i] - central_freq[i - 1]) 98 | elif f[j] > central_freq[i] and f[j] < central_freq[i + 1]: 99 | freq_band_transformation[i, j] = (central_freq[i + 1] - f[j]) / (central_freq[i + 1] - central_freq[i]) 100 | 101 | tfrL = np.dot(freq_band_transformation[:, :len(ceps)], ceps) 102 | return tfrL, central_freq 103 | 104 | 105 | def cfp_filterbank(x, fr, fs, Hop, h, fc, tc, g, bin_per_octave): 106 | NumofLayer = np.size(g) 107 | 108 | [tfr, f, t, N] = STFT(x, fr, fs, Hop, h) 109 | tfr = np.power(abs(tfr), g[0]) 110 | tfr0 = tfr # original STFT 111 | ceps = np.zeros(tfr.shape) 112 | 113 | if NumofLayer >= 2: 114 | for gc in range(1, NumofLayer): 115 | if np.remainder(gc, 2) == 1: 116 | tc_idx = round(fs * tc) 117 | ceps = np.real(np.fft.fft(tfr, axis=0)) / np.sqrt(N) 118 | ceps = nonlinear_func(ceps, g[gc], tc_idx) 119 | else: 120 | fc_idx = round(fc / fr) 121 | tfr = np.real(np.fft.fft(ceps, axis=0)) / np.sqrt(N) 122 | tfr = nonlinear_func(tfr, g[gc], fc_idx) 123 | 124 | tfr0 = tfr0[:int(round(N / 2)), :] 125 | tfr = tfr[:int(round(N / 2)), :] 126 | ceps = ceps[:int(round(N / 2)), :] 127 | 128 | HighFreqIdx = int(round((1/tc) / fr) + 1) 129 | f = f[:HighFreqIdx] 130 | tfr0 = tfr0[:HighFreqIdx, :] 131 | tfr = tfr[:HighFreqIdx, :] 132 | HighQuefIdx = int(round(fs / fc) + 1) 133 | q = np.arange(HighQuefIdx) / float(fs) 134 | ceps = ceps[:HighQuefIdx, :] 135 | 136 | tfrL0, central_frequencies = freq_to_log_freq_mapping(tfr0, f, fr, fc, tc, bin_per_octave) 137 | tfrLF, central_frequencies = freq_to_log_freq_mapping(tfr, f, fr, fc, tc, bin_per_octave) 138 | tfrLQ, central_frequencies = quef_to_log_freq_mapping(ceps, q, fs, fc, tc, bin_per_octave) 139 | 140 | return tfrL0, tfrLF, tfrLQ, f, q, t, central_frequencies 141 | 142 | 143 | def parallel_extract(x, samples, max_sample, fr, fs, Hop, h, fc, tc, g, bin_per_octave): 144 | freq_width = max_sample * Hop 145 | iters = np.ceil(samples / max_sample).astype("int") 146 | tmpL0, tmpLF, tmpLQ, tmpZ = {}, {}, {}, {} 147 | 148 | slice_list = [x[i * freq_width:(i+1) * freq_width] for i in range(iters)] 149 | 150 | feat_generator = enumerate( 151 | parallel_generator( 152 | cfp_filterbank, 153 | slice_list, 154 | fr=fr, 155 | fs=fs, 156 | Hop=Hop, 157 | h=h, 158 | fc=fc, 159 | tc=tc, 160 | g=g, 161 | bin_per_octave=bin_per_octave, 162 | max_workers=3) 163 | ) 164 | for idx, (feat_list, slice_idx) in feat_generator: 165 | logger.debug("Slice feature extracted: %d/%d", idx+1, len(slice_list)) 166 | tfrL0, tfrLF, tfrLQ, f, q, t, cen_freq = feat_list 167 | tmpL0[slice_idx] = tfrL0 168 | tmpLF[slice_idx] = tfrLF 169 | tmpLQ[slice_idx] = tfrLQ 170 | tmpZ[slice_idx] = tfrLF * tfrLQ 171 | return tmpL0, tmpLF, tmpLQ, tmpZ, f, q, t, cen_freq 172 | 173 | 174 | def spectral_flux(spec, invert=False, norm=True): 175 | flux = np.pad(np.diff(spec), ((0, 0), (1, 0))) 176 | if invert: 177 | flux *= -1.0 178 | 179 | flux[flux < 0] = 0.0 180 | if norm: 181 | flux = (flux - np.mean(flux)) / np.std(flux) 182 | 183 | return flux 184 | 185 | 186 | def _extract_cfp( 187 | x, 188 | fs, 189 | hop=0.02, # in seconds 190 | win_size=7939, 191 | fr=2.0, 192 | fc=27.5, 193 | tc=1/4487.0, 194 | g=[0.24, 0.6, 1], 195 | bin_per_octave=48, 196 | down_fs=44100, 197 | max_sample=2000, 198 | ): 199 | if fs != down_fs: 200 | x = scipy.signal.resample_poly(x, down_fs, fs) 201 | fs = down_fs 202 | 203 | Hop = round(down_fs * hop) 204 | x = x.astype("float32") 205 | h = scipy.signal.blackmanharris(win_size) # window size 206 | g = np.array(g) 207 | 208 | samples = np.floor(len(x) / Hop).astype("int") 209 | logger.debug("Sample number: %d", samples) 210 | logger.debug("Extracting CFP feature...") 211 | if samples > max_sample: 212 | tmpL0, tmpLF, tmpLQ, tmpZ, _, _, _, cen_freq = parallel_extract( 213 | x, samples, max_sample, fr, fs, Hop, h, fc, tc, g, bin_per_octave 214 | ) 215 | 216 | tfrL0 = tmpL0.pop(0) 217 | tfrLF = tmpLF.pop(0) 218 | tfrLQ = tmpLQ.pop(0) 219 | Z = tmpZ.pop(0) 220 | rr = len(tmpL0) 221 | for i in range(1, rr + 1, 1): 222 | tfrL0 = np.concatenate((tfrL0, tmpL0.pop(i)), axis=1) 223 | tfrLF = np.concatenate((tfrLF, tmpLF.pop(i)), axis=1) 224 | tfrLQ = np.concatenate((tfrLQ, tmpLQ.pop(i)), axis=1) 225 | Z = np.concatenate((Z, tmpZ.pop(i)), axis=1) 226 | else: 227 | tfrL0, tfrLF, tfrLQ, _, _, _, cen_freq = cfp_filterbank(x, fr, fs, Hop, h, fc, tc, g, bin_per_octave) 228 | Z = tfrLF * tfrLQ 229 | 230 | return Z, tfrL0, tfrLF, tfrLQ, cen_freq 231 | 232 | 233 | def extract_cfp(filename, down_fs=44100, **kwargs): 234 | """CFP feature extraction function""" 235 | 236 | logger.debug("Loading audio: %s", filename) 237 | x, fs = load_audio(filename, sampling_rate=down_fs) 238 | return _extract_cfp(x, fs, down_fs=fs, **kwargs) 239 | 240 | 241 | def _extract_vocal_cfp( 242 | x, 243 | fs, 244 | hop=0.02, 245 | fr=2.0, 246 | fc=80.0, 247 | tc=1/1000, 248 | **kwargs 249 | ): 250 | logger.debug("Extract three types of CFP with different window sizes.") 251 | high_z, high_spec, _, _, _ = _extract_cfp(x, fs, win_size=743, hop=hop, fr=fr, fc=fc, tc=tc, **kwargs) 252 | med_z, med_spec, _, _, _ = _extract_cfp(x, fs, win_size=372, hop=hop, fr=fr, fc=fc, tc=tc, **kwargs) 253 | low_z, low_spec, _, _, _ = _extract_cfp(x, fs, win_size=186, hop=hop, fr=fr, fc=fc, tc=tc, **kwargs) 254 | 255 | # Normalize Z 256 | high_z_norm = (high_z - np.mean(high_z)) / np.std(high_z) 257 | med_z_norm = (med_z - np.mean(med_z)) / np.std(med_z) 258 | low_z_norm = (low_z - np.mean(low_z)) / np.std(low_z) 259 | 260 | # Spectral flux 261 | high_flux = spectral_flux(high_spec) 262 | med_flux = spectral_flux(med_spec) 263 | low_flux = spectral_flux(low_spec) 264 | 265 | # Inverse spectral flux 266 | high_inv_flux = spectral_flux(high_spec, invert=True) 267 | med_inv_flux = spectral_flux(med_spec, invert=True) 268 | low_inv_flux = spectral_flux(low_spec, invert=True) 269 | 270 | # Collect and concat 271 | flux = np.dstack([low_flux, med_flux, high_flux]) 272 | inv_flux = np.dstack([low_inv_flux, med_inv_flux, high_inv_flux]) 273 | z_norm = np.dstack([low_z_norm, med_z_norm, high_z_norm]) 274 | 275 | output = np.dstack([flux, inv_flux, z_norm]) 276 | return np.transpose(output, axes=[1, 0, 2]) # time x feat x channel 277 | 278 | 279 | def extract_vocal_cfp(filename, down_fs=16000, **kwargs): 280 | """Specialized CFP feature extraction for vocal submodule.""" 281 | logger.debug("Loading audio: %s", filename) 282 | x, fs = load_audio(filename, sampling_rate=down_fs) 283 | logger.debug("Extracting vocal feature") 284 | return _extract_vocal_cfp(x, fs, **kwargs) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import types 4 | import logging 5 | import uuid 6 | import concurrent.futures 7 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor 8 | import importlib 9 | import csv 10 | from ruamel import yaml 11 | import librosa 12 | 13 | import jsonschema 14 | import pretty_midi 15 | import numpy as np 16 | 17 | 18 | def get_logger(name=None, level="warn"): 19 | """Get the logger for printing informations""" 20 | 21 | logger_name = str(uuid.uuid4())[:8] if name is None else name 22 | logger = logging.getLogger(logger_name) 23 | level = os.environ.get("LOG_LEVEL", level) 24 | 25 | msg_formats = { 26 | "debug": "%(asctime)s [%(levelname)s] %(message)s [at %(filename)s:%(lineno)d]", 27 | "info": "%(asctime)s %(message)s [at %(filename)s:%(lineno)d]", 28 | "warn": "%(asctime)s %(message)s", 29 | "warning": "%(asctime)s %(message)s", 30 | "error": "%(asctime)s [%(levelname)s] %(message)s [at %(filename)s:%(lineno)d]", 31 | "critical": "%(asctime)s [%(levelname)s] %(message)s [at %(filename)s:%(lineno)d]", 32 | } 33 | level_mapping = { 34 | "debug": logging.DEBUG, 35 | "info": logging.INFO, 36 | "warn": logging.INFO, 37 | "warning": logging.WARNING, 38 | "error": logging.ERROR, 39 | "critical": logging.CRITICAL, 40 | } 41 | 42 | date_format = "%Y-%m-%d %H:%M:%S" 43 | formatter = logging.Formatter(fmt=msg_formats[level.lower()], datefmt=date_format) 44 | handler = logging.StreamHandler() 45 | handler.setFormatter(formatter) 46 | if len(logger.handlers) > 0: 47 | rm_idx = [idx for idx, handler in enumerate(logger.handlers) if isinstance(handler, logging.StreamHandler)] 48 | for idx in rm_idx: 49 | del logger.handlers[idx] 50 | logger.addHandler(handler) 51 | logger.setLevel(level_mapping[level.lower()]) 52 | return logger 53 | 54 | logger = get_logger("Utils") 55 | 56 | 57 | class LazyLoader(types.ModuleType): 58 | """Lazily import a module""" 59 | 60 | def __init__(self, local_name, parent_module_globals, name, warning=None): 61 | self._local_name = local_name 62 | self._parent_module_globals = parent_module_globals 63 | self._warning = warning 64 | 65 | super().__init__(name) 66 | 67 | def _load(self): 68 | """Load the module and insert it into the parent's globals.""" 69 | module = importlib.import_module(self.__name__) 70 | self._parent_module_globals[self._local_name] = module 71 | 72 | if self._warning: 73 | logger.warning(self._warning) 74 | # Make sure to only warn once. 75 | self._warning = None 76 | 77 | self.__dict__.update(module.__dict__) 78 | return module 79 | 80 | def __getattr__(self, item): 81 | module = self._load() 82 | return getattr(module, item) 83 | 84 | def __dir__(self): 85 | module = self._load() 86 | return dir(module) 87 | 88 | 89 | # Lazy load the Spleeter pacakge for avoiding pulling large dependencies and boosting the import speed. 90 | adapter = LazyLoader("adapter", globals(), "spleeter.audio.adapter") 91 | 92 | 93 | def load_audio(audio_path, sampling_rate=44100, mono=True): 94 | audio, sr = librosa.load(audio_path, mono=mono, sr=sampling_rate) 95 | return audio, sr 96 | 97 | 98 | def load_yaml(yaml_path): 99 | return yaml.round_trip_load(open(yaml_path, "r"), preserve_quotes=True) 100 | 101 | 102 | def write_yaml(json_obj, output_path, dump=True): 103 | # json_obj should be yaml string already if dump is False 104 | out_str = yaml.round_trip_dump(json_obj, indent=4) if dump else json_obj 105 | open(output_path, "w").write(out_str) 106 | 107 | 108 | def write_agg_f0_results(agg_f0, output_path): 109 | """Write out aggregated F0 information as a CSV file""" 110 | 111 | fieldnames = ["start_time", "end_time", "frequency", "pitch"] 112 | 113 | # Check the format is correct 114 | if any(list(row.keys()) != fieldnames for row in agg_f0): 115 | raise ValueError(f"Fields inconsistent! Expected: {fieldnames}") 116 | 117 | with open(output_path, "w") as out: 118 | writer = csv.DictWriter(out, fieldnames=fieldnames) 119 | writer.writeheader() 120 | writer.writerows(agg_f0) 121 | 122 | 123 | def ensure_path_exists(path): 124 | if not os.path.exists(path): 125 | os.makedirs(path) 126 | 127 | 128 | def camel_to_snake(string): 129 | """Convert a camel case to snake case""" 130 | return re.sub(r"(? chunk_size: 229 | logger.warning( 230 | "Chunk size should larger than the maximum number of workers, or the parallel computation " 231 | "can do nothing helpful. Received max workers: %d, chunk size: %d", 232 | max_workers, chunk_size 233 | ) 234 | max_workers = chunk_size 235 | 236 | executor = ThreadPoolExecutor(max_workers=max_workers) \ 237 | if use_thread else ProcessPoolExecutor(max_workers=max_workers) 238 | 239 | chunks = 1 240 | slice_len = len(input_list) 241 | if chunk_size is not None: 242 | chunks = len(input_list) / chunk_size 243 | if int(chunks) < chunks: 244 | chunks = int(chunks) + 1 245 | slice_len = chunk_size 246 | 247 | for chunk_idx in range(int(chunks)): 248 | start_idx = chunk_idx * slice_len 249 | end_idx = (chunk_idx + 1) * slice_len 250 | future_to_input = {} 251 | for idx, _input in enumerate(input_list[start_idx:end_idx]): 252 | logger.debug("Parallel job submitted %s", func.__name__) 253 | future = executor.submit(func, _input, **kwargs) 254 | future_to_input[future] = idx + start_idx 255 | 256 | try: 257 | for future in concurrent.futures.as_completed(future_to_input, timeout=timeout): 258 | logger.debug("Yielded %s", func.__name__) 259 | yield future.result(), future_to_input[future] 260 | except KeyboardInterrupt as exp: 261 | for future in future_to_input: 262 | if future.cancel(): 263 | logger.info("Job cancelled") 264 | else: 265 | logger.warning("Fail to cancel job: %s", future) 266 | executor.shutdown() 267 | raise exp 268 | executor.shutdown() 269 | 270 | 271 | def resolve_dataset_type(dataset_path, keywords): 272 | low_path = os.path.basename(os.path.abspath(dataset_path)).lower() 273 | d_type = [val for key, val in keywords.items() if key in low_path] 274 | if len(d_type) == 0: 275 | return None 276 | 277 | assert len(set(d_type)) == 1 278 | return d_type[0] 279 | 280 | 281 | def get_filename(path): 282 | abspath = os.path.abspath(path) 283 | return os.path.splitext(os.path.basename(abspath))[0] 284 | 285 | 286 | def aggregate_f0_info(pred, t_unit): 287 | """Aggregation F0 contour to start time, end time, frequency, pitch""" 288 | 289 | results = [] 290 | 291 | cur_idx = 0 292 | start_idx = 0 293 | last_hz = pred[0] 294 | eps = 1e-6 295 | pred = np.append(pred, 0) # Append an additional zero to the end temporarily. 296 | while cur_idx < len(pred): 297 | cur_hz = pred[cur_idx] 298 | if abs(cur_hz - last_hz) < eps: 299 | # Skip to the next index with different frequency. 300 | last_hz = cur_hz 301 | cur_idx += 1 302 | continue 303 | 304 | if last_hz < eps: 305 | # Almost equals to zero. Ignored. 306 | last_hz = cur_hz 307 | start_idx = cur_idx 308 | cur_idx += 1 309 | continue 310 | 311 | results.append({ 312 | "start_time": round(start_idx * t_unit, 6), 313 | "end_time": round(cur_idx * t_unit, 6), 314 | "frequency": last_hz, 315 | "pitch": pretty_midi.hz_to_note_number(last_hz) 316 | }) 317 | 318 | start_idx = cur_idx 319 | cur_idx += 1 320 | last_hz = cur_hz 321 | 322 | pred = pred[:-1] # Remove the additional ending zero. 323 | return results 324 | -------------------------------------------------------------------------------- /base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | from os.path import join as jpath 5 | from abc import ABCMeta, abstractmethod 6 | 7 | import h5py 8 | import tensorflow as tf 9 | from tensorflow.keras.models import model_from_yaml 10 | 11 | from utils import get_logger, ensure_path_exists, get_filename 12 | from constants.midi import LOWEST_MIDI_NOTE, HIGHEST_MIDI_NOTE 13 | 14 | MODULE_PATH = os.path.split(__file__)[0] 15 | 16 | 17 | logger = get_logger("Base Class") 18 | 19 | 20 | class BaseTranscription(metaclass=ABCMeta): 21 | """Base class of transcription applications.""" 22 | def __init__(self, setting_class, conf_path=None): 23 | self.setting_class = setting_class 24 | self.settings = setting_class(conf_path=conf_path) 25 | self.custom_objects = {} 26 | 27 | @abstractmethod 28 | def transcribe(self, input_audio, model_path, output="./"): 29 | raise NotImplementedError 30 | 31 | def get_model(self, settings): 32 | """Get the model from the python source file""" 33 | raise NotImplementedError 34 | 35 | def _load_model(self, model_path=None, custom_objects=None): 36 | if model_path in self.settings.checkpoint_path: 37 | # The given model_path is actually the 'transcription_mode'. 38 | default_path = self.settings.checkpoint_path[model_path] 39 | model_path = os.path.join(MODULE_PATH, default_path) 40 | logger.info("Using built-in model %s for transcription.", model_path) 41 | 42 | model_path, conf_path = self._resolve_model_path(model_path) 43 | settings = self.setting_class(conf_path=conf_path) 44 | 45 | try: 46 | model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) 47 | except (OSError): 48 | raise FileNotFoundError( 49 | f"Checkpoint file not found: {model_path}/variables/variables.data*" 50 | ) 51 | 52 | return model, settings 53 | 54 | def _resolve_model_path(self, model_path=None): 55 | model_path = os.path.abspath(model_path) if model_path is not None else None 56 | logger.debug("Absolute path of the given model: %s", model_path) 57 | if model_path is None: 58 | default_path = self.settings.checkpoint_path[self.settings.transcription_mode] 59 | model_path = os.path.join(MODULE_PATH, default_path) 60 | logger.info("Using built-in model %s for transcription.", model_path) 61 | elif not os.path.exists(model_path): 62 | raise FileNotFoundError(f"The given path doesn't exist: {model_path}.") 63 | elif not os.path.basename(model_path).startswith(self.settings.model.save_prefix.lower()) \ 64 | and not set(["arch.yaml", "weights.h5", "configurations.yaml"]).issubset(os.listdir(model_path)): 65 | 66 | # Search checkpoint folders under the given path 67 | dirs = [c_dir for c_dir in os.listdir(model_path) if os.path.isdir(c_dir)] 68 | prefix = self.settings.model.save_prefix.lower() 69 | cand_dirs = [c_dir for c_dir in dirs if c_dir.startswith(prefix)] 70 | 71 | if len(cand_dirs) == 0: 72 | raise FileNotFoundError(f"No checkpoint of {prefix} found in {model_path}") 73 | elif len(cand_dirs) > 1: 74 | logger.warning("There are multiple checkpoints in the directory. Default to use %s", cand_dirs[0]) 75 | model_path = os.path.join(model_path, cand_dirs[0]) 76 | 77 | # There should be one configuration file of this checkpoint. 78 | conf_path = os.path.join(model_path, "configurations.yaml") 79 | return model_path, conf_path 80 | 81 | def _get_model_from_yaml(self, arch_path, custom_objects=None): 82 | return model_from_yaml(open(arch_path, "r").read(), custom_objects=custom_objects) 83 | 84 | def _resolve_feature_output_path(self, dataset_path, settings): 85 | if settings.dataset.feature_save_path == "+": 86 | base_output_path = dataset_path 87 | settings.dataset.save_path = dataset_path 88 | else: 89 | base_output_path = settings.dataset.feature_save_path 90 | train_feat_out_path = jpath(base_output_path, "train_feature") 91 | test_feat_out_path = jpath(base_output_path, "test_feature") 92 | ensure_path_exists(train_feat_out_path) 93 | ensure_path_exists(test_feat_out_path) 94 | return train_feat_out_path, test_feat_out_path 95 | 96 | def _resolve_semi_feature_output_path(self, dataset_path, settings): 97 | if settings.dataset.feature_save_path == "+": 98 | base_output_path = dataset_path 99 | settings.dataset.save_path = dataset_path 100 | else: 101 | base_output_path = settings.dataset.feature_save_path 102 | train_feat_out_path = jpath(base_output_path, "semi_feature") 103 | ensure_path_exists(train_feat_out_path) 104 | return train_feat_out_path 105 | 106 | def _output_midi(self, output, input_audio, midi=None, verbose=True): 107 | if output is None: 108 | return None 109 | 110 | if os.path.isdir(output): 111 | output = jpath(output, get_filename(input_audio)) 112 | if midi is not None: 113 | out_path = output if output.endswith(".mid") else f"{output}.mid" 114 | midi.write(out_path) 115 | if verbose: 116 | logger.info("MIDI file has been written to %s.", out_path) 117 | return output 118 | 119 | def _validate_and_get_settings(self, setting_instance): 120 | if setting_instance is not None: 121 | assert isinstance(setting_instance, self.setting_class) 122 | return setting_instance 123 | return self.settings 124 | 125 | 126 | class Label: 127 | """Interface of different label format""" 128 | 129 | def __init__( 130 | self, 131 | start_time, 132 | end_time, 133 | note, 134 | instrument=0, 135 | velocity=64, 136 | start_beat=0, 137 | end_beat=10, 138 | note_value="", 139 | is_drum=False 140 | ): 141 | self.start_time = start_time 142 | self.end_time = end_time 143 | self.note = note 144 | self.velocity = velocity 145 | self.instrument = instrument 146 | self.start_beat = start_beat 147 | self.end_beat = end_beat 148 | self.note_value = note_value 149 | self.is_drum = is_drum 150 | 151 | def __eq__(self, val): 152 | if not isinstance(val, Label): 153 | return False 154 | 155 | epsilon = 1e-4 # Tolerance of time difference 156 | if abs(self.start_time - val.start_time) < epsilon \ 157 | and abs(self.end_time - val.end_time) < epsilon \ 158 | and abs(self.note - val.note) < epsilon \ 159 | and self.velocity == val.velocity \ 160 | and self.instrument == val.instrument \ 161 | and abs(self.start_beat - val.start_beat) < epsilon \ 162 | and abs(self.end_beat - val.end_beat) < epsilon \ 163 | and self.note_value == val.note_value \ 164 | and self.is_drum == val.is_drum: 165 | return True 166 | return False 167 | 168 | def __str__(self): 169 | msg = [ 170 | f"Start time: {self.start_time}", 171 | f"End time: {self.end_time}", 172 | f"Note number: {self.note}", 173 | f"Velocity: {self.velocity}", 174 | f"Instrument number: {self.instrument}", 175 | f"Start beat: {self.start_beat}", 176 | f"End beat: {self.end_beat}", 177 | f"Note value: {self.note_value}", 178 | f"Is drum: {self.is_drum}" 179 | ] 180 | return ", ".join(msg) 181 | 182 | def __repr__(self): 183 | return self.__str__() 184 | 185 | @property 186 | def note(self): 187 | return self._note 188 | 189 | @note.setter 190 | def note(self, midi_num): 191 | if LOWEST_MIDI_NOTE <= midi_num <= HIGHEST_MIDI_NOTE: 192 | self._note = midi_num 193 | else: 194 | logger.warning( 195 | "The given midi number is out-of-bound and will be skipped. " 196 | "Received midi number: %d. Available: [%d - %d]", 197 | midi_num, LOWEST_MIDI_NOTE, HIGHEST_MIDI_NOTE 198 | ) 199 | self._note = -1 200 | 201 | @property 202 | def velocity(self): 203 | return self._velocity 204 | 205 | @velocity.setter 206 | def velocity(self, value): 207 | assert 0 <= value <= 127 208 | self._velocity = value 209 | 210 | 211 | class BaseDatasetLoader: 212 | """Base dataset loader for yielding training samples""" 213 | 214 | def __init__(self, is_labeled = True, feature_folder=None, feature_files=None, num_samples=100, slice_hop=1, feat_col_name="feature"): 215 | self.is_labeled = is_labeled 216 | 217 | if feature_files is None: 218 | assert feature_folder is not None 219 | self.hdf_files = glob.glob(f"{feature_folder}/*.hdf") 220 | else: 221 | self.hdf_files = feature_files 222 | 223 | if len(self.hdf_files) == 0: 224 | logger.warning("Warning! No feature file was found in the given path.") 225 | 226 | self.slice_hop = slice_hop 227 | self.feat_col_name = feat_col_name 228 | 229 | self.hdf_refs = {} 230 | for hdf in self.hdf_files: 231 | try: 232 | self.hdf_refs[hdf] = h5py.File(hdf, "r") 233 | except OSError: 234 | msg = f"Resource temporarily unavailable due to file being opened without closing. Resource: {hdf}" 235 | logger.error(msg) 236 | raise OSError(msg) 237 | self.num_samples = num_samples 238 | 239 | # Initialize indices of index-to-file mapping to ensure all samples 240 | # will be visited during training. 241 | length_map = {hdf: len(hdf_ref[feat_col_name]) for hdf, hdf_ref in self.hdf_refs.items()} 242 | self.total_length = sum(length_map.values()) 243 | self.start_idxs = list(range(0, self.total_length, slice_hop)) 244 | self.idx_to_hdf_map = {} 245 | cur_len = 0 246 | cur_iid = 0 247 | for hdf, length in length_map.items(): 248 | end_iid = length // slice_hop 249 | for iid in range(cur_iid, cur_iid+end_iid): # noqa: E226 250 | start_idx = self.start_idxs[iid] 251 | self.idx_to_hdf_map[start_idx] = (hdf, start_idx - cur_len) 252 | cur_len += end_iid * slice_hop 253 | cur_iid += end_iid 254 | diff = set(self.start_idxs) - set(self.idx_to_hdf_map.keys()) 255 | self.cut_idx = len(diff) 256 | if self.cut_idx > 0: 257 | self.start_idxs = self.start_idxs[:-self.cut_idx] 258 | random.shuffle(self.start_idxs) 259 | 260 | logger.info("Total samples: %s", len(self.start_idxs)) 261 | 262 | def __iter__(self): 263 | for _ in range(self.num_samples): 264 | if len(self.start_idxs) == 0: 265 | # Shuffle the indexes after visiting all the samples in the dataset. 266 | self.start_idxs = list(range(0, self.total_length, self.slice_hop)) 267 | if self.cut_idx > 0: 268 | self.start_idxs = self.start_idxs[:-self.cut_idx] 269 | random.shuffle(self.start_idxs) 270 | 271 | start_idx = self.start_idxs.pop() 272 | hdf_name, slice_start = self.idx_to_hdf_map[start_idx] 273 | 274 | feat = self._get_feature(hdf_name, slice_start) 275 | if self.is_labeled: 276 | label = self._get_label(hdf_name, slice_start) 277 | feat, label = self._pre_yield(feat, label) 278 | yield feat, label 279 | else: 280 | yield feat 281 | 282 | def _get_feature(self, hdf_name, slice_start): 283 | return self.hdf_refs[hdf_name][self.feat_col_name][slice_start:slice_start + self.slice_hop].squeeze() 284 | 285 | def _get_label(self, hdf_name, slice_start): 286 | return self.hdf_refs[hdf_name]["label"][slice_start:slice_start + self.slice_hop].squeeze() 287 | 288 | def _pre_yield(self, feature, label): 289 | return feature, label 290 | 291 | def get_dataset(self, batch_size, output_types=None, output_shapes=None): 292 | def gen_wrapper(): 293 | for data in self: 294 | yield data 295 | 296 | return tf.data.Dataset.from_generator( 297 | gen_wrapper, output_types=output_types, output_shapes=output_shapes 298 | ) \ 299 | .batch(batch_size, drop_remainder=True) \ 300 | .prefetch(tf.data.experimental.AUTOTUNE) 301 | -------------------------------------------------------------------------------- /vocal_contour/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as jpath 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | from scipy.io.wavfile import write as wavwrite 7 | import h5py 8 | import tensorflow as tf 9 | from tensorflow.keras.utils import to_categorical 10 | from mir_eval import sonify 11 | 12 | import sys 13 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 14 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 15 | 16 | from base import BaseTranscription, BaseDatasetLoader 17 | from setting_loaders import VocalContourSettings 18 | from feature.wrapper_func import extract_cfp_feature 19 | from utils import write_yaml, write_agg_f0_results, get_logger, ensure_path_exists, parallel_generator, resolve_dataset_type, aggregate_f0_info 20 | from train import train_epochs, get_train_val_feat_file_list 21 | from vocal_contour.callbacks import EarlyStopping, ModelCheckpoint 22 | from vocal_contour.inference import inference 23 | from vocal_contour import labels as lextor 24 | from constants import datasets as d_struct 25 | from models.u_net import semantic_segmentation 26 | from models.losses import focal_loss 27 | 28 | 29 | logger = get_logger("Vocal Contour") 30 | 31 | 32 | class VocalContourTranscription(BaseTranscription): 33 | """Application class for vocal-contour transcription""" 34 | 35 | def __init__(self, conf_path=None): 36 | super().__init__(VocalContourSettings, conf_path=conf_path) 37 | 38 | def transcribe(self, input_audio, model_path=None, output="./"): 39 | """Transcribe frame-level fundamental frequency of vocal from the given audio""" 40 | 41 | if not os.path.isfile(input_audio): 42 | raise FileNotFoundError(f"The given audio path does not exist. Path: {input_audio}") 43 | 44 | logger.info("Loading model...") 45 | model, model_settings = self._load_model(model_path) 46 | 47 | logger.info("Extracting feature...") 48 | feature = extract_cfp_feature( 49 | input_audio, 50 | hop=model_settings.feature.hop_size, 51 | win_size=model_settings.feature.window_size, 52 | down_fs=model_settings.feature.sampling_rate 53 | ) 54 | 55 | logger.info("Predicting...") 56 | f0 = inference(feature[:, :, 0], model, timestep=model_settings.training.timesteps) 57 | agg_f0 = aggregate_f0_info(f0, t_unit=model_settings.feature.hop_size) 58 | 59 | timestamp = np.arange(len(f0)) * model_settings.feature.hop_size 60 | wav = sonify.pitch_contour( 61 | timestamp, f0, model_settings.feature.sampling_rate, amplitudes=0.5 * np.ones(len(f0)) 62 | ) 63 | 64 | output = self._output_midi(output, input_audio, verbose=False) 65 | if output is not None: 66 | write_agg_f0_results(agg_f0, f"{output}_f0.csv") 67 | wavwrite(f"{output}_trans.wav", model_settings.feature.sampling_rate, wav) 68 | logger.info("Text and Wav files have been written to %s", os.path.abspath(os.path.dirname(output))) 69 | 70 | logger.info("Transcription finished") 71 | return agg_f0 72 | 73 | def generate_feature(self, dataset_path, vocalcontour_settings=None, num_threads=4): 74 | """Extract the feature from the given dataset""" 75 | 76 | settings = self._validate_and_get_settings(vocalcontour_settings) 77 | 78 | train_feat_out_path, test_feat_out_path = self._resolve_feature_output_path(dataset_path, settings) 79 | 80 | struct = d_struct.VocalContourStructure 81 | label_extractor = lextor.VocalContourlabelExtraction 82 | 83 | train_data_pair = struct.get_train_data_pair(dataset_path=dataset_path) 84 | logger.info( 85 | "Start extract training feature of the dataset. " 86 | "This may take time to finish and affect the computer's performance" 87 | ) 88 | _parallel_feature_extraction( 89 | train_data_pair, train_feat_out_path, label_extractor, settings.feature, num_threads=num_threads 90 | ) 91 | 92 | test_data_pair = struct.get_test_data_pair(dataset_path=dataset_path) 93 | logger.info( 94 | "Start extract testing feature of the dataset. " 95 | "This may take time to finish and affect the computer's performance" 96 | ) 97 | _parallel_feature_extraction( 98 | test_data_pair, test_feat_out_path, label_extractor, settings.feature, num_threads=num_threads 99 | ) 100 | 101 | # Writing out the settings 102 | write_yaml(settings.to_json(), jpath(train_feat_out_path, ".success.yaml")) 103 | write_yaml(settings.to_json(), jpath(test_feat_out_path, ".success.yaml")) 104 | logger.info("All done") 105 | 106 | def train(self, feature_folder, model_name=None, input_model_path=None, vocalcontour_settings=None): 107 | """Train model""" 108 | 109 | settings = self._validate_and_get_settings(vocalcontour_settings) 110 | 111 | if input_model_path is not None: 112 | logger.info("Continue to train one model: %s", input_model_path) 113 | model, prev_set = self._load_model(input_model_path) 114 | settings.training.timesteps = prev_set.training.timesteps 115 | settings.model.save_path = prev_set.model.save_path 116 | 117 | logger.info("Constructing dataset instance") 118 | split = settings.training.steps / (settings.training.steps + settings.training.val_steps) 119 | train_feat_files, val_feat_files = get_train_val_feat_file_list(feature_folder, split=split) 120 | 121 | output_types = (tf.float32, tf.float32) 122 | train_dataset = VocalContourDatasetLoader( 123 | feature_files=train_feat_files, 124 | num_samples=settings.training.batch_size * settings.training.steps, 125 | timesteps=settings.training.timesteps 126 | ).get_dataset(settings.training.batch_size, output_types=output_types) 127 | 128 | val_dataset = VocalContourDatasetLoader( 129 | feature_files=val_feat_files, 130 | num_samples=settings.training.val_batch_size * settings.training.val_steps, 131 | timesteps=settings.training.timesteps 132 | ).get_dataset(settings.training.val_batch_size, output_types=output_types) 133 | 134 | if input_model_path is None: 135 | logger.info("Constructing new model") 136 | # Note: The default value of dropout rate for ConvBlock is different in VocalSeg which is 0.2. 137 | model = semantic_segmentation( 138 | multi_grid_layer_n=1, feature_num=384, ch_num=1, timesteps=settings.training.timesteps 139 | ) 140 | model.compile(optimizer="adam", loss=focal_loss, metrics=['accuracy']) 141 | 142 | logger.info("Resolving model output path") 143 | if model_name is None: 144 | model_name = str(datetime.now()).replace(" ", "_") 145 | if not model_name.startswith(settings.model.save_prefix): 146 | model_name = settings.model.save_prefix + "_" + model_name 147 | 148 | model_save_path = jpath(settings.model.save_path, model_name) 149 | ensure_path_exists(model_save_path) 150 | write_yaml(settings.to_json(), jpath(model_save_path, "configurations.yaml")) 151 | write_yaml(model.to_yaml(), jpath(model_save_path, "arch.yaml"), dump=False) 152 | logger.info("Model output to: %s", model_save_path) 153 | 154 | logger.info("Constructing callbacks") 155 | callbacks = [ 156 | EarlyStopping(patience=settings.training.early_stop), 157 | ModelCheckpoint(model_save_path, save_weights_only=True) 158 | ] 159 | logger.info("Callback list: %s", callbacks) 160 | 161 | logger.info("Start training") 162 | history = train_epochs( 163 | model, 164 | train_dataset, 165 | validate_dataset=val_dataset, 166 | epochs=settings.training.epoch, 167 | steps=settings.training.steps, 168 | val_steps=settings.training.val_steps, 169 | callbacks=callbacks 170 | ) 171 | 172 | return model_save_path, history 173 | 174 | 175 | def _all_in_one_extract(data_pair, label_extractor, t_unit, **kwargs): 176 | feat = extract_cfp_feature(data_pair[0], **kwargs) 177 | label = label_extractor.extract_label(data_pair[1], t_unit=t_unit) 178 | flen = len(feat) 179 | llen = len(label) 180 | if flen > llen: 181 | diff = flen - llen 182 | label = np.pad(label, ((0, diff), (0, 0)), constant_values=0) 183 | elif llen > flen: 184 | label = label[:flen] 185 | return feat, label 186 | 187 | 188 | def _parallel_feature_extraction(data_pair, out_path, label_extractor, feat_settings, num_threads=4): 189 | feat_extract_params = { 190 | "hop": feat_settings.hop_size, 191 | "down_fs": feat_settings.sampling_rate, 192 | "win_size": feat_settings.window_size 193 | } 194 | 195 | iters = enumerate( 196 | parallel_generator( 197 | _all_in_one_extract, 198 | data_pair, 199 | max_workers=num_threads, 200 | use_thread=True, 201 | chunk_size=num_threads, 202 | label_extractor=label_extractor, 203 | t_unit=feat_settings.hop_size, 204 | **feat_extract_params 205 | ) 206 | ) 207 | 208 | for idx, ((feature, label), audio_idx) in iters: 209 | audio = data_pair[audio_idx][0] 210 | 211 | print(f"Progress: {idx+1}/{len(data_pair)} - {audio}" + " "*6, end="\r") # noqa: E226 212 | 213 | filename, _ = os.path.splitext(os.path.basename(audio)) 214 | out_hdf = jpath(out_path, filename + ".hdf") 215 | saved = False 216 | retry_times = 5 217 | for retry in range(retry_times): 218 | if saved: 219 | break 220 | try: 221 | with h5py.File(out_hdf, "w") as out_f: 222 | out_f.create_dataset("feature", data=feature) 223 | out_f.create_dataset("label", data=label) 224 | saved = True 225 | except OSError as exp: 226 | logger.warning("OSError occurred, retrying %d times. Reason: %s", retry + 1, str(exp)) 227 | if not saved: 228 | logger.error("H5py failed to save the feature file after %d retries.", retry_times) 229 | raise OSError 230 | print("") 231 | 232 | 233 | class VocalContourDatasetLoader(BaseDatasetLoader): 234 | """Data loader for training the mdoel of `vocal-contour`""" 235 | def __init__( 236 | self, 237 | feature_folder=None, 238 | feature_files=None, 239 | num_samples=100, 240 | timesteps=128, 241 | channels=0, 242 | feature_num=384 243 | ): 244 | super().__init__( 245 | feature_folder=feature_folder, feature_files=feature_files, num_samples=num_samples, slice_hop=timesteps 246 | ) 247 | 248 | self.feature_folder = feature_folder 249 | self.feature_files = feature_files 250 | self.num_samples = num_samples 251 | self.timesteps = timesteps 252 | self.channels = channels 253 | self.feature_num = feature_num 254 | 255 | self.hdf_refs = {} 256 | for hdf in self.hdf_files: 257 | ref = h5py.File(hdf, "r") 258 | self.hdf_refs[hdf] = ref 259 | 260 | def _pad(self, data): 261 | pad_bottom = (self.feature_num - data.shape[1]) // 2 262 | pad_top = self.feature_num - data.shape[1] - pad_bottom 263 | paddings = ((0, 0), (pad_bottom, pad_top)) 264 | if len(data.shape) == 3: 265 | paddings += ((0, 0),) 266 | return np.pad(data, paddings) 267 | 268 | def _get_feature(self, hdf_name, slice_start): 269 | feat = self.hdf_refs[hdf_name]["feature"] 270 | feat = feat[:, :, self.channels] 271 | feat = self._pad(feat) 272 | feat = feat[slice_start:slice_start + self.slice_hop] 273 | return feat.reshape(self.timesteps, self.feature_num, 1) 274 | 275 | def _get_label(self, hdf_name, slice_start): 276 | label = self.hdf_refs[hdf_name]["label"] 277 | label = self._pad(label) 278 | label = label[slice_start:slice_start + self.slice_hop] 279 | return to_categorical(label, num_classes=2) 280 | 281 | def _pre_yield(self, feature, label): 282 | feat_len = len(feature) 283 | label_len = len(label) 284 | 285 | if (feat_len == self.timesteps) and (label_len == self.timesteps): 286 | # All normal 287 | return feature, label 288 | 289 | # The length of feature and label are inconsistent. Trim to the same size as the shorter one. 290 | if feat_len > label_len: 291 | feature = feature[:label_len] 292 | feat_len = len(feature) 293 | else: 294 | label = label[:feat_len] 295 | label_len = len(label) 296 | 297 | return feature, label -------------------------------------------------------------------------------- /models/pyramid_net.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | 4 | import tensorflow as tf 5 | import logging 6 | logging.disable(logging.WARNING) 7 | 8 | 9 | class ShakeDrop(tf.keras.layers.Layer): 10 | """Shake drop layer""" 11 | 12 | def __init__(self, prob, min_alpha=-1, max_alpha=1, min_beta=0, max_beta=1, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.prob = prob 16 | self.min_alpha = min_alpha 17 | self.max_alpha = max_alpha 18 | self.min_beta = min_beta 19 | self.max_beta = max_beta 20 | 21 | def call(self, inputs, is_training=True): 22 | if is_training: 23 | in_shape = tf.shape(inputs) 24 | random_tensor = self.prob 25 | random_tensor += tf.random.uniform(in_shape, dtype=tf.float32) 26 | binary_tensor = tf.floor(random_tensor) 27 | 28 | alpha_values = tf.random.uniform([in_shape[0], 1, 1, 1], minval=self.min_alpha, maxval=self.max_alpha) 29 | beta_values = tf.random.uniform([in_shape[0], 1, 1, 1], minval=self.min_beta, maxval=self.max_beta) 30 | rand_forward = binary_tensor + alpha_values - binary_tensor * alpha_values 31 | rand_backward = binary_tensor + beta_values - binary_tensor * beta_values 32 | outputs = inputs * rand_backward + tf.stop_gradient( 33 | inputs*rand_forward - inputs*rand_backward # noqa: E226 34 | ) 35 | return outputs 36 | 37 | expected_alpha = (self.min_alpha + self.max_alpha) / 2 38 | return (self.prob + expected_alpha - self.prob * expected_alpha) * inputs 39 | 40 | def get_config(self): 41 | config = super().get_config().copy() 42 | config.update({ 43 | "prob": self.prob, 44 | "min_alpha": self.min_alpha, 45 | "max_alpha": self.max_alpha, 46 | "min_beta": self.min_beta, 47 | "max_beta": self.max_beta 48 | }) 49 | return config 50 | 51 | 52 | class PyramidBlock(tf.keras.layers.Layer): 53 | """Pyramid block for building pyramid net""" 54 | 55 | def __init__(self, out_channel, stride=1, padding="same", prob=1.0, shakedrop=True, **kwargs): 56 | super().__init__(**kwargs) 57 | 58 | self.shakedrop = shakedrop 59 | self.padding = padding 60 | self.stride = stride 61 | self.prob = prob 62 | self.downsample = stride == 2 63 | self.out_channel = int(out_channel) 64 | 65 | conv_init = tf.keras.initializers.VarianceScaling(scale=0.1, seed=int(time.time())) 66 | 67 | self.avgpool = tf.keras.layers.AveragePooling2D(strides=2, padding='same') 68 | self.batch_norm_1 = tf.keras.layers.BatchNormalization() 69 | self.conv_1 = tf.keras.layers.Conv2D( 70 | out_channel, 3, strides=stride, padding=padding, kernel_initializer=conv_init, activation='relu' 71 | ) 72 | self.batch_norm_2 = tf.keras.layers.BatchNormalization() 73 | self.relu = tf.keras.layers.ReLU() 74 | self.conv_2 = tf.keras.layers.Conv2D( 75 | out_channel, 3, strides=1, padding=padding, use_bias=False, kernel_initializer=conv_init, activation='relu' 76 | ) 77 | self.batch_norm_3 = tf.keras.layers.BatchNormalization() 78 | 79 | if shakedrop: 80 | self.shakedrop_layer = ShakeDrop(prob) 81 | 82 | def call(self, inputs, is_training=True): 83 | res = self._shortcut(inputs) 84 | output = self.batch_norm_1(inputs) 85 | output = self.conv_1(output) 86 | output = self.batch_norm_2(output) 87 | output = self.relu(output) 88 | output = self.conv_2(output) 89 | output = self.batch_norm_3(output) 90 | if self.shakedrop: 91 | output = self.shakedrop_layer(output, is_training=is_training) 92 | 93 | return output + res 94 | 95 | def _shortcut(self, inputs): 96 | out = inputs 97 | if self.downsample: 98 | out = self.avgpool(inputs) 99 | 100 | num_filters = inputs.shape[3] 101 | if num_filters != self.out_channel: 102 | diff = self.out_channel - num_filters 103 | assert diff > 0 104 | padding = [[0, 0], [0, 0], [0, 0], [0, diff]] 105 | out = tf.pad(out, padding) 106 | return out 107 | 108 | def get_config(self): 109 | config = super().get_config().copy() 110 | config.update({ 111 | "prob": self.prob, 112 | "out_channel": self.out_channel, 113 | "stride": self.stride, 114 | "padding": self.padding, 115 | "shakedrop": self.shakedrop 116 | }) 117 | return config 118 | 119 | 120 | def _make_blocks(num_blocks, kernel_sizes, probs, stride=2, shakedrop=True): 121 | assert len(kernel_sizes) == num_blocks, f"{num_blocks} {len(kernel_sizes)}" 122 | assert len(probs) == num_blocks, f"{num_blocks} {len(probs)}" 123 | blocks = [] 124 | for kernel_size, prob in zip(kernel_sizes, probs): 125 | blocks.append(PyramidBlock(kernel_size, prob=prob, shakedrop=shakedrop, stride=stride)) 126 | stride = 1 127 | return blocks 128 | 129 | 130 | class PyramidNet(tf.keras.Model): 131 | """Pyramid Net with shake drop layer""" 132 | 133 | def __init__( 134 | self, 135 | out_classes=6, 136 | min_kernel_size=16, 137 | depth=110, 138 | alpha=270, 139 | shakedrop=True, 140 | semi_loss_weight=1, 141 | semi_xi=1e-6, 142 | semi_epsilon=40, 143 | semi_iters=2, 144 | **kwargs 145 | ): 146 | super().__init__(**kwargs) 147 | 148 | self.out_classes = out_classes 149 | self.min_kernel_size = min_kernel_size 150 | self.depth = depth 151 | self.alpha = alpha 152 | self.shakedrop = shakedrop 153 | self.semi_loss_weight = semi_loss_weight 154 | self.semi_xi = semi_xi 155 | self.semi_epsilon = semi_epsilon 156 | self.semi_iters = semi_iters 157 | 158 | self.kl_loss = tf.keras.losses.KLDivergence() 159 | self.loss_tracker = tf.keras.metrics.Mean(name="loss") 160 | 161 | if (depth - 2) % 6 != 0: 162 | raise ValueError(f"Value of 'depth' - 2 should be divisible by 6. Received: {depth}.") 163 | 164 | n_units = (depth - 2) // 6 165 | self.kernel_sizes = [min_kernel_size] + list(map( 166 | lambda x: math.ceil(alpha * (x + 1)) / (3 * n_units) + min_kernel_size, 167 | list(range(n_units * 3)) 168 | )) 169 | 170 | self.conv_1 = tf.keras.layers.Conv2D( 171 | self.kernel_sizes[0], 172 | (7, 7), 173 | strides=(2, 2), 174 | use_bias=False, 175 | activation='relu', 176 | kernel_initializer=tf.keras.initializers.HeNormal() 177 | ) 178 | self.batch_norm_1 = tf.keras.layers.BatchNormalization(name="batch_norm_1") 179 | self.relu_1 = tf.keras.layers.ReLU(name="relu_1") 180 | self.maxpool = tf.keras.layers.MaxPool2D(strides=2, name="max_pool") 181 | self.batch_norm_out = tf.keras.layers.BatchNormalization(name="batch_norm_2") 182 | 183 | total_blocks = n_units * 3 184 | calc_prob = lambda cur_layer: 1 - (cur_layer + 1) / total_blocks * 0.5 185 | self.kernel_sizes = self.kernel_sizes[1:] 186 | self.blocks = _make_blocks( 187 | n_units, 188 | kernel_sizes=self.kernel_sizes[:n_units], 189 | probs=[calc_prob(idx) for idx in range(n_units)], 190 | shakedrop=shakedrop, 191 | stride=1 192 | ) 193 | self.blocks += _make_blocks( 194 | n_units, 195 | kernel_sizes=self.kernel_sizes[n_units:n_units * 2], 196 | probs=[calc_prob(idx) for idx in range(n_units, n_units * 2)], 197 | shakedrop=shakedrop, 198 | stride=2 199 | ) 200 | self.blocks += _make_blocks( 201 | n_units, 202 | kernel_sizes=self.kernel_sizes[n_units*2:n_units*3], # noqa: E226 203 | probs=[calc_prob(idx) for idx in range(n_units*2, n_units*3)], # noqa: E226 204 | shakedrop=shakedrop, 205 | stride=2 206 | ) 207 | 208 | self.relu_out = tf.keras.layers.ReLU(name="relu_out") 209 | self.avgpool = tf.keras.layers.AveragePooling2D(pool_size=(1, 11), name="avg_pool") 210 | self.flatten = tf.keras.layers.Flatten(name="flatten") 211 | self.dense = tf.keras.layers.Dense(out_classes * 19, activation='sigmoid', name="dense_out") 212 | self.reshape = tf.keras.layers.Reshape((19, out_classes)) 213 | 214 | def call(self, inputs, is_training=True): # pylint: disable=W0221 215 | enc = self.conv_1(inputs) 216 | enc = self.batch_norm_1(enc) 217 | enc = self.relu_1(enc) 218 | b_out = self.maxpool(enc) 219 | 220 | for block in self.blocks: 221 | b_out = block(b_out, is_training=is_training) 222 | 223 | output = self.relu_out(b_out) 224 | output = self.avgpool(output) 225 | output = self.flatten(output) 226 | output = self.dense(output) 227 | return self.reshape(output) 228 | 229 | def train_step(self, data): 230 | data1, data2 = data 231 | semi = False 232 | if isinstance(data1, tuple): 233 | # Semi-supervise learning 234 | assert isinstance(data2, tf.Tensor) 235 | super_feat, super_label = data1 236 | # unsup_feat, _ = data2 237 | unsup_feat = data2 238 | semi = True 239 | else: 240 | super_feat = data1 241 | super_label = data2 242 | 243 | with tf.GradientTape() as tape: 244 | super_pred = self(super_feat, is_training=True) 245 | loss = self._compute_supervised_loss(super_label, super_pred) 246 | if semi: 247 | loss += self._compute_unsupervised_loss(unsup_feat) * self.semi_loss_weight 248 | 249 | trainable_vars = self.trainable_variables 250 | grads = tape.gradient(loss, trainable_vars) 251 | # grads = self.optimizer._clip_gradients(grads) 252 | self.optimizer.apply_gradients(zip(grads, trainable_vars)) 253 | self.loss_tracker.update_state(loss) 254 | self.compiled_metrics.update_state(super_label, super_pred) 255 | result = {m.name: m.result() for m in self.metrics} 256 | result.update({"loss": self.loss_tracker.result()}) 257 | return result 258 | 259 | def test_step(self, data): 260 | data1, data2 = data 261 | semi = False 262 | if isinstance(data1, tuple): 263 | # Semi-supervise learning 264 | assert isinstance(data2, tf.Tensor) 265 | super_feat, super_label = data1 266 | unsup_feat = data2 267 | semi = True 268 | else: 269 | super_feat = data1 270 | super_label = data2 271 | 272 | super_pred = self(super_feat, is_training=False) 273 | loss = self._compute_supervised_loss(super_label, super_pred) 274 | if semi: 275 | loss += self._compute_unsupervised_loss(unsup_feat) * self.semi_loss_weight 276 | 277 | self.compiled_metrics.update_state(super_label, super_pred) 278 | self.loss_tracker.update_state(loss) 279 | result = {m.name: m.result() for m in self.metrics} 280 | result.update({"loss": self.loss_tracker.result()}) 281 | return result 282 | 283 | def _compute_supervised_loss(self, label, pred): 284 | loss = self.compiled_loss(label, pred) 285 | empahsize_channel = [1, 2, 4] 286 | weight = 0.7 287 | emp_loss = 0 288 | for channel in empahsize_channel: 289 | emp_loss += self.compiled_loss(label[:, :, channel], pred[:, :, channel]) 290 | return loss * (1 - weight) + emp_loss * weight 291 | 292 | def _compute_unsupervised_loss(self, unsup_feat): 293 | """Computes VAT loss""" 294 | unsup_pred = self(unsup_feat) 295 | r_adv = self._gen_virtual_adv_perturbation(unsup_feat, unsup_pred) 296 | tf.stop_gradient(unsup_pred) 297 | unsup_pred_copy = unsup_pred 298 | adv_pred = self(unsup_feat + r_adv) 299 | loss = self.kl_loss(unsup_pred_copy, adv_pred) 300 | return tf.identity(loss) 301 | 302 | def _gen_virtual_adv_perturbation(self, unsup_feat, unsup_pred): 303 | self._switch_batch_norm_trainable_stat() 304 | 305 | perturb = tf.random.normal(tf.shape(unsup_feat)) 306 | for _ in range(self.semi_iters): 307 | perturb = self.semi_xi * _normalize(perturb) 308 | unsup_pred_copy = unsup_pred 309 | perturb_pred = self(unsup_feat + perturb) 310 | dist = self.kl_loss(unsup_pred_copy, perturb_pred) 311 | grad = tf.gradients(dist, [perturb], aggregation_method=2)[0] 312 | perturb = tf.stop_gradient(grad) 313 | 314 | self._switch_batch_norm_trainable_stat() 315 | return self.semi_epsilon * _normalize(perturb) 316 | 317 | def _switch_batch_norm_trainable_stat(self): 318 | for layer in self.layers: 319 | if isinstance(layer, tf.keras.layers.BatchNormalization): 320 | layer.trainable ^= True 321 | 322 | def get_config(self): 323 | return { 324 | "class_name": self.__class__.__name__, 325 | "config": { 326 | "out_classes": self.out_classes, 327 | "min_kernel_size": self.min_kernel_size, 328 | "depth": self.depth, 329 | "alpha": self.alpha, 330 | "shakedrop": self.shakedrop, 331 | "semi_loss_weight": self.semi_loss_weight, 332 | "semi_xi": self.semi_xi, 333 | "semi_epsilon": self.semi_epsilon, 334 | "semi_iters": self.semi_iters 335 | } 336 | } 337 | 338 | 339 | def _normalize(tensor): 340 | tensor /= (1e-12 + tf.reduce_max(tf.abs(tensor), range(1, len(tensor.shape)), keepdims=True)) 341 | tensor /= tf.sqrt(1e-6 + tf.reduce_sum(tensor**2, range(1, len(tensor.shape)), keepdims=True)) 342 | return tensor 343 | -------------------------------------------------------------------------------- /vocal/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | import subprocess 5 | from os.path import join as jpath 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | 9 | import sys 10 | MODULE_PATH = os.path.abspath(f"{os.path.split(__file__)[0]}/..") 11 | if sys.path[0] != MODULE_PATH: sys.path.insert(0, MODULE_PATH) 12 | 13 | import tensorflow as tf 14 | tf.get_logger().setLevel('ERROR') 15 | 16 | import h5py 17 | import numpy as np 18 | from spleeter.separator import Separator 19 | from spleeter.utils.logging import logger as sp_logger 20 | 21 | from utils import load_audio, write_yaml, get_logger, resolve_dataset_type, parallel_generator, ensure_path_exists, get_filename, LazyLoader 22 | from constants import datasets as d_struct 23 | from base import BaseTranscription, BaseDatasetLoader 24 | from feature.cfp import extract_vocal_cfp, _extract_vocal_cfp 25 | from setting_loaders import VocalSettings 26 | from vocal import labels as lextor 27 | from vocal.prediction import predict 28 | from vocal.inference import infer_interval, infer_midi 29 | from train import get_train_val_feat_file_list 30 | pyramid_net = LazyLoader('pyramid_net', globals(), 'models.pyramid_net') 31 | 32 | from vocal_contour.app import VocalContourTranscription 33 | vcapp = VocalContourTranscription(jpath(MODULE_PATH, 'defaults', 'vocal_contour.yaml')) 34 | 35 | logger = get_logger("Vocal Transcription") 36 | 37 | 38 | class SpleeterError(Exception): 39 | """Wrapper exception class around Spleeter errors""" 40 | pass 41 | 42 | 43 | class VocalTranscription(BaseTranscription): 44 | """Application class for vocal note transcription""" 45 | 46 | def __init__(self, conf_path=None): 47 | super().__init__(VocalSettings, conf_path=conf_path) 48 | 49 | # Disable logging information of Spleeter 50 | sp_logger.setLevel(40) 51 | 52 | def transcribe(self, input_audio, model_path=None, output="./"): 53 | """Transcribe vocal notes in the audio""" 54 | 55 | logger.info("Separating vocal track from the audio...") 56 | command = ["spleeter", "separate", input_audio, "-o", "./"] 57 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 58 | _, error = process.communicate() 59 | if process.returncode != 0: 60 | raise SpleeterError(error.decode("utf-8")) 61 | 62 | # Resolve the path of separated output files 63 | folder_path = jpath("./", get_filename(input_audio)) 64 | vocal_wav_path = jpath(folder_path, "vocals.wav") 65 | wav, fs = load_audio(vocal_wav_path) 66 | 67 | # Clean out the output files 68 | shutil.rmtree(folder_path) 69 | 70 | logger.info("Loading model...") 71 | model, model_settings = self._load_model(model_path) 72 | 73 | logger.info("Extracting feature...") 74 | feature = _extract_vocal_cfp( 75 | wav, 76 | fs, 77 | down_fs=model_settings.feature.sampling_rate, 78 | hop=model_settings.feature.hop_size, 79 | fr=model_settings.feature.frequency_resolution, 80 | fc=model_settings.feature.frequency_center, 81 | tc=model_settings.feature.time_center, 82 | g=model_settings.feature.gamma, 83 | bin_per_octave=model_settings.feature.bins_per_octave 84 | ) 85 | 86 | logger.info("Predicting...") 87 | pred = predict(feature, model) 88 | 89 | logger.info("Infering notes...") 90 | interval = infer_interval( 91 | pred, 92 | ctx_len=model_settings.inference.context_length, 93 | threshold=model_settings.inference.threshold, 94 | min_dura=model_settings.inference.min_duration, 95 | t_unit=model_settings.feature.hop_size 96 | ) 97 | 98 | logger.info("Extracting pitch contour") 99 | agg_f0 = vcapp.transcribe(input_audio=input_audio, model_path=model_settings.inference.pitch_model, output=output) 100 | 101 | logger.info("Inferencing MIDI...") 102 | midi = infer_midi(interval, agg_f0, t_unit=model_settings.feature.hop_size) 103 | 104 | self._output_midi(output=output, input_audio=input_audio, midi=midi) 105 | logger.info("Transcription finished") 106 | return midi 107 | 108 | def generate_feature(self, dataset_path, vocal_settings=None, num_threads=4): 109 | """Extract the feature of the whole dataset""" 110 | 111 | settings = self._validate_and_get_settings(vocal_settings) 112 | 113 | # Check labeled data or not 114 | is_labeled = False 115 | if os.path.isdir(jpath(dataset_path, 'labels')): 116 | is_labeled = True 117 | 118 | if is_labeled: 119 | # Build instance mapping 120 | struct = d_struct.VocalAlignStructure 121 | label_extractor = lextor.VocalAlignLabelExtraction 122 | 123 | # Fetching wav files 124 | train_data = struct.get_train_data_pair(dataset_path=dataset_path) 125 | test_data = struct.get_test_data_pair(dataset_path=dataset_path) 126 | logger.info("Number of total training wavs: %d", len(train_data)) 127 | logger.info("Number of total testing wavs: %d", len(test_data)) 128 | 129 | # Resolve feature output path 130 | train_feat_out_path, test_feat_out_path = self._resolve_feature_output_path(dataset_path, settings) 131 | logger.info("Output training feature to %s", train_feat_out_path) 132 | logger.info("Output testing feature to %s", test_feat_out_path) 133 | 134 | # Feature extraction 135 | logger.info( 136 | "Start extract training feature. " 137 | "This may take time to finish and affect the computer's performance.", 138 | ) 139 | wav_paths = _vocal_separation([data[0] for data in train_data], jpath(dataset_path, "train_wavs_spleeter")) 140 | train_data = _validate_order_and_get_new_pair(wav_paths, train_data) 141 | _parallel_feature_extraction(train_data, label_extractor, train_feat_out_path, settings.feature, num_threads=num_threads) 142 | 143 | # Feature extraction 144 | logger.info( 145 | "Start extract testing feature. " 146 | "This may take time to finish and affect the computer's performance." 147 | ) 148 | wav_paths = _vocal_separation([data[0] for data in test_data], jpath(dataset_path, "test_wavs_spleeter")) 149 | test_data = _validate_order_and_get_new_pair(wav_paths, test_data) 150 | _parallel_feature_extraction(test_data, label_extractor, test_feat_out_path, settings.feature, num_threads=num_threads) 151 | 152 | write_yaml(settings.to_json(), jpath(train_feat_out_path, ".success.yaml")) 153 | write_yaml(settings.to_json(), jpath(test_feat_out_path, ".success.yaml")) 154 | else: 155 | # Build instance mapping 156 | struct = d_struct.UnlabeledStructure 157 | label_extractor = lextor.UnlabeledLabelExtraction 158 | 159 | # Fetching wav files 160 | train_data = struct.get_train_data_pair(dataset_path=dataset_path) 161 | logger.info("Number of total semi training wavs: %d", len(train_data)) 162 | 163 | # Resolve feature output path 164 | train_feat_out_path = self._resolve_semi_feature_output_path(dataset_path, settings) 165 | logger.info("Output semi feature to %s", train_feat_out_path) 166 | 167 | # Feature extraction 168 | logger.info( 169 | "Start extract semi feature. " 170 | "This may take time to finish and affect the computer's performance." 171 | ) 172 | wav_paths = _vocal_separation([data for data in train_data], jpath(dataset_path, "train_wavs_spleeter")) 173 | train_data = _semi_validate_order_and_get_new_pair(wav_paths, train_data) 174 | _semi_parallel_feature_extraction(train_data, label_extractor, train_feat_out_path, settings.feature, num_threads=num_threads) 175 | 176 | write_yaml(settings.to_json(), jpath(train_feat_out_path, ".success.yaml")) 177 | logger.info("All done") 178 | 179 | def train(self, feature_folder, semi_feature_folder=None, model_name=None, input_model_path=None, vocal_settings=None): 180 | """Train model""" 181 | 182 | settings = self._validate_and_get_settings(vocal_settings) 183 | 184 | if input_model_path is not None: 185 | logger.info("Continue to train on model: %s", input_model_path) 186 | model, prev_set = self._load_model(input_model_path) 187 | settings.model.save_path = prev_set.model.save_path 188 | 189 | logger.info("Constructing dataset instance") 190 | split = settings.training.steps / (settings.training.steps + settings.training.val_steps) 191 | train_feat_files, val_feat_files = get_train_val_feat_file_list(feature_folder, split=split) 192 | 193 | output_types = (tf.float32, tf.float32) 194 | output_shapes = ((settings.training.context_length*2 + 1, 174, 9), (19, 6)) 195 | train_dataset = VocalDatasetLoader( 196 | ctx_len=settings.training.context_length, 197 | feature_files=train_feat_files, 198 | num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps 199 | ) \ 200 | .get_dataset(settings.training.batch_size, output_types=output_types, output_shapes=output_shapes) 201 | 202 | val_dataset = VocalDatasetLoader( 203 | ctx_len=settings.training.context_length, 204 | feature_files=val_feat_files, 205 | num_samples=settings.training.epoch * settings.training.val_batch_size * settings.training.val_steps 206 | ) \ 207 | .get_dataset(settings.training.val_batch_size, output_types=output_types, output_shapes=output_shapes) 208 | if semi_feature_folder is not None: 209 | # Semi-supervise learning dataset. 210 | feat_files = glob.glob(f"{semi_feature_folder}/*.hdf") 211 | semi_dataset = VocalDatasetLoader( 212 | is_labeled = False, 213 | ctx_len=settings.training.context_length, 214 | feature_files=feat_files, 215 | num_samples=settings.training.epoch * settings.training.batch_size * settings.training.steps 216 | ) \ 217 | .get_dataset(settings.training.batch_size, output_types=(output_types[0]), output_shapes=(output_shapes[0])) 218 | train_dataset = tf.data.Dataset.zip((train_dataset, semi_dataset)) 219 | 220 | if input_model_path is None: 221 | logger.info("Constructing new model") 222 | model = self.get_model(settings) 223 | 224 | optimizer = tf.keras.optimizers.Adam(learning_rate=settings.training.init_learning_rate) 225 | model.compile(optimizer=optimizer, loss='bce', metrics=['accuracy', 'binary_accuracy']) 226 | 227 | logger.info("Resolving model output path") 228 | if model_name is None: 229 | model_name = str(datetime.now()).replace(" ", "-").replace(":", "_") 230 | if not model_name.startswith(settings.model.save_prefix): 231 | model_name = settings.model.save_prefix + "_" + model_name 232 | model_save_path = jpath(settings.model.save_path, model_name) 233 | ensure_path_exists(model_save_path) 234 | write_yaml(settings.to_json(), jpath(model_save_path, "configurations.yaml")) 235 | logger.info("Model output to: %s", model_save_path) 236 | 237 | logger.info("Constructing callbacks") 238 | callbacks = [ 239 | tf.keras.callbacks.EarlyStopping(patience=settings.training.early_stop, monitor="val_loss"), 240 | tf.keras.callbacks.ModelCheckpoint(model_save_path, monitor="val_loss") 241 | ] 242 | logger.info("Callback list: %s", callbacks) 243 | 244 | logger.info("Start training") 245 | history = model.fit( 246 | train_dataset, 247 | validation_data=val_dataset, 248 | epochs=settings.training.epoch, 249 | steps_per_epoch=settings.training.steps, 250 | validation_steps=settings.training.val_steps, 251 | callbacks=callbacks, 252 | use_multiprocessing=True, 253 | workers=8 254 | ) 255 | return model_save_path, history 256 | 257 | def get_model(self, settings): 258 | """Get the Pyramid model""" 259 | 260 | return pyramid_net.PyramidNet( 261 | out_classes=6, 262 | min_kernel_size=settings.model.min_kernel_size, 263 | depth=settings.model.depth, 264 | alpha=settings.model.alpha, 265 | shakedrop=settings.model.shake_drop, 266 | semi_loss_weight=settings.model.semi_loss_weight, 267 | semi_xi=settings.model.semi_xi, 268 | semi_epsilon=settings.model.semi_epsilon, 269 | semi_iters=settings.model.semi_iterations 270 | ) 271 | 272 | 273 | def _validate_order_and_get_new_pair(wav_paths, data_pair): 274 | wavs = [os.path.basename(wav) for wav in wav_paths] 275 | ori_wavs = [os.path.basename(data[0]) for data in data_pair] 276 | assert wavs == ori_wavs 277 | return [(wav_path, label_path) for wav_path, (_, label_path) in zip(wav_paths, data_pair)] 278 | 279 | # For semi features 280 | def _semi_validate_order_and_get_new_pair(wav_paths, data_pair): 281 | wavs = [os.path.basename(wav) for wav in wav_paths] 282 | ori_wavs = [os.path.basename(data) for data in data_pair] 283 | assert wavs == ori_wavs 284 | 285 | return [(wav_path) for wav_path, _ in zip(wav_paths, data_pair)] 286 | 287 | 288 | def _vocal_separation(wav_list, out_folder): 289 | wavs = OrderedDict({os.path.basename(wav): wav for wav in wav_list}) 290 | if os.path.exists(out_folder): 291 | # There are already some separated audio. 292 | sep_wavs = set(os.listdir(out_folder)) 293 | diff_wavs = set(wavs.keys()) - sep_wavs 294 | logger.debug("Audio to be separated: %s", diff_wavs) 295 | 296 | # Check the difference of the separated audio and the received audio list. 297 | done_wavs = set(wavs.keys()) - diff_wavs 298 | wavs_copy = wavs.copy() 299 | for dwav in done_wavs: 300 | del wavs_copy[dwav] 301 | wav_list = list(wavs_copy.values()) 302 | 303 | out_list = [jpath(out_folder, wav) for wav in wavs] 304 | if len(wav_list) > 0: 305 | separator = Separator('spleeter:2stems') 306 | separator._params["stft_backend"] = "librosa" 307 | for idx, wav_path in enumerate(wav_list, 1): 308 | logger.info("Separation Progress: %d/%d - %s", idx, len(wav_list), wav_path) 309 | separator.separate_to_file(wav_path, out_folder) 310 | 311 | # The separated tracks are stored in sub-folders. 312 | # Move the vocal track to the desired folder and rename them. 313 | fname, _ = os.path.splitext(os.path.basename(wav_path)) 314 | sep_folder = jpath(out_folder, fname) 315 | vocal_track = jpath(sep_folder, "vocals.wav") 316 | shutil.move(vocal_track, jpath(out_folder, fname + ".wav")) 317 | shutil.rmtree(sep_folder) 318 | return out_list 319 | 320 | 321 | def _all_in_one_extract(data_pair, label_extractor, t_unit, **feat_kargs): 322 | wav, label = data_pair 323 | logger.debug("Extracting vocal CFP feature") 324 | feature = extract_vocal_cfp(wav, **feat_kargs) 325 | logger.debug("Extracting label") 326 | label = label_extractor.extract_label(label, t_unit=t_unit) 327 | return feature, label 328 | 329 | # For semi features 330 | def _semi_all_in_one_extract(data_pair, label_extractor, t_unit, **feat_kargs): 331 | wav = data_pair 332 | logger.debug("Extracting vocal CFP feature") 333 | feature = extract_vocal_cfp(wav, **feat_kargs) 334 | return feature 335 | 336 | 337 | def _parallel_feature_extraction( 338 | data_pair, label_extractor, out_path, feat_settings, num_threads=4 339 | ): 340 | feat_extract_params = { 341 | "hop": feat_settings.hop_size, 342 | "fr": feat_settings.frequency_resolution, 343 | "fc": feat_settings.frequency_center, 344 | "tc": feat_settings.time_center, 345 | "g": feat_settings.gamma, 346 | "bin_per_octave": feat_settings.bins_per_octave 347 | } 348 | 349 | iters = enumerate( 350 | parallel_generator( 351 | _all_in_one_extract, 352 | data_pair, 353 | max_workers=num_threads, 354 | chunk_size=num_threads, 355 | label_extractor=label_extractor, 356 | t_unit=feat_settings.hop_size, 357 | **feat_extract_params 358 | ) 359 | ) 360 | for idx, ((feature, label), audio_idx) in iters: 361 | audio = data_pair[audio_idx][0] 362 | logger.info("Progress: %s/%s - %s", idx + 1, len(data_pair), audio) 363 | 364 | # Trim to the same length 365 | max_len = min(len(feature), len(label)) 366 | feature = feature[:max_len] 367 | label = label[:max_len] 368 | 369 | basename = os.path.basename(audio) 370 | filename, _ = os.path.splitext(basename) 371 | out_hdf = jpath(out_path, filename + ".hdf") 372 | with h5py.File(out_hdf, "w") as out_f: 373 | out_f.create_dataset("feature", data=feature, compression="gzip", compression_opts=3) 374 | out_f.create_dataset("label", data=label, compression="gzip", compression_opts=3) 375 | 376 | # For semi features 377 | def _semi_parallel_feature_extraction(data_pair, label_extractor, out_path, feat_settings, num_threads=4): 378 | feat_extract_params = { 379 | "hop": feat_settings.hop_size, 380 | "fr": feat_settings.frequency_resolution, 381 | "fc": feat_settings.frequency_center, 382 | "tc": feat_settings.time_center, 383 | "g": feat_settings.gamma, 384 | "bin_per_octave": feat_settings.bins_per_octave 385 | } 386 | 387 | iters = enumerate( 388 | parallel_generator( 389 | _semi_all_in_one_extract, 390 | data_pair, 391 | max_workers=num_threads, 392 | chunk_size=num_threads, 393 | label_extractor=label_extractor, 394 | t_unit=feat_settings.hop_size, 395 | **feat_extract_params 396 | ) 397 | ) 398 | for idx, ((feature), audio_idx) in iters: 399 | audio = data_pair[audio_idx] 400 | logger.info("Progress: %s/%s - %s", idx + 1, len(data_pair), audio) 401 | 402 | basename = os.path.basename(audio) 403 | filename, _ = os.path.splitext(basename) 404 | out_hdf = jpath(out_path, filename + ".hdf") 405 | with h5py.File(out_hdf, "w") as out_f: 406 | out_f.create_dataset("feature", data=feature, compression="gzip", compression_opts=3) 407 | 408 | 409 | class VocalDatasetLoader(BaseDatasetLoader): 410 | """Dataset loader of 'vocal' module""" 411 | 412 | def __init__(self, is_labeled=True, ctx_len=9, feature_folder=None, feature_files=None, num_samples=100, slice_hop=1): 413 | super().__init__( 414 | is_labeled=is_labeled, 415 | feature_folder=feature_folder, 416 | feature_files=feature_files, 417 | num_samples=num_samples, 418 | slice_hop=slice_hop 419 | ) 420 | self.ctx_len = ctx_len 421 | 422 | def _get_feature(self, hdf_name, slice_start): 423 | feat = self.hdf_refs[hdf_name]["feature"] 424 | 425 | pad_left = 0 426 | if slice_start - self.ctx_len < 0: 427 | pad_left = self.ctx_len - slice_start 428 | 429 | pad_right = 0 430 | if slice_start + self.ctx_len + 1 > len(feat): 431 | pad_right = slice_start + self.ctx_len + 1 - len(feat) 432 | 433 | start = max(slice_start - self.ctx_len, 0) 434 | end = min(slice_start + self.ctx_len + 1, len(feat)) 435 | feat = feat[start:end] 436 | if (pad_left > 0) or (pad_right > 0): 437 | feat = np.pad(feat, ((pad_left, pad_right), (0, 0), (0, 0))) 438 | 439 | return feat # Time x Freq x 9 440 | 441 | def _get_label(self, hdf_name, slice_start): 442 | label = self.hdf_refs[hdf_name]["label"] 443 | 444 | pad_left = 0 445 | if slice_start - self.ctx_len < 0: 446 | pad_left = self.ctx_len - slice_start 447 | 448 | pad_right = 0 449 | if slice_start + self.ctx_len + 1 > len(label): 450 | pad_right = slice_start + self.ctx_len + 1 - len(label) 451 | 452 | start = max(slice_start - self.ctx_len, 0) 453 | end = min(slice_start + self.ctx_len + 1, len(label)) 454 | label = label[start:end] 455 | if (pad_left > 0) or (pad_right > 0): 456 | label = np.pad(label, ((pad_left, pad_right), (0, 0))) 457 | 458 | return label # Time x 6 459 | --------------------------------------------------------------------------------