├── .gitattributes ├── .gitignore ├── End2End ├── Data.py ├── MIDI_program_map.py ├── MIDI_program_map.tsv ├── Openmic_map.py ├── Wrong.ipynb ├── __init__.py ├── config │ ├── .tseparation.yaml.swp │ ├── Hungarian_IR.yaml │ ├── Instrument_Classification.yaml │ ├── Jointist.yaml │ ├── datamodule │ │ ├── h5.yaml │ │ ├── msd.yaml │ │ ├── openmic.yaml │ │ ├── slakh.yaml │ │ ├── slakh_ir.yaml │ │ └── wild.yaml │ ├── detection │ │ ├── CombinedModel_Av2.yaml │ │ ├── CombinedModel_Av2_Teacher.yaml │ │ ├── CombinedModel_CLS.yaml │ │ ├── CombinedModel_CLSv2.yaml │ │ ├── CombinedModel_H.yaml │ │ ├── CombinedModel_Linear.yaml │ │ ├── CombinedModel_NewCLSv2.yaml │ │ ├── CombinedModel_S.yaml │ │ ├── CombinedModel_Sv2.yaml │ │ ├── CombinedModel_Sv2_torch.yaml │ │ ├── OpenMicBaseline.yaml │ │ ├── Original.yaml │ │ ├── backbone │ │ │ ├── AcousticModelCnn8Dropout.yaml │ │ │ ├── CNN14.yaml │ │ │ ├── CNN14_less_pooling.yaml │ │ │ ├── CNN8.yaml │ │ │ ├── ResNet101.yaml │ │ │ └── ResNet50.yaml │ │ ├── feature │ │ │ └── mel.yaml │ │ └── transformer │ │ │ ├── BERT.yaml │ │ │ ├── BERTv2.yaml │ │ │ ├── DETR_Transformer.yaml │ │ │ ├── DETR_Transformerv2.yaml │ │ │ ├── Linear.yaml │ │ │ ├── MusicTaggingTransformer.yaml │ │ │ ├── torch_Transformer.yaml │ │ │ └── torch_Transformer_API.yaml │ ├── detection_config.yaml │ ├── jointist_inference.yaml │ ├── jointist_ss_inference.yaml │ ├── jointist_testing.yaml │ ├── openmic-DETR_Hungarian_IR.yaml │ ├── pkl2pianoroll.yaml │ ├── pkl2pianoroll_MSD.yaml │ ├── pkl2sparsepianoroll_MSD.yaml │ ├── pred_transcription_config.yaml │ ├── scheduler │ │ ├── LambdaLR.yaml │ │ ├── Lambda_ss.yaml │ │ └── MultiStepLR.yaml │ ├── separation │ │ ├── CUNet.yaml │ │ ├── TCUNet.yaml │ │ └── feature │ │ │ └── SS_STFT.yaml │ ├── separation_config.yaml │ ├── transcription │ │ ├── FrameOnly.yaml │ │ ├── Original.yaml │ │ ├── Semantic_Segmentation.yaml │ │ ├── backend │ │ │ ├── CNN_GRU.yaml │ │ │ ├── CNN_LSTM.yaml │ │ │ └── language │ │ │ │ ├── GRU.yaml │ │ │ │ └── LSTM.yaml │ │ ├── feature │ │ │ └── mel.yaml │ │ └── postprocessor │ │ │ ├── OnsetFramePostProcessor.yaml │ │ │ └── RegressionPostProcessor.yaml │ ├── transcription_config.yaml │ └── tseparation.yaml ├── constants.py ├── create_notes_for_instruments_classification_MIDI_class.py ├── create_notes_for_instruments_classification_MIDI_instrument.py ├── create_notes_for_openmic.py ├── create_openmic2018.py ├── data │ ├── _data_modules.py │ ├── augmentors.py │ ├── data_modules.py │ ├── mixing_secrets_vocals.py │ ├── samplers.py │ └── target_processors.py ├── dataset_creation │ ├── README.md │ ├── __init__.py │ ├── crash.py │ ├── create_groove.py │ ├── create_musdb18.py │ ├── create_muse.py │ ├── create_notes_for_instruments_classification.py │ ├── create_slakh2100.py │ ├── groove_prepare_midi.py │ ├── midi_track_group_config_1.csv │ ├── midi_track_group_config_2.csv │ ├── midi_track_group_config_3.csv │ ├── midi_track_group_config_4.csv │ ├── mixing_secrets │ │ ├── __init__.py │ │ └── segment_vocal_stems.py │ ├── plugin_to_midi_program.json │ ├── prepare_closed_set.py │ └── test9.py ├── inference_instruments_filter.py ├── loss.py ├── losses.py ├── lr_schedulers.py ├── models │ ├── instrument_detection │ │ ├── CLS.py │ │ ├── CLS_CNN14.py │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── combined.py │ │ ├── detr.py │ │ ├── openmic_baseline.py │ │ └── utils.py │ ├── instruments_classification_models.py │ ├── position_encoding.py │ ├── separation │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cond_unet.py │ │ └── t_cond_unet.py │ ├── transcription │ │ ├── __init__.py │ │ ├── acoustic.py │ │ ├── combined.py │ │ └── seg_baseline.py │ ├── transformer.py │ └── utils.py ├── notes.txt ├── openmic.py ├── piano_vad.py ├── samplers.py ├── slakh_instruments.pkl ├── target_processors.py ├── tasks │ ├── detection │ │ ├── __init__.py │ │ ├── binary.py │ │ ├── hungarian.py │ │ ├── hungarian_autoregressive.py │ │ ├── linear.py │ │ └── softmax_autoregressive.py │ ├── jointist.py │ ├── jointist_ss.py │ ├── separation │ │ ├── __init__.py │ │ ├── separation.py │ │ └── utils.py │ ├── t_separation.py │ └── transcription │ │ ├── __init__.py │ │ ├── transcription.py │ │ └── utils.py ├── transcription_utils.py ├── util │ ├── __init__.py │ ├── box_ops.py │ ├── misc.py │ └── plot_utils.py └── utils.py ├── GPU_debug.py ├── README.md ├── SS_visualization.ipynb ├── create_slakh2100.py ├── evaluate_end2end_Filter.py ├── experiments.md ├── f1.ipynb ├── inst_wise.png ├── jointist_explanation.md ├── model_fig.png ├── openmic_dataprocessing.sh ├── piece_wise.png ├── pkl2pianoroll.py ├── pkl2pianoroll_MSD.py ├── pkl2pianoroll_MSD.sh ├── pkl2pianoroll_MTAT.py ├── pkl2sparsepianoroll_MSD.py ├── pred_detection.py ├── pred_jointist.py ├── pred_jointist_ss.py ├── pred_transcription.py ├── requirements.txt ├── roll_convert2channel.py ├── roll_convert_sparse.py ├── slakh2100_dataprocessing.sh ├── songs.zip ├── test_detection.py ├── test_jointist.py ├── test_openmic_DETR_Hungarian.py ├── test_separation.py ├── test_transcription.py ├── test_tseparation.py ├── train_detection.py ├── train_jointist.py ├── train_separation.py ├── train_transcription.py ├── train_tseparation.py └── weights └── link.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | songs.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | .DS_Store 4 | .egg-info/ 5 | .ipynb_checkpoints 6 | .tox 7 | .pytest_cache 8 | datasets/ 9 | lightning_logs/ 10 | logs/ 11 | requirements.txt 12 | workspaces/ 13 | dataset_processed/ 14 | *.pt 15 | hdf5s/ 16 | instruments_classification_notes3/ 17 | instruments_classification_notes_MIDI_class/ 18 | outputs/ 19 | pickles/ 20 | !End2End_Dataloading.ipynb 21 | End2End/MIDI_program_Debug.ipynb 22 | visualization 23 | statistics 24 | checkpoints 25 | -------------------------------------------------------------------------------- /End2End/MIDI_program_map.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import csv 3 | with open('./End2End/MIDI_program_map.tsv') as csv_file: 4 | csv_reader = csv.reader(csv_file, delimiter='\t') 5 | line_count = 0 6 | idx2class_name = {} 7 | idx2instrument_name = {} 8 | idx2instrument_class = {} 9 | for row in csv_reader: 10 | idx2class_name[int(row[0])] = row[1] 11 | idx2instrument_name[int(row[0])] = row[2] 12 | idx2instrument_class[int(row[0])] = row[3] 13 | 14 | 15 | slakh_instrument = pickle.load(open('./End2End/slakh_instruments.pkl', 'rb')) 16 | MIDIProgramName2class_idx = {} 17 | class_idx2MIDIProgramName = {} 18 | 19 | W_MIDIClassName2class_idx = {} 20 | W_class_idx2MIDIClass = {} 21 | 22 | MIDIClassName2class_idx = {} 23 | class_idx2MIDIClass = {} 24 | for idx,i in enumerate(slakh_instrument): 25 | MIDIProgramName2class_idx[idx2instrument_name[i]] = idx 26 | class_idx2MIDIProgramName[idx] = idx2instrument_name[i] 27 | 28 | W_MIDIClassName2class_idx[idx2instrument_class[i]] = idx 29 | W_class_idx2MIDIClass[idx] = idx2instrument_class[i] 30 | 31 | # # Assigning Empty class 32 | # MIDIProgramName2class_idx['empty'] = idx+1 33 | 34 | # More general definition 35 | unique_instrument_class = [] 36 | for i in idx2instrument_class.items(): 37 | if i[1] in unique_instrument_class: 38 | continue 39 | else: 40 | unique_instrument_class.append(i[1]) 41 | 42 | for idx, class_name in enumerate(unique_instrument_class): 43 | MIDIClassName2class_idx[class_name] = idx 44 | class_idx2MIDIClass[idx] = class_name 45 | 46 | # assign empty class 47 | MIDIClassName2class_idx['Empty'] = idx+1 48 | class_idx2MIDIClass[idx+1] = 'Empty' 49 | 50 | MIDI_PROGRAM_NUM = len(MIDIProgramName2class_idx) 51 | MIDI_Class_NUM = len(MIDIClassName2class_idx) 52 | W_MIDI_Class_NUM = len(W_class_idx2MIDIClass) -------------------------------------------------------------------------------- /End2End/MIDI_program_map.tsv: -------------------------------------------------------------------------------- 1 | 0 Piano Acoustic Grand Piano Piano 2 | 1 Piano Bright Acoustic Piano Piano 3 | 2 Piano Electric Grand Piano Piano 4 | 3 Piano Honky-tonk Piano Piano 5 | 4 Piano Electric Piano 1 Electric Piano 6 | 5 Piano Electric Piano 2 Electric Piano 7 | 6 Piano Harpsichord Harpsichord 8 | 7 Piano Clavinet Clavinet 9 | 8 Chromatic Percussion Celesta Chromatic Percussion 10 | 9 Chromatic Percussion Glockenspiel Chromatic Percussion 11 | 10 Chromatic Percussion Music box Chromatic Percussion 12 | 11 Chromatic Percussion Vibraphone Chromatic Percussion 13 | 12 Chromatic Percussion Marimba Chromatic Percussion 14 | 13 Chromatic Percussion Xylophone Chromatic Percussion 15 | 14 Chromatic Percussion Tubular Bells Chromatic Percussion 16 | 15 Chromatic Percussion Dulcimer Chromatic Percussion 17 | 16 Organ Drawbar Organ Organ 18 | 17 Organ Percussive Organ Organ 19 | 18 Organ Rock Organ Organ 20 | 19 Organ Church Organ Organ 21 | 20 Organ Reed Organ Organ 22 | 21 Organ Accordion Accordion 23 | 22 Organ Harmonica Harmonica 24 | 23 Organ Tango Accordion Accordion 25 | 24 Guitar Acoustic Guitar (nylon) Acoustic Guitar 26 | 25 Guitar Acoustic Guitar (steel) Acoustic Guitar 27 | 26 Guitar Electric Guitar (jazz) Electric Guitar 28 | 27 Guitar Electric Guitar (clean) Electric Guitar 29 | 28 Guitar Electric Guitar (muted) Electric Guitar 30 | 29 Guitar Overdriven Guitar Electric Guitar 31 | 30 Guitar Distortion Guitar Electric Guitar 32 | 31 Guitar Guitar Harmonics Electric Guitar 33 | 32 Bass Acoustic Bass Bass 34 | 33 Bass Electric Bass (finger) Bass 35 | 34 Bass Electric Bass (pick) Bass 36 | 35 Bass Fretless Bass Bass 37 | 36 Bass Slap Bass 1 Bass 38 | 37 Bass Slap Bass 2 Bass 39 | 38 Bass Synth Bass 1 Bass 40 | 39 Bass Synth Bass 2 Bass 41 | 40 Strings Violin Violin 42 | 41 Strings Viola Viola 43 | 42 Strings Cello Cello 44 | 43 Strings Contrabass Contrabass 45 | 44 Strings Tremolo Strings Strings 46 | 45 Strings Pizzicato Strings Strings 47 | 46 Strings Orchestral Harp Harp 48 | 47 Strings Timpani Timpani 49 | 48 Ensemble String Ensemble 1 Strings 50 | 49 Ensemble String Ensemble 2 Strings 51 | 50 Ensemble Synth Strings 1 Strings 52 | 51 Ensemble Synth Strings 2 Strings 53 | 52 Ensemble Choir Aahs Voice 54 | 53 Ensemble Voice Oohs Voice 55 | 54 Ensemble Synth Choir Voice 56 | 55 Ensemble Orchestra Hit Strings 57 | 56 Brass Trumpet Trumpet 58 | 57 Brass Trombone Trombone 59 | 58 Brass Tuba Tuba 60 | 59 Brass Muted Trumpet Trumpet 61 | 60 Brass French Horn Horn 62 | 61 Brass Brass Section Brass 63 | 62 Brass Synth Brass 1 Brass 64 | 63 Brass Synth Brass 2 Brass 65 | 64 Reed Soprano Sax Saxophone 66 | 65 Reed Alto Sax Saxophone 67 | 66 Reed Tenor Sax Saxophone 68 | 67 Reed Baritone Sax Saxophone 69 | 68 Reed Oboe Oboe 70 | 69 Reed English Horn Horn 71 | 70 Reed Bassoon Bassoon 72 | 71 Reed Clarinet Clarinet 73 | 72 Pipe Piccolo Piccolo 74 | 73 Pipe Flute Flute 75 | 74 Pipe Recorder Recorder 76 | 75 Pipe Pan Flute Pipe 77 | 76 Pipe Blown bottle Pipe 78 | 77 Pipe Shakuhachi Pipe 79 | 78 Pipe Whistle Pipe 80 | 79 Pipe Ocarina Pipe 81 | 80 Synth Lead Lead 1 (square) Synth Lead 82 | 81 Synth Lead Lead 2 (sawtooth) Synth Lead 83 | 82 Synth Lead Lead 3 (calliope) Synth Lead 84 | 83 Synth Lead Lead 4 chiff Synth Lead 85 | 84 Synth Lead Lead 5 (charang) Synth Lead 86 | 85 Synth Lead Lead 6 (voice) Synth Lead 87 | 86 Synth Lead Lead 7 (fifths) Synth Lead 88 | 87 Synth Lead Lead 8 (bass + lead) Synth Lead 89 | 88 Synth Pad Pad 1 (new age) Synth Pad 90 | 89 Synth Pad Pad 2 (warm) Synth Pad 91 | 90 Synth Pad Pad 3 (polysynth) Synth Pad 92 | 91 Synth Pad Pad 4 (choir) Synth Pad 93 | 92 Synth Pad Pad 5 (bowed) Synth Pad 94 | 93 Synth Pad Pad 6 (metallic) Synth Pad 95 | 94 Synth Pad Pad 7 (halo) Synth Pad 96 | 95 Synth Pad Pad 8 (sweep) Synth Pad 97 | 96 Synth Effects FX 1 (rain) Synth Effects 98 | 97 Synth Effects FX 2 (soundtrack) Synth Effects 99 | 98 Synth Effects FX 3 (crystal) Synth Effects 100 | 99 Synth Effects FX 4 (atmosphere) Synth Effects 101 | 100 Synth Effects FX 5 (brightness) Synth Effects 102 | 101 Synth Effects FX 6 (goblins) Synth Effects 103 | 102 Synth Effects FX 7 (echoes) Synth Effects 104 | 103 Synth Effects FX 8 (sci-fi) Synth Effects 105 | 104 Ethnic Sitar Ethnic 106 | 105 Ethnic Banjo Ethnic 107 | 106 Ethnic Shamisen Ethnic 108 | 107 Ethnic Koto Ethnic 109 | 108 Ethnic Kalimba Ethnic 110 | 109 Ethnic Bagpipe Ethnic 111 | 110 Ethnic Fiddle Ethnic 112 | 111 Ethnic Shana Ethnic 113 | 112 Percussive Tinkle Bell Percussive 114 | 113 Percussive Agogo Percussive 115 | 114 Percussive Steel Drums Percussive 116 | 115 Percussive Woodblock Percussive 117 | 116 Percussive Taiko Drum Percussive 118 | 117 Percussive Melodic Tom Percussive 119 | 118 Percussive Synth Drum Percussive 120 | 119 Percussive Reverse Cymbal Percussive 121 | 120 Sound Effects Guitar Fret Noise Sound Effects 122 | 121 Sound Effects Breath Noise Sound Effects 123 | 122 Sound Effects Seashore Sound Effects 124 | 123 Sound Effects Bird Tweet Sound Effects 125 | 124 Sound Effects Telephone Ring Sound Effects 126 | 125 Sound Effects Helicopter Sound Effects 127 | 126 Sound Effects Applause Sound Effects 128 | 127 Sound Effects Gunshot Sound Effects 129 | 128 Drums Drums Drums -------------------------------------------------------------------------------- /End2End/Openmic_map.py: -------------------------------------------------------------------------------- 1 | OpenmicIDX2Name ={ 2 | 0: 'accordion', 3 | 1: 'banjo', 4 | 2: 'bass', 5 | 3: 'cello', 6 | 4: 'clarinet', 7 | 5: 'cymbals', 8 | 6: 'drums', 9 | 7: 'flute', 10 | 8: 'guitar', 11 | 9: 'mallet_percussion', 12 | 10: 'mandolin', 13 | 11: 'organ', 14 | 12: 'piano', 15 | 13: 'saxophone', 16 | 14: 'synthesizer', 17 | 15: 'trombone', 18 | 16: 'trumpet', 19 | 17: 'ukulele', 20 | 18: 'violin', 21 | 19: 'voice', 22 | 20: 'Empty' # this is for DETR 23 | } 24 | 25 | Name2OpenmicIDX = {} 26 | for idx,name in OpenmicIDX2Name.items(): 27 | Name2OpenmicIDX[name] = idx 28 | 29 | OpenMic_Class_NUM = len(Name2OpenmicIDX) -------------------------------------------------------------------------------- /End2End/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/__init__.py -------------------------------------------------------------------------------- /End2End/config/.tseparation.yaml.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/config/.tseparation.yaml.swp -------------------------------------------------------------------------------- /End2End/config/Hungarian_IR.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 8 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | 8 | 9 | datamodule: 10 | waveform_hdf5s_dir: 11 | notes_pkls_dir: 12 | dataset_cfg: 13 | train: 14 | segment_seconds: ${segment_seconds} 15 | frames_per_second: ${frames_per_second} 16 | pre_load_audio: False 17 | transcription: False 18 | val: 19 | segment_seconds: ${segment_seconds} 20 | frames_per_second: ${frames_per_second} 21 | pre_load_audio: False 22 | transcription: False 23 | test: 24 | segment_seconds: null 25 | frames_per_second: ${frames_per_second} 26 | pre_load_audio: False 27 | transcription: False 28 | dataloader_cfg: 29 | train: 30 | batch_size: ${batch_size} 31 | num_workers: 4 32 | shuffle: True 33 | pin_memory: True 34 | val: 35 | batch_size: ${batch_size} 36 | num_workers: 4 37 | shuffle: False 38 | pin_memory: True 39 | test: 40 | batch_size: 1 41 | num_workers: 4 42 | shuffle: False 43 | pin_memory: True 44 | 45 | MIDI_MAPPING: # This whole part will be overwritten in the main code 46 | type: 'plugin_names' 47 | plugin_labels_num: 0 48 | NAME_TO_IX: 0 49 | IX_TO_NAME: 0 50 | 51 | model: 52 | warm_up_epochs: 0 53 | alpha: 1 54 | type: 'Cnn14Seq2Seq_biLSTM' 55 | args: 56 | lr: 1e-3 57 | 58 | 59 | checkpoint: 60 | monitor: 'Total_Loss/Valid' 61 | filename: "e={epoch:02d}-acc={mAP/Valid:.2f}-loss={Loss/Valid:.2f}" 62 | save_top_k: 3 63 | mode: 'min' 64 | 65 | lr: 66 | warm_up_steps: 1000 67 | reduce_lr_steps: 10000 68 | 69 | trainer: 70 | checkpoint_callback: True 71 | gpus: ${gpus} 72 | accelerator: 'ddp' 73 | sync_batchnorm: True 74 | max_epochs: ${epochs} 75 | replace_sampler_ddp: False 76 | profiler: 'simple' 77 | 78 | 79 | 80 | evaluate: 81 | max_evaluation_steps: 100 -------------------------------------------------------------------------------- /End2End/config/Instrument_Classification.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 8 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | 8 | 9 | datamodule: 10 | waveform_hdf5s_dir: 11 | notes_pkls_dir: 12 | dataset_cfg: 13 | train: 14 | segment_seconds: ${segment_seconds} 15 | frames_per_second: ${frames_per_second} 16 | pre_load_audio: False 17 | transcription: False 18 | val: 19 | segment_seconds: ${segment_seconds} 20 | frames_per_second: ${frames_per_second} 21 | pre_load_audio: False 22 | transcription: False 23 | test: 24 | segment_seconds: null 25 | frames_per_second: ${frames_per_second} 26 | pre_load_audio: False 27 | transcription: False 28 | dataloader_cfg: 29 | train: 30 | batch_size: ${batch_size} 31 | num_workers: 4 32 | shuffle: True 33 | pin_memory: True 34 | val: 35 | batch_size: ${batch_size} 36 | num_workers: 4 37 | shuffle: False 38 | pin_memory: True 39 | test: 40 | batch_size: 1 41 | num_workers: 4 42 | shuffle: False 43 | pin_memory: True 44 | 45 | MIDI_MAPPING: # This whole part will be overwritten in the main code 46 | type: 'plugin_names' 47 | plugin_labels_num: 0 48 | NAME_TO_IX: 0 49 | IX_TO_NAME: 0 50 | 51 | model: 52 | warm_up_epochs: 0 53 | alpha: 1 54 | type: 'Cnn14MeanMax' 55 | args: 56 | lr: 1e-3 57 | 58 | 59 | checkpoint: 60 | monitor: 'Total_Loss/Valid' 61 | filename: "e={epoch:02d}-acc={mAP/Valid:.2f}-loss={Loss/Valid:.2f}" 62 | save_top_k: 3 63 | mode: 'min' 64 | 65 | lr: 66 | warm_up_steps: 1000 67 | reduce_lr_steps: 10000 68 | 69 | trainer: 70 | checkpoint_callback: True 71 | gpus: ${gpus} 72 | accelerator: 'ddp' 73 | sync_batchnorm: True 74 | max_epochs: ${epochs} 75 | replace_sampler_ddp: False 76 | profiler: 'simple' 77 | 78 | 79 | 80 | evaluate: 81 | max_evaluation_steps: 100 -------------------------------------------------------------------------------- /End2End/config/Jointist.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 1000 3 | augmentation: False 4 | batch_size: 16 5 | num_workers: 4 6 | segment_seconds: 10 7 | frames_per_second: 100 8 | every_n_epochs: 20 9 | lr: 1e-3 10 | checkpoint_path: 'outputs/2021-11-03/15-32-36/Decoder_L2-empty_0.1-feature_weigh_0.1-Cnn14Transformer-hidden=256-Q=20-LearnPos=False-aux_loss-bsz=32-audio_len=10/version_0/checkpoints/last.ckpt' 11 | 12 | 13 | datamodule: 14 | waveform_hdf5s_dir: 15 | notes_pkls_dir: 16 | dataset_cfg: 17 | train: 18 | segment_seconds: ${segment_seconds} 19 | frames_per_second: ${frames_per_second} 20 | pre_load_audio: False 21 | transcription: True 22 | random_crop: True 23 | val: 24 | segment_seconds: ${segment_seconds} 25 | frames_per_second: ${frames_per_second} 26 | pre_load_audio: False 27 | transcription: True 28 | random_crop: False 29 | test: 30 | segment_seconds: ${segment_seconds} 31 | frames_per_second: ${frames_per_second} 32 | pre_load_audio: False 33 | transcription: True 34 | random_crop: False 35 | dataloader_cfg: 36 | train: 37 | batch_size: ${batch_size} 38 | num_workers: ${num_workers} 39 | shuffle: True 40 | pin_memory: True 41 | val: 42 | batch_size: ${batch_size} 43 | num_workers: ${num_workers} 44 | shuffle: False 45 | pin_memory: True 46 | test: 47 | batch_size: ${batch_size} 48 | num_workers: ${num_workers} 49 | shuffle: False 50 | pin_memory: True 51 | 52 | MIDI_MAPPING: # This whole part will be overwritten in the main code 53 | type: 'MIDI_class' 54 | plugin_labels_num: 0 55 | NAME_TO_IX: 0 56 | IX_TO_NAME: 0 57 | 58 | checkpoint: 59 | monitor: 'Total_Loss/Valid' 60 | filename: "e={epoch:02d}-trainloss={Total_Loss/Train:.3f}-validloss{Total_Loss/Valid:.3f}" 61 | save_top_k: 1 62 | mode: 'min' 63 | save_last: True 64 | every_n_epochs: ${every_n_epochs} 65 | 66 | trainer: 67 | checkpoint_callback: True 68 | gpus: ${gpus} 69 | accelerator: 'ddp' 70 | sync_batchnorm: True 71 | max_epochs: ${epochs} 72 | replace_sampler_ddp: False 73 | profiler: 'simple' 74 | check_val_every_n_epoch: ${every_n_epochs} 75 | num_sanity_val_steps: 2 76 | 77 | 78 | defaults: 79 | - feature: STFT 80 | - detection: config 81 | - transcription: config 82 | - scheduler: LambdaLR 83 | -------------------------------------------------------------------------------- /End2End/config/datamodule/h5.yaml: -------------------------------------------------------------------------------- 1 | type: 'H5Dataset' 2 | args: 3 | h5_path: # will be overwritten in the main pred_transcription.py 4 | # by default, h5 dataset will try to predict the pianoroll for the full lenght of the audio 5 | waveform_hdf5s_dir: # This parameter won't be used, it is only applicailbe when training and evaluating 6 | notes_pkls_dir: # This parameter won't be used, it is only applicailbe when training and evaluating 7 | dataloader_cfg: 8 | pred: 9 | batch_size: 1 10 | num_workers: 0 11 | shuffle: False 12 | pin_memory: True -------------------------------------------------------------------------------- /End2End/config/datamodule/msd.yaml: -------------------------------------------------------------------------------- 1 | type: 'MSD' 2 | waveform_hdf5s_dir: #dummy variable, won't be used when MSD is used 3 | notes_pkls_dir: #dummy variable, won't be used when MSD is used 4 | dataloader_cfg: 5 | dataset_split: 'TRAIN' 6 | sampler_type: 'normal' 7 | batch_size: 1 8 | num_workers: 8 9 | num_readers: 32 10 | num_chunks: 1 11 | start: 0 12 | end: 40000 -------------------------------------------------------------------------------- /End2End/config/datamodule/openmic.yaml: -------------------------------------------------------------------------------- 1 | waveform_hdf5s_dir: 2 | notes_pkls_dir: 3 | dataset_cfg: 4 | train: 5 | pre_load_audio: False 6 | slack_mapping: False 7 | val: 8 | pre_load_audio: False 9 | slack_mapping: False 10 | test: 11 | pre_load_audio: False 12 | slack_mapping: False 13 | dataloader_cfg: 14 | train: 15 | batch_size: ${batch_size} 16 | num_workers: ${num_workers} 17 | shuffle: True 18 | pin_memory: True 19 | val: 20 | batch_size: ${batch_size} 21 | num_workers: ${num_workers} 22 | shuffle: False 23 | pin_memory: True 24 | test: 25 | batch_size: ${batch_size} 26 | num_workers: ${num_workers} 27 | shuffle: False 28 | pin_memory: True -------------------------------------------------------------------------------- /End2End/config/datamodule/slakh.yaml: -------------------------------------------------------------------------------- 1 | slakhdata_root: '../../MusicDataset' 2 | waveform_dir: 'waveforms' 3 | pkl_dir: 'packed_pkl' 4 | type: 'slakh' 5 | dataset_cfg: 6 | train: 7 | segment_seconds: ${segment_seconds} 8 | frames_per_second: ${frames_per_second} 9 | transcription: True 10 | random_crop: True 11 | source: ${source} 12 | download: ${download} 13 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 14 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 15 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 16 | val: 17 | segment_seconds: ${segment_seconds} 18 | frames_per_second: ${frames_per_second} 19 | transcription: True 20 | random_crop: False 21 | source: ${source} 22 | download: False 23 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 24 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 25 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 26 | test: 27 | segment_seconds: null 28 | frames_per_second: ${frames_per_second} 29 | transcription: True 30 | random_crop: False 31 | source: ${source} 32 | download: ${download} 33 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 34 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 35 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 36 | dataloader_cfg: 37 | train: 38 | batch_size: ${batch_size} 39 | num_workers: ${num_workers} 40 | shuffle: True 41 | pin_memory: True 42 | val: 43 | batch_size: ${batch_size} 44 | num_workers: ${num_workers} 45 | shuffle: False 46 | pin_memory: True 47 | test: 48 | batch_size: 1 49 | num_workers: ${num_workers} 50 | shuffle: False 51 | pin_memory: True -------------------------------------------------------------------------------- /End2End/config/datamodule/slakh_ir.yaml: -------------------------------------------------------------------------------- 1 | slakhdata_root: '../../MusicDataset' 2 | waveform_dir: 'waveforms' 3 | pkl_dir: 'packed_pkl' 4 | dataset_cfg: 5 | train: 6 | segment_seconds: ${segment_seconds} 7 | frames_per_second: ${frames_per_second} 8 | transcription: False 9 | random_crop: True 10 | source: ${source} 11 | download: ${download} 12 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 13 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 14 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 15 | val: 16 | segment_seconds: ${segment_seconds} 17 | frames_per_second: ${frames_per_second} 18 | transcription: False 19 | random_crop: False 20 | source: ${source} 21 | download: False 22 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 23 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 24 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 25 | test: 26 | segment_seconds: ${segment_seconds} 27 | frames_per_second: ${frames_per_second} 28 | transcription: False 29 | random_crop: False 30 | source: ${source} 31 | download: ${download} 32 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 33 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 34 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 35 | dataloader_cfg: 36 | train: 37 | batch_size: ${batch_size} 38 | num_workers: ${num_workers} 39 | shuffle: True 40 | pin_memory: True 41 | val: 42 | batch_size: ${batch_size} 43 | num_workers: ${num_workers} 44 | shuffle: False 45 | pin_memory: True 46 | test: 47 | batch_size: ${batch_size} 48 | num_workers: ${num_workers} 49 | shuffle: False 50 | pin_memory: True -------------------------------------------------------------------------------- /End2End/config/datamodule/wild.yaml: -------------------------------------------------------------------------------- 1 | type: 'WildDataset' 2 | args: 3 | audio_path: ${audio_path} 4 | audio_ext: ${audio_ext} 5 | waveform_hdf5s_dir: # This parameter won't be used, it is only applicailbe when training and evaluating 6 | notes_pkls_dir: # This parameter won't be used, it is only applicailbe when training and evaluating 7 | dataloader_cfg: 8 | pred: 9 | batch_size: 1 10 | num_workers: 1 11 | shuffle: False 12 | pin_memory: True -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_Av2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN14_less_pooling 3 | - transformer: torch_Transformer_API 4 | 5 | type: 'CombinedModel_Av2' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Hungarian_Autoregressive' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | d_model: 256 14 | dropout: 0.2 15 | positional: 16 | temperature: 10000 17 | normalize: True 18 | scale: null 19 | optimizer: 20 | lr: ${lr} 21 | eps: 1e-08 22 | weight_decay: 0.0 23 | amsgrad: True 24 | eos_coef: 0.1 25 | 26 | evaluate: 27 | max_evaluation_steps: 100 28 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_Av2_Teacher.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN14 3 | - transformer: torch_Transformer_API 4 | 5 | type: 'CombinedModel_Av2_Teacher' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Softmax' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | d_model: 256 14 | dropout: 0.2 15 | positional: 16 | temperature: 10000 17 | normalize: True 18 | scale: null 19 | optimizer: 20 | lr: ${lr} 21 | eps: 1e-08 22 | weight_decay: 0.0 23 | amsgrad: True 24 | eos_coef: 0.1 25 | shuffle_target: False 26 | target_dropout: 0 27 | scale_logits: False 28 | 29 | evaluate: 30 | max_evaluation_steps: 100 31 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_CLS.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN8 3 | - transformer: BERT 4 | 5 | type: 'CombinedModel_CLS' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Binary' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: False 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_CLSv2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN8 3 | - transformer: BERTv2 4 | - feature: mel 5 | 6 | type: 'CombinedModel_CLSv2' #All options are avliable in models/instrument_detection/combined.py 7 | task: 'Binary' 8 | 9 | model: 10 | args: 11 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 12 | num_Q: 39 13 | feature_weight: 0.1 14 | positional: 15 | temperature: 10000 16 | normalize: False 17 | scale: null 18 | optimizer: 19 | lr: ${lr} 20 | eps: 1e-08 21 | weight_decay: 0.0 22 | amsgrad: True 23 | eos_coef: 0.1 24 | 25 | evaluate: 26 | max_evaluation_steps: 100 27 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_H.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN8 3 | - transformer: DETR_Transformer 4 | 5 | type: 'CombinedModel_H' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Hungarian' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: False 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_Linear.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN14 3 | - transformer: Linear 4 | 5 | type: 'CombinedModel_Linear' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Binary' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: False 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_NewCLSv2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - feature: mel 3 | - backbone: CNN14 4 | - transformer: MusicTaggingTransformer 5 | 6 | type: 'CombinedModel_NewCLSv2' #All options are avliable in models/instrument_detection/combined.py 7 | task: 'Binary' 8 | 9 | model: 10 | args: 11 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 12 | num_Q: 39 13 | feature_weight: 0.1 14 | positional: 15 | temperature: 10000 16 | normalize: False 17 | scale: null 18 | optimizer: 19 | lr: ${lr} 20 | eps: 1e-08 21 | weight_decay: 0.0 22 | amsgrad: True 23 | eos_coef: 0.1 24 | 25 | evaluate: 26 | max_evaluation_steps: 100 27 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_S.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN8 3 | - transformer: DETR_Transformer 4 | 5 | type: 'CombinedModel_S' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Binary' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: True 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_Sv2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN14 3 | - transformer: DETR_Transformerv2 4 | 5 | type: 'CombinedModel_Sv2' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Binary' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 39 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: True 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/CombinedModel_Sv2_torch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN14 3 | - transformer: torch_Transformer_API 4 | 5 | type: 'CombinedModel_Sv2_torch' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'Binary' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 38 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: True 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/OpenMicBaseline.yaml: -------------------------------------------------------------------------------- 1 | type: 'OpenMicBaseline' #All options are avliable in models/instrument_detection/combined.py 2 | task: 'BinaryOpenMic' 3 | 4 | model: 5 | args: 6 | freq_bins: ${feature.STFT.n_mels} 7 | classes_num: ${MIDI_MAPPING.plugin_labels_num} 8 | emb_layers: 3 9 | hidden_units: ${feature.STFT.n_mels} 10 | drop_rate: 0.6 11 | optimizer: 12 | lr: ${lr} 13 | eps: 1e-08 14 | weight_decay: 0.0 15 | amsgrad: True 16 | eos_coef: 0.1 17 | 18 | evaluate: 19 | max_evaluation_steps: 100 20 | -------------------------------------------------------------------------------- /End2End/config/detection/Original.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backbone: CNN8 3 | - feature: mel 4 | 5 | type: 'Original' #All options are avliable in models/instrument_detection/combined.py 6 | task: 'BinaryOpenMic' 7 | 8 | model: 9 | args: 10 | num_classes: ${MIDI_MAPPING.plugin_labels_num} 11 | num_Q: 20 12 | feature_weight: 0.1 13 | positional: 14 | temperature: 10000 15 | normalize: False 16 | scale: null 17 | optimizer: 18 | lr: ${lr} 19 | eps: 1e-08 20 | weight_decay: 0.0 21 | amsgrad: True 22 | eos_coef: 0.1 23 | 24 | evaluate: 25 | max_evaluation_steps: 100 26 | -------------------------------------------------------------------------------- /End2End/config/detection/backbone/AcousticModelCnn8Dropout.yaml: -------------------------------------------------------------------------------- 1 | type: 'AcousticModelCnn8Dropout' 2 | args: 3 | dropout: 0.2 -------------------------------------------------------------------------------- /End2End/config/detection/backbone/CNN14.yaml: -------------------------------------------------------------------------------- 1 | type: 'CNN14' 2 | args: 3 | n_mels: ${detection.feature.STFT.n_mels} 4 | channel_list: 5 | - 64 6 | - 128 7 | - 256 8 | - 512 9 | - 1024 10 | - 2048 11 | dropout: 0.2 -------------------------------------------------------------------------------- /End2End/config/detection/backbone/CNN14_less_pooling.yaml: -------------------------------------------------------------------------------- 1 | type: 'CNN14_less_pooling' 2 | args: 3 | n_mels: ${feature.STFT.n_mels} 4 | num_pooling: 4 5 | pool_first: True 6 | channel_list: 7 | - 64 8 | - 128 9 | - 256 10 | - 512 11 | - 1024 12 | - 2048 13 | dropout: 0.2 -------------------------------------------------------------------------------- /End2End/config/detection/backbone/CNN8.yaml: -------------------------------------------------------------------------------- 1 | type: 'CNN8' 2 | args: 3 | dropout: 0.2 -------------------------------------------------------------------------------- /End2End/config/detection/backbone/ResNet101.yaml: -------------------------------------------------------------------------------- 1 | type: 'Resnet101' 2 | args: 3 | train_backbone: True 4 | return_interm_layers: False 5 | dilation: True -------------------------------------------------------------------------------- /End2End/config/detection/backbone/ResNet50.yaml: -------------------------------------------------------------------------------- 1 | type: 'Resnet50' 2 | args: 3 | train_backbone: True 4 | return_interm_layers: False 5 | dilation: True -------------------------------------------------------------------------------- /End2End/config/detection/feature/mel.yaml: -------------------------------------------------------------------------------- 1 | STFT: 2 | sample_rate: 16000 3 | n_fft: 2048 4 | hop_length: 160 5 | n_mels: 229 6 | f_min: 0 7 | f_max: 8000 8 | center: True 9 | normalized: True 10 | pad_mode: 'reflect' 11 | dB_args: 12 | multiplier: 20 13 | amin: 1e-10 14 | db_multiplier: 1.0 15 | top_db: null -------------------------------------------------------------------------------- /End2End/config/detection/transformer/BERT.yaml: -------------------------------------------------------------------------------- 1 | type: 'BertEncoder' 2 | args: 3 | vocab_size: 256 4 | hidden_size: 256 5 | num_hidden_layers: 3 6 | num_attention_heads: 8 7 | intermediate_size: 2048 8 | hidden_act: "gelu" 9 | hidden_dropout_prob: 0.4 10 | max_position_embeddings: 700 11 | attention_probs_dropout_prob: 0.5 -------------------------------------------------------------------------------- /End2End/config/detection/transformer/BERTv2.yaml: -------------------------------------------------------------------------------- 1 | type: 'BertEncoder' 2 | args: 3 | vocab_size: 256 4 | hidden_size: 256 5 | num_hidden_layers: 3 6 | num_attention_heads: 8 7 | intermediate_size: 2048 8 | hidden_act: "gelu" 9 | hidden_dropout_prob: 0.4 10 | max_position_embeddings: 700 11 | attention_probs_dropout_prob: 0.5 -------------------------------------------------------------------------------- /End2End/config/detection/transformer/DETR_Transformer.yaml: -------------------------------------------------------------------------------- 1 | type: 'DETR_Transformer' 2 | args: 3 | d_model: 256 4 | dropout: 0.2 5 | nhead: 8 6 | dim_feedforward: 2048 7 | num_encoder_layers: 0 8 | num_decoder_layers: 1 9 | normalize_before: False 10 | return_intermediate_dec: True -------------------------------------------------------------------------------- /End2End/config/detection/transformer/DETR_Transformerv2.yaml: -------------------------------------------------------------------------------- 1 | type: 'DETR_Transformerv2' 2 | args: 3 | d_model: 256 4 | dropout: 0.2 5 | nhead: 8 6 | dim_feedforward: 2048 7 | num_encoder_layers: 0 8 | num_decoder_layers: 1 9 | normalize_before: False 10 | return_intermediate_dec: True -------------------------------------------------------------------------------- /End2End/config/detection/transformer/Linear.yaml: -------------------------------------------------------------------------------- 1 | hidden_dim: 256 -------------------------------------------------------------------------------- /End2End/config/detection/transformer/MusicTaggingTransformer.yaml: -------------------------------------------------------------------------------- 1 | type: 'MusicTaggingTransformer' 2 | args: 3 | d_model: 256 4 | dropout: 0.1 5 | nhead: 8 6 | num_encoder_layers: 4 7 | attention_max_len: 512 8 | n_seq_cls: 1 9 | n_token_cls: 1 -------------------------------------------------------------------------------- /End2End/config/detection/transformer/torch_Transformer.yaml: -------------------------------------------------------------------------------- 1 | type: 'torch_Transformer' 2 | encoderlayer: 3 | d_model: ${detection.model.args.d_model} 4 | dropout: ${detection.model.args.dropout} 5 | nhead: 8 6 | dim_feedforward: 2048 7 | batch_first: True 8 | encoder: 9 | num_layers: 1 10 | 11 | decoderlayer: 12 | d_model: ${detection.model.args.d_model} 13 | dropout: ${detection.model.args.dropout} 14 | nhead: 8 15 | dim_feedforward: 2048 16 | batch_first: True 17 | decoder: 18 | num_layers: 1 -------------------------------------------------------------------------------- /End2End/config/detection/transformer/torch_Transformer_API.yaml: -------------------------------------------------------------------------------- 1 | type: 'torch_Transformer_API' 2 | args: 3 | d_model: 256 4 | dropout: 0.1 5 | nhead: 8 6 | dim_feedforward: 2048 7 | num_encoder_layers: 0 8 | num_decoder_layers: 1 9 | norm_first: False 10 | batch_first: True -------------------------------------------------------------------------------- /End2End/config/detection_config.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 300 3 | augmentation: False 4 | batch_size: 32 5 | num_workers: 4 6 | segment_seconds: 10 7 | frames_per_second: 100 8 | every_n_epochs: 10 9 | source: False 10 | download: False 11 | lr: 1e-3 12 | checkpoint_path: 'weights/MTT.ckpt' 13 | seg_batch_size: 8 # only use during inference 14 | 15 | MIDI_MAPPING: # This whole part will be overwritten in the main code 16 | type: 'slakh' 17 | plugin_labels_num: 0 18 | NAME_TO_IX: 0 19 | IX_TO_NAME: 0 20 | 21 | 22 | # This checkpoint will only be used when training this model standalone 23 | checkpoint: 24 | monitor: 'Detection_Loss/Valid' 25 | filename: "e={epoch:02d}-trainloss={Detection_Loss/Train:.3f}-validloss{Detection_Loss/Valid:.3f}" 26 | save_top_k: 1 27 | mode: 'min' 28 | save_last: True 29 | every_n_epochs: ${every_n_epochs} 30 | 31 | trainer: 32 | checkpoint_callback: True 33 | gpus: ${gpus} 34 | accelerator: 'ddp' 35 | sync_batchnorm: True 36 | max_epochs: ${epochs} 37 | replace_sampler_ddp: False 38 | profiler: 'simple' 39 | check_val_every_n_epoch: ${every_n_epochs} 40 | 41 | defaults: 42 | - detection: CombinedModel_NewCLSv2 43 | - scheduler: LambdaLR 44 | - datamodule: slakh_ir -------------------------------------------------------------------------------- /End2End/config/jointist_inference.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 # choose your GPU 2 | audio_path: # need to use absolute path 3 | audio_ext: 'mp3' # mp3, wav, flac, anything that torchaudio.load supports 4 | seg_batch_size: 8 # only use during inference 5 | frames_per_second: 100 6 | segment_seconds: 10 7 | lr: null 8 | h5_name: 'ballroom_audio.h5' # valid only when datamodule=h5 is selected 9 | h5_root: '/opt/tiger/kinwai/jointist/sheetdoctor' # valid only when datamodule=h5 is selected 10 | 11 | 12 | MIDI_MAPPING: # This whole part will be overwritten in the main code 13 | type: 'MIDI_class' 14 | plugin_labels_num: 0 15 | NAME_TO_IX: 0 16 | IX_TO_NAME: 0 17 | 18 | checkpoint: 19 | transcription: 'weights/transcription1000.ckpt' 20 | detection: "weights/MTT.ckpt" 21 | 22 | trainer: 23 | gpus: ${gpus} 24 | accelerator: 'ddp' 25 | sync_batchnorm: True 26 | replace_sampler_ddp: False 27 | profiler: 'simple' 28 | 29 | 30 | defaults: 31 | - datamodule: wild 32 | - detection: CombinedModel_NewCLSv2 33 | - transcription: Original 34 | 35 | -------------------------------------------------------------------------------- /End2End/config/jointist_ss_inference.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 # choose your GPU 2 | audio_path: 'audio_path=/opt/tiger/kinwai/jointist/songs' # need to use absolute path 3 | audio_ext: 'mp3' # mp3, wav, flac, anything that torchaudio.load supports 4 | seg_batch_size: 8 # only use during inference 5 | frames_per_second: 100 6 | segment_seconds: 10 7 | lr: null 8 | h5_name: 'ballroom_audio.h5' # valid only when datamodule=h5 is selected 9 | h5_root: '/opt/tiger/kinwai/jointist/sheetdoctor' # valid only when datamodule=h5 is selected 10 | 11 | 12 | MIDI_MAPPING: # This whole part will be overwritten in the main code 13 | type: 'MIDI_class' 14 | plugin_labels_num: 0 15 | NAME_TO_IX: 0 16 | IX_TO_NAME: 0 17 | 18 | checkpoint: 19 | tseparation: 'weights/tseparation.ckpt' 20 | detection: "weights/MTT.ckpt" 21 | 22 | trainer: 23 | gpus: ${gpus} 24 | accelerator: 'ddp' 25 | sync_batchnorm: True 26 | replace_sampler_ddp: False 27 | profiler: 'simple' 28 | 29 | 30 | defaults: 31 | - datamodule: wild 32 | - detection: CombinedModel_NewCLSv2 33 | - separation: TCUNet 34 | - transcription: Original 35 | -------------------------------------------------------------------------------- /End2End/config/jointist_testing.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 # choose your GPU 2 | audio_path: # need to use absolute path 3 | audio_ext: 'mp3' # mp3, wav, flac, anything that torchaudio.load supports 4 | seg_batch_size: 8 # only use during inference 5 | frames_per_second: 100 6 | segment_seconds: 10 7 | lr: null 8 | batch_size: 1 9 | num_workers: 2 10 | 11 | MIDI_MAPPING: # This whole part will be overwritten in the main code 12 | type: 'MIDI_class' 13 | plugin_labels_num: 0 14 | NAME_TO_IX: 0 15 | IX_TO_NAME: 0 16 | 17 | checkpoint: 18 | transcription: 'weights/transcription1000.ckpt' 19 | detection: "weights/MTT.ckpt" 20 | 21 | trainer: 22 | gpus: ${gpus} 23 | accelerator: 'ddp' 24 | sync_batchnorm: True 25 | replace_sampler_ddp: False 26 | profiler: 'simple' 27 | 28 | 29 | defaults: 30 | - datamodule: slakh 31 | - detection: CombinedModel_NewCLSv2 32 | - transcription: Original 33 | 34 | -------------------------------------------------------------------------------- /End2End/config/openmic-DETR_Hungarian_IR.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 1000 3 | augmentation: False 4 | batch_size: 32 5 | num_workers: 4 6 | segment_seconds: 10 7 | frames_per_second: 100 8 | every_n_epochs: 10 9 | lr: 1e-3 10 | checkpoint_path: 'outputs/2021-11-03/15-32-36/Decoder_L2-empty_0.1-feature_weigh_0.1-Cnn14Transformer-hidden=256-Q=20-LearnPos=False-aux_loss-bsz=32-audio_len=10/version_0/checkpoints/last.ckpt' 11 | 12 | datamodule: 13 | waveform_hdf5s_dir: 14 | notes_pkls_dir: 15 | dataset_cfg: 16 | train: 17 | pre_load_audio: False 18 | val: 19 | pre_load_audio: False 20 | test: 21 | pre_load_audio: False 22 | dataloader_cfg: 23 | train: 24 | batch_size: ${batch_size} 25 | num_workers: ${num_workers} 26 | shuffle: True 27 | pin_memory: True 28 | val: 29 | batch_size: ${batch_size} 30 | num_workers: ${num_workers} 31 | shuffle: False 32 | pin_memory: True 33 | test: 34 | batch_size: ${batch_size} 35 | num_workers: ${num_workers} 36 | shuffle: False 37 | pin_memory: True 38 | 39 | MIDI_MAPPING: # This whole part will be overwritten in the main code 40 | type: 'MIDI_class' 41 | plugin_labels_num: 0 42 | NAME_TO_IX: 0 43 | IX_TO_NAME: 0 44 | 45 | model: 46 | type: 'Cnn14Transformer' 47 | optimizer: 48 | lr: ${lr} 49 | eps: 1e-08 50 | weight_decay: 0.0 51 | amsgrad: True 52 | eos_coef: 0.1 53 | args: 54 | hidden_dim: 256 55 | num_Q: 20 56 | max_pos: 50 57 | learnable_pos: False 58 | nheads: 8 59 | feature_weight: 0.1 60 | num_encoder_layers: 0 61 | num_decoder_layers: 2 62 | spec_args: 63 | sample_rate: 16000 64 | n_fft: 1024 65 | hop_length: 160 66 | n_mels: 229 67 | f_min: 0 68 | f_max: 8000 69 | center: True 70 | pad_mode: 'reflect' 71 | 72 | 73 | checkpoint: 74 | monitor: 'Total_Loss/Valid' 75 | filename: "e={epoch:02d}-trainloss={Total_Loss/Train:.3f}-validloss{Total_Loss/Valid:.3f}" 76 | save_top_k: 1 77 | mode: 'min' 78 | save_last: True 79 | every_n_epochs: ${every_n_epochs} 80 | 81 | trainer: 82 | checkpoint_callback: True 83 | gpus: ${gpus} 84 | accelerator: 'ddp' 85 | sync_batchnorm: True 86 | max_epochs: ${epochs} 87 | replace_sampler_ddp: False 88 | profiler: 'simple' 89 | check_val_every_n_epoch: ${every_n_epochs} 90 | num_sanity_val_steps: 2 91 | 92 | 93 | 94 | evaluate: 95 | max_evaluation_steps: 100 96 | 97 | LambdaLR_args: # only useful when using LambdaLR 98 | warm_up_steps: 1000 99 | reduce_lr_steps: 10000 100 | 101 | defaults: 102 | - scheduler: LambdaLR -------------------------------------------------------------------------------- /End2End/config/pkl2pianoroll.yaml: -------------------------------------------------------------------------------- 1 | audio_h5_path: './sheetdoctor/template_audio.h5' 2 | pkl_path: './multirun/2022-02-04/23-14-29/0/MIDI_output/' 3 | roll_output_path: './sheetdoctor/roll' -------------------------------------------------------------------------------- /End2End/config/pkl2pianoroll_MSD.yaml: -------------------------------------------------------------------------------- 1 | pkl_path: './outputs/2022-03-08/21-14-00/MIDI_output' 2 | roll_name: 'MSD_train_part01' 3 | roll_output_path: './MSD/roll' 4 | -------------------------------------------------------------------------------- /End2End/config/pkl2sparsepianoroll_MSD.yaml: -------------------------------------------------------------------------------- 1 | pkl_path: './MSD_Test/21-57-27/MIDI_output' 2 | roll_output_path: './MSD_Train/sparse_roll' 3 | -------------------------------------------------------------------------------- /End2End/config/pred_transcription_config.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 4 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | val_frequency: 50 8 | lr: 1e-3 9 | h5_name: 'ballroom_audio.h5' 10 | h5_root: '/opt/tiger/kinwai/jointist/h5dataset' 11 | inst_sampler: 12 | mode: 'imbalance' 13 | temp: 0.9 14 | samples: 3 15 | audio_noise: 0.1 16 | 17 | MIDI_MAPPING: # This whole part will be overwritten in the main code 18 | type: 'MIDI_class' 19 | plugin_labels_num: 0 20 | NAME_TO_IX: 0 21 | IX_TO_NAME: 0 22 | 23 | trainer: 24 | checkpoint_callback: True 25 | gpus: ${gpus} 26 | accelerator: 'ddp' 27 | sync_batchnorm: True 28 | max_epochs: ${epochs} 29 | replace_sampler_ddp: False 30 | profiler: 'simple' 31 | check_val_every_n_epoch: ${val_frequency} 32 | log_every_n_steps: 100 33 | 34 | checkpoint: 35 | monitor: 'Transcription_Loss/Valid' 36 | filename: "e={epoch:02d}-train_loss={Transcription_Loss/Train:.2f}-valid_loss={Transcription_Loss/Valid:.2f}" 37 | save_top_k: 1 38 | mode: 'min' 39 | save_last: True 40 | every_n_epochs: ${trainer.check_val_every_n_epoch} 41 | 42 | defaults: 43 | - transcription: Original 44 | - scheduler: LambdaLR 45 | - datamodule: h5 46 | -------------------------------------------------------------------------------- /End2End/config/scheduler/LambdaLR.yaml: -------------------------------------------------------------------------------- 1 | type: LambdaLR 2 | args: 3 | warm_up_steps: 1000 4 | reduce_lr_steps: 10000 -------------------------------------------------------------------------------- /End2End/config/scheduler/Lambda_ss.yaml: -------------------------------------------------------------------------------- 1 | type: LambdaLR 2 | args: 3 | warm_up_steps: 10000 4 | reduce_lr_steps: 50000 -------------------------------------------------------------------------------- /End2End/config/scheduler/MultiStepLR.yaml: -------------------------------------------------------------------------------- 1 | type: MultiStepLR 2 | milestones: 3 | - 50 4 | - 100 5 | - 150 6 | - 300 7 | - 450 8 | - 600 9 | - 750 10 | - 900 11 | gamma: 0.5 -------------------------------------------------------------------------------- /End2End/config/separation/CUNet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - feature: SS_STFT 3 | model: 4 | type: "CondUNet" 5 | args: 6 | channels_num: 1 7 | condition_size: ${MIDI_MAPPING.plugin_labels_num} 8 | is_gamma: False 9 | is_beta: True 10 | loss_types: 'l1_wav' 11 | 12 | evaluation: 13 | max_evaluation_steps: 100 14 | onset_threshold: 0.1 15 | offset_threshod: 0.1 16 | frame_threshold: 0.1 17 | pedal_offset_threshold: 0.2 18 | modeling_offset: ${transcription.model.args.modeling_offset} 19 | seg_batch_size: 8 20 | checkpoint_path: '/opt/tiger/kinwai/jointist/outputs/SS_Weights/2022-01-20_18-46-27_CondUNet-3p1n-csize=40/CondUNet-3p1n-csize=40/version_0/checkpoints/last.ckpt' 21 | output_path: null 22 | 23 | batchprocess: 24 | MIDI_MAPPING: ${MIDI_MAPPING} 25 | mode: ${inst_sampler.mode} 26 | temp: ${inst_sampler.temp} 27 | samples: ${inst_sampler.samples} 28 | neg_samples: ${inst_sampler.neg_samples} 29 | audio_noise: ${inst_sampler.audio_noise} 30 | transcription: ${transcription} 31 | source_separation: ${source} -------------------------------------------------------------------------------- /End2End/config/separation/TCUNet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - feature: SS_STFT 3 | 4 | model: 5 | type: "TCondUNet" 6 | args: 7 | mode: 'sum' 8 | condition_size: ${MIDI_MAPPING.plugin_labels_num} 9 | is_gamma: False 10 | is_beta: True 11 | loss_types: 'l1_wav' 12 | 13 | 14 | evaluation: 15 | max_evaluation_steps: 100 16 | onset_threshold: 0.1 17 | offset_threshod: 0.1 18 | frame_threshold: 0.1 19 | pedal_offset_threshold: 0.2 20 | modeling_offset: ${transcription.model.args.modeling_offset} 21 | seg_batch_size: 8 22 | checkpoint_path: '/opt/tiger/kinwai/jointist/weights/2022-02-03/11-53-36/TSeparation-3p0n-ste_roll-pretrainedT/version_0/checkpoints/last.ckpt' 23 | output_path: null 24 | 25 | batchprocess: 26 | MIDI_MAPPING: ${MIDI_MAPPING} 27 | mode: ${inst_sampler.mode} 28 | temp: ${inst_sampler.temp} 29 | samples: ${inst_sampler.samples} 30 | neg_samples: ${inst_sampler.neg_samples} 31 | audio_noise: ${inst_sampler.audio_noise} 32 | transcription: ${transcription} 33 | source_separation: ${source} -------------------------------------------------------------------------------- /End2End/config/separation/feature/SS_STFT.yaml: -------------------------------------------------------------------------------- 1 | STFT: 2 | n_fft: 1024 3 | hop_length: 160 4 | center: True 5 | pad_mode: "reflect" 6 | power: null 7 | return_complex: True 8 | iSTFT: 9 | n_fft: ${separation.feature.STFT.n_fft} 10 | hop_length: ${separation.feature.STFT.hop_length} 11 | center: ${separation.feature.STFT.center} 12 | pad_mode: ${separation.feature.STFT.pad_mode} -------------------------------------------------------------------------------- /End2End/config/separation_config.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 4 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | val_frequency: 5 8 | num_workers: 4 9 | lr: 1e-4 10 | source: True 11 | transcription: # to be overwritten depending on the model type 12 | inst_sampler: 13 | mode: 'imbalance' 14 | temp: 0.5 15 | samples: 3 16 | neg_samples: 1 17 | audio_noise: 0.0 18 | 19 | datamodule: 20 | waveform_hdf5s_dir: 21 | notes_pkls_dir: 22 | dataset_cfg: 23 | train: 24 | segment_seconds: ${segment_seconds} 25 | frames_per_second: ${frames_per_second} 26 | pre_load_audio: False 27 | transcription: ${transcription} 28 | random_crop: True 29 | source: ${source} 30 | val: 31 | segment_seconds: ${segment_seconds} 32 | frames_per_second: ${frames_per_second} 33 | pre_load_audio: False 34 | transcription: ${transcription} 35 | random_crop: False 36 | source: ${source} 37 | test: 38 | segment_seconds: ${segment_seconds} 39 | frames_per_second: ${frames_per_second} 40 | pre_load_audio: False 41 | transcription: ${transcription} 42 | random_crop: False 43 | source: ${source} 44 | dataloader_cfg: 45 | train: 46 | batch_size: ${batch_size} 47 | num_workers: ${num_workers} 48 | shuffle: True 49 | pin_memory: True 50 | val: 51 | batch_size: ${batch_size} 52 | num_workers: ${num_workers} 53 | shuffle: False 54 | pin_memory: True 55 | test: 56 | batch_size: 1 57 | num_workers: ${num_workers} 58 | shuffle: False 59 | pin_memory: True 60 | 61 | MIDI_MAPPING: # This whole part will be overwritten in the main code 62 | type: 'MIDI_class' 63 | plugin_labels_num: 0 64 | NAME_TO_IX: 0 65 | IX_TO_NAME: 0 66 | 67 | trainer: 68 | checkpoint_callback: True 69 | gpus: ${gpus} 70 | accelerator: 'ddp' 71 | sync_batchnorm: True 72 | max_epochs: ${epochs} 73 | replace_sampler_ddp: False 74 | profiler: 'simple' 75 | check_val_every_n_epoch: ${val_frequency} 76 | log_every_n_steps: 100 77 | 78 | checkpoint: 79 | monitor: 'Separation/Valid/Loss' 80 | filename: "e={epoch:02d}-train_loss={Separation/Train/Loss:.2f}-valid_loss={Separation/Valid/Loss:.2f}" 81 | save_top_k: 1 82 | mode: 'min' 83 | save_last: True 84 | every_n_epochs: ${trainer.check_val_every_n_epoch} 85 | 86 | defaults: 87 | - separation: CUNet 88 | - scheduler: LambdaLR 89 | -------------------------------------------------------------------------------- /End2End/config/transcription/FrameOnly.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backend: CNN_GRU 3 | model: 4 | type: "FrameOnly" 5 | args: 6 | frames_per_second: ${frames_per_second} 7 | classes_num: 88 8 | modeling_offset: False 9 | modeling_velocity: False 10 | loss_types: 11 | - onset 12 | - frame 13 | 14 | evaluation: 15 | max_evaluation_steps: 100 16 | onset_threshold: 0.1 17 | offset_threshod: 0.1 18 | frame_threshold: 0.1 19 | pedal_offset_threshold: 0.2 20 | modeling_offset: ${model.args.modeling_offset} 21 | seg_batch_size: 8 22 | checkpoint_path: '/opt/tiger/kinwai/jointist/outputs/2021-12-14/14-01-06/Original-CNN8Dropout-GRU_256-MIDI_class-imbalance-fps=100-csize=39-bz=16/version_0/checkpoints/last.ckpt' 23 | output_path: null 24 | 25 | postprocessor: 26 | frames_per_second: ${frames_per_second} 27 | classes_num: ${MIDI_MAPPING.plugin_labels_num} 28 | onset_threshold: ${transcription.evaluation.onset_threshold} 29 | offset_threshold: ${transcription.evaluation.offset_threshod} 30 | frame_threshold: ${transcription.evaluation.frame_threshold} 31 | pedal_offset_threshold: ${transcription.evaluation.pedal_offset_threshold} 32 | modeling_offset: ${transcription.model.args.modeling_offset} -------------------------------------------------------------------------------- /End2End/config/transcription/Original.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - backend: CNN_GRU 3 | - postprocessor: OnsetFramePostProcessor 4 | - feature: mel 5 | 6 | model: 7 | type: "Original" 8 | args: 9 | frames_per_second: ${frames_per_second} 10 | classes_num: 88 11 | modeling_offset: False 12 | modeling_velocity: False 13 | loss_types: 14 | - onset 15 | - frame 16 | 17 | evaluation: 18 | max_evaluation_steps: 100 19 | onset_threshold: 0.1 20 | offset_threshod: 0.1 21 | frame_threshold: 0.1 22 | pedal_offset_threshold: 0.2 23 | modeling_offset: ${transcription.model.args.modeling_offset} 24 | seg_batch_size: 8 25 | checkpoint_path: '/workspace/public_data/raven/amt_ir/weights/transcription1000.ckpt' 26 | output_path: null 27 | 28 | batchprocess: 29 | MIDI_MAPPING: ${MIDI_MAPPING} 30 | mode: ${inst_sampler.mode} 31 | temp: ${inst_sampler.temp} 32 | samples: ${inst_sampler.samples} 33 | neg_samples: ${inst_sampler.neg_samples} 34 | audio_noise: ${inst_sampler.audio_noise} 35 | -------------------------------------------------------------------------------- /End2End/config/transcription/Semantic_Segmentation.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: "Semantic_Segmentation" 3 | args: 4 | out_class: ${MIDI_MAPPING.plugin_labels_num} 5 | dropout_rate: 0.4 6 | 7 | evaluation: 8 | max_evaluation_steps: 100 9 | onset_threshold: 0.1 10 | offset_threshod: 0.1 11 | frame_threshold: 0.1 12 | pedal_offset_threshold: 0.2 13 | modeling_offset: False 14 | seg_batch_size: 8 15 | checkpoint_path: '/opt/tiger/kinwai/jointist/outputs/2021-12-25/08-38-16/Semantic_Segmentation-MIDI_class-csize=39-bz=4/version_0/checkpoints/last.ckpt' 16 | output_path: null 17 | 18 | defaults: 19 | - postprocessor: OnsetFramePostProcessor 20 | - feature: mel 21 | 22 | batchprocess: 23 | MIDI_MAPPING: ${MIDI_MAPPING} -------------------------------------------------------------------------------- /End2End/config/transcription/backend/CNN_GRU.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - language: GRU 3 | 4 | acoustic: 5 | type: "CNN8Dropout" 6 | args: 7 | condition_size: ${MIDI_MAPPING.plugin_labels_num} 8 | in_channels: 1 9 | 10 | acoustic_dim: 768 11 | acoustic_dropout: 0.5 12 | language_dim: 512 13 | language_dropout: 0.5 -------------------------------------------------------------------------------- /End2End/config/transcription/backend/CNN_LSTM.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - language: LSTM 3 | 4 | acoustic: 5 | type: "CNN8Dropout" 6 | args: 7 | condition_size: ${MIDI_MAPPING.plugin_labels_num} 8 | in_channels: 1 9 | 10 | acoustic_dim: 768 11 | acoustic_dropout: 0.5 12 | language_dim: 512 13 | language_dropout: 0.5 -------------------------------------------------------------------------------- /End2End/config/transcription/backend/language/GRU.yaml: -------------------------------------------------------------------------------- 1 | type: "GRU" 2 | args: 3 | input_size: 768 4 | hidden_size: 256 5 | num_layers: 2 6 | bias: True 7 | batch_first: True 8 | dropout: 0.0 9 | bidirectional: True -------------------------------------------------------------------------------- /End2End/config/transcription/backend/language/LSTM.yaml: -------------------------------------------------------------------------------- 1 | type: "LSTM" 2 | args: 3 | input_size: 768 4 | hidden_size: 256 5 | num_layers: 2 6 | bias: True 7 | batch_first: True 8 | dropout: 0.0 9 | bidirectional: True -------------------------------------------------------------------------------- /End2End/config/transcription/feature/mel.yaml: -------------------------------------------------------------------------------- 1 | STFT: 2 | sample_rate: 16000 3 | n_fft: 2048 4 | hop_length: 160 5 | n_mels: 229 6 | f_min: 0 7 | f_max: 8000 8 | center: True 9 | normalized: True 10 | pad_mode: 'reflect' 11 | dB_args: 12 | multiplier: 20 13 | amin: 1e-10 14 | db_multiplier: 1.0 15 | top_db: null -------------------------------------------------------------------------------- /End2End/config/transcription/postprocessor/OnsetFramePostProcessor.yaml: -------------------------------------------------------------------------------- 1 | type: 'OnsetFramePostProcessor' 2 | args: 3 | frames_per_second: ${frames_per_second} 4 | onset_threshold: ${transcription.evaluation.onset_threshold} 5 | offset_threshold: ${transcription.evaluation.offset_threshod} 6 | frame_threshold: ${transcription.evaluation.frame_threshold} 7 | pedal_offset_threshold: ${transcription.evaluation.pedal_offset_threshold} 8 | modeling_offset: ${transcription.evaluation.modeling_offset} -------------------------------------------------------------------------------- /End2End/config/transcription/postprocessor/RegressionPostProcessor.yaml: -------------------------------------------------------------------------------- 1 | type: 'RegressionPostProcessor' 2 | args: 3 | frames_per_second: ${frames_per_second} 4 | onset_threshold: ${transcription.evaluation.onset_threshold} 5 | offset_threshold: ${transcription.evaluation.offset_threshod} 6 | frame_threshold: ${transcription.evaluation.frame_threshold} 7 | pedal_offset_threshold: ${transcription.evaluation.pedal_offset_threshold} 8 | modeling_offset: ${transcription.model.args.modeling_offset} -------------------------------------------------------------------------------- /End2End/config/transcription_config.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 12 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | val_frequency: 50 8 | lr: 1e-3 9 | source: True 10 | download: False 11 | inst_sampler: 12 | mode: 'imbalance' 13 | temp: 0.5 14 | samples: 3 15 | neg_samples: 1 16 | audio_noise: 0.0 17 | 18 | datamodule: 19 | slakhdata_root: '../../MusicDataset' 20 | waveform_dir: 'waveforms/packed_waveforms' 21 | pkl_dir: 'packed_pkl' 22 | dataset_cfg: 23 | train: 24 | segment_seconds: ${segment_seconds} 25 | frames_per_second: ${frames_per_second} 26 | transcription: True 27 | random_crop: True 28 | source: ${source} 29 | download: ${download} 30 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 31 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 32 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 33 | val: 34 | segment_seconds: ${segment_seconds} 35 | frames_per_second: ${frames_per_second} 36 | transcription: True 37 | random_crop: False 38 | source: ${source} 39 | download: False 40 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 41 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 42 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 43 | test: 44 | segment_seconds: null 45 | frames_per_second: ${frames_per_second} 46 | transcription: True 47 | random_crop: False 48 | source: ${source} 49 | download: ${download} 50 | name_to_ix: ${MIDI_MAPPING.NAME_TO_IX} 51 | ix_to_name: ${MIDI_MAPPING.IX_TO_NAME} 52 | plugin_labels_num: ${MIDI_MAPPING.plugin_labels_num} 53 | dataloader_cfg: 54 | train: 55 | batch_size: ${batch_size} 56 | num_workers: 4 57 | shuffle: True 58 | pin_memory: True 59 | val: 60 | batch_size: ${batch_size} 61 | num_workers: 4 62 | shuffle: False 63 | pin_memory: True 64 | test: 65 | batch_size: 1 66 | num_workers: 4 67 | shuffle: False 68 | pin_memory: True 69 | 70 | MIDI_MAPPING: # This whole part will be overwritten in the main code 71 | type: 'MIDI_class' 72 | plugin_labels_num: 0 73 | NAME_TO_IX: 0 74 | IX_TO_NAME: 0 75 | 76 | trainer: 77 | checkpoint_callback: True 78 | gpus: ${gpus} 79 | accelerator: 'ddp' 80 | sync_batchnorm: True 81 | max_epochs: ${epochs} 82 | replace_sampler_ddp: False 83 | profiler: 'simple' 84 | check_val_every_n_epoch: ${val_frequency} 85 | log_every_n_steps: 100 86 | 87 | checkpoint: 88 | monitor: 'Transcription_Loss/Valid' 89 | filename: "e={epoch:02d}-train_loss={Transcription_Loss/Train:.2f}-valid_loss={Transcription_Loss/Valid:.2f}" 90 | save_top_k: 1 91 | mode: 'min' 92 | save_last: True 93 | every_n_epochs: ${trainer.check_val_every_n_epoch} 94 | 95 | defaults: 96 | - transcription: Original 97 | - scheduler: LambdaLR 98 | -------------------------------------------------------------------------------- /End2End/config/tseparation.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 2 | epochs: 500 3 | augmentation: False 4 | batch_size: 2 5 | segment_seconds: 10 6 | frames_per_second: 100 7 | val_frequency: 5 8 | num_workers: 8 9 | straight_through: False 10 | lr: 1e-4 11 | source: True 12 | inst_sampler: 13 | mode: 'imbalance' 14 | temp: 0.5 15 | samples: 3 16 | neg_samples: 0 17 | audio_noise: 0.0 18 | 19 | transcription_weights: '/opt/tiger/kinwai/jointist/weights/transcription1000.ckpt' # for loading pretrained transcription 20 | checkpoint_path: './weights/tseparation.ckpt' 21 | 22 | datamodule: 23 | waveform_hdf5s_dir: 24 | notes_pkls_dir: 25 | dataset_cfg: 26 | train: 27 | segment_seconds: ${segment_seconds} 28 | frames_per_second: ${frames_per_second} 29 | pre_load_audio: False 30 | transcription: True 31 | random_crop: True 32 | source: ${source} 33 | val: 34 | segment_seconds: ${segment_seconds} 35 | frames_per_second: ${frames_per_second} 36 | pre_load_audio: False 37 | transcription: True 38 | random_crop: False 39 | source: ${source} 40 | test: 41 | segment_seconds: null 42 | frames_per_second: ${frames_per_second} 43 | pre_load_audio: False 44 | transcription: True 45 | random_crop: False 46 | source: ${source} 47 | dataloader_cfg: 48 | train: 49 | batch_size: ${batch_size} 50 | num_workers: ${num_workers} 51 | shuffle: True 52 | pin_memory: True 53 | val: 54 | batch_size: ${batch_size} 55 | num_workers: ${num_workers} 56 | shuffle: False 57 | pin_memory: True 58 | test: 59 | batch_size: 1 60 | num_workers: ${num_workers} 61 | shuffle: False 62 | pin_memory: True 63 | 64 | MIDI_MAPPING: # This whole part will be overwritten in the main code 65 | type: 'MIDI_class' 66 | plugin_labels_num: 0 67 | NAME_TO_IX: 0 68 | IX_TO_NAME: 0 69 | 70 | checkpoint: 71 | monitor: 'Total_Loss/Valid' 72 | filename: "e={epoch:02d}-trainloss={Total_Loss/Train:.3f}-validloss{Total_Loss/Valid:.3f}" 73 | save_top_k: 1 74 | mode: 'min' 75 | save_last: True 76 | every_n_epochs: ${trainer.check_val_every_n_epoch} 77 | 78 | trainer: 79 | checkpoint_callback: True 80 | gpus: ${gpus} 81 | accelerator: 'ddp' 82 | sync_batchnorm: True 83 | max_epochs: ${epochs} 84 | replace_sampler_ddp: False 85 | profiler: 'simple' 86 | check_val_every_n_epoch: ${val_frequency} 87 | num_sanity_val_steps: 2 88 | resume_from_checkpoint: False 89 | 90 | defaults: 91 | - transcription: Original 92 | - separation: TCUNet 93 | - scheduler: LambdaLR 94 | -------------------------------------------------------------------------------- /End2End/constants.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import pathlib 4 | 5 | 6 | SAMPLE_RATE = 16000 7 | CLASSES_NUM = 88 # Number of notes of piano 8 | BEGIN_NOTE = 21 # MIDI note of A0, the lowest note of a piano. 9 | SEGMENT_SECONDS = 10.0 # Training segment duration 10 | HOP_SECONDS = 1.0 11 | FRAMES_PER_SECOND = 100 12 | VELOCITY_SCALE = 128 13 | 14 | TAGGING_SEGMENT_SECONDS = 2.0 15 | 16 | # Load plugin related information. 17 | with open('End2End/dataset_creation/plugin_to_midi_program.json') as f: 18 | plugin_dict = json.load(f) 19 | 20 | PLUGIN_LABELS = sorted([pathlib.Path(key).stem for key in plugin_dict.keys()]) 21 | # E.g., ['AGML2', ..., 'bass_trombone', 'bassoon', ...] 22 | 23 | PLUGIN_LABELS_NUM = len(PLUGIN_LABELS) 24 | PLUGIN_LB_TO_IX = {lb: i for i, lb in enumerate(PLUGIN_LABELS)} 25 | PLUGIN_IX_TO_LB = {i: lb for i, lb in enumerate(PLUGIN_LABELS)} 26 | 27 | # Get plugin name to instruments mapping. 28 | PLUGIN_NAME_TO_INSTRUMENT = {} 29 | 30 | for key in plugin_dict.keys(): 31 | count = -1 32 | 33 | for instrument_name in plugin_dict[key].keys(): 34 | this_count = plugin_dict[key][instrument_name] 35 | 36 | if this_count > count: 37 | instrument = instrument_name 38 | count = this_count 39 | 40 | PLUGIN_NAME_TO_INSTRUMENT[pathlib.Path(key).stem] = instrument 41 | 42 | # E.g., PLUGIN_NAME_TO_INSTRUMENT: { 43 | # 'elektrik_guitar': 'Overdriven Guitar', 44 | # 'session_kit_full': 'Drums', 45 | # ... 46 | # } 47 | 48 | BN_MOMENTUM = 0.01 # a globally applied momentum 49 | -------------------------------------------------------------------------------- /End2End/data/augmentors.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import sox 4 | 5 | from End2End.constants import SAMPLE_RATE 6 | 7 | 8 | class Augmentor: 9 | def __init__(self, augmentation: str): 10 | r"""Data augmentor. 11 | 12 | Args: 13 | augmentation: str, 'none' | 'aug' 14 | """ 15 | 16 | self.augmentation = augmentation 17 | self.random_state = np.random.RandomState(1234) 18 | self.sample_rate = SAMPLE_RATE 19 | 20 | def __call__(self, x): 21 | r"""Do augmentation. 22 | 23 | Args: 24 | x: ndarray, (audio_length,) 25 | 26 | Returns: 27 | ndarray, (audio_length) 28 | """ 29 | 30 | if self.augmentation == 'none': 31 | return x 32 | 33 | elif self.augmentation == 'aug': 34 | return self.aug(x) 35 | 36 | def aug(self, x): 37 | # Todo 38 | clip_samples = len(x) 39 | 40 | logger = logging.getLogger('sox') 41 | logger.propagate = False 42 | 43 | tfm = sox.Transformer() 44 | tfm.set_globals(verbosity=0) 45 | 46 | tfm.pitch(self.random_state.uniform(-0.1, 0.1, 1)[0]) 47 | tfm.contrast(self.random_state.uniform(0, 100, 1)[0]) 48 | 49 | tfm.equalizer( 50 | frequency=self.loguniform(32, 4096, 1)[0], 51 | width_q=self.random_state.uniform(1, 2, 1)[0], 52 | gain_db=self.random_state.uniform(-30, 10, 1)[0], 53 | ) 54 | 55 | tfm.equalizer( 56 | frequency=self.loguniform(32, 4096, 1)[0], 57 | width_q=self.random_state.uniform(1, 2, 1)[0], 58 | gain_db=self.random_state.uniform(-30, 10, 1)[0], 59 | ) 60 | 61 | tfm.reverb(reverberance=self.random_state.uniform(0, 70, 1)[0]) 62 | 63 | aug_x = tfm.build_array(input_array=x, sample_rate_in=self.sample_rate) 64 | aug_x = pad_truncate_sequence(aug_x, clip_samples) 65 | 66 | return aug_x 67 | 68 | def loguniform(self, low, high, size): 69 | return np.exp(self.random_state.uniform(np.log(low), np.log(high), size)) 70 | -------------------------------------------------------------------------------- /End2End/data/mixing_secrets_vocals.py: -------------------------------------------------------------------------------- 1 | """This module loads mixing secrets vocal stems and their pitch contour estimated by crepe. 2 | Run `./scripts/dataset-mixing-secrets/get-vocal-stems-to-local.sh` to copy them to your local folder. 3 | 4 | Details of CREPE: https://arxiv.org/abs/1802.06182 5 | 6 | The mixing secrets dataset doesn't have any split. Based on this 10-second audio files, I decided that 7 | the first 19613 files are training set, 19613:20850 is for validation, 21850:21890 is for testing. 8 | In this way, there's no artist/track overlapping across the sets. 9 | 10 | """ 11 | import os 12 | import glob 13 | 14 | import torch 15 | import numpy as np 16 | import soundfile as sf 17 | 18 | NUM_VOCAL_STEMS = 21890 # 19613, 20850, 19 | 20 | 21 | class MixingSecretDataset(torch.utils.data.Dataset): 22 | def __init__(self, split: str, wav_path: str, npy_path: str): 23 | """ 24 | 25 | Args: 26 | split: 'train', 'valid' ,'test' 27 | 28 | wav_path: directory path of the wav files. 29 | See ./scripts/dataset-mixing-secrets/get-vocal-stems-to-local.sh for more information. 30 | 31 | npy_path: directory path of the npy files. 32 | """ 33 | self.split = split 34 | self.wav_path = wav_path 35 | self.npy_path = npy_path 36 | 37 | wav_filenames = sorted(glob.glob(os.path.join(self.wav_path, '*.wav'))) 38 | npy_filenames = sorted(glob.glob(os.path.join(self.npy_path, '*.npy'))) 39 | 40 | if len(wav_filenames) != len(npy_filenames) or len(wav_filenames) != NUM_VOCAL_STEMS: 41 | raise RuntimeError(f'{len(wav_filenames)} != {len(npy_filenames)}. They all should be {NUM_VOCAL_STEMS}.') 42 | 43 | self.fiilenames = [fn[:-4] for fn in wav_filenames] 44 | if self.split == 'train': 45 | self.filenames = self.filenames[:19613] 46 | elif self.split == 'valid': 47 | self.filenames = self.filenames[19613:20850] 48 | elif self.split == 'test': 49 | self.filenames = self.filenames[20850:] 50 | else: 51 | raise ValueError(f'self.split is unexpected --> {self.split}') 52 | 53 | self.sample_rate = 16000 54 | 55 | def __len__(self): 56 | return len(self.filenames) 57 | 58 | def __getitem__(self, index: int): 59 | """ 60 | 61 | Args: 62 | index (int): an integer that is < 21890 63 | 64 | Returns: 65 | audio_signal (float): [-1.0, 1.0] normalized audio signal. 10 second. shape: (160000, ) 66 | pitch (float): (1001=time, 360=pitch) shaped. prediction was made every 10ms. a pitch bin covers 20 cents. 67 | its range is C1 to B7. 68 | """ 69 | 70 | audio_signal, sr = sf.read(os.path.join(self.wav_path, self.filenames[index] + '.wav')) 71 | assert sr == self.sample_rate 72 | pitch = np.load(os.path.join(self.npy_path, self.filenames[index] + '.activation.npy')) 73 | 74 | return audio_signal, pitch 75 | -------------------------------------------------------------------------------- /End2End/dataset_creation/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Slakh dataset processing code 3 | 4 | - To group the MIDI instruments to a closed set 5 | 6 | ## Files 7 | 8 | - `midi_track_group_config_1.csv`: the config file of MIDI track group info. It groups to 5 tracks based on MIDI program number. The 5 tracks represent: 9 | - Piano 10 | - Organ 11 | - Strings 12 | - Bass 13 | - Drums 14 | 15 | - `midi_track_group_config_2.csv`: the config file of MIDI track group info. It groups to 6 tracks based on plug-in name. The 6 tracks represent: 16 | - Piano 17 | - Organ 18 | - Strings 19 | - Bass 20 | - Distorted 21 | - Drums 22 | 23 | - `midi_track_group_config_3.csv`: the config file of MIDI track group info. It groups to 3 tracks based on plug-in name. The 3 tracks represent: 24 | - Piano (all instruments other than bass and drum) 25 | - Bass 26 | - Drums 27 | 28 | - `midi_track_group_config_4.csv`: the config file of MIDI track group info. It groups to 2 tracks based on plug-in name. The 2 tracks represent: 29 | - Piano (all pitched instruments) 30 | - Drums 31 | 32 | - `prepare_closed_set.py`: the code to run dataset processing. Parameters are hardcoded in the main function entry, with comments. 33 | 34 | ## Environment 35 | 36 | - `pip install pretty_midi`: to process MIDI files 37 | - `pip install pyFluidSynth`: to synthesize the processed MIDI files for preview (optional) 38 | - `ffmpeg`: to convert wav to mp3 (optional) 39 | 40 | ## The dataset 41 | 42 | The original Slakh dataset is temporarily stored at: 43 | https://www.dropbox.com/sh/5zh099o75kvpvz3/AADSqL8p3o7wLIvcta3IrP5Da?dl=0 44 | 45 | The MIDI programs used in Slakh dataset are collected for listening at: 46 | https://www.dropbox.com/sh/a0szdx51jl9p9i8/AADSvQnbQG_bmMfAMSsKR8UYa?dl=0 47 | 48 | The sound plug-ins used in Slakh dataset are collected for listening at: 49 | https://www.dropbox.com/sh/a2xw1t2h015nls0/AACJouypEmlgBPPK5EeHdqtxa?dl=0 50 | 51 | -------------------------------------------------------------------------------- /End2End/dataset_creation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/dataset_creation/__init__.py -------------------------------------------------------------------------------- /End2End/dataset_creation/crash.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class ExceptionHook: 5 | instance = None 6 | 7 | def __call__(self, *args, **kwargs): 8 | if self.instance is None: 9 | from IPython.core import ultratb 10 | 11 | self.instance = ultratb.FormattedTB(mode='Plain', color_scheme='Linux', call_pdb=1) 12 | return self.instance(*args, **kwargs) 13 | 14 | 15 | sys.excepthook = ExceptionHook() 16 | -------------------------------------------------------------------------------- /End2End/dataset_creation/create_groove.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import pandas as pd 5 | from concurrent.futures import ProcessPoolExecutor 6 | 7 | from jointist.config import SAMPLE_RATE 8 | from jointist.dataset_creation.create_slakh2100 import write_single_audio_to_hdf5, write_single_midi_to_hdf5 9 | 10 | 11 | def pack_audios_to_hdf5s(args): 12 | r"""Load & resample audios of the Slakh2100 dataset, then write them into 13 | hdf5 files. 14 | 15 | Args: 16 | dataset_dir: str, directory of dataset 17 | workspace: str, directory of your workspace 18 | 19 | Returns: 20 | None 21 | """ 22 | 23 | # arguments & parameters 24 | dataset_root = args.dataset_root 25 | meta_csv_path = args.meta_csv_path 26 | hdf5s_dir = args.hdf5s_dir 27 | sample_rate = SAMPLE_RATE 28 | 29 | df = pd.read_csv(meta_csv_path, sep=',') 30 | audios_num = len(df) 31 | 32 | # paths 33 | feature_extraction_time = time.time() 34 | 35 | for target_split in ["train", "test", "validation"]: 36 | 37 | params = [] 38 | 39 | for audio_index in range(audios_num): 40 | split = df['split'][audio_index] 41 | midi_filename = df['midi_filename'][audio_index] 42 | audio_filename = df['audio_filename'][audio_index] 43 | 44 | audio_path = os.path.join(dataset_root, audio_filename) 45 | 46 | hdf5_path = os.path.join(hdf5s_dir, split, "{}.h5".format(os.path.splitext(midi_filename)[0])) 47 | os.makedirs(os.path.dirname(hdf5_path), exist_ok=True) 48 | 49 | param = (audio_index, audio_path, hdf5_path, audio_filename, split, sample_rate) 50 | 51 | if split == target_split: 52 | params.append(param) 53 | 54 | print("------ Split: {} (Total: {} clips) ------".format(target_split, len(params))) 55 | 56 | # Debug by uncomment the following code. 57 | # write_single_audio_to_hdf5(params[0]) 58 | 59 | # Pack audio files to hdf5 files in parallel. 60 | with ProcessPoolExecutor(max_workers=None) as pool: 61 | pool.map(write_single_audio_to_hdf5, params) 62 | 63 | print("Time: {:.3f} s".format(time.time() - feature_extraction_time)) 64 | 65 | 66 | def pack_midi_events_to_hdf5s(args): 67 | r"""Extract MIDI events of the processed Slakh2100 dataset, and write the 68 | MIDI events to hdf5 files. The processed MIDI files are obtained by merging 69 | tracks from open set tracks to predefined tracks, such as `piano`, `drums`, 70 | `strings`, etc. 71 | 72 | Args: 73 | processed_midis_dir: str, directory of processed MIDI files 74 | hdf5s_dir: str, directory to write out hdf5 files 75 | 76 | Returns: 77 | None 78 | """ 79 | 80 | # arguments & parameters 81 | processed_midis_dir = args.processed_midis_dir 82 | meta_csv_path = args.meta_csv_path 83 | hdf5s_dir = args.hdf5s_dir 84 | 85 | df = pd.read_csv(meta_csv_path, sep=',') 86 | audios_num = len(df) 87 | 88 | # paths 89 | feature_extraction_time = time.time() 90 | 91 | for target_split in ["train", "test", "validation"]: 92 | # for target_split in ["test"]: 93 | 94 | params = [] 95 | 96 | for midi_index in range(audios_num): 97 | split = df['split'][midi_index] 98 | midi_filename = df['midi_filename'][midi_index] 99 | 100 | # audio_path = os.path.join(dataset_root, audio_filename) 101 | midi_path = os.path.join(processed_midis_dir, midi_filename) 102 | 103 | hdf5_path = os.path.join(hdf5s_dir, split, "{}.h5".format(os.path.splitext(midi_filename)[0])) 104 | os.makedirs(os.path.dirname(hdf5_path), exist_ok=True) 105 | 106 | # if '1_funk-groove1_138_beat_4-4_1' in midi_path: 107 | 108 | param = (midi_index, midi_path, hdf5_path, split) 109 | 110 | if split == target_split: 111 | params.append(param) 112 | 113 | print("------ Split: {} (Total: {} clips) ------".format(target_split, len(params))) 114 | 115 | # Debug by uncomment the following code. 116 | # write_single_midi_to_hdf5(params[0]) 117 | 118 | # Pack audio files to hdf5 files in parallel. 119 | with ProcessPoolExecutor() as pool: 120 | pool.map(write_single_midi_to_hdf5, params) 121 | 122 | # for param in params: 123 | # write_single_midi_to_hdf5(param) 124 | 125 | print("Time: {:.3f} s".format(time.time() - feature_extraction_time)) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | subparsers = parser.add_subparsers(dest="mode") 131 | 132 | parser_pack_audios = subparsers.add_parser("pack_audios_to_hdf5s") 133 | parser_pack_audios.add_argument("--dataset_root", type=str, required=True, help="Directory of groove audios.") 134 | parser_pack_audios.add_argument("--meta_csv_path", type=str, required=True, help="Directory of groove audios.") 135 | parser_pack_audios.add_argument( 136 | "--hdf5s_dir", 137 | type=str, 138 | required=True, 139 | help="Directory to write out hdf5 files.", 140 | ) 141 | 142 | parser_pack_midi_events = subparsers.add_parser("pack_midi_events_to_hdf5s") 143 | parser_pack_midi_events.add_argument( 144 | "--processed_midis_dir", 145 | type=str, 146 | required=True, 147 | help="Directory of processed MIDI files.", 148 | ) 149 | 150 | parser_pack_midi_events.add_argument("--meta_csv_path", type=str, required=True, help="Directory of groove audios.") 151 | 152 | parser_pack_midi_events.add_argument( 153 | "--hdf5s_dir", 154 | type=str, 155 | required=True, 156 | help="Directory to write out hdf5 files.", 157 | ) 158 | 159 | # Parse arguments 160 | args = parser.parse_args() 161 | 162 | if args.mode == "pack_audios_to_hdf5s": 163 | pack_audios_to_hdf5s(args) 164 | elif args.mode == "pack_midi_events_to_hdf5s": 165 | pack_midi_events_to_hdf5s(args) 166 | else: 167 | raise Exception("Incorrect arguments!") 168 | -------------------------------------------------------------------------------- /End2End/dataset_creation/create_musdb18.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pathlib 4 | import time 5 | import pandas as pd 6 | from concurrent.futures import ProcessPoolExecutor 7 | 8 | import h5py 9 | import librosa 10 | import numpy as np 11 | import musdb 12 | from mido import MidiFile 13 | 14 | from jointist.config import SAMPLE_RATE 15 | from jointist.utils import float32_to_int16 16 | 17 | from jointist.dataset_creation.create_slakh2100 import write_single_audio_to_hdf5, write_single_midi_to_hdf5 18 | 19 | 20 | def pack_audios_to_hdf5s(args): 21 | r"""Load & resample audios of the Slakh2100 dataset, then write them into 22 | hdf5 files. 23 | 24 | Args: 25 | dataset_dir: str, directory of dataset 26 | workspace: str, directory of your workspace 27 | 28 | Returns: 29 | None 30 | """ 31 | 32 | # arguments & parameters 33 | dataset_root = args.dataset_root 34 | source_type = args.source_type 35 | hdf5s_dir = args.hdf5s_dir 36 | sample_rate = SAMPLE_RATE 37 | 38 | mono = True 39 | resample_type = "kaiser_fast" 40 | 41 | # paths 42 | feature_extraction_time = time.time() 43 | 44 | for subset in ['train', 'test']: 45 | 46 | mus = musdb.DB(root=dataset_root, subsets=subset) 47 | 48 | print("------ Split: {} (Total: {} clips) ------".format(subset, len(mus))) 49 | 50 | for track_index, track in enumerate(mus.tracks): 51 | 52 | hdf5_path = os.path.join(hdf5s_dir, subset, "{}.h5".format(track.name)) 53 | os.makedirs(os.path.dirname(hdf5_path), exist_ok=True) 54 | 55 | with h5py.File(hdf5_path, "w") as hf: 56 | 57 | hf.attrs.create("audio_name", data=track.name.encode(), dtype="S100") 58 | hf.attrs.create("sample_rate", data=sample_rate, dtype=np.int32) 59 | hf.attrs.create("split", data=subset.encode(), dtype="S20") 60 | # hf.attrs.create("duration", data=duration, dtype=np.float32) 61 | 62 | audio = track.targets[source_type].audio.T 63 | # (channels_num, audio_samples) 64 | 65 | # Preprocess audio to mono / stereo, and resample. 66 | audio = preprocess_audio(audio, mono, track.rate, sample_rate, resample_type) 67 | # (audio_samples,) 68 | 69 | hf.create_dataset(name='waveform', data=float32_to_int16(audio), dtype=np.int16) 70 | 71 | hf.attrs.create("duration", data=len(audio) / sample_rate, dtype=np.float32) 72 | 73 | print("{} Write to {}, {}".format(track_index, hdf5_path, audio.shape)) 74 | 75 | print("Time: {:.3f} s".format(time.time() - feature_extraction_time)) 76 | 77 | 78 | def preprocess_audio(audio, mono, origin_sr, sr, resample_type): 79 | r"""Preprocess audio to mono / stereo, and resample. 80 | 81 | Args: 82 | audio: (channels_num, audio_samples), input audio 83 | mono: bool 84 | origin_sr: float, original sample rate 85 | sr: float, target sample rate 86 | resample_type: str, e.g., 'kaiser_fast' 87 | 88 | Returns: 89 | output: ndarray, output audio 90 | """ 91 | 92 | if mono: 93 | audio = np.mean(audio, axis=0) 94 | # (audio_samples,) 95 | 96 | output = librosa.core.resample(audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type) 97 | # (channels_num, audio_samples) | (audio_samples,) 98 | 99 | return output 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser() 104 | subparsers = parser.add_subparsers(dest="mode") 105 | 106 | parser_pack_audios = subparsers.add_parser("pack_audios_to_hdf5s") 107 | parser_pack_audios.add_argument("--dataset_root", type=str, required=True, help="Directory of Slakh2100 audios.") 108 | parser_pack_audios.add_argument("--source_type", type=str, required=True, help="Directory of Slakh2100 audios.") 109 | parser_pack_audios.add_argument( 110 | "--hdf5s_dir", 111 | type=str, 112 | required=True, 113 | help="Directory to write out hdf5 files.", 114 | ) 115 | 116 | # Parse arguments 117 | args = parser.parse_args() 118 | 119 | if args.mode == "pack_audios_to_hdf5s": 120 | pack_audios_to_hdf5s(args) 121 | 122 | else: 123 | raise Exception("Incorrect arguments!") 124 | -------------------------------------------------------------------------------- /End2End/dataset_creation/midi_track_group_config_2.csv: -------------------------------------------------------------------------------- 1 | Plugin_name,Closed set,Mapped MIDI program number,Description 2 | AGML2,Piano,0,acoustic guitar 3 | None,None,-2,/ 4 | across_the_pacific,Strings,48,/ 5 | ahoy,Organ,16,/ 6 | alicias_keys,Piano,0,/ 7 | alto_sax_vintage_solo,Organ,16,/ 8 | alto_saxophone,Organ,16,/ 9 | ambibella,Organ,16,Ambient organ with clear onset 10 | april_pan,Strings,48,Ambient organ 11 | ar_modern_sparkle_kit_full,Drums,-1,/ 12 | ar_modern_white_kit_full,Drums,-1,/ 13 | arctic_morning,Strings,48,/ 14 | august_foerster_grand,Piano,0,/ 15 | baritone_sax_vintage_solo,Organ,16,/ 16 | baritone_saxophone,Organ,16,/ 17 | bass_trombone,Organ,16,/ 18 | bassoon,Organ,16,/ 19 | bassoon_combi,Organ,16,/ 20 | bassoon_essential,Organ,16,/ 21 | bassoons_essential,Organ,16,/ 22 | belle_de_jour,Strings,48,/ 23 | brass_quartet_essential,Organ,16,/ 24 | celesta,Piano,0,Llike bell 25 | cello_ensemble,Strings,48,/ 26 | cello_solo,Strings,48,Maybe organ? 27 | cerulean,Strings,48,/ 28 | cheesy_lead,Organ,16,/ 29 | choir_a,Strings,48,/ 30 | choir_e,Strings,48,/ 31 | choir_o,Strings,48,/ 32 | chrystal,Organ,16,/ 33 | clarinet,Organ,16,/ 34 | clarinet_combi,Organ,16,/ 35 | clarinet_essential,Organ,16,/ 36 | clarinets_essential,Organ,16,/ 37 | classic_bass,Bass,33,/ 38 | cold_cave,Strings,48,/ 39 | concert_grand,Piano,0,/ 40 | crawling_lead,Distorted,30,/ 41 | daft,Distorted,30,Like brass 42 | dawn_chorus,Strings,48,/ 43 | december_saw,Distorted,30,Like effects 44 | double_bass_ensemble,Strings,48,It is bass bowing not plucking so still take as strings 45 | double_bass_solo,Strings,48,/ 46 | douglas_lead,Distorted,30,/ 47 | downforce_saw,Distorted,30,Not sure ? 48 | drifting_apart,Organ,16,/ 49 | elektrik_guitar,Distorted,30,/ 50 | english_horn,Organ,16,/ 51 | fever,Organ,16,/ 52 | flugelhorn,Organ,16,/ 53 | flute,Organ,16,/ 54 | flute_essential,Organ,16,/ 55 | flutes_essential,Organ,16,/ 56 | french_oboe,Organ,16,/ 57 | funk_bass,Bass,33,/ 58 | funk_guitar,Distorted,30,Not sure ? 59 | funk_kit,Drums,-1,/ 60 | garage_kit_lite,Drums,-1,/ 61 | glockenspiel,Piano,0,Llike bell may need another category ? 62 | glockenspiel_essential,Piano,0,Llike bell may need another category ? 63 | grand_piano,Piano,0,/ 64 | guitar_lead,Distorted,30,/ 65 | hard_n_dirty,Distorted,30,/ 66 | harmonic_guitar,Distorted,30,/ 67 | harp,Piano,0,More like guitar 68 | harpsichord,Piano,0,More like guitar 69 | horn_1_essential,Organ,16,/ 70 | horn_2_essential,Organ,16,/ 71 | house_cat,Distorted,30,Not sure ? 72 | hybrid_keys_antique_toy,Piano,0,like bell with kick drum ? 73 | hybrid_keys_concert_marimba,Piano,0,Like bell 74 | hybrid_keys_futurebells,Piano,0,bell (some tracks are distorted) 75 | hybrid_keys_glockenspiel,Piano,0,like bell with kick drum 76 | hybrid_keys_hot_tropics,Piano,0,bell with noise ? 77 | hybrid_keys_tube_vibraphone,Piano,0,like bell with kick drum ? 78 | hybrid_lead,Distorted,30,/ 79 | jazz_guitar,Piano,0,guitar 80 | jazz_guitar2,Piano,0,guitar 81 | jazz_guitar3,Organ,16,More like organ ? 82 | jazz_guitar4,Piano,0,guitar 83 | jazz_organ,Organ,16,/ 84 | jazz_upright,Bass,33,/ 85 | marimba,Piano,0,Maybe bell ? 86 | marimba_essential,Piano,0,Maybe bell ? 87 | musicology,Organ,16,Some are weird 88 | mute_trumpet,Organ,16,/ 89 | muted_trumpet,Organ,16,/ 90 | mystic_lead,Distorted,30,/ 91 | nylon_guitar,Piano,0,/ 92 | nylon_guitar2,Piano,0,/ 93 | oboe,Organ,16,/ 94 | oboe_essential,Organ,16,/ 95 | oboes_essential,Organ,16,/ 96 | organ_kh_floeten_1_manual,Organ,16,/ 97 | organ_kh_grprplenum_manual,Organ,16,/ 98 | outta_space,Distorted,30,/ 99 | percussive_lead,Organ,16,/ 100 | piccolo,Organ,16,/ 101 | pimped_analog_saw,Distorted,30,/ 102 | poly_detuned_lead,Distorted,30,/ 103 | pop_bass,Bass,33,/ 104 | pop_kit,Drums,-1,/ 105 | processor_lead,Distorted,30,/ 106 | ragtime_piano,Piano,0,/ 107 | reamped_lead,Distorted,30,/ 108 | rhythm_rock_guitar,Distorted,30,/ 109 | rock_guitar,Organ,16,Some are distorted 110 | saxophone_essential,Organ,16,/ 111 | saxophone_section,Organ,16,/ 112 | saxophones_essential,Organ,16,/ 113 | scarbee_a_200,Piano,0,guitar 114 | scarbee_clavinet_full,Piano,0,like guitar 115 | scarbee_jay_bass_both,Bass,33,/ 116 | scarbee_jay_bass_slap_both,Bass,33,/ 117 | scarbee_mark_I,Organ,16,/ 118 | scarbee_mm_bass,Bass,33,/ 119 | scarbee_pianet,Piano,0,guitar 120 | scarbee_pre_bass,Bass,33,/ 121 | scarbee_rickenbacker_bass,Bass,33,/ 122 | scarbee_rickenbacker_bass_palm_muted,Bass,33,/ 123 | session_horns_pro_keyswitch_60s_horns,Organ,16,trumpet 124 | session_horns_pro_keyswitch_generic_section,Organ,16,trumpet 125 | session_kit_full,Drums,-1,/ 126 | session_strings_pro_2_ensemble_modern,Strings,48,/ 127 | session_strings_pro_2_ensemble_traditional,Strings,48,/ 128 | solo_guitar,Piano,0,Longer sustain sometimes distorted 129 | solo_strings,Strings,48,Solo maybe like organ? 130 | stadium_kit_full,Drums,-1,/ 131 | street_knowledge_kit,Drums,-1,/ 132 | string_ensemble,Strings,48,/ 133 | string_ensemble_essential,Strings,48,/ 134 | tenor_sax,Organ,16,/ 135 | tenor_saxophone,Organ,16,/ 136 | tenor_trombone,Organ,16,/ 137 | the_gentleman,Piano,0,/ 138 | the_giant_hard_and_tough,Piano,0,/ 139 | the_giant_modern_studio,Piano,0,/ 140 | the_giant_vibrant,Piano,0,/ 141 | the_grandeur,Piano,0,/ 142 | timpani,Bass,33,Unique pitched drum? 143 | tonewheel_organ_b3,Organ,16,/ 144 | tonewheel_organ_c3,Organ,16,/ 145 | tonewheel_organ_m3,Organ,16,/ 146 | transistor_compact,Organ,16,/ 147 | transistor_continental,Organ,16,/ 148 | trombone,Organ,16,trumpet 149 | trombone_section,Organ,16,trumpet 150 | trumpet,Organ,16,trumpet 151 | trumpet_1,Organ,16,trumpet 152 | trumpet_2,Organ,16,trumpet 153 | trumpet_section,Organ,16,trumpet 154 | tuba,Organ,16,trumpet 155 | tubular_bells_metal,Piano,0,Like distorted bell ? 156 | tubular_bells_wood,Piano,0,Like distorted bell ? 157 | upright_bass,Bass,33,/ 158 | upright_bass2,Bass,33,/ 159 | upright_piano,Piano,0,/ 160 | viola_ensemble,Strings,48,/ 161 | viola_solo,Strings,48,Like organ ? 162 | violin_ensemble,Strings,48,/ 163 | violin_solo,Strings,48,Like organ ? 164 | woodwind_ensemble_essential,Organ,16,/ 165 | woodwind_quintet_essential,Organ,16,/ 166 | wurly_ep,Piano,0,Like guitar 167 | xylophone,Piano,0,Like distorted bell ? 168 | xylophone_essential,Piano,0,Like distorted bell ? -------------------------------------------------------------------------------- /End2End/dataset_creation/midi_track_group_config_3.csv: -------------------------------------------------------------------------------- 1 | Plugin_name,Closed set,Mapped MIDI program number,Description 2 | AGML2,Piano,0,acoustic guitar 3 | None,None,-2,/ 4 | across_the_pacific,Strings,0,/ 5 | ahoy,Organ,0,/ 6 | alicias_keys,Piano,0,/ 7 | alto_sax_vintage_solo,Organ,0,/ 8 | alto_saxophone,Organ,0,/ 9 | ambibella,Organ,0,Ambient organ with clear onset 10 | april_pan,Strings,0,Ambient organ 11 | ar_modern_sparkle_kit_full,Drums,-1,/ 12 | ar_modern_white_kit_full,Drums,-1,/ 13 | arctic_morning,Strings,0,/ 14 | august_foerster_grand,Piano,0,/ 15 | baritone_sax_vintage_solo,Organ,0,/ 16 | baritone_saxophone,Organ,0,/ 17 | bass_trombone,Organ,0,/ 18 | bassoon,Organ,0,/ 19 | bassoon_combi,Organ,0,/ 20 | bassoon_essential,Organ,0,/ 21 | bassoons_essential,Organ,0,/ 22 | belle_de_jour,Strings,0,/ 23 | brass_quartet_essential,Organ,0,/ 24 | celesta,Piano,0,Llike bell 25 | cello_ensemble,Strings,0,/ 26 | cello_solo,Strings,0,Maybe organ? 27 | cerulean,Strings,0,/ 28 | cheesy_lead,Organ,0,/ 29 | choir_a,Strings,0,/ 30 | choir_e,Strings,0,/ 31 | choir_o,Strings,0,/ 32 | chrystal,Organ,0,/ 33 | clarinet,Organ,0,/ 34 | clarinet_combi,Organ,0,/ 35 | clarinet_essential,Organ,0,/ 36 | clarinets_essential,Organ,0,/ 37 | classic_bass,Bass,33,/ 38 | cold_cave,Strings,0,/ 39 | concert_grand,Piano,0,/ 40 | crawling_lead,Distorted,0,/ 41 | daft,Distorted,0,Like brass 42 | dawn_chorus,Strings,0,/ 43 | december_saw,Distorted,0,Like effects 44 | double_bass_ensemble,Strings,0,It is bass bowing not plucking so still take as strings 45 | double_bass_solo,Strings,0,/ 46 | douglas_lead,Distorted,0,/ 47 | downforce_saw,Distorted,0,Not sure ? 48 | drifting_apart,Organ,0,/ 49 | elektrik_guitar,Distorted,0,/ 50 | english_horn,Organ,0,/ 51 | fever,Organ,0,/ 52 | flugelhorn,Organ,0,/ 53 | flute,Organ,0,/ 54 | flute_essential,Organ,0,/ 55 | flutes_essential,Organ,0,/ 56 | french_oboe,Organ,0,/ 57 | funk_bass,Bass,33,/ 58 | funk_guitar,Distorted,0,Not sure ? 59 | funk_kit,Drums,-1,/ 60 | garage_kit_lite,Drums,-1,/ 61 | glockenspiel,Piano,0,Llike bell may need another category ? 62 | glockenspiel_essential,Piano,0,Llike bell may need another category ? 63 | grand_piano,Piano,0,/ 64 | guitar_lead,Distorted,0,/ 65 | hard_n_dirty,Distorted,0,/ 66 | harmonic_guitar,Distorted,0,/ 67 | harp,Piano,0,More like guitar 68 | harpsichord,Piano,0,More like guitar 69 | horn_1_essential,Organ,0,/ 70 | horn_2_essential,Organ,0,/ 71 | house_cat,Distorted,0,Not sure ? 72 | hybrid_keys_antique_toy,Piano,0,like bell with kick drum ? 73 | hybrid_keys_concert_marimba,Piano,0,Like bell 74 | hybrid_keys_futurebells,Piano,0,bell (some tracks are distorted) 75 | hybrid_keys_glockenspiel,Piano,0,like bell with kick drum 76 | hybrid_keys_hot_tropics,Piano,0,bell with noise ? 77 | hybrid_keys_tube_vibraphone,Piano,0,like bell with kick drum ? 78 | hybrid_lead,Distorted,0,/ 79 | jazz_guitar,Piano,0,guitar 80 | jazz_guitar2,Piano,0,guitar 81 | jazz_guitar3,Organ,0,More like organ ? 82 | jazz_guitar4,Piano,0,guitar 83 | jazz_organ,Organ,0,/ 84 | jazz_upright,Bass,33,/ 85 | marimba,Piano,0,Maybe bell ? 86 | marimba_essential,Piano,0,Maybe bell ? 87 | musicology,Organ,0,Some are weird 88 | mute_trumpet,Organ,0,/ 89 | muted_trumpet,Organ,0,/ 90 | mystic_lead,Distorted,0,/ 91 | nylon_guitar,Piano,0,/ 92 | nylon_guitar2,Piano,0,/ 93 | oboe,Organ,0,/ 94 | oboe_essential,Organ,0,/ 95 | oboes_essential,Organ,0,/ 96 | organ_kh_floeten_1_manual,Organ,0,/ 97 | organ_kh_grprplenum_manual,Organ,0,/ 98 | outta_space,Distorted,0,/ 99 | percussive_lead,Organ,0,/ 100 | piccolo,Organ,0,/ 101 | pimped_analog_saw,Distorted,0,/ 102 | poly_detuned_lead,Distorted,0,/ 103 | pop_bass,Bass,33,/ 104 | pop_kit,Drums,-1,/ 105 | processor_lead,Distorted,0,/ 106 | ragtime_piano,Piano,0,/ 107 | reamped_lead,Distorted,0,/ 108 | rhythm_rock_guitar,Distorted,0,/ 109 | rock_guitar,Organ,0,Some are distorted 110 | saxophone_essential,Organ,0,/ 111 | saxophone_section,Organ,0,/ 112 | saxophones_essential,Organ,0,/ 113 | scarbee_a_200,Piano,0,guitar 114 | scarbee_clavinet_full,Piano,0,like guitar 115 | scarbee_jay_bass_both,Bass,33,/ 116 | scarbee_jay_bass_slap_both,Bass,33,/ 117 | scarbee_mark_I,Organ,0,/ 118 | scarbee_mm_bass,Bass,33,/ 119 | scarbee_pianet,Piano,0,guitar 120 | scarbee_pre_bass,Bass,33,/ 121 | scarbee_rickenbacker_bass,Bass,33,/ 122 | scarbee_rickenbacker_bass_palm_muted,Bass,33,/ 123 | session_horns_pro_keyswitch_60s_horns,Organ,0,trumpet 124 | session_horns_pro_keyswitch_generic_section,Organ,0,trumpet 125 | session_kit_full,Drums,-1,/ 126 | session_strings_pro_2_ensemble_modern,Strings,0,/ 127 | session_strings_pro_2_ensemble_traditional,Strings,0,/ 128 | solo_guitar,Piano,0,Longer sustain sometimes distorted 129 | solo_strings,Strings,0,Solo maybe like organ? 130 | stadium_kit_full,Drums,-1,/ 131 | street_knowledge_kit,Drums,-1,/ 132 | string_ensemble,Strings,0,/ 133 | string_ensemble_essential,Strings,0,/ 134 | tenor_sax,Organ,0,/ 135 | tenor_saxophone,Organ,0,/ 136 | tenor_trombone,Organ,0,/ 137 | the_gentleman,Piano,0,/ 138 | the_giant_hard_and_tough,Piano,0,/ 139 | the_giant_modern_studio,Piano,0,/ 140 | the_giant_vibrant,Piano,0,/ 141 | the_grandeur,Piano,0,/ 142 | timpani,Bass,33,Unique pitched drum? 143 | tonewheel_organ_b3,Organ,0,/ 144 | tonewheel_organ_c3,Organ,0,/ 145 | tonewheel_organ_m3,Organ,0,/ 146 | transistor_compact,Organ,0,/ 147 | transistor_continental,Organ,0,/ 148 | trombone,Organ,0,trumpet 149 | trombone_section,Organ,0,trumpet 150 | trumpet,Organ,0,trumpet 151 | trumpet_1,Organ,0,trumpet 152 | trumpet_2,Organ,0,trumpet 153 | trumpet_section,Organ,0,trumpet 154 | tuba,Organ,0,trumpet 155 | tubular_bells_metal,Piano,0,Like distorted bell ? 156 | tubular_bells_wood,Piano,0,Like distorted bell ? 157 | upright_bass,Bass,33,/ 158 | upright_bass2,Bass,33,/ 159 | upright_piano,Piano,0,/ 160 | viola_ensemble,Strings,0,/ 161 | viola_solo,Strings,0,Like organ ? 162 | violin_ensemble,Strings,0,/ 163 | violin_solo,Strings,0,Like organ ? 164 | woodwind_ensemble_essential,Organ,0,/ 165 | woodwind_quintet_essential,Organ,0,/ 166 | wurly_ep,Piano,0,Like guitar 167 | xylophone,Piano,0,Like distorted bell ? 168 | xylophone_essential,Piano,0,Like distorted bell ? -------------------------------------------------------------------------------- /End2End/dataset_creation/midi_track_group_config_4.csv: -------------------------------------------------------------------------------- 1 | Plugin_name,Closed set,Mapped MIDI program number,Description 2 | AGML2,Piano,0,acoustic guitar 3 | None,None,-2,/ 4 | across_the_pacific,Strings,0,/ 5 | ahoy,Organ,0,/ 6 | alicias_keys,Piano,0,/ 7 | alto_sax_vintage_solo,Organ,0,/ 8 | alto_saxophone,Organ,0,/ 9 | ambibella,Organ,0,Ambient organ with clear onset 10 | april_pan,Strings,0,Ambient organ 11 | ar_modern_sparkle_kit_full,Drums,-1,/ 12 | ar_modern_white_kit_full,Drums,-1,/ 13 | arctic_morning,Strings,0,/ 14 | august_foerster_grand,Piano,0,/ 15 | baritone_sax_vintage_solo,Organ,0,/ 16 | baritone_saxophone,Organ,0,/ 17 | bass_trombone,Organ,0,/ 18 | bassoon,Organ,0,/ 19 | bassoon_combi,Organ,0,/ 20 | bassoon_essential,Organ,0,/ 21 | bassoons_essential,Organ,0,/ 22 | belle_de_jour,Strings,0,/ 23 | brass_quartet_essential,Organ,0,/ 24 | celesta,Piano,0,Llike bell 25 | cello_ensemble,Strings,0,/ 26 | cello_solo,Strings,0,Maybe organ? 27 | cerulean,Strings,0,/ 28 | cheesy_lead,Organ,0,/ 29 | choir_a,Strings,0,/ 30 | choir_e,Strings,0,/ 31 | choir_o,Strings,0,/ 32 | chrystal,Organ,0,/ 33 | clarinet,Organ,0,/ 34 | clarinet_combi,Organ,0,/ 35 | clarinet_essential,Organ,0,/ 36 | clarinets_essential,Organ,0,/ 37 | classic_bass,Bass,0,/ 38 | cold_cave,Strings,0,/ 39 | concert_grand,Piano,0,/ 40 | crawling_lead,Distorted,0,/ 41 | daft,Distorted,0,Like brass 42 | dawn_chorus,Strings,0,/ 43 | december_saw,Distorted,0,Like effects 44 | double_bass_ensemble,Strings,0,It is bass bowing not plucking so still take as strings 45 | double_bass_solo,Strings,0,/ 46 | douglas_lead,Distorted,0,/ 47 | downforce_saw,Distorted,0,Not sure ? 48 | drifting_apart,Organ,0,/ 49 | elektrik_guitar,Distorted,0,/ 50 | english_horn,Organ,0,/ 51 | fever,Organ,0,/ 52 | flugelhorn,Organ,0,/ 53 | flute,Organ,0,/ 54 | flute_essential,Organ,0,/ 55 | flutes_essential,Organ,0,/ 56 | french_oboe,Organ,0,/ 57 | funk_bass,Bass,0,/ 58 | funk_guitar,Distorted,0,Not sure ? 59 | funk_kit,Drums,-1,/ 60 | garage_kit_lite,Drums,-1,/ 61 | glockenspiel,Piano,0,Llike bell may need another category ? 62 | glockenspiel_essential,Piano,0,Llike bell may need another category ? 63 | grand_piano,Piano,0,/ 64 | guitar_lead,Distorted,0,/ 65 | hard_n_dirty,Distorted,0,/ 66 | harmonic_guitar,Distorted,0,/ 67 | harp,Piano,0,More like guitar 68 | harpsichord,Piano,0,More like guitar 69 | horn_1_essential,Organ,0,/ 70 | horn_2_essential,Organ,0,/ 71 | house_cat,Distorted,0,Not sure ? 72 | hybrid_keys_antique_toy,Piano,0,like bell with kick drum ? 73 | hybrid_keys_concert_marimba,Piano,0,Like bell 74 | hybrid_keys_futurebells,Piano,0,bell (some tracks are distorted) 75 | hybrid_keys_glockenspiel,Piano,0,like bell with kick drum 76 | hybrid_keys_hot_tropics,Piano,0,bell with noise ? 77 | hybrid_keys_tube_vibraphone,Piano,0,like bell with kick drum ? 78 | hybrid_lead,Distorted,0,/ 79 | jazz_guitar,Piano,0,guitar 80 | jazz_guitar2,Piano,0,guitar 81 | jazz_guitar3,Organ,0,More like organ ? 82 | jazz_guitar4,Piano,0,guitar 83 | jazz_organ,Organ,0,/ 84 | jazz_upright,Bass,0,/ 85 | marimba,Piano,0,Maybe bell ? 86 | marimba_essential,Piano,0,Maybe bell ? 87 | musicology,Organ,0,Some are weird 88 | mute_trumpet,Organ,0,/ 89 | muted_trumpet,Organ,0,/ 90 | mystic_lead,Distorted,0,/ 91 | nylon_guitar,Piano,0,/ 92 | nylon_guitar2,Piano,0,/ 93 | oboe,Organ,0,/ 94 | oboe_essential,Organ,0,/ 95 | oboes_essential,Organ,0,/ 96 | organ_kh_floeten_1_manual,Organ,0,/ 97 | organ_kh_grprplenum_manual,Organ,0,/ 98 | outta_space,Distorted,0,/ 99 | percussive_lead,Organ,0,/ 100 | piccolo,Organ,0,/ 101 | pimped_analog_saw,Distorted,0,/ 102 | poly_detuned_lead,Distorted,0,/ 103 | pop_bass,Bass,0,/ 104 | pop_kit,Drums,-1,/ 105 | processor_lead,Distorted,0,/ 106 | ragtime_piano,Piano,0,/ 107 | reamped_lead,Distorted,0,/ 108 | rhythm_rock_guitar,Distorted,0,/ 109 | rock_guitar,Organ,0,Some are distorted 110 | saxophone_essential,Organ,0,/ 111 | saxophone_section,Organ,0,/ 112 | saxophones_essential,Organ,0,/ 113 | scarbee_a_200,Piano,0,guitar 114 | scarbee_clavinet_full,Piano,0,like guitar 115 | scarbee_jay_bass_both,Bass,0,/ 116 | scarbee_jay_bass_slap_both,Bass,0,/ 117 | scarbee_mark_I,Organ,0,/ 118 | scarbee_mm_bass,Bass,0,/ 119 | scarbee_pianet,Piano,0,guitar 120 | scarbee_pre_bass,Bass,0,/ 121 | scarbee_rickenbacker_bass,Bass,0,/ 122 | scarbee_rickenbacker_bass_palm_muted,Bass,0,/ 123 | session_horns_pro_keyswitch_60s_horns,Organ,0,trumpet 124 | session_horns_pro_keyswitch_generic_section,Organ,0,trumpet 125 | session_kit_full,Drums,-1,/ 126 | session_strings_pro_2_ensemble_modern,Strings,0,/ 127 | session_strings_pro_2_ensemble_traditional,Strings,0,/ 128 | solo_guitar,Piano,0,Longer sustain sometimes distorted 129 | solo_strings,Strings,0,Solo maybe like organ? 130 | stadium_kit_full,Drums,-1,/ 131 | street_knowledge_kit,Drums,-1,/ 132 | string_ensemble,Strings,0,/ 133 | string_ensemble_essential,Strings,0,/ 134 | tenor_sax,Organ,0,/ 135 | tenor_saxophone,Organ,0,/ 136 | tenor_trombone,Organ,0,/ 137 | the_gentleman,Piano,0,/ 138 | the_giant_hard_and_tough,Piano,0,/ 139 | the_giant_modern_studio,Piano,0,/ 140 | the_giant_vibrant,Piano,0,/ 141 | the_grandeur,Piano,0,/ 142 | timpani,Bass,0,Unique pitched drum? 143 | tonewheel_organ_b3,Organ,0,/ 144 | tonewheel_organ_c3,Organ,0,/ 145 | tonewheel_organ_m3,Organ,0,/ 146 | transistor_compact,Organ,0,/ 147 | transistor_continental,Organ,0,/ 148 | trombone,Organ,0,trumpet 149 | trombone_section,Organ,0,trumpet 150 | trumpet,Organ,0,trumpet 151 | trumpet_1,Organ,0,trumpet 152 | trumpet_2,Organ,0,trumpet 153 | trumpet_section,Organ,0,trumpet 154 | tuba,Organ,0,trumpet 155 | tubular_bells_metal,Piano,0,Like distorted bell ? 156 | tubular_bells_wood,Piano,0,Like distorted bell ? 157 | upright_bass,Bass,0,/ 158 | upright_bass2,Bass,0,/ 159 | upright_piano,Piano,0,/ 160 | viola_ensemble,Strings,0,/ 161 | viola_solo,Strings,0,Like organ ? 162 | violin_ensemble,Strings,0,/ 163 | violin_solo,Strings,0,Like organ ? 164 | woodwind_ensemble_essential,Organ,0,/ 165 | woodwind_quintet_essential,Organ,0,/ 166 | wurly_ep,Piano,0,Like guitar 167 | xylophone,Piano,0,Like distorted bell ? 168 | xylophone_essential,Piano,0,Like distorted bell ? -------------------------------------------------------------------------------- /End2End/dataset_creation/mixing_secrets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/dataset_creation/mixing_secrets/__init__.py -------------------------------------------------------------------------------- /End2End/dataset_creation/mixing_secrets/segment_vocal_stems.py: -------------------------------------------------------------------------------- 1 | """Run this file after running `unzip_vocals_only.py` in mixing-secrets repo 2 | """ 3 | import os 4 | import tqdm 5 | import argparse 6 | import tempfile 7 | import multiprocessing 8 | 9 | import numpy as np 10 | import auditok 11 | import librosa 12 | import soundfile as sf 13 | 14 | TARGET_SR = 16000 15 | 16 | 17 | def _process_one(args): 18 | """load audio, resample and save it again to a tempo dir, split it.""" 19 | source_path, target_path, wav_fn = args 20 | with tempfile.TemporaryDirectory() as temp_dir: 21 | src, sr = librosa.load(os.path.join(source_path, wav_fn), sr=TARGET_SR, mono=True, dtype=np.float32) 22 | 23 | src = np.expand_dims(src, 1) # time. ch for pysoundfile 24 | 25 | resampled_audio_path = os.path.join(temp_dir, 'audio.wav') 26 | sf.write(resampled_audio_path, src, sr, subtype='PCM_16') 27 | 28 | audio_regions = auditok.split( 29 | resampled_audio_path, 30 | max_dur=10.0, 31 | max_silence=1.0, 32 | sampling_rate=sr, 33 | channels=1, 34 | sample_width=1, 35 | drop_trailing_silence=True, 36 | analysis_window=1.0, 37 | ) 38 | for i, r in enumerate(audio_regions): 39 | filename = r.save( 40 | os.path.join( 41 | target_path, wav_fn.replace(' ', '-').replace('.wav', '') + "_{meta.start:.3f}-{meta.end:.3f}.wav" 42 | ) 43 | ) 44 | # print("region saved as: {}".format(filename)) 45 | 46 | 47 | def main(source_path, target_path): 48 | os.makedirs(target_path, exist_ok=True) 49 | 50 | wav_fns = [f for f in os.listdir(source_path) if f.endswith('wav')] 51 | 52 | pool = multiprocessing.Pool(multiprocessing.cpu_count()) 53 | 54 | for _ in tqdm.tqdm( 55 | pool.imap_unordered(_process_one, [(source_path, target_path, wav_fn) for wav_fn in wav_fns]), 56 | total=len(wav_fns), 57 | ): 58 | pass 59 | 60 | pool.close() 61 | pool.join() 62 | print('done.') 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('--source_path', type=str, required=True, help='path that has all vocal stems') 68 | parser.add_argument('--target_path', type=str, required=True, help='path to save segmented vocal stems') 69 | 70 | args = parser.parse_args() 71 | 72 | main(args.source_path, args.target_path) 73 | -------------------------------------------------------------------------------- /End2End/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | def get_lr_lambda(step, warm_up_steps, reduce_lr_steps): 2 | r"""Get lr_lambda for LambdaLR. E.g., 3 | 4 | .. code-block: python 5 | lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000) 6 | 7 | from torch.optim.lr_scheduler import LambdaLR 8 | LambdaLR(optimizer, lr_lambda) 9 | """ 10 | if step <= warm_up_steps: 11 | return step / warm_up_steps 12 | else: 13 | return 0.9 ** (step // reduce_lr_steps) 14 | -------------------------------------------------------------------------------- /End2End/models/instrument_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .detr import Cnn14Transformer, Cnn14_DETR_transformer 2 | # from .backbones import CNN14 -------------------------------------------------------------------------------- /End2End/models/instrument_detection/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def summarized_output(segmentwise_output, threshold=0.5): 5 | """ 6 | input: 7 | segmentwise_output: (N, output_classes) 8 | output: 9 | (output_classes) 10 | """ 11 | 12 | bool_pred = torch.sigmoid(segmentwise_output)>threshold 13 | bool_summarized = torch.zeros(bool_pred.shape[1]).to(bool_pred.device) 14 | 15 | for i in bool_pred: 16 | bool_summarized = torch.logical_or(bool_summarized, i) 17 | 18 | return bool_summarized 19 | 20 | 21 | def obtain_segments(audio, segment_samples): 22 | # Preparing placeholders for audio segmenting 23 | audio_length = audio.shape[1] 24 | # Pad audio to be evenly divided by segment_samples. 25 | pad_len = int(np.ceil(audio_length / segment_samples)) * segment_samples - audio_length 26 | 27 | if audio_length>segment_samples: 28 | audio = torch.cat((audio, torch.zeros((1, pad_len), device=audio.device)), axis=1) 29 | 30 | # Enframe to segments. 31 | segments = audio.unfold(1, segment_samples, segment_samples//2).squeeze(0) # faster version of enframe 32 | # (N, segment_samples) 33 | return segments, audio_length 34 | else: 35 | return audio, audio_length -------------------------------------------------------------------------------- /End2End/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class PositionEmbeddingSine(nn.Module): 11 | """ 12 | This is a more standard version of the position embedding, very similar to the one 13 | used by the Attention is all you need paper, generalized to work on images. 14 | """ 15 | def __init__(self, temperature=10000, normalize=False, scale=None): 16 | super().__init__() 17 | self.temperature = temperature 18 | self.normalize = normalize 19 | if scale is not None and normalize is False: 20 | raise ValueError("normalize should be True if scale is passed") 21 | if scale is None: 22 | scale = 2 * math.pi 23 | self.scale = scale 24 | 25 | def forward(self, x): 26 | # x here is only used to get the shape 27 | # The actual values of x is never used 28 | B, C, H, W = x.shape 29 | num_pos_feats = C//2 30 | mask = torch.ones(B, H, W).to(x.device) 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (torch.div(dim_t,2, rounding_mode='trunc')) / num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | 48 | 49 | class PositionEmbeddingSinev2(nn.Module): 50 | """ 51 | This is a more standard version of the position embedding, very similar to the one 52 | used by the Attention is all you need paper, generalized to work on images. 53 | """ 54 | def __init__(self, temperature=10000, normalize=False, scale=None): 55 | super().__init__() 56 | self.temperature = temperature 57 | self.normalize = normalize 58 | if scale is not None and normalize is False: 59 | raise ValueError("normalize should be True if scale is passed") 60 | if scale is None: 61 | scale = 2 * math.pi 62 | self.scale = scale 63 | 64 | def forward(self, x): 65 | # x here is only used to get the shape 66 | # The actual values of x is never used 67 | B, C, T = x.shape 68 | num_pos_feats = C 69 | mask = torch.ones(B, T).to(x.device) 70 | y_embed = mask.cumsum(1, dtype=torch.float32) 71 | if self.normalize: 72 | eps = 1e-6 73 | y_embed = y_embed / (y_embed[:, -1:] + eps) * self.scale 74 | 75 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device) 76 | dim_t = self.temperature ** (2 * (torch.div(dim_t,2, rounding_mode='trunc')) / num_pos_feats) 77 | 78 | pos_y = y_embed[:, :, None] / dim_t 79 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 80 | pos = pos_y.permute(0, 2, 1) 81 | return pos 82 | 83 | 84 | class PositionEmbeddingLearned(nn.Module): 85 | """ 86 | Absolute pos embedding, learned. 87 | """ 88 | def __init__(self, num_pos_feats=256): 89 | super().__init__() 90 | self.row_embed = nn.Embedding(50, num_pos_feats) 91 | self.col_embed = nn.Embedding(50, num_pos_feats) 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | nn.init.uniform_(self.row_embed.weight) 96 | nn.init.uniform_(self.col_embed.weight) 97 | 98 | def forward(self, x): 99 | h, w = x.shape[-2:] 100 | i = torch.arange(w, device=x.device) 101 | j = torch.arange(h, device=x.device) 102 | x_emb = self.col_embed(i) 103 | y_emb = self.row_embed(j) 104 | pos = torch.cat([ 105 | x_emb.unsqueeze(0).repeat(h, 1, 1), 106 | y_emb.unsqueeze(1).repeat(1, w, 1), 107 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 108 | return pos 109 | 110 | 111 | def build_position_encoding(args): 112 | N_steps = args.hidden_dim // 2 113 | if args.position_embedding in ('v2', 'sine'): 114 | # TODO find a better way of exposing other arguments 115 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 116 | elif args.position_embedding in ('v3', 'learned'): 117 | position_embedding = PositionEmbeddingLearned(N_steps) 118 | else: 119 | raise ValueError(f"not supported {args.position_embedding}") 120 | 121 | return position_embedding -------------------------------------------------------------------------------- /End2End/models/separation/__init__.py: -------------------------------------------------------------------------------- 1 | from .cond_unet import CondUNet 2 | from .t_cond_unet import TCondUNet -------------------------------------------------------------------------------- /End2End/models/separation/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | 8 | def magphase(real, img): 9 | phase = torch.atan2(img, real) 10 | return phase, torch.cos(phase), torch.sin(phase) 11 | 12 | def init_layer(layer): 13 | """Initialize a Linear or Convolutional layer. """ 14 | nn.init.xavier_uniform_(layer.weight) 15 | 16 | if hasattr(layer, "bias"): 17 | if layer.bias is not None: 18 | layer.bias.data.fill_(0.0) 19 | 20 | 21 | def init_bn(bn): 22 | """Initialize a Batchnorm layer. """ 23 | bn.bias.data.fill_(0.0) 24 | bn.weight.data.fill_(1.0) 25 | 26 | 27 | def init_embedding(layer): 28 | """Initialize a Linear or Convolutional layer. """ 29 | nn.init.uniform_(layer.weight, -1., 1.) 30 | 31 | if hasattr(layer, 'bias'): 32 | if layer.bias is not None: 33 | layer.bias.data.fill_(0.) 34 | 35 | 36 | def init_gru(rnn): 37 | """Initialize a GRU layer. """ 38 | 39 | def _concat_init(tensor, init_funcs): 40 | (length, fan_out) = tensor.shape 41 | fan_in = length // len(init_funcs) 42 | 43 | for (i, init_func) in enumerate(init_funcs): 44 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 45 | 46 | def _inner_uniform(tensor): 47 | fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") 48 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 49 | 50 | for i in range(rnn.num_layers): 51 | _concat_init( 52 | getattr(rnn, "weight_ih_l{}".format(i)), 53 | [_inner_uniform, _inner_uniform, _inner_uniform], 54 | ) 55 | torch.nn.init.constant_(getattr(rnn, "bias_ih_l{}".format(i)), 0) 56 | 57 | _concat_init( 58 | getattr(rnn, "weight_hh_l{}".format(i)), 59 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_], 60 | ) 61 | torch.nn.init.constant_(getattr(rnn, "bias_hh_l{}".format(i)), 0) 62 | 63 | 64 | def act(x, activation): 65 | if activation == "relu": 66 | return F.relu_(x) 67 | 68 | elif activation == "leaky_relu": 69 | return F.leaky_relu_(x, negative_slope=0.01) 70 | 71 | elif activation == "swish": 72 | return x * torch.sigmoid(x) 73 | 74 | else: 75 | raise Exception("Incorrect activation!") 76 | 77 | 78 | class Base: 79 | def __init__(self): 80 | pass 81 | 82 | def spectrogram(self, input, eps=0.): 83 | (real, imag) = self.stft(input) 84 | return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 85 | 86 | def spectrogram_phase(self, input, eps=0.): 87 | (real, imag) = self.stft(input) 88 | mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 89 | cos = real / mag 90 | sin = imag / mag 91 | return mag, cos, sin 92 | 93 | 94 | def spectrogram_to_wav(self, input, spectrogram, length=None): 95 | """Spectrogram to waveform. 96 | 97 | Args: 98 | input: (batch_size, segment_samples, channels_num) 99 | spectrogram: (batch_size, channels_num, time_steps, freq_bins) 100 | 101 | Outputs: 102 | output: (batch_size, segment_samples, channels_num) 103 | """ 104 | assert input.shape[1]==1, "Current model only supports mono audio" 105 | wav_list = [] 106 | # (real, imag) = self.stft(input[:, channel, :]) 107 | spec = self.stft(input.squeeze(1)).transpose(-1,-2) 108 | # spectrogram.shape=torch.Size([4, 1, 501, 513]) 109 | real = spec.real 110 | imag = spec.imag 111 | (_, cos, sin) = magphase(real, imag) 112 | 113 | recon_spec = spectrogram.squeeze(1)*cos +\ 114 | spectrogram.squeeze(1)*sin*1j 115 | 116 | output = self.istft(recon_spec.transpose(-1,-2), length) 117 | return output.unsqueeze(1) # make it (B, 1, T, F) to match previous setting -------------------------------------------------------------------------------- /End2End/models/transcription/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/models/transcription/__init__.py -------------------------------------------------------------------------------- /End2End/models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | def init_layer(layer): 8 | r"""Initialize a Linear or Convolutional layer.""" 9 | nn.init.xavier_uniform_(layer.weight) 10 | 11 | if hasattr(layer, 'bias'): 12 | if layer.bias is not None: 13 | layer.bias.data.fill_(0.0) 14 | 15 | 16 | def init_bn(bn): 17 | r"""Initialize a Batchnorm layer.""" 18 | bn.bias.data.fill_(0.0) 19 | bn.weight.data.fill_(1.0) 20 | 21 | 22 | def init_gru(rnn): 23 | r"""Initialize a GRU layer.""" 24 | 25 | def _concat_init(tensor, init_funcs): 26 | (length, fan_out) = tensor.shape 27 | fan_in = length // len(init_funcs) 28 | 29 | for (i, init_func) in enumerate(init_funcs): 30 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 31 | 32 | def _inner_uniform(tensor): 33 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 34 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 35 | 36 | for i in range(rnn.num_layers): 37 | _concat_init( 38 | getattr(rnn, 'weight_ih_l{}'.format(i)), 39 | [_inner_uniform, _inner_uniform, _inner_uniform], 40 | ) 41 | torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0) 42 | 43 | _concat_init( 44 | getattr(rnn, 'weight_hh_l{}'.format(i)), 45 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_], 46 | ) 47 | torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0) 48 | 49 | 50 | class Normalization(): 51 | """This class is for normalizing the spectrograms batch by batch. The normalization used is min-max, two modes 'framewise' and 'imagewise' can be selected. In this paper, we found that 'imagewise' normalization works better than 'framewise'""" 52 | def __init__(self, mode='framewise'): 53 | if mode == 'framewise': 54 | def normalize(x): 55 | size = x.shape 56 | x_max = x.max(1, keepdim=True)[0] # Finding max values for each frame 57 | x_min = x.min(1, keepdim=True)[0] 58 | output = (x-x_min)/(x_max-x_min) # If there is a column with all zero, nan will occur 59 | output[torch.isnan(output)]=0 # Making nan to 0 60 | return output 61 | elif mode == 'imagewise': 62 | def normalize(x): 63 | size = x.shape 64 | x_max = x.reshape(size[0], size[1]*size[2]).max(1, keepdim=True)[0] 65 | x_min = x.reshape(size[0], size[1]*size[2]).min(1, keepdim=True)[0] 66 | x_max = x_max.unsqueeze(1) # Make it broadcastable 67 | x_min = x_min.unsqueeze(1) # Make it broadcastable 68 | return (x-x_min)/(x_max-x_min) 69 | else: 70 | print(f'please choose the correct mode') 71 | self.normalize = normalize 72 | 73 | def transform(self, x): 74 | return self.normalize(x) 75 | 76 | def __call__(self, x): 77 | return self.transform(x) -------------------------------------------------------------------------------- /End2End/notes.txt: -------------------------------------------------------------------------------- 1 | When using plugin_names 2 | Max instruments for training = 13 3 | Max instruments for val = 10 4 | 5 | When using MIDI programs 6 | Max instruments for training = 12 7 | Max instruments for val = 11 -------------------------------------------------------------------------------- /End2End/samplers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | from typing import List 5 | 6 | import pickle 7 | import numpy as np 8 | import h5py 9 | from pytorch_lightning.utilities import rank_zero_only 10 | import torch.distributed as dist 11 | 12 | 13 | class End2EndSegmentSampler: 14 | def __init__( 15 | self, 16 | hdf5s_dir: str, 17 | split: str, 18 | segment_seconds: float, 19 | hop_seconds: float, 20 | batch_size: int, 21 | steps_per_epoch: int, 22 | evaluation: bool, 23 | max_evaluation_steps: int = -1, 24 | random_seed: int = 1234, 25 | mini_data: bool = False, 26 | ): 27 | r"""Sampler is used to sample segments for training or evaluation. 28 | 29 | Args: 30 | hdf5s_dir: str 31 | split: 'train' | 'validation' | 'test' 32 | segment_seconds: float, e.g., 10.0 33 | hop_seconds: float, e.g., 1.0 34 | batch_size: int, e.g., 16 35 | evaluation: bool, set to True in training, and False in evaluation 36 | max_evaluation_steps: only activate when evaluation=True 37 | random_seed: int 38 | mini_data: bool, sample a small amount of data for debugging 39 | """ 40 | assert split in ['train', 'validation', 'test'] 41 | 42 | self.segment_seconds = segment_seconds 43 | self.hop_seconds = hop_seconds 44 | self.batch_size = batch_size 45 | self.steps_per_epoch = steps_per_epoch 46 | self.evaluation = evaluation 47 | self.max_evaluation_steps = max_evaluation_steps 48 | 49 | # paths 50 | split_hdf5s_dir = os.path.join(hdf5s_dir, split) 51 | 52 | # Traverse directory 53 | hdf5_paths = sorted([str(path) for path in pathlib.Path(split_hdf5s_dir).rglob('*.h5')]) 54 | 55 | self.segment_list = [] 56 | 57 | n = 0 58 | for hdf5_path in hdf5_paths: 59 | try: 60 | with h5py.File(hdf5_path, 'r') as hf: 61 | if hf.attrs['split'].decode() == split: 62 | 63 | audio_name = '{}.h5'.format(os.path.splitext(hf.attrs['audio_name'])[0].decode()) 64 | self.segment_list.append([hf.attrs['split'].decode(), audio_name,]) 65 | except: 66 | from IPython import embed; embed(using=False); os._exit(0) 67 | # self.segment_list looks like: 68 | # [['train', 'Track01122.h5', 0], 69 | # ['train', 'Track01122.h5', 1.0], 70 | # ['train', 'Track01122.h5', 2.0], 71 | # ...] 72 | 73 | if evaluation: 74 | logging.info('Mini-batches for evaluating {} set: {}'.format(split, max_evaluation_steps)) 75 | 76 | else: 77 | logging.info('Training segments: {}'.format(len(self.segment_list))) 78 | 79 | self.pointer = 0 80 | self.segment_indexes = np.arange(len(self.segment_list)) 81 | 82 | if len(self.segment_indexes) == 0: 83 | error_msg = 'No training data found in {}! Please set up your workspace and data path properly!'.format(split_hdf5s_dir) 84 | raise Exception(error_msg) 85 | 86 | 87 | # Both training and evaluation shuffle segment_indexes in the begining. 88 | self.random_state = np.random.RandomState(random_seed) 89 | self.random_state.shuffle(self.segment_indexes) 90 | 91 | def __iter__(self): 92 | r"""Get batch meta. 93 | 94 | Returns: 95 | batch_segment_list: list of list, e.g., 96 | [['train', 'Track00255.h5', 4.0], 97 | ['train', 'Track00894.h5', 53.0], 98 | ['train', 'Track01422.h5', 77.0], 99 | ...] 100 | """ 101 | if self.evaluation: 102 | return self.iter_eval() 103 | else: 104 | return self.iter_train() 105 | 106 | def iter_train(self): 107 | r"""Get batch meta for training. 108 | 109 | Returns: 110 | batch_segment_list: list of list, e.g., 111 | [['train', 'Track00255.h5', 4.0], 112 | ['train', 'Track00894.h5', 53.0], 113 | ['train', 'Track01422.h5', 77.0], 114 | ...] 115 | """ 116 | while True: 117 | batch_segment_list = [] 118 | i = 0 119 | while i < self.batch_size: 120 | index = self.segment_indexes[self.pointer] 121 | self.pointer += 1 122 | 123 | if self.pointer >= len(self.segment_indexes): 124 | self.random_state.shuffle(self.segment_indexes) 125 | self.pointer = 0 126 | 127 | batch_segment_list.append(self.segment_list[index]) 128 | i += 1 129 | 130 | yield batch_segment_list 131 | 132 | def iter_eval(self): 133 | r"""Get batch meta for evaluation. 134 | 135 | Returns: 136 | batch_segment_list: list of list, e.g., 137 | [['train', 'Track00255.h5', 4.0], 138 | ['train', 'Track00894.h5', 53.0], 139 | ['train', 'Track01422.h5', 77.0], 140 | ...] 141 | """ 142 | _pointer = 0 143 | _steps = 0 144 | 145 | while _pointer < len(self.segment_indexes): 146 | 147 | if _steps == self.max_evaluation_steps: 148 | break 149 | 150 | batch_segment_list = [] 151 | i = 0 152 | while i < self.batch_size: 153 | index = self.segment_indexes[_pointer] 154 | _pointer += 1 155 | 156 | if _pointer >= len(self.segment_indexes): 157 | break 158 | 159 | batch_segment_list.append(self.segment_list[index]) 160 | i += 1 161 | 162 | _steps += 1 163 | 164 | yield batch_segment_list 165 | 166 | def __len__(self): 167 | return self.steps_per_epoch 168 | 169 | def state_dict(self): 170 | state = {'pointer': self.pointer, 'segment_indexes': self.segment_indexes} 171 | return state 172 | 173 | def load_state_dict(self, state): 174 | self.pointer = state['pointer'] 175 | self.segment_indexes = state['segment_indexes'] 176 | 177 | @rank_zero_only 178 | def log(self, str): 179 | logging.info(str) 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /End2End/slakh_instruments.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/End2End/slakh_instruments.pkl -------------------------------------------------------------------------------- /End2End/tasks/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import Binary, BinaryOpenMic 2 | from .hungarian import Hungarian 3 | from .hungarian_autoregressive import Hungarian_Autoregressive 4 | from .softmax_autoregressive import Softmax, SoftmaxOpenMic 5 | from .linear import Linear -------------------------------------------------------------------------------- /End2End/tasks/jointist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instrument Recognition + 3 | Transcription 4 | """ 5 | 6 | import torch 7 | from torch import nn as nn, optim as optim 8 | from torch.optim.lr_scheduler import LambdaLR, MultiStepLR 9 | import pytorch_lightning as pl 10 | 11 | from omegaconf import OmegaConf 12 | import pandas as pd 13 | 14 | class Jointist(pl.LightningModule): 15 | def __init__( 16 | self, 17 | detection_model: pl.LightningModule, 18 | transcription_model: pl.LightningModule, 19 | lr_lambda, 20 | cfg 21 | ): 22 | r"""Pytorch Lightning wrapper of PyTorch model, including forward, 23 | optimization of model, etc. 24 | 25 | Args: 26 | network: nn.Module 27 | loss_function: func 28 | learning_rate, float, e.g., 1e-3 29 | lr_lambda: func 30 | """ 31 | super().__init__() 32 | self.detection_model = detection_model 33 | self.transcription_model = transcription_model 34 | self.lr_lambda = lr_lambda 35 | self.cfg = cfg 36 | 37 | def training_step(self, batch, batch_idx): 38 | detection_loss = self.detection_model.training_step(batch, batch_idx, self) 39 | transcription_loss = self.transcription_model.training_step(batch, batch_idx, self) 40 | 41 | total_loss = transcription_loss+detection_loss 42 | self.log('Total_Loss/Train', total_loss, on_step=False, on_epoch=True) 43 | 44 | return total_loss 45 | 46 | 47 | def validation_step(self, batch, batch_idx): 48 | outputs, detection_loss = self.detection_model.validation_step(batch, batch_idx, self) 49 | transcription_loss = self.transcription_model.validation_step(batch, batch_idx, self) 50 | total_loss = transcription_loss+detection_loss 51 | self.log('Total_Loss/Valid', total_loss, on_step=False, on_epoch=True) 52 | 53 | return outputs, detection_loss 54 | 55 | def validation_epoch_end(self, outputs): 56 | detection_loss = self.detection_model.validation_epoch_end(outputs, self) 57 | 58 | def test_step(self, batch, batch_idx): 59 | plugin_idxs = self.detection_model.test_step(batch, batch_idx, self) 60 | return self.transcription_model.test_step(batch, batch_idx, plugin_idxs, self) 61 | 62 | def test_epoch_end(self, outputs): 63 | self.transcription_model.test_epoch_end(outputs, self) 64 | 65 | 66 | 67 | def predict_step(self, batch, batch_idx): 68 | plugin_idxs = self.detection_model.predict_step(batch, batch_idx) 69 | self.transcription_model.predict_step(batch, batch_idx, plugin_idxs) 70 | 71 | return plugin_idxs 72 | 73 | def configure_optimizers(self): 74 | r"""Configure optimizer.""" 75 | optimizer = optim.Adam( 76 | list(self.transcription_model.parameters()) + list(self.detection_model.parameters()), 77 | **self.cfg.detection.model.optimizer, 78 | ) 79 | 80 | if self.cfg.scheduler.type=="MultiStepLR": 81 | scheduler = { 82 | 'scheduler': MultiStepLR(optimizer, 83 | milestones=list(self.cfg.scheduler.milestones), 84 | gamma=self.cfg.scheduler.gamma), 85 | 'interval': 'epoch', 86 | 'frequency': 1, 87 | } 88 | elif self.cfg.scheduler.type=="LambdaLR": 89 | scheduler = { 90 | 'scheduler': LambdaLR(optimizer, self.lr_lambda), 91 | 'interval': 'step', 92 | 'frequency': 1, 93 | } 94 | 95 | 96 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /End2End/tasks/jointist_ss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instrument Recognition + 3 | Transcription + 4 | Music Source Separation 5 | """ 6 | 7 | import torch 8 | from torch import nn as nn, optim as optim 9 | from torch.optim.lr_scheduler import LambdaLR, MultiStepLR 10 | import pytorch_lightning as pl 11 | 12 | from omegaconf import OmegaConf 13 | import pandas as pd 14 | 15 | class Jointist_SS(pl.LightningModule): 16 | def __init__( 17 | self, 18 | detection_model: pl.LightningModule, 19 | tseparation_model: pl.LightningModule, 20 | lr_lambda, 21 | cfg 22 | ): 23 | r"""Pytorch Lightning wrapper of PyTorch model, including forward, 24 | optimization of model, etc. 25 | 26 | Args: 27 | network: nn.Module 28 | loss_function: func 29 | learning_rate, float, e.g., 1e-3 30 | lr_lambda: func 31 | """ 32 | super().__init__() 33 | self.detection_model = detection_model 34 | self.tseparation_model = tseparation_model 35 | self.lr_lambda = lr_lambda 36 | self.cfg = cfg 37 | 38 | 39 | def predict_step(self, batch, batch_idx): 40 | plugin_idxs = self.detection_model.predict_step(batch, batch_idx) 41 | self.tseparation_model.predict_step(batch, batch_idx, plugin_idxs) 42 | 43 | 44 | def configure_optimizers(self): 45 | r"""Configure optimizer.""" 46 | optimizer = optim.Adam( 47 | list(self.transcription_model.parameters()) + list(self.detection_model.parameters()), 48 | **self.cfg.detection.model.optimizer, 49 | ) 50 | 51 | if self.cfg.scheduler.type=="MultiStepLR": 52 | scheduler = { 53 | 'scheduler': MultiStepLR(optimizer, 54 | milestones=list(self.cfg.scheduler.milestones), 55 | gamma=self.cfg.scheduler.gamma), 56 | 'interval': 'epoch', 57 | 'frequency': 1, 58 | } 59 | elif self.cfg.scheduler.type=="LambdaLR": 60 | scheduler = { 61 | 'scheduler': LambdaLR(optimizer, self.lr_lambda), 62 | 'interval': 'step', 63 | 'frequency': 1, 64 | } 65 | 66 | 67 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /End2End/tasks/separation/__init__.py: -------------------------------------------------------------------------------- 1 | from .separation import Separation -------------------------------------------------------------------------------- /End2End/tasks/separation/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from matplotlib.ticker import (MultipleLocator, AutoMinorLocator) 4 | import numpy as np 5 | 6 | def calculate_sdr(ref, est): 7 | assert ref.dim()==est.dim(), f"ref {ref.shape} has a different size than est {est.shape}" 8 | 9 | s_true = ref 10 | s_artif = est - ref 11 | 12 | sdr = 10. * ( 13 | torch.log10(torch.clip(torch.mean(s_true ** 2, 1), 1e-8, torch.inf)) \ 14 | - torch.log10(torch.clip(torch.mean(s_artif ** 2, 1), 1e-8, torch.inf))) 15 | return sdr 16 | 17 | 18 | 19 | def _append_to_dict(dict, key, value): 20 | if key in dict.keys(): 21 | dict[key].append(value) 22 | else: 23 | dict[key] = [value] 24 | 25 | 26 | def barplot(stat_mean, title="Untitles", figsize=(4,24)): 27 | # stat_mean = collections.OrderedDict(sorted(stat_mean.items())) 28 | stat_mean = {k: v for k, v in sorted(stat_mean.items(), key=lambda item: item[1])} 29 | fig, ax = plt.subplots(1,1, figsize=figsize) 30 | xlabels = list(stat_mean.keys()) 31 | values = list(stat_mean.values()) 32 | ax.barh(xlabels, values, color='cyan') 33 | global_mean = sum(stat_mean.values())/len(stat_mean.values()) 34 | ax.vlines(global_mean, 0, len(stat_mean), 'r') 35 | ax.tick_params(labeltop=True, labelright=False) 36 | ax.set_ylim([-1,len(xlabels)]) 37 | ax.set_title(title) 38 | ax.grid(axis='x') 39 | ax.grid(b=True, which='minor', linestyle='--') 40 | 41 | # move the left boundary to origin 42 | ax.spines['left'].set_position('zero') 43 | # turn off the RHS boundary 44 | ax.spines['right'].set_color('none') 45 | 46 | fig.savefig(f'{title}.png', bbox_inches='tight') 47 | 48 | return global_mean, fig -------------------------------------------------------------------------------- /End2End/tasks/t_separation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn as nn, optim as optim 4 | from torch.optim.lr_scheduler import LambdaLR, MultiStepLR 5 | import pytorch_lightning as pl 6 | 7 | from omegaconf import OmegaConf 8 | import pandas as pd 9 | 10 | # for applying threshold on outputroll 11 | class STE(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, threshold): 14 | # ctx is the context object that can be called in backward 15 | # in DANN we use it to save the reversal scaler lambda 16 | # in this case, we don't need to use it 17 | return (x > threshold).float() 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | # Since we have two inputs during forward() 22 | # There would be grad for both x and threshold 23 | # Therefore we need to return a grad for each 24 | # But we don't need grad for threshold, so make it None 25 | return F.hardtanh(grad_output), None 26 | 27 | class TSeparation(pl.LightningModule): 28 | def __init__( 29 | self, 30 | transcription_model: pl.LightningModule, 31 | separation_model: pl.LightningModule, 32 | batch_data_preprocessor, 33 | lr_lambda, 34 | cfg 35 | ): 36 | r"""Pytorch Lightning wrapper of PyTorch model, including forward, 37 | optimization of model, etc. 38 | 39 | Args: 40 | network: nn.Module 41 | loss_function: func 42 | learning_rate, float, e.g., 1e-3 43 | lr_lambda: func 44 | """ 45 | super().__init__() 46 | self.transcription_model = transcription_model 47 | self.separation_model = separation_model 48 | self.lr_lambda = lr_lambda 49 | self.batch_data_preprocessor = batch_data_preprocessor 50 | self.cfg = cfg 51 | 52 | def training_step(self, batch, batch_idx): 53 | batch = self.batch_data_preprocessor(batch) 54 | transcription_output = self.transcription_model.training_step(batch, batch_idx, self) 55 | transcription_loss = transcription_output['loss'] 56 | outputs = transcription_output['outputs'] 57 | if self.cfg.straight_through==True: 58 | outputs['frame_output'] = STE.apply(outputs['frame_output'], self.cfg.transcription.evaluation.frame_threshold) 59 | separation_loss = self.separation_model.training_step(batch, batch_idx, outputs, self) 60 | 61 | total_loss = transcription_loss + separation_loss 62 | self.log('Total_Loss/Train', total_loss, on_step=False, on_epoch=True) 63 | 64 | return total_loss 65 | 66 | 67 | def validation_step(self, batch, batch_idx): 68 | batch = self.batch_data_preprocessor(batch) 69 | transcription_loss, outputs = self.transcription_model.validation_step(batch, batch_idx, self) 70 | if self.cfg.straight_through==True: 71 | outputs['frame_output'] = STE.apply(outputs['frame_output'], self.cfg.transcription.evaluation.frame_threshold) 72 | separation_loss = self.separation_model.validation_step(batch, batch_idx, outputs, self) 73 | 74 | total_loss = transcription_loss + separation_loss['Separation/Valid/Loss'] 75 | self.log('Total_Loss/Valid', total_loss, on_step=False, on_epoch=True) 76 | 77 | return outputs 78 | 79 | def test_step(self, batch, batch_idx): 80 | # TODO: Update it for Jointist 81 | _, _, output_dict = self.transcription_model.test_step(batch,batch_idx, None, False, self) 82 | if self.cfg.straight_through==True: 83 | output_dict['frame_output'] = STE.apply(output_dict['frame_output'], self.cfg.transcription.evaluation.frame_threshold) 84 | 85 | sdr_dict = self.separation_model.test_step(batch,batch_idx, output_dict, self) 86 | 87 | return sdr_dict 88 | 89 | def test_epoch_end(self, outputs): 90 | self.separation_model.test_epoch_end(outputs, self) 91 | 92 | 93 | 94 | def predict_step(self, batch, batch_idx, plugin_idxs): 95 | output_dict = self.transcription_model.predict_step(batch, batch_idx, plugin_idxs, True) 96 | self.separation_model.predict_step(batch, batch_idx, output_dict, plugin_idxs) 97 | 98 | def configure_optimizers(self): 99 | r"""Configure optimizer.""" 100 | optimizer = optim.Adam( 101 | list(self.transcription_model.parameters()) + list(self.separation_model.parameters()), 102 | lr=self.cfg.lr, 103 | betas=(0.9, 0.999), 104 | eps=1e-08, 105 | weight_decay=0.0, 106 | amsgrad=True, 107 | ) 108 | 109 | if self.cfg.scheduler.type=="MultiStepLR": 110 | scheduler = { 111 | 'scheduler': MultiStepLR(optimizer, 112 | milestones=list(self.cfg.scheduler.milestones), 113 | gamma=self.cfg.scheduler.gamma), 114 | 'interval': 'epoch', 115 | 'frequency': 1, 116 | } 117 | elif self.cfg.scheduler.type=="LambdaLR": 118 | scheduler = { 119 | 'scheduler': LambdaLR(optimizer, self.lr_lambda), 120 | 'interval': 'step', 121 | 'frequency': 1, 122 | } 123 | 124 | 125 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /End2End/tasks/transcription/__init__.py: -------------------------------------------------------------------------------- 1 | from .transcription import Transcription, BaselineTranscription -------------------------------------------------------------------------------- /End2End/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /End2End/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /End2End/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path, PurePath 11 | 12 | 13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 14 | ''' 15 | Function to plot specific fields from training log(s). Plots both training and test results. 16 | 17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 18 | - fields = which results to plot from each log file - plots both training and test for each field. 19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 20 | - log_name = optional, name of log file if different than default 'log.txt'. 21 | 22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 23 | - solid lines are training results, dashed lines are test results. 24 | 25 | ''' 26 | func_name = "plot_utils.py::plot_logs" 27 | 28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 29 | # convert single Path to list to avoid 'not iterable' error 30 | 31 | if not isinstance(logs, list): 32 | if isinstance(logs, PurePath): 33 | logs = [logs] 34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 35 | else: 36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 37 | Expect list[Path] or single Path obj, received {type(logs)}") 38 | 39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir 40 | for i, dir in enumerate(logs): 41 | if not isinstance(dir, PurePath): 42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 43 | if not dir.exists(): 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | # verify log_name exists 46 | fn = Path(dir / log_name) 47 | if not fn.exists(): 48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") 49 | print(f"--> full path of missing log file: {fn}") 50 | return 51 | 52 | # load log file(s) and plot 53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 54 | 55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 56 | 57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 58 | for j, field in enumerate(fields): 59 | if field == 'mAP': 60 | coco_eval = pd.DataFrame( 61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] 62 | ).ewm(com=ewm_col).mean() 63 | axs[j].plot(coco_eval, c=color) 64 | else: 65 | df.interpolate().ewm(com=ewm_col).mean().plot( 66 | y=[f'train_{field}', f'test_{field}'], 67 | ax=axs[j], 68 | color=[color] * 2, 69 | style=['-', '--'] 70 | ) 71 | for ax, field in zip(axs, fields): 72 | ax.legend([Path(p).name for p in logs]) 73 | ax.set_title(field) 74 | 75 | 76 | def plot_precision_recall(files, naming_scheme='iter'): 77 | if naming_scheme == 'exp_id': 78 | # name becomes exp_id 79 | names = [f.parts[-3] for f in files] 80 | elif naming_scheme == 'iter': 81 | names = [f.stem for f in files] 82 | else: 83 | raise ValueError(f'not supported {naming_scheme}') 84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 86 | data = torch.load(f) 87 | # precision is n_iou, n_points, n_cat, n_area, max_det 88 | precision = data['precision'] 89 | recall = data['params'].recThrs 90 | scores = data['scores'] 91 | # take precision for all classes, all areas and 100 detections 92 | precision = precision[0, :, :, 0, -1].mean(1) 93 | scores = scores[0, :, :, 0, -1].mean(1) 94 | prec = precision.mean() 95 | rec = data['recall'][0, :, 0, -1].mean() 96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 97 | f'score={scores.mean():0.3f}, ' + 98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 99 | ) 100 | axs[0].plot(recall, precision, c=color) 101 | axs[1].plot(recall, scores, c=color) 102 | 103 | axs[0].set_title('Precision / Recall') 104 | axs[0].legend(names) 105 | axs[1].set_title('Scores / Recall') 106 | axs[1].legend(names) 107 | return fig, axs 108 | -------------------------------------------------------------------------------- /End2End/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import yaml 5 | import datetime 6 | import pickle 7 | 8 | 9 | def create_logging(log_dir, filemode): 10 | os.makedirs(log_dir, exist_ok=True) 11 | i1 = 0 12 | 13 | while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))): 14 | i1 += 1 15 | 16 | log_path = os.path.join(log_dir, '{:04d}.log'.format(i1)) 17 | logging.basicConfig( 18 | level=logging.DEBUG, 19 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 20 | datefmt='%a, %d %b %Y %H:%M:%S', 21 | filename=log_path, 22 | filemode=filemode, 23 | ) 24 | 25 | # Print to console 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.INFO) 28 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 29 | console.setFormatter(formatter) 30 | logging.getLogger('').addHandler(console) 31 | 32 | return logging 33 | 34 | 35 | def float32_to_int16(x: np.ndarray): 36 | assert np.max(np.abs(x)) <= 2.0 37 | return (x * 32767.0).astype(np.int16) 38 | 39 | 40 | def int16_to_float32(x: np.ndarray): 41 | return (x / 32767.0).astype(np.float32) 42 | 43 | 44 | def read_yaml(config_yaml): 45 | with open(config_yaml, "r") as fr: 46 | return yaml.load(fr, Loader=yaml.FullLoader) 47 | 48 | 49 | def note_to_freq(piano_note): 50 | return 2 ** ((piano_note - 39) / 12) * 440 51 | 52 | 53 | def get_pitch_shift_factor(pitch_shift): 54 | return 2 ** (pitch_shift / 12) 55 | 56 | 57 | class StatisticsContainer(object): 58 | def __init__(self, statistics_path): 59 | self.statistics_path = statistics_path 60 | 61 | self.backup_statistics_path = "{}_{}.pkl".format( 62 | os.path.splitext(self.statistics_path)[0], 63 | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 64 | ) 65 | 66 | self.statistics_dict = {"train": [], "test": []} 67 | 68 | def append(self, steps, statistics, split): 69 | statistics["steps"] = steps 70 | self.statistics_dict[split].append(statistics) 71 | 72 | def dump(self): 73 | pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) 74 | pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) 75 | logging.info(" Dump statistics to {}".format(self.statistics_path)) 76 | logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) 77 | 78 | def load_state_dict(self, resume_steps): 79 | self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) 80 | 81 | resume_statistics_dict = {"train": [], "test": []} 82 | 83 | for key in self.statistics_dict.keys(): 84 | for statistics in self.statistics_dict[key]: 85 | if statistics["steps"] <= resume_steps: 86 | resume_statistics_dict[key].append(statistics) 87 | 88 | self.statistics_dict = resume_statistics_dict 89 | -------------------------------------------------------------------------------- /GPU_debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | from sklearn.datasets import make_blobs 5 | from sklearn.model_selection import train_test_split 6 | import torch.optim as optim 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | X, Y = make_blobs(10000,1000,centers=10, cluster_std=10) 12 | X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=0) 13 | 14 | 15 | trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),torch.from_numpy(y_train)) 16 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2) 17 | 18 | 19 | class Model(pl.LightningModule): 20 | def __init__(self): 21 | super(Model, self).__init__() 22 | self.lstm = nn.LSTM(100, 256, bidirectional=True) 23 | self.classifier = nn.Linear(256*2*10,10) 24 | 25 | def forward(self, x): 26 | x, _ = self.lstm(x.view(-1,10,100)) 27 | x = self.classifier(x.flatten(1)) 28 | return x 29 | 30 | 31 | def training_step(self, batch, batch_idx): 32 | pred = self(batch[0]) 33 | loss = torch.nn.functional.cross_entropy(pred, batch[1]) 34 | return loss 35 | 36 | 37 | 38 | def configure_optimizers(self): 39 | r"""Configure optimizer.""" 40 | return optim.Adam(self.parameters()) 41 | 42 | 43 | model = Model() 44 | 45 | trainer = pl.Trainer(max_epochs=99999, gpus=2, accelerator="ddp") 46 | 47 | 48 | trainer.fit(model, trainloader) 49 | # check if bin 0-20 has changed 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Table of Contents 3 | - [Table of Contents](#table-of-contents) 4 | - [Jointist](#jointist) 5 | - [Setup](#setup) 6 | - [Inference](#inference) 7 | - [a. Instrument Recognition + Transcription](#a-instrument-recognition--transcription) 8 | - [b. Instrument Recognition + Transcription + Source Separation](#b-instrument-recognition--transcription--source-separation) 9 | - [Using individual pretrained models](#using-individual-pretrained-models) 10 | - [Transcription](#transcription) 11 | - [Training](#training) 12 | - [Instrument Recognition](#instrument-recognition) 13 | - [Transcrpition](#transcrpition) 14 | - [End2end training (Jointist)](#end2end-training-jointist) 15 | - [Experiments](#experiments) 16 | 17 | 18 | 19 | # Jointist 20 | 21 | Jointist is a joint-training framework capable of: 22 | 1. Instrument Recogition 23 | 1. Multi-Instrument Transcription 24 | 1. Music Source Separation 25 | 26 | 27 | 28 | Demo: [https://jointist.github.io/Demo/](https://jointist.github.io/Demo/) 29 | 30 | Paper: [https://arxiv.org/abs/2302.00286](https://arxiv.org/abs/2302.00286) 31 | 32 | 33 | ## Setup 34 | This code is developed using the docker image `nvidia/cuda:10.2-devel-ubuntu18.04` and python version 3.8.10. 35 | 36 | To setup the environment for joinist, install the dependies 37 | ```bash 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | If you get `OSError: sndfile library not found`, you need to install `libsndfile1` using 42 | 43 | ```bash 44 | apt install libsndfile1 45 | ``` 46 | 47 | 50 | 51 | The pretrained **model weights** can be download from [dropbox](https://www.dropbox.com/s/n0eerriphw65qsr/jointist_weights.zip?dl=0). Put the model weights under the `weights` folder after downloading. 52 | 53 | The example **songs** for interference is included in this repo as `songs.zip`. 54 | 55 | After unzipping it using the following command, a new folder called `songs` will be created. 56 | 57 | ```bash 58 | unzip songs.zip 59 | ``` 60 | 61 | ## Inference 62 | ### a. Instrument Recognition + Transcription 63 | The following script detects the instrument in the song and transcribe the instruments detected: 64 | ```bash 65 | python pred_jointist.py audio_path=songs audio_ext=mp3 gpus=[0] 66 | ``` 67 | 68 | It will first run a instrument recognition model, and the predicted instruments are used as the conditions to the transcription model. 69 | 70 | If you have multiple GPUs, the argument `gpus` controls which GPU to use. For example, if you want to use GPU:2, then you can do `gpus=[2]`. 71 | 72 | The `audio_path` specifies the path to the input audio files. If your audio files are not in `.mp3` format, you can change the `audio_ext` argument to the audio format of your songs. Since we use `torchaudio.load` to load audio files, you can used any audio format as long as it is supported by torchaudio.load. 73 | 74 | The output MIDI files will be stored inside the `outputs/YYYY-MM-DD/HH-MM-SS/MIDI_output` folder. 75 | 76 | Model weights can be changed under `checkpoint` of `End2End/config/jointist_inference.yaml`. 77 | 78 | - `transcription1000.ckpt` is the model trained only on the transcription task. 79 | - `tseparation.ckpt` is the model weight jointly trained with both transcription and source separation tasks. 80 | 81 | ### b. Instrument Recognition + Transcription + Source Separation 82 | 83 | The following inference script performs instrument detection, transcription, and source separation: 84 | 85 | ```bash 86 | python pred_jointist_ss.py audio_path=songs audio_ext=mp3 gpus=[0] 87 | ``` 88 | 89 | Same as above, the output MIDI files will be stored inside the `outputs/YYYY-MM-DD/HH-MM-SS/MIDI_output` folder. 90 | 91 | Model weights can be changed under `checkpoint` of `End2End/config/jointist_ss_inference.yaml`. `tseparation.ckpt` is the checkpoint with a better transcription F1 sources and source separation SDR after training both of them end2end. 92 | 93 | Implementational details for Jointist is avaliable [here](./jointist_explanation.md) 94 | 95 | 96 | ## Using individual pretrained models 97 | ### Transcription 98 | ``` 99 | python pred_transcription.py datamodule=wild 100 | ``` 101 | 102 | Currently supported `datamodule`: 103 | 1. wild 104 | 1. h5 105 | 1. slakh 106 | The configuration such as `path` and `audio_ext` for each datamodule can be modified inside `End2End/config/datamoudle/xxx.yaml` 107 | 108 | ## Training 109 | 110 | ### Instrument Recognition 111 | 112 | ```bash 113 | python train_detection.py detection=CombinedModel_NewCLSv2 datamodule=slakh epoch=50 gpus=4 every_n_epochs=2 114 | ``` 115 | 116 | `detection`: controls the model type 117 | `detection/backbone`: controls which CNN backbone to use 118 | `datamodule`: controls which dataset to use `(openmic2018/slakh)`. It affects the instrument mappings. 119 | 120 | Please refer to `End2End/config/detection_config.yaml` for more configuration parameters 121 | 122 | ### Transcrpition 123 | 124 | ```bash 125 | python train_transcription.py transcription.backend.acoustic.type=CNN8Dropout_Wide inst_sampler.mode=imbalance inst_sampler.samples=2 inst_sampler.neg_samples=2 inst_sampler.temp=0.5 inst_sampler.audio_noise=0 gpus=[0] batch_size=2 126 | ``` 127 | 128 | `transcription.backend.acoustic.type`: controls the model type 129 | `inst_sampler.mode=imbalance`: controls which sampling mode to use 130 | `inst_sampler.samples`: controls how many positive samples to be mined for training 131 | `inst_sampler.neg_samples`: controls how many negative samples to be mined for training 132 | `inst_sampler.temp`: sampling temperature, only effective when using imbalance sampling 133 | `inst_sampler.audio_noise`: controls if random noise should be added to the audio during training 134 | `gpus`: controls which gpus to use. `[0]` means using cuda:0; `[2]` means using cuda:2; `[0,1,2,3]` means using four gpus cuda:0-3 135 | 136 | Please refer to `End2End/config/transcription_config.yaml` for more configuration parameters 137 | 138 | ### End2end training (Jointist) 139 | 140 | ``` 141 | python train_jointist.py 142 | ``` 143 | 144 | 145 | ## Experiments 146 | [link](./experiments.md) -------------------------------------------------------------------------------- /evaluate_end2end_Filter.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.plugins import DDPPlugin 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | 8 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor 9 | from End2End.tasks.transcription import Transcription 10 | 11 | from End2End.MIDI_program_map import ( 12 | MIDI_Class_NUM, 13 | MIDIClassName2class_idx, 14 | class_idx2MIDIClass, 15 | ) 16 | from End2End.data.augmentors import Augmentor 17 | from End2End.lr_schedulers import get_lr_lambda 18 | import End2End.models.transcription.combined as TranscriptionModel 19 | from End2End.losses import get_loss_function 20 | 21 | # Libraries related to hydra 22 | import hydra 23 | from hydra.utils import to_absolute_path 24 | 25 | @hydra.main(config_path="End2End/config/", config_name="transcription_config") 26 | def main(cfg): 27 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 28 | if cfg.transcription.evaluation.output_path: 29 | cfg.transcription.evaluation.output_path = to_absolute_path(cfg.transcription.evaluation.output_path) 30 | else: 31 | cfg.transcription.evaluation.output_path = os.getcwd() 32 | 33 | if cfg.MIDI_MAPPING.type=='plugin_names': 34 | cfg.MIDI_MAPPING.plugin_labels_num = PLUGIN_LABELS_NUM 35 | cfg.MIDI_MAPPING.NAME_TO_IX = PLUGIN_LB_TO_IX 36 | cfg.MIDI_MAPPING.IX_TO_NAME = PLUGIN_IX_TO_LB 37 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3/') 38 | elif cfg.MIDI_MAPPING.type=='MIDI_class': 39 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 40 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 41 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 42 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 43 | else: 44 | raise ValueError(f"Please choose the correct MIDI_MAPPING.type") 45 | 46 | experiment_name = ("Eval-" 47 | f"{cfg.transcription.model.type}-" 48 | f"{cfg.MIDI_MAPPING.type}-" 49 | f"hidden=128-" 50 | f"fps={cfg.transcription.model.args.frames_per_second}-" 51 | f"csize={MIDI_Class_NUM}-" 52 | f"bz={cfg.batch_size}" 53 | ) 54 | 55 | 56 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=None, MIDI_MAPPING=cfg.MIDI_MAPPING) 57 | data_module.setup('test') 58 | 59 | checkpoint_path = to_absolute_path(cfg.transcription.evaluation.checkpoint_path) 60 | 61 | # model 62 | Model = getattr(TranscriptionModel, cfg.transcription.model.type) 63 | model = Model(cfg, **cfg.transcription.model.args) 64 | 65 | pl_model = Transcription.load_from_checkpoint(checkpoint_path, 66 | network=model, 67 | loss_function=None, 68 | lr_lambda=None, 69 | batch_data_preprocessor=End2EndBatchDataPreprocessor(cfg.MIDI_MAPPING, cfg.inst_sampler.type, cfg.inst_sampler.temp), 70 | cfg=cfg) 71 | 72 | 73 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 74 | trainer = pl.Trainer( 75 | **cfg.trainer, 76 | plugins=[DDPPlugin(find_unused_parameters=True)], 77 | logger=logger 78 | ) 79 | trainer.test(pl_model, data_module.test_dataloader()) 80 | 81 | if __name__ == '__main__': 82 | main() -------------------------------------------------------------------------------- /inst_wise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/inst_wise.png -------------------------------------------------------------------------------- /jointist_explanation.md: -------------------------------------------------------------------------------- 1 | # Inference code explanation 2 | The main inference code is located at the `predict_step()` function under the `Transcription` class in `End2End/tasks/transcription/transcription.py`. It involves two steps: 3 | 4 | `waveforms => piano rolls => pkl/midi files` 5 | 6 | 7 | ## Step 1: Converting audio clips into posteriorgrams (or piano rolls) 8 | During inference, audio clips can be of any lengths. We specify the `segment_samples` to determine how long each segment is, and use `segment_samples//2` as the hop size to cut long audio clips into something that our model can handle. For example, a 4-minute audio clip would be a waveform of the shape (4*60*16000)=(3840000). And `segment_samples=160,000` will generate waveforms segments of the shape (47, 160000). Then we use `seg_batch_size` to control how many segments are fed to the network per batch. If `seg_batch_size=8`, `47/8=6` feed-forwards are needed to finish transcribing the audio. 9 | 10 | `predict_probabilities()` at line 429 of `End2End/tasks/transcription/transcription.py` is responsible for this operation: 11 | 12 | ```python 13 | predict_probabilities( 14 | network, 15 | audio, 16 | condition, 17 | segment_samples 18 | seg_batch_size 19 | ) 20 | ``` 21 | 22 | **network**: The pytorch model for feedforward 23 | 24 | **audio**: Single waveform (batch size must be 1 during inference) of the shape (len). 25 | 26 | **condition**: One-hot vector corresponding to the instruments of the shape (39) 27 | 28 | **segment_samples**: The length of each audio segment (default is 10 seconds/160,000 samples) 29 | 30 | **seg_batch_size**: How many segments per each feedforward (default is 8) 31 | 32 | This function returns a dictionary called `_output_dict` containing two keys 33 | 34 | ```python 35 | _output_dict={ 36 | 'frame_output': (num_frames, 88), 37 | 'reg_onset_output': (num_frames, 88), 38 | } 39 | ``` 40 | 41 | Each `_output_dict` from `predict_probabilities()` at line 429 of `End2End/tasks/transcription/transcription.py` corresponds to one musical instrument. 42 | 43 | After that, we concatenate all the outputs for different instrument forming a new dictionary `output_dict` at line 435 of `End2End/tasks/transcription/transcription.py`. 44 | 45 | ```python 46 | output_dict={ 47 | 'frame_output': (num_frames, 88*num_instruments), 48 | 'reg_onset_output': (num_frames, 88*num_instruments), 49 | } 50 | ``` 51 | 52 | 53 | `frame_output` and `reg_onset_output` are posteriorgrams with values between [0,1] indiciting the probability of the notes are present. Piano rolls can be easily obtained by applying a thershold to `frame_output`. In jointist, we use a onset and frame postprocessor `OnsetFramePostProcessor()` to directly convert these posteriorgrams into pkl/midi files via `postprocess_probabilities_to_midi_events()`. 54 | 55 | 56 | ## Step 2: Converting posteriorgrams into pkl/midi files 57 | After we have all the posteriorgrams for all instruments stored in `output_dict`, we use `postprocess_probabilities_to_midi_events()` to obtain the pkl/midi files 58 | 59 | ``` 60 | midi_events=postprocess_probabilities_to_midi_events( 61 | output_dict, 62 | plugin_ids, 63 | IX_TO_NAME, 64 | classes_num, 65 | post_processor) 66 | ``` 67 | 68 | **output_dict**: The output obtained from line 435 of `End2End/tasks/transcription/transcription.py`. 69 | 70 | **plugin_ids**: a list of indices indicting what instruments to be transcribed 71 | 72 | **IX_TO_NAME**: The dictionary mapping indices back to its instrument names in string 73 | 74 | **classes_num**: number of instrument classes (39) 75 | 76 | **post_processor**: Different post-processor avaliable in `End2End/inference_instruments_filter.py`. Default is `OnsetFramePostProcessor()` -------------------------------------------------------------------------------- /model_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/model_fig.png -------------------------------------------------------------------------------- /openmic_dataprocessing.sh: -------------------------------------------------------------------------------- 1 | WORKSPACE="/opt/tiger/kinwai/jointist" 2 | DATASET_DIR="${WORKSPACE}/openmic-2018" 3 | WAVEFORM_HDF5S_DIR="${WORKSPACE}/hdf5s/openmic_waveforms" 4 | 5 | # # Download audio files 6 | # wget https://zenodo.org/record/1432913/files/openmic-2018-v1.0.0.tgz?download=1 ./ 7 | # tar -xvf openmic-2018-v1.0.0.tgz\?download\=1 8 | 9 | # # pack audios 10 | # # Pack audio files into hdf5 files. 11 | python3 End2End/create_openmic2018.py pack_audios_to_hdf5s \ 12 | --audios_dir=$DATASET_DIR \ 13 | --hdf5s_dir=$WAVEFORM_HDF5S_DIR 14 | 15 | # The labels are inside the csv file, no need to create pkl files for it -------------------------------------------------------------------------------- /piece_wise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KinWaiCheuk/Jointist/1846975ec904c8b23ea653b9b7881f343574e7a6/piece_wise.png -------------------------------------------------------------------------------- /pkl2pianoroll.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import to_absolute_path 3 | import h5py 4 | import numpy as np 5 | import pickle 6 | from End2End.target_processors import TargetProcessor 7 | from pathlib import Path 8 | import tqdm 9 | from hydra.utils import to_absolute_path 10 | import os 11 | from End2End.MIDI_program_map import ( 12 | MIDI_Class_NUM, 13 | MIDIClassName2class_idx, 14 | class_idx2MIDIClass, 15 | ) 16 | import torch 17 | 18 | # constants 19 | frames_per_second=100 20 | SAMPLE_RATE=16000 21 | 22 | 23 | 24 | 25 | @hydra.main(config_path="End2End/config/", config_name="pkl2pianoroll") 26 | def main(cfg): 27 | 28 | audio_h5_path = to_absolute_path(cfg.audio_h5_path) 29 | pkl_path = to_absolute_path(cfg.pkl_path) 30 | 31 | # output name based on the original audio_h5_path name 32 | roll_name = os.path.basename(audio_h5_path).split('_')[0] + '_roll.h5' 33 | roll_h5_path = os.path.join(to_absolute_path(cfg.roll_output_path), roll_name) 34 | 35 | 36 | target_processor = TargetProcessor(frames_per_second=frames_per_second, 37 | begin_note=21, 38 | classes_num=88) 39 | 40 | with h5py.File(audio_h5_path, 'r') as h5: 41 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 42 | with h5py.File(roll_h5_path, "w") as hf: 43 | num_pkl = len(pkl_list) 44 | num_audio = len(h5.keys()) 45 | if num_pkl!=num_audio: 46 | val = input(f"num_pkl={num_pkl}, while num_audio={num_audio}\n" 47 | f"Do you want to continue? [y/n]") 48 | if val.lower()=='y': 49 | pass 50 | elif val.lower()=='n': 51 | raise ValueError(f"please check if it is normal to have missing pkl files") 52 | else: 53 | raise ValueError(f"Unkonwn input: {val}, please try again") 54 | 55 | for pkl_path in tqdm.tqdm(sorted(pkl_list)): 56 | piece_name = pkl_path.name[:-4] 57 | note_event = pickle.load(open(pkl_path, 'rb')) 58 | valid_length = len(h5[piece_name][()]) 59 | segment_seconds = valid_length/SAMPLE_RATE 60 | 61 | flat_frame_roll = event2roll(0, 62 | segment_seconds, 63 | note_event, 64 | target_processor) 65 | hf.create_dataset(piece_name, data=flat_frame_roll) 66 | 67 | 68 | def event2roll(start_time, segment_seconds, note_events, target_processor): 69 | keys = list(note_events.keys()) 70 | key = keys[0] 71 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 72 | segment_seconds=segment_seconds, 73 | note_events=note_events[key], 74 | ) 75 | frame_roll = target_dict_per_plugin['frame_roll'] 76 | placeholder = np.zeros_like(frame_roll).astype('bool') 77 | placeholder = np.expand_dims(placeholder,0) 78 | placeholder = placeholder.repeat(39,0) 79 | 80 | placeholder[MIDIClassName2class_idx[key]] = frame_roll 81 | for key in keys[1:]: 82 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 83 | segment_seconds=segment_seconds, 84 | note_events=note_events[key], 85 | ) 86 | placeholder[MIDIClassName2class_idx[key]] = target_dict_per_plugin['frame_roll'] 87 | 88 | return placeholder 89 | 90 | 91 | 92 | 93 | 94 | 95 | if __name__ == '__main__': 96 | main() -------------------------------------------------------------------------------- /pkl2pianoroll_MSD.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import to_absolute_path 3 | import h5py 4 | import numpy as np 5 | import pickle 6 | from End2End.target_processors import TargetProcessor 7 | from pathlib import Path 8 | import tqdm 9 | from hydra.utils import to_absolute_path 10 | import os 11 | import sys 12 | from End2End.MIDI_program_map import ( 13 | MIDI_Class_NUM, 14 | MIDIClassName2class_idx, 15 | class_idx2MIDIClass, 16 | ) 17 | import torch 18 | import torchaudio 19 | 20 | # constants 21 | frames_per_second=100 22 | SAMPLE_RATE=16000 23 | 24 | 25 | 26 | @hydra.main(config_path="End2End/config/", config_name="pkl2pianoroll_MSD") 27 | def main(cfg): 28 | 29 | pkl_path = to_absolute_path(cfg.pkl_path) 30 | cfg.roll_output_path = to_absolute_path(cfg.roll_output_path) 31 | Path(cfg.roll_output_path).mkdir(parents=True, exist_ok=True) 32 | 33 | roll_h5_path = os.path.join(cfg.roll_output_path, cfg.roll_name) 34 | 35 | 36 | target_processor = TargetProcessor(frames_per_second=frames_per_second, 37 | begin_note=21, 38 | classes_num=88) 39 | 40 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 41 | with h5py.File(roll_h5_path, "w") as hf: 42 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 43 | for pkl_path in tqdm.tqdm(sorted(pkl_list)): 44 | piece_name = pkl_path.name[:-4] 45 | note_event = pickle.load(open(pkl_path, 'rb')) 46 | # valid_length = len(h5[piece_name][()]) 47 | # audio, _ = torchaudio.load(to_absolute_path(f'../MTAT/{piece_name}')) 48 | valid_length = 16000*30 49 | segment_seconds = valid_length/SAMPLE_RATE 50 | 51 | 52 | flat_frame_roll = event2roll(0, 53 | segment_seconds, 54 | note_event, 55 | target_processor) 56 | 57 | hf.create_dataset(piece_name.replace("['", "").replace("']", ""), data=flat_frame_roll) 58 | 59 | 60 | def event2roll(start_time, segment_seconds, note_events, target_processor): 61 | keys = list(note_events.keys()) 62 | key = keys[0] 63 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 64 | segment_seconds=segment_seconds, 65 | note_events=note_events[key], 66 | ) 67 | frame_roll = target_dict_per_plugin['frame_roll'] 68 | placeholder = np.zeros_like(frame_roll).astype('bool') 69 | placeholder = np.expand_dims(placeholder,0) 70 | placeholder = placeholder.repeat(39,0) 71 | 72 | placeholder[MIDIClassName2class_idx[key]] = frame_roll 73 | for key in keys[1:]: 74 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 75 | segment_seconds=segment_seconds, 76 | note_events=note_events[key], 77 | ) 78 | placeholder[MIDIClassName2class_idx[key]] = target_dict_per_plugin['frame_roll'] 79 | 80 | return placeholder 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /pkl2pianoroll_MSD.sh: -------------------------------------------------------------------------------- 1 | sleep 1h 2 | date 3 | python pkl2pianoroll_MSD.py 4 | -------------------------------------------------------------------------------- /pkl2pianoroll_MTAT.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import to_absolute_path 3 | import h5py 4 | import numpy as np 5 | import pickle 6 | from End2End.target_processors import TargetProcessor 7 | from pathlib import Path 8 | import tqdm 9 | from hydra.utils import to_absolute_path 10 | import os 11 | from End2End.MIDI_program_map import ( 12 | MIDI_Class_NUM, 13 | MIDIClassName2class_idx, 14 | class_idx2MIDIClass, 15 | ) 16 | import torch 17 | import torchaudio 18 | 19 | # constants 20 | frames_per_second=100 21 | SAMPLE_RATE=16000 22 | 23 | 24 | 25 | 26 | @hydra.main(config_path="End2End/config/", config_name="pkl2pianoroll") 27 | def main(cfg): 28 | 29 | audio_h5_path = to_absolute_path(cfg.audio_h5_path) 30 | pkl_path = to_absolute_path(cfg.pkl_path) 31 | 32 | # output name based on the original audio_h5_path name 33 | roll_name = os.path.basename(audio_h5_path).split('_')[0] + '_roll.h5' 34 | roll_h5_path = os.path.join(to_absolute_path(cfg.roll_output_path), roll_name) 35 | 36 | 37 | target_processor = TargetProcessor(frames_per_second=frames_per_second, 38 | begin_note=21, 39 | classes_num=88) 40 | 41 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 42 | with h5py.File(roll_h5_path, "w") as hf: 43 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 44 | for pkl_path in tqdm.tqdm(sorted(pkl_list)): 45 | piece_name = pkl_path.name[:-4] 46 | note_event = pickle.load(open(pkl_path, 'rb')) 47 | # valid_length = len(h5[piece_name][()]) 48 | audio, _ = torchaudio.load(to_absolute_path(f'../MTAT/{piece_name}')) 49 | valid_length = audio.shape[1] 50 | segment_seconds = valid_length/SAMPLE_RATE 51 | 52 | 53 | flat_frame_roll = event2roll(0, 54 | segment_seconds, 55 | note_event, 56 | target_processor) 57 | hf.create_dataset(piece_name, data=flat_frame_roll) 58 | 59 | 60 | def event2roll(start_time, segment_seconds, note_events, target_processor): 61 | keys = list(note_events.keys()) 62 | key = keys[0] 63 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 64 | segment_seconds=segment_seconds, 65 | note_events=note_events[key], 66 | ) 67 | frame_roll = target_dict_per_plugin['frame_roll'] 68 | placeholder = np.zeros_like(frame_roll).astype('bool') 69 | placeholder = np.expand_dims(placeholder,0) 70 | placeholder = placeholder.repeat(39,0) 71 | 72 | placeholder[MIDIClassName2class_idx[key]] = frame_roll 73 | for key in keys[1:]: 74 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 75 | segment_seconds=segment_seconds, 76 | note_events=note_events[key], 77 | ) 78 | placeholder[MIDIClassName2class_idx[key]] = target_dict_per_plugin['frame_roll'] 79 | 80 | return placeholder 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | main() -------------------------------------------------------------------------------- /pkl2sparsepianoroll_MSD.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.utils import to_absolute_path 3 | import sparse 4 | import h5py 5 | import numpy as np 6 | import pickle 7 | from End2End.target_processors import TargetProcessor 8 | from pathlib import Path 9 | import tqdm 10 | from hydra.utils import to_absolute_path 11 | import os 12 | import sys 13 | from End2End.MIDI_program_map import ( 14 | MIDI_Class_NUM, 15 | MIDIClassName2class_idx, 16 | class_idx2MIDIClass, 17 | ) 18 | import torch 19 | import torchaudio 20 | 21 | # constants 22 | frames_per_second=100 23 | SAMPLE_RATE=16000 24 | 25 | 26 | 27 | @hydra.main(config_path="End2End/config/", config_name="pkl2sparsepianoroll_MSD") 28 | def main(cfg): 29 | 30 | pkl_path = to_absolute_path(cfg.pkl_path) 31 | cfg.roll_output_path = to_absolute_path(cfg.roll_output_path) 32 | Path(cfg.roll_output_path).mkdir(parents=True, exist_ok=True) 33 | 34 | 35 | 36 | target_processor = TargetProcessor(frames_per_second=frames_per_second, 37 | begin_note=21, 38 | classes_num=88) 39 | 40 | pkl_list = list(Path(pkl_path).glob('*.pkl')) 41 | for pkl_path in tqdm.tqdm(sorted(pkl_list)): 42 | piece_name = pkl_path.name[:-4] 43 | note_event = pickle.load(open(pkl_path, 'rb')) 44 | # valid_length = len(h5[piece_name][()]) 45 | # audio, _ = torchaudio.load(to_absolute_path(f'../MTAT/{piece_name}')) 46 | valid_length = 16000*30 47 | segment_seconds = valid_length/SAMPLE_RATE 48 | 49 | 50 | flat_frame_roll = event2roll(0, 51 | segment_seconds, 52 | note_event, 53 | target_processor) 54 | 55 | # The native torch is having a bug for bool 56 | # https://github.com/pytorch/pytorch/issues/49977 57 | # sparse_roll = torch.tensor(flat_frame_roll).to_sparse() 58 | # torch.save(sparse_roll, os.path.join(cfg.roll_output_path, f'''{piece_name.replace("['", "").replace("']", "")}.pt''')) 59 | sparse_roll = sparse.COO(flat_frame_roll) 60 | sparse.save_npz(os.path.join(cfg.roll_output_path, 61 | f'''{piece_name.replace("['", "").replace("']", "")}'''), 62 | sparse_roll) 63 | 64 | 65 | def event2roll(start_time, segment_seconds, note_events, target_processor): 66 | keys = list(note_events.keys()) 67 | key = keys[0] 68 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 69 | segment_seconds=segment_seconds, 70 | note_events=note_events[key], 71 | ) 72 | frame_roll = target_dict_per_plugin['frame_roll'] 73 | placeholder = np.zeros_like(frame_roll).astype('bool') 74 | placeholder = np.expand_dims(placeholder,0) 75 | placeholder = placeholder.repeat(39,0) 76 | 77 | placeholder[MIDIClassName2class_idx[key]] = frame_roll 78 | for key in keys[1:]: 79 | target_dict_per_plugin = target_processor.pkl2roll(start_time=0, 80 | segment_seconds=segment_seconds, 81 | note_events=note_events[key], 82 | ) 83 | placeholder[MIDIClassName2class_idx[key]] = target_dict_per_plugin['frame_roll'] 84 | 85 | return placeholder 86 | 87 | 88 | 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.10.1+cu111 3 | torchaudio==0.10.1+cu111 4 | torchmetrics==0.6.2 5 | torchvision==0.11.2+cu111 6 | mir_eval==0.6 7 | h5py==2.10.0 8 | tqdm==4.56.0 9 | mido==1.2.9 10 | pytorch_lightning==1.4.5 11 | numpy==1.18.5 12 | sox==1.4.1 13 | pretty_midi==0.2.9 14 | pyfluidsynth==1.3.0 15 | matplotlib==3.3.4 16 | pandas==1.2.1 17 | librosa==0.8.1 18 | soundfile==0.10.2 19 | auditok 20 | einops==0.3.0 21 | black==20.8b1 22 | tabulate 23 | hydra-core 24 | git+https://github.com/KinWaiCheuk/slakh_loader 25 | -------------------------------------------------------------------------------- /roll_convert2channel.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from IPython.display import Audio 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | # roll39_h5_path = './roll/leadsheet_roll.h5' 8 | # roll2_h5_path = './roll2/leadsheet_roll2.h5' 9 | for file in os.listdir('./roll'): 10 | h5_file = os.path.basename(file) 11 | if '.ipynb_checkpoints' not in file and 'leadsheet_roll' not in file: 12 | roll39_h5_path = os.path.join('./roll/', h5_file) 13 | roll2_h5_path = os.path.join('./roll2', h5_file[:-3]+'2.h5') 14 | with h5py.File(roll39_h5_path, 'r') as h5roll: 15 | with h5py.File(roll2_h5_path, "w") as hf: 16 | name_list = list(h5roll.keys()) 17 | 18 | for i in name_list: 19 | pitched = h5roll[i][()][:38] 20 | drums = h5roll[i][()][38] 21 | 22 | placeholder = np.zeros((2, *pitched.shape[1:])).astype('bool') 23 | placeholder[0] = np.any(pitched, axis=0) 24 | placeholder[1] = drums 25 | 26 | hf.create_dataset(i, data=placeholder) 27 | -------------------------------------------------------------------------------- /roll_convert_sparse.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | import torch 6 | from pathlib import Path 7 | 8 | # roll39_h5_path = './roll/leadsheet_roll.h5' 9 | # roll2_h5_path = './roll2/leadsheet_roll2.h5' 10 | input_path = '/opt/tiger/kinwai/jointist/MSD/roll/MSD_test_part01.h5' 11 | output_path = './MSD/sparse_roll/' 12 | Path(output_path).mkdir(parents=True, exist_ok=True) 13 | # for file in os.listdir('./MSD/roll/MTAT_roll39_full'): 14 | 15 | # file = 'MSD_test_part02' 16 | 17 | with h5py.File(input_path, 'r') as h5roll: 18 | name_list = list(h5roll.keys()) 19 | for i in name_list: 20 | sparse_roll = torch.tensor(h5roll[i][()]).to_sparse() 21 | torch.save(sparse_roll, os.path.join(output_path, f"{i}.pt")) 22 | # sparse_roll = sparse.COO(h5roll[i][()]) 23 | # hf.create_dataset(i, data=sparse_roll) 24 | -------------------------------------------------------------------------------- /slakh2100_dataprocessing.sh: -------------------------------------------------------------------------------- 1 | WORKSPACE=$PWD 2 | SLAKH_DATASET_DIR="${WORKSPACE}/datasets/slakh2100/slakh2100_flac" 3 | WAVEFORM_HDF5S_DIR="${WORKSPACE}/hdf5s/waveforms" 4 | 5 | # Download audio files 6 | ./scripts/dataset-slakh2100/download_slakh2100_from_hdfs.sh 7 | 8 | # pack audios 9 | # Pack audio files into hdf5 files. 10 | python3 End2End/dataset_creation/create_slakh2100.py pack_audios_to_hdf5s \ 11 | --audios_dir=$SLAKH_DATASET_DIR \ 12 | --hdf5s_dir=$WAVEFORM_HDF5S_DIR 13 | 14 | #Create pkl files for instrument classificatin 15 | for SPLIT in 'train' 'validation' 'test' 16 | do 17 | python3 End2End/create_notes_for_instruments_classification_MIDI_class.py create_notes \ 18 | --path_dataset=$SLAKH_DATASET_DIR \ 19 | --workspace=$WORKSPACE \ 20 | --split=$SPLIT 21 | done 22 | 23 | # # ====== Train piano roll transcription 24 | # # Prepare slakh2100 into piano + drums MIDI files. 25 | # AUDIOS_DIR="/opt/tiger/debugtest/jointist/datasets/slakh2100/slakh2100_flac" 26 | # CONFIG_NAME="config_4" # Piano + drums 27 | # CONFIG_CSV_PATH="./jointist/dataset_creation/midi_track_group_${CONFIG_NAME}.csv" 28 | # PATH_DATASET_PROCESSED="${WORKSPACE}/dataset_processed/closed_set_${CONFIG_NAME}" 29 | # python3 jointist/dataset_creation/prepare_closed_set.py \ 30 | # --config_type="plugin" \ 31 | # --config_csv_path=$CONFIG_CSV_PATH \ 32 | # --path_dataset=$AUDIOS_DIR \ 33 | # --path_dataset_processed=$PATH_DATASET_PROCESSED 34 | 35 | # # Pack MIDI events 36 | # PROCESSED_MIDIS_DIR="${WORKSPACE}/dataset_processed/closed_set_config_4" 37 | # MIDI_EVENTS_HDF5S_DIR="${WORKSPACE}/pickles/prettymidi_events/closed_set_config_4" 38 | # python3 jointist/dataset_creation/create_slakh2100.py pack_midi_events_to_hdf5s \ 39 | # --processed_midis_dir=$PROCESSED_MIDIS_DIR \ 40 | # --hdf5s_dir=$MIDI_EVENTS_HDF5S_DIR -------------------------------------------------------------------------------- /songs.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0306b730c41277a71463a4a12975b32bd5ec823ae213326395b3f7b27ffc9e67 3 | size 75269200 4 | -------------------------------------------------------------------------------- /test_openmic_DETR_Hungarian.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import torch.nn.functional as F 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.plugins import DDPPlugin 7 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 8 | 9 | from End2End.openmic import Openmic2018DataModule 10 | from End2End.Task import DETR_IR 11 | from End2End.MIDI_program_map import (MIDI_PROGRAM_NUM, 12 | MIDIProgramName2class_idx, 13 | class_idx2MIDIProgramName, 14 | MIDI_Class_NUM, 15 | MIDIClassName2class_idx, 16 | class_idx2MIDIClass, 17 | W_MIDI_Class_NUM, 18 | W_MIDIClassName2class_idx, 19 | W_class_idx2MIDIClass, 20 | ) 21 | import End2End.models.detr as DETR_Model 22 | 23 | from jointist.config import ( 24 | BEGIN_NOTE, 25 | PLUGIN_LABELS_NUM, 26 | FRAMES_PER_SECOND, 27 | SAMPLE_RATE, 28 | SEGMENT_SECONDS, 29 | VELOCITY_SCALE, 30 | TAGGING_SEGMENT_SECONDS, 31 | PLUGIN_NAME_TO_INSTRUMENT, 32 | PLUGIN_LB_TO_IX, 33 | PLUGIN_IX_TO_LB 34 | ) 35 | 36 | from jointist.data.augmentors import Augmentor 37 | from jointist.lr_schedulers import get_lr_lambda 38 | # from jointist.models.instruments_classification_models import get_model_class 39 | 40 | # Libraries related to hydra 41 | import hydra 42 | from hydra.utils import to_absolute_path 43 | from omegaconf import OmegaConf 44 | 45 | @hydra.main(config_path="End2End/config/", config_name="openmic-DETR_Hungarian_IR") 46 | def main(cfg): 47 | 48 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'openmic_waveforms')) 49 | 50 | if cfg.MIDI_MAPPING.type=='plugin_names': 51 | cfg.MIDI_MAPPING.plugin_labels_num = PLUGIN_LABELS_NUM 52 | cfg.MIDI_MAPPING.NAME_TO_IX = PLUGIN_LB_TO_IX 53 | cfg.MIDI_MAPPING.IX_TO_NAME = PLUGIN_IX_TO_LB 54 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3/') 55 | elif cfg.MIDI_MAPPING.type=='MIDI_programs': 56 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_PROGRAM_NUM 57 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIProgramName2class_idx 58 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIProgramName 59 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_instrument/') 60 | elif cfg.MIDI_MAPPING.type=='MIDI_class': 61 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 62 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 63 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 64 | cfg.datamodule.notes_pkls_dir = to_absolute_path('datasets/openmic-2018') 65 | elif cfg.MIDI_MAPPING.type=='W_MIDI_class': 66 | cfg.MIDI_MAPPING.plugin_labels_num = W_MIDI_Class_NUM 67 | cfg.MIDI_MAPPING.NAME_TO_IX = W_MIDIClassName2class_idx 68 | cfg.MIDI_MAPPING.IX_TO_NAME = W_class_idx2MIDIClass 69 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 70 | 71 | 72 | experiment_name = ( 73 | f"Decoder_L{cfg.model.args.num_decoder_layers}-" 74 | f"empty_{cfg.model.eos_coef}-" 75 | f"feature_weigh_{cfg.model.args.feature_weight}-" 76 | f"{cfg.model.type}-" 77 | f"hidden={cfg.model.args.hidden_dim}-" 78 | f"Q={cfg.model.args.num_Q}-" 79 | f"LearnPos={cfg.model.args.learnable_pos}-" 80 | f"aux_loss-bsz={cfg.batch_size}-" 81 | f"audio_len={cfg.segment_seconds}" 82 | ) 83 | 84 | 85 | 86 | # data module # augmentor 87 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 88 | data_module = Openmic2018DataModule(**cfg.datamodule, MIDI_MAPPING=cfg.MIDI_MAPPING) 89 | # data_module.setup() 90 | 91 | # model 92 | # Model = getattr(IR_Model, cfg.model.type) 93 | Model = getattr(DETR_Model, cfg.model.type) 94 | model = Model(num_classes=cfg.MIDI_MAPPING.plugin_labels_num, **cfg.model.args) 95 | # model = Model(classes_num=cfg.MIDI_MAPPING.plugin_labels_num) 96 | 97 | # PL model 98 | pl_model = DETR_IR.load_from_checkpoint(to_absolute_path(cfg.checkpoint_path), 99 | network=model, 100 | learning_rate=cfg.lr, 101 | lr_lambda=None, 102 | cfg=cfg) 103 | 104 | trainer = pl.Trainer( 105 | **cfg.trainer, 106 | ) 107 | 108 | 109 | # Fit, evaluate, and save checkpoints. 110 | trainer.test(pl_model, data_module) 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /test_separation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.plugins import DDPPlugin 7 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 8 | 9 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor, FullPreprocessor 10 | from End2End.tasks.separation import Separation 11 | import End2End.models.separation as SeparationModel 12 | 13 | from End2End.MIDI_program_map import ( 14 | MIDI_Class_NUM, 15 | MIDIClassName2class_idx, 16 | class_idx2MIDIClass, 17 | ) 18 | from End2End.data.augmentors import Augmentor 19 | from End2End.lr_schedulers import get_lr_lambda 20 | import End2End.losses as Losses 21 | 22 | # Libraries related to hydra 23 | import hydra 24 | from hydra.utils import to_absolute_path 25 | 26 | 27 | 28 | 29 | @hydra.main(config_path="End2End/config/", config_name="separation_config") 30 | def main(cfg): 31 | r"""Train an instrument classification system, evluate, and save checkpoints. 32 | 33 | Args: 34 | workspace: str, path 35 | config_yaml: str, path 36 | gpus: int 37 | mini_data: bool 38 | 39 | Returns: 40 | None 41 | """ 42 | 43 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 44 | 45 | if cfg.MIDI_MAPPING.type=='plugin_names': 46 | cfg.MIDI_MAPPING.plugin_labels_num = PLUGIN_LABELS_NUM 47 | cfg.MIDI_MAPPING.NAME_TO_IX = PLUGIN_LB_TO_IX 48 | cfg.MIDI_MAPPING.IX_TO_NAME = PLUGIN_IX_TO_LB 49 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3/') 50 | elif cfg.MIDI_MAPPING.type=='MIDI_class': 51 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 52 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 53 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 54 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 55 | else: 56 | raise ValueError(f"Please choose the correct MIDI_MAPPING.type") 57 | 58 | Model = getattr(SeparationModel, cfg.separation.model.type) 59 | if cfg.separation.model.type=='CondUNet': 60 | model = Model(**cfg.separation.model.args) 61 | cfg.transcription = False 62 | elif cfg.separation.model.type=='TCondUNet': 63 | model = Model(**cfg.separation.model.args, spec_cfg=cfg.separation.feature) 64 | cfg.transcription = True 65 | else: 66 | raise ValueError("please choose the correct model type") 67 | 68 | 69 | # augmentor 70 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 71 | 72 | # data module 73 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING) 74 | data_module.setup('test') 75 | 76 | experiment_name = ( 77 | f"Eval-{cfg.separation.model.type}-" 78 | f"{cfg.MIDI_MAPPING.type}-" 79 | f"{cfg.inst_sampler.mode}_{cfg.inst_sampler.temp}_" 80 | f"{cfg.inst_sampler.samples}p_{cfg.inst_sampler.neg_samples}" 81 | f"noise{cfg.inst_sampler.audio_noise}-" 82 | f"csize={MIDI_Class_NUM}-" 83 | f"bz={cfg.batch_size}" 84 | ) 85 | DataPreprocessor = End2EndBatchDataPreprocessor 86 | # loss function 87 | loss_function = getattr(Losses, cfg.separation.model.loss_types) 88 | 89 | # callbacks 90 | # save checkpoint callback 91 | 92 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 93 | 94 | # learning rate reduce function. 95 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 96 | 97 | checkpoint_path = to_absolute_path(cfg.separation.evaluation.checkpoint_path) 98 | # pl_model = Separation.load_from_checkpoint(checkpoint_path, 99 | # network=model, 100 | # loss_function=loss_function, 101 | # lr_lambda=None, 102 | # batch_data_preprocessor=DataPreprocessor(**cfg.separation.batchprocess), 103 | # cfg=cfg 104 | # ) 105 | ckpt = torch.load(checkpoint_path) 106 | 107 | new_state_dict = {} 108 | for key in ckpt['state_dict'].keys(): 109 | if 'separation_model' in key: 110 | new_key = '.'.join(key.split('.')[2:]) 111 | new_state_dict[new_key] = ckpt['state_dict'][key] 112 | if 'network' in key: 113 | new_key = '.'.join(key.split('.')[1:]) 114 | new_state_dict[new_key] = ckpt['state_dict'][key] 115 | 116 | model.load_state_dict(new_state_dict) 117 | 118 | 119 | pl_model = Separation( 120 | network=model, 121 | loss_function=loss_function, 122 | lr_lambda=None, 123 | batch_data_preprocessor=DataPreprocessor(**cfg.separation.batchprocess), 124 | cfg=cfg 125 | ) 126 | 127 | if cfg.trainer.gpus==0: # If CPU is used, disable syncbatch norm 128 | cfg.trainer.sync_batchnorm=False 129 | 130 | trainer = pl.Trainer( 131 | **cfg.trainer, 132 | callbacks=None, 133 | plugins=[DDPPlugin(find_unused_parameters=False)], 134 | logger=logger 135 | ) 136 | 137 | # Fit, evaluate, and save checkpoints. 138 | trainer.test(pl_model, data_module) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /test_tseparation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.plugins import DDPPlugin 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | 8 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor 9 | from End2End.tasks.separation import Separation 10 | from End2End.tasks.transcription import Transcription 11 | from End2End.tasks.t_separation import TSeparation 12 | 13 | import End2End.models.separation as SeparationModel 14 | import End2End.models.transcription.combined as TranscriptionModel 15 | 16 | from End2End.MIDI_program_map import ( 17 | MIDI_Class_NUM, 18 | MIDIClassName2class_idx, 19 | class_idx2MIDIClass, 20 | ) 21 | from End2End.data.augmentors import Augmentor 22 | from End2End.lr_schedulers import get_lr_lambda 23 | from End2End.losses import get_loss_function 24 | import End2End.losses as Losses 25 | 26 | # Libraries related to hydra 27 | import hydra 28 | from hydra.utils import to_absolute_path 29 | 30 | 31 | 32 | 33 | @hydra.main(config_path="End2End/config/", config_name="tseparation") 34 | def main(cfg): 35 | r"""Train an instrument classification system, evluate, and save checkpoints. 36 | 37 | Args: 38 | workspace: str, path 39 | config_yaml: str, path 40 | gpus: int 41 | mini_data: bool 42 | 43 | Returns: 44 | None 45 | """ 46 | 47 | 48 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 49 | checkpoint_path = to_absolute_path(cfg.checkpoint_path) 50 | 51 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 52 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 53 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 54 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 55 | 56 | experiment_name = ( 57 | f"Eval-TSeparation-" 58 | f"{cfg.inst_sampler.samples}p{cfg.inst_sampler.neg_samples}n-" 59 | f"ste_roll" 60 | ) 61 | 62 | # augmentor 63 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 64 | 65 | # data module 66 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING) 67 | data_module.setup('test') 68 | 69 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 70 | 71 | 72 | # loss function 73 | loss_function = getattr(Losses, cfg.separation.model.loss_types) 74 | model = getattr(SeparationModel, cfg.separation.model.type)\ 75 | (**cfg.separation.model.args, spec_cfg=cfg.separation.feature) 76 | 77 | separation_model = Separation( 78 | network=model, 79 | loss_function=loss_function, 80 | lr_lambda=lr_lambda, 81 | batch_data_preprocessor=None, 82 | cfg=cfg 83 | ) 84 | 85 | # defining transcription model 86 | Model = getattr(TranscriptionModel, cfg.transcription.model.type) 87 | model = Model(cfg, **cfg.transcription.model.args) 88 | loss_function = get_loss_function(cfg.transcription.model.loss_types) 89 | 90 | 91 | transcription_model = Transcription( 92 | network=model, 93 | loss_function=loss_function, 94 | lr_lambda=lr_lambda, 95 | batch_data_preprocessor=None, 96 | cfg=cfg 97 | ) 98 | 99 | 100 | # defining jointist 101 | tseparation = TSeparation.load_from_checkpoint( 102 | checkpoint_path, 103 | transcription_model = transcription_model, 104 | separation_model = separation_model, 105 | batch_data_preprocessor = End2EndBatchDataPreprocessor(cfg.MIDI_MAPPING, 106 | **cfg.inst_sampler, 107 | transcription=True, 108 | source_separation=True), 109 | lr_lambda=lr_lambda, 110 | cfg=cfg 111 | ) 112 | 113 | 114 | # defining Trainer 115 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 116 | trainer = pl.Trainer( 117 | **cfg.trainer, 118 | plugins=[DDPPlugin(find_unused_parameters=False)], 119 | logger=logger 120 | ) 121 | 122 | # Fit, evaluate, and save checkpoints. 123 | trainer.test(tseparation, data_module) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /train_jointist.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.plugins import DDPPlugin 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | 8 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor 9 | from End2End.tasks import Jointist, Transcription, Detection 10 | import End2End.models.instrument_detection as DectectionModel 11 | from End2End.models.transcription.instruments_filter_models import get_model_class 12 | 13 | from End2End.MIDI_program_map import ( 14 | MIDI_Class_NUM, 15 | MIDIClassName2class_idx, 16 | class_idx2MIDIClass, 17 | ) 18 | from End2End.data.augmentors import Augmentor 19 | from End2End.lr_schedulers import get_lr_lambda 20 | from End2End.losses import get_loss_function 21 | 22 | # Libraries related to hydra 23 | import hydra 24 | from hydra.utils import to_absolute_path 25 | 26 | 27 | 28 | 29 | @hydra.main(config_path="End2End/config/", config_name="Jointist") 30 | def main(cfg): 31 | r"""Train an instrument classification system, evluate, and save checkpoints. 32 | 33 | Args: 34 | workspace: str, path 35 | config_yaml: str, path 36 | gpus: int 37 | mini_data: bool 38 | 39 | Returns: 40 | None 41 | """ 42 | 43 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 44 | 45 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 46 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 47 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 48 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 49 | 50 | experiment_name = "jointist".format(cfg.gpus, cfg.batch_size, cfg.segment_seconds) 51 | 52 | # augmentor 53 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 54 | 55 | # data module 56 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING) 57 | 58 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 59 | # defining transcription model 60 | Model = get_model_class(cfg.transcription.model.type) 61 | model = Model(cfg.feature, **cfg.transcription.model.args) 62 | loss_function = get_loss_function(cfg.transcription.model.loss_types) 63 | transcription_model = Transcription( 64 | network=model, 65 | loss_function=loss_function, 66 | lr_lambda=lr_lambda, 67 | batch_data_preprocessor=End2EndBatchDataPreprocessor(cfg.MIDI_MAPPING, 'random'), 68 | cfg=cfg 69 | ) 70 | 71 | 72 | # defining instrument detection model 73 | Model = getattr(DectectionModel, cfg.detection.model.type) 74 | model = Model(num_classes=cfg.MIDI_MAPPING.plugin_labels_num, spec_args=cfg.feature, **cfg.detection.model.args) 75 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 76 | detection_model = Detection( 77 | network=model, 78 | lr_lambda=lr_lambda, 79 | cfg=cfg 80 | ) 81 | 82 | # defining jointist 83 | jointist = Jointist( 84 | detection_model=detection_model, 85 | transcription_model=transcription_model, 86 | lr_lambda=lr_lambda, 87 | cfg=cfg 88 | ) 89 | 90 | 91 | # defining Trainer 92 | checkpoint_callback = ModelCheckpoint(**cfg.checkpoint, auto_insert_metric_name=False) 93 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 94 | callbacks = [checkpoint_callback, lr_monitor] 95 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 96 | trainer = pl.Trainer( 97 | **cfg.trainer, 98 | callbacks=callbacks, 99 | plugins=[DDPPlugin(find_unused_parameters=True)], 100 | logger=logger 101 | ) 102 | 103 | # Fit, evaluate, and save checkpoints. 104 | trainer.fit(jointist, data_module) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /train_separation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.plugins import DDPPlugin 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | 8 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor, FullPreprocessor 9 | from End2End.tasks.separation import Separation 10 | import End2End.models.separation as SeparationModel 11 | 12 | from End2End.MIDI_program_map import ( 13 | MIDI_Class_NUM, 14 | MIDIClassName2class_idx, 15 | class_idx2MIDIClass, 16 | ) 17 | from End2End.data.augmentors import Augmentor 18 | from End2End.lr_schedulers import get_lr_lambda 19 | import End2End.losses as Losses 20 | 21 | # Libraries related to hydra 22 | import hydra 23 | from hydra.utils import to_absolute_path 24 | 25 | 26 | 27 | 28 | @hydra.main(config_path="End2End/config/", config_name="separation_config") 29 | def main(cfg): 30 | r"""Train an instrument classification system, evluate, and save checkpoints. 31 | 32 | Args: 33 | workspace: str, path 34 | config_yaml: str, path 35 | gpus: int 36 | mini_data: bool 37 | 38 | Returns: 39 | None 40 | """ 41 | 42 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 43 | 44 | if cfg.MIDI_MAPPING.type=='plugin_names': 45 | cfg.MIDI_MAPPING.plugin_labels_num = PLUGIN_LABELS_NUM 46 | cfg.MIDI_MAPPING.NAME_TO_IX = PLUGIN_LB_TO_IX 47 | cfg.MIDI_MAPPING.IX_TO_NAME = PLUGIN_IX_TO_LB 48 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes3/') 49 | elif cfg.MIDI_MAPPING.type=='MIDI_class': 50 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 51 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 52 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 53 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 54 | else: 55 | raise ValueError(f"Please choose the correct MIDI_MAPPING.type") 56 | 57 | Model = getattr(SeparationModel, cfg.separation.model.type) 58 | if cfg.separation.model.type=='CondUNet': 59 | model = Model(**cfg.separation.model.args) 60 | cfg.transcription = False 61 | elif cfg.separation.model.type=='TCondUNet': 62 | model = Model(**cfg.separation.model.args, spec_cfg=cfg.feature) 63 | cfg.transcription = True 64 | else: 65 | raise ValueError("please choose the correct model type") 66 | 67 | 68 | # augmentor 69 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 70 | 71 | # data module 72 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING) 73 | data_module.setup() 74 | 75 | experiment_name = ( 76 | f"{cfg.separation.model.type}-" 77 | f"{cfg.inst_sampler.samples}p{cfg.inst_sampler.neg_samples}n-" 78 | f"csize={MIDI_Class_NUM}" 79 | ) 80 | DataPreprocessor = End2EndBatchDataPreprocessor 81 | # loss function 82 | loss_function = getattr(Losses, cfg.separation.model.loss_types) 83 | 84 | # callbacks 85 | # save checkpoint callback 86 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 87 | checkpoint_callback = ModelCheckpoint(**cfg.checkpoint, 88 | auto_insert_metric_name=False) 89 | callbacks = [checkpoint_callback, lr_monitor] 90 | 91 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 92 | 93 | # learning rate reduce function. 94 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 95 | 96 | 97 | pl_model = Separation( 98 | network=model, 99 | loss_function=loss_function, 100 | lr_lambda=lr_lambda, 101 | batch_data_preprocessor=DataPreprocessor(**cfg.separation.batchprocess), 102 | cfg=cfg 103 | ) 104 | 105 | if cfg.trainer.gpus==0: # If CPU is used, disable syncbatch norm 106 | cfg.trainer.sync_batchnorm=False 107 | 108 | trainer = pl.Trainer( 109 | **cfg.trainer, 110 | callbacks=callbacks, 111 | plugins=[DDPPlugin(find_unused_parameters=False)], 112 | logger=logger 113 | ) 114 | 115 | # Fit, evaluate, and save checkpoints. 116 | trainer.fit(pl_model, data_module) 117 | trainer.test(pl_model, data_module.test_dataloader()) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /train_tseparation.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.plugins import DDPPlugin 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | 8 | from End2End.Data import DataModuleEnd2End, End2EndBatchDataPreprocessor 9 | from End2End.tasks.separation import Separation 10 | from End2End.tasks.transcription import Transcription 11 | from End2End.tasks.t_separation import TSeparation 12 | 13 | import End2End.models.separation as SeparationModel 14 | import End2End.models.transcription.combined as TranscriptionModel 15 | 16 | from End2End.MIDI_program_map import ( 17 | MIDI_Class_NUM, 18 | MIDIClassName2class_idx, 19 | class_idx2MIDIClass, 20 | ) 21 | from End2End.data.augmentors import Augmentor 22 | from End2End.lr_schedulers import get_lr_lambda 23 | from End2End.losses import get_loss_function 24 | import End2End.losses as Losses 25 | 26 | # Libraries related to hydra 27 | import hydra 28 | from hydra.utils import to_absolute_path 29 | 30 | 31 | 32 | 33 | @hydra.main(config_path="End2End/config/", config_name="tseparation") 34 | def main(cfg): 35 | r"""Train an instrument classification system, evluate, and save checkpoints. 36 | 37 | Args: 38 | workspace: str, path 39 | config_yaml: str, path 40 | gpus: int 41 | mini_data: bool 42 | 43 | Returns: 44 | None 45 | """ 46 | 47 | 48 | cfg.datamodule.waveform_hdf5s_dir = to_absolute_path(os.path.join('hdf5s', 'waveforms')) 49 | if cfg.trainer.resume_from_checkpoint: # resume previous training when this is given 50 | cfg.trainer.resume_from_checkpoint = to_absolute_path(cfg.trainer.resume_from_checkpoint) 51 | 52 | cfg.MIDI_MAPPING.plugin_labels_num = MIDI_Class_NUM 53 | cfg.MIDI_MAPPING.NAME_TO_IX = MIDIClassName2class_idx 54 | cfg.MIDI_MAPPING.IX_TO_NAME = class_idx2MIDIClass 55 | cfg.datamodule.notes_pkls_dir = to_absolute_path('instruments_classification_notes_MIDI_class/') 56 | 57 | if cfg.transcription_weights!=None: 58 | experiment_name = ( 59 | f"TSeparation-{cfg.inst_sampler.samples}p{cfg.inst_sampler.neg_samples}n-" 60 | f"ste_roll-pretrainedT" 61 | ) 62 | else: 63 | experiment_name = ( 64 | f"TSeparation-{cfg.inst_sampler.samples}p{cfg.inst_sampler.neg_samples}n-" 65 | f"ste_roll" 66 | ) 67 | 68 | # augmentor 69 | augmentor = Augmentor(augmentation=cfg.augmentation) if cfg.augmentation else None 70 | 71 | # data module 72 | data_module = DataModuleEnd2End(**cfg.datamodule,augmentor=augmentor, MIDI_MAPPING=cfg.MIDI_MAPPING) 73 | data_module.setup() 74 | 75 | lr_lambda = partial(get_lr_lambda, **cfg.scheduler.args) 76 | 77 | 78 | # loss function 79 | loss_function = getattr(Losses, cfg.separation.model.loss_types) 80 | model = getattr(SeparationModel, cfg.separation.model.type)\ 81 | (**cfg.separation.model.args, spec_cfg=cfg.separation.feature) 82 | 83 | separation_model = Separation( 84 | network=model, 85 | loss_function=loss_function, 86 | lr_lambda=lr_lambda, 87 | batch_data_preprocessor=None, 88 | cfg=cfg 89 | ) 90 | 91 | # defining transcription model 92 | Model = getattr(TranscriptionModel, cfg.transcription.model.type) 93 | model = Model(cfg, **cfg.transcription.model.args) 94 | loss_function = get_loss_function(cfg.transcription.model.loss_types) 95 | 96 | if cfg.transcription_weights!=None: 97 | checkpoint_path = to_absolute_path(cfg.transcription_weights) 98 | transcription_model = Transcription.load_from_checkpoint(checkpoint_path, 99 | network=model, 100 | loss_function=loss_function, 101 | lr_lambda=lr_lambda, 102 | batch_data_preprocessor=None, 103 | cfg=cfg) 104 | else: 105 | transcription_model = Transcription( 106 | network=model, 107 | loss_function=loss_function, 108 | lr_lambda=lr_lambda, 109 | batch_data_preprocessor=None, 110 | cfg=cfg 111 | ) 112 | 113 | 114 | # defining jointist 115 | tseparation = TSeparation( 116 | transcription_model = transcription_model, 117 | separation_model = separation_model, 118 | batch_data_preprocessor = End2EndBatchDataPreprocessor(cfg.MIDI_MAPPING, 119 | **cfg.inst_sampler, 120 | transcription=True, 121 | source_separation=True), 122 | lr_lambda=lr_lambda, 123 | cfg=cfg 124 | ) 125 | 126 | 127 | # defining Trainer 128 | checkpoint_callback = ModelCheckpoint(**cfg.checkpoint, auto_insert_metric_name=False) 129 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 130 | callbacks = [checkpoint_callback, lr_monitor] 131 | logger = pl.loggers.TensorBoardLogger(save_dir='.', name=experiment_name) 132 | trainer = pl.Trainer( 133 | **cfg.trainer, 134 | callbacks=callbacks, 135 | plugins=[DDPPlugin(find_unused_parameters=False)], 136 | logger=logger, 137 | ) 138 | 139 | # Fit, evaluate, and save checkpoints. 140 | trainer.fit(tseparation, data_module) 141 | trainer.test(tseparation, data_module) 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /weights/link.txt: -------------------------------------------------------------------------------- 1 | https://www.dropbox.com/s/n0eerriphw65qsr/jointist_weights.zip --------------------------------------------------------------------------------