├── .gitignore ├── README.md ├── assets └── exp.png ├── configs └── config.yaml └── train ├── dataset ├── midi.py └── slakh2100.py ├── inference.py ├── loss ├── deep_cluster_loss.py └── mask_inference_loss.py ├── metrics └── transcript_metric.py ├── network ├── cerberus.py ├── cerberus_wrapper.py ├── clustering_head.py ├── separation_head.py ├── shared_body.py ├── transcription_head.py └── transform_layer.py ├── optimizer └── radam.py ├── requirements.txt ├── train.py └── utils ├── debug.py ├── decoding.py └── dsp.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints/ 2 | **/__pycache__/ 3 | **/lightning_logs/ 4 | **/telegram.yaml 5 | .vscode/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CERBERUS Network for Music Separation & Transcription 2 | --- 3 | This is a rough implementation of 4 | [Simultaneous Separation and Transcription of Mixtures with Multiple Polyphonic and Percussive Instruments](https://arxiv.org/abs/1910.12621) (Ethan Manilow et al., ICASSP2020) 5 | 6 | ## Note 7 | 8 | This implementation did not achieve as much performance as reported in the paper. 9 | 10 | 11 | ## Demo (Source Separation) 12 | 13 | [![cerberus](http://img.youtube.com/vi/59uTEk0ZamE/0.jpg)](https://youtu.be/59uTEk0ZamE) 14 | 15 | 16 | ## Quantitative Evaluation (Transcription) 17 | 18 | | | Precision | Recall | Accuracy | 19 | | ----- | --------- | ------ | -------- | 20 | | Piano | 0.585 | 0.566 | 0.460 | 21 | | Bass | 0.797 | 0.817 | 0.747 | 22 | | Drums | 0.230 | 0.417 | 0.133 | 23 | 24 | - Note : There's no benchmark dataset. These results are measured on data I randomly created using test set of Slakh2100 dataset. So It is not appropriate to quantitatively compare these results with those reported in the paper. 25 | 26 | ## Pretrained Network & config 27 | 28 | - [Weight(ckpt)](https://github.com/sweetcocoa/cerberus-pytorch/raw/weights/weights/last.ckpt) 29 | - [Config(yaml)](https://github.com/sweetcocoa/cerberus-pytorch/raw/weights/weights/hparams.yaml) 30 | 31 | 32 | ## Inference 33 | 34 | ```bash 35 | python inference.py hparams.yaml weight.ckpt input.wav output_dir/ 36 | ``` 37 | 38 | - Expected Results: 39 | 40 | ![results](./assets/exp.png) 41 | 42 | 43 | ## Training with Slakh2100 Dataset 44 | 45 | 1. Get Slakh2100 dataset (See: [Slakh2100 Project](http://www.slakh.com/)) 46 | 2. Downsample audio to 16k 47 | 3. Modify configs/config.yaml 48 | ```yaml 49 | data_dir: "/path/to/slakh2100_flac_16k/" 50 | 51 | # see: validation_epoch_end() in network/cerberus_wrapper.py 52 | sample_audio: 53 | path: "/path/to/sample/audio/sample_rate_16k.wav" 54 | offset: 1264000 55 | num_frames: 160000 56 | ``` 57 | 4. Run training 58 | ```bash 59 | python train.py 60 | ``` 61 | 62 | 63 | ## Contact 64 | - Jongho Choi (sweetcocoa@snu.ac.kr) 65 | - Jiwon Kim 66 | - Ahyeon Choi 67 | -------------------------------------------------------------------------------- /assets/exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/cerberus-pytorch/3ae3cc296e0134da01a2fd8fe086a48c435bcd30/assets/exp.png -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 40 2 | experiment_name: cerberus 3 | gpu: 0 4 | num_gpu: 1 5 | heads: ['sep', 'dc', 'tr'] 6 | # heads: ['sep', 'tr'] 7 | lr: 1.0e-04 8 | lr_decay: 0.98 9 | lr_min: 1.0e-07 10 | lr_scheduler: multistep 11 | num_workers: 24 12 | optimizer: rmsprop 13 | sr: 16000 14 | seed: 940513 15 | data_dir: "/mnt/ssd3/mlproject/data/slakh2100_flac_16k/" 16 | 17 | loss_alpha: 0.00001 18 | loss_beta: 0.1 19 | loss_gamma: 0.8 20 | num_pitches: 88 21 | midi_min: 21 22 | transcription_threshold: 0.4 23 | n_fft: 1024 24 | hop_length: 256 25 | 26 | lstm_hidden_size: 300 27 | lstm_num_layers: 4 28 | lstm_bidirectional: True 29 | embedding_size: 20 30 | dropout_rate: 0.3 31 | 32 | num_inst: 3 33 | # num_inst: 4 34 | inst: ["Piano", "Bass", "Drums"] 35 | # inst: ["Piano", "Bass", "Guitar"] 36 | # inst: ["Piano", "Bass", "Drums", "Guitar"] 37 | metrics: 38 | - valid_total_loss: 0. 39 | - train_total_loss: 0. 40 | - train_mask_inference_loss: 0. 41 | - valid_mask_inference_loss: 0. 42 | - train_tr_loss: 0. 43 | - valid_tr_loss: 0. 44 | duration: 2.0 45 | 46 | num_epochs: 3000 47 | gradient_clip_val: 3.0 48 | find_lr: false 49 | 50 | sample_audio: 51 | path: "/mnt/ssd3/mlproject/data/beatles_16000.wav" 52 | offset: 1264000 53 | num_frames: 160000 54 | 55 | check_val_every_n_epoch: 5 -------------------------------------------------------------------------------- /train/dataset/midi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pretty_midi 3 | import numpy as np 4 | from pretty_midi.containers import PitchBend 5 | from pretty_midi.utilities import pitch_bend_to_semitones, note_number_to_hz 6 | import mido 7 | 8 | program_dict = dict( 9 | Piano=2, 10 | Bass=35, 11 | Drums=119, 12 | Guitar=28 13 | ) 14 | 15 | def piano_roll_to_pretty_midi(piano_roll, fs=100, program=0): 16 | '''Convert a Piano Roll array into a PrettyMidi object 17 | with a single instrument. 18 | Parameters 19 | ---------- 20 | piano_roll : np.ndarray, shape=(128,frames), dtype=int 21 | Piano roll of one instrument 22 | fs : int 23 | Sampling frequency of the columns, i.e. each column is spaced apart 24 | by ``1./fs`` seconds. 25 | program : int 26 | The program number of the instrument. 27 | Returns 28 | ------- 29 | midi_object : pretty_midi.PrettyMIDI 30 | A pretty_midi.PrettyMIDI class instance describing 31 | the piano roll. 32 | ''' 33 | notes, frames = piano_roll.shape 34 | pm = pretty_midi.PrettyMIDI() 35 | instrument = pretty_midi.Instrument(program=program) 36 | 37 | # pad 1 column of zeros so we can acknowledge inital and ending events 38 | piano_roll = np.pad(piano_roll, [(0, 0), (1, 1)], 'constant') 39 | 40 | # use changes in velocities to find note on / note off events 41 | velocity_changes = np.nonzero(np.diff(piano_roll).T) 42 | 43 | # keep track on velocities and note on times 44 | prev_velocities = np.zeros(notes, dtype=int) 45 | note_on_time = np.zeros(notes) 46 | 47 | for time, note in zip(*velocity_changes): 48 | # use time + 1 because of padding above 49 | velocity = piano_roll[note, time + 1] 50 | time = time / fs 51 | if velocity > 0: 52 | if prev_velocities[note] == 0: 53 | note_on_time[note] = time 54 | prev_velocities[note] = velocity 55 | else: 56 | pm_note = pretty_midi.Note( 57 | velocity=prev_velocities[note], 58 | pitch=note, 59 | start=note_on_time[note], 60 | end=time) 61 | instrument.notes.append(pm_note) 62 | prev_velocities[note] = 0 63 | pm.instruments.append(instrument) 64 | return pm 65 | 66 | 67 | def parse_midi(path): 68 | """ 69 | Original Source : https://github.com/jongwook/onsets-and-frames/blob/master/onsets_and_frames/midi.py 70 | """ 71 | 72 | """open midi file and return np.array of (onset, offset, note, velocity) rows""" 73 | midi = mido.MidiFile(path) 74 | 75 | time = 0 76 | sustain = False 77 | events = [] 78 | for message in midi: 79 | time += message.time 80 | 81 | if message.type == 'control_change' and message.control == 64 and (message.value >= 64) != sustain: 82 | # sustain pedal state has just changed 83 | sustain = message.value >= 64 84 | event_type = 'sustain_on' if sustain else 'sustain_off' 85 | event = dict(index=len(events), time=time, type=event_type, note=None, velocity=0) 86 | events.append(event) 87 | 88 | if 'note' in message.type: 89 | # MIDI offsets can be either 'note_off' events or 'note_on' with zero velocity 90 | velocity = message.velocity if message.type == 'note_on' else 0 91 | event = dict(index=len(events), time=time, type='note', note=message.note, velocity=velocity, sustain=sustain) 92 | events.append(event) 93 | 94 | notes = [] 95 | for i, onset in enumerate(events): 96 | if onset['velocity'] == 0: 97 | continue 98 | 99 | if len(events) == i + 1: 100 | continue 101 | 102 | # find the next note_off message 103 | offset = next(n for n in events[i + 1:] if n['note'] == onset['note'] or n is events[-1]) 104 | 105 | if offset['sustain'] and offset is not events[-1]: 106 | # if the sustain pedal is active at offset, find when the sustain ends 107 | offset = next(n for n in events[offset['index'] + 1:] if n['type'] == 'sustain_off' or n is events[-1]) 108 | 109 | note = (onset['time'], offset['time'], onset['note'], onset['velocity']) 110 | notes.append(note) 111 | 112 | return np.array(notes) 113 | 114 | 115 | def parsed_midi_to_roll(mid_np, audio_length, hop_length, sample_rate, num_pitches, midi_min, hops_in_onset=1, hops_in_offset=1): 116 | n_keys = num_pitches 117 | n_steps = (audio_length - 1) // hop_length + 1 118 | 119 | label = torch.zeros(n_steps, n_keys, dtype=torch.uint8) 120 | velocity = torch.zeros(n_steps, n_keys, dtype=torch.uint8) 121 | 122 | for onset, offset, note, vel in mid_np: 123 | left = int(round(onset * sample_rate / hop_length)) 124 | onset_right = min(n_steps, left + hops_in_onset) 125 | frame_right = int(round(offset * sample_rate / hop_length)) 126 | frame_right = min(n_steps, frame_right) 127 | offset_right = min(n_steps, frame_right + hops_in_offset) 128 | 129 | f = int(note) - midi_min 130 | 131 | if f >= n_keys: 132 | continue 133 | 134 | label[left:onset_right, f] = 3 135 | label[onset_right:frame_right, f] = 2 136 | label[frame_right:offset_right, f] = 1 137 | velocity[left:frame_right, f] = vel 138 | 139 | data = dict(label=label, velocity=velocity) 140 | return data -------------------------------------------------------------------------------- /train/dataset/slakh2100.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchaudio 8 | import numpy as np 9 | from omegaconf import OmegaConf 10 | from dataset.midi import parse_midi, parsed_midi_to_roll 11 | 12 | torchaudio.set_audio_backend("sox_io") 13 | 14 | class Slakh2100(torch.utils.data.Dataset): 15 | def __init__(self, 16 | data_dir, 17 | phase, 18 | inst=['Piano', 'Bass', 'Drums'], 19 | sr=44100, 20 | duration=2.0, 21 | hop_length=256, 22 | num_pitches=88, 23 | use_cache=True, 24 | random=True, 25 | n_fft=1024, 26 | midi_min=21, 27 | ): 28 | """ 29 | data_dir : "/mnt/ssd3/mlproject/data/slakh2100_flac/" 30 | phase : "train" 31 | """ 32 | super().__init__() 33 | self.tracks = sorted(glob.glob(os.path.join(data_dir, phase) + "/*")) 34 | self.sr = sr 35 | self.inst = inst 36 | self.phase = phase 37 | self.data_dir = data_dir 38 | self.duration = duration 39 | self.hop_length = hop_length 40 | self.random = random 41 | self.num_pitches = num_pitches 42 | self.meta = None 43 | self.midi_min = midi_min 44 | 45 | x = torch.zeros(int(self.sr * self.duration)) 46 | X = torch.stft(x, n_fft, hop_length) 47 | self.num_tr_frames = X.shape[1] 48 | 49 | cache_location = os.path.join(data_dir, f"meta_{phase}_{inst}_v1.pth") 50 | if os.path.exists(cache_location) and use_cache: 51 | self.meta = torch.load(cache_location) 52 | else: 53 | print(f"Creating slakh2100 Cache for {inst}..") 54 | self.meta = self.get_meta(self.tracks) 55 | torch.save(self.meta, cache_location) 56 | print("Done") 57 | 58 | self.labels = [None for track in range(len(self.tracks))] 59 | 60 | def get_label(self, idx): 61 | """ 62 | onset / offset / frame piano roll of the midi file, 63 | See : parse_midi, parsed_miti_to_roll 64 | """ 65 | if self.labels[idx] is None: 66 | track = self.tracks[idx] 67 | info = self.meta[idx]['info'] 68 | inst_track = self.meta[idx]['inst_track'] 69 | track_labels = dict() # track_labels['Piano'] = np.array(n_steps, pitch) 70 | 71 | for inst in self.inst: 72 | if inst in inst_track: 73 | stem = inst_track[inst] 74 | midi_path = track + f"/MIDI/{stem}.mid" 75 | midi_cache = midi_path.replace(".mid", ".pt") 76 | if os.path.exists(midi_cache): 77 | label = torch.load(midi_cache) 78 | else: 79 | mid_np = parse_midi(midi_path) 80 | label = parsed_midi_to_roll(mid_np, info.num_frames, hop_length=self.hop_length, sample_rate=self.sr, num_pitches=self.num_pitches, midi_min=self.midi_min, hops_in_offset=2, hops_in_onset=2) 81 | torch.save(label, midi_cache) 82 | track_labels[inst] = label 83 | self.labels[idx] = track_labels 84 | return track_labels 85 | else: 86 | return self.labels[idx] 87 | 88 | def get_meta(self, tracks): 89 | """ 90 | cfg['info'] = torchaudio.info (cfg['info'].num_frames = duration * sample_rate ) 91 | cfg['inst_track'] 92 | cfg['inst_track']['Piano'] = "S01" 93 | cfg['inst_track']['Bass'] = "S02" 94 | cfg['inst_track']['Drums'] = "S04" 95 | """ 96 | 97 | cfgs = [] 98 | for k, track in enumerate(tracks): 99 | info = torchaudio.info(track + "/mix.flac") 100 | cfg = dict(info=info) 101 | 102 | track_cfg = OmegaConf.load(track + "/metadata.yaml") 103 | inst_track = dict() 104 | for i in track_cfg.stems: 105 | if track_cfg.stems[i].audio_rendered and track_cfg.stems[i].inst_class in self.inst and track_cfg.stems[i].midi_saved: 106 | if track_cfg.stems[i].inst_class in inst_track: 107 | prev_i = inst_track[track_cfg.stems[i].inst_class] 108 | if track_cfg.stems[prev_i].program_num > track_cfg.stems[i].program_num: 109 | inst_track[track_cfg.stems[i].inst_class] = i 110 | else: 111 | inst_track[track_cfg.stems[i].inst_class] = i 112 | cfg['inst_track'] = inst_track 113 | cfgs.append(cfg) 114 | return cfgs 115 | 116 | def __getitem__(self, idx): 117 | track = self.tracks[idx] 118 | info = self.meta[idx]['info'] 119 | inst_track = self.meta[idx]['inst_track'] 120 | 121 | if self.random: 122 | step_begin = np.random.randint(info.num_frames - int(self.duration*self.sr) - 1) // self.hop_length 123 | else: 124 | frame_offset = (info.num_frames - int(self.duration*self.sr) - 1)//3 125 | step_begin = frame_offset // self.hop_length 126 | 127 | n_steps = self.num_tr_frames 128 | step_end = step_begin + n_steps 129 | 130 | raw_begin = step_begin * self.hop_length 131 | 132 | ys = [] 133 | transcripts = [] 134 | onsets = [] 135 | offsets = [] 136 | 137 | for inst in self.inst: 138 | if inst in inst_track: 139 | stem = inst_track[inst] 140 | y, sr = torchaudio.load(track + f"/stems/{stem}.flac", frame_offset=raw_begin, num_frames=int(self.duration*self.sr)) 141 | assert sr == self.sr 142 | 143 | if len(y[0]) < int(self.sr * self.duration): 144 | y = nn.functional.pad(y, (0, int(self.sr * self.duration) - len(y[0]))) 145 | 146 | label = self.get_label(idx)[inst] # (steps, pitch) 147 | midi_piece = label['label'][step_begin:step_end, :] 148 | 149 | ys.append(y) 150 | transcripts.append((midi_piece > 1).float().T) 151 | onsets.append((midi_piece == 3).float().T) 152 | offsets.append((midi_piece == 1).float().T) 153 | else: 154 | # 0 tensor if no source is available on that track 155 | ys.append(torch.zeros((1, int(self.sr * self.duration)))) 156 | roll = np.zeros((self.num_pitches , self.num_tr_frames), dtype=np.float32) 157 | transcripts.append(torch.Tensor(roll)) 158 | onsets.append(torch.Tensor(roll)) 159 | offsets.append(torch.Tensor(roll)) 160 | 161 | separation_gt = torch.cat(ys) 162 | transcripts = torch.stack(transcripts) 163 | onsets = torch.stack(onsets) 164 | offsets = torch.stack(offsets) 165 | mix = torch.sum(separation_gt, dim=0) 166 | 167 | sample = dict( 168 | separation_gt=separation_gt, 169 | mix=mix, 170 | transcripts_gt=transcripts, 171 | onsets_gt=onsets, 172 | offsets_gt=offsets 173 | ) 174 | 175 | return sample 176 | 177 | def __len__(self): 178 | return len(self.tracks) -------------------------------------------------------------------------------- /train/inference.py: -------------------------------------------------------------------------------- 1 | # python inference.py config_path weight_ckpt input_wav output_dir 2 | 3 | import os 4 | import sys 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import torchaudio 9 | import matplotlib.pyplot as plt 10 | 11 | import librosa 12 | import librosa.display 13 | from omegaconf import OmegaConf 14 | 15 | import pretty_midi as pm 16 | import numpy as np 17 | 18 | from network.cerberus_wrapper import CerberusWrapper 19 | 20 | config_path = sys.argv[1] # "lightning_logs/experiment_name/version_0/hparams.yaml" 21 | weight_path = sys.argv[2] # "lightning_logs/experiment_name/version_0/checkpoints/last.ckpt" 22 | input_wav = sys.argv[3] 23 | output_dir = sys.argv[4] 24 | 25 | config = OmegaConf.load(config_path) 26 | 27 | y, sr = torchaudio.load(input_wav, frame_offset=config.sr*90, num_frames=config.sr*20) 28 | assert sr == config.sr 29 | 30 | n = CerberusWrapper.load_from_checkpoint(weight_path) 31 | 32 | rt = n.get_transcripts(y) 33 | 34 | fs = config.sr / config.hop_length 35 | plt.rcParams.update({"figure.facecolor": (1.0, 1.0, 1.0, 1.0), 36 | "axes.facecolor": (1.0, 1.0, 1.0, 1.0), 37 | }) 38 | 39 | os.makedirs(output_dir, exist_ok=True) 40 | 41 | for i, inst in enumerate(config.inst): 42 | plt.title(inst + "_Activation") 43 | librosa.display.specshow(rt[0][i].float().detach().numpy(), hop_length=1, sr=int(fs), x_axis='time', y_axis='cqt_note', 44 | fmin=pm.note_number_to_hz(config.midi_min)) 45 | plt.savefig(os.path.join(output_dir, f"transcript_activation_{inst}.jpg")) 46 | threshold = 0.8 47 | if inst == "Drums": 48 | threshold = threshold / 4 49 | 50 | plt.title(inst + "_Binary") 51 | librosa.display.specshow(rt[0][i].float().detach().numpy() > threshold, hop_length=1, sr=int(fs), x_axis='time', y_axis='cqt_note', 52 | fmin=pm.note_number_to_hz(config.midi_min)) 53 | plt.savefig(os.path.join(output_dir, f"transcript_threshold_{inst}_{threshold}.jpg")) 54 | 55 | pmidi = n.multiple_piano_roll_to_pretty_midi(rt[0][:3]) 56 | pmidi.write(os.path.join(output_dir, "transcripted_midi.mid")) 57 | 58 | source_hat = n.get_separated_sources(y) 59 | 60 | for i, inst in enumerate(config.inst): 61 | audio_save_path = os.path.join(output_dir, f"separated_{inst}.wav") 62 | torchaudio.save(audio_save_path, source_hat[0, i].detach().unsqueeze(0), sample_rate=config.sr) 63 | -------------------------------------------------------------------------------- /train/loss/deep_cluster_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DeepClusterLoss(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, tf_embedding: torch.Tensor, separation_gt_spec: torch.Tensor) -> torch.Tensor: 10 | """ 11 | tf_embedding : (batch, freq*time, embedding_size) 12 | separation_gt_spec : (batch, inst, time, freq) 13 | """ 14 | batch, num_inst, time, freq = separation_gt_spec.shape 15 | 16 | target_idx = separation_gt_spec.argmax(dim=1) 17 | t = nn.functional.one_hot(target_idx, num_classes=num_inst).float() 18 | t = t.view(batch, time*freq, num_inst) 19 | 20 | v = tf_embedding 21 | 22 | # (b, embedding_size, TF) * (b, TF, embedding_size) = (b, embedding_size, embedding_size) 23 | vvT = torch.matmul(v.transpose(-1, -2), v) 24 | ttT = torch.matmul(t.transpose(-1, -2), t) 25 | vTt = torch.matmul(v.transpose(-1, -2), t) 26 | 27 | loss = vvT.norm() + ttT.norm() - 2*vTt.norm() 28 | return loss -------------------------------------------------------------------------------- /train/loss/mask_inference_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import torch.nn.functional as F 5 | from utils.dsp import apply_masks 6 | 7 | class MaskInferenceLoss(nn.Module): 8 | def __init__(self, num_inst): 9 | super().__init__() 10 | self.num_inst = num_inst 11 | self.mse = nn.MSELoss() 12 | 13 | def forward(self, separation_mask, mix_mag, separation_gt_mag): 14 | """ 15 | separation_mask : (batch, inst, time, freq) 16 | mix_mag : (batch, time, freq) 17 | separation_gt_mag : (batch, inst, time, freq) 18 | """ 19 | hat = apply_masks(separation_mask, mix_mag, self.num_inst) 20 | loss = self.mse(hat, separation_gt_mag) 21 | return loss -------------------------------------------------------------------------------- /train/metrics/transcript_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from mir_eval.multipitch import evaluate as evaluate_frames 6 | from mir_eval.util import midi_to_hz 7 | 8 | 9 | def get_tf(roll, midi_min, scaling): 10 | # roll (freq, time) 11 | time = np.arange(roll.shape[1]) 12 | freqs = [roll[:, t].nonzero(as_tuple=True)[0] for t in time] 13 | t_ref, f_ref = time, freqs 14 | t_ref = t_ref.astype(np.float64) * scaling 15 | f_ref = [np.array([midi_to_hz(midi_min + midi) for midi in freqs]) for freqs in f_ref] 16 | return t_ref, f_ref 17 | 18 | 19 | class TrMetrics: 20 | def __init__(self, config): 21 | self.config = config 22 | self.scaling = self.config.hop_length / self.config.sr 23 | 24 | def __call__(self, pred, label, threshold=None): 25 | """ 26 | pred : (freq, time) float 27 | label : (freq, time) float 28 | 29 | return : dict 30 | 31 | e.g. OrderedDict([('Precision', 0.07260726072607261), 32 | ('Recall', 0.056921086675291076), 33 | ('Accuracy', 0.03295880149812734), 34 | ('Substitution Error', 0.3738680465717982), 35 | ('Miss Error', 0.5692108667529108), 36 | ('False Alarm Error', 0.35316946959896506), 37 | ('Total Error', 1.296248382923674), 38 | ('Chroma Precision', 0.23927392739273928), 39 | ('Chroma Recall', 0.18758085381630013), 40 | ('Chroma Accuracy', 0.11750405186385737), 41 | ('Chroma Substitution Error', 0.24320827943078913), 42 | ('Chroma Miss Error', 0.5692108667529108), 43 | ('Chroma False Alarm Error', 0.35316946959896506), 44 | ('Chroma Total Error', 1.165588615782665)]) 45 | 46 | """ 47 | if threshold is None: 48 | threshold = self.config.transcription_threshold 49 | 50 | t_ref, f_ref = get_tf(label.int(), self.config.midi_min, scaling=self.scaling) 51 | t_est, f_est = get_tf((pred > threshold).int(), self.config.midi_min, scaling=self.scaling) 52 | frame_metrics = evaluate_frames(t_ref, f_ref, t_est, f_est) 53 | 54 | return frame_metrics -------------------------------------------------------------------------------- /train/network/cerberus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import torchaudio.models 5 | 6 | from network.shared_body import SharedBody 7 | from network.clustering_head import ClusteringHead 8 | from network.separation_head import SeparationHead 9 | from network.transcription_head import TranscriptionHead 10 | 11 | class Cerberus(nn.Module): 12 | def __init__(self, config): 13 | super().__init__() 14 | self.config = config 15 | 16 | self.shared_body = SharedBody(config) 17 | if 'sep' in self.config.heads: 18 | self.separation_head = SeparationHead(config) 19 | else: 20 | self.separation_head = None 21 | 22 | if 'dc' in self.config.heads: 23 | self.clustering_head = ClusteringHead(config) 24 | else: 25 | self.clustering_head = None 26 | 27 | if 'tr' in self.config.heads: 28 | self.transcription_head = TranscriptionHead(config) 29 | else: 30 | self.transcription_head = None 31 | 32 | def forward(self, mix_mag): 33 | spec = mix_mag 34 | shared_representation = self.shared_body(spec) 35 | 36 | rt = dict() 37 | 38 | if self.separation_head is not None: 39 | separation_mask = self.separation_head(shared_representation) 40 | rt['separation_mask'] = separation_mask 41 | 42 | if self.clustering_head is not None: 43 | embedding = self.clustering_head(shared_representation) 44 | rt['embedding'] = embedding 45 | 46 | if self.transcription_head is not None: 47 | transcripts = self.transcription_head(shared_representation) 48 | rt['transcripts'] = transcripts 49 | 50 | return rt -------------------------------------------------------------------------------- /train/network/cerberus_wrapper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import pytorch_lightning as pl 9 | import torchaudio 10 | from omegaconf import OmegaConf 11 | import librosa.display 12 | import pretty_midi as pm 13 | import matplotlib.pyplot as plt 14 | 15 | from network.cerberus import Cerberus 16 | from network.transform_layer import ISTFT 17 | 18 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../") 19 | 20 | from loss.mask_inference_loss import MaskInferenceLoss 21 | from loss.deep_cluster_loss import DeepClusterLoss 22 | from metrics.transcript_metric import TrMetrics 23 | from utils.dsp import realimag, apply_masks 24 | from dataset.slakh2100 import Slakh2100 25 | from dataset.midi import piano_roll_to_pretty_midi, program_dict 26 | 27 | class CerberusWrapper(pl.LightningModule): 28 | def __init__(self, config): 29 | super().__init__() 30 | self.config = config 31 | self.save_hyperparameters(self.config) 32 | 33 | self.stft = torchaudio.transforms.Spectrogram(n_fft=config.n_fft, hop_length=config.hop_length, power=None) 34 | self.istft = ISTFT(config) 35 | 36 | if 'sep' in self.config.heads: 37 | self.mask_inference_loss = MaskInferenceLoss(config.num_inst) 38 | self.mask_metric = None # Not implemented. 39 | else: 40 | self.mask_inference_loss = None 41 | self.mask_metric = None 42 | 43 | if 'dc' in self.config.heads: 44 | self.deep_clustering_loss = DeepClusterLoss() 45 | self.clustering_metric = None # Not implemented. 46 | else: 47 | self.deep_clustering_loss = None 48 | self.clustering_metric = None 49 | 50 | if 'tr' in self.config.heads: 51 | self.transcription_loss = nn.BCELoss() 52 | self.transcription_metric = TrMetrics(config) 53 | else: 54 | self.transcription_loss = None 55 | self.transcription_metric = None 56 | 57 | self.cerberus = Cerberus(config) 58 | 59 | # auto lr find 60 | self.lr = config.lr 61 | 62 | self.saved_gt_to_tensorboard = False 63 | 64 | def val_dataloader(self): 65 | config = self.config 66 | 67 | valid_ds = Slakh2100(data_dir=config.data_dir, 68 | phase="validation", 69 | inst=config.inst, 70 | duration=config.duration, 71 | use_cache=True, 72 | sr=config.sr, 73 | random=False, 74 | n_fft=config.n_fft, 75 | hop_length=config.hop_length, 76 | midi_min=config.midi_min) 77 | valid_dl = torch.utils.data.DataLoader(valid_ds, shuffle=False, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True) 78 | return valid_dl 79 | 80 | def train_dataloader(self): 81 | config = self.config 82 | 83 | train_ds = Slakh2100(data_dir=config.data_dir, 84 | phase="train", 85 | inst=config.inst, 86 | duration=config.duration, 87 | use_cache=True, 88 | sr=config.sr, 89 | random=True, 90 | n_fft=config.n_fft, 91 | hop_length=config.hop_length, 92 | midi_min=config.midi_min) 93 | 94 | train_dl = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True) 95 | return train_dl 96 | 97 | def test_dataloader(self): 98 | config = self.config 99 | 100 | test_ds = Slakh2100(data_dir=config.data_dir, 101 | phase="test", 102 | inst=config.inst, 103 | duration=4., 104 | use_cache=True, 105 | sr=config.sr, 106 | random=False, 107 | n_fft=config.n_fft, 108 | hop_length=config.hop_length, 109 | midi_min=config.midi_min) 110 | test_dl = torch.utils.data.DataLoader(test_ds, shuffle=False, batch_size=2, num_workers=config.num_workers, pin_memory=True) 111 | return test_dl 112 | 113 | 114 | def batch_preprocessing(self, mini_batch:dict): 115 | """ 116 | mix -> its stft, magnitude, phase 117 | separation(source) -> stft, magnitude 118 | 119 | mini_batch.keys() 120 | >> "mix", "separation_gt", "transcript_gt" 121 | mini_batch = self.batch_preprocessing(mini_batch) 122 | mini_batch.keys() 123 | >> "mix", "separation_gt", "transcript_gt", 124 | "mix_mag", "mix_phase", "mix_stft", 125 | "separation_gt_stft", "separation_gt_mag" 126 | """ 127 | 128 | mix_stft = self.stft(mini_batch['mix']) # (batch, freq, time, 2) 129 | mix_mag, mix_phase = torchaudio.functional.magphase(mix_stft, power=1.0) 130 | 131 | mix_mag = mix_mag.permute(0, 2, 1) # (batch, time, freq) 132 | 133 | sep_stft = self.stft(mini_batch['separation_gt']) 134 | sep_mag, sep_phase = torchaudio.functional.magphase(sep_stft, power=1.0) 135 | sep_mag = sep_mag.permute(0, 1, 3, 2) 136 | 137 | mini_batch['mix_stft'] = mix_stft 138 | mini_batch['mix_mag'] = mix_mag 139 | mini_batch['mix_phase'] = mix_phase 140 | mini_batch['separation_gt_stft'] = sep_stft 141 | mini_batch['separation_gt_mag'] = sep_mag 142 | 143 | return mini_batch 144 | 145 | def sample_preprocessing(self, mix:torch.Tensor): 146 | prep = dict() 147 | 148 | # mix : (batch, time) 149 | mix_stft = self.stft(mix) 150 | 151 | # stft : (batch, freq, time, 2) 152 | mix_mag, phase = torchaudio.functional.magphase(mix_stft, power=1.0) 153 | 154 | # mag : (batch, freq, time) 155 | 156 | mix_mag = mix_mag.permute(0, 2, 1) # mag : (batch, time, freq) 157 | 158 | prep['mix_mag'] = mix_mag 159 | prep['phase'] = phase 160 | return prep 161 | 162 | def get_separated_sources(self, mix:torch.Tensor): 163 | """ 164 | Just for inference 165 | audio(batch, time) 입력으로 받아서 num_inst만큼 return (batch, inst, time) 166 | """ 167 | prep = self.sample_preprocessing(mix) 168 | mix_mag, phase = prep['mix_mag'], prep['phase'] 169 | rt = self(mix_mag) 170 | 171 | mask = rt['separation_mask'] 172 | sep_mag_hat = apply_masks(mask, mix_mag, self.config.num_inst) #(batch, inst, time, freq) 173 | 174 | # (batch, inst, time, freq) -> (batch, inst, freq, time) 175 | sep_mag_hat = sep_mag_hat.permute(0, 1, 3, 2) 176 | 177 | # use mix phase to do istft 178 | b, f, t = phase.shape 179 | phase_repeat = phase.repeat_interleave(self.config.num_inst, 0).view(b, self.config.num_inst, f, t) # (batch, inst, freq, time) 180 | 181 | # (batch, inst, freq, time, 2) 182 | sep_stft_hat = realimag(sep_mag_hat, phase_repeat) 183 | 184 | # (batch, inst, time) 185 | separated_wavs = self.istft(sep_stft_hat) 186 | return separated_wavs 187 | 188 | def get_transcripts(self, mix): 189 | """ 190 | Just for inference 191 | 192 | mix : audio(batch, time) 193 | 194 | return : transcription (batch, inst, pitch, time) 195 | """ 196 | prep = self.sample_preprocessing(mix) 197 | mix_mag = prep['mix_mag'] 198 | rt = self(mix_mag) 199 | transcripts = rt['transcripts'] 200 | 201 | # (batch, inst, pitch, time) 202 | return transcripts 203 | 204 | def common_step(self, mini_batch:dict, phase:str): 205 | mini_batch = self.batch_preprocessing(mini_batch) 206 | 207 | rt = self(mini_batch['mix_mag']) 208 | 209 | total_loss = 0. 210 | log = dict() 211 | 212 | if self.mask_inference_loss is not None: 213 | mask_inference_loss = self.mask_inference_loss(rt['separation_mask'], mini_batch['mix_mag'], mini_batch['separation_gt_mag']) 214 | total_loss += self.config.loss_beta * mask_inference_loss 215 | log[f'{phase}_mask_inference_loss'] = mask_inference_loss 216 | 217 | if self.deep_clustering_loss is not None: 218 | deep_clustering_loss = self.deep_clustering_loss(rt['embedding'], mini_batch['separation_gt_mag']) 219 | total_loss += self.config.loss_alpha * deep_clustering_loss 220 | log[f'{phase}_dc_loss'] = deep_clustering_loss 221 | 222 | if self.transcription_loss is not None: 223 | transcription_loss = self.transcription_loss(rt['transcripts'], mini_batch['transcripts_gt']) 224 | total_loss += self.config.loss_gamma * transcription_loss 225 | log[f'{phase}_tr_loss'] = transcription_loss 226 | 227 | log[f'{phase}_total_loss'] = total_loss 228 | self.log_dict(log, on_epoch=True, on_step=False) 229 | return total_loss 230 | 231 | def forward(self, mix_mag): 232 | rt = self.cerberus(mix_mag) 233 | return rt 234 | 235 | def on_train_start(self): 236 | metrics = {k: v for d in OmegaConf.to_container(self.config.metrics) for k, v in d.items()} 237 | if not isinstance(self.logger, pl.loggers.base.DummyLogger): 238 | # dummy logger일 때(auto lr find) 아래에서 에러나서 239 | self.logger.log_hyperparams(self.hparams, metrics=metrics) 240 | 241 | def training_step(self, mini_batch, batch_idx): 242 | phase = "train" 243 | total_loss = self.common_step(mini_batch, phase) 244 | return total_loss 245 | 246 | def validation_step(self, mini_batch, batch_idx): 247 | phase = "valid" 248 | total_loss = self.common_step(mini_batch, phase) 249 | return total_loss 250 | 251 | def test_step(self, mini_batch, batch_idx): 252 | """ 253 | batch -> transcript, separation -> transcription metric 254 | Warning : TOO SLOW 255 | """ 256 | 257 | phase = "test" 258 | batch_size = mini_batch['mix'].shape[0] 259 | transcript_est = self.get_transcripts(mini_batch['mix']) 260 | 261 | target_metrics = ['Precision', 'Recall', 'Accuracy'] 262 | 263 | # metrics['piano'] = {'Precision' : [0.33, 0.324], 'Recall' : [0.24, 0.15], 'Accuracy' : [0.33, 0.22]} 264 | metrics = dict() 265 | for inst in self.config.inst: 266 | # inst : str 267 | metrics[inst] = defaultdict(list) 268 | 269 | for i, inst in enumerate(self.config.inst): 270 | # inst : int 271 | for b in range(batch_size): 272 | tr = transcript_est[b, i] 273 | gt = mini_batch['transcripts_gt'][b, i] 274 | threshold = self.config.transcription_threshold / 4 if inst == "Drums" else self.config.transcription_threshold 275 | result = self.transcription_metric(tr, gt, threshold) 276 | 277 | for tm in target_metrics: 278 | metrics[inst][tm].append(result[tm]) 279 | 280 | return metrics 281 | 282 | def test_epoch_end(self, test_out): 283 | target_metrics = ['Precision', 'Recall', 'Accuracy'] 284 | num_samples = 0 285 | for metric in test_out: 286 | num_samples += len(metric[self.config.inst[0]][target_metrics[0]]) 287 | 288 | # final_metrics['piano'] = {'precision': 0.324, 'Recall': .2343 ... } 289 | final_metrics = dict() 290 | for inst in self.config.inst: 291 | # inst : str 292 | final_metrics[inst] = defaultdict(float) 293 | 294 | for out in test_out: 295 | for inst in self.config.inst: 296 | for tm in target_metrics: 297 | final_metrics[inst][tm] += sum(out[inst][tm]) 298 | 299 | for inst in self.config.inst: 300 | for tm in target_metrics: 301 | final_metrics[inst][tm] /= num_samples 302 | 303 | self.log_dict(final_metrics) 304 | 305 | def multiple_piano_roll_to_pretty_midi(self, rt): 306 | # rt = (inst, freq, time) (float tensor) 307 | # return : prettymidi object 308 | 309 | config = self.config 310 | pmidi = pm.PrettyMIDI() 311 | 312 | for i in range(config.num_inst): 313 | threshold = config.transcription_threshold / 4 if config.inst[i] == "Drums" else config.transcription_threshold 314 | is_drum = config.inst[i] == "Drums" 315 | 316 | pr = rt[i].cpu().float().detach() > threshold 317 | pr_pad = torch.nn.functional.pad(pr, (0, 0, config.midi_min, 128-config.num_pitches-config.midi_min)) 318 | instrument = piano_roll_to_pretty_midi(pr_pad.int().numpy(), fs=config.sr/config.hop_length, program=program_dict[config.inst[i]]).instruments[0] 319 | instrument.is_drum = is_drum 320 | pmidi.instruments.append(instrument) 321 | 322 | return pmidi 323 | 324 | def validation_epoch_end(self, val_out: list): 325 | """ 326 | Save Audio / Transcription results on every validation 327 | """ 328 | writer = self.logger.experiment 329 | 330 | mini_batch = next(iter(self.trainer.val_dataloaders[0])) 331 | for k,v in mini_batch.items(): 332 | mini_batch[k] = v.cuda() 333 | 334 | sample_audio, sr = torchaudio.load(self.config.sample_audio.path, frame_offset=self.config.sample_audio.offset, num_frames=self.config.sample_audio.num_frames) 335 | sample_audio = sample_audio.cuda() 336 | 337 | if self.cerberus.separation_head is not None: 338 | y_hat = self.get_separated_sources(mini_batch['mix']) 339 | for i in range(self.config.num_inst): 340 | writer.add_audio(f"{self.config.inst[i]}_hat", y_hat[4, i].cpu(), self.trainer.global_step , sample_rate=self.config.sr) 341 | if not self.saved_gt_to_tensorboard: 342 | writer.add_audio(f"{self.config.inst[i]}_gt", mini_batch['separation_gt'][4, i].cpu(), self.trainer.global_step , sample_rate=self.config.sr) 343 | 344 | y_hat = self.get_separated_sources(sample_audio) 345 | if not self.saved_gt_to_tensorboard: 346 | writer.add_audio(f"(real)gt", sample_audio[0].cpu(), self.trainer.global_step , sample_rate=self.config.sr) 347 | 348 | for i in range(self.config.num_inst): 349 | writer.add_audio(f"(real){self.config.inst[i]}_hat", y_hat[0, i].cpu(), self.trainer.global_step , sample_rate=self.config.sr) 350 | 351 | if self.cerberus.transcription_head is not None: 352 | transcripts_hat = self.get_transcripts(mini_batch['mix']) 353 | for i in range(self.config.num_inst): 354 | writer.add_image(f"{self.config.inst[i]}_tr_hat", transcripts_hat[4, i].unsqueeze(0).cpu(), self.trainer.global_step) 355 | writer.add_image(f"{self.config.inst[i]}_tr_hat_{self.config.transcription_threshold}", (transcripts_hat[4, i].unsqueeze(0).cpu() > self.config.transcription_threshold).float(), self.trainer.global_step) 356 | if not self.saved_gt_to_tensorboard: 357 | writer.add_image(f"{self.config.inst[i]}_tr_gt", mini_batch['transcripts_gt'][4, i].unsqueeze(0).cpu(), self.trainer.global_step) 358 | 359 | sample_transcription = self.get_transcripts(sample_audio) 360 | for i in range(self.config.num_inst): 361 | writer.add_image(f"(real){self.config.inst[i]}_tr_hat", sample_transcription[0, i].unsqueeze(0).cpu(), self.trainer.global_step) 362 | writer.add_image(f"(real){self.config.inst[i]}_tr_hat_{self.config.transcription_threshold}", (sample_transcription[0, i].unsqueeze(0).cpu() > self.config.transcription_threshold).float(), self.trainer.global_step) 363 | 364 | if not self.saved_gt_to_tensorboard: 365 | self.saved_gt_to_tensorboard = True 366 | 367 | 368 | def configure_optimizers(self): 369 | config = self.config 370 | 371 | if config.optimizer == 'adam': 372 | optimizer = optim.Adam(self.parameters(), lr=self.lr) 373 | elif config.optimizer == "radam": 374 | from optimizer.radam import RAdam 375 | optimizer = RAdam(self.parameters(), lr=self.lr) 376 | elif config.optimizer == "rmsprop": 377 | optimizer = optim.RMSprop(self.parameters(), lr=self.lr) 378 | else: 379 | raise NotImplementedError 380 | 381 | reduce_on_plateau=False 382 | 383 | # Setting Scheduler 384 | monitor = None 385 | if config.lr_scheduler == "cosine": 386 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 387 | optimizer, T_0=200, eta_min=config.lr_min 388 | ) 389 | 390 | elif config.lr_scheduler == "plateau": 391 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 392 | optimizer, mode="min", patience=3, factor=config.lr_decay 393 | ) 394 | monitor = "valid_tr_loss" 395 | reduce_on_plateau=True 396 | 397 | elif config.lr_scheduler == "multistep": 398 | scheduler = optim.lr_scheduler.MultiStepLR( 399 | optimizer, 400 | [5 * (x + 2) for x in range(500)], 401 | gamma=config.lr_decay, 402 | ) 403 | 404 | elif config.lr_scheduler == "no": 405 | scheduler = None 406 | else: 407 | raise ValueError(f"unknown lr_scheduler :: {config.lr_scheduler}") 408 | 409 | if scheduler is not None: 410 | if monitor is not None: 411 | optimizers = [optimizer] 412 | schedulers = [ 413 | dict( 414 | scheduler=scheduler, 415 | monitor=monitor, 416 | interval='epoch', 417 | reduce_on_plateau=reduce_on_plateau, 418 | frequency=config.check_val_every_n_epoch, 419 | )] 420 | 421 | return optimizers, schedulers 422 | else: 423 | return [optimizer], [scheduler] 424 | else: 425 | return optimizer -------------------------------------------------------------------------------- /train/network/clustering_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ClusteringHead(nn.Module): 5 | def __init__(self, config): 6 | super().__init__() 7 | 8 | # Deep cluster head 9 | dense_in = config.lstm_hidden_size * (int(config.lstm_bidirectional)+1) 10 | dense_out = (config.n_fft//2 + 1) * (config.embedding_size) 11 | 12 | self.num_inst = config.num_inst 13 | self.embedding_size = config.embedding_size 14 | 15 | self.dense_embedding = nn.Linear(dense_in, dense_out) 16 | self.activate = nn.Tanh() 17 | 18 | 19 | def forward(self, shared_representation: torch.Tensor) -> torch.Tensor: 20 | # (batch, time, shared_representation) 21 | 22 | batch, time, embed_numinst = shared_representation.shape 23 | 24 | proj = self.dense_embedding(shared_representation) # (batch, time, embedding_size * n_frequency) 25 | proj = self.activate(proj) 26 | 27 | proj = proj.view(batch, -1, self.embedding_size) # (batch, time * n_frequency, embedding_size) 28 | proj_norm = torch.norm(proj, p=2, dim=-1, keepdim=True) 29 | proj_one = proj / (proj_norm + 1e-12) 30 | 31 | return proj_one 32 | -------------------------------------------------------------------------------- /train/network/separation_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | 5 | class SeparationHead(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | 9 | # Mask heads 10 | dense_in = config.lstm_hidden_size * (int(config.lstm_bidirectional)+1) 11 | dense_out = (config.n_fft//2 + 1) * config.num_inst 12 | 13 | self.num_inst = config.num_inst 14 | self.dense = nn.Linear(dense_in, dense_out) 15 | self.activate = nn.Softmax2d() 16 | 17 | 18 | def forward(self, shared_representation: torch.Tensor) -> torch.Tensor: 19 | # (batch, time, shared_representation) 20 | 21 | mask = self.dense(shared_representation) # (batch, time, frequency * num_inst) mask 22 | batch, time, frequency_inst = mask.shape 23 | mask = mask.view(batch, time, self.num_inst, -1) # (batch, time, num_inst, frequency) 24 | masks = mask.permute(0, 2, 1, 3) # (batch, inst, time, freq) 25 | masks = self.activate(masks) 26 | 27 | return masks -------------------------------------------------------------------------------- /train/network/shared_body.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | 5 | class SharedBody(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | input_size = config.n_fft//2 + 1 9 | 10 | self.lstm = nn.LSTM(batch_first=True, 11 | input_size=input_size, 12 | hidden_size=config.lstm_hidden_size, 13 | num_layers=config.lstm_num_layers, 14 | bidirectional=config.lstm_bidirectional 15 | ) 16 | 17 | self.dropout = nn.Dropout(p=config.dropout_rate, inplace=False) 18 | 19 | def forward(self, spec) -> torch.Tensor: 20 | lstm_embed, (h, c) = self.lstm(spec) # (batch, time, lstm_embedding_dim) 21 | shared_representation = self.dropout(lstm_embed) 22 | 23 | return shared_representation -------------------------------------------------------------------------------- /train/network/transcription_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class TranscriptionHead(nn.Module): 5 | def __init__(self, config): 6 | super().__init__() 7 | self.num_pitches = config.num_pitches 8 | self.num_inst = config.num_inst 9 | 10 | dense_in = config.lstm_hidden_size * (int(config.lstm_bidirectional)+1) 11 | dense_out = self.num_pitches * self.num_inst 12 | 13 | self.dense = nn.Linear(dense_in, dense_out) 14 | self.activate = nn.Sigmoid() 15 | 16 | def forward(self, shared_representation): 17 | # (batch, time, shared_representation) 18 | batch, time, embed_numinst = shared_representation.shape 19 | 20 | transcript = self.dense(shared_representation) # (batch, time, num_inst * pitches) 21 | transcript = self.activate(transcript) 22 | transcript = transcript.view(batch, time, self.num_inst, -1) 23 | transcript = transcript.permute(0, 2, 3, 1) # (batch, num_inst, pitch, time) 24 | 25 | return transcript -------------------------------------------------------------------------------- /train/network/transform_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | 5 | class ISTFT(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | self.hop_length = config.hop_length 9 | self.n_fft = config.n_fft 10 | window = torch.hann_window(config.n_fft) 11 | self.register_buffer("window", window) 12 | 13 | def forward(self, X : torch.Tensor): 14 | 15 | num_inst = 1 16 | if len(X.shape) == 5: 17 | # 악기별 X일 때 (batch, inst, freq, time, 2) 18 | num_inst = X.shape[1] 19 | X = X.view(X.shape[0]*X.shape[1], X.shape[2], X.shape[3], 2) 20 | 21 | x = torch.istft(X, 22 | n_fft=self.n_fft, 23 | hop_length=self.hop_length, 24 | window=self.window, 25 | return_complex=False 26 | ) 27 | if num_inst != 1: 28 | x = x.view(x.shape[0] // num_inst, num_inst, -1) 29 | 30 | return x -------------------------------------------------------------------------------- /train/optimizer/radam.py: -------------------------------------------------------------------------------- 1 | # original : 2 | # https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | class RAdam(Optimizer): 8 | 9 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 10 | if not 0.0 <= lr: 11 | raise ValueError("Invalid learning rate: {}".format(lr)) 12 | if not 0.0 <= eps: 13 | raise ValueError("Invalid epsilon value: {}".format(eps)) 14 | if not 0.0 <= betas[0] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 16 | if not 0.0 <= betas[1] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 18 | 19 | self.degenerated_to_sgd = degenerated_to_sgd 20 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 21 | for param in params: 22 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 23 | param['buffer'] = [[None, None, None] for _ in range(10)] 24 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 25 | super(RAdam, self).__init__(params, defaults) 26 | 27 | def __setstate__(self, state): 28 | super(RAdam, self).__setstate__(state) 29 | 30 | def step(self, closure=None): 31 | 32 | loss = None 33 | if closure is not None: 34 | loss = closure() 35 | 36 | for group in self.param_groups: 37 | 38 | for p in group['params']: 39 | if p.grad is None: 40 | continue 41 | grad = p.grad.data.float() 42 | if grad.is_sparse: 43 | raise RuntimeError('RAdam does not support sparse gradients') 44 | 45 | p_data_fp32 = p.data.float() 46 | 47 | state = self.state[p] 48 | 49 | if len(state) == 0: 50 | state['step'] = 0 51 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 52 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 53 | else: 54 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 55 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 56 | 57 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 58 | beta1, beta2 = group['betas'] 59 | 60 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) 61 | exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1) 62 | 63 | state['step'] += 1 64 | buffered = group['buffer'][int(state['step'] % 10)] 65 | if state['step'] == buffered[0]: 66 | N_sma, step_size = buffered[1], buffered[2] 67 | else: 68 | buffered[0] = state['step'] 69 | beta2_t = beta2 ** state['step'] 70 | N_sma_max = 2 / (1 - beta2) - 1 71 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 72 | buffered[1] = N_sma 73 | 74 | # more conservative since it's an approximated value 75 | if N_sma >= 5: 76 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 77 | elif self.degenerated_to_sgd: 78 | step_size = 1.0 / (1 - beta1 ** state['step']) 79 | else: 80 | step_size = -1 81 | buffered[2] = step_size 82 | 83 | # more conservative since it's an approximated value 84 | if N_sma >= 5: 85 | if group['weight_decay'] != 0: 86 | p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr']) 87 | denom = exp_avg_sq.sqrt().add_(group['eps']) 88 | p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr']) 89 | p.data.copy_(p_data_fp32) 90 | elif step_size > 0: 91 | if group['weight_decay'] != 0: 92 | p_data_fp32.add_(p_data_fp32, alpha = -group['weight_decay'] * group['lr']) 93 | p_data_fp32.add_(exp_avg, alpha = -step_size * group['lr']) 94 | p.data.copy_(p_data_fp32) 95 | 96 | return loss 97 | 98 | class PlainRAdam(Optimizer): 99 | 100 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 101 | if not 0.0 <= lr: 102 | raise ValueError("Invalid learning rate: {}".format(lr)) 103 | if not 0.0 <= eps: 104 | raise ValueError("Invalid epsilon value: {}".format(eps)) 105 | if not 0.0 <= betas[0] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 107 | if not 0.0 <= betas[1] < 1.0: 108 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 109 | 110 | self.degenerated_to_sgd = degenerated_to_sgd 111 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 112 | 113 | super(PlainRAdam, self).__init__(params, defaults) 114 | 115 | def __setstate__(self, state): 116 | super(PlainRAdam, self).__setstate__(state) 117 | 118 | def step(self, closure=None): 119 | 120 | loss = None 121 | if closure is not None: 122 | loss = closure() 123 | 124 | for group in self.param_groups: 125 | 126 | for p in group['params']: 127 | if p.grad is None: 128 | continue 129 | grad = p.grad.data.float() 130 | if grad.is_sparse: 131 | raise RuntimeError('RAdam does not support sparse gradients') 132 | 133 | p_data_fp32 = p.data.float() 134 | 135 | state = self.state[p] 136 | 137 | if len(state) == 0: 138 | state['step'] = 0 139 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 140 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 141 | else: 142 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 143 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 144 | 145 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 146 | beta1, beta2 = group['betas'] 147 | 148 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 149 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 150 | 151 | state['step'] += 1 152 | beta2_t = beta2 ** state['step'] 153 | N_sma_max = 2 / (1 - beta2) - 1 154 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 155 | 156 | 157 | # more conservative since it's an approximated value 158 | if N_sma >= 5: 159 | if group['weight_decay'] != 0: 160 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 161 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 162 | denom = exp_avg_sq.sqrt().add_(group['eps']) 163 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 164 | p.data.copy_(p_data_fp32) 165 | elif self.degenerated_to_sgd: 166 | if group['weight_decay'] != 0: 167 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 168 | step_size = group['lr'] / (1 - beta1 ** state['step']) 169 | p_data_fp32.add_(-step_size, exp_avg) 170 | p.data.copy_(p_data_fp32) 171 | 172 | return loss 173 | 174 | 175 | class AdamW(Optimizer): 176 | 177 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 178 | if not 0.0 <= lr: 179 | raise ValueError("Invalid learning rate: {}".format(lr)) 180 | if not 0.0 <= eps: 181 | raise ValueError("Invalid epsilon value: {}".format(eps)) 182 | if not 0.0 <= betas[0] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 184 | if not 0.0 <= betas[1] < 1.0: 185 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 186 | 187 | defaults = dict(lr=lr, betas=betas, eps=eps, 188 | weight_decay=weight_decay, warmup = warmup) 189 | super(AdamW, self).__init__(params, defaults) 190 | 191 | def __setstate__(self, state): 192 | super(AdamW, self).__setstate__(state) 193 | 194 | def step(self, closure=None): 195 | loss = None 196 | if closure is not None: 197 | loss = closure() 198 | 199 | for group in self.param_groups: 200 | 201 | for p in group['params']: 202 | if p.grad is None: 203 | continue 204 | grad = p.grad.data.float() 205 | if grad.is_sparse: 206 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 207 | 208 | p_data_fp32 = p.data.float() 209 | 210 | state = self.state[p] 211 | 212 | if len(state) == 0: 213 | state['step'] = 0 214 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 215 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 216 | else: 217 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 218 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 219 | 220 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 221 | beta1, beta2 = group['betas'] 222 | 223 | state['step'] += 1 224 | 225 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 226 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 227 | 228 | denom = exp_avg_sq.sqrt().add_(group['eps']) 229 | bias_correction1 = 1 - beta1 ** state['step'] 230 | bias_correction2 = 1 - beta2 ** state['step'] 231 | 232 | if group['warmup'] > state['step']: 233 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 234 | else: 235 | scheduled_lr = group['lr'] 236 | 237 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 238 | 239 | if group['weight_decay'] != 0: 240 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 241 | 242 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 243 | 244 | p.data.copy_(p_data_fp32) 245 | 246 | return loss -------------------------------------------------------------------------------- /train/requirements.txt: -------------------------------------------------------------------------------- 1 | backtrace 2 | librosa 3 | mido 4 | mir-eval==0.6 5 | numpy 6 | omegaconf 7 | pretty-midi 8 | pytorch-lightning==1.0.4 9 | torch==1.7.0+cu92 10 | torchaudio==0.7.0 11 | torchsummary 12 | torchvision==0.8.1 13 | tqdm -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 5 | from pytorch_lightning.loggers import TensorBoardLogger 6 | 7 | from omegaconf import OmegaConf 8 | 9 | from network.cerberus_wrapper import CerberusWrapper 10 | from dataset.slakh2100 import Slakh2100 11 | from utils.debug import set_debug_mode 12 | 13 | set_debug_mode() 14 | config = OmegaConf.load("../configs/config.yaml") 15 | config.merge_with_cli() 16 | 17 | pl.trainer.seed_everything(config.seed) 18 | 19 | net = CerberusWrapper(config) 20 | 21 | ckpt_callback = ModelCheckpoint( 22 | monitor="valid_total_loss", 23 | filename="model-{epoch:03d}-{valid_total_loss:.4f}", 24 | save_top_k=1, 25 | mode='min', 26 | save_last=True 27 | ) 28 | 29 | logger = TensorBoardLogger(save_dir="lightning_logs", 30 | name=config.experiment_name, 31 | default_hp_metric=False) 32 | 33 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 34 | 35 | trainer = pl.Trainer( 36 | gpus=config.num_gpu, 37 | min_epochs=1, 38 | max_epochs=config.num_epochs, 39 | checkpoint_callback=ckpt_callback, 40 | check_val_every_n_epoch=config.check_val_every_n_epoch, 41 | gradient_clip_val=config.gradient_clip_val, 42 | auto_lr_find=config.find_lr, 43 | logger=logger, 44 | callbacks=[lr_monitor], 45 | weights_summary='full', 46 | profiler='simple', 47 | ) 48 | 49 | if config.find_lr: 50 | trainer.tune(net) 51 | 52 | trainer.fit(net) 53 | 54 | trainer.test(model=net) 55 | -------------------------------------------------------------------------------- /train/utils/debug.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import backtrace 3 | 4 | def set_debug_mode(): 5 | backtrace.hook(align=True) 6 | old_hook = sys.excepthook 7 | 8 | def new_hook(type_, value, tb): 9 | old_hook(type_, value, tb) 10 | if type_ != KeyboardInterrupt: 11 | import pdb 12 | 13 | pdb.post_mortem(tb) 14 | 15 | sys.excepthook = new_hook 16 | 17 | -------------------------------------------------------------------------------- /train/utils/decoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original Implementation from Jongwook's 3 | https://github.com/jongwook/onsets-and-frames/blob/master/onsets_and_frames/decoding.py 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def extract_notes(onsets, frames, onset_threshold=0.5, frame_threshold=0.5): 11 | """ 12 | Finds the note timings based on the onsets and frames information 13 | Parameters 14 | ---------- 15 | onsets: torch.FloatTensor, shape = [frames, bins] 16 | frames: torch.FloatTensor, shape = [frames, bins] 17 | onset_threshold: float 18 | frame_threshold: float 19 | Returns 20 | ------- 21 | pitches: np.ndarray of bin_indices 22 | intervals: np.ndarray of rows containing (onset_index, offset_index) 23 | """ 24 | onsets = (onsets > onset_threshold).cpu().to(torch.uint8) 25 | frames = (frames > frame_threshold).cpu().to(torch.uint8) 26 | onset_diff = torch.cat([onsets[:1, :], onsets[1:, :] - onsets[:-1, :]], dim=0) == 1 27 | 28 | pitches = [] 29 | intervals = [] 30 | 31 | for nonzero in onset_diff.nonzero(): 32 | frame = nonzero[0].item() 33 | pitch = nonzero[1].item() 34 | 35 | onset = frame 36 | offset = frame 37 | 38 | while onsets[offset, pitch].item() or frames[offset, pitch].item(): 39 | offset += 1 40 | if offset == onsets.shape[0]: 41 | break 42 | 43 | if offset > onset: 44 | pitches.append(pitch) 45 | intervals.append([onset, offset]) 46 | 47 | return np.array(pitches), np.array(intervals) 48 | 49 | 50 | def notes_to_frames(pitches, intervals, shape): 51 | """ 52 | Takes lists specifying notes sequences and return 53 | Parameters 54 | ---------- 55 | pitches: list of pitch bin indices 56 | intervals: list of [onset, offset] ranges of bin indices 57 | shape: the shape of the original piano roll, [n_frames, n_bins] 58 | Returns 59 | ------- 60 | time: np.ndarray containing the frame indices 61 | freqs: list of np.ndarray, each containing the frequency bin indices 62 | """ 63 | roll = np.zeros(tuple(shape)) 64 | for pitch, (onset, offset) in zip(pitches, intervals): 65 | roll[onset:offset, pitch] = 1 66 | 67 | time = np.arange(roll.shape[0]) 68 | freqs = [roll[t, :].nonzero()[0] for t in time] 69 | return time, freqs -------------------------------------------------------------------------------- /train/utils/dsp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def apply_masks(mask, mix_mag, num_inst): 4 | mix_shape = mix_mag.shape 5 | mix_mag_repeat = mix_mag.repeat_interleave(num_inst, 0).view(mix_shape[0], num_inst, mix_shape[1], mix_shape[2]) 6 | sep_mag_hat = mask * mix_mag_repeat 7 | return sep_mag_hat 8 | 9 | def realimag(mag, phase): 10 | """ 11 | Combine a magnitude spectrogram and a phase spectrogram to a complex-valued spectrogram with shape (*, 2) 12 | """ 13 | spec_real = mag * torch.cos(phase) 14 | spec_imag = mag * torch.sin(phase) 15 | spec = torch.stack([spec_real, spec_imag], dim=-1) 16 | return spec --------------------------------------------------------------------------------