├── .gitignore ├── README.rst ├── align_dataset.py ├── align_real_world.py ├── confidence_experiment.py ├── corrupt_midi.py ├── create_data.py ├── db_utils.py ├── find_best_aligners.py ├── overview.ipynb ├── parameter_experiment_gp.py └── parameter_experiment_random.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by http://www.gitignore.io 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # Installer logs 28 | pip-log.txt 29 | pip-delete-this-directory.txt 30 | 31 | # Unit test / coverage reports 32 | htmlcov/ 33 | .tox/ 34 | .coverage 35 | .cache 36 | nosetests.xml 37 | coverage.xml 38 | 39 | # Translations 40 | *.mo 41 | *.pot 42 | 43 | # Django stuff: 44 | *.log 45 | 46 | # Sphinx documentation 47 | docs/_build/ 48 | 49 | 50 | ### OSX ### 51 | .DS_Store 52 | .AppleDouble 53 | .LSOverride 54 | 55 | # Icon must end with two \r 56 | Icon 57 | 58 | 59 | # Thumbnails 60 | ._* 61 | 62 | # Files that might appear on external disk 63 | .Spotlight-V100 64 | .Trashes 65 | 66 | # Directories potentially created on remote AFP share 67 | .AppleDB 68 | .AppleDesktop 69 | Network Trash Folder 70 | Temporary Items 71 | .apdisk 72 | 73 | ### vim ### 74 | [._]*.s[a-w][a-z] 75 | [._]s[a-w][a-z] 76 | *.un~ 77 | Session.vim 78 | .netrwhist 79 | *~ 80 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | alignment-search 2 | ================ 3 | 4 | Code for searching for good MIDI aliginment parameters. Incomplete dependency list: 5 | 6 | * `pretty_midi `_ 7 | * `djitw `_ 8 | * `librosa `_ 9 | * `joblib `_ 10 | * `craffel/spearmint/main_args `_ 11 | * Scipy/numpy 12 | 13 | Tentative general procedure: 14 | 15 | #. Collect 1000 non-corrupt MIDI files 16 | #. Create a corrupted version for each MIDI file, which reflects common minor issues found in audio-to-audio alignment (see create_data.py and corrupt_midi.py) 17 | #. Run GP-SMBO hyperparameter optimization over the standard DTW-based MIDI alignment scheme to choose the alignment architecture which best turns the lightly-corrupted MIDI files back into the original files (see experiment.py) 18 | #. Create a second corrupted version for each MIDI file, which reflects quite major corruptions which we don't expect an alignment scheme to fix in all cases 19 | #. Run hyperparameter optimization again, and either jointly optimize the alignment performance and a confidence score, OR optimize alignment performance and then find the score calculation scheme which produces the highest Spearman rank coefficient between the alignment score and the error of each alignment, over all highly-performing alignment architectures 20 | #. Collect real-world audio/MIDI alignment pairs and run the best-performing alignment architecture on them, manually annotate whether it was successful or not, and find the ROC-AUC score for the confidence measure 21 | -------------------------------------------------------------------------------- /align_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code for aligning an entire dataset 3 | ''' 4 | 5 | import glob 6 | import scipy.spatial 7 | import librosa 8 | import os 9 | import numpy as np 10 | import create_data 11 | import djitw 12 | import collections 13 | 14 | 15 | def load_dataset(file_glob): 16 | """Load in a collection of feature files created by create_data.py. 17 | 18 | Parameters 19 | ---------- 20 | file_glob : str 21 | Glob string for .npz files to load. 22 | 23 | Returns 24 | ------- 25 | data : list of dict 26 | Loaded dataset, sorted by filename. 27 | """ 28 | # Load in all npz's, casted to dict to force full loading 29 | return [dict(feature_file=os.path.abspath(d), **np.load(d)) 30 | for d in sorted(glob.glob(file_glob))] 31 | 32 | 33 | def align_dataset(params, data): 34 | ''' 35 | Perform alignment of all corrupted MIDIs in the database given the supplied 36 | parameters and compute the mean alignment error across all examples 37 | 38 | Parameters 39 | ---------- 40 | params : dict 41 | Dictionary of alignment parameters. 42 | 43 | data : list of dict 44 | Collection of things to align, loaded via load_dataset. 45 | 46 | Returns 47 | ------- 48 | results : list of dict 49 | List of dicts reporting the results for each alignment 50 | ''' 51 | def post_process_features(gram, beats): 52 | ''' 53 | Apply processing to a feature matrix given the supplied param values 54 | 55 | Parameters 56 | ---------- 57 | gram : np.ndarray 58 | Feature matrix, shape (n_features, n_samples) 59 | beats : np.ndarray 60 | Indices of beat locations in gram 61 | 62 | Returns 63 | ------- 64 | gram : np.ndarray 65 | Feature matrix, shape (n_samples, n_features), post-processed 66 | according to the values in `params` 67 | ''' 68 | # Convert to chroma 69 | if params['feature'] == 'chroma': 70 | gram = librosa.feature.chroma_cqt( 71 | C=gram, fmin=librosa.midi_to_hz(create_data.NOTE_START)) 72 | # Beat-synchronize the feature matrix 73 | if params['beat_sync']: 74 | gram = librosa.feature.sync(gram, beats, pad=False) 75 | # Compute log magnitude 76 | gram = librosa.logamplitude(gram, ref_power=gram.max()) 77 | # Normalize the feature vectors 78 | gram = librosa.util.normalize(gram, norm=params['norm']) 79 | # Standardize the feature vectors 80 | if params['standardize']: 81 | gram = scipy.stats.mstats.zscore(gram, axis=1) 82 | # Transpose it to (n_samples, n_features) and return it 83 | return gram.T 84 | # List for storing the results of each alignment 85 | results = collections.defaultdict(list) 86 | for n, d in enumerate(data): 87 | # If we are beat syncing and either of the beat frames are empty, we 88 | # can't really align, so just skip this file. 89 | if params['beat_sync'] and (d['orig_beat_frames'].size == 0 or 90 | d['corrupted_beat_frames'].size == 0): 91 | continue 92 | # Post proces the chosen feature matrices 93 | orig_gram = post_process_features( 94 | d['orig_gram'], d['orig_beat_frames']) 95 | corrupted_gram = post_process_features( 96 | d['corrupted_gram'], d['corrupted_beat_frames']) 97 | # Compute a distance matrix according to the supplied metric 98 | distance_matrix = scipy.spatial.distance.cdist( 99 | orig_gram, corrupted_gram, params['metric']) 100 | # If the entire distance matrix is non-finite, we can't align, skip 101 | if not np.any(np.isfinite(distance_matrix)): 102 | continue 103 | # Set any Nan/inf values to the largest distance 104 | distance_matrix[np.logical_not(np.isfinite(distance_matrix))] = np.max( 105 | distance_matrix[np.isfinite(distance_matrix)]) 106 | # Compute a band mask or set to None for no mask 107 | if params['band_mask']: 108 | mask = np.zeros(distance_matrix.shape, dtype=np.bool) 109 | djitw.band_mask(1 - params['gully'], mask) 110 | else: 111 | mask = None 112 | # Get DTW path and score 113 | add_pen = params['add_pen']*np.median(distance_matrix) 114 | p, q, score = djitw.dtw( 115 | distance_matrix, params['gully'], add_pen, mask=mask, inplace=0) 116 | if params['beat_sync']: 117 | # If we are beat syncing, we have to compare against beat times 118 | # so we index adjusted_times by the beat indices 119 | adjusted_times = d['adjusted_times'][d['orig_beat_frames']] 120 | corrupted_times = d['corrupted_beat_times'] 121 | else: 122 | corrupted_times = d['corrupted_times'] 123 | adjusted_times = d['adjusted_times'] 124 | # Compute the error, clipped to within .5 seconds 125 | error = np.clip( 126 | corrupted_times[q] - adjusted_times[p], -.5, .5) 127 | # Compute the mean error for this MIDI 128 | mean_error = np.mean(np.abs(error)) 129 | # If the mean error is NaN or inf for some reason, set it to max (.5) 130 | if not np.isfinite(mean_error): 131 | mean_error = .5 132 | results['mean_errors'].append(mean_error) 133 | results['raw_scores'].append(score) 134 | results['raw_scores_no_penalty'].append(distance_matrix[p, q].sum()) 135 | results['path_lengths'].append(p.shape[0]) 136 | results['distance_matrix_means'].append(np.mean( 137 | distance_matrix[p.min():p.max() + 1, q.min():q.max() + 1])) 138 | results['feature_files'].append(os.path.basename(d['feature_file'])) 139 | return results 140 | -------------------------------------------------------------------------------- /align_real_world.py: -------------------------------------------------------------------------------- 1 | """ Align all of the real-world data. """ 2 | import djitw 3 | import numpy as np 4 | import pretty_midi 5 | import librosa 6 | import create_data 7 | import json 8 | import scipy.spatial 9 | import os 10 | import joblib 11 | 12 | GULLY = .96 13 | REAL_WORLD_PATH = 'data/real_world' 14 | 15 | 16 | def process_one_file(audio_file, midi_file, output_midi_file, pair_file, 17 | diagnostics_file): 18 | """ 19 | Wrapper routine for loading in audio/MIDI data, aligning, and writing 20 | out the result. 21 | 22 | Parameters 23 | ---------- 24 | audio_file, midi_file, output_midi_file, pair_file, diagnostics_file : str 25 | Paths to the audio file to align, MIDI file to align, and paths where 26 | to write the aligned MIDI, the synthesized pair file, and the DTW 27 | diagnostics file. 28 | """ 29 | # Load in the audio data 30 | audio_data, _ = librosa.load(audio_file, sr=create_data.FS) 31 | # Compute the log-magnitude CQT of the data 32 | audio_cqt, audio_times = create_data.extract_cqt(audio_data) 33 | audio_cqt = librosa.logamplitude(audio_cqt, ref_power=audio_cqt.max()).T 34 | # Load and synthesize MIDI data 35 | midi_object = pretty_midi.PrettyMIDI(midi_file) 36 | midi_audio = midi_object.fluidsynth(fs=create_data.FS) 37 | # Compute log-magnitude CQT 38 | midi_cqt, midi_times = create_data.extract_cqt(midi_audio) 39 | midi_cqt = librosa.logamplitude(midi_cqt, ref_power=midi_cqt.max()).T 40 | # Compute cosine distance matrix 41 | distance_matrix = scipy.spatial.distance.cdist( 42 | midi_cqt, audio_cqt, 'cosine') 43 | # Get lowest cost path 44 | p, q, score = djitw.dtw( 45 | distance_matrix, GULLY, np.median(distance_matrix), inplace=False) 46 | # Normalize by path length 47 | score = score/len(p) 48 | # Normalize by distance matrix submatrix within path 49 | score = score/distance_matrix[p.min():p.max(), q.min():q.max()].mean() 50 | # Adjust the MIDI file 51 | midi_object.adjust_times(midi_times[p], audio_times[q]) 52 | # Write the result 53 | midi_object.write(output_midi_file) 54 | # Synthesize aligned MIDI 55 | midi_audio_aligned = midi_object.fluidsynth(fs=create_data.FS) 56 | # Adjust to the same size as audio 57 | if midi_audio_aligned.shape[0] > audio_data.shape[0]: 58 | midi_audio_aligned = midi_audio_aligned[:audio_data.shape[0]] 59 | else: 60 | trim_amount = audio_data.shape[0] - midi_audio_aligned.shape[0] 61 | midi_audio_aligned = np.append(midi_audio_aligned, 62 | np.zeros(trim_amount)) 63 | # Stack one in each channel 64 | librosa.output.write_wav( 65 | pair_file, np.array([midi_audio_aligned, audio_data]), create_data.FS) 66 | # Write out diagnostics 67 | with open(diagnostics_file, 'wb') as f: 68 | json.dump({'p': list(p), 'q': list(q), 'score': score}, f) 69 | 70 | if __name__ == '__main__': 71 | # Utility function for getting lists of all files of a certain type 72 | def get_file_list(extension): 73 | return [os.path.join(REAL_WORLD_PATH, '{}{}'.format(n, extension)) 74 | for n in range(1000)] 75 | # Construct lists of each type of files 76 | audios = get_file_list('.mp3') 77 | mids = get_file_list('.mid') 78 | out_mids = get_file_list('-aligned.mid') 79 | pairs = get_file_list('-pair.wav') 80 | diags = get_file_list('-diagnostics.js') 81 | # Process each file from each list in parallel 82 | joblib.Parallel(n_jobs=20, verbose=51)( 83 | joblib.delayed(process_one_file)(audio, mid, output_mid, pair, diag) 84 | for (audio, mid, output_mid, pair, diag) 85 | in zip(audios, mids, out_mids, pairs, diags)) 86 | -------------------------------------------------------------------------------- /confidence_experiment.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Run the confidence score search experiment. 3 | ''' 4 | 5 | import numpy as np 6 | import db_utils 7 | import scipy.stats 8 | import joblib 9 | import align_dataset 10 | try: 11 | import ujson as json 12 | except ImportError: 13 | import json 14 | import glob 15 | import os 16 | 17 | # Path to corrupted datasets, created by create_data.py 18 | CORRUPTED_HARD = 'data/corrupted_hard/*.npz' 19 | # Path to json results for parameter experiment 20 | PARAMETER_RESULTS_GLOB = 'results/parameter_experiment_gp/*.json' 21 | # Path where confidence experiment results should be written 22 | OUTPUT_RESULTS_PATH = 'results/confidence_experiment/' 23 | 24 | 25 | def check_trials(best_errors, parameter_trials): 26 | """Compute confidence score effectiveness a set of parameter settings. 27 | 28 | Parameters 29 | ---------- 30 | best_errors : np.ndarray 31 | Array of the per-song errors resulting from the best aligner 32 | parameter_trials: list 33 | List of hyperparameter search trials on the "easy" dataset 34 | """ 35 | # Load in corrupted MIDI datasets 36 | hard_dataset = align_dataset.load_dataset(CORRUPTED_HARD) 37 | # Grab objective values for each trial 38 | objectives = [np.mean(r['results']['mean_errors']) 39 | for r in parameter_trials] 40 | # Sort the results settings by their objective value 41 | parameter_trials = [parameter_trials[n] for n in np.argsort(objectives)] 42 | for trial in parameter_trials: 43 | easy_results = trial['results'] 44 | # Retrieve the errors for each song 45 | easy_errors = np.array(easy_results['mean_errors']) 46 | # Run a paired difference test of the errors, i.e. test whether the 47 | # distribution of differences between best_errors[n] and 48 | # easy_errors[n] is significantly different from 0 under a t-test 49 | _, r_score = scipy.stats.ttest_1samp(best_errors - easy_errors, 0) 50 | # When best_errors = easy_errors, the r_score will be NaN 51 | if np.isnan(r_score): 52 | r_score = 1. 53 | # Replace 'norm' param with numeric infinity if it's 'inf' 54 | params = trial['params'] 55 | if params['norm'] == str(np.inf): 56 | params['norm'] = np.inf 57 | # Align the hard dataset using these params and retrieve errors 58 | hard_results = align_dataset.align_dataset(params, hard_dataset) 59 | hard_errors = np.array(hard_results['mean_errors']) 60 | # Create results dict, storing the r_csore and errors, plus stuff below 61 | result = dict(r_score=r_score, 62 | easy_errors=easy_errors.tolist(), 63 | hard_errors=hard_errors.tolist()) 64 | # Try all combinations of score normalization 65 | for include_pen in [0, 1]: 66 | for length_normalize in [0, 1]: 67 | for mean_normalize in [0, 1]: 68 | for results, name in zip([hard_results, easy_results], 69 | ['hard', 'easy']): 70 | # Retrieve the score with or without penalties included 71 | if include_pen: 72 | scores = np.array(results['raw_scores']) 73 | name += '_penalty' 74 | else: 75 | scores = np.array(results['raw_scores_no_penalty']) 76 | name += '_no_penalty' 77 | # Optionally normalize by path length 78 | if length_normalize: 79 | scores /= np.array(results['path_lengths']) 80 | name += '_len_norm' 81 | # Optionally normalize by distance matrix mean 82 | if mean_normalize: 83 | scores /= np.array( 84 | results['distance_matrix_means']) 85 | name += '_mean_norm' 86 | # Store the scores 87 | result[name + '_scores'] = scores.tolist() 88 | # Write out this result 89 | db_utils.dump_result(params, result, OUTPUT_RESULTS_PATH) 90 | 91 | if __name__ == '__main__': 92 | if not os.path.exists(OUTPUT_RESULTS_PATH): 93 | os.makedirs(OUTPUT_RESULTS_PATH) 94 | # Load in all parameter search experiment results 95 | parameter_trials = [] 96 | for result_file in glob.glob(PARAMETER_RESULTS_GLOB): 97 | with open(result_file) as f: 98 | parameter_trials.append(json.load(f)) 99 | # Grab objective values for each trial 100 | objectives = [np.mean(r['results']['mean_errors']) 101 | for r in parameter_trials] 102 | best_errors = np.array( 103 | parameter_trials[np.argmin(objectives)]['results']['mean_errors']) 104 | # Split up the parameter trials into 10 roughly equally sized divisions 105 | split_parameter_trials = [parameter_trials[n::10] for n in range(10)] 106 | 107 | # Run check_trials for all splits in parallel 108 | joblib.Parallel(n_jobs=10, verbose=51)( 109 | joblib.delayed(check_trials)(best_errors, trials) 110 | for trials in split_parameter_trials) 111 | -------------------------------------------------------------------------------- /corrupt_midi.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Methods for corrupting a MIDI file in real-world-ish ways. 3 | ''' 4 | import numpy as np 5 | 6 | 7 | def warp_time(original_times, std): 8 | """ 9 | Computes a random smooth time offset. 10 | 11 | Parameters 12 | ---------- 13 | original_times : np.ndarray 14 | Array of original times, to be warped 15 | std : float 16 | Standard deviation of smooth noise. 17 | 18 | Returns 19 | ------- 20 | warp_offset : np.ndarray 21 | Smooth time warping offset, to be applied by addition 22 | """ 23 | N = original_times.shape[0] 24 | # Invert a random spectra, with most energy concentrated 25 | # on low frequencies via exponential decay 26 | warp_offset = np.fft.irfft( 27 | N*std*np.random.randn(N)*np.exp(-np.arange(N)))[:N] 28 | return warp_offset 29 | 30 | 31 | def crop_time(midi_object, original_times, start, end): 32 | """ 33 | Crop times out of a MIDI object 34 | 35 | Parameters 36 | ---------- 37 | midi_object : pretty_midi.PrettyMIDI 38 | MIDI object, notes will be modified to allow correct cropping 39 | with adjust_times. 40 | original_times : np.ndarray 41 | Array of original times, to be cropped 42 | start : float 43 | Start time of the crop 44 | end : float 45 | End time of the crop 46 | 47 | Returns 48 | ------- 49 | crop_offset : np.ndarray 50 | Time offset to be applied by addition 51 | """ 52 | for inst in midi_object.instruments: 53 | # Remove all notes within the interval we're cropping out 54 | inst.notes = [note for note in inst.notes if not ( 55 | note.start > start and note.start < end and 56 | note.end > start and note.end < end)] 57 | for note in inst.notes: 58 | # If the note starts before the interval and ends within the 59 | # interval truncate it so that it ends at the start of the interval 60 | if note.start < start and note.end > start and note.end < end: 61 | note.end = start 62 | # If the note starts within the interval and ends after the 63 | # interval move the start to the end of the interval. 64 | elif note.start > start and note.start < end and note.end > end: 65 | note.start = end 66 | # Move all events within the interval to the start 67 | for events in [inst.control_changes, inst.pitch_bends]: 68 | for event in events: 69 | if event.time > start and event.time < end: 70 | event.time = start 71 | # The crop offset is just the difference in timing, 72 | time_offset = np.zeros(original_times.shape[0]) 73 | # applied after the interval starts. 74 | time_offset[original_times >= start] -= (end - start) 75 | return time_offset 76 | 77 | 78 | def corrupt_instruments(midi_object, probability): 79 | ''' 80 | Randomly adjust the program numbers of instruments by +/-1. 81 | 82 | Parameters 83 | ---------- 84 | midi_object : pretty_midi.PrettyMIDI 85 | MIDI object; program numbers will be randomly adjusted. 86 | probability : float 87 | Probability \in [0, 1] that the program number will be adjusted. 88 | ''' 89 | for instrument in midi_object.instruments: 90 | # Ignore drum instruments; changing their prog is futile 91 | if not instrument.is_drum: 92 | # Use the applied probability 93 | if np.random.rand() < probability: 94 | # Randomly add or subtract one 95 | new_prog = instrument.program + np.random.choice([-1, 1]) 96 | # Handle edge cases 97 | if new_prog == -1: 98 | new_prog = 1 99 | elif new_prog == 128: 100 | new_prog = 127 101 | # Overwrite the program number 102 | instrument.program = new_prog 103 | 104 | 105 | def remove_instruments(midi_object, probability): 106 | ''' 107 | Randomly remove instruments from a MIDI object. 108 | Will never allow there to be zero instruments. 109 | 110 | Parameters 111 | ---------- 112 | midi_object : pretty_midi.PrettyMIDI 113 | MIDI object; instruments will be randomly removed 114 | probability : float 115 | Probability of removing an instrument. 116 | ''' 117 | # Pick a random subset of the instruments 118 | random_insts = [inst for inst in midi_object.instruments 119 | if np.random.rand() > probability] 120 | # Don't allow there to be 0 instruments 121 | if len(random_insts) == 0: 122 | midi_object.instruments = [np.random.choice(midi_object.instruments)] 123 | else: 124 | midi_object.instruments = random_insts 125 | 126 | 127 | def corrupt_velocity(midi_object, std): 128 | ''' 129 | Randomly corrupt the velocity of all notes. 130 | 131 | Parameters 132 | ---------- 133 | midi_object : pretty_midi.PrettyMIDI 134 | MIDI object; velocity will be randomly adjusted. 135 | std : float 136 | Velocities will be multiplied by N(1, std) 137 | ''' 138 | for instrument in midi_object.instruments: 139 | for note in instrument.notes: 140 | # Compute new velocity by scaling by N(1, std) 141 | new_velocity = note.velocity*(np.random.randn()*std + 1) 142 | # Clip to the range [0, 127], convert to int, and save 143 | note.velocity = int(np.clip(new_velocity, 0, 127)) 144 | 145 | 146 | def corrupt_midi(midi_object, original_times, warp_std, 147 | start_crop_prob, end_crop_prob, 148 | middle_crop_prob, remove_inst_prob, 149 | change_inst_prob, velocity_std): 150 | ''' 151 | Apply a series of corruptions to a MIDI object. 152 | 153 | Parameters 154 | ---------- 155 | midi_object : pretty_midi.PrettyMIDI 156 | MIDI object, will be corrupted in place. 157 | original_times : np.ndarray 158 | Array of original sampled times. 159 | warp_std : float 160 | Standard deviation of random smooth noise offsets. 161 | start_crop_prob : float 162 | Probability of cutting out the first 10% of the MIDI object. 163 | end_crop_prob : float 164 | Probability of cutting out the final 10% of the MIDI object. 165 | middle_crop_prob : float 166 | Probability of cutting out 1% of the MIDI object somewhere. 167 | remove_inst_prob : float 168 | Probability of removing instruments. 169 | change_inst_prob : float 170 | Probability of randomly adjusting instrument program numbers by +/-1. 171 | velocity_std : float 172 | Standard deviation of multiplicative scales to apply to velocities. 173 | 174 | Returns 175 | ------- 176 | adjusted_times : np.ndarray 177 | `original_times` adjusted by the cropping 178 | diagnostics : dict 179 | Diagnostics about the corruptions applied 180 | ''' 181 | # Store all keyword arguments as diagnostics 182 | diagnostics = dict((k, v) for (k, v) in locals().iteritems() 183 | if isinstance(v, (int, long, float))) 184 | # Smoothly warp times 185 | warp_offset = warp_time(original_times, warp_std) 186 | # Start with no cropping offset, as it will depend on the probabilities 187 | crop_offset = np.zeros(original_times.shape[0]) 188 | # Store whether we are cropping out the beginning 189 | diagnostics['crop_start'] = np.random.rand() < start_crop_prob 190 | if diagnostics['crop_start']: 191 | # Crop out the first 10% 192 | end_time = .1*original_times[-1] 193 | crop_offset += crop_time(midi_object, original_times, 0, end_time) 194 | diagnostics['crop_end'] = np.random.rand() < end_crop_prob 195 | if diagnostics['crop_end']: 196 | # Crop out the last 10% 197 | start_time = .9*original_times[-1] 198 | crop_offset += crop_time( 199 | midi_object, original_times, start_time, original_times[-1]) 200 | diagnostics['crop_middle'] = np.random.rand() < middle_crop_prob 201 | if diagnostics['crop_middle']: 202 | # Randomly crop out 1% from somewhere in the middle 203 | rand = np.random.rand() 204 | offset = original_times[-1]*(rand*.8 + .1) 205 | crop_offset += crop_time( 206 | midi_object, original_times, offset, 207 | offset + .01*original_times[-1]) 208 | # Store the number of instruments originally, and after optionally removing 209 | diagnostics['n_instruments_before'] = len(midi_object.instruments) 210 | # Randomly remove instruments 211 | remove_instruments(midi_object, remove_inst_prob) 212 | diagnostics['n_instruments_after'] = len(midi_object.instruments) 213 | # Corrupt their program numbers 214 | corrupt_instruments(midi_object, change_inst_prob) 215 | # Adjust velocity randomly 216 | corrupt_velocity(midi_object, velocity_std) 217 | # Apply the time warps computed above 218 | adjusted_times = original_times + warp_offset + crop_offset 219 | midi_object.adjust_times(original_times, adjusted_times) 220 | return adjusted_times, diagnostics 221 | -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Creates .npz archives of corrupted MIDI file features. 3 | ''' 4 | import numpy as np 5 | import pretty_midi 6 | import librosa 7 | import corrupt_midi 8 | import os 9 | import itertools 10 | import sys 11 | import argparse 12 | import glob 13 | import traceback 14 | import joblib 15 | import warnings 16 | 17 | FS = 22050 18 | NOTE_START = 36 19 | N_NOTES = 48 20 | HOP_LENGTH = 1024 21 | 22 | 23 | def extract_cqt(audio_data): 24 | ''' 25 | CQT routine with default parameters filled in, and some post-processing. 26 | 27 | Parameters 28 | ---------- 29 | audio_data : np.ndarray 30 | Audio data to compute CQT of 31 | 32 | Returns 33 | ------- 34 | cqt : np.ndarray 35 | CQT of the supplied audio data. 36 | frame_times : np.ndarray 37 | Times, in seconds, of each frame in the CQT 38 | ''' 39 | # Compute CQT 40 | cqt = librosa.cqt(audio_data, sr=FS, fmin=librosa.midi_to_hz(NOTE_START), 41 | n_bins=N_NOTES, hop_length=HOP_LENGTH, tuning=0.) 42 | # Compute the time of each frame 43 | times = librosa.frames_to_time( 44 | np.arange(cqt.shape[1]), sr=FS, hop_length=HOP_LENGTH) 45 | # Use float32 for the cqt to save space/memory 46 | cqt = cqt.astype(np.float32) 47 | return cqt, times 48 | 49 | 50 | def extract_features(midi_object): 51 | ''' 52 | Main feature extraction routine for a MIDI file 53 | 54 | Parameters 55 | ---------- 56 | midi_object : pretty_midi.PrettyMIDI 57 | PrettyMIDI object to extract features from 58 | 59 | Returns 60 | ------- 61 | features : dict 62 | Dictionary of features 63 | ''' 64 | # Synthesize the midi object 65 | midi_audio = midi_object.fluidsynth(fs=FS) 66 | # Compute constant-Q transform 67 | gram, times = extract_cqt(midi_audio) 68 | # Estimate the tempo from the MIDI data 69 | try: 70 | tempo = midi_object.estimate_tempo() 71 | except IndexError: 72 | # When there's no tempo to estimate, estimate_tempo currently fails on 73 | # an IndexError (see https://github.com/craffel/pretty-midi/issues/36) 74 | warnings.warn('No tempo was found, 120 bpm will be used.') 75 | tempo = 120 76 | # Usually, estimate_tempo gives tempos around 200 bpm, which is usually 77 | # double time, which we want. Sometimes, it's not, so we double it. 78 | while tempo < 160: 79 | tempo *= 2 80 | # Estimate the beats, forcing the tempo to be near the MIDI tempo 81 | beat_frames = librosa.beat.beat_track( 82 | midi_audio, bpm=tempo, hop_length=HOP_LENGTH)[1] 83 | beat_times = librosa.frames_to_time( 84 | beat_frames, sr=FS, hop_length=HOP_LENGTH) 85 | 86 | return {'times': times, 'gram': gram, 'beat_frames': beat_frames, 87 | 'beat_times': beat_times} 88 | 89 | 90 | def process_one_file(midi_filename, output_path, corruption_params): 91 | ''' 92 | Create features and diagnostics dict for original and corrupted MIDI file 93 | 94 | Parameters 95 | ---------- 96 | midi_filename : str 97 | Path to a MIDI file to corrupt. 98 | output_path : str 99 | Base path to write out .npz/.mid 100 | corruption_params : dict 101 | Parameters to pass to corrupt_midi.corrupt_midi 102 | 103 | Returns 104 | ------- 105 | features : dict 106 | Features of original and corrupted MIDI, with diagnostics 107 | ''' 108 | try: 109 | # Load in and extract features/diagnostic information for the file 110 | midi_object = pretty_midi.PrettyMIDI(midi_filename) 111 | orig_features = extract_features(midi_object) 112 | # Prepend keys with 'orig' 113 | orig_features = dict( 114 | ('orig_{}'.format(k), v) for (k, v) in orig_features.iteritems()) 115 | # Corrupt MIDI object (in place) 116 | adjusted_times, diagnostics = corrupt_midi.corrupt_midi( 117 | midi_object, orig_features['orig_times'], **corruption_params) 118 | # Get features for corrupted MIDI 119 | corrupted_features = extract_features(midi_object) 120 | corrupted_features = dict(('corrupted_{}'.format(k), v) 121 | for (k, v) in corrupted_features.iteritems()) 122 | # Combine features, diagnostics into one fat dict 123 | data = dict(i for i in itertools.chain( 124 | orig_features.iteritems(), [('adjusted_times', adjusted_times)], 125 | diagnostics.iteritems(), corrupted_features.iteritems())) 126 | data['original_file'] = os.path.abspath(midi_filename) 127 | corrupted_filename = os.path.abspath(os.path.join( 128 | output_path, os.path.basename(midi_filename))) 129 | midi_object.write(corrupted_filename) 130 | data['corrupted_file'] = corrupted_filename 131 | # Write out the npz 132 | output_npz = os.path.join( 133 | output_path, 134 | os.path.splitext(os.path.basename(midi_filename))[0] + '.npz') 135 | np.savez_compressed(output_npz, **data) 136 | except Exception: 137 | print "Error parsing {}:".format(midi_filename) 138 | traceback.print_exc() 139 | 140 | if __name__ == '__main__': 141 | # Parse command-line arguments 142 | parser = argparse.ArgumentParser( 143 | description='Create a dataset of corrupted MIDI information.') 144 | parser.add_argument('mode', action='store', 145 | help='Create "easy" or "hard" corruptions?') 146 | parser.add_argument('midi_glob', action='store', 147 | help='Glob to MIDI files (e.g. data/mid/*/*.mid)') 148 | parser.add_argument('output_path', action='store', 149 | help='Where to output .npz files') 150 | parameters = vars(parser.parse_args(sys.argv[1:])) 151 | # Set shared values of corruption_params 152 | corruption_params = { 153 | 'start_crop_prob': .5, 154 | 'end_crop_prob': .5, 155 | 'middle_crop_prob': .1, 156 | 'change_inst_prob': 1.} 157 | if parameters['mode'] == 'hard': 158 | corruption_params['warp_std'] = 20. 159 | corruption_params['remove_inst_prob'] = .5, 160 | corruption_params['velocity_std'] = 1. 161 | elif parameters['mode'] == 'easy': 162 | corruption_params['warp_std'] = 5. 163 | corruption_params['remove_inst_prob'] = .1, 164 | corruption_params['velocity_std'] = .2 165 | else: 166 | raise ValueError('mode must be "easy" or "hard", got {}'.format( 167 | parameters['mode'])) 168 | # Create the output directory if it doesn't exist 169 | if not os.path.exists(parameters['output_path']): 170 | os.makedirs(parameters['output_path']) 171 | joblib.Parallel(n_jobs=10, verbose=51)( 172 | joblib.delayed(process_one_file)( 173 | midi_file, parameters['output_path'], corruption_params) 174 | for midi_file in glob.glob(parameters['midi_glob'])) 175 | -------------------------------------------------------------------------------- /db_utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions to retrieve results from tinydb """ 2 | 3 | import numpy as np 4 | try: 5 | import ujson as json 6 | except ImportError: 7 | import json 8 | import glob 9 | import os 10 | 11 | 12 | def get_experiment_results(results_glob): 13 | """Get a list of all results of an experiment from .json files table. 14 | 15 | Parameters 16 | ---------- 17 | results_glob : str 18 | Glob to results .json files 19 | 20 | Returns 21 | ------- 22 | params : list 23 | List of dicts, where each entry is the dict of parameter name/value 24 | associations for the corresponding objective value in objectives 25 | objectives : list 26 | List of float, where each entry is the objective value for the 27 | corresponding parameter settings in params 28 | """ 29 | results = [] 30 | for result_file in glob.glob(results_glob): 31 | with open(result_file) as f: 32 | results.append(json.load(f)) 33 | params = [result['params'] for result in results] 34 | objectives = [np.mean(r['results']['mean_errors']) for r in results] 35 | return params, objectives 36 | 37 | 38 | def get_best_result(results_glob): 39 | """Get the parameters and objective corresponding to the best result for an 40 | experiment. 41 | 42 | Parameters 43 | ---------- 44 | results_glob : str 45 | Glob to results .json files 46 | 47 | Returns 48 | ------- 49 | params : dict 50 | dict of parameter name/value associations for the best objective value 51 | objective : float 52 | Best achived objective value, corresponding to the parameter settings 53 | in params 54 | """ 55 | # This function is just a wrapper around using argmin 56 | params, objectives = get_experiment_results(results_glob) 57 | best_result = np.argmin(objectives) 58 | return params[best_result], objectives[best_result] 59 | 60 | 61 | def dump_result(params, results, output_path): 62 | """Writes out a single result .json file in output_path. 63 | 64 | Parameters 65 | ---------- 66 | params : dict 67 | Dictionary of parameter names and values 68 | results : dict 69 | Dictionary of an alignment result 70 | output_path : str 71 | Where to write out the json file 72 | """ 73 | # Make a copy of params to avoid writing in-place below 74 | params = dict(params) 75 | # ujson can't handle infs, so we need to replace them manually: 76 | if params['norm'] == np.inf: 77 | params['norm'] = str(np.inf) 78 | # Convert params dict to a string of the form 79 | # param1_name_param1_value_param2_name_param2_value... 80 | param_string = "_".join( 81 | '{}_{}'.format(name, value) if type(value) != float else 82 | '{}_{:.3f}'.format(name, value) for name, value in params.items()) 83 | # Construct a path where the .json results file will be written 84 | output_filename = os.path.join(output_path, "{}.json".format(param_string)) 85 | # Store this result 86 | try: 87 | with open(output_filename, 'wb') as f: 88 | json.dump({'params': params, 'results': results}, f) 89 | # Ignore "OverflowError"s raised by ujson; they correspond to inf/NaN 90 | except OverflowError: 91 | pass 92 | -------------------------------------------------------------------------------- /find_best_aligners.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import tabulate 3 | import numpy as np 4 | import numba 5 | import ujson as json 6 | import collections 7 | 8 | 9 | @numba.jit(nopython=True) 10 | def kendall_ignore_ties(x, y): 11 | """ Compute Kendall's rank correlation coefficient, ignoring ties. 12 | 13 | Parameters 14 | ---------- 15 | x, y : np.ndarray 16 | Samples from two different random variables. 17 | 18 | Returns 19 | ------- 20 | tau : float 21 | Kendall's rank correlation coefficient, ignoring ties. 22 | """ 23 | # Accumulate numerator and denominator as we go 24 | numer = 0. 25 | denom = 0. 26 | # Compare all pairs 27 | for i in xrange(1, x.shape[0]): 28 | for j in xrange(0, i): 29 | # When either pair is equal, ignore 30 | if x[i] == x[j] or y[i] == y[j]: 31 | continue 32 | else: 33 | # Add one when both x[i] < x[j] and y[i] < y[j] 34 | # or x[i] > x[j] and y[i] > y[j], 35 | # otherwise add -1 36 | numer += np.sign(x[i] - x[j])*np.sign(y[i] - y[j]) 37 | # Add 1 more sample considered 38 | denom += 1 39 | # Divide to compute tau 40 | return numer/denom 41 | 42 | if __name__ == '__main__': 43 | 44 | # Load in all confidence experiment results 45 | results = [] 46 | for result_file in glob.glob('results/confidence_experiment/*.json'): 47 | with open(result_file) as f: 48 | results.append(json.load(f)) 49 | 50 | # Create a list which stores the performance of each aligner tried 51 | aligner_performance = [] 52 | for result in results: 53 | # Retrieve the parameters, to report later 54 | params = result['params'] 55 | # Retrieve the result, for less verbosity 56 | result = result['results'] 57 | # Only consider aligners with small r-score 58 | if result['r_score'] > .05: 59 | # Combine the errors for the easy and hard datasets 60 | errors = np.array(result['hard_errors'] + result['easy_errors']) 61 | # Retrieve the reported aligner scores for all reported errors 62 | scores = dict((k, v) for k, v in result.items() if 'scores' in k) 63 | # Combine hard and easy scores 64 | scores = dict((k, np.array(v + scores[k.replace('hard', 'easy')])) 65 | for k, v in scores.items() if 'hard' in k) 66 | # Compute rank correlation coefficients for all scores 67 | rank_corrs = dict((k, kendall_ignore_ties(errors, s)) 68 | for k, s in scores.items()) 69 | # Find the name and score of the best-correlating score 70 | best_name = max(rank_corrs, key=rank_corrs.get) 71 | best_score = rank_corrs[best_name] 72 | # Store the performance of this aligner 73 | aligner_performance.append(collections.OrderedDict([ 74 | ('hard_error', np.mean(result['hard_errors'])), 75 | ('easy_error', np.mean(result['easy_errors'])), 76 | ('r_score', result['r_score']), 77 | ('best_name', best_name), 78 | ('best_score', best_score), 79 | ('params', params)])) 80 | # Sort aligners by their best_score, descendin ( 81 | aligner_performance.sort(key=lambda x: x['best_score'], reverse=True) 82 | print "Systems below threshold: {}".format(len(aligner_performance)) 83 | # Print table of all aligners 84 | print tabulate.tabulate( 85 | aligner_performance, 86 | headers=dict((k, k) for k in aligner_performance[0])) 87 | -------------------------------------------------------------------------------- /parameter_experiment_gp.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Run the alignment parameter search experiment. 3 | ''' 4 | 5 | import simple_spearmint 6 | import numpy as np 7 | import argparse 8 | import align_dataset 9 | import os 10 | import db_utils 11 | 12 | # Path to corrupted dataset, created by create_data.py 13 | CORRUPTED_PATH = 'data/corrupted_easy/*.npz' 14 | # How many total trials of hyperparameter optimization should we run? 15 | N_TRIALS = 100 16 | # How many randomly selected hyperparameter settings shuld we start with? 17 | INITIAL_RANDOM = 100 18 | # Where should experiment results be output? 19 | OUTPUT_PATH = 'results/parameter_experiment_gp' 20 | # Where do the results from the random parameter search live? 21 | RANDOM_RESULTS_PATH = 'results/parameter_experiment_random' 22 | 23 | if __name__ == '__main__': 24 | # Retrieve the seed from the command line 25 | parser = argparse.ArgumentParser( 26 | description='Run a MIDI alignment parameter search experiment.', 27 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument('-s', '--seed', action='store', type=int, default=0, 29 | help='Random seed') 30 | seed = parser.parse_args().seed 31 | np.random.seed(seed) 32 | 33 | space = { 34 | # Use chroma or CQT for feature representation 35 | 'feature': {'type': 'enum', 'options': ['chroma', 'gram']}, 36 | # Beat sync, or don't 37 | 'beat_sync': {'type': 'enum', 'options': [True, False]}, 38 | # Don't normalize, max-norm, L1-norm, or L2 norm 39 | 'norm': {'type': 'enum', 'options': [None, np.inf, 1, 2]}, 40 | # Whether or not to z-score (standardize) the feature dimensions 41 | 'standardize': {'type': 'enum', 'options': [True, False]}, 42 | # Which distance metric to use for distance matrix 43 | 'metric': {'type': 'enum', 44 | 'options': ['euclidean', 'sqeuclidean', 'cosine']}, 45 | # DTW additive penalty 46 | 'add_pen': {'type': 'float', 'min': 0, 'max': 3}, 47 | # DTW end point tolerance 48 | 'gully': {'type': 'float', 'min': 0, 'max': 1}, 49 | # Whether to constrain the path to within the tolerance 50 | 'band_mask': {'type': 'enum', 'options': [True, False]}} 51 | 52 | # Check that the results database directory exists 53 | if not os.path.exists(OUTPUT_PATH): 54 | os.makedirs(OUTPUT_PATH) 55 | 56 | # Initialize apsis experiment 57 | experiment = simple_spearmint.SimpleSpearmint(space, noiseless=True) 58 | 59 | # Load in all random parameter search results 60 | random_params, random_objectives = db_utils.get_experiment_results( 61 | os.path.join(RANDOM_RESULTS_PATH, '*.json')) 62 | # Seed the GP optimizer with the INITIAL_RANDOM best results and 63 | # INITIAL_RANDOM random results 64 | best_indices = np.argsort(random_objectives)[:INITIAL_RANDOM] 65 | # Set subtraction avoids randomly choosing best objective trials 66 | random_indices = np.random.choice( 67 | [n for n in range(len(random_params)) if n not in best_indices], 68 | INITIAL_RANDOM, False) 69 | # Seed the GP optimizer with random parameter search results 70 | for n in np.append(best_indices, random_indices): 71 | # Replace 'inf' with actual values 72 | params = dict((k, v) if v != 'inf' else (k, np.inf) 73 | for k, v in random_params[n].items()) 74 | experiment.update(params, random_objectives[n]) 75 | 76 | # Load in the alignment dataset 77 | data = align_dataset.load_dataset(CORRUPTED_PATH) 78 | 79 | for _ in range(N_TRIALS): 80 | # Retrieve GP-based parameter suggestion 81 | candidate_params = experiment.suggest() 82 | # Get results for these parameters 83 | result = align_dataset.align_dataset(candidate_params, data) 84 | # Write results out 85 | db_utils.dump_result(candidate_params, result, OUTPUT_PATH) 86 | # Update optimizer 87 | objective = np.mean(result['mean_errors']) 88 | experiment.update(candidate_params, objective) 89 | -------------------------------------------------------------------------------- /parameter_experiment_random.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Run the alignment parameter search experiment. 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | import align_dataset 8 | import functools 9 | import db_utils 10 | 11 | # Path to corrupted dataset, created by create_data.py 12 | CORRUPTED_PATH = 'data/corrupted_easy/*.npz' 13 | # Path where results should be written 14 | OUTPUT_PATH = 'results/parameter_experiment_random' 15 | # Number of parameter settings to try 16 | N_TRIALS = 1000 17 | 18 | 19 | def experiment_wrapper(param_sampler, data, output_path): 20 | ''' 21 | Run alignment over the dataset and save the result. 22 | 23 | Parameters 24 | ---------- 25 | param_sampler : dict of functions 26 | Dictionary which maps parameter names to functions to sample values for 27 | those parameters. 28 | 29 | data : list of dict of np.ndarray 30 | Collection aligned/unaligned MIDI pairs 31 | 32 | output_path : str 33 | Where to write the results .json file 34 | ''' 35 | # Call the sample function for each param name in the param sampler dict 36 | # to create a dict which maps param names to sampled values 37 | params = dict((name, sample()) for (name, sample) in param_sampler.items()) 38 | # Get the results dictionary for this parameter setting 39 | results = align_dataset.align_dataset(params, data) 40 | db_utils.dump_result(params, results, output_path) 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | param_sampler = { 46 | # Use chroma or CQT for feature representation 47 | 'feature': functools.partial(np.random.choice, ['chroma', 'gram']), 48 | # Beat sync, or don't 49 | 'beat_sync': functools.partial(np.random.choice, [0, 1]), 50 | # Don't normalize, max-norm, L1-norm, or L2 norm 51 | 'norm': functools.partial(np.random.choice, [None, np.inf, 1, 2]), 52 | # Whether or not to z-score (standardize) the feature dimensions 53 | 'standardize': functools.partial(np.random.choice, [0, 1]), 54 | # Which distance metric to use for distance matrix 55 | 'metric': functools.partial(np.random.choice, 56 | ['euclidean', 'sqeuclidean', 'cosine']), 57 | # DTW additive penalty 58 | 'add_pen': functools.partial(np.random.uniform, 0, 3), 59 | # DTW end point tolerance 60 | 'gully': functools.partial(np.random.uniform, 0, 1), 61 | # Whether to constrain the path to within the tolerance 62 | 'band_mask': functools.partial(np.random.choice, [0, 1])} 63 | 64 | # Load in the easy corrupted dataset 65 | data = align_dataset.load_dataset(CORRUPTED_PATH) 66 | # Check that the results database directory exists 67 | if not os.path.exists(OUTPUT_PATH): 68 | os.makedirs(OUTPUT_PATH) 69 | 70 | for _ in range(N_TRIALS): 71 | experiment_wrapper(param_sampler, data, OUTPUT_PATH) 72 | --------------------------------------------------------------------------------