├── .gitignore ├── README.md ├── config.yaml ├── download └── download.py ├── evaluate ├── evaluate_testset.py ├── inference_testset.py └── midi_melody_accuracy.py ├── layer ├── __init__.py └── input.py ├── midi_tokenizer.py ├── midiaudiopair.py ├── preprocess ├── README.md ├── beat_quantizer.py ├── bpm_quantize.py ├── melody_accuracy.py ├── pop_align.py └── split_spleeter.py ├── requirements.txt ├── train_dataset.csv ├── transformer_wrapper.py └── utils ├── __init__.py ├── demo.py └── dsp.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.zip 3 | .ipynb_checkpoints 4 | .vscode 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pop2Piano : Pop Audio-based Piano Cover Generation 2 | --- 3 | ## 4 | - [Paper](https://arxiv.org/abs/2211.00895) 5 | - [Colab](http://bit.ly/pop2piano-colab) 6 | - [Project Page](http://sweetcocoa.github.io/pop2piano_samples) 7 | 8 | ## How to prepare dataset 9 | ### Download Original Media 10 | --- 11 | - List of data : ```train_dataset.csv``` 12 | - Downloader : ```download/download.py``` 13 | - ```python download.py ../train_dataset.csv output_dir/``` 14 | 15 | ### Preprocess Data 16 | --- 17 | - [Details](./preprocess/) 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | project: pop2piano 2 | dataset: 3 | target_length: 256 4 | input_length: 1024 5 | n_bars: 2 6 | sample_rate: 22050 7 | use_mel: true 8 | mel_is_conditioned: true 9 | composer_to_feature_token: 10 | composer1: 2052 11 | composer2: 2053 12 | composer3: 2054 13 | composer4: 2055 14 | composer5: 2056 15 | composer6: 2057 16 | composer7: 2058 17 | composer8: 2059 18 | composer9: 2060 19 | composer10: 2061 20 | composer11: 2062 21 | composer12: 2063 22 | composer13: 2064 23 | composer14: 2065 24 | composer15: 2066 25 | composer16: 2067 26 | composer17: 2068 27 | composer18: 2069 28 | composer19: 2070 29 | composer20: 2071 30 | composer21: 2072 31 | t5: 32 | feed_forward_proj: gated-gelu 33 | tie_word_embeddings: false 34 | tie_encoder_decoder: false 35 | vocab_size: 2400 36 | n_positions: 1024 37 | relative_attention_num_buckets: 32 38 | tokenizer: 39 | vocab_size: 40 | special: 4 41 | note: 128 42 | velocity: 2 43 | time: 100 44 | training: 45 | seed: 3407 46 | resume: false 47 | offline: false 48 | num_gpu: 1 49 | max_epochs: 5000 50 | accumulate_grad_batches: 1 51 | check_val_every_n_epoch: 20 52 | find_lr: false 53 | optimizer: adafactor 54 | version: none 55 | lr: 0.001 56 | lr_min: 1.0e-06 57 | lr_scheduler: false 58 | lr_decay: 0.99 59 | batch_size: 32 60 | num_workers: 32 61 | gradient_clip_val: 3.0 62 | -------------------------------------------------------------------------------- /download/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python youtube_down.py piano_covers.txt /output/dir 4 | """ 5 | 6 | import os 7 | import multiprocessing 8 | 9 | import tempfile 10 | import shutil 11 | import glob 12 | import pandas as pd 13 | import re 14 | 15 | from tqdm import tqdm 16 | from joblib import Parallel, delayed 17 | from omegaconf import OmegaConf 18 | 19 | 20 | def download_piano( 21 | url: str, 22 | output_dir: str, 23 | postprocess=True, 24 | dry_run=False, 25 | ) -> int: 26 | # os.makedirs(os.path.dirname(output_dir), exist_ok=True) 27 | 28 | with tempfile.TemporaryDirectory() as tmpdir: 29 | output = f"{tmpdir}/%(uploader)s___%(title)s___%(id)s___%(duration)d.%(ext)s" 30 | 31 | if postprocess: 32 | postprocess_call = '--postprocessor-args "-ac 1 -ar 16000"' 33 | else: 34 | postprocess_call = "" 35 | result = os.system( 36 | f"""youtube-dl -o "{output}" \\ 37 | --extract-audio \\ 38 | --audio-quality 0 \\ 39 | --audio-format wav \\ 40 | --retries 50 \\ 41 | --prefer-ffmpeg \\ 42 | {"--get-filename" if dry_run else ""}\\ 43 | {postprocess_call} \\ 44 | --force-ipv4 \\ 45 | --yes-playlist \\ 46 | --ignore-errors \\ 47 | {url}""" 48 | ) 49 | 50 | if not dry_run: 51 | 52 | files = os.listdir(tmpdir) 53 | 54 | for filename in files: 55 | filename_wo_ext, ext = os.path.splitext(filename) 56 | uploader, title, ytid, duration = filename_wo_ext.split("___") 57 | meta = OmegaConf.create() 58 | meta.piano = OmegaConf.create() 59 | meta.piano.uploader = uploader 60 | meta.piano.title = title 61 | meta.piano.ytid = ytid 62 | meta.piano.duration = int(duration) 63 | OmegaConf.save(meta, os.path.join(output_dir, ytid + ".yaml")) 64 | shutil.move( 65 | os.path.join(tmpdir, filename), 66 | os.path.join(output_dir, f"{ytid}{ext}"), 67 | ) 68 | 69 | return result 70 | 71 | 72 | def download_piano_main(piano_list, output_dir, dry_run=False): 73 | """ 74 | piano_list : list of youtube id 75 | """ 76 | os.makedirs(output_dir, exist_ok=True) 77 | Parallel(n_jobs=multiprocessing.cpu_count())( 78 | delayed(download_piano)( 79 | url=f"https://www.youtube.com/watch?v={ytid}", 80 | output_dir=output_dir, 81 | postprocess=True, 82 | dry_run=dry_run, 83 | ) 84 | for ytid in tqdm(piano_list) 85 | ) 86 | 87 | 88 | def download_pop(piano_id, pop_id, output_dir, dry_run): 89 | output_file_template = "%(id)s___%(title)s___%(duration)d.%(ext)s" 90 | pop_output_dir = os.path.join(output_dir, piano_id) 91 | os.makedirs(pop_output_dir, exist_ok=True) 92 | output_template = os.path.join(output_dir, piano_id, output_file_template) 93 | url = f"https://www.youtube.com/watch?v={pop_id}" 94 | 95 | result = os.system( 96 | f"""youtube-dl -o "{output_template}" \\ 97 | --extract-audio \\ 98 | --audio-quality 0 \\ 99 | --audio-format wav \\ 100 | --retries 25 \\ 101 | {"--get-filename" if dry_run else ""}\\ 102 | --prefer-ffmpeg \\ 103 | --match-filter 'duration < 300 & duration > 150'\\ 104 | --postprocessor-args "-ac 2 -ar 44100" \\ 105 | {url}""" 106 | ) 107 | 108 | if not dry_run: 109 | files = list(filter(lambda x: x.endswith(".wav"), os.listdir(pop_output_dir))) 110 | files = glob.glob(os.path.join(pop_output_dir, "*.wav")) 111 | for filename in files: 112 | filename_wo_ext, ext = os.path.splitext(os.path.basename(filename)) 113 | ytid, title, duration = filename_wo_ext.split("___") 114 | yaml = os.path.join(output_dir, piano_id + ".yaml") 115 | 116 | meta = OmegaConf.load(yaml) 117 | meta.song = OmegaConf.create() 118 | meta.song.ytid = ytid 119 | meta.song.title = title 120 | meta.song.duration = int(duration) 121 | 122 | OmegaConf.save(meta, yaml) 123 | shutil.move( 124 | os.path.join(filename), 125 | os.path.join(pop_output_dir, f"{ytid}{ext}"), 126 | ) 127 | 128 | 129 | def download_pop_main(piano_list, pop_list, output_dir, dry_run=False): 130 | """ 131 | piano_list : list of youtube id 132 | pop_list : corresponding youtube id of pop songs 133 | """ 134 | 135 | Parallel(n_jobs=multiprocessing.cpu_count())( 136 | delayed(download_pop)( 137 | piano_id=piano_id, 138 | pop_id=pop_id, 139 | output_dir=output_dir, 140 | dry_run=dry_run, 141 | ) 142 | for piano_id, pop_id in tqdm(list(zip(piano_list, pop_list))) 143 | ) 144 | 145 | 146 | if __name__ == "__main__": 147 | import argparse 148 | 149 | parser = argparse.ArgumentParser(description="piano cover downloader") 150 | 151 | parser.add_argument("dataset", type=str, default=None, help="provided csv") 152 | parser.add_argument("output_dir", type=str, default=None, help="output dir") 153 | parser.add_argument( 154 | "--num_audio", 155 | type=int, 156 | default=None, 157 | help="if specified, only {num_audio} pairs will be downloaded", 158 | ) 159 | parser.add_argument( 160 | "--dry_run", default=False, action="store_true", help="whether dry_run" 161 | ) 162 | args = parser.parse_args() 163 | 164 | df = pd.read_csv(args.dataset) 165 | df = df[: args.num_audio] 166 | piano_list = df["piano_ids"].tolist() 167 | download_piano_main(piano_list, args.output_dir, args.dry_run) 168 | 169 | available_piano_list = glob.glob(args.output_dir + "/**/*.yaml", recursive=True) 170 | df.index = df["piano_ids"] 171 | 172 | failed_piano = [] 173 | 174 | available_piano_list_id = [ 175 | os.path.splitext(os.path.basename(ap))[0] for ap in available_piano_list 176 | ] 177 | 178 | for piano_id_to_be_downloaded in tqdm(df["piano_ids"]): 179 | if piano_id_to_be_downloaded in available_piano_list_id: 180 | continue 181 | else: 182 | failed_piano.append(piano_id_to_be_downloaded) 183 | 184 | if len(failed_piano) > 0: 185 | print(f"{len(failed_piano)} of files are failed to be downloaded") 186 | df = df.drop(index=failed_piano) 187 | 188 | piano_list = df["piano_ids"].tolist() 189 | pop_list = df["pop_ids"].tolist() 190 | 191 | download_pop_main( 192 | piano_list, pop_list, output_dir=args.output_dir, dry_run=args.dry_run 193 | ) 194 | -------------------------------------------------------------------------------- /evaluate/evaluate_testset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | 5 | import librosa 6 | import pretty_midi 7 | 8 | from omegaconf import OmegaConf 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from midiaudiopair import MidiAudioPair 12 | from evaluate import midi_melody_accuracy as ma 13 | from transformer_wrapper import DEFAULT_COMPOSERS 14 | 15 | 16 | def evaluate(meta_file, composer_dic, model_id): 17 | 18 | import warnings 19 | 20 | warnings.filterwarnings(action="ignore") 21 | 22 | sample = MidiAudioPair(meta_file) 23 | 24 | if ( 25 | sample.error_code == MidiAudioPair.NO_PIANO 26 | or sample.error_code == MidiAudioPair.NO_SONG_DIR 27 | or sample.error_code == MidiAudioPair.NO_SONG 28 | ): 29 | return 30 | 31 | if "vocals" in sample.invalids: 32 | print("no vocal:", meta_file) 33 | return 34 | 35 | vocals, sr = librosa.load(sample.vocals, sr=44100) 36 | HOP_LENGTH = 1024 37 | f0, _, _ = ma._f0(vocals, sr, hop_length=HOP_LENGTH) 38 | 39 | chroma_accuracys = list() 40 | 41 | for composer, value in composer_dic.items(): 42 | midi_path = sample.generated(composer, model_id) 43 | midi = pretty_midi.PrettyMIDI(midi_path) 44 | chroma_accuracy, pitch_accuracy = ma._evaluate_melody(midi, f0, sr, HOP_LENGTH) 45 | result = sample.result_json(model_id) 46 | if os.path.exists(result): 47 | result_json = OmegaConf.load(result) 48 | else: 49 | result_json = OmegaConf.create() 50 | result_json[composer] = OmegaConf.create() 51 | result_json[composer].melody_chroma_accuracy = chroma_accuracy.item() 52 | result_json[composer].melody_pitch_accuracy = pitch_accuracy.item() 53 | OmegaConf.save(result_json, result) 54 | chroma_accuracys.append(chroma_accuracy) 55 | 56 | mean_accuracy = sum(chroma_accuracys) / len(chroma_accuracys) 57 | gt_accuracy = sample.yaml.eval.melody_chroma_accuracy 58 | print(gt_accuracy, mean_accuracy) 59 | 60 | return mean_accuracy 61 | 62 | 63 | def main(meta_files, composer_config, model_id, **kwargs): 64 | from tqdm.auto import tqdm 65 | import multiprocessing 66 | from joblib import Parallel, delayed 67 | 68 | if composer_config is None: 69 | composer_dic = DEFAULT_COMPOSERS 70 | else: 71 | composer_dic = OmegaConf.load(composer_config) 72 | 73 | # for meta_file in tqdm(meta_files): 74 | # evaluate(meta_file, composer_dic, model_id) 75 | 76 | mean_accuracys = Parallel(n_jobs=multiprocessing.cpu_count() // 2)( 77 | delayed(evaluate)(meta_file, composer_dic, model_id) 78 | for meta_file in tqdm(meta_files) 79 | ) 80 | 81 | print( 82 | "Total Accuracy of", model_id, "is", sum(mean_accuracys) / len(mean_accuracys) 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | import argparse 88 | 89 | parser = argparse.ArgumentParser(description="eval melody accuracy") 90 | 91 | parser.add_argument( 92 | "data_dir", 93 | type=str, 94 | default=None, 95 | help="""directory contains {id}/{pop_filename.wav} 96 | """, 97 | ) 98 | 99 | parser.add_argument( 100 | "--composer_config", 101 | type=str, 102 | default=None, 103 | help="""config composer_to_token.yaml""", 104 | ) 105 | 106 | parser.add_argument( 107 | "--model_id", 108 | type=str, 109 | default="model_id", 110 | help="""model id""", 111 | ) 112 | 113 | args = parser.parse_args() 114 | 115 | meta_files = sorted(glob.glob(args.data_dir + "/**/*.yaml", recursive=True)) 116 | print("meta ", len(meta_files)) 117 | 118 | main(meta_files, **vars(args)) 119 | -------------------------------------------------------------------------------- /evaluate/inference_testset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | 5 | import librosa 6 | import torch 7 | import numpy as np 8 | import pretty_midi 9 | from omegaconf import OmegaConf 10 | from tqdm.auto import tqdm 11 | 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from midiaudiopair import MidiAudioPair 14 | from transformer_wrapper import TransformerWrapper, DEFAULT_COMPOSERS 15 | from evaluate import midi_melody_accuracy as ma 16 | from sweetdebug import sweetdebug 17 | 18 | 19 | def inference_main(meta_files, ckpt, config, id, **kwargs): 20 | 21 | import warnings 22 | 23 | sweetdebug(use_telegram_if_cache_exists=False) 24 | warnings.filterwarnings(action="ignore") 25 | 26 | config = OmegaConf.load(config) 27 | wrapper = TransformerWrapper(config) 28 | wrapper = wrapper.load_from_checkpoint(ckpt, config=config).cuda() 29 | wrapper.eval() 30 | 31 | with torch.no_grad(): 32 | for meta_file in tqdm(meta_files): 33 | sample = MidiAudioPair(meta_file) 34 | 35 | # Pass if the midi of all composers are generated. 36 | # ------------------------------------------- 37 | some_not_generated = False 38 | for composer, value in wrapper.composer_to_feature_token.items(): 39 | midi_path = sample.generated(composer=composer, generated=id) 40 | os.makedirs(os.path.dirname(midi_path), exist_ok=True) 41 | 42 | if not os.path.exists(midi_path): 43 | some_not_generated = True 44 | 45 | all_generated = not some_not_generated 46 | if all_generated: 47 | continue 48 | # --------------------------------------------- 49 | 50 | # load pre-computed beats 51 | # ------------------------------------ 52 | beatstep = np.load(sample.beatstep) 53 | # ------------------------------------ 54 | 55 | # load audio if needed 56 | if wrapper.use_mel: 57 | y, sr = librosa.load(sample.song, sr=config.dataset.sample_rate) 58 | vqvae_token = None 59 | else: 60 | vqvae_token = torch.load(sample.vqvae, map_location="cuda") 61 | y = None 62 | sr = None 63 | 64 | for composer, value in wrapper.composer_to_feature_token.items(): 65 | midi_path = sample.generated(composer=composer, generated=id) 66 | os.makedirs(os.path.dirname(midi_path), exist_ok=True) 67 | 68 | if os.path.exists(midi_path): 69 | continue 70 | 71 | wrapper.generate( 72 | audio_path=None, 73 | composer=composer, 74 | model=id, 75 | save_midi=True, 76 | save_mix=False, 77 | show_plot=False, 78 | midi_path=midi_path, 79 | vqvae_token=vqvae_token, 80 | beatsteps=beatstep - beatstep[0], 81 | audio_y=y, 82 | audio_sr=sr, 83 | ) 84 | 85 | 86 | def evaluate(meta_file, composer_dic, model_id): 87 | 88 | import warnings 89 | 90 | warnings.filterwarnings(action="ignore") 91 | 92 | sample = MidiAudioPair(meta_file) 93 | 94 | if ( 95 | sample.error_code == MidiAudioPair.NO_PIANO 96 | or sample.error_code == MidiAudioPair.NO_SONG_DIR 97 | or sample.error_code == MidiAudioPair.NO_SONG 98 | ): 99 | return 100 | 101 | if "vocals" in sample.invalids: 102 | print("no vocal:", meta_file) 103 | return 104 | 105 | vocals, sr = librosa.load(sample.vocals, sr=44100) 106 | HOP_LENGTH = 1024 107 | f0, _, _ = ma._f0(vocals, sr, hop_length=HOP_LENGTH) 108 | 109 | chroma_accuracys = list() 110 | 111 | for composer, value in composer_dic.items(): 112 | midi_path = sample.generated(composer, model_id) 113 | midi = pretty_midi.PrettyMIDI(midi_path) 114 | chroma_accuracy, pitch_accuracy = ma._evaluate_melody(midi, f0, sr, HOP_LENGTH) 115 | result = sample.result_json(model_id) 116 | if os.path.exists(result): 117 | result_json = OmegaConf.load(result) 118 | else: 119 | result_json = OmegaConf.create() 120 | result_json[composer] = OmegaConf.create() 121 | result_json[composer].melody_chroma_accuracy = chroma_accuracy.item() 122 | result_json[composer].melody_pitch_accuracy = pitch_accuracy.item() 123 | OmegaConf.save(result_json, result) 124 | chroma_accuracys.append(chroma_accuracy) 125 | 126 | mean_accuracy = sum(chroma_accuracys) / len(chroma_accuracys) 127 | gt_accuracy = sample.yaml.eval.melody_chroma_accuracy 128 | print(gt_accuracy, mean_accuracy) 129 | 130 | return mean_accuracy 131 | 132 | 133 | def evaluate_main(meta_files, config, model_id, **kwargs): 134 | from tqdm.auto import tqdm 135 | import multiprocessing 136 | from joblib import Parallel, delayed 137 | 138 | config = OmegaConf.load(config) 139 | composer_dic = config.composer_to_feature_token 140 | if config.dataset.use_mel and not config.dataset.mel_is_conditioned: 141 | composer_dic = DEFAULT_COMPOSERS 142 | 143 | mean_accuracys = Parallel(n_jobs=multiprocessing.cpu_count() // 2)( 144 | delayed(evaluate)(meta_file, composer_dic, model_id) 145 | for meta_file in tqdm(meta_files) 146 | ) 147 | 148 | print( 149 | "Total Accuracy of", model_id, "is", sum(mean_accuracys) / len(mean_accuracys) 150 | ) 151 | 152 | 153 | if __name__ == "__main__": 154 | import argparse 155 | 156 | parser = argparse.ArgumentParser(description="bpm estimate using essentia") 157 | 158 | parser.add_argument( 159 | "data_dir", 160 | type=str, 161 | default=None, 162 | help="""directory contains {id}/{pop_filename.wav} 163 | """, 164 | ) 165 | 166 | parser.add_argument( 167 | "--ckpt", 168 | type=str, 169 | default=None, 170 | help="""ckpt *.ckpt""", 171 | ) 172 | 173 | parser.add_argument( 174 | "--config", 175 | type=str, 176 | default=None, 177 | help="""config *.yaml""", 178 | ) 179 | 180 | parser.add_argument( 181 | "--id", 182 | type=str, 183 | default="model_name_id", 184 | help="""config composer_to_token.yaml""", 185 | ) 186 | 187 | parser.add_argument("--evaluate", action="store_true", default=False) 188 | 189 | args = parser.parse_args() 190 | 191 | meta_files = sorted(glob.glob(args.data_dir + "/**/*.yaml", recursive=True)) 192 | print("meta ", len(meta_files)) 193 | 194 | inference_main(meta_files, **vars(args)) 195 | if args.evaluate: 196 | evaluate_main(meta_files, config=args.config, model_id=args.id) 197 | -------------------------------------------------------------------------------- /evaluate/midi_melody_accuracy.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import pretty_midi 4 | import mir_eval.melody 5 | 6 | 7 | def get_highest_pitches_from_piano_roll(pr): 8 | """ 9 | params: 10 | pr : (128, time(frame)) 11 | 12 | return: 13 | highest_pitches : (time(frame), ) 14 | """ 15 | highest_pitches = [] 16 | for i in range(pr.shape[1]): 17 | ps = np.nonzero(pr[:, i]) 18 | if len(ps[0]) == 0: 19 | highest_pitches.append(np.nan) 20 | else: 21 | highest_pitches.append(ps[0][-1]) 22 | highest_pitches = np.array(highest_pitches) 23 | 24 | return highest_pitches 25 | 26 | 27 | def _f0(y, sr, hop_length): 28 | pyin = librosa.pyin( 29 | y, 30 | fmin=librosa.note_to_hz("C2"), 31 | fmax=librosa.note_to_hz("C6"), 32 | sr=sr, 33 | hop_length=hop_length, 34 | ) 35 | # f0, voiced_flag, voiced_probs 36 | return pyin 37 | 38 | 39 | def _evaluate_melody(midi, f0, sr, hop_length): 40 | x_coords = np.arange(0, midi.get_end_time(), hop_length / sr) 41 | pr = midi.get_piano_roll(fs=sr / hop_length, times=x_coords) 42 | highest_pitches = get_highest_pitches_from_piano_roll(pr) 43 | 44 | (ref_v, ref_c, est_v, est_c) = mir_eval.melody.to_cent_voicing( 45 | x_coords, f0, x_coords, librosa.midi_to_hz(highest_pitches) 46 | ) 47 | 48 | raw_chroma = mir_eval.melody.raw_chroma_accuracy(ref_v, ref_c, est_v, est_c) 49 | raw_pitch = mir_eval.melody.raw_pitch_accuracy(ref_v, ref_c, est_v, est_c) 50 | return raw_chroma, raw_pitch 51 | 52 | 53 | def evaluate_melody( 54 | midi: pretty_midi.PrettyMIDI, 55 | vocals: np.array, 56 | sr: int = 44100, 57 | hop_length: int = 1024, 58 | ): 59 | f0, voiced_flag, voiced_probs = _f0(vocals, sr=sr, hop_length=hop_length) 60 | 61 | raw_chroma, raw_pitch = _evaluate_melody(midi, f0, sr, hop_length) 62 | return raw_chroma, raw_pitch 63 | 64 | 65 | if __name__ == "__main__": 66 | pass 67 | -------------------------------------------------------------------------------- /layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/pop2piano/f218c094c3e185d43fa700ed5724cc616be5608f/layer/__init__.py -------------------------------------------------------------------------------- /layer/input.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | 5 | 6 | class LogMelSpectrogram(nn.Module): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | self.melspectrogram = torchaudio.transforms.MelSpectrogram( 10 | sample_rate=22050, 11 | n_fft=4096, 12 | hop_length=1024, 13 | f_min=10.0, 14 | n_mels=512, 15 | ) 16 | 17 | def forward(self, x): 18 | # x : audio(batch, sample) 19 | # X : melspec (batch, freq, frame) 20 | with torch.no_grad(): 21 | with torch.cuda.amp.autocast(enabled=False): 22 | X = self.melspectrogram(x) 23 | X = X.clamp(min=1e-6).log() 24 | 25 | return X 26 | 27 | 28 | class ConcatEmbeddingToMel(nn.Module): 29 | def __init__(self, embedding_offset, n_vocab, n_dim) -> None: 30 | super().__init__() 31 | self.embedding = nn.Embedding(num_embeddings=n_vocab, embedding_dim=n_dim) 32 | self.embedding_offset = embedding_offset 33 | 34 | def forward(self, feature, index_value): 35 | """ 36 | index_value : (batch, ) 37 | feature : (batch, time, feature_dim) 38 | """ 39 | index_shifted = index_value - self.embedding_offset 40 | 41 | # (batch, 1, feature_dim) 42 | composer_embedding = self.embedding(index_shifted).unsqueeze(1) 43 | # print(composer_embedding.shape, feature.shape) 44 | # (batch, 1 + time, feature_dim) 45 | inputs_embeds = torch.cat([composer_embedding, feature], dim=1) 46 | return inputs_embeds 47 | -------------------------------------------------------------------------------- /midi_tokenizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numba import jit 3 | import pretty_midi 4 | import scipy.interpolate as interp 5 | 6 | TOKEN_SPECIAL: int = 0 7 | TOKEN_NOTE: int = 1 8 | TOKEN_VELOCITY: int = 2 9 | TOKEN_TIME: int = 3 10 | 11 | DEFAULT_VELOCITY: int = 77 12 | 13 | TIE: int = 2 14 | EOS: int = 1 15 | PAD: int = 0 16 | 17 | 18 | def extrapolate_beat_times(beat_times, n_extend=1): 19 | beat_times_function = interp.interp1d( 20 | np.arange(beat_times.size), 21 | beat_times, 22 | bounds_error=False, 23 | fill_value="extrapolate", 24 | ) 25 | 26 | ext_beats = beat_times_function( 27 | np.linspace(0, beat_times.size + n_extend - 1, beat_times.size + n_extend) 28 | ) 29 | 30 | return ext_beats 31 | 32 | 33 | @jit(nopython=True, cache=True) 34 | def fast_tokenize(idx, token_type, n_special, n_note, n_velocity): 35 | if token_type == TOKEN_TIME: 36 | return n_special + n_note + n_velocity + idx 37 | elif token_type == TOKEN_VELOCITY: 38 | return n_special + n_note + idx 39 | elif token_type == TOKEN_NOTE: 40 | return n_special + idx 41 | elif token_type == TOKEN_SPECIAL: 42 | return idx 43 | else: 44 | return -1 45 | 46 | 47 | @jit(nopython=True, cache=True) 48 | def fast_detokenize(idx, n_special, n_note, n_velocity, time_idx_offset): 49 | if idx >= n_special + n_note + n_velocity: 50 | return (TOKEN_TIME, (idx - (n_special + n_note + n_velocity)) + time_idx_offset) 51 | elif idx >= n_special + n_note: 52 | return TOKEN_VELOCITY, idx - (n_special + n_note) 53 | elif idx >= n_special: 54 | return TOKEN_NOTE, idx - n_special 55 | else: 56 | return TOKEN_SPECIAL, idx 57 | 58 | 59 | class MidiTokenizer: 60 | def __init__(self, config) -> None: 61 | self.config = config 62 | 63 | def tokenize_note(self, idx, token_type): 64 | rt = fast_tokenize( 65 | idx, 66 | token_type, 67 | self.config.vocab_size.special, 68 | self.config.vocab_size.note, 69 | self.config.vocab_size.velocity, 70 | ) 71 | if rt == -1: 72 | raise ValueError(f"type {type} is not a predefined token type.") 73 | else: 74 | return rt 75 | 76 | def notes_to_tokens(self, notes): 77 | """ 78 | notes : (onset idx, offset idx, pitch, velocity) 79 | """ 80 | max_time_idx = notes[:, :2].max() 81 | 82 | times = [[] for i in range((max_time_idx + 1))] 83 | for onset, offset, pitch, velocity in notes: 84 | times[onset].append([pitch, velocity]) 85 | times[offset].append([pitch, 0]) 86 | 87 | tokens = [] 88 | current_velocity = 0 89 | for i, time in enumerate(times): 90 | if len(time) == 0: 91 | continue 92 | tokens.append(self.tokenize_note(i, TOKEN_TIME)) 93 | for pitch, velocity in time: 94 | velocity = int(velocity > 0) 95 | if current_velocity != velocity: 96 | current_velocity = velocity 97 | tokens.append(self.tokenize_note(velocity, TOKEN_VELOCITY)) 98 | tokens.append(self.tokenize_note(pitch, TOKEN_NOTE)) 99 | 100 | return np.array(tokens, dtype=int) 101 | 102 | def detokenize(self, token, time_idx_offset): 103 | type, value = fast_detokenize( 104 | token, 105 | n_special=self.config.vocab_size.special, 106 | n_note=self.config.vocab_size.note, 107 | n_velocity=self.config.vocab_size.velocity, 108 | time_idx_offset=time_idx_offset, 109 | ) 110 | if type != TOKEN_TIME: 111 | value = int(value) 112 | return [type, value] 113 | 114 | def to_string(self, tokens, time_idx_offset=0): 115 | nums = [ 116 | self.detokenize(token, time_idx_offset=time_idx_offset) for token in tokens 117 | ] 118 | strings = [] 119 | for i in range(len(nums)): 120 | type = nums[i][0] 121 | value = nums[i][1] 122 | 123 | if type == TOKEN_TIME: 124 | type = "time" 125 | elif type == TOKEN_SPECIAL: 126 | if value == EOS: 127 | value = "EOS" 128 | elif value == PAD: 129 | value = "PAD" 130 | elif value == TIE: 131 | value = "TIE" 132 | else: 133 | value = "Unknown Special" 134 | elif type == TOKEN_NOTE: 135 | type = "note" 136 | elif type == TOKEN_VELOCITY: 137 | type = "velocity" 138 | strings.append((type, value)) 139 | return strings 140 | 141 | def split_notes(self, notes, beatsteps, time_from, time_to): 142 | """ 143 | Assumptions 144 | - notes are sorted by onset time 145 | - beatsteps are sorted by time 146 | """ 147 | start_idx = np.searchsorted(beatsteps, time_from) 148 | start_note = np.searchsorted(notes[:, 0], start_idx) 149 | 150 | end_idx = np.searchsorted(beatsteps, time_to) 151 | end_note = np.searchsorted(notes[:, 0], end_idx) 152 | splited_notes = notes[start_note:end_note] 153 | 154 | return splited_notes, (start_idx, end_idx, start_note, end_note) 155 | 156 | def notes_to_relative_tokens( 157 | self, notes, offset_idx, add_eos=False, add_composer=False, composer_value=None 158 | ): 159 | """ 160 | notes : (onset idx, offset idx, pitch, velocity) 161 | """ 162 | 163 | def _add_eos(tokens): 164 | tokens = np.concatenate((tokens, np.array([EOS], dtype=tokens.dtype))) 165 | return tokens 166 | 167 | def _add_composer(tokens, composer_value): 168 | tokens = np.concatenate( 169 | (np.array([composer_value], dtype=tokens.dtype), tokens) 170 | ) 171 | return tokens 172 | 173 | if len(notes) == 0: 174 | tokens = np.array([], dtype=int) 175 | if add_eos: 176 | tokens = _add_eos(tokens) 177 | if add_composer: 178 | tokens = _add_composer(tokens, composer_value=composer_value) 179 | return tokens 180 | 181 | max_time_idx = notes[:, :2].max() 182 | 183 | # times[time_idx] = [[pitch, .. ], [pitch, 0], ..] 184 | times = [[] for i in range((max_time_idx + 1 - offset_idx))] 185 | for abs_onset, abs_offset, pitch, velocity in notes: 186 | rel_onset = abs_onset - offset_idx 187 | rel_offset = abs_offset - offset_idx 188 | times[rel_onset].append([pitch, velocity]) 189 | times[rel_offset].append([pitch, 0]) 190 | 191 | # 여기서부터는 전부 시간 0(offset) 기준 192 | tokens = [] 193 | current_velocity = 0 194 | current_time_idx = 0 195 | 196 | for rel_idx, time in enumerate(times): 197 | if len(time) == 0: 198 | continue 199 | time_idx_shift = rel_idx - current_time_idx 200 | current_time_idx = rel_idx 201 | 202 | tokens.append(self.tokenize_note(time_idx_shift, TOKEN_TIME)) 203 | for pitch, velocity in time: 204 | velocity = int(velocity > 0) 205 | if current_velocity != velocity: 206 | current_velocity = velocity 207 | tokens.append(self.tokenize_note(velocity, TOKEN_VELOCITY)) 208 | tokens.append(self.tokenize_note(pitch, TOKEN_NOTE)) 209 | 210 | tokens = np.array(tokens, dtype=int) 211 | if add_eos: 212 | tokens = _add_eos(tokens) 213 | if add_composer: 214 | tokens = _add_composer(tokens, composer_value=composer_value) 215 | return tokens 216 | 217 | def relative_batch_tokens_to_midi( 218 | self, 219 | tokens, 220 | beatstep, 221 | beat_offset_idx=None, 222 | bars_per_batch=None, 223 | cutoff_time_idx=None, 224 | ): 225 | """ 226 | tokens : (batch, sequence) 227 | beatstep : (times, ) 228 | """ 229 | beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx 230 | notes = None 231 | bars_per_batch = 2 if bars_per_batch is None else bars_per_batch 232 | 233 | N = len(tokens) 234 | for n in range(N): 235 | _tokens = tokens[n] 236 | _start_idx = beat_offset_idx + n * bars_per_batch * 4 237 | _cutoff_time_idx = cutoff_time_idx + _start_idx 238 | _notes = self.relative_tokens_to_notes( 239 | _tokens, 240 | start_idx=_start_idx, 241 | cutoff_time_idx=_cutoff_time_idx, 242 | ) 243 | # print(_notes, "\n-------") 244 | if len(_notes) == 0: 245 | pass 246 | # print("_notes zero") 247 | elif notes is None: 248 | notes = _notes 249 | else: 250 | notes = np.concatenate((notes, _notes), axis=0) 251 | 252 | if notes is None: 253 | notes = [] 254 | midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx]) 255 | return midi, notes 256 | 257 | def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None): 258 | # TODO remove legacy 259 | # decoding 첫토큰이 편곡자인 경우 260 | if tokens[0] >= sum(self.config.vocab_size.values()): 261 | tokens = tokens[1:] 262 | 263 | words = [self.detokenize(token, time_idx_offset=0) for token in tokens] 264 | 265 | if hasattr(start_idx, "item"): 266 | """ 267 | if numpy or torch tensor 268 | """ 269 | start_idx = start_idx.item() 270 | 271 | current_idx = start_idx 272 | current_velocity = 0 273 | note_onsets_ready = [None for i in range(self.config.vocab_size.note + 1)] 274 | notes = [] 275 | for type, number in words: 276 | if type == TOKEN_SPECIAL: 277 | if number == EOS: 278 | break 279 | elif type == TOKEN_TIME: 280 | current_idx += number 281 | if cutoff_time_idx is not None: 282 | current_idx = min(current_idx, cutoff_time_idx) 283 | 284 | elif type == TOKEN_VELOCITY: 285 | current_velocity = number 286 | elif type == TOKEN_NOTE: 287 | pitch = number 288 | if current_velocity == 0: 289 | # note_offset 290 | if note_onsets_ready[pitch] is None: 291 | # offset without onset 292 | pass 293 | else: 294 | onset_idx = note_onsets_ready[pitch] 295 | if onset_idx >= current_idx: 296 | # No time shift after previous note_on 297 | pass 298 | else: 299 | offset_idx = current_idx 300 | notes.append( 301 | [onset_idx, offset_idx, pitch, DEFAULT_VELOCITY] 302 | ) 303 | note_onsets_ready[pitch] = None 304 | else: 305 | # note_on 306 | if note_onsets_ready[pitch] is None: 307 | note_onsets_ready[pitch] = current_idx 308 | else: 309 | # note-on already exists 310 | onset_idx = note_onsets_ready[pitch] 311 | if onset_idx >= current_idx: 312 | # No time shift after previous note_on 313 | pass 314 | else: 315 | offset_idx = current_idx 316 | notes.append( 317 | [onset_idx, offset_idx, pitch, DEFAULT_VELOCITY] 318 | ) 319 | note_onsets_ready[pitch] = current_idx 320 | else: 321 | raise ValueError 322 | 323 | for pitch, note_on in enumerate(note_onsets_ready): 324 | # force offset if no offset for each pitch 325 | if note_on is not None: 326 | if cutoff_time_idx is None: 327 | cutoff = note_on + 1 328 | else: 329 | cutoff = max(cutoff_time_idx, note_on + 1) 330 | 331 | offset_idx = max(current_idx, cutoff) 332 | notes.append([note_on, offset_idx, pitch, DEFAULT_VELOCITY]) 333 | 334 | if len(notes) == 0: 335 | return [] 336 | else: 337 | notes = np.array(notes) 338 | note_order = notes[:, 0] * 128 + notes[:, 1] 339 | notes = notes[note_order.argsort()] 340 | return notes 341 | 342 | def notes_to_midi(self, notes, beatstep, offset_sec=None): 343 | new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0) 344 | new_inst = pretty_midi.Instrument(program=0) 345 | new_notes = [] 346 | if offset_sec is None: 347 | offset_sec = 0.0 348 | 349 | for onset_idx, offset_idx, pitch, velocity in notes: 350 | new_note = pretty_midi.Note( 351 | velocity=velocity, 352 | pitch=pitch, 353 | start=beatstep[onset_idx] - offset_sec, 354 | end=beatstep[offset_idx] - offset_sec, 355 | ) 356 | new_notes.append(new_note) 357 | new_inst.notes = new_notes 358 | new_pm.instruments.append(new_inst) 359 | new_pm.remove_invalid_notes() 360 | return new_pm 361 | 362 | 363 | @jit(nopython=True, cache=False) 364 | def fast_notes_to_relative_tokens( 365 | notes, offset_idx, max_time_idx, n_special, n_note, n_velocity 366 | ): 367 | """ 368 | notes : (onset idx, offset idx, pitch, velocity) 369 | """ 370 | 371 | times_p = [np.array([], dtype=int) for i in range((max_time_idx + 1 - offset_idx))] 372 | times_v = [np.array([], dtype=int) for i in range((max_time_idx + 1 - offset_idx))] 373 | 374 | for abs_onset, abs_offset, pitch, velocity in notes: 375 | rel_onset = abs_onset - offset_idx 376 | rel_offset = abs_offset - offset_idx 377 | times_p[rel_onset] = np.append(times_p[rel_onset], pitch) 378 | times_v[rel_onset] = np.append(times_v[rel_onset], velocity) 379 | times_p[rel_offset] = np.append(times_p[rel_offset], pitch) 380 | times_v[rel_offset] = np.append(times_v[rel_offset], velocity) 381 | 382 | # 여기서부터는 전부 시간 0(offset) 기준 383 | tokens = [] 384 | current_velocity = np.array([0]) 385 | current_time_idx = np.array([0]) 386 | 387 | # range가 0일 수도 있으니까.. 388 | for i in range(len(times_p)): 389 | rel_idx = i 390 | notes_at_time = times_p[i] 391 | if len(notes_at_time) == 0: 392 | continue 393 | 394 | time_idx_shift = rel_idx - current_time_idx[0] 395 | current_time_idx[0] = rel_idx 396 | 397 | token = fast_tokenize( 398 | time_idx_shift, 399 | TOKEN_TIME, 400 | n_special=n_special, 401 | n_note=n_note, 402 | n_velocity=n_velocity, 403 | ) 404 | tokens.append(token) 405 | 406 | for j in range(len(notes_at_time)): 407 | pitch = times_p[j] 408 | velocity = times_v[j] 409 | # for pitch, velocity in time: 410 | velocity = int(velocity > 0) 411 | if current_velocity[0] != velocity: 412 | current_velocity[0] = velocity 413 | token = fast_tokenize( 414 | velocity, 415 | TOKEN_VELOCITY, 416 | n_special=n_special, 417 | n_note=n_note, 418 | n_velocity=n_velocity, 419 | ) 420 | tokens.append(token) 421 | token = fast_tokenize( 422 | pitch, 423 | TOKEN_NOTE, 424 | n_special=n_special, 425 | n_note=n_note, 426 | n_velocity=n_velocity, 427 | ) 428 | tokens.append(token) 429 | 430 | return np.array(tokens) 431 | -------------------------------------------------------------------------------- /midiaudiopair.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | BLACKLIST_PIANO_YTID = [ 8 | "cp37xi5Jbs", 9 | "0meKPm-75As", 10 | "0uN66vwQElI", 11 | "S5zn1FJ29GU", 12 | "s_npS7szUjk", 13 | "SCssvPlXbvc", 14 | "sH7ErWQut5g", 15 | "DYQhtNMCzsA", 16 | "4CaGkbWUovE", 17 | "SQtrlqkIl4o", 18 | "ykpFk4EniDk", 19 | "WpHke7iywS8", 20 | ] 21 | 22 | 23 | class MidiAudioPair: 24 | VALID = 0 25 | 26 | NO_SONG = 1 27 | NO_PIANO = 2 28 | NO_SONG_DIR = 3 29 | BAD_DURATION = 4 30 | BAD_TITLE = 5 31 | NO_TEMPO = 6 32 | BLACKLIST = 7 33 | BAD_ACCURACY = 8 34 | 35 | ERROR_CODE = { 36 | VALID: "Valid", 37 | NO_SONG: "No Song", 38 | NO_PIANO: "No Piano", 39 | NO_SONG_DIR: "No Song Dir", 40 | BAD_DURATION: "Duration Bad", 41 | BAD_TITLE: "Bad Title", 42 | NO_TEMPO: "No Tempo", 43 | BLACKLIST: "Blacklist", 44 | BAD_ACCURACY: "Bad Accuracy", 45 | } 46 | 47 | def validate_files(self): 48 | attrs = [ 49 | "midi", 50 | "song", 51 | "beattime", 52 | "beatstep", 53 | "beatinterval", 54 | "qmidi", 55 | "qmix", 56 | "notes", 57 | "vqvae", 58 | "vocals", 59 | ] 60 | 61 | invalids = [] 62 | for attr in attrs: 63 | file = getattr(self, attr, None) 64 | if file is None or not os.path.exists(file): 65 | invalids.append(attr) 66 | 67 | return invalids 68 | 69 | def validate_yaml(self, audio_dir, yaml): 70 | if not hasattr(yaml, "song"): 71 | return MidiAudioPair.NO_SONG 72 | 73 | if not hasattr(yaml, "piano"): 74 | return MidiAudioPair.NO_PIANO 75 | 76 | if yaml.piano.ytid in BLACKLIST_PIANO_YTID: 77 | return MidiAudioPair.BLACKLIST 78 | 79 | song_dir = os.path.join(audio_dir, yaml.piano.ytid) 80 | if not os.path.exists(song_dir) or not os.path.isdir(song_dir): 81 | return MidiAudioPair.NO_SONG_DIR 82 | 83 | piano_sec = int(yaml.piano.duration) 84 | song_sec = int(yaml.song.duration) 85 | if piano_sec / song_sec > 1.2 or piano_sec / song_sec < 0.83: 86 | return MidiAudioPair.BAD_DURATION 87 | 88 | if yaml.piano.title.find("HANPPYEOM") != -1: 89 | return MidiAudioPair.BAD_TITLE 90 | 91 | if not hasattr(yaml, "tempo"): 92 | return MidiAudioPair.NO_TEMPO 93 | 94 | if not hasattr(yaml, "eval") or yaml.eval.melody_chroma_accuracy < 0.15: 95 | return MidiAudioPair.BAD_ACCURACY 96 | 97 | return MidiAudioPair.VALID 98 | 99 | def set_song_attrs(self): 100 | basename = os.path.join(self.song_dir, f"{self.yaml.song.ytid}") 101 | 102 | self.mix = basename + ".mix.flac" 103 | self.midi = basename + ".mid" 104 | self.song = basename + ".pitchshift.wav" 105 | self.beattime = basename + ".beattime.npy" 106 | self.beatstep = basename + ".beatstep.npy" 107 | 108 | self.beatinterval = basename + ".beatinterval.npy" 109 | 110 | self.qmidi = basename + ".qmidi.mid" 111 | self.qmix = basename + ".qmix.flac" 112 | self.notes = basename + ".notes.npy" 113 | self.vqvae = basename + ".vqvae.pt" 114 | self.vocals = basename + ".vocals.mp3" 115 | 116 | def delete_files_myself(self): 117 | shutil.rmtree(os.path.join(self.audio_dir, self.yaml.piano.ytid)) 118 | os.remove(self.yaml_path) 119 | os.remove(self.original_midi) 120 | if os.path.exists(self.original_wav): 121 | os.remove(self.original_wav) 122 | 123 | def __init__(self, yaml_path, audio_dir=None, auto_remove_no_song=False): 124 | self.yaml_path = yaml_path 125 | 126 | self.yaml = OmegaConf.load(yaml_path) 127 | 128 | self.audio_dir = ( 129 | audio_dir if audio_dir is not None else os.path.dirname(yaml_path) 130 | ) 131 | self.song_dir = os.path.join(self.audio_dir, self.yaml.piano.ytid) 132 | 133 | self.error_code = self.validate_yaml(self.audio_dir, self.yaml) 134 | 135 | self.original_midi = os.path.join(self.audio_dir, f"{self.yaml.piano.ytid}.mid") 136 | self.original_wav = os.path.join(self.audio_dir, f"{self.yaml.piano.ytid}.wav") 137 | 138 | if self.error_code == MidiAudioPair.NO_SONG: 139 | print("no song :", yaml_path) 140 | if auto_remove_no_song: 141 | print("remove :", yaml_path) 142 | self.delete_files_myself() 143 | return 144 | else: 145 | self.set_song_attrs() 146 | 147 | self.invalids = self.validate_files() 148 | self.is_valid = (self.error_code == MidiAudioPair.VALID) and ( 149 | len(self.invalids) == 0 150 | ) 151 | 152 | if self.error_code != MidiAudioPair.NO_SONG: 153 | self.original_song = os.path.join( 154 | self.song_dir, f"{self.yaml.song.ytid}.wav" 155 | ) 156 | self.title = f"{self.yaml.piano.title}___{self.yaml.song.title}" 157 | else: 158 | self.title = f"{self.yaml.piano.title}" 159 | 160 | def __repr__(self): 161 | return f"{MidiAudioPair.ERROR_CODE[self.error_code]}, inv{self.invalids}, {self.yaml_path}, {self.title}" 162 | 163 | def generated(self, composer, generated="model_name"): 164 | midi_path = os.path.join( 165 | self.song_dir, generated, self.yaml.song.ytid + "." + composer + ".mid" 166 | ) 167 | return midi_path 168 | 169 | def result_json(self, generated="model_name"): 170 | json_path = os.path.join( 171 | self.song_dir, generated, self.yaml.song.ytid + ".result.json" 172 | ) 173 | return json_path 174 | -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Preprocess Scripts 2 | --- 3 | - Note : the order of these scripts is IMPORTANT. 4 | - the preprocessing step is easy. but environment setting is not. please understand. 5 | - If you encounter any problems, please do not hesitate to email me or open an issue to the github. 6 | 7 | 1. Transcribe piano wavs to midi 8 | - You should transcribe {piano_cover_file.wav} -> {piano_cover_file.mid} 9 | - I recommend you to use original codes from this repo : [High-resolution Piano Transcription with Pedals by Regressing Onsets and Offsets Times](https://github.com/qiuqiangkong/piano_transcription_inference) 10 | 11 | 2. synchronize midi 12 | ```bash 13 | python pop_align.py DATA_DIR 14 | ``` 15 | 16 | 3. Estimate Pop's beats 17 | ```bash 18 | python bpm_quantize.py DATA_DIR 19 | ``` 20 | 21 | 4. get separated vocal track 22 | ```bash 23 | python split_spleeter.py DATA_DIR 24 | ``` 25 | 26 | 5. caculate melody chroma accuracy 27 | ```bash 28 | python melody_accuracy.py DATA_DIR 29 | ``` 30 | 31 | # Expected Structure 32 | ``` 33 | ├── -7lV0oJ0QXc 34 | │ ├── EHl_eQhgefw.beatinterval.npy 35 | │ ├── EHl_eQhgefw.beatstep.npy 36 | │ ├── EHl_eQhgefw.beattime.npy 37 | │ ├── EHl_eQhgefw.mid 38 | │ ├── EHl_eQhgefw.notes.npy 39 | │ ├── EHl_eQhgefw.pitchshift.wav 40 | │ ├── EHl_eQhgefw.qmidi.mid 41 | │ ├── EHl_eQhgefw.qmix.flac 42 | │ ├── EHl_eQhgefw.vocals.mp3 43 | │ ├── EHl_eQhgefw.wav 44 | │ └── The Beatles - With a Little Help from My Friends ____With A Little Help From My Friends - The Beatles _.txt 45 | ├── -7lV0oJ0QXc.mid 46 | ├── -7lV0oJ0QXc.wav 47 | ├── -7lV0oJ0QXc.yaml 48 | ``` 49 | 50 | ## Descriptions for each data 51 | 1. ```*.beattime.npy``` 52 | - timesteps (unit : second) extracted using essentia. ```np.ndarray```. (num_beats, ) 53 | 2. ```*.beatstep.npy``` 54 | - timesteps (unit : second) per every half-beat. it is calculated using linear interpolation of ```beattime```. 55 | 2. ```*.notes.npy``` 56 | - ```np.ndarray``` shape: ```(number_of_notes, 4)``` 57 | - each row contains : ```[onset(unit: index), offset(unit: index), pitch, velocity]``` 58 | - onset/offset values mean that the index of ```beatstep``` time. 59 | - for example, 60 | - ```beatstep = [0.6, 1.0, 1.4]``` 61 | - ```note = [0, 1, 77, 88]`` 62 | - then ```note``` means a note starts from 0.6sec to 1.0sec, and its pitch is 77 and velocity is 88. 63 | 64 | -------------------------------------------------------------------------------- /preprocess/beat_quantizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import librosa 3 | import essentia 4 | import essentia.standard 5 | import numpy as np 6 | import scipy.interpolate as interp 7 | import note_seq 8 | 9 | SAMPLERATE = 44100 10 | 11 | 12 | def nearest_onset_offset_digitize(on, off, bins): 13 | intermediate = (bins[1:] + bins[:-1]) / 2 14 | on_idx = np.digitize(on, intermediate) 15 | off_idx = np.digitize(off, intermediate) 16 | off_idx[on_idx == off_idx] += 1 17 | # off_idx = np.clip(off_idx, a_min=0, a_max=len(bins) - 1) 18 | return on_idx, off_idx 19 | 20 | 21 | def apply_sustain_pedal(pm): 22 | ns = note_seq.midi_to_note_sequence(pm) 23 | susns = note_seq.apply_sustain_control_changes(ns) 24 | suspm = note_seq.note_sequence_to_pretty_midi(susns) 25 | return suspm 26 | 27 | 28 | def interpolate_beat_times(beat_times, steps_per_beat, extend=False): 29 | beat_times_function = interp.interp1d( 30 | np.arange(beat_times.size), 31 | beat_times, 32 | bounds_error=False, 33 | fill_value="extrapolate", 34 | ) 35 | if extend: 36 | beat_steps_8th = beat_times_function( 37 | np.linspace(0, beat_times.size, beat_times.size * steps_per_beat + 1) 38 | ) 39 | else: 40 | beat_steps_8th = beat_times_function( 41 | np.linspace(0, beat_times.size - 1, beat_times.size * steps_per_beat - 1) 42 | ) 43 | return beat_steps_8th 44 | 45 | 46 | def midi_quantize_by_beats( 47 | sample, beat_times, steps_per_beat, ignore_sustain_pedal=False 48 | ): 49 | ns = note_seq.midi_file_to_note_sequence(sample.midi) 50 | if ignore_sustain_pedal: 51 | susns = ns 52 | else: 53 | susns = note_seq.apply_sustain_control_changes(ns) 54 | 55 | qns = copy.deepcopy(susns) 56 | 57 | notes = np.array([[n.start_time, n.end_time] for n in susns.notes]) 58 | note_attributes = np.array([[n.pitch, n.velocity] for n in susns.notes]) 59 | 60 | note_ons = np.array(notes[:, 0]) 61 | note_offs = np.array(notes[:, 1]) 62 | 63 | beat_steps_8th = interpolate_beat_times(beat_times, steps_per_beat, extend=False) 64 | 65 | on_idx, off_idx = nearest_onset_offset_digitize(note_ons, note_offs, beat_steps_8th) 66 | 67 | beat_steps_8th = interpolate_beat_times(beat_times, steps_per_beat, extend=True) 68 | 69 | discrete_notes = np.concatenate( 70 | (np.stack((on_idx, off_idx), axis=1), note_attributes), axis=1 71 | ) 72 | 73 | def delete_duplicate_notes(dnotes): 74 | note_order = dnotes[:, 0] * 128 + dnotes[:, 2] 75 | dnotes = dnotes[note_order.argsort()] 76 | indices = [] 77 | for i in range(1, len(dnotes)): 78 | if dnotes[i, 0] == dnotes[i - 1, 0] and dnotes[i, 2] == dnotes[i - 1, 2]: 79 | indices.append(i) 80 | dnotes = np.delete(dnotes, indices, axis=0) 81 | note_order = dnotes[:, 0] * 128 + dnotes[:, 1] 82 | dnotes = dnotes[note_order.argsort()] 83 | return dnotes 84 | 85 | discrete_notes = delete_duplicate_notes(discrete_notes) 86 | 87 | digitized_note_ons, digitized_note_offs = ( 88 | beat_steps_8th[on_idx], 89 | beat_steps_8th[off_idx], 90 | ) 91 | 92 | for i, note in enumerate(qns.notes): 93 | note.start_time = digitized_note_ons[i] 94 | note.end_time = digitized_note_offs[i] 95 | 96 | return qns, discrete_notes, beat_steps_8th 97 | 98 | 99 | def extract_rhythm(song, y=None): 100 | if y is None: 101 | y, sr = librosa.load(song, sr=SAMPLERATE) 102 | 103 | essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature") 104 | ( 105 | bpm, 106 | beat_times, 107 | confidence, 108 | estimates, 109 | essentia_beat_intervals, 110 | ) = essentia_tracker(y) 111 | return bpm, beat_times, confidence, estimates, essentia_beat_intervals 112 | -------------------------------------------------------------------------------- /preprocess/bpm_quantize.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | 5 | 6 | import librosa 7 | import soundfile as sf 8 | import numpy as np 9 | 10 | import note_seq 11 | from omegaconf import OmegaConf 12 | from beat_quantizer import extract_rhythm, midi_quantize_by_beats 13 | 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 15 | from midiaudiopair import MidiAudioPair 16 | from utils.dsp import get_stereo 17 | 18 | 19 | def estimate(meta_file, ignore_sustain_pedal): 20 | sample = MidiAudioPair(meta_file) 21 | 22 | if ( 23 | sample.error_code == MidiAudioPair.NO_PIANO 24 | or sample.error_code == MidiAudioPair.NO_SONG_DIR 25 | or sample.error_code == MidiAudioPair.NO_SONG 26 | ): 27 | return 28 | 29 | bpm, beat_times, confidence, estimates, essentia_beat_intervals = extract_rhythm( 30 | sample.song 31 | ) 32 | beat_times = np.array(beat_times) 33 | essentia_beat_intervals = np.array(essentia_beat_intervals) 34 | 35 | qns, discrete_notes, beat_steps_8th = midi_quantize_by_beats( 36 | sample, beat_times, 2, ignore_sustain_pedal=ignore_sustain_pedal 37 | ) 38 | 39 | qpm = note_seq.note_sequence_to_pretty_midi(qns) 40 | qpm.instruments[0].control_changes = [] 41 | qpm.write(sample.qmidi) 42 | y, sr = librosa.load(sample.song, sr=None) 43 | qpm_y = qpm.fluidsynth(sr) 44 | qmix = get_stereo(y, qpm_y, 0.4) 45 | sf.write(file=sample.qmix, data=qmix.T, samplerate=sr, format="flac") 46 | 47 | meta = OmegaConf.load(meta_file) 48 | meta.tempo = OmegaConf.create() 49 | meta.tempo.bpm = bpm 50 | meta.tempo.confidence = confidence 51 | OmegaConf.save(meta, meta_file) 52 | 53 | np.save(sample.notes, discrete_notes) 54 | np.save(sample.beatstep, beat_steps_8th) 55 | np.save(sample.beattime, beat_times) 56 | np.save(sample.beatinterval, essentia_beat_intervals) 57 | 58 | 59 | def main(meta_files, ignore_sustain_pedal): 60 | from tqdm import tqdm 61 | import multiprocessing 62 | from joblib import Parallel, delayed 63 | 64 | def files(): 65 | pbar = tqdm(meta_files) 66 | for meta_file in pbar: 67 | pbar.set_description(meta_file) 68 | yield meta_file 69 | 70 | Parallel(n_jobs=multiprocessing.cpu_count() // 2)( 71 | delayed(estimate)(meta_file, ignore_sustain_pedal) for meta_file in files() 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | import argparse 77 | 78 | parser = argparse.ArgumentParser(description="bpm estimate using essentia") 79 | 80 | parser.add_argument( 81 | "data_dir", 82 | type=str, 83 | default=None, 84 | help="""directory contains {id}/{pop_filename.wav} 85 | """, 86 | ) 87 | 88 | parser.add_argument( 89 | "--ignore_sustain_pedal", 90 | default=False, 91 | action="store_true", 92 | ) 93 | 94 | args = parser.parse_args() 95 | 96 | meta_files = sorted(glob.glob(args.data_dir + "/*.yaml")) 97 | print("meta ", len(meta_files)) 98 | 99 | main(meta_files, args.ignore_sustain_pedal) 100 | -------------------------------------------------------------------------------- /preprocess/melody_accuracy.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | import os 4 | 5 | import librosa 6 | import pretty_midi 7 | 8 | from omegaconf import OmegaConf 9 | 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | from midiaudiopair import MidiAudioPair 12 | from evaluate import midi_melody_accuracy as ma 13 | 14 | 15 | def estimate(meta_file): 16 | 17 | import warnings 18 | 19 | warnings.filterwarnings(action="ignore") 20 | 21 | sample = MidiAudioPair(meta_file) 22 | 23 | if ( 24 | sample.error_code == MidiAudioPair.NO_PIANO 25 | or sample.error_code == MidiAudioPair.NO_SONG_DIR 26 | or sample.error_code == MidiAudioPair.NO_SONG 27 | ): 28 | return 29 | 30 | if "vocals" in sample.invalids: 31 | print("no vocal:", meta_file) 32 | return 33 | 34 | midi = pretty_midi.PrettyMIDI(sample.qmidi) 35 | vocals, sr = librosa.load(sample.vocals, sr=44100) 36 | 37 | chroma_accuracy, pitch_accuracy = ma.evaluate_melody( 38 | midi, vocals, sr=sr, hop_length=1024 39 | ) 40 | meta = OmegaConf.load(meta_file) 41 | meta.eval = OmegaConf.create() 42 | meta.eval.melody_chroma_accuracy = chroma_accuracy.item() 43 | meta.eval.melody_pitch_accuracy = pitch_accuracy.item() 44 | OmegaConf.save(meta, meta_file) 45 | 46 | 47 | def main(meta_files): 48 | from tqdm import tqdm 49 | import multiprocessing 50 | from joblib import Parallel, delayed 51 | 52 | def files(): 53 | pbar = tqdm(meta_files) 54 | for meta_file in pbar: 55 | pbar.set_description(meta_file) 56 | yield meta_file 57 | 58 | Parallel(n_jobs=multiprocessing.cpu_count() // 2)( 59 | delayed(estimate)(meta_file) for meta_file in files() 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | import argparse 65 | 66 | parser = argparse.ArgumentParser(description="bpm estimate using essentia") 67 | 68 | parser.add_argument( 69 | "data_dir", 70 | type=str, 71 | default=None, 72 | help="""directory contains {id}/{pop_filename.wav} 73 | """, 74 | ) 75 | 76 | args = parser.parse_args() 77 | 78 | meta_files = sorted(glob.glob(args.data_dir + "/**/*.yaml", recursive=True)) 79 | print("meta ", len(meta_files)) 80 | 81 | main(meta_files) 82 | -------------------------------------------------------------------------------- /preprocess/pop_align.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import soundfile as sf 3 | import glob 4 | import os 5 | import copy 6 | import sys 7 | 8 | import numpy as np 9 | import pyrubberband as pyrb 10 | import pretty_midi 11 | from omegaconf import OmegaConf 12 | from tqdm.auto import tqdm 13 | 14 | from synctoolbox.dtw.mrmsdtw import sync_via_mrmsdtw 15 | from synctoolbox.dtw.utils import ( 16 | compute_optimal_chroma_shift, 17 | shift_chroma_vectors, 18 | make_path_strictly_monotonic, 19 | ) 20 | from synctoolbox.feature.chroma import ( 21 | pitch_to_chroma, 22 | quantize_chroma, 23 | quantized_chroma_to_CENS, 24 | ) 25 | from synctoolbox.feature.dlnco import pitch_onset_features_to_DLNCO 26 | from synctoolbox.feature.pitch import audio_to_pitch_features 27 | from synctoolbox.feature.pitch_onset import audio_to_pitch_onset_features 28 | from synctoolbox.feature.utils import estimate_tuning 29 | 30 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 31 | print(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 32 | from utils.dsp import normalize, get_stereo 33 | from midiaudiopair import MidiAudioPair 34 | 35 | Fs = 22050 36 | feature_rate = 50 37 | step_weights = np.array([1.5, 1.5, 2.0]) 38 | threshold_rec = 10 ** 6 39 | 40 | 41 | def save_delayed_song( 42 | sample, 43 | dry_run, 44 | ): 45 | import warnings 46 | 47 | warnings.filterwarnings(action="ignore") 48 | 49 | song_audio, _ = librosa.load(sample.original_song, Fs) 50 | midi_pm = pretty_midi.PrettyMIDI(sample.original_midi) 51 | 52 | if np.power(song_audio, 2).sum() < 1: # low energy: invalid file 53 | print("invalid audio :", sample.original_song) 54 | sample.delete_files_myself() 55 | return 56 | 57 | rd = get_aligned_results(midi_pm=midi_pm, song_audio=song_audio) 58 | 59 | mix_song = rd["mix_song"] 60 | song_pitch_shifted = rd["song_pitch_shifted"] 61 | midi_warped_pm = rd["midi_warped_pm"] 62 | pitch_shift_for_song_audio = rd["pitch_shift_for_song_audio"] 63 | tuning_offset_song = rd["tuning_offset_song"] 64 | tuning_offset_piano = rd["tuning_offset_piano"] 65 | 66 | try: 67 | if dry_run: 68 | print("write audio files: ", sample.song) 69 | else: 70 | sf.write( 71 | file=sample.song, 72 | data=song_pitch_shifted, 73 | samplerate=Fs, 74 | format="wav", 75 | ) 76 | except: 77 | print("Fail : ", sample.song) 78 | 79 | try: 80 | if dry_run: 81 | print("write warped midi :", sample.midi) 82 | else: 83 | midi_warped_pm.write(sample.midi) 84 | 85 | except: 86 | midi_warped_pm._tick_scales = midi_pm._tick_scales 87 | try: 88 | if dry_run: 89 | print("write warped midi2 :", sample.midi) 90 | else: 91 | midi_warped_pm.write(sample.midi) 92 | 93 | except: 94 | print("ad-hoc failed midi : ", sample.midi) 95 | print("ad-hoc midi : ", sample.midi) 96 | 97 | sample.yaml.song.pitch_shift = pitch_shift_for_song_audio.item() 98 | sample.yaml.song.tuning_offset = tuning_offset_song.item() 99 | sample.yaml.piano.tuning_offset = tuning_offset_piano.item() 100 | OmegaConf.save(sample.yaml, sample.yaml_path) 101 | 102 | 103 | def get_aligned_results(midi_pm, song_audio): 104 | piano_audio = midi_pm.fluidsynth(Fs) 105 | 106 | song_audio = normalize(song_audio) 107 | 108 | # The reason for estimating tuning :: 109 | # https://www.audiolabs-erlangen.de/resources/MIR/FMP/C3/C3S1_TranspositionTuning.html 110 | tuning_offset_1 = estimate_tuning(song_audio, Fs) 111 | tuning_offset_2 = estimate_tuning(piano_audio, Fs) 112 | 113 | # DLNCO features (Sebastian Ewert, Meinard Müller, and Peter Grosche: High Resolution Audio Synchronization Using Chroma Onset Features, In Proceedings of IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP): 1869–1872, 2009.): 114 | # helpful to increase synchronization accuracy, especially for music with clear onsets. 115 | 116 | # Quantized and smoothed chroma : CENS features 117 | # Because, MrMsDTW Requires CENS. 118 | f_chroma_quantized_1, f_DLNCO_1 = get_features_from_audio( 119 | song_audio, tuning_offset_1 120 | ) 121 | f_chroma_quantized_2, f_DLNCO_2 = get_features_from_audio( 122 | piano_audio, tuning_offset_2 123 | ) 124 | 125 | # Shift chroma vectors : 126 | # Otherwise, different keys of two audio leads to degradation of alignment. 127 | opt_chroma_shift = compute_optimal_chroma_shift( 128 | quantized_chroma_to_CENS(f_chroma_quantized_1, 201, 50, feature_rate)[0], 129 | quantized_chroma_to_CENS(f_chroma_quantized_2, 201, 50, feature_rate)[0], 130 | ) 131 | f_chroma_quantized_2 = shift_chroma_vectors(f_chroma_quantized_2, opt_chroma_shift) 132 | f_DLNCO_2 = shift_chroma_vectors(f_DLNCO_2, opt_chroma_shift) 133 | 134 | wp = sync_via_mrmsdtw( 135 | f_chroma1=f_chroma_quantized_1, 136 | f_onset1=f_DLNCO_1, 137 | f_chroma2=f_chroma_quantized_2, 138 | f_onset2=f_DLNCO_2, 139 | input_feature_rate=feature_rate, 140 | step_weights=step_weights, 141 | threshold_rec=threshold_rec, 142 | verbose=False, 143 | ) 144 | 145 | wp = make_path_strictly_monotonic(wp) 146 | pitch_shift_for_song_audio = -opt_chroma_shift % 12 147 | if pitch_shift_for_song_audio > 6: 148 | pitch_shift_for_song_audio -= 12 149 | 150 | if pitch_shift_for_song_audio != 0: 151 | song_audio_shifted = pyrb.pitch_shift( 152 | song_audio, Fs, pitch_shift_for_song_audio 153 | ) 154 | else: 155 | song_audio_shifted = song_audio 156 | 157 | time_map_second = wp / feature_rate 158 | midi_pm_warped = copy.deepcopy(midi_pm) 159 | 160 | midi_pm_warped = simple_adjust_times( 161 | midi_pm_warped, time_map_second[1], time_map_second[0] 162 | ) 163 | piano_audio_warped = midi_pm_warped.fluidsynth(Fs) 164 | 165 | song_audio_shifted = normalize(song_audio_shifted) 166 | stereo_sonification_piano = get_stereo(song_audio_shifted, piano_audio_warped) 167 | 168 | rd = dict( 169 | mix_song=stereo_sonification_piano, 170 | song_pitch_shifted=song_audio_shifted, 171 | midi_warped_pm=midi_pm_warped, 172 | pitch_shift_for_song_audio=pitch_shift_for_song_audio, 173 | tuning_offset_song=tuning_offset_1, 174 | tuning_offset_piano=tuning_offset_2, 175 | ) 176 | return rd 177 | 178 | 179 | def simple_adjust_times(pm, original_times, new_times): 180 | """ 181 | most of these codes are from original pretty_midi 182 | https://github.com/craffel/pretty-midi/blob/main/pretty_midi/pretty_midi.py 183 | """ 184 | for instrument in pm.instruments: 185 | instrument.notes = [ 186 | copy.deepcopy(note) 187 | for note in instrument.notes 188 | if note.start >= original_times[0] and note.end <= original_times[-1] 189 | ] 190 | # Get array of note-on locations and correct them 191 | note_ons = np.array( 192 | [note.start for instrument in pm.instruments for note in instrument.notes] 193 | ) 194 | adjusted_note_ons = np.interp(note_ons, original_times, new_times) 195 | # Same for note-offs 196 | note_offs = np.array( 197 | [note.end for instrument in pm.instruments for note in instrument.notes] 198 | ) 199 | adjusted_note_offs = np.interp(note_offs, original_times, new_times) 200 | # Correct notes 201 | for n, note in enumerate( 202 | [note for instrument in pm.instruments for note in instrument.notes] 203 | ): 204 | note.start = (adjusted_note_ons[n] > 0) * adjusted_note_ons[n] 205 | note.end = (adjusted_note_offs[n] > 0) * adjusted_note_offs[n] 206 | # After performing alignment, some notes may have an end time which is 207 | # on or before the start time. Remove these! 208 | pm.remove_invalid_notes() 209 | 210 | def adjust_events(event_getter): 211 | """This function calls event_getter with each instrument as the 212 | sole argument and adjusts the events which are returned.""" 213 | # Sort the events by time 214 | for instrument in pm.instruments: 215 | event_getter(instrument).sort(key=lambda e: e.time) 216 | # Correct the events by interpolating 217 | event_times = np.array( 218 | [ 219 | event.time 220 | for instrument in pm.instruments 221 | for event in event_getter(instrument) 222 | ] 223 | ) 224 | adjusted_event_times = np.interp(event_times, original_times, new_times) 225 | for n, event in enumerate( 226 | [ 227 | event 228 | for instrument in pm.instruments 229 | for event in event_getter(instrument) 230 | ] 231 | ): 232 | event.time = adjusted_event_times[n] 233 | for instrument in pm.instruments: 234 | # We want to keep only the final event which has time == 235 | # new_times[0] 236 | valid_events = [ 237 | event 238 | for event in event_getter(instrument) 239 | if event.time == new_times[0] 240 | ] 241 | if valid_events: 242 | valid_events = valid_events[-1:] 243 | # Otherwise only keep events within the new set of times 244 | valid_events.extend( 245 | event 246 | for event in event_getter(instrument) 247 | if event.time > new_times[0] and event.time < new_times[-1] 248 | ) 249 | event_getter(instrument)[:] = valid_events 250 | 251 | # Correct pitch bends and control changes 252 | adjust_events(lambda i: i.pitch_bends) 253 | adjust_events(lambda i: i.control_changes) 254 | 255 | return pm 256 | 257 | 258 | def get_features_from_audio(audio, tuning_offset, visualize=False): 259 | f_pitch = audio_to_pitch_features( 260 | f_audio=audio, 261 | Fs=Fs, 262 | tuning_offset=tuning_offset, 263 | feature_rate=feature_rate, 264 | verbose=visualize, 265 | ) 266 | f_chroma = pitch_to_chroma(f_pitch=f_pitch) 267 | f_chroma_quantized = quantize_chroma(f_chroma=f_chroma) 268 | 269 | f_pitch_onset = audio_to_pitch_onset_features( 270 | f_audio=audio, Fs=Fs, tuning_offset=tuning_offset, verbose=visualize 271 | ) 272 | f_DLNCO = pitch_onset_features_to_DLNCO( 273 | f_peaks=f_pitch_onset, 274 | feature_rate=feature_rate, 275 | feature_sequence_length=f_chroma_quantized.shape[1], 276 | visualize=visualize, 277 | ) 278 | return f_chroma_quantized, f_DLNCO 279 | 280 | 281 | def main(samples, dry_run): 282 | import multiprocessing 283 | from joblib import Parallel, delayed 284 | 285 | Parallel(n_jobs=multiprocessing.cpu_count() // 2)( 286 | delayed(save_delayed_song)(sample=sample, dry_run=dry_run) 287 | for sample in tqdm(samples) 288 | ) 289 | 290 | 291 | if __name__ == "__main__": 292 | 293 | import argparse 294 | 295 | parser = argparse.ArgumentParser(description="piano cover downloader") 296 | 297 | parser.add_argument( 298 | "data_dir", 299 | type=str, 300 | default=None, 301 | help="""directory contains {id}/{song_filename.wav} 302 | """, 303 | ) 304 | parser.add_argument( 305 | "--dry_run", default=False, action="store_true", help="whether dry_run" 306 | ) 307 | 308 | args = parser.parse_args() 309 | 310 | def getfiles(): 311 | meta_files = sorted(glob.glob(args.data_dir + "/*.yaml")) 312 | print("meta ", len(meta_files)) 313 | 314 | samples = list() 315 | for meta_file in tqdm(meta_files): 316 | m = MidiAudioPair(meta_file, auto_remove_no_song=True) 317 | if m.error_code != MidiAudioPair.NO_SONG: 318 | aux_txt = os.path.join( 319 | m.audio_dir, 320 | m.yaml.piano.ytid, 321 | f"{m.yaml.piano.title[:50]}___{m.yaml.song.title[:50]}.txt", 322 | ) 323 | with open(aux_txt, "w") as f: 324 | f.write(".") 325 | samples.append(m) 326 | 327 | print(f"files available {len(samples)}") 328 | return samples 329 | 330 | samples = getfiles() 331 | main(samples=samples, dry_run=args.dry_run) 332 | -------------------------------------------------------------------------------- /preprocess/split_spleeter.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import sys 5 | 6 | from tqdm.auto import tqdm 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | from midiaudiopair import MidiAudioPair 10 | 11 | 12 | def split_spleeter(meta_files): 13 | # Use audio loader explicitly for loading audio waveform : 14 | from spleeter.audio.adapter import AudioAdapter 15 | from spleeter.separator import Separator 16 | import spleeter 17 | 18 | sample_rate = 44100 19 | audio_loader = AudioAdapter.default() 20 | 21 | # Using embedded configuration. 22 | separator = Separator("spleeter:2stems") 23 | 24 | for meta_file in tqdm(meta_files): 25 | sample = MidiAudioPair(meta_file) 26 | if sample.error_code == MidiAudioPair.NO_SONG: 27 | continue 28 | if os.path.exists(sample.vocals): 29 | continue 30 | 31 | waveform, _ = audio_loader.load(sample.song, sample_rate=sample_rate) 32 | 33 | # Perform the separation : 34 | prediction = separator.separate(waveform) 35 | 36 | audio_loader.save( 37 | path=sample.vocals, 38 | data=prediction["vocals"][:, 0:1], 39 | codec=spleeter.audio.Codec.MP3, 40 | sample_rate=sample_rate, 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser(description="bpm estimate using essentia") 48 | 49 | parser.add_argument( 50 | "data_dir", 51 | type=str, 52 | default=None, 53 | help="""directory contains {id}/{pop_filename.wav} 54 | """, 55 | ) 56 | 57 | parser.add_argument( 58 | "--random_order", 59 | default=False, 60 | action="store_true", 61 | help="Random order process (to run multiple process)", 62 | ) 63 | 64 | args = parser.parse_args() 65 | 66 | meta_files = sorted(glob.glob(args.data_dir + "/*.yaml")) 67 | if args.random_order: 68 | random.shuffle(meta_files) 69 | 70 | print("meta ", len(meta_files)) 71 | 72 | split_spleeter(meta_files) 73 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pretty-midi==0.2.9 2 | omegaconf==2.1.1 3 | transformers==4.16.1 4 | pytorch-lightning==1.8.4 5 | essentia==2.1b6.dev1034 6 | note-seq==0.0.5 7 | pyFluidSynth==1.3.0 8 | torch==1.13.1 9 | torchaudio==0.13.1 10 | -------------------------------------------------------------------------------- /transformer_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import librosa 6 | import torch 7 | import torch.optim as optim 8 | import pytorch_lightning as pl 9 | import soundfile as sf 10 | from torch.nn.utils.rnn import pad_sequence 11 | from transformers import T5Config, T5ForConditionalGeneration 12 | 13 | from midi_tokenizer import MidiTokenizer, extrapolate_beat_times 14 | from layer.input import LogMelSpectrogram, ConcatEmbeddingToMel 15 | from preprocess.beat_quantizer import extract_rhythm, interpolate_beat_times 16 | from utils.dsp import get_stereo 17 | 18 | 19 | DEFAULT_COMPOSERS = {"various composer": 2052} 20 | 21 | 22 | class TransformerWrapper(pl.LightningModule): 23 | def __init__(self, config): 24 | super().__init__() 25 | self.config = config 26 | 27 | self.tokenizer = MidiTokenizer(config.tokenizer) 28 | self.t5config = T5Config.from_pretrained("t5-small") 29 | 30 | for k, v in config.t5.items(): 31 | self.t5config.__setattr__(k, v) 32 | 33 | self.transformer = T5ForConditionalGeneration(self.t5config) 34 | self.use_mel = self.config.dataset.use_mel 35 | self.mel_is_conditioned = self.config.dataset.mel_is_conditioned 36 | self.composer_to_feature_token = config.composer_to_feature_token 37 | 38 | if self.use_mel and not self.mel_is_conditioned: 39 | self.composer_to_feature_token = DEFAULT_COMPOSERS 40 | 41 | if self.use_mel: 42 | self.spectrogram = LogMelSpectrogram() 43 | if self.mel_is_conditioned: 44 | n_dim = 512 45 | composer_n_vocab = len(self.composer_to_feature_token) 46 | embedding_offset = min(self.composer_to_feature_token.values()) 47 | self.mel_conditioner = ConcatEmbeddingToMel( 48 | embedding_offset=embedding_offset, 49 | n_vocab=composer_n_vocab, 50 | n_dim=n_dim, 51 | ) 52 | else: 53 | self.spectrogram = None 54 | 55 | self.lr = config.training.lr 56 | 57 | def forward(self, input_ids, labels): 58 | """ 59 | Deprecated. 60 | """ 61 | rt = self.transformer(input_ids=input_ids, labels=labels) 62 | return rt 63 | 64 | @torch.no_grad() 65 | def single_inference( 66 | self, 67 | feature_tokens=None, 68 | audio=None, 69 | beatstep=None, 70 | max_length=256, 71 | max_batch_size=64, 72 | n_bars=None, 73 | composer_value=None, 74 | ): 75 | """ 76 | generate a long audio sequence 77 | 78 | feature_tokens or audio : shape (time, ) 79 | 80 | beatstep : shape (time, ) 81 | - input_ids가 해당하는 beatstep 값들 82 | (offset 빠짐, 즉 beatstep[0] == 0) 83 | - beatstep[-1] : input_ids가 끝나는 지점의 시간값 84 | (즉 beatstep[-1] == len(y)//sr) 85 | """ 86 | 87 | assert feature_tokens is not None or audio is not None 88 | assert beatstep is not None 89 | 90 | if feature_tokens is not None: 91 | assert len(feature_tokens.shape) == 1 92 | 93 | if audio is not None: 94 | assert len(audio.shape) == 1 95 | 96 | config = self.config 97 | PAD = self.t5config.pad_token_id 98 | n_bars = config.dataset.n_bars if n_bars is None else n_bars 99 | 100 | if beatstep[0] > 0.01: 101 | print( 102 | "inference warning : beatstep[0] is not 0 ({beatstep[0]}). all beatstep will be shifted." 103 | ) 104 | beatstep = beatstep - beatstep[0] 105 | 106 | if self.use_mel: 107 | input_ids = None 108 | inputs_embeds, ext_beatstep = self.prepare_inference_mel( 109 | audio, 110 | beatstep, 111 | n_bars=n_bars, 112 | padding_value=PAD, 113 | composer_value=composer_value, 114 | ) 115 | batch_size = inputs_embeds.shape[0] 116 | else: 117 | raise NotImplementedError 118 | 119 | # Considering GPU capacity, some sequence would not be generated at once. 120 | relative_tokens = list() 121 | for i in range(0, batch_size, max_batch_size): 122 | start = i 123 | end = min(batch_size, i + max_batch_size) 124 | 125 | if input_ids is None: 126 | _input_ids = None 127 | _inputs_embeds = inputs_embeds[start:end] 128 | else: 129 | _input_ids = input_ids[start:end] 130 | _inputs_embeds = None 131 | 132 | _relative_tokens = self.transformer.generate( 133 | input_ids=_input_ids, 134 | inputs_embeds=_inputs_embeds, 135 | max_length=max_length, 136 | ) 137 | _relative_tokens = _relative_tokens.cpu().numpy() 138 | relative_tokens.append(_relative_tokens) 139 | 140 | max_length = max([rt.shape[-1] for rt in relative_tokens]) 141 | for i in range(len(relative_tokens)): 142 | relative_tokens[i] = np.pad( 143 | relative_tokens[i], 144 | [(0, 0), (0, max_length - relative_tokens[i].shape[-1])], 145 | constant_values=PAD, 146 | ) 147 | relative_tokens = np.concatenate(relative_tokens) 148 | 149 | pm, notes = self.tokenizer.relative_batch_tokens_to_midi( 150 | relative_tokens, 151 | beatstep=ext_beatstep, 152 | bars_per_batch=n_bars, 153 | cutoff_time_idx=(n_bars + 1) * 4, 154 | ) 155 | 156 | return relative_tokens, notes, pm 157 | 158 | def prepare_inference_mel( 159 | self, audio, beatstep, n_bars, padding_value, composer_value=None 160 | ): 161 | n_steps = n_bars * 4 162 | n_target_step = len(beatstep) 163 | sample_rate = self.config.dataset.sample_rate 164 | ext_beatstep = extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1) 165 | 166 | def split_audio(audio): 167 | # Split audio corresponding beat intervals. 168 | # Each audio's lengths are different. 169 | # Because each corresponding beat interval times are different. 170 | batch = [] 171 | 172 | for i in range(0, n_target_step, n_steps): 173 | 174 | start_idx = i 175 | end_idx = min(i + n_steps, n_target_step) 176 | 177 | start_sample = int(ext_beatstep[start_idx] * sample_rate) 178 | end_sample = int(ext_beatstep[end_idx] * sample_rate) 179 | feature = audio[start_sample:end_sample] 180 | batch.append(feature) 181 | return batch 182 | 183 | def pad_and_stack_batch(batch): 184 | batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) 185 | return batch 186 | 187 | batch = split_audio(audio) 188 | batch = pad_and_stack_batch(batch) 189 | 190 | inputs_embeds = self.spectrogram(batch).transpose(-1, -2) 191 | if self.mel_is_conditioned: 192 | composer_value = torch.tensor(composer_value).to(self.device) 193 | composer_value = composer_value.repeat(inputs_embeds.shape[0]) 194 | inputs_embeds = self.mel_conditioner(inputs_embeds, composer_value) 195 | return inputs_embeds, ext_beatstep 196 | 197 | @torch.no_grad() 198 | def generate( 199 | self, 200 | audio_path=None, 201 | composer=None, 202 | model="generated", 203 | steps_per_beat=2, 204 | stereo_amp=0.5, 205 | n_bars=2, 206 | ignore_duplicate=True, 207 | show_plot=False, 208 | save_midi=False, 209 | save_mix=False, 210 | midi_path=None, 211 | mix_path=None, 212 | click_amp=0.2, 213 | add_click=False, 214 | max_batch_size=None, 215 | beatsteps=None, 216 | mix_sample_rate=None, 217 | audio_y=None, 218 | audio_sr=None, 219 | ): 220 | config = self.config 221 | device = self.device 222 | 223 | if audio_path is not None: 224 | extension = os.path.splitext(audio_path)[1] 225 | mix_path = ( 226 | audio_path.replace(extension, f".{model}.{composer}.wav") 227 | if mix_path is None 228 | else mix_path 229 | ) 230 | midi_path = ( 231 | audio_path.replace(extension, f".{model}.{composer}.mid") 232 | if midi_path is None 233 | else midi_path 234 | ) 235 | 236 | max_batch_size = 64 // n_bars if max_batch_size is None else max_batch_size 237 | composer_to_feature_token = self.composer_to_feature_token 238 | 239 | if composer is None: 240 | composer = random.sample(list(composer_to_feature_token.keys()), 1)[0] 241 | 242 | composer_value = composer_to_feature_token[composer] 243 | mix_sample_rate = ( 244 | config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate 245 | ) 246 | 247 | if not ignore_duplicate: 248 | if os.path.exists(midi_path): 249 | return 250 | 251 | ESSENTIA_SAMPLERATE = 44100 252 | 253 | if beatsteps is None: 254 | y, sr = librosa.load(audio_path, sr=ESSENTIA_SAMPLERATE) 255 | ( 256 | bpm, 257 | beat_times, 258 | confidence, 259 | estimates, 260 | essentia_beat_intervals, 261 | ) = extract_rhythm(audio_path, y=y) 262 | beat_times = np.array(beat_times) 263 | beatsteps = interpolate_beat_times(beat_times, steps_per_beat, extend=True) 264 | else: 265 | y = None 266 | 267 | if self.use_mel: 268 | if audio_y is None and config.dataset.sample_rate != ESSENTIA_SAMPLERATE: 269 | if y is not None: 270 | y = librosa.core.resample( 271 | y, 272 | orig_sr=ESSENTIA_SAMPLERATE, 273 | target_sr=config.dataset.sample_rate, 274 | ) 275 | sr = config.dataset.sample_rate 276 | else: 277 | y, sr = librosa.load(audio_path, sr=config.dataset.sample_rate) 278 | elif audio_y is not None: 279 | if audio_sr != config.dataset.sample_rate: 280 | audio_y = librosa.core.resample( 281 | audio_y, orig_sr=audio_sr, target_sr=config.dataset.sample_rate 282 | ) 283 | audio_sr = config.dataset.sample_rate 284 | y = audio_y 285 | sr = audio_sr 286 | 287 | start_sample = int(beatsteps[0] * sr) 288 | end_sample = int(beatsteps[-1] * sr) 289 | _audio = torch.from_numpy(y)[start_sample:end_sample].to(device) 290 | fzs = None 291 | else: 292 | raise NotImplementedError 293 | 294 | relative_tokens, notes, pm = self.single_inference( 295 | feature_tokens=fzs, 296 | audio=_audio, 297 | beatstep=beatsteps - beatsteps[0], 298 | max_length=config.dataset.target_length 299 | * max(1, (n_bars // config.dataset.n_bars)), 300 | max_batch_size=max_batch_size, 301 | n_bars=n_bars, 302 | composer_value=composer_value, 303 | ) 304 | 305 | for n in pm.instruments[0].notes: 306 | n.start += beatsteps[0] 307 | n.end += beatsteps[0] 308 | 309 | if show_plot or save_mix: 310 | if mix_sample_rate != sr: 311 | y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate) 312 | sr = mix_sample_rate 313 | if add_click: 314 | clicks = ( 315 | librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp 316 | ) 317 | y = y + clicks 318 | pm_y = pm.fluidsynth(sr) 319 | stereo = get_stereo(y, pm_y, pop_scale=stereo_amp) 320 | 321 | if show_plot: 322 | import IPython.display as ipd 323 | from IPython.display import display 324 | import note_seq 325 | 326 | display("Stereo MIX", ipd.Audio(stereo, rate=sr)) 327 | display("Rendered MIDI", ipd.Audio(pm_y, rate=sr)) 328 | display("Original Song", ipd.Audio(y, rate=sr)) 329 | display(note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm))) 330 | 331 | if save_mix: 332 | sf.write( 333 | file=mix_path, 334 | data=stereo.T, 335 | samplerate=sr, 336 | format="wav", 337 | ) 338 | 339 | if save_midi: 340 | pm.write(midi_path) 341 | 342 | return pm, composer, mix_path, midi_path 343 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/pop2piano/f218c094c3e185d43fa700ed5724cc616be5608f/utils/__init__.py -------------------------------------------------------------------------------- /utils/demo.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | 4 | import shutil 5 | import youtube_dl 6 | from youtube_dl.utils import sanitize_filename 7 | 8 | 9 | def download_youtube(url, dst_dir, dst_filename=None, keep_video=False): 10 | ydl_opts = { 11 | "format": "mp4", 12 | "restrictfilenames": True, 13 | "keepvideo": keep_video, 14 | "postprocessors": [ 15 | { 16 | "key": "FFmpegExtractAudio", 17 | "preferredcodec": "mp3", 18 | "preferredquality": "192", 19 | } 20 | ], 21 | } 22 | 23 | with youtube_dl.YoutubeDL(ydl_opts) as ydl: 24 | rt = ydl.extract_info(url) 25 | 26 | title = sanitize_filename(rt["title"], restricted=True) 27 | reg_title = re.sub("[^a-zA-Zㄱ-ㅎ가-힣0-9\ \-\_\.]", "", rt["title"]) 28 | 29 | result_video_filename = f"{title}-{rt['id']}.{rt['ext']}" 30 | result_audio_filename = f"{title}-{rt['id']}.mp3" 31 | result_ok = os.path.exists(result_audio_filename) 32 | 33 | dst_audio_filename = ( 34 | f"{reg_title}-{rt['id']}.mp3" if dst_filename is None else f"{dst_filename}.mp3" 35 | ) 36 | dst_audio_filepath = os.path.join(dst_dir, dst_audio_filename) 37 | dst_video_filename = f"{reg_title}-{rt['id']}.{rt['ext']}" 38 | dst_video_filepath = os.path.join(dst_dir, dst_video_filename) 39 | 40 | if result_ok: 41 | os.makedirs(dst_dir, exist_ok=True) 42 | shutil.move(result_audio_filename, dst_audio_filepath) 43 | if keep_video: 44 | shutil.move(result_video_filename, dst_video_filepath) 45 | return dst_audio_filepath, dst_video_filepath 46 | 47 | return dst_audio_filepath 48 | -------------------------------------------------------------------------------- /utils/dsp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import interp1d 3 | 4 | 5 | def normalize(audio, min_y=-1.0, max_y=1.0, eps=1e-8): 6 | assert len(audio.shape) == 1 7 | max_y -= eps 8 | min_y += eps 9 | amax = audio.max() 10 | amin = audio.min() 11 | audio = (max_y - min_y) * (audio - amin) / (amax - amin) + min_y 12 | return audio 13 | 14 | 15 | def get_stereo(pop_y, midi_y, pop_scale=0.99): 16 | if len(pop_y) > len(midi_y): 17 | midi_y = np.pad(midi_y, (0, len(pop_y) - len(midi_y))) 18 | elif len(pop_y) < len(midi_y): 19 | pop_y = np.pad(pop_y, (0, -len(pop_y) + len(midi_y))) 20 | stereo = np.stack((midi_y, pop_y * pop_scale)) 21 | return stereo 22 | 23 | 24 | def generate_variable_f0_sine_wave(f0, len_y, sr): 25 | """ 26 | integrate instant frequencies to get pure tone sine wave 27 | """ 28 | x_sample = np.arange(len(f0)) 29 | intp = interp1d(x_sample, f0, kind="linear") 30 | f0_audiorate = intp(np.linspace(0, len(f0) - 1, len_y)) 31 | pitch_wave = np.sin((np.nan_to_num(f0_audiorate) / sr * 2 * np.pi).cumsum()) 32 | return pitch_wave 33 | 34 | 35 | def fluidsynth_without_normalize(self, fs=44100, sf2_path=None): 36 | """Synthesize using fluidsynth. without signal normalize 37 | Parameters 38 | ---------- 39 | fs : int 40 | Sampling rate to synthesize at. 41 | sf2_path : str 42 | Path to a .sf2 file. 43 | Default ``None``, which uses the TimGM6mb.sf2 file included with 44 | ``pretty_midi``. 45 | Returns 46 | ------- 47 | synthesized : np.ndarray 48 | Waveform of the MIDI data, synthesized at ``fs``. 49 | """ 50 | # If there are no instruments, or all instruments have no notes, return 51 | # an empty array 52 | if len(self.instruments) == 0 or all(len(i.notes) == 0 for i in self.instruments): 53 | return np.array([]) 54 | # Get synthesized waveform for each instrument 55 | waveforms = [i.fluidsynth(fs=fs, sf2_path=sf2_path) for i in self.instruments] 56 | # Allocate output waveform, with #sample = max length of all waveforms 57 | synthesized = np.zeros(np.max([w.shape[0] for w in waveforms])) 58 | # Sum all waveforms in 59 | for waveform in waveforms: 60 | synthesized[: waveform.shape[0]] += waveform 61 | # Normalize 62 | # synthesized /= np.abs(synthesized).max() 63 | return synthesized 64 | --------------------------------------------------------------------------------