├── .gitignore ├── LICENSE ├── README.md ├── evaluate.py ├── onsets_and_frames ├── __init__.py ├── constants.py ├── dataset.py ├── decoding.py ├── harmonic_layers │ ├── __init__.py │ ├── conv3d_layer.py │ └── harmo_dilated.py ├── mel.py ├── midi.py ├── network_utils.py ├── transcriber.py └── utils.py ├── prepare_maestro.sh ├── requirements.txt ├── train.py └── transcribe.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data 3 | runs 4 | venv -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Jong Wook Kim 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch code for _TriAD_: Capturing harmonics with 3D convolutions 2 | 3 | This is the code accompayning **TriAD: Capturing harmonics with 3D convolutions**. 4 | It is mostly based on Jong Wook's [repository](https://github.com/jongwook/onsets-and-frames). 5 | 6 | ### Downloading Dataset 7 | You need to get two datasets: Maestro and Maps. Maestro is hosted in Google's servers, and you can download it and 8 | parse it using the `prepare_maestro.sh` script. 9 | When calling the script, use `-s` to indicate where will be Maestro downloaded; a symbolic link `data/MAESTRO` will be 10 | created pointing at the location where maestro was downloaded & unzipped. 11 | It will also take care resampling and encoding the files as FLAC. 12 | 13 | In case you have Maestro already in your computer, you can just use the bash script in Jong Wook's 14 | [repository](https://github.com/jongwook/onsets-and-frames). 15 | 16 | To obtain the MAPS dataset just download it from Jong Wook's 17 | [repository](https://github.com/jongwook/onsets-and-frames), and place it in data/MAPS 18 | 19 | ### Training 20 | 21 | All package requirements are contained in `requirements.txt`. To train the model, run: 22 | 23 | ```bash 24 | pip install -r requirements.txt 25 | python train.py 26 | ``` 27 | 28 | `train.py` is written using [sacred](https://sacred.readthedocs.io/), and accepts configuration options such as: 29 | 30 | ```bash 31 | python train.py with logdir=runs/model iterations=1000000 32 | ``` 33 | 34 | Trained models will be saved in the specified `logdir`, otherwise at a timestamped directory under `runs/`. 35 | 36 | ### Testing 37 | 38 | To evaluate the trained model using the MAPS database, run the following command to calculate the note and frame metrics: 39 | 40 | ```bash 41 | python evaluate.py 42 | ``` 43 | 44 | Specifying `--save-path` will output the transcribed MIDI file along with the piano roll images: 45 | 46 | ```bash 47 | python evaluate.py --save-path output/ 48 | ``` 49 | 50 | In order to test on the Maestro dataset's test split instead of the MAPS database, run: 51 | 52 | ```bash 53 | python evaluate.py MAESTRO test 54 | ``` 55 | 56 | ## Citing 57 | 58 | Please, if you use this repository or the model consider citing: 59 | ```text 60 | @inproceedings{Perez2023triad, 61 | author = {Perez, Miguel and Kirchhoff, Holger and Serra, Xavier} 62 | title = {TriAD: Capturing harmonics with 3D convolutions}, 63 | booktitle = {Proceedings of the 24th International Society for Music Information 64 | Retrieval Conference, {ISMIR} 2023, Milan, November 5-9, 2023}, 65 | } 66 | ``` 67 | 68 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | from mir_eval.multipitch import evaluate as evaluate_frames 8 | from mir_eval.transcription import precision_recall_f1_overlap as evaluate_notes 9 | from mir_eval.transcription_velocity import precision_recall_f1_overlap as evaluate_notes_with_velocity 10 | from mir_eval.util import midi_to_hz 11 | from scipy.stats import hmean 12 | from tqdm import tqdm 13 | 14 | from typing import Optional, Iterable 15 | 16 | import pandas as pd 17 | 18 | import onsets_and_frames.dataset as dataset_module 19 | from onsets_and_frames import * 20 | 21 | from sklearn.metrics import auc 22 | 23 | eps = sys.float_info.epsilon 24 | 25 | 26 | def evaluate(data, model, onset_threshold: float = 0.5, frame_threshold: float = 0.5, save_path=None, pr_au_thresholds: Optional[Iterable[float]] = np.arange(0.1, 0.9, 0.05)): 27 | metrics = defaultdict(list) 28 | 29 | for label in data: 30 | pred, losses = model.run_on_batch(label) 31 | 32 | for key, loss in losses.items(): 33 | metrics[key].append(loss.item()) 34 | 35 | for key, value in pred.items(): 36 | value.squeeze_(0).relu_() 37 | 38 | p_ref, i_ref, v_ref = extract_notes(label['onset'], label['frame'], label['velocity']) 39 | p_est, i_est, v_est = extract_notes(pred['onset'], pred['frame'], pred['velocity'], onset_threshold, frame_threshold) 40 | 41 | t_ref, f_ref = notes_to_frames(p_ref, i_ref, label['frame'].shape) 42 | t_est, f_est = notes_to_frames(p_est, i_est, pred['frame'].shape) 43 | 44 | scaling = HOP_LENGTH / SAMPLE_RATE 45 | 46 | i_ref = (i_ref * scaling).reshape(-1, 2) 47 | p_ref = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_ref]) 48 | i_est = (i_est * scaling).reshape(-1, 2) 49 | p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est]) 50 | 51 | t_ref = t_ref.astype(np.float64) * scaling 52 | f_ref = [np.array([midi_to_hz(MIN_MIDI + midi) for midi in freqs]) for freqs in f_ref] 53 | t_est = t_est.astype(np.float64) * scaling 54 | f_est = [np.array([midi_to_hz(MIN_MIDI + midi) for midi in freqs]) for freqs in f_est] 55 | 56 | p, r, f, o = evaluate_notes(i_ref, p_ref, i_est, p_est, offset_ratio=None) 57 | metrics['metric/note/precision'].append(p) 58 | metrics['metric/note/recall'].append(r) 59 | metrics['metric/note/f1'].append(f) 60 | metrics['metric/note/overlap'].append(o) 61 | 62 | p, r, f, o = evaluate_notes(i_ref, p_ref, i_est, p_est) 63 | metrics['metric/note-with-offsets/precision'].append(p) 64 | metrics['metric/note-with-offsets/recall'].append(r) 65 | metrics['metric/note-with-offsets/f1'].append(f) 66 | metrics['metric/note-with-offsets/overlap'].append(o) 67 | 68 | p, r, f, o = evaluate_notes_with_velocity(i_ref, p_ref, v_ref, i_est, p_est, v_est, 69 | offset_ratio=None, velocity_tolerance=0.1) 70 | metrics['metric/note-with-velocity/precision'].append(p) 71 | metrics['metric/note-with-velocity/recall'].append(r) 72 | metrics['metric/note-with-velocity/f1'].append(f) 73 | metrics['metric/note-with-velocity/overlap'].append(o) 74 | 75 | p, r, f, o = evaluate_notes_with_velocity(i_ref, p_ref, v_ref, i_est, p_est, v_est, velocity_tolerance=0.1) 76 | metrics['metric/note-with-offsets-and-velocity/precision'].append(p) 77 | metrics['metric/note-with-offsets-and-velocity/recall'].append(r) 78 | metrics['metric/note-with-offsets-and-velocity/f1'].append(f) 79 | metrics['metric/note-with-offsets-and-velocity/overlap'].append(o) 80 | 81 | frame_metrics = evaluate_frames(t_ref, f_ref, t_est, f_est) 82 | metrics['metric/frame/f1'].append(hmean([frame_metrics['Precision'] + eps, frame_metrics['Recall'] + eps]) - eps) 83 | 84 | for key, loss in frame_metrics.items(): 85 | metrics['metric/frame/' + key.lower().replace(' ', '_')].append(loss) 86 | 87 | if save_path is not None: 88 | os.makedirs(save_path, exist_ok=True) 89 | label_path = os.path.join(save_path, os.path.basename(label['path']) + '.label.png') 90 | save_pianoroll(label_path, label['onset'], label['frame']) 91 | pred_path = os.path.join(save_path, os.path.basename(label['path']) + '.pred.png') 92 | save_pianoroll(pred_path, pred['onset'], pred['frame']) 93 | midi_path = os.path.join(save_path, os.path.basename(label['path']) + '.pred.mid') 94 | save_midi(midi_path, p_est, i_est, v_est) 95 | 96 | if pr_au_thresholds is not None: 97 | p_frames = np.zeros_like(pr_au_thresholds) 98 | r_frames = np.zeros_like(pr_au_thresholds) 99 | for idx, threshold in enumerate(pr_au_thresholds): 100 | p_est, i_est, v_est = extract_notes(pred['onset'], pred['frame'], pred['velocity'], threshold, threshold) 101 | t_est, f_est = notes_to_frames(p_est, i_est, pred['frame'].shape) 102 | 103 | i_est = (i_est * scaling).reshape(-1, 2) 104 | p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est]) 105 | t_est = t_est.astype(np.float64) * scaling 106 | f_est = [np.array([midi_to_hz(MIN_MIDI + midi) for midi in freqs]) for freqs in f_est] 107 | 108 | frame_metrics = evaluate_frames(t_ref, f_ref, t_est, f_est) 109 | p_frames[idx] = frame_metrics['Precision'] 110 | r_frames[idx] = frame_metrics['Recall'] 111 | metrics['metric/frame/pr-auc'].append((auc(r_frames, p_frames))) 112 | 113 | return metrics 114 | 115 | 116 | @torch.no_grad() 117 | def evaluate_file(model_file, dataset, dataset_group, sequence_length, save_path, 118 | onset_threshold, frame_threshold, device, output_dir: Optional[str] = None): 119 | dataset_class = getattr(dataset_module, dataset) 120 | kwargs = {'sequence_length': sequence_length, 'device': device} 121 | if dataset_group is not None: 122 | kwargs['groups'] = [dataset_group] 123 | dataset = dataset_class(**kwargs) 124 | 125 | model = torch.load(model_file, map_location=device).eval() 126 | summary(model) 127 | 128 | if output_dir: 129 | os.makedirs(output_dir, exist_ok=False) 130 | df = pd.DataFrame(columns=['category', 'name', 'mean', 'std']) 131 | csv_file = os.path.join(output_dir, 'metrics.csv') 132 | summary_file = os.path.join(output_dir, 'model.txt') 133 | with open(summary_file, "w") as f: 134 | summary(model=model, file=f) 135 | 136 | metrics = evaluate(tqdm(dataset), model, onset_threshold, frame_threshold, save_path) 137 | 138 | for key, values in metrics.items(): 139 | if key.startswith('metric/'): 140 | _, category, name = key.split('/') 141 | print(f'{category:>32} {name:25}: {np.mean(values):.3f} ± {np.std(values):.3f}') 142 | if output_dir: 143 | df = pd.concat([df, pd.DataFrame([{'category': category, 'name': name, 'mean': np.mean(values), 'std': np.std(values)}])], ignore_index=True) 144 | if output_dir: 145 | df.to_csv(csv_file, index=False) 146 | 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('model_file', type=str) 151 | parser.add_argument('dataset', nargs='?', default='MAPS') 152 | parser.add_argument('dataset_group', nargs='?', default=None) 153 | parser.add_argument('--save-path', default=None) 154 | parser.add_argument('--sequence-length', default=None, type=int) 155 | parser.add_argument('--onset-threshold', default=0.5, type=float) 156 | parser.add_argument('--frame-threshold', default=0.5, type=float) 157 | parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') 158 | parser.add_argument('--output_dir', default=None) 159 | 160 | with torch.inference_mode(): 161 | evaluate_file(**vars(parser.parse_args())) 162 | -------------------------------------------------------------------------------- /onsets_and_frames/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .dataset import MAPS, MAESTRO 3 | from .decoding import extract_notes, notes_to_frames 4 | from .midi import save_midi 5 | from .transcriber import HPPNet, HPPNetDDD, HPPNetLess 6 | from .utils import summary, save_pianoroll, cycle 7 | -------------------------------------------------------------------------------- /onsets_and_frames/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | SAMPLE_RATE = 16000 5 | HOP_LENGTH = SAMPLE_RATE * 32 // 1000 6 | ONSET_LENGTH = SAMPLE_RATE * 32 // 1000 7 | OFFSET_LENGTH = SAMPLE_RATE * 32 // 1000 8 | HOPS_IN_ONSET = ONSET_LENGTH // HOP_LENGTH 9 | HOPS_IN_OFFSET = OFFSET_LENGTH // HOP_LENGTH 10 | MIN_MIDI = 21 11 | MAX_MIDI = 108 12 | 13 | N_MELS = 229 14 | MEL_FMIN = 30 15 | MEL_FMAX = SAMPLE_RATE // 2 16 | WINDOW_LENGTH = 2048 17 | 18 | DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | -------------------------------------------------------------------------------- /onsets_and_frames/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from abc import abstractmethod 4 | from glob import glob 5 | 6 | import numpy as np 7 | import soundfile 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | 11 | from .constants import * 12 | from .midi import parse_midi 13 | 14 | import pandas as pd 15 | 16 | 17 | class PianoRollAudioDataset(Dataset): 18 | def __init__(self, path, groups=None, sequence_length=None, seed=42, device=DEFAULT_DEVICE): 19 | self.path = path 20 | self.groups = groups if groups is not None else self.available_groups() 21 | self.sequence_length = sequence_length 22 | self.device = device 23 | self.random = np.random.RandomState(seed) 24 | self.preload = False 25 | 26 | # If preload is True, this will contain all the data, if not it will contain the paths to the data 27 | self.data = [] 28 | print(f"Loading {len(groups)} group{'s' if len(groups) > 1 else ''} " 29 | f"of {self.__class__.__name__} at {path}") 30 | for group in groups: 31 | for input_files in tqdm(self.files(group), desc='Loading group %s' % group): 32 | self.data.append(self.load(*input_files)) 33 | 34 | def __getitem__(self, index): 35 | if self.preload: 36 | data = self.data[index] 37 | else: 38 | data = torch.load(self.data[index]) 39 | result = dict(path=data['path']) 40 | 41 | if self.sequence_length is not None: 42 | audio_length = len(data['audio']) 43 | step_begin = self.random.randint(audio_length - self.sequence_length) // HOP_LENGTH 44 | n_steps = self.sequence_length // HOP_LENGTH 45 | step_end = step_begin + n_steps 46 | 47 | begin = step_begin * HOP_LENGTH 48 | end = begin + self.sequence_length 49 | 50 | result['audio'] = data['audio'][begin:end].to(self.device) 51 | result['label'] = data['label'][step_begin:step_end, :].to(self.device) 52 | result['velocity'] = data['velocity'][step_begin:step_end, :].to(self.device) 53 | else: 54 | result['audio'] = data['audio'].to(self.device) 55 | result['label'] = data['label'].to(self.device) 56 | result['velocity'] = data['velocity'].to(self.device).float() 57 | 58 | result['audio'] = result['audio'].float().div_(32768.0) 59 | result['onset'] = (result['label'] == 3).float() 60 | result['offset'] = (result['label'] == 1).float() 61 | result['frame'] = (result['label'] > 1).float() 62 | result['velocity'] = result['velocity'].float().div_(128.0) 63 | 64 | return result 65 | 66 | def __len__(self): 67 | return len(self.data) 68 | 69 | @classmethod 70 | @abstractmethod 71 | def available_groups(cls): 72 | """return the names of all available groups""" 73 | raise NotImplementedError 74 | 75 | @abstractmethod 76 | def files(self, group): 77 | """return the list of input files (audio_filename, tsv_filename) for this group""" 78 | raise NotImplementedError 79 | 80 | def load(self, audio_path, tsv_path): 81 | """ 82 | load an audio track and the corresponding labels 83 | 84 | Returns 85 | ------- 86 | A dictionary containing the following data: 87 | 88 | path: str 89 | the path to the audio file 90 | 91 | audio: torch.ShortTensor, shape = [num_samples] 92 | the raw waveform 93 | 94 | label: torch.ByteTensor, shape = [num_steps, midi_bins] 95 | a matrix that contains the onset/offset/frame labels encoded as: 96 | 3 = onset, 2 = frames after onset, 1 = offset, 0 = all else 97 | 98 | velocity: torch.ByteTensor, shape = [num_steps, midi_bins] 99 | a matrix that contains MIDI velocity values at the frame locations 100 | """ 101 | saved_data_path = audio_path.replace('.flac', '.pt').replace('.wav', '.pt') 102 | if os.path.exists(saved_data_path): 103 | if self.preload: 104 | return torch.load(saved_data_path) 105 | else: 106 | return saved_data_path 107 | 108 | audio, sr = soundfile.read(audio_path, dtype='int16') 109 | assert sr == SAMPLE_RATE 110 | 111 | audio = torch.ShortTensor(audio) 112 | audio_length = len(audio) 113 | 114 | n_keys = MAX_MIDI - MIN_MIDI + 1 115 | n_steps = (audio_length - 1) // HOP_LENGTH + 1 116 | 117 | label = torch.zeros(n_steps, n_keys, dtype=torch.uint8) 118 | velocity = torch.zeros(n_steps, n_keys, dtype=torch.uint8) 119 | 120 | tsv_path = tsv_path 121 | midi = np.loadtxt(tsv_path, delimiter='\t', skiprows=1) 122 | 123 | for onset, offset, note, vel in midi: 124 | left = int(round(onset * SAMPLE_RATE / HOP_LENGTH)) 125 | onset_right = min(n_steps, left + HOPS_IN_ONSET) 126 | frame_right = int(round(offset * SAMPLE_RATE / HOP_LENGTH)) 127 | frame_right = min(n_steps, frame_right) 128 | offset_right = min(n_steps, frame_right + HOPS_IN_OFFSET) 129 | 130 | f = int(note) - MIN_MIDI 131 | label[left:onset_right, f] = 3 132 | label[onset_right:frame_right, f] = 2 133 | label[frame_right:offset_right, f] = 1 134 | velocity[left:frame_right, f] = vel 135 | 136 | data = dict(path=audio_path, audio=audio, label=label, velocity=velocity) 137 | torch.save(data, saved_data_path) 138 | if self.preload: 139 | return data 140 | else: 141 | return saved_data_path 142 | 143 | 144 | class MAESTRO(PianoRollAudioDataset): 145 | 146 | def __init__(self, path='data/MAESTRO', groups=None, sequence_length=None, seed=42, device=DEFAULT_DEVICE): 147 | super().__init__(path, groups if groups is not None else ['train'], sequence_length, seed, device) 148 | 149 | @classmethod 150 | def available_groups(cls): 151 | return ['train', 'validation', 'test'] 152 | 153 | def files(self, group): 154 | if group not in self.available_groups(): 155 | # year-based grouping 156 | flacs = sorted(glob(os.path.join(self.path, group, '*.flac'))) 157 | if len(flacs) == 0: 158 | flacs = sorted(glob(os.path.join(self.path, group, '*.wav'))) 159 | 160 | midis = sorted(glob(os.path.join(self.path, group, '*.midi'))) 161 | files = list(zip(flacs, midis)) 162 | if len(files) == 0: 163 | raise RuntimeError(f'Group {group} is empty') 164 | else: 165 | metadata = pd.read_json(os.path.join(self.path, 'maestro-v3.0.0.json')) 166 | files = metadata[metadata.split == group][['audio_filename', 'midi_filename']].values 167 | 168 | files = [(audio if os.path.exists(audio) else audio.replace('.flac', '.wav'), midi) for audio, midi in files] 169 | files = [(self.path + "/" + audio, self.path + "/" + midi) for audio, midi in files] 170 | 171 | result = [] 172 | for audio_path, midi_path in files: 173 | tsv_filename = midi_path.replace('.midi', '.tsv').replace('.mid', '.tsv') 174 | if not os.path.exists(tsv_filename): 175 | midi = parse_midi(midi_path) 176 | np.savetxt(tsv_filename, midi, fmt='%.6f', delimiter='\t', header='onset,offset,note,velocity') 177 | result.append((audio_path, tsv_filename)) 178 | return result 179 | 180 | 181 | class MAPS(PianoRollAudioDataset): 182 | def __init__(self, path='data/MAPS', groups=None, sequence_length=None, seed=42, device=DEFAULT_DEVICE): 183 | super().__init__(path, groups if groups is not None else ['ENSTDkAm', 'ENSTDkCl'], sequence_length, seed, device) 184 | 185 | @classmethod 186 | def available_groups(cls): 187 | return ['AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'ENSTDkAm', 'ENSTDkCl', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2'] 188 | 189 | def files(self, group): 190 | flacs = glob(os.path.join(self.path, 'flac', '*_%s.flac' % group)) 191 | tsvs = [f.replace('/flac/', '/tsv/matched/').replace('.flac', '.tsv') for f in flacs] 192 | 193 | assert(all(os.path.isfile(flac) for flac in flacs)) 194 | assert(all(os.path.isfile(tsv) for tsv in tsvs)) 195 | 196 | return sorted(zip(flacs, tsvs)) -------------------------------------------------------------------------------- /onsets_and_frames/decoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def extract_notes(onsets, frames, velocity, onset_threshold=0.5, frame_threshold=0.5): 6 | """ 7 | Finds the note timings based on the onsets and frames information 8 | 9 | Parameters 10 | ---------- 11 | onsets: torch.FloatTensor, shape = [frames, bins] 12 | frames: torch.FloatTensor, shape = [frames, bins] 13 | velocity: torch.FloatTensor, shape = [frames, bins] 14 | onset_threshold: float 15 | frame_threshold: float 16 | 17 | Returns 18 | ------- 19 | pitches: np.ndarray of bin_indices 20 | intervals: np.ndarray of rows containing (onset_index, offset_index) 21 | velocities: np.ndarray of velocity values 22 | """ 23 | onsets = (onsets > onset_threshold).cpu().to(torch.uint8) 24 | frames = (frames > frame_threshold).cpu().to(torch.uint8) 25 | onset_diff = torch.cat([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], dim=0) == 1 26 | 27 | pitches = [] 28 | intervals = [] 29 | velocities = [] 30 | 31 | for nonzero in onset_diff.nonzero(): 32 | frame = nonzero[0].item() 33 | pitch = nonzero[1].item() 34 | 35 | onset = frame 36 | offset = frame 37 | velocity_samples = [] 38 | 39 | while onsets[offset, pitch].item() or frames[offset, pitch].item(): 40 | if onsets[offset, pitch].item(): 41 | velocity_samples.append(velocity[offset, pitch].item()) 42 | offset += 1 43 | if offset == onsets.shape[0]: 44 | break 45 | 46 | if offset > onset: 47 | pitches.append(pitch) 48 | intervals.append([onset, offset]) 49 | velocities.append(np.mean(velocity_samples) if len(velocity_samples) > 0 else 0) 50 | 51 | return np.array(pitches), np.array(intervals), np.array(velocities) 52 | 53 | 54 | def notes_to_frames(pitches, intervals, shape): 55 | """ 56 | Takes lists specifying notes sequences and return 57 | 58 | Parameters 59 | ---------- 60 | pitches: list of pitch bin indices 61 | intervals: list of [onset, offset] ranges of bin indices 62 | shape: the shape of the original piano roll, [n_frames, n_bins] 63 | 64 | Returns 65 | ------- 66 | time: np.ndarray containing the frame indices 67 | freqs: list of np.ndarray, each containing the frequency bin indices 68 | """ 69 | roll = np.zeros(tuple(shape)) 70 | for pitch, (onset, offset) in zip(pitches, intervals): 71 | roll[onset:offset, pitch] = 1 72 | 73 | time = np.arange(roll.shape[0]) 74 | freqs = [roll[t, :].nonzero()[0] for t in time] 75 | return time, freqs 76 | -------------------------------------------------------------------------------- /onsets_and_frames/harmonic_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv3d_layer import HarmConvBlock 2 | from .harmo_dilated import HarmonicDilatedConv 3 | -------------------------------------------------------------------------------- /onsets_and_frames/harmonic_layers/conv3d_layer.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..network_utils import CircularOctavePadding, Conv2D, Conv3D 6 | 7 | from typing import Optional, Iterable 8 | 9 | 10 | class HarmConvBlock(nn.Module): 11 | """ 12 | For our 3D harmonic convolutions 13 | """ 14 | def __init__(self, n_in_channels: int, n_out_channels: int, octave_depth: int = 3, 15 | dilation_rates: Optional[Iterable[int]] = None, time_width: int = 1, special_padding: bool = True, depthwise: bool = False, 16 | ): 17 | super(HarmConvBlock, self).__init__() 18 | if dilation_rates is None: 19 | dilation_rates = [0, 28, 16] 20 | self.dilation_rates = dilation_rates 21 | self.n_in_channels = n_in_channels 22 | self.n_out_channels = n_out_channels 23 | self.octave_depth = octave_depth 24 | self.time_width = time_width 25 | self.special_padding = special_padding 26 | self.use_always_3d = False 27 | self.using_3d_convolutions = False 28 | self.depthwise = depthwise 29 | 30 | if self.special_padding: 31 | padding = 0 32 | padding_layer = CircularOctavePadding 33 | else: 34 | padding = "same" 35 | padding_layer = nn.Identity 36 | 37 | if self.time_width > 1 or self.use_always_3d: 38 | self.using_3d_convolutions = True 39 | 40 | module_list = [] 41 | for dl in dilation_rates: 42 | if dl == 0: 43 | kernel_size_h = 1 44 | dilation = 1 45 | else: 46 | kernel_size_h = 2 47 | dilation = dl 48 | 49 | if self.using_3d_convolutions: 50 | module_list.append( 51 | nn.Sequential( 52 | padding_layer(kernel_size=(octave_depth, kernel_size_h, time_width), pitch_class_dilation=dilation), 53 | Conv3D(n_in_channels, n_out_channels, padding_mode="circular", kernel_size=(octave_depth, kernel_size_h, time_width), padding=padding, dilation=(1, dilation, 1), depthwise=depthwise) 54 | ) 55 | ) 56 | else: 57 | module_list.append( 58 | nn.Sequential( 59 | padding_layer(kernel_size=(octave_depth, kernel_size_h, time_width), pitch_class_dilation=dilation), 60 | Conv2D(n_in_channels, n_out_channels, padding_mode="circular", kernel_size=(octave_depth, kernel_size_h), padding=padding, dilation=(1, dilation), depthwise=depthwise) 61 | ) 62 | ) 63 | 64 | self.module_list = nn.ModuleList(module_list) 65 | 66 | def forward_2d(self, x): 67 | batch, channels, octaves, pitch_classes, frames = x.size() 68 | x = x.permute([0, 4, 1, 2, 3]).reshape([batch * frames, channels, octaves, pitch_classes]) # Stack frames at batch dimension 69 | outputs = None 70 | for module in self.module_list: 71 | if outputs is None: 72 | outputs = module(x) 73 | else: 74 | outputs += module(x) 75 | # Return to the original shape 76 | outputs = outputs.reshape([batch, frames, self.n_out_channels, octaves, pitch_classes]).permute([0, 2, 3, 4, 1]) 77 | return outputs 78 | 79 | def forward_3d(self, x): 80 | outputs = None 81 | for module in self.module_list: 82 | if outputs is None: 83 | outputs = module(x) 84 | else: 85 | outputs += module(x) 86 | return outputs 87 | 88 | def forward(self, x): 89 | if self.using_3d_convolutions: 90 | output = self.forward_3d(x) 91 | else: 92 | output = self.forward_2d(x) 93 | return F.relu(output) -------------------------------------------------------------------------------- /onsets_and_frames/harmonic_layers/harmo_dilated.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..network_utils import Conv2D 3 | import torch.nn.functional as F 4 | 5 | class HarmonicDilatedConv(nn.Module): 6 | """ 7 | From the HPPNet original code. It is fixed to 4bins per semitone (see dilation) 8 | """ 9 | 10 | def __init__(self, c_in, c_out, depthwise: bool = False) -> None: 11 | super(HarmonicDilatedConv, self).__init__() 12 | super().__init__() 13 | self.conv_1 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[48, 1]) 14 | self.conv_2 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[76, 1]) 15 | self.conv_3 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[96, 1]) 16 | self.conv_4 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[111, 1]) 17 | self.conv_5 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[124, 1]) 18 | self.conv_6 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[135, 1]) 19 | self.conv_7 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[144, 1]) 20 | self.conv_8 = Conv2D(c_in, c_out, depthwise = depthwise, kernel_size = [3, 1], padding="same", dilation=[152, 1]) 21 | 22 | def forward(self, x): 23 | x = ( 24 | self.conv_1(x) 25 | + self.conv_2(x) 26 | + self.conv_3(x) 27 | + self.conv_4(x) 28 | + self.conv_5(x) 29 | + self.conv_6(x) 30 | + self.conv_7(x) 31 | + self.conv_8(x) 32 | ) 33 | return F.relu(x) -------------------------------------------------------------------------------- /onsets_and_frames/mel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | from librosa.filters import mel 4 | from librosa.util import pad_center 5 | from scipy.signal import get_window 6 | from torch.autograd import Variable 7 | 8 | from .constants import * 9 | 10 | 11 | class STFT(torch.nn.Module): 12 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 13 | def __init__(self, filter_length, hop_length, win_length=None, window='hann'): 14 | super(STFT, self).__init__() 15 | if win_length is None: 16 | win_length = filter_length 17 | 18 | self.filter_length = filter_length 19 | self.hop_length = hop_length 20 | self.win_length = win_length 21 | self.window = window 22 | self.forward_transform = None 23 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 24 | 25 | cutoff = int((self.filter_length / 2 + 1)) 26 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 27 | np.imag(fourier_basis[:cutoff, :])]) 28 | 29 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 30 | 31 | if window is not None: 32 | assert(filter_length >= win_length) 33 | # get window and zero center pad it to filter_length 34 | fft_window = get_window(window, win_length, fftbins=True) 35 | fft_window = pad_center(fft_window, filter_length) 36 | fft_window = torch.from_numpy(fft_window).float() 37 | 38 | # window the bases 39 | forward_basis *= fft_window 40 | 41 | self.register_buffer('forward_basis', forward_basis.float()) 42 | 43 | def forward(self, input_data): 44 | num_batches = input_data.size(0) 45 | num_samples = input_data.size(1) 46 | 47 | # similar to librosa, reflect-pad the input 48 | input_data = input_data.view(num_batches, 1, num_samples) 49 | input_data = F.pad( 50 | input_data.unsqueeze(1), 51 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 52 | mode='reflect') 53 | input_data = input_data.squeeze(1) 54 | 55 | forward_transform = F.conv1d( 56 | input_data, 57 | Variable(self.forward_basis, requires_grad=False), 58 | stride=self.hop_length, 59 | padding=0) 60 | 61 | cutoff = int((self.filter_length / 2) + 1) 62 | real_part = forward_transform[:, :cutoff, :] 63 | imag_part = forward_transform[:, cutoff:, :] 64 | 65 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 66 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 67 | 68 | return magnitude, phase 69 | 70 | 71 | class MelSpectrogram(torch.nn.Module): 72 | def __init__(self, n_mels, sample_rate, filter_length, hop_length, 73 | win_length=None, mel_fmin=0.0, mel_fmax=None): 74 | super(MelSpectrogram, self).__init__() 75 | self.stft = STFT(filter_length, hop_length, win_length) 76 | 77 | mel_basis = mel(sample_rate, filter_length, n_mels, mel_fmin, mel_fmax, htk=True) 78 | mel_basis = torch.from_numpy(mel_basis).float() 79 | self.register_buffer('mel_basis', mel_basis) 80 | 81 | def forward(self, y): 82 | """Computes mel-spectrograms from a batch of waves 83 | PARAMS 84 | ------ 85 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 86 | RETURNS 87 | ------- 88 | mel_output: torch.FloatTensor of shape (B, T, n_mels) 89 | """ 90 | assert(torch.min(y.data) >= -1) 91 | assert(torch.max(y.data) <= 1) 92 | 93 | magnitudes, phases = self.stft(y) 94 | magnitudes = magnitudes.data 95 | mel_output = torch.matmul(self.mel_basis, magnitudes) 96 | mel_output = torch.log(torch.clamp(mel_output, min=1e-5)) 97 | return mel_output 98 | 99 | 100 | # the default melspectrogram converter across the project 101 | melspectrogram = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, mel_fmin=MEL_FMIN, mel_fmax=MEL_FMAX) 102 | melspectrogram.to(DEFAULT_DEVICE) 103 | -------------------------------------------------------------------------------- /onsets_and_frames/midi.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import sys 3 | 4 | import mido 5 | import numpy as np 6 | from joblib import Parallel, delayed 7 | from mido import Message, MidiFile, MidiTrack 8 | from mir_eval.util import hz_to_midi 9 | from tqdm import tqdm 10 | 11 | 12 | def parse_midi(path): 13 | """open midi file and return np.array of (onset, offset, note, velocity) rows""" 14 | midi = mido.MidiFile(path) 15 | 16 | time = 0 17 | sustain = False 18 | events = [] 19 | for message in midi: 20 | time += message.time 21 | 22 | if message.type == 'control_change' and message.control == 64 and (message.value >= 64) != sustain: 23 | # sustain pedal state has just changed 24 | sustain = message.value >= 64 25 | event_type = 'sustain_on' if sustain else 'sustain_off' 26 | event = dict(index=len(events), time=time, type=event_type, note=None, velocity=0) 27 | events.append(event) 28 | 29 | if 'note' in message.type: 30 | # MIDI offsets can be either 'note_off' events or 'note_on' with zero velocity 31 | velocity = message.velocity if message.type == 'note_on' else 0 32 | event = dict(index=len(events), time=time, type='note', note=message.note, velocity=velocity, sustain=sustain) 33 | events.append(event) 34 | 35 | notes = [] 36 | for i, onset in enumerate(events): 37 | if onset['velocity'] == 0: 38 | continue 39 | 40 | # find the next note_off message 41 | offset = next(n for n in events[i + 1:] if n['note'] == onset['note'] or n is events[-1]) 42 | 43 | if offset['sustain'] and offset is not events[-1]: 44 | # if the sustain pedal is active at offset, find when the sustain ends 45 | offset = next(n for n in events[offset['index'] + 1:] 46 | if n['type'] == 'sustain_off' or n['note'] == onset['note'] or n is events[-1]) 47 | 48 | note = (onset['time'], offset['time'], onset['note'], onset['velocity']) 49 | notes.append(note) 50 | 51 | return np.array(notes) 52 | 53 | 54 | def save_midi(path, pitches, intervals, velocities): 55 | """ 56 | Save extracted notes as a MIDI file 57 | Parameters 58 | ---------- 59 | path: the path to save the MIDI file 60 | pitches: np.ndarray of bin_indices 61 | intervals: list of (onset_index, offset_index) 62 | velocities: list of velocity values 63 | """ 64 | file = MidiFile() 65 | track = MidiTrack() 66 | file.tracks.append(track) 67 | ticks_per_second = file.ticks_per_beat * 2.0 68 | 69 | events = [] 70 | for i in range(len(pitches)): 71 | events.append(dict(type='on', pitch=pitches[i], time=intervals[i][0], velocity=velocities[i])) 72 | events.append(dict(type='off', pitch=pitches[i], time=intervals[i][1], velocity=velocities[i])) 73 | events.sort(key=lambda row: row['time']) 74 | 75 | last_tick = 0 76 | for event in events: 77 | current_tick = int(event['time'] * ticks_per_second) 78 | velocity = int(event['velocity'] * 127) 79 | if velocity > 127: 80 | velocity = 127 81 | pitch = int(round(hz_to_midi(event['pitch']))) 82 | track.append(Message('note_' + event['type'], note=pitch, velocity=velocity, time=current_tick - last_tick)) 83 | last_tick = current_tick 84 | 85 | file.save(path) 86 | 87 | 88 | if __name__ == '__main__': 89 | 90 | def process(input_file, output_file): 91 | midi_data = parse_midi(input_file) 92 | np.savetxt(output_file, midi_data, '%.6f', '\t', header='onset\toffset\tnote\tvelocity') 93 | 94 | 95 | def files(): 96 | for input_file in tqdm(sys.argv[1:]): 97 | if input_file.endswith('.mid'): 98 | output_file = input_file[:-4] + '.tsv' 99 | elif input_file.endswith('.midi'): 100 | output_file = input_file[:-5] + '.tsv' 101 | else: 102 | print('ignoring non-MIDI file %s' % input_file, file=sys.stderr) 103 | continue 104 | 105 | yield (input_file, output_file) 106 | 107 | Parallel(n_jobs=multiprocessing.cpu_count())(delayed(process)(in_file, out_file) for in_file, out_file in files()) 108 | -------------------------------------------------------------------------------- /onsets_and_frames/network_utils.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Iterable, Optional, Tuple 5 | 6 | 7 | def compute_hopsize_cqt(fs_cqt_target, sr=16000): 8 | """ 9 | Computes the necessary CQT hopsize to approximate a desired feature rate fs_cqt_target 10 | Args: 11 | fs_cqt_target: desired frame rate in Hz 12 | fs: audio sampling rate 13 | num_octaves: number of octaves for the CQT 14 | Returns: 15 | hopsize_cqt: CQT hopsize in samples 16 | fs_cqt: resulting CQT frame rate in Hz 17 | """ 18 | hopsize_target = sr // fs_cqt_target 19 | return hopsize_target 20 | 21 | 22 | class CircularOctavePadding(nn.Module): 23 | def __init__( 24 | self, kernel_size: Tuple[int], pitch_class_dilation: int, strides=Optional[None] 25 | ) -> None: 26 | super(CircularOctavePadding, self).__init__() 27 | self.kernel_size = kernel_size 28 | self.pitch_class_dilation = pitch_class_dilation 29 | self.strides = strides # Not implemented 30 | self.dummy_padding = nn.ConstantPad1d(1, 0) 31 | self.pitch_class_required_padding = ( 32 | 0 if kernel_size[1] == 1 else self.pitch_class_dilation 33 | ) 34 | 35 | def forward(self, x): 36 | try: # Full 3D convolution 37 | batch, channels, octaves, pitch_classes, frames = x.size() 38 | pitch_class_padding = x[:, :, :, :self.pitch_class_required_padding, :].roll(-1, dims=2) 39 | pitch_class_padding[:, :, -1, :, :] = 0 40 | octave_padding = t.zeros( 41 | (batch, channels, self.kernel_size[0]-1, pitch_classes + self.pitch_class_required_padding, frames), 42 | device=x.device 43 | ) 44 | 45 | if self.pitch_class_required_padding > 0: 46 | padded_x = t.concat([x, pitch_class_padding], dim=-2) 47 | padded_x = t.concat([padded_x, octave_padding], dim=-3) 48 | else: 49 | padded_x = t.concat([x, octave_padding], dim=-3) 50 | except: # 2D trick 51 | batch, channels, octaves, pitch_classes = x.size() 52 | pitch_class_padding = x[:, :, :, :self.pitch_class_required_padding].roll(-1, dims=2) 53 | pitch_class_padding[:, :, -1, :] = 0 54 | octave_padding = t.zeros( 55 | (batch, channels, self.kernel_size[0]-1, pitch_classes + self.pitch_class_required_padding), 56 | device=x.device 57 | ) 58 | 59 | if self.pitch_class_required_padding > 0: 60 | padded_x = t.concat([x, pitch_class_padding], dim=-1) 61 | padded_x = t.concat([padded_x, octave_padding], dim=-2) 62 | else: 63 | padded_x = t.concat([x, octave_padding], dim=-2) 64 | return padded_x 65 | 66 | 67 | class MultiRateConv(nn.Module): 68 | """ 69 | For HarmoF0 70 | """ 71 | def __init__(self, n_in_channels, n_out_channels, dilation_rates: Optional[Iterable[int]]): 72 | super(MultiRateConv, self).__init__() 73 | if dilation_rates is None: 74 | dilations_rates = [0, 28, 76] 75 | self.dilation_rates = dilation_rates 76 | self.n_in_channels = n_in_channels 77 | self.n_out_channels = n_out_channels 78 | 79 | module_list = [] 80 | for _ in dilations_rates: 81 | module_list.append( 82 | nn.Conv2d(n_in_channels, n_out_channels, padding_mode="circular", kernel_size=(3, 1), padding="same", dilation=(1, 1)) 83 | ) 84 | self.module_list = nn.ModuleList(module_list) 85 | 86 | def forward(self, x): 87 | outputs = [module(x) for module in self.module_list] 88 | for idx, shift in enumerate(self.dilation_rates): 89 | outputs[idx] = t.roll(outputs[idx], -shift, dims=1) 90 | if shift > 0: 91 | outputs[idx][:, :, -shift:, :] = 0 92 | return t.stack(outputs, dim=1).sum(dim=1) 93 | 94 | 95 | class From2Dto3D(nn.Module): 96 | def __init__(self, bins_per_octave: int, n_octaves: int): 97 | super(From2Dto3D, self).__init__() 98 | self.bins_per_octave = bins_per_octave 99 | self.n_octaves = n_octaves 100 | self.total_bins = int(self.n_octaves * self.bins_per_octave) 101 | 102 | def forward(self, cqt): 103 | padding_needed = self.total_bins - cqt.shape[2] 104 | cqt = F.pad(cqt, (0, 0, 0, padding_needed)) 105 | batch, channels, bins, frames = cqt.size() 106 | octave_pc_spectrum = cqt.reshape([batch, channels, self.n_octaves, self.bins_per_octave, frames]) 107 | return octave_pc_spectrum 108 | 109 | 110 | class From3Dto2D(nn.Module): 111 | def __init__(self, bins_per_octave: int, n_octaves: int): 112 | super(From3Dto2D, self).__init__() 113 | self.bins_per_octave = bins_per_octave 114 | self.n_octaves = n_octaves 115 | self.total_bins = int(self.n_octaves * self.bins_per_octave) 116 | 117 | def forward(self, octave_pc_spectrum): 118 | batch, channels, octaves, pitch_classes, frames = octave_pc_spectrum.size() 119 | cqt = octave_pc_spectrum.reshape([batch, channels, self.total_bins, frames]) 120 | return cqt 121 | 122 | 123 | class FGLSTM(nn.Module): 124 | """ 125 | From hhpnet code 126 | """ 127 | 128 | def __init__(self, channel_in, channel_out, lstm_size) -> None: 129 | super().__init__() 130 | 131 | self.channel_out = channel_out 132 | 133 | self.lstm = BiLSTM(channel_in, lstm_size // 2) 134 | self.linear = nn.Linear(lstm_size, channel_out) 135 | 136 | def forward(self, x): 137 | # inputs: [b x c_in x freq x T] 138 | # outputs: [b x c_out x T x freq] 139 | 140 | b, c_in, n_freq, frames = x.size() 141 | 142 | # => [b x freq x T x c_in] 143 | x = t.permute(x, [0, 3, 2, 1]) 144 | 145 | # => [(b*freq) x T x c_in] 146 | x = x.reshape([b * n_freq, frames, c_in]) 147 | # => [(b*freq) x T x lstm_size] 148 | x = self.lstm(x) 149 | # => [(b*freq) x T x c_out] 150 | x = self.linear(x) 151 | # => [b x freq x T x c_out] 152 | x = x.reshape([b, n_freq, frames, self.channel_out]) 153 | # => [b x c_out x T x freq] 154 | x = t.permute(x, [0, 3, 2, 1]) 155 | return x.squeeze(1) 156 | 157 | 158 | class BiLSTM(nn.Module): 159 | inference_chunk_length = 512 160 | 161 | def __init__(self, input_features, recurrent_features): 162 | super().__init__() 163 | self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True, bidirectional=True) 164 | 165 | def forward(self, x): 166 | if self.training: 167 | return self.rnn(x)[0] 168 | else: 169 | # evaluation mode: support for longer sequences that do not fit in memory 170 | batch_size, sequence_length, input_features = x.shape 171 | hidden_size = self.rnn.hidden_size 172 | num_directions = 2 if self.rnn.bidirectional else 1 173 | 174 | h = t.zeros(num_directions, batch_size, hidden_size, device=x.device) 175 | c = t.zeros(num_directions, batch_size, hidden_size, device=x.device) 176 | output = t.zeros(batch_size, sequence_length, num_directions * hidden_size, device=x.device) 177 | 178 | # forward direction 179 | slices = range(0, sequence_length, self.inference_chunk_length) 180 | for start in slices: 181 | end = start + self.inference_chunk_length 182 | output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c)) 183 | 184 | # reverse direction 185 | if self.rnn.bidirectional: 186 | # h.zero_() 187 | # c.zero_() 188 | # ONNX does not support tensor.zero_(), so use following: 189 | h.fill_(0) 190 | c.fill_(0) 191 | 192 | for start in reversed(slices): 193 | end = start + self.inference_chunk_length 194 | result, (h, c) = self.rnn(x[:, start:end, :], (h, c)) 195 | output[:, start:end, hidden_size:] = result[:, :, hidden_size:] 196 | 197 | return output 198 | 199 | 200 | class Conv2D(nn.Module): 201 | def __init__(self, in_channels: int, out_channels: int, depthwise: bool = False, **kwargs) -> None: 202 | super(Conv2D, self).__init__() 203 | self.depthwise = depthwise 204 | self.in_channels = in_channels 205 | self.out_channels = out_channels 206 | self.out_in_ratio = None 207 | self.groups = 1 208 | 209 | if depthwise: 210 | self.n_groups = in_channels 211 | self.out_in_ratio = in_channels/out_channels 212 | self.block = nn.Sequential( 213 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, groups=self.groups, **kwargs), 214 | nn.ReLU(), 215 | nn.InstanceNorm2d(in_channels), 216 | nn.Conv2d(in_channels, out_channels, groups=self.groups, kernel_size=(1, 1)) 217 | ) 218 | else: 219 | self.block = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, groups=self.groups, **kwargs) 220 | 221 | def forward(self, x): 222 | return self.block(x) 223 | 224 | 225 | class Conv3D(nn.Module): 226 | def __init__(self, in_channels: int, out_channels: int, depthwise: bool = False, **kwargs) -> None: 227 | super(Conv2D, self).__init__() 228 | self.depthwise = depthwise 229 | self.in_channels = in_channels 230 | self.out_channels = out_channels 231 | self.out_in_ratio = None 232 | self.groups = 1 233 | 234 | if depthwise: 235 | self.n_groups = in_channels 236 | self.out_in_ratio = in_channels/out_channels 237 | self.block = nn.Sequential( 238 | nn.Conv3d(in_channels=in_channels, out_channels=in_channels, groups=self.groups, **kwargs), 239 | nn.ReLU(), 240 | nn.InstanceNorm2d(in_channels), 241 | nn.Conv3d(in_channels, out_channels, groups=self.groups, kernel_size=(1, 1, 1)) 242 | ) 243 | else: 244 | self.block = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, groups=self.groups, **kwargs) 245 | 246 | def forward(self, x): 247 | return self.block(x) -------------------------------------------------------------------------------- /onsets_and_frames/transcriber.py: -------------------------------------------------------------------------------- 1 | """ 2 | A rough translation of Magenta's Onsets and Frames implementation [1]. 3 | 4 | [1] https://github.com/tensorflow/magenta/blob/master/magenta/models/onsets_frames_transcription/model.py 5 | """ 6 | 7 | import torch as t 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from nnAudio.features import CQT 12 | 13 | from .network_utils import From3Dto2D, From2Dto3D, FGLSTM 14 | from .harmonic_layers import HarmonicDilatedConv, HarmConvBlock 15 | from .constants import HOP_LENGTH 16 | 17 | from typing import Tuple 18 | 19 | 20 | class HPPNet(nn.Module): 21 | def __init__(self, bins_per_octave: int = 48, desired_frame_rate: int = 50, sr: int = 16000, n_dilated_conv_layers: int = 3, 22 | convblock_length: int = 3, add_dilated_convblock: bool = True, post_dilated_convblock_length: int = 3, channel_sizes: Tuple[int, int, int] = [16, 128, 128]): 23 | super(HPPNet, self).__init__() 24 | self.sr = sr 25 | self.num_octaves = 8 26 | self.bins_per_octave = bins_per_octave 27 | self.n_bins_in = self.bins_per_octave * self.num_octaves 28 | self.n_dilated_conv_layers = n_dilated_conv_layers 29 | self.convblock_length = convblock_length 30 | self.add_dilated_convblock = add_dilated_convblock 31 | self.post_dilated_convblock_length = post_dilated_convblock_length 32 | self.channel_sizes = channel_sizes 33 | 34 | # CQT extractor 35 | self.hopsize = HOP_LENGTH 36 | self.feature_rate = int(self.sr/self.hopsize) 37 | self.cqt_layer = CQT(bins_per_octave=self.bins_per_octave, n_bins=352, hop_length=self.hopsize, sr=sr, 38 | pad_mode="constant", center=False, trainable=False, verbose=False) 39 | 40 | convblock_0 = nn.Sequential( 41 | nn.Conv2d(kernel_size=(7, 7), in_channels=1, out_channels=self.channel_sizes[0], groups=1, padding="same"), 42 | nn.ReLU(), 43 | nn.InstanceNorm2d(self.channel_sizes[0]), 44 | ) 45 | 46 | convblock = nn.Sequential( 47 | nn.Conv2d(kernel_size=(7, 7), in_channels=self.channel_sizes[0], out_channels=self.channel_sizes[0], groups=1, padding="same"), 48 | nn.ReLU(), 49 | nn.InstanceNorm2d(self.channel_sizes[0]), 50 | ) 51 | 52 | self.convblock = nn.Sequential() 53 | for i, _ in enumerate(range(self.convblock_length)): 54 | if i == 0: 55 | self.convblock = self.convblock.append(convblock_0) 56 | else: 57 | self.convblock = self.convblock.append(convblock) 58 | 59 | self.harmonic_block = HarmonicDilatedConv(c_in=self.channel_sizes[0], c_out=self.channel_sizes[1]) 60 | 61 | dilated_convblock_base = nn.Sequential( 62 | nn.Conv2d(kernel_size=(3, 1), in_channels=self.channel_sizes[1], out_channels=self.channel_sizes[2], groups=1, padding="same", 63 | dilation=(48, 1)), 64 | nn.ReLU(), 65 | nn.MaxPool2d(kernel_size=(4, 1)), 66 | nn.InstanceNorm2d(self.channel_sizes[2]), 67 | 68 | nn.Conv2d(kernel_size=(3, 1), in_channels=self.channel_sizes[2], out_channels=self.channel_sizes[2], groups=1, padding="same", 69 | dilation=(12, 1)), 70 | nn.ReLU(), 71 | nn.InstanceNorm2d(self.channel_sizes[2]), 72 | ) 73 | 74 | dilated_convblock_extension = nn.Sequential( 75 | nn.Conv2d(kernel_size=(1, 5), in_channels=self.channel_sizes[2], out_channels=self.channel_sizes[2], groups=1, padding="same"), 76 | nn.ReLU(), 77 | nn.InstanceNorm2d(self.channel_sizes[2]), 78 | ) 79 | 80 | dilated_convblock = dilated_convblock_base if self.add_dilated_convblock else nn.Sequential(nn.MaxPool2d(kernel_size=(4, 1))) 81 | 82 | for _ in range(self.post_dilated_convblock_length): 83 | dilated_convblock = dilated_convblock.append(dilated_convblock_extension) 84 | 85 | self.dilated_convblock = dilated_convblock 86 | 87 | self.fglstm_frames = FGLSTM(channel_in=self.channel_sizes[2], channel_out=1, lstm_size=self.channel_sizes[2]) 88 | self.fglstm_onsets = FGLSTM(channel_in=self.channel_sizes[2], channel_out=1, lstm_size=self.channel_sizes[2]) 89 | self.fglstm_offsets = FGLSTM(channel_in=self.channel_sizes[2], channel_out=1, lstm_size=self.channel_sizes[2]) 90 | self.fglstm_velocities = FGLSTM(channel_in=self.channel_sizes[2], channel_out=1, lstm_size=self.channel_sizes[2]) 91 | 92 | def obtain_cqt(self, x): 93 | kernel_size = self.cqt_layer.cqt_kernels_imag.shape[-1] 94 | x = F.pad(input=x, pad=(kernel_size//2, kernel_size//2), mode="constant", value=0) 95 | cqt = self.cqt_layer(x) 96 | return t.log10(cqt + 1) 97 | 98 | def neural_processing(self, cqt): 99 | output = cqt.unsqueeze(1) 100 | output = self.convblock(output) 101 | output = self.harmonic_block(output) 102 | output = self.dilated_convblock(output) 103 | frames = self.fglstm_frames(output) 104 | onsets = self.fglstm_onsets(output) 105 | offsets = self.fglstm_offsets(output) 106 | velocities = self.fglstm_velocities(output) 107 | return frames, onsets, offsets, velocities 108 | 109 | def forward(self, x): 110 | cqt = self.obtain_cqt(x) 111 | frames, onsets, offsets, velocities = self.neural_processing(cqt) 112 | 113 | frames = t.sigmoid(frames) 114 | onsets = t.sigmoid(onsets) 115 | offsets = t.sigmoid(offsets) 116 | velocities = t.sigmoid(velocities) 117 | return frames, onsets, offsets, velocities 118 | 119 | def run_on_batch(self, batch): 120 | audio_label = batch['audio'] 121 | onset_label = batch['onset'] 122 | offset_label = batch['offset'] 123 | frame_label = batch['frame'] 124 | velocity_label = batch['velocity'] 125 | 126 | frames = onset_label.size()[-2] 127 | frame_pred, onset_pred, offset_pred, velocity_pred = self.forward(audio_label) 128 | frame_pred, onset_pred, offset_pred, velocity_pred = frame_pred[..., :frames, :], onset_pred[..., :frames, :], offset_pred[..., :frames, :], velocity_pred[..., :frames, :] 129 | 130 | predictions = { 131 | 'onset': onset_pred.reshape(*onset_label.shape), 132 | 'offset': offset_pred.reshape(*offset_label.shape), 133 | 'frame': frame_pred.reshape(*frame_label.shape), 134 | 'velocity': velocity_pred.reshape(*velocity_label.shape) 135 | } 136 | 137 | losses = { 138 | 'loss/onset': (-2 * onset_label * t.log(predictions["onset"]) - (1 - onset_label) * t.log(1-predictions["onset"])).mean(), 139 | 'loss/offset': F.binary_cross_entropy(predictions['offset'], offset_label), 140 | 'loss/frame': F.binary_cross_entropy(predictions['frame'], frame_label), 141 | 'loss/velocity': self.velocity_loss(predictions['velocity'], velocity_label, onset_label) 142 | } 143 | 144 | return predictions, losses 145 | 146 | def velocity_loss(self, velocity_pred, velocity_label, onset_label): 147 | denominator = onset_label.sum() 148 | if denominator.item() == 0: 149 | return denominator 150 | else: 151 | return (onset_label * (velocity_label - velocity_pred) ** 2).sum() / denominator 152 | 153 | 154 | 155 | class HPPNetLess(HPPNet): 156 | """ 157 | HPPNet without harmonic knowledge 158 | """ 159 | def __init__(self, **kwargs): 160 | super(HPPNetLess, self).__init__(**kwargs) 161 | self.harmonic_block = nn.Conv2d(kernel_size=(3, 1), in_channels=16, out_channels=128, groups=1, dilation=(1, 1), padding="same") 162 | 163 | ############ 3D Models 164 | 165 | class HPPNetDDD(HPPNet): 166 | def __init__(self, **kwargs): 167 | super(HPPNetDDD, self).__init__(**kwargs) 168 | 169 | self.cqt_layer = CQT(bins_per_octave=self.bins_per_octave, fmax=self.sr//2, hop_length=self.hopsize, sr=self.sr, 170 | pad_mode="constant", center=False, trainable=False, verbose=False) 171 | 172 | self.harmonic_block = nn.Sequential( 173 | From2Dto3D(bins_per_octave=self.bins_per_octave, n_octaves=self.num_octaves), 174 | HarmConvBlock(n_in_channels=self.channel_sizes[0], n_out_channels=self.channel_sizes[1], octave_depth=3, dilation_rates=[4*0, 4*7, 4*4, 4*10]), 175 | From3Dto2D(bins_per_octave=self.bins_per_octave, n_octaves=self.num_octaves), 176 | ) 177 | 178 | def neural_processing(self, cqt): 179 | output = cqt.unsqueeze(1) 180 | output = self.convblock(output) 181 | output = self.harmonic_block(output) 182 | output = self.dilated_convblock(output) 183 | output = output[:, :, 4:-4, :] # Output 3rd is dimension 96, should be 88 for piano 184 | frames = self.fglstm_frames(output) 185 | onsets = self.fglstm_onsets(output) 186 | offsets = self.fglstm_offsets(output) 187 | velocities = self.fglstm_velocities(output) 188 | return frames, onsets, offsets, velocities 189 | -------------------------------------------------------------------------------- /onsets_and_frames/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from functools import reduce 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.nn.modules.module import _addindent 7 | 8 | 9 | def cycle(iterable): 10 | while True: 11 | for item in iterable: 12 | yield item 13 | 14 | 15 | def summary(model, file=sys.stdout): 16 | def repr(model): 17 | # We treat the extra repr like the sub-module, one item per line 18 | extra_lines = [] 19 | extra_repr = model.extra_repr() 20 | # empty string will be split into list [''] 21 | if extra_repr: 22 | extra_lines = extra_repr.split('\n') 23 | child_lines = [] 24 | total_params = 0 25 | for key, module in model._modules.items(): 26 | mod_str, num_params = repr(module) 27 | mod_str = _addindent(mod_str, 2) 28 | child_lines.append('(' + key + '): ' + mod_str) 29 | total_params += num_params 30 | lines = extra_lines + child_lines 31 | 32 | for name, p in model._parameters.items(): 33 | if hasattr(p, 'shape'): 34 | total_params += reduce(lambda x, y: x * y, p.shape) 35 | 36 | main_str = model._get_name() + '(' 37 | if lines: 38 | # simple one-liner info, which most builtin Modules will use 39 | if len(extra_lines) == 1 and not child_lines: 40 | main_str += extra_lines[0] 41 | else: 42 | main_str += '\n ' + '\n '.join(lines) + '\n' 43 | 44 | main_str += ')' 45 | if file is sys.stdout: 46 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 47 | else: 48 | main_str += ', {:,} params'.format(total_params) 49 | return main_str, total_params 50 | 51 | string, count = repr(model) 52 | if file is not None: 53 | if isinstance(file, str): 54 | file = open(file, 'w') 55 | print(string, file=file) 56 | file.flush() 57 | 58 | return count 59 | 60 | 61 | def save_pianoroll(path, onsets, frames, onset_threshold=0.5, frame_threshold=0.5, zoom=4): 62 | """ 63 | Saves a piano roll diagram 64 | 65 | Parameters 66 | ---------- 67 | path: str 68 | onsets: torch.FloatTensor, shape = [frames, bins] 69 | frames: torch.FloatTensor, shape = [frames, bins] 70 | onset_threshold: float 71 | frame_threshold: float 72 | zoom: int 73 | """ 74 | onsets = (1 - (onsets.t() > onset_threshold).to(torch.uint8)).cpu() 75 | frames = (1 - (frames.t() > frame_threshold).to(torch.uint8)).cpu() 76 | both = (1 - (1 - onsets) * (1 - frames)) 77 | image = torch.stack([onsets, frames, both], dim=2).flip(0).mul(255).numpy() 78 | image = Image.fromarray(image, 'RGB') 79 | image = image.resize((image.size[0], image.size[1] * zoom)) 80 | image.save(path) 81 | -------------------------------------------------------------------------------- /prepare_maestro.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # set -e 3 | 4 | USAGE="Usage: $0 -s [-m /dev/null 32 | then 33 | echo "ffmpeg could not be found" 34 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 35 | echo "Installing ffmpeg ..." 36 | sudo apt-get install ffmpeg 37 | elif [[ "$OSTYPE" == "darwin"* ]]; then 38 | echo "Installing ffmpeg ..." 39 | brew install ffmpeg 40 | else 41 | echo "Please install ffmpeg manually." 42 | exit 1 43 | fi 44 | fi 45 | 46 | # Get the directory where the data will be stored, if not specified, throw an error. 47 | if [ -z "$DATA_DIR" ]; then 48 | echo "Please specify the directory where the data will be downloaded." 49 | echo $USAGE 50 | exit 1 51 | fi 52 | 53 | # Check if the directory exists already 54 | if [ -d "$DATA_DIR/MAESTRO" && $download_dataset = "False" ]; then 55 | echo "The directory MAESTRO in $DATA_DIR already exists." 56 | echo "Please specify a new directory." 57 | exit 1 58 | else 59 | mkdir $DATA_DIR 60 | fi 61 | 62 | 63 | # Create a symbolic link to the data directory 64 | ln -s $DATA_DIR ./data 65 | echo "[WARNING] ./data/ is now a symbolic link to $DATA_DIR. If this directory is an external drive it might change its location in the future." 66 | echo "If you have problems with the data in the future, please check if the symbolic link is still valid." 67 | 68 | 69 | 70 | # If the variable download_dataset is True, Download the data inside the symbolic link 71 | if [ $download_dataset = "True" ]; then 72 | echo "Downloading the MAESTRO dataset" 73 | curl -O "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip" -o ./data/maestro-v3.0.0.zip 74 | 75 | echo "Extracting the files ..." 76 | unzip -o ./data/maestro-v3.0.0.zip | awk 'BEGIN{ORS=""} {print "\rExtracting " NR "/2383 ..."; system("")} END {print "\ndone\n"}' 77 | 78 | rm ./data/maestro-v3.0.0.zip 79 | mv ./data/maestro-v3.0.0 ./data/MAESTRO 80 | fi 81 | 82 | echo Converting the audio files to FLAC ... 83 | COUNTER=0 84 | for f in ./data/MAESTRO/*/*.wav; do 85 | COUNTER=$((COUNTER + 1)) 86 | echo -ne "\rConverting ($COUNTER/1184) ..." 87 | ffmpeg -y -loglevel fatal -i $f -ac 1 -ar 16000 ${f/\.wav/.flac} 88 | rm $f 89 | done 90 | 91 | echo 92 | echo Preparation complete! 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.0.0 2 | nnAudio>=0.3.2 3 | scipy>=1.1.0 4 | torch>=1.10.0 5 | torchaudio>=0.10.0 6 | torchmetrics==0.10.1 7 | SoundFile>=0.10.2 8 | sacred>=0.8.3 9 | librosa>=0.6.2 10 | numpy>=1.20.0<1.24 11 | tqdm>=4.64.1 12 | git+https://github.com/craffel/mir_eval.git 13 | mido>=1.2.9 14 | Pillow>=6.2.0 15 | tensorboard>=2.10 16 | pandas>=1.2 17 | scikit-learn>=1.0 18 | resampy>=0.2.2 19 | soxr -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | from sacred import Experiment 6 | from sacred.commands import print_config 7 | from sacred.observers import FileStorageObserver 8 | from torch.nn.utils import clip_grad_norm_ 9 | from torch.optim.lr_scheduler import StepLR 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from evaluate import evaluate 15 | from onsets_and_frames import * 16 | import onsets_and_frames.transcriber as nnmodels 17 | 18 | ex = Experiment('train_transcriber') 19 | 20 | 21 | class EarlyStopper: 22 | # From https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch. 23 | # Adapted for F1 score: should stop when is not increasing anymore. 24 | # Original one is designed for losses, stop when it does not decrease anymore. 25 | def __init__(self, patience=10, min_delta=0): 26 | self.patience = patience 27 | self.min_delta = min_delta 28 | self.counter = 0 29 | self.max_validation_f1 = 0 30 | 31 | def early_stop(self, f1): 32 | if f1 > self.max_validation_f1: 33 | self.max_validation_f1 = f1 34 | self.counter = 0 35 | elif f1 < (self.max_validation_f1 + self.min_delta): 36 | self.counter += 1 37 | if self.counter >= self.patience: 38 | return True 39 | return False 40 | 41 | 42 | @ex.config 43 | def config(): 44 | logdir = 'runs/transcriber-' + datetime.now().strftime('%y%m%d-%H%M%S') 45 | device = 'cpu' # DEFAULT_DEVICE 46 | iterations = 500000 47 | resume_iteration = None 48 | checkpoint_interval = 1000 49 | train_on = 'MAESTRO' 50 | model = "HPPNetDDD" 51 | preload_dataset = True 52 | 53 | batch_size = 4 54 | sequence_length = SAMPLE_RATE*5 55 | 56 | if torch.cuda.is_available() and torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory < 10e9: 57 | batch_size //= 2 58 | sequence_length //= 2 59 | print(f'Reducing batch size to {batch_size} and sequence_length to {sequence_length} to save memory') 60 | 61 | learning_rate = 0.0006 62 | learning_rate_decay_steps = 10000 63 | learning_rate_decay_rate = 0.98 64 | 65 | leave_one_out = None 66 | 67 | clip_gradient_norm = None 68 | 69 | convblock_length = 3 70 | add_dilated_convblock = True 71 | validation_length = sequence_length 72 | validation_interval = 500 73 | n_dilated_conv_layers = 3 74 | 75 | ex.observers.append(FileStorageObserver.create(logdir)) 76 | 77 | 78 | @ex.automain 79 | def train(logdir, device, iterations, resume_iteration, checkpoint_interval, train_on, batch_size, sequence_length, preload_dataset, 80 | model, learning_rate, learning_rate_decay_steps, learning_rate_decay_rate, leave_one_out, 81 | clip_gradient_norm, validation_length, validation_interval, n_dilated_conv_layers, convblock_length, add_dilated_convblock): 82 | print_config(ex.current_run) 83 | 84 | os.makedirs(logdir, exist_ok=True) 85 | writer = SummaryWriter(logdir) 86 | early_stopper = EarlyStopper(patience=15, min_delta=0.0005) 87 | stop_training = False 88 | train_groups, validation_groups = ['train'], ['validation'] 89 | 90 | if leave_one_out is not None: 91 | all_years = {'2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017'} 92 | train_groups = list(all_years - {str(leave_one_out)}) 93 | validation_groups = [str(leave_one_out)] 94 | 95 | if train_on == 'MAESTRO': 96 | dataset = MAESTRO(groups=train_groups, sequence_length=sequence_length, device=device) 97 | validation_dataset = MAESTRO(groups=validation_groups, sequence_length=sequence_length, device=device) 98 | else: 99 | dataset = MAPS(groups=['AkPnBcht', 'AkPnBsdf', 'AkPnCGdD', 'AkPnStgb', 'SptkBGAm', 'SptkBGCl', 'StbgTGd2'], sequence_length=sequence_length, device=device) 100 | validation_dataset = MAPS(groups=['ENSTDkAm', 'ENSTDkCl'], sequence_length=validation_length, device=device) 101 | 102 | loader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True, num_workers=4) 103 | 104 | if resume_iteration is None: 105 | model = getattr(nnmodels, model)(n_dilated_conv_layers=n_dilated_conv_layers, convblock_length=convblock_length, add_dilated_convblock=add_dilated_convblock).to(device) 106 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 107 | resume_iteration = 0 108 | else: 109 | model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') 110 | model = torch.load(model_path) 111 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 112 | optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) 113 | 114 | summary(model) 115 | scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) 116 | 117 | progress_bar_metrics = { 118 | 'loss': np.nan, 119 | "note_f1": np.nan, 120 | } 121 | 122 | loop = tqdm(range(resume_iteration + 1, iterations + 1)) 123 | for i, batch in zip(loop, cycle(loader)): 124 | predictions, losses = model.run_on_batch(batch) 125 | 126 | loss = sum(losses.values()) 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | scheduler.step() 131 | 132 | if clip_gradient_norm: 133 | clip_grad_norm_(model.parameters(), clip_gradient_norm) 134 | 135 | for key, value in {'loss': loss, **losses}.items(): 136 | writer.add_scalar(key, value.item(), global_step=i) 137 | 138 | progress_bar_metrics["loss"] = f"{loss.item():4.3f}" 139 | loop.set_postfix(progress_bar_metrics) 140 | 141 | if i % validation_interval == 0: 142 | model.eval() 143 | with torch.no_grad(): 144 | for key, value in evaluate(validation_dataset, model, pr_au_thresholds=None).items(): 145 | writer.add_scalar('validation/' + key.replace(' ', '_'), np.mean(value), global_step=i) 146 | 147 | # Early stopping. Stop if f1 does not increase anymore 148 | if key.replace(' ', '_') == "metric/note/f1": 149 | if early_stopper.early_stop(np.mean(value)): 150 | stop_training = True 151 | 152 | progress_bar_metrics["note_f1"] = f"{np.mean(value):.4f}" 153 | loop.set_postfix(progress_bar_metrics) 154 | 155 | model.train() 156 | 157 | if i % checkpoint_interval == 0: 158 | torch.save(model, os.path.join(logdir, f'model-{i}.pt')) 159 | torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt')) 160 | 161 | if stop_training: 162 | break 163 | 164 | torch.save(model, os.path.join(logdir, f"model-after_training.pt")) 165 | -------------------------------------------------------------------------------- /transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import soundfile 7 | from mir_eval.util import midi_to_hz 8 | 9 | from onsets_and_frames import * 10 | 11 | 12 | def load_and_process_audio(flac_path, sequence_length, device): 13 | 14 | random = np.random.RandomState(seed=42) 15 | 16 | audio, sr = soundfile.read(flac_path, dtype='int16') 17 | assert sr == SAMPLE_RATE 18 | 19 | audio = torch.ShortTensor(audio) 20 | 21 | if sequence_length is not None: 22 | audio_length = len(audio) 23 | step_begin = random.randint(audio_length - sequence_length) // HOP_LENGTH 24 | n_steps = sequence_length // HOP_LENGTH 25 | 26 | begin = step_begin * HOP_LENGTH 27 | end = begin + sequence_length 28 | 29 | audio = audio[begin:end].to(device) 30 | else: 31 | audio = audio.to(device) 32 | 33 | audio = audio.float().div_(32768.0) 34 | 35 | return audio 36 | 37 | 38 | def transcribe(model, audio): 39 | 40 | mel = melspectrogram(audio.reshape(-1, audio.shape[-1])[:, :-1]).transpose(-1, -2) 41 | onset_pred, offset_pred, _, frame_pred, velocity_pred = model(mel) 42 | 43 | predictions = { 44 | 'onset': onset_pred.reshape((onset_pred.shape[1], onset_pred.shape[2])), 45 | 'offset': offset_pred.reshape((offset_pred.shape[1], offset_pred.shape[2])), 46 | 'frame': frame_pred.reshape((frame_pred.shape[1], frame_pred.shape[2])), 47 | 'velocity': velocity_pred.reshape((velocity_pred.shape[1], velocity_pred.shape[2])) 48 | } 49 | 50 | return predictions 51 | 52 | 53 | def transcribe_file(model_file, flac_paths, save_path, sequence_length, 54 | onset_threshold, frame_threshold, device): 55 | 56 | model = torch.load(model_file, map_location=device).eval() 57 | summary(model) 58 | 59 | for flac_path in flac_paths: 60 | print(f'Processing {flac_path}...', file=sys.stderr) 61 | audio = load_and_process_audio(flac_path, sequence_length, device) 62 | predictions = transcribe(model, audio) 63 | 64 | p_est, i_est, v_est = extract_notes(predictions['onset'], predictions['frame'], predictions['velocity'], onset_threshold, frame_threshold) 65 | 66 | scaling = HOP_LENGTH / SAMPLE_RATE 67 | 68 | i_est = (i_est * scaling).reshape(-1, 2) 69 | p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est]) 70 | 71 | os.makedirs(save_path, exist_ok=True) 72 | pred_path = os.path.join(save_path, os.path.basename(flac_path) + '.pred.png') 73 | save_pianoroll(pred_path, predictions['onset'], predictions['frame']) 74 | midi_path = os.path.join(save_path, os.path.basename(flac_path) + '.pred.mid') 75 | save_midi(midi_path, p_est, i_est, v_est) 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('model_file', type=str) 81 | parser.add_argument('flac_paths', type=str, nargs='+') 82 | parser.add_argument('--save-path', type=str, default='.') 83 | parser.add_argument('--sequence-length', default=None, type=int) 84 | parser.add_argument('--onset-threshold', default=0.5, type=float) 85 | parser.add_argument('--frame-threshold', default=0.5, type=float) 86 | parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') 87 | 88 | with torch.no_grad(): 89 | transcribe_file(**vars(parser.parse_args())) 90 | --------------------------------------------------------------------------------