├── README.md ├── arguments.py ├── chords_infer.py ├── constants.py ├── create_nottingham_roll.py ├── data_converter.py ├── env.yml ├── interpolate_double.py ├── interpolate_single.py ├── interpolate_triple.py ├── midi_io.py ├── model.py ├── note_sequence.py ├── note_sequence_ops.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | Lisa Kawai, Philippe Esling, Tatsuya Harada Attributes-aware Deep Music Transformation The 21st International Society for Music Information Retrieval Conference (ISMIR), 2020 -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from constants import QUARTERS_PER_BAR 3 | 4 | 5 | def generate_cfg(env_var, args, output_dir, checkpoint=None, chords_per_bar=2): 6 | if checkpoint: 7 | cfg = checkpoint['cfg'] 8 | if 'structure' not in cfg['model']: 9 | cfg['model']['structure'] = 'RhythmPitchGenModel' 10 | if 'steps_per_quarter' not in cfg: 11 | cfg['steps_per_quarter'] = 4 12 | if 'EOR_token' not in cfg: 13 | cfg['EOR_token'] = cfg['steps_per_quarter'] * QUARTERS_PER_BAR // chords_per_bar + 1 14 | cfg['SOR_token'] = cfg['EOR_token'] + 1 # Start of rhythm 15 | if 'chords_per_bar' not in cfg: 16 | if 'EOR_token' in cfg: 17 | cfg['chords_per_bar'] = cfg['steps_per_quarter'] * QUARTERS_PER_BAR // (cfg['EOR_token'] - 1) 18 | else: 19 | cfg['chords_per_bar'] = chords_per_bar 20 | if 'use_custom_vocab' not in cfg: 21 | cfg['use_custom_vocab'] = False 22 | if 'r_alpha' not in cfg: 23 | cfg['r_alpha'] = 1.0 24 | if 'p_alpha' not in cfg: 25 | cfg['p_alpha'] = 1.0 26 | if 'offset_valid' in cfg['data']: 27 | cfg['data']['offset_valid'] = cfg['data']['offset_valid'].replace('vaild', 'valid') 28 | if 'style_valid' in cfg['data']: 29 | cfg['data']['style_valid'] = cfg['data']['style_valid'].replace('valid_style.h5', 'valid_sf.h5') 30 | if 'event_vocab_size' not in cfg['data']: 31 | cfg['data']['event_vocab_size'] = 1297 32 | if 'num_instruments' not in cfg['data']: 33 | cfg['data']['num_instruments'] = 5 34 | return cfg 35 | 36 | cfg = env_var 37 | cfg['model_name'] = args.model_name 38 | cfg['r_alpha'] = args.r_alpha 39 | cfg['p_alpha'] = args.p_alpha 40 | cfg['output_dir'] = output_dir 41 | cfg['chords_per_bar'] = chords_per_bar 42 | if 'steps_per_quarter' not in cfg: 43 | cfg['steps_per_quarter'] = 4 44 | if 'use_custom_vocab' not in cfg: 45 | cfg['use_custom_vocab'] = False 46 | cfg['EOR_token'] = cfg['steps_per_quarter'] * QUARTERS_PER_BAR // chords_per_bar + 1 47 | cfg['SOR_token'] = cfg['EOR_token'] + 1 48 | if args.n_iters: 49 | cfg['n_iters'] = args.n_iters 50 | if args.learning_rate: 51 | cfg['learning_rate'] = args.learning_rate 52 | if args.model_structure: 53 | cfg['model']['structure'] = args.model_structure 54 | if args.hidden_size: 55 | for key in cfg['model'].keys(): 56 | if 'hidden_size' in cfg['model'][key]: 57 | cfg['model'][key]['hidden_size'] = args.hidden_size 58 | if args.num_layers: 59 | for key in cfg['model'].keys(): 60 | if 'num_layers' in cfg['model'][key]: 61 | cfg['model'][key]['num_layers'] = args.num_layers 62 | if args.c_num_layers: 63 | cfg['model']['chord_encoder']['num_layers'] = args.c_num_layers 64 | if args.c_hidden_size: 65 | cfg['model']['chord_encoder']['hidden_size'] = args.c_hidden_size 66 | return cfg 67 | 68 | 69 | def generate_cfg_fader(cfg, args, output_dir, checkpoint=None): 70 | if checkpoint: 71 | cfg = checkpoint['cfg'] 72 | if 'steps_per_quarter' not in cfg: 73 | cfg['steps_per_quarter'] = 12 74 | if 'bars_per_data' not in cfg: 75 | cfg['bars_per_data'] = 4 76 | if 'max_shift_steps' not in cfg: 77 | cfg['max_shift_steps'] = 48 78 | if 'z_dims' not in cfg: 79 | cfg['z_dims'] = 128 80 | if 'n_classes' not in cfg: 81 | cfg['n_classes'] = 8 82 | if 'activation_d' not in cfg: 83 | cfg['activation_d'] = 'tanh' 84 | return cfg 85 | 86 | cfg['output_dir'] = output_dir 87 | cfg['z_dims'] = args.z_dims 88 | 89 | if args.n_iters: 90 | cfg['n_iters'] = args.n_iters 91 | if args.learning_rate: 92 | cfg['learning_rate'] = args.learning_rate 93 | if args.learning_rate_d: 94 | cfg['learning_rate_d'] = args.learning_rate_d 95 | if args.lambda_d is not None: 96 | cfg['lambda_d'] = args.lambda_d 97 | if args.lambda_kl: 98 | cfg['lambda_kl'] = args.lambda_kl 99 | 100 | if args.model_structure: 101 | cfg['model']['structure'] = args.model_structure 102 | if args.c_num_layers: 103 | cfg['model']['chord_encoder']['num_layers'] = args.c_num_layers 104 | if args.c_hidden_size: 105 | cfg['model']['chord_encoder']['hidden_size'] = args.c_hidden_size 106 | if args.e_num_layers: 107 | cfg['model']['encoder']['num_layers'] = args.e_num_layers 108 | if args.e_hidden_size: 109 | cfg['model']['encoder']['hidden_size'] = args.e_hidden_size 110 | if args.d_num_layers: 111 | cfg['model']['decoder']['num_layers'] = args.d_num_layers 112 | if args.d_hidden_size: 113 | cfg['model']['decoder']['hidden_size'] = args.d_hidden_size 114 | if args.dis_num_layers: 115 | cfg['model']['discriminator']['num_layers'] = args.dis_num_layers 116 | if args.attribute: 117 | cfg['attr'] = args.attribute 118 | if args.batch_size: 119 | cfg['batch_size'] = args.batch_size 120 | if args.is_ordinal == 1: 121 | cfg['is_ordinal'] = True 122 | else: 123 | cfg['is_ordinal'] = False 124 | cfg['thresholds'] = np.array([0.0]) 125 | cfg['n_attr'] = len(cfg['attr']) 126 | return cfg -------------------------------------------------------------------------------- /chords_infer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Chord inference for NoteSequences.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import bisect 22 | import itertools 23 | import math 24 | import numbers 25 | 26 | import constants 27 | from note_sequence import NoteSequence, KeySignature, AnnotationType, TextAnnotation 28 | import note_sequence_ops 29 | import numpy as np 30 | 31 | 32 | # Names of pitch classes to use (mostly ignoring spelling). 33 | _PITCH_CLASS_NAMES = [ 34 | 'C', 'C#', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B'] 35 | 36 | # Pitch classes in a key (rooted at zero). 37 | _KEY_PITCHES = [0, 2, 4, 5, 7, 9, 11] 38 | 39 | # Pitch classes in each chord kind (rooted at zero). 40 | _CHORD_KIND_PITCHES = { 41 | '': [0, 4, 7], 42 | 'm': [0, 3, 7], 43 | '+': [0, 4, 8], 44 | 'dim': [0, 3, 6], 45 | '7': [0, 4, 7, 10], 46 | 'maj7': [0, 4, 7, 11], 47 | 'm7': [0, 3, 7, 10], 48 | 'm7b5': [0, 3, 6, 10], 49 | } 50 | _CHORD_KINDS = _CHORD_KIND_PITCHES.keys() 51 | 52 | # All usable chords, including no-chord. 53 | _CHORDS = [constants.NO_CHORD] + list( 54 | itertools.product(range(12), _CHORD_KINDS)) 55 | 56 | # All key-chord pairs. 57 | _KEY_CHORDS = list(itertools.product(range(12), _CHORDS)) 58 | 59 | # Maximum length of chord sequence to infer. 60 | _MAX_NUM_CHORDS = 1000 61 | 62 | # Mapping from time signature to number of chords to infer per bar. 63 | _DEFAULT_TIME_SIGNATURE_CHORDS_PER_BAR = { 64 | (2, 2): 1, 65 | (2, 4): 1, 66 | (3, 4): 1, 67 | (4, 4): 2, 68 | (6, 8): 2, 69 | } 70 | 71 | 72 | def _key_chord_distribution(chord_pitch_out_of_key_prob): 73 | """Probability distribution over chords for each key.""" 74 | num_pitches_in_key = np.zeros([12, len(_CHORDS)], dtype=np.int32) 75 | num_pitches_out_of_key = np.zeros([12, len(_CHORDS)], dtype=np.int32) 76 | 77 | # For each key and chord, compute the number of chord notes in the key and the 78 | # number of chord notes outside the key. 79 | for key in range(12): 80 | key_pitches = set((key + offset) % 12 for offset in _KEY_PITCHES) 81 | for i, chord in enumerate(_CHORDS[1:]): 82 | root, kind = chord 83 | chord_pitches = set((root + offset) % 12 84 | for offset in _CHORD_KIND_PITCHES[kind]) 85 | num_pitches_in_key[key, i + 1] = len(chord_pitches & key_pitches) 86 | num_pitches_out_of_key[key, 87 | i + 1] = len(chord_pitches - key_pitches) 88 | 89 | # Compute the probability of each chord under each key, normalizing to sum to 90 | # one for each key. 91 | # TODO 92 | mat = ((1 - chord_pitch_out_of_key_prob) ** num_pitches_in_key * 93 | chord_pitch_out_of_key_prob ** num_pitches_out_of_key) 94 | mat /= mat.sum(axis=1)[:, np.newaxis] 95 | return mat 96 | 97 | 98 | def _key_chord_transition_distribution( 99 | key_chord_distribution, key_change_prob, chord_change_prob): 100 | """Transition distribution between key-chord pairs.""" 101 | mat = np.zeros([len(_KEY_CHORDS), len(_KEY_CHORDS)]) 102 | 103 | for i, key_chord_1 in enumerate(_KEY_CHORDS): 104 | key_1, chord_1 = key_chord_1 105 | chord_index_1 = i % len(_CHORDS) 106 | 107 | for j, key_chord_2 in enumerate(_KEY_CHORDS): 108 | key_2, chord_2 = key_chord_2 109 | chord_index_2 = j % len(_CHORDS) 110 | 111 | if key_1 != key_2: 112 | # Key change. Chord probability depends only on key and not previous 113 | # chord. 114 | mat[i, j] = (key_change_prob / 11) 115 | mat[i, j] *= key_chord_distribution[key_2, chord_index_2] 116 | 117 | else: 118 | # No key change. 119 | mat[i, j] = 1 - key_change_prob 120 | if chord_1 != chord_2: 121 | # Chord probability depends on key, but we have to redistribute the 122 | # probability mass on the previous chord since we know the chord 123 | # changed. 124 | mat[i, j] *= (chord_change_prob * ( 125 | key_chord_distribution[key_2, chord_index_2] + 126 | key_chord_distribution[key_2, chord_index_1] / (len(_CHORDS) - 1))) 127 | else: 128 | # No chord change. 129 | mat[i, j] *= 1 - chord_change_prob 130 | 131 | return mat 132 | 133 | 134 | def _chord_pitch_vectors(): 135 | """Unit vectors over pitch classes for all chords.""" 136 | x = np.zeros([len(_CHORDS), 12]) 137 | for i, chord in enumerate(_CHORDS[1:]): 138 | root, kind = chord 139 | for offset in _CHORD_KIND_PITCHES[kind]: 140 | x[i + 1, (root + offset) % 12] = 1 141 | x[1:, :] /= np.linalg.norm(x[1:, :], axis=1)[:, np.newaxis] 142 | return x 143 | 144 | 145 | def sequence_note_pitch_vectors(sequence, seconds_per_frame, lowest_pitch_weight=2.): 146 | """Compute pitch class vectors for temporal frames across a sequence. 147 | 148 | Args: 149 | sequence: The NoteSequence for which to compute pitch class vectors. 150 | seconds_per_frame: The size of the frame corresponding to each pitch class 151 | vector, in seconds. Alternatively, a list of frame boundary times in 152 | seconds (not including initial start time and final end time). 153 | 154 | Returns: 155 | A numpy array with shape `[num_frames, 12]` where each row is a unit- 156 | normalized pitch class vector for the corresponding frame in `sequence`. 157 | """ 158 | if isinstance(seconds_per_frame, numbers.Number): 159 | # Construct array of frame boundary times. 160 | num_frames = int(math.ceil(sequence.total_time / seconds_per_frame)) 161 | frame_boundaries = seconds_per_frame * np.arange(1, num_frames) 162 | else: 163 | frame_boundaries = sorted(seconds_per_frame) 164 | num_frames = len(frame_boundaries) + 1 165 | 166 | x = np.zeros([num_frames, 12]) 167 | lowest_pitch = np.ones(num_frames, dtype=np.int) * 1000 168 | 169 | for note in sequence.notes: 170 | if note.is_drum: 171 | continue 172 | if note.program in constants.UNPITCHED_PROGRAMS: 173 | continue 174 | 175 | start_frame = bisect.bisect_right(frame_boundaries, note.start_time) 176 | end_frame = bisect.bisect_left(frame_boundaries, note.end_time) 177 | 178 | for frame in range(start_frame, end_frame + 1): 179 | lowest_pitch[frame] = min(int(lowest_pitch[frame]), int(note.pitch)) 180 | pitch_class = int(note.pitch) % 12 181 | 182 | if start_frame >= end_frame: 183 | x[start_frame, pitch_class] += note.end_time - note.start_time 184 | else: 185 | x[start_frame, pitch_class] += ( 186 | frame_boundaries[start_frame] - note.start_time) 187 | for frame in range(start_frame + 1, end_frame): 188 | x[frame, pitch_class] += ( 189 | frame_boundaries[frame] - frame_boundaries[frame - 1]) 190 | x[end_frame, pitch_class] += ( 191 | note.end_time - frame_boundaries[end_frame - 1]) 192 | 193 | lowest_pitch %= 12 194 | # for frame in range(num_frames): 195 | # x[frame, lowest_pitch[frame]] *= lowest_pitch_weight 196 | x_norm = np.linalg.norm(x, axis=1) 197 | nonzero_frames = x_norm > 0 198 | x[nonzero_frames, :] /= x_norm[nonzero_frames, np.newaxis] 199 | 200 | return x 201 | 202 | 203 | def _chord_frame_log_likelihood(note_pitch_vectors, chord_note_concentration): 204 | """Log-likelihood of observing each frame of note pitches under each chord.""" 205 | return chord_note_concentration * np.dot(note_pitch_vectors, 206 | _chord_pitch_vectors().T) 207 | 208 | 209 | def _key_chord_viterbi(chord_frame_loglik, 210 | key_chord_loglik, 211 | key_chord_transition_loglik): 212 | """Use the Viterbi algorithm to infer a sequence of key-chord pairs.""" 213 | num_frames, num_chords = chord_frame_loglik.shape 214 | num_key_chords = len(key_chord_transition_loglik) 215 | 216 | loglik_matrix = np.zeros([num_frames, num_key_chords]) 217 | path_matrix = np.zeros([num_frames, num_key_chords], dtype=np.int32) 218 | 219 | # Initialize with a uniform distribution over keys. 220 | for i, key_chord in enumerate(_KEY_CHORDS): 221 | key, unused_chord = key_chord 222 | chord_index = i % len(_CHORDS) 223 | loglik_matrix[0, i] = ( 224 | -np.log(12) + key_chord_loglik[key, chord_index] + 225 | chord_frame_loglik[0, chord_index]) 226 | 227 | for frame in range(1, num_frames): 228 | # At each frame, store the log-likelihood of the best sequence ending in 229 | # each key-chord pair, along with the index of the parent key-chord pair 230 | # from the previous frame. 231 | mat = (np.tile(loglik_matrix[frame - 1][:, np.newaxis], 232 | [1, num_key_chords]) + 233 | key_chord_transition_loglik) 234 | path_matrix[frame, :] = mat.argmax(axis=0) 235 | loglik_matrix[frame, :] = ( 236 | mat[path_matrix[frame, :], range(num_key_chords)] + 237 | np.tile(chord_frame_loglik[frame], 12)) 238 | 239 | # Reconstruct the most likely sequence of key-chord pairs. 240 | path = [np.argmax(loglik_matrix[-1])] 241 | for frame in range(num_frames, 1, -1): 242 | path.append(path_matrix[frame - 1, path[-1]]) 243 | 244 | return [(index // num_chords, _CHORDS[index % num_chords]) 245 | for index in path[::-1]] 246 | 247 | 248 | class ChordInferenceError(Exception): # pylint:disable=g-bad-exception-name 249 | pass 250 | 251 | 252 | class SequenceAlreadyHasChordsError(ChordInferenceError): 253 | pass 254 | 255 | 256 | class UncommonTimeSignatureError(ChordInferenceError): 257 | pass 258 | 259 | 260 | class NonIntegerStepsPerChordError(ChordInferenceError): 261 | pass 262 | 263 | 264 | class EmptySequenceError(ChordInferenceError): 265 | pass 266 | 267 | 268 | class SequenceTooLongError(ChordInferenceError): 269 | pass 270 | 271 | 272 | def infer_chords_for_sequence(sequence, 273 | chords_per_bar=None, 274 | key_change_prob=0.001, 275 | chord_change_prob=0.5, 276 | chord_pitch_out_of_key_prob=0.01, 277 | chord_note_concentration=100.0, 278 | add_key_signatures=False): 279 | """Infer chords for a NoteSequence using the Viterbi algorithm. 280 | 281 | This uses some heuristics to infer chords for a quantized NoteSequence. At 282 | each chord position a key and chord will be inferred, and the chords will be 283 | added (as text annotations) to the sequence. 284 | 285 | If the sequence is quantized relative to meter, a fixed number of chords per 286 | bar will be inferred. Otherwise, the sequence is expected to have beat 287 | annotations and one chord will be inferred per beat. 288 | 289 | Args: 290 | sequence: The NoteSequence for which to infer chords. This NoteSequence will 291 | be modified in place. 292 | chords_per_bar: If `sequence` is quantized, the number of chords per bar to 293 | infer. If None, use a default number of chords based on the time 294 | signature of `sequence`. 295 | key_change_prob: Probability of a key change between two adjacent frames. 296 | chord_change_prob: Probability of a chord change between two adjacent 297 | frames. 298 | chord_pitch_out_of_key_prob: Probability of a pitch in a chord not belonging 299 | to the current key. 300 | chord_note_concentration: Concentration parameter for the distribution of 301 | observed pitches played over a chord. At zero, all pitches are equally 302 | likely. As concentration increases, observed pitches must match the 303 | chord pitches more closely. 304 | add_key_signatures: If True, also add inferred key signatures to 305 | `quantized_sequence` (and remove any existing key signatures). 306 | 307 | Raises: 308 | SequenceAlreadyHasChordsError: If `sequence` already has chords. 309 | QuantizationStatusError: If `sequence` is not quantized relative to 310 | meter but `chords_per_bar` is specified or no beat annotations are 311 | present. 312 | UncommonTimeSignatureError: If `chords_per_bar` is not specified and 313 | `sequence` is quantized and has an uncommon time signature. 314 | NonIntegerStepsPerChordError: If the number of quantized steps per chord 315 | is not an integer. 316 | EmptySequenceError: If `sequence` is empty. 317 | SequenceTooLongError: If the number of chords to be inferred is too 318 | large. 319 | """ 320 | assert sequence.quantization_info.steps_per_quarter > 0 321 | 322 | # Infer a fixed number of chords per bar. 323 | if chords_per_bar is None: 324 | time_signature = (sequence.time_signatures[0].numerator, 325 | sequence.time_signatures[0].denominator) 326 | if time_signature not in _DEFAULT_TIME_SIGNATURE_CHORDS_PER_BAR: 327 | raise UncommonTimeSignatureError( 328 | 'No default chords per bar for time signature: (%d, %d)' % 329 | time_signature) 330 | chords_per_bar = _DEFAULT_TIME_SIGNATURE_CHORDS_PER_BAR[time_signature] 331 | 332 | # Determine the number of seconds (and steps) each chord is held. 333 | steps_per_bar_float = note_sequence_ops.steps_per_bar_in_quantized_sequence(sequence) 334 | steps_per_chord_float = steps_per_bar_float / chords_per_bar 335 | if steps_per_chord_float != round(steps_per_chord_float): 336 | raise NonIntegerStepsPerChordError('Non-integer number of steps per chord: %f' % 337 | steps_per_chord_float) 338 | steps_per_chord = int(steps_per_chord_float) 339 | steps_per_second = note_sequence_ops.steps_per_quarter_to_steps_per_second( 340 | sequence.quantization_info.steps_per_quarter, sequence.tempos[0].qpm) 341 | seconds_per_chord = steps_per_chord / steps_per_second 342 | 343 | num_chords = int(math.ceil(sequence.total_time / seconds_per_chord)) 344 | if num_chords == 0: 345 | raise EmptySequenceError('NoteSequence is empty.') 346 | 347 | if num_chords > _MAX_NUM_CHORDS: 348 | raise SequenceTooLongError( 349 | 'NoteSequence too long for chord inference: %d frames' % 350 | num_chords) 351 | 352 | # Compute pitch vectors for each chord frame, then compute log-likelihood of 353 | # observing those pitch vectors under each possible chord. 354 | note_pitch_vectors = sequence_note_pitch_vectors(sequence, seconds_per_chord) # TODO: save this vector 355 | chord_frame_loglik = _chord_frame_log_likelihood( 356 | note_pitch_vectors, chord_note_concentration) 357 | 358 | # Compute distribution over chords for each key, and transition distribution 359 | # between key-chord pairs. 360 | key_chord_distribution = _key_chord_distribution( 361 | chord_pitch_out_of_key_prob=chord_pitch_out_of_key_prob) 362 | key_chord_transition_distribution = _key_chord_transition_distribution( 363 | key_chord_distribution, 364 | key_change_prob=key_change_prob, 365 | chord_change_prob=chord_change_prob) 366 | key_chord_loglik = np.log(key_chord_distribution) 367 | key_chord_transition_loglik = np.log(key_chord_transition_distribution) 368 | 369 | key_chords = _key_chord_viterbi( 370 | chord_frame_loglik, key_chord_loglik, key_chord_transition_loglik) 371 | if add_key_signatures: 372 | del sequence.key_signatures[:] 373 | 374 | # Add the inferred chord changes to the sequence, optionally adding key 375 | # signature(s) as well. 376 | current_key_name = None 377 | current_chord_name = None 378 | for frame, (key, chord) in enumerate(key_chords): 379 | time = frame * seconds_per_chord 380 | if _PITCH_CLASS_NAMES[key] != current_key_name: 381 | # A key change was inferred. 382 | if add_key_signatures: 383 | ks = KeySignature(time, key, key // 12) 384 | sequence.key_signatures.append(ks) 385 | else: 386 | if current_key_name is not None: 387 | print('Sequence has key change from %s to %s at %f seconds.', 388 | current_key_name, _PITCH_CLASS_NAMES[key], time) 389 | 390 | current_key_name = _PITCH_CLASS_NAMES[key] 391 | 392 | if chord == constants.NO_CHORD: 393 | figure = constants.NO_CHORD 394 | else: 395 | root, kind = chord 396 | figure = '%s%s' % (_PITCH_CLASS_NAMES[root], kind) 397 | 398 | ta = TextAnnotation(time, frame * steps_per_chord, figure, AnnotationType.CHORD_SYMBOL, chord[0], chord[1], 399 | note_pitch_vectors[frame]) 400 | sequence.text_annotations.append(ta) 401 | 402 | return sequence 403 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Constants for music processing in Magenta.""" 16 | 17 | # Meter-related constants. 18 | DEFAULT_QUARTERS_PER_MINUTE = 60.0 19 | DEFAULT_STEPS_PER_BAR = 16 # 4/4 music sampled at 4 steps per quarter note. 20 | DEFAULT_STEPS_PER_QUARTER = 4 21 | 22 | # Default absolute quantization. 23 | DEFAULT_STEPS_PER_SECOND = 100 24 | 25 | # Standard pulses per quarter. 26 | # https://en.wikipedia.org/wiki/Pulses_per_quarter_note 27 | STANDARD_PPQ = 220 28 | 29 | # Special melody events. 30 | NUM_SPECIAL_MELODY_EVENTS = 2 31 | MELODY_NOTE_OFF = -1 32 | MELODY_NO_EVENT = -2 33 | 34 | # Other melody-related constants. 35 | MIN_MELODY_EVENT = -2 36 | MAX_MELODY_EVENT = 127 37 | MIN_MIDI_PITCH = 0 # Inclusive. 38 | MAX_MIDI_PITCH = 127 # Inclusive. 39 | NOTES_PER_OCTAVE = 12 40 | 41 | # Velocity-related constants. 42 | MIN_MIDI_VELOCITY = 1 # Inclusive. 43 | MAX_MIDI_VELOCITY = 127 # Inclusive. 44 | 45 | # Program-related constants. 46 | MIN_MIDI_PROGRAM = 0 47 | MAX_MIDI_PROGRAM = 127 48 | 49 | # MIDI programs that typically sound unpitched. 50 | UNPITCHED_PROGRAMS = ( 51 | list(range(96, 104)) + list(range(112, 120)) + list(range(120, 128))) 52 | 53 | # Chord symbol for "no chord". 54 | NO_CHORD = 'N.C.' 55 | 56 | # The indices of the pitch classes in a major scale. 57 | MAJOR_SCALE = [0, 2, 4, 5, 7, 9, 11] 58 | 59 | # NOTE_KEYS[note] = The major keys that note belongs to. 60 | # ex. NOTE_KEYS[0] lists all the major keys that contain the note C, 61 | # which are: 62 | # [0, 1, 3, 5, 7, 8, 10] 63 | # [C, C#, D#, F, G, G#, A#] 64 | # 65 | # 0 = C 66 | # 1 = C# 67 | # 2 = D 68 | # 3 = D# 69 | # 4 = E 70 | # 5 = F 71 | # 6 = F# 72 | # 7 = G 73 | # 8 = G# 74 | # 9 = A 75 | # 10 = A# 76 | # 11 = B 77 | # 78 | # NOTE_KEYS can be generated using the code below, but is explicitly declared 79 | # for readability: 80 | # NOTE_KEYS = [[j for j in range(12) if (i - j) % 12 in MAJOR_SCALE] 81 | # for i in range(12)] 82 | NOTE_KEYS = [ 83 | [0, 1, 3, 5, 7, 8, 10], 84 | [1, 2, 4, 6, 8, 9, 11], 85 | [0, 2, 3, 5, 7, 9, 10], 86 | [1, 3, 4, 6, 8, 10, 11], 87 | [0, 2, 4, 5, 7, 9, 11], 88 | [0, 1, 3, 5, 6, 8, 10], 89 | [1, 2, 4, 6, 7, 9, 11], 90 | [0, 2, 3, 5, 7, 8, 10], 91 | [1, 3, 4, 6, 8, 9, 11], 92 | [0, 2, 4, 5, 7, 9, 10], 93 | [1, 3, 5, 6, 8, 10, 11], 94 | [0, 2, 4, 6, 7, 9, 11] 95 | ] 96 | 97 | # Names of pitch classes to use (mostly ignoring spelling). 98 | PITCH_CLASS_NAMES = [ 99 | 'C', 'C#', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B'] 100 | 101 | SCALE_MODE = {0: "MAJOR", 1: "MINOR"} 102 | 103 | # STEPS_PER_QUARTER = 12 104 | QUARTERS_PER_BAR = 4 105 | # BARS_PER_DATA = 4 106 | 107 | MAX_LENGTH = 16 108 | NUM_INSTRUMENTS = 4 109 | 110 | SOS_token = MAX_MIDI_PITCH + 1 # Start of sequence 111 | EOS_token = MAX_MIDI_PITCH + 2 # End of sequence 112 | REST_token = MAX_MIDI_PITCH + 3 # Token for REST 113 | 114 | # MAX_SHIFT_STEPS = STEPS_PER_QUARTER * QUARTERS_PER_BAR 115 | MAX_EVENT_LENGTH = 300 # TODO 116 | 117 | -------------------------------------------------------------------------------- /create_nottingham_roll.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import h5py 4 | import pretty_midi 5 | import glob 6 | 7 | from data_converter import RollAugMonoSingleSequenceConverter 8 | from constants import QUARTERS_PER_BAR 9 | from note_sequence_ops import quantize_note_sequence, calculate_style_feature, \ 10 | delete_auftakt, filter_note 11 | from chords_infer import infer_chords_for_sequence 12 | from midi_io import midi_to_note_sequence 13 | 14 | # CHANGE PARAM HERE 15 | CHORDS_PER_BAR = 1 16 | STEPS_PER_QUARTER = 4 17 | BARS_PER_DATA = 4 18 | 19 | 20 | def run(): 21 | base_path = "data/original" 22 | h5py_dir = 'data/h5py' # CHANGE PATH HERE 23 | metadata_dir = 'data/metadata' # CHANGE PATH HERE 24 | 25 | if not os.path.exists(h5py_dir): 26 | os.mkdir(h5py_dir) 27 | if not os.path.exists(metadata_dir): 28 | os.mkdir(metadata_dir) 29 | 30 | data_splits = ["train", "test", "valid"] 31 | 32 | RPSC = RollAugMonoSingleSequenceConverter(steps_per_quarter=STEPS_PER_QUARTER, quarters_per_bar=QUARTERS_PER_BAR, 33 | chords_per_bar=CHORDS_PER_BAR, bars_per_data=BARS_PER_DATA) 34 | 35 | for m_i, mode in enumerate(data_splits): 36 | print(mode) 37 | path_lists = glob.glob(os.path.join(base_path, mode, '*.mid')) 38 | h5_path_chord = os.path.join(h5py_dir, mode + '_chord.h5') 39 | h5_path_event = os.path.join(h5py_dir, mode + '_event.h5') 40 | h5_path_sf = os.path.join(h5py_dir, mode + '_sf.h5') 41 | h5_path_chord_f = os.path.join(h5py_dir, mode + '_chord_f.h5') 42 | 43 | data_list = [] 44 | f_chord = h5py.File(h5_path_chord, 'w') 45 | f_event = h5py.File(h5_path_event, 'w') 46 | f_sf = h5py.File(h5_path_sf, 'w') 47 | f_chord_f = h5py.File(h5_path_chord_f, 'w') 48 | 49 | for i, path in enumerate(path_lists): 50 | print(path) 51 | pm = pretty_midi.PrettyMIDI(path) 52 | ns = midi_to_note_sequence(pm) 53 | ns = delete_auftakt(ns) 54 | quantized_ns = quantize_note_sequence(ns, steps_per_quarter=STEPS_PER_QUARTER) 55 | quantized_sequence_with_chord = infer_chords_for_sequence( 56 | quantized_ns, chords_per_bar=CHORDS_PER_BAR, add_key_signatures=True) 57 | filtered_ns = filter_note(quantized_sequence_with_chord, ins_index=0) 58 | 59 | style_feature = calculate_style_feature(filtered_ns, num_inst=4, mono=True) 60 | if True: 61 | events, chord, chord_feature, data_num = RPSC.to_tensors_instrument_separate( 62 | filtered_ns, train=mode == 'train') 63 | if len(events) == 0: 64 | print(path, 'none event') 65 | continue 66 | 67 | name = path.split('/')[-1] 68 | for j, (c, c_f, n) in enumerate(zip(chord, chord_feature, events)): 69 | ind_shift = j // data_num 70 | ind_data = j % data_num 71 | key = '{}_{}_{}'.format(name, ind_data, ind_shift) 72 | try: 73 | f_chord.create_dataset(key, data=c) 74 | f_sf.create_dataset(key, data=style_feature) 75 | f_chord_f.create_dataset(key, data=c_f) 76 | f_event.create_dataset(key, data=n) 77 | data_list.append(key) 78 | except Exception as e: 79 | print(path, e) 80 | 81 | data_list_path = os.path.join(metadata_dir, mode + '.pkl') 82 | with open(data_list_path, 'wb') as f: 83 | pickle.dump(data_list, f, protocol=2) 84 | 85 | f_chord.close() 86 | f_event.close() 87 | f_sf.close() 88 | f_chord_f.close() 89 | 90 | 91 | if __name__ == "__main__": 92 | run() 93 | -------------------------------------------------------------------------------- /data_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from constants import STANDARD_PPQ, PITCH_CLASS_NAMES 4 | from note_sequence import Note, NoteSequence 5 | 6 | INS_TO_RANGE = { 7 | 'piano': [50, 95], 8 | } 9 | 10 | class RollAugMonoSingleSequenceConverter(object): 11 | _CHORD_KIND_INDEX = { 12 | '': 0, 13 | 'm': 1, 14 | '7': 2, 15 | } 16 | INS_NAMES = ['piano'] 17 | def __init__(self, steps_per_quarter, quarters_per_bar, chords_per_bar, bars_per_data, ins_to_range=INS_TO_RANGE): 18 | self._chords_per_data = chords_per_bar * bars_per_data 19 | self._steps_per_chords = steps_per_quarter * quarters_per_bar // chords_per_bar 20 | self._steps_per_data = steps_per_quarter * quarters_per_bar * bars_per_data 21 | self._notes_per_data = None 22 | self.reversed_chord_kind_index = {v: k for k, v in self._CHORD_KIND_INDEX.items()} 23 | self.INS_TO_RANGE = ins_to_range 24 | self.create_vocab() 25 | 26 | def to_tensors_instrument_separate(self, note_sequence, train=True): 27 | """ 28 | Convert note sequence to tensor assuming that monophonic for each instrument. 29 | :param note_sequence: quantized note sequence with chord from midi 30 | :param train: YES if it's data for train 31 | :return: rhythm tensor [num_data, bars_per_data, steps_per_data], 32 | pitch tensor [num_data, bars_per_data, steps_per_data], chord tensor [num_data, chords_per_data] 33 | """ 34 | if train: 35 | shifts = [i for i in range(-5, 7)] 36 | else: 37 | shifts = [0] 38 | 39 | # Event list shape: [12, num_chords] 40 | event_lists = [] # shift x num_chord 41 | chord_indices = [] 42 | chord_features = [] 43 | 44 | for i, ta_chord in enumerate(note_sequence.text_annotations): 45 | this_event_list, use_this_bar = self._from_quantized_sequence( 46 | note_sequence, 47 | start_step=i * self._steps_per_chords, 48 | end_step=(i + 1) * self._steps_per_chords, 49 | shift=0) 50 | if use_this_bar: 51 | this_event_indices_shifted = list(map( 52 | lambda shift: self.shift_event_indices(this_event_list, shift), shifts)) 53 | event_lists.extend(this_event_indices_shifted) 54 | 55 | chord_shifted = list(map(lambda shift: self.chord_index(ta_chord, shift), shifts)) 56 | chord_indices.extend(chord_shifted) 57 | 58 | chord_feature_shifted = list(map(lambda shift: np.roll(ta_chord.pitch_vector, shift), shifts)) 59 | chord_features.extend(chord_feature_shifted) 60 | else: 61 | print('no use') 62 | 63 | event_lists = np.array(event_lists).reshape((-1, len(shifts), self._steps_per_chords)) # num_chord, shifts, event 64 | event_lists = np.transpose(event_lists, (1, 0, 2)) # shifts, num_chord, len_events 65 | chord_indices = np.array(chord_indices).reshape((-1, len(shifts))).transpose() # shifts, num_chord 66 | chord_features = np.array(chord_features).reshape((-1, len(shifts), 12)) # num_chord, shifts, pitch 67 | chord_features = np.transpose(chord_features, (1, 0, 2)) # shifts, num_chord, pitch 68 | 69 | total_data_num = event_lists.shape[1] // self._chords_per_data 70 | if total_data_num > 0: 71 | event_lists = event_lists[:, :total_data_num * self._chords_per_data] 72 | chord_indices = chord_indices[:, :total_data_num * self._chords_per_data] 73 | chord_features = chord_features[:, :total_data_num * self._chords_per_data] 74 | 75 | event_lists = event_lists.reshape((-1, self._chords_per_data, self._steps_per_chords)) 76 | chord_indices = chord_indices.reshape((-1, self._chords_per_data)) 77 | chord_features = chord_features.reshape((-1, self._chords_per_data, 12)) 78 | return event_lists, chord_indices, chord_features, total_data_num 79 | 80 | def _from_quantized_sequence(self, quantized_sequence, start_step, end_step, shift): 81 | """Extract a list of events from the given quantized NoteSequence object. 82 | 83 | Within a step, new pitches are started with NOTE_ON and existing pitches are 84 | ended with NOTE_OFF. TIME_SHIFT shifts the current step forward in time. 85 | Args: 86 | quantized_sequence: A quantized NoteSequence instance. 87 | start_step: Start converting the sequence at this time step. 88 | Returns: 89 | A list of events. 90 | """ 91 | # Adds the pitches which were on in the previous sequence. 92 | notes = [note for note in quantized_sequence.notes 93 | if note.quantized_start_step < end_step and note.quantized_end_step > start_step] 94 | sorted_notes = sorted(notes, key=lambda note: (note.start_time, note.pitch)) 95 | events = ['REST' for _ in range(end_step - start_step)] 96 | use_this_bar = False 97 | for note in sorted_notes: 98 | note_start_step = max(note.quantized_start_step - start_step, 0) 99 | note_end_step = min(note.quantized_end_step - start_step, end_step - start_step) 100 | events[note_start_step] = 'NOTEON_{}'.format(note.pitch) 101 | use_this_bar = True 102 | for i in range(note_start_step + 1, note_end_step): 103 | events[i] = 'CONTINUE' 104 | return events, use_this_bar 105 | 106 | def to_note_sequence_from_events(self, events, seconds_per_step): 107 | """ 108 | Convert to note sequence. 109 | :param seconds_per_step: 110 | :return: 111 | """ 112 | ns = NoteSequence(ticks_per_quarter=STANDARD_PPQ) 113 | ns.total_time = 0 114 | for i, event in enumerate(events): # Each bar 115 | sequence_start_time = i * self._steps_per_chords * seconds_per_step 116 | sequence_end_time = (i + 1) * self._steps_per_chords * seconds_per_step 117 | ns = self._to_sequence(ns, event, seconds_per_step, sequence_start_time, sequence_end_time) 118 | ns.total_time = sequence_end_time 119 | return ns 120 | 121 | def _to_sequence(self, sequence, event_list, seconds_per_step, sequence_start_time, sequence_end_time): 122 | velocity = 60 123 | pitch = None 124 | pitch_start_step = 0 125 | 126 | for i, event_index in enumerate(event_list): 127 | event = self.vocab[int(event_index)] 128 | if event == 'CONTINUE': 129 | if pitch is not None: 130 | 'unexpected continue' 131 | continue 132 | if pitch is not None: 133 | start_time = pitch_start_step * seconds_per_step + sequence_start_time 134 | end_time = i * seconds_per_step + sequence_start_time 135 | end_time = min(end_time, sequence_end_time) 136 | note = Note(instrument=0, program=0, start_time=start_time, end_time=end_time, 137 | pitch=pitch, velocity=velocity, is_drum=False) 138 | sequence.notes.append(note) 139 | pitch = None 140 | if 'NOTEON' in event: 141 | pitch = int(event.split('_')[1]) 142 | pitch_start_step = i 143 | 144 | if pitch is not None: 145 | start_time = pitch_start_step * seconds_per_step + sequence_start_time 146 | end_time = (i + 1) * seconds_per_step + sequence_start_time 147 | end_time = min(end_time, sequence_end_time) 148 | note = Note(instrument=0, program=0, start_time=start_time, end_time=end_time, 149 | pitch=pitch, velocity=velocity, is_drum=False) 150 | sequence.notes.append(note) 151 | return sequence 152 | 153 | def create_vocab(self): 154 | self.vocab = ['NOTEON_{}'.format(pitch) for pitch in range( 155 | self.INS_TO_RANGE[self.INS_NAMES[0]][0], self.INS_TO_RANGE[self.INS_NAMES[0]][1])] 156 | self.vocab.extend(['REST', 'CONTINUE']) 157 | 158 | def shift_event_indices(self, events, shift): 159 | shifted_events = list(map(lambda e: self.index_with_shift(e, shift), events)) 160 | return shifted_events 161 | 162 | def index_with_shift(self, event, shift): 163 | if event in ['REST', 'CONTINUE']: 164 | return self.vocab.index(event) 165 | pitch = int(event.split('_')[1]) + shift 166 | assert self.INS_TO_RANGE[self.INS_NAMES[0]][0] <= pitch <= self.INS_TO_RANGE[self.INS_NAMES[0]][1] 167 | event_str = 'NOTEON_{}'.format(pitch) 168 | return self.vocab.index(event_str) 169 | 170 | def chord_index(self, ta_chord, offset): 171 | if isinstance(ta_chord.root, str) or ta_chord.kind not in self._CHORD_KIND_INDEX: # No chord 172 | return 0 173 | root = (ta_chord.root + offset) % 12 174 | kind_ind = self._CHORD_KIND_INDEX[ta_chord.kind] 175 | return root * len(self._CHORD_KIND_INDEX) + kind_ind + 1 176 | 177 | def chord_from_index(self, chord_index): 178 | if chord_index == 0: 179 | return "N.C." 180 | root = (chord_index - 1) // len(self._CHORD_KIND_INDEX) 181 | kind_ind = (chord_index - 1) % len(self._CHORD_KIND_INDEX) 182 | return PITCH_CLASS_NAMES[root] + self.reversed_chord_kind_index[kind_ind] -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | learning_rate: 0.0001 2 | learning_rate_d: 0.0001 3 | lambda_d: 0.1 4 | lambda_kl: 0.1 5 | n_iters: 50000 6 | chords_per_bar: 1 7 | gradient_clip: 1.0 8 | vocab_size: 47 9 | c_vocab_size: 37 10 | activation: tanh 11 | activation_d: leakyrelu 12 | batch_size: 64 13 | steps_per_quarter: 4 14 | bars_per_data: 4 15 | max_shift_steps: 16 16 | chords_per_data: 4 17 | n_classes: 8 18 | # CHANGE PATH HERE 19 | data: 20 | chord: data/h5py/train_chord.h5 21 | chord_f: data/h5py/train_chord_f.h5 22 | style: data/h5py/train_sf.h5 23 | attr: data/h5py/train_style.h5 24 | attr_cls: data/h5py/train_style_cls.h5 25 | event: data/h5py/train_event.h5 26 | keys: data/metadata/train.pkl 27 | 28 | chord_test: data/h5py/test_chord.h5 29 | chord_f_test: data/h5py/test_chord_f.h5 30 | style_test: data/h5py/test_sf.h5 31 | attr_test: data/h5py/test_style.h5 32 | attr_cls_test: data/h5py/test_style_cls.h5 33 | event_test: data/h5py/test_event.h5 34 | keys_test: data/metadata/test.pkl 35 | 36 | chord_valid: data/h5py/valid_chord.h5 37 | chord_f_valid: data/h5py/valid_chord_f.h5 38 | style_valid: data/h5py/valid_sf.h5 39 | attr_valid: data/h5py/valid_style.h5 40 | attr_cls_valid: data/h5py/valid_style_cls.h5 41 | inst_valid: data/h5py/valid_inst.h5 42 | event_valid: data/h5py/valid_event.h5 43 | keys_valid: data/metadata/valid.pkl 44 | 45 | num_instruments: 1 46 | model: 47 | chord_encoder: 48 | hidden_size: 512 49 | num_layers: 4 50 | encoder: 51 | num_layers: 2 52 | cnn_hidden_size: 1152 53 | hidden_size: 1024 54 | cnn: 55 | cnn_1: 56 | n_channel: 128 57 | kernel_size: 4 58 | cnn_2: 59 | n_channel: 128 60 | kernel_size: 4 61 | pool: 62 | pool_1: 63 | kernel_size: 2 64 | stride: 2 65 | pool_2: 66 | kernel_size: 2 67 | stride: 2 68 | rnn: 69 | hidden_size: 512 70 | num_layers: 4 71 | decoder: 72 | hidden_size: 512 73 | num_layers: 4 74 | discriminator: 75 | num_layers: 2 76 | n_attr: 1 77 | attr: 78 | - Rhythmic_Value_Variability 79 | thresholds: 80 | Rhythmic_Value_Variability: 0. -------------------------------------------------------------------------------- /interpolate_double.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import pickle 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | 7 | import torch 8 | 9 | import constants 10 | from model import Classifier, FaderVAE 11 | from arguments import generate_cfg_fader 12 | from data_converter import RollAugMonoSingleSequenceConverter 13 | from note_sequence_ops import steps_per_quarter_to_seconds_per_step, quantize_note_sequence 14 | from midi_io import note_sequence_to_pretty_midi 15 | from chords_infer import infer_chords_for_sequence 16 | from note_sequence import Tempo, TimeSignature, InstrumentInfo 17 | 18 | # CHANGE PATH HERE 19 | output_path_base = 'output' 20 | 21 | 22 | def to_tensor(data, device): 23 | if device == 'cpu': 24 | return torch.LongTensor(data, device=device) 25 | return torch.cuda.LongTensor(data, device=device) 26 | 27 | 28 | def to_float_tensor(data, device): 29 | if device == 'cpu': 30 | return torch.FloatTensor(data, device=device) 31 | return torch.cuda.FloatTensor(data, device=device) 32 | 33 | 34 | def get_original_style_name(attrs, original_style): 35 | if len(attrs) == 1: 36 | return '_org_{}{:.2}'.format(attrs[0], float(original_style[0][0])) 37 | if len(attrs) == 2: 38 | return '_org_{}{:.2}_{}{:.2}'.format(attrs[0], float(original_style[0][0]), 39 | attrs[1], float(original_style[0][1])) 40 | if len(attrs) == 3: 41 | return '_org_{}{:.2}_{}{:.2}_{}{:.2}'.format(attrs[0], float(original_style[0][0]), 42 | attrs[1], float(original_style[0][1]), 43 | attrs[2], float(original_style[0][2])) 44 | 45 | 46 | def test(f_chord_test, f_style_test, f_event_test, keys_test, model, device, attrs, 47 | batch_size): 48 | model.eval() 49 | losses, losses_d, losses_kl, accs = 0., 0., 0., 0. 50 | preds, org_styles, chords, interpolation_for_back = {}, {}, {}, {} 51 | 52 | interpolate_num = 11 53 | vec_to_interpolate = to_float_tensor(np.array([(i - 5) * 0.1 for i in range(interpolate_num)]), device=device) 54 | mul_tgt = [-1.0, -0.5, -0.3, 0.3, 0.5, 1.0] 55 | vec_to_interpolate_mul = to_float_tensor(np.array([[i, j] for i in mul_tgt for j in mul_tgt]), device=device) 56 | 57 | with torch.no_grad(): 58 | for key in keys_test: 59 | chord_tensor = to_float_tensor(f_chord_test[key], device=device).repeat(batch_size, 1, 1) 60 | event_tensor = to_tensor(f_event_test[key], device=device).repeat(batch_size, 1, 1) 61 | style_tensor = to_float_tensor([f_style_test[attr + '/' + key] for attr in attrs], 62 | device=device).reshape(1, len(attrs)).repeat(batch_size, 1) 63 | 64 | style_tensor[:interpolate_num, 0] += vec_to_interpolate 65 | style_tensor[interpolate_num: 2 * interpolate_num, 1] += vec_to_interpolate 66 | style_tensor[interpolate_num * 2: interpolate_num * 2 + len(mul_tgt) ** 2] += vec_to_interpolate_mul 67 | 68 | loss, pred, lv, acc, distribution = model(event_tensor, chord_tensor, style_tensor) 69 | 70 | losses += loss 71 | accs += acc 72 | 73 | original_style_value_0 = float(np.array(f_style_test[attrs[0] + '/' + key])) 74 | original_style_value_1 = float(np.array(f_style_test[attrs[1] + '/' + key])) 75 | names = ['{}_{}_{:.5}'.format(key, attrs[0], original_style_value_0 + (i - 5) * 0.1) for i in range(interpolate_num)] +\ 76 | ['{}_{}_{:.5}'.format(key, attrs[1], original_style_value_1 + (i - 5) * 0.1) for i in range(interpolate_num)] +\ 77 | ['{}_demo_{}_{:.5}_{}_{:.5}'.format( 78 | key, attrs[0], original_style_value_0 + i, attrs[1], original_style_value_1 + j) 79 | for i in mul_tgt for j in mul_tgt] 80 | for i, name in enumerate(names): 81 | preds[name] = pred[i] 82 | 83 | return losses.item() / batch_size, accs / batch_size, preds 84 | 85 | def evaluation(model, cfg, device): 86 | f_chord_test = h5py.File(cfg['data']['chord_f_valid'], 'r') 87 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 88 | f_style_test = h5py.File(cfg['data']['attr_valid'], 'r') 89 | with open(cfg['data']['keys_valid'], 'rb') as f: 90 | keys_test = pickle.load(f) 91 | 92 | return test(f_chord_test, f_style_test, f_event_test, keys_test, model, device, 93 | cfg['attr'], cfg['batch_size']) 94 | 95 | 96 | def run(args): 97 | output_dir = os.path.join(output_path_base, args.model_name) 98 | latest_model_text_file = os.path.join(output_dir, 'latest_model.txt') 99 | sample_dir = os.path.join(output_dir, 'samples') 100 | if not os.path.exists(sample_dir): 101 | os.mkdir(sample_dir) 102 | 103 | if not os.path.exists(output_dir) and not os.path.exists(latest_model_text_file): 104 | raise IOError("Model file not found.") 105 | 106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | print(torch.cuda.current_device()) 108 | 109 | if args.model_path: 110 | latest_model_path = args.model_path 111 | else: 112 | with open(latest_model_text_file, 'r') as f: 113 | latest_model_path = f.read() 114 | checkpoint = torch.load(latest_model_path) 115 | latest_model_name = latest_model_path.split('/')[-1] 116 | cfg = generate_cfg_fader(None, args, output_dir, checkpoint) 117 | 118 | print(cfg) 119 | 120 | model_d = Classifier(input_dim=cfg['z_dims'], 121 | num_layers=cfg['model']['discriminator']['num_layers'], 122 | n_attr=len(cfg['attr']), 123 | activation=cfg['activation_d'], 124 | n_classes=8, 125 | device=device) 126 | 127 | model = FaderVAE(vocab_size=cfg['vocab_size'], 128 | hidden_dims=cfg['model']['encoder']['hidden_size'], 129 | z_dims=cfg['z_dims'], 130 | n_step=cfg['bars_per_data'] * cfg['steps_per_quarter'] * constants.QUARTERS_PER_BAR, 131 | device=device, 132 | n_attr=cfg['n_attr']) 133 | 134 | model.load_state_dict(checkpoint['model']) 135 | model_d.load_state_dict(checkpoint['model_d']) 136 | model.to(device) 137 | model_d.to(device) 138 | print(model) 139 | 140 | losses, accs, preds = evaluation(model, cfg, device) 141 | val_result_path = os.path.join(output_dir, 'eval_result.txt') 142 | with open(val_result_path, 'a') as f: 143 | f.write("model: {}, val_loss: {}, acc: {}, ".format( 144 | latest_model_name, losses, accs)) 145 | 146 | RPSC = RollAugMonoSingleSequenceConverter(steps_per_quarter=cfg['steps_per_quarter'], 147 | quarters_per_bar=constants.QUARTERS_PER_BAR, 148 | chords_per_bar=cfg['chords_per_bar'], 149 | bars_per_data=cfg['bars_per_data']) 150 | 151 | sample_save_path = os.path.join(sample_dir, latest_model_name + '_original') 152 | if not os.path.exists(sample_save_path): 153 | os.mkdir(sample_save_path) 154 | 155 | seconds_per_step = steps_per_quarter_to_seconds_per_step(cfg['steps_per_quarter'], 60) 156 | 157 | f_chord_test = h5py.File(cfg['data']['chord_valid'], 'r') 158 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 159 | with open(cfg['data']['keys_valid'], 'rb') as f: 160 | keys_test = pickle.load(f) 161 | 162 | chord_acc, chord_style_acc = 0., 0. 163 | for key, event in preds.items(): 164 | # Normalized 165 | event = event.reshape(cfg['chords_per_data'], -1) 166 | ns = RPSC.to_note_sequence_from_events(event, seconds_per_step) # Normalized 167 | tempo = Tempo(time=0., qpm=60) 168 | ns.tempos.append(tempo) 169 | ns.instrument_infos = { 170 | InstrumentInfo('piano', 0), 171 | } 172 | time_signature = TimeSignature(time=0, numerator=4, denominator=4) 173 | ns.time_signatures.append(time_signature) 174 | 175 | quantized_ns = quantize_note_sequence(ns, steps_per_quarter=cfg['steps_per_quarter']) 176 | try: 177 | ns_with_chord = infer_chords_for_sequence(quantized_ns, chords_per_bar=cfg['chords_per_bar']) 178 | except Exception as e: 179 | print(e) 180 | continue 181 | pm = note_sequence_to_pretty_midi(ns_with_chord) 182 | key_string = '_'.join(key.split('/')) 183 | output_path = os.path.join(sample_save_path, key_string + '.mid') 184 | 185 | print(output_path) 186 | pm.write(output_path) 187 | 188 | # Normalized key comparison 189 | chord_list = [ta_chord.text for ta_chord in ns_with_chord.text_annotations] 190 | chord_txt = ",".join(chord_list) 191 | output_chord_path = os.path.join(sample_save_path, key_string + '.txt') 192 | with open(output_chord_path, 'w') as f: 193 | f.write(chord_txt) 194 | 195 | with open(val_result_path, 'a') as f: 196 | f.write("chord_acc: {}\n".format(chord_acc / len(keys_test))) 197 | 198 | original_path = os.path.join(sample_dir, 'original') 199 | if not os.path.exists(original_path): 200 | os.mkdir(original_path) 201 | 202 | for key in keys_test: 203 | # Unnormalized 204 | ns = RPSC.to_note_sequence_from_events(np.array(f_event_test[key]), seconds_per_step) 205 | pm = note_sequence_to_pretty_midi(ns) 206 | key_string = '_'.join(key.split('/')) 207 | output_path = os.path.join(original_path, key_string + '.mid') 208 | pm.write(output_path) 209 | chord = list(f_chord_test[key]) 210 | chord_list = [RPSC.chord_from_index(c) for c in chord] 211 | chord_list = ",".join(chord_list) 212 | output_chord_path = os.path.join(original_path, key_string + '.txt') 213 | with open(output_chord_path, 'w') as f: 214 | f.write(chord_list) 215 | 216 | 217 | if __name__ == "__main__": 218 | parser = ArgumentParser() 219 | parser.add_argument('--gpu', type=int, default=[0, 1], nargs='+', help='used gpu') 220 | parser.add_argument('--model_name', type=str, default="tmp", help='model name') 221 | parser.add_argument('--model_path', type=str, help='to use a specific model, not latest') 222 | 223 | args = parser.parse_args() 224 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) 225 | 226 | run(args) 227 | -------------------------------------------------------------------------------- /interpolate_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import pickle 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | 7 | import torch 8 | from torch.distributions import kl_divergence, Normal 9 | 10 | import constants 11 | from model import Classifier, FaderVAE 12 | from arguments import generate_cfg_fader 13 | from data_converter import RollAugMonoSingleSequenceConverter 14 | from note_sequence_ops import steps_per_quarter_to_seconds_per_step, quantize_note_sequence 15 | from midi_io import note_sequence_to_pretty_midi 16 | from chords_infer import infer_chords_for_sequence 17 | from note_sequence import Tempo, TimeSignature, InstrumentInfo 18 | from utils import std_normal 19 | 20 | # CHANGE PATH HERE 21 | output_path_base = 'output' 22 | 23 | 24 | def to_tensor(data, device): 25 | if device == 'cpu': 26 | return torch.LongTensor(data, device=device) 27 | return torch.cuda.LongTensor(data, device=device) 28 | 29 | 30 | def to_float_tensor(data, device): 31 | if device == 'cpu': 32 | return torch.FloatTensor(data, device=device) 33 | return torch.cuda.FloatTensor(data, device=device) 34 | 35 | 36 | def create_style_const(attrs, batch_size, thresholds): 37 | if len(attrs) == 1: 38 | style_ratios = [[thre] for thre in thresholds[attrs[0]]] 39 | names = ['_{}{:.2}'.format(attrs[0], ratio[0]) for ratio in style_ratios] 40 | padding = [[0.] for _ in range(batch_size - len(style_ratios) - 1)] 41 | style_const = np.array(style_ratios + padding) 42 | style_const_cls = np.array([[i] for i in range(len(thresholds[attrs[0]]))] + padding) 43 | elif len(attrs) == 2: 44 | tgt_class = [0, 4, 7] 45 | style_ratios = [[thresholds[attrs[0]][i], thresholds[attrs[1]][j]] for i in tgt_class for j in tgt_class] 46 | names = ['_{}{:.2}_{}{:.2}'.format(attrs[0], ratio[0], attrs[1], ratio[1]) for ratio in style_ratios] 47 | padding = [[0., 0.] for _ in range(batch_size - len(style_ratios) - 1)] 48 | style_const = np.array(style_ratios + padding) 49 | style_const_cls = np.array([[i, j] for i in tgt_class for j in tgt_class] + padding) 50 | elif len(attrs) == 3: 51 | tgt_class = [2, 5] 52 | style_ratios = [[thresholds[attrs[0]][i], thresholds[attrs[1]][j], thresholds[attrs[2]][k]] 53 | for i in tgt_class for j in tgt_class for k in tgt_class] 54 | names = ['_{}{:.2}_{}{:.2}_{}{:.2}'.format(attrs[0], ratio[0], attrs[1], ratio[1], attrs[2], ratio[2]) 55 | for ratio in style_ratios] 56 | padding = [[0., 0., 0.] for _ in range(batch_size - len(style_ratios) - 1)] 57 | style_const = np.array(style_ratios + padding) 58 | style_const_cls = np.array([[i, j, k] for i in tgt_class for j in tgt_class for k in tgt_class] + padding) 59 | else: 60 | raise Exception('unsupported number of style') 61 | return style_const, style_const_cls, names 62 | 63 | 64 | def get_original_style_name(attrs, original_style): 65 | if len(attrs) == 1: 66 | return '_org_{}{:.2}'.format(attrs[0], float(original_style[0][0])) 67 | if len(attrs) == 2: 68 | return '_org_{}{:.2}_{}{:.2}'.format(attrs[0], float(original_style[0][0]), 69 | attrs[1], float(original_style[0][1])) 70 | if len(attrs) == 3: 71 | return '_org_{}{:.2}_{}{:.2}_{}{:.2}'.format(attrs[0], float(original_style[0][0]), 72 | attrs[1], float(original_style[0][1]), 73 | attrs[2], float(original_style[0][2])) 74 | 75 | 76 | def test(f_chord_test, f_style_test, f_style_cls_test, f_event_test, keys_test, model, model_d, device, attrs, 77 | batch_size, thresholds): 78 | model.eval() 79 | model_d.eval() 80 | losses, losses_d, losses_kl, accs = 0., 0., 0., 0. 81 | preds = {} 82 | 83 | style_const, style_const_cls, names = create_style_const(attrs, batch_size, thresholds) 84 | with torch.no_grad(): 85 | for key in keys_test: 86 | chord_tensor = to_float_tensor(f_chord_test[key], device=device).repeat(batch_size, 1, 1) 87 | event_tensor = to_tensor(f_event_test[key], device=device).repeat(batch_size, 1, 1) 88 | 89 | original_style_value = np.array([f_style_test[attr + '/' + key] for attr in attrs]).reshape(1, -1) 90 | style_tensor = np.concatenate([original_style_value, style_const]) 91 | style_tensor = to_float_tensor(style_tensor, device=device) 92 | 93 | original_style_cls_value = np.array([f_style_cls_test[attr + '/' + key] for attr in attrs]).reshape(1, -1) 94 | style_cls_tensor = np.concatenate([original_style_cls_value, style_const_cls]) 95 | style_cls_tensor = to_tensor(style_cls_tensor, device=device) 96 | 97 | loss, pred, lv, acc, distribution = model(event_tensor, chord_tensor, style_tensor) 98 | dis_out = model_d(lv) 99 | loss_d = model_d.calc_loss(dis_out, style_cls_tensor) 100 | 101 | normal = std_normal(distribution.mean.size()) 102 | loss_kl = kl_divergence(distribution, normal).mean() 103 | 104 | losses += loss 105 | accs += acc 106 | losses_d += loss_d 107 | losses_kl += loss_kl 108 | 109 | original_style_name = get_original_style_name(attrs, original_style_value) 110 | preds[key + original_style_name] = pred[0] 111 | for i, name in enumerate(names): 112 | preds[key + name] = pred[i + 1] 113 | 114 | return losses.item() / batch_size, losses_d.item() / batch_size, losses_kl.item() / batch_size, accs / batch_size,\ 115 | preds 116 | 117 | 118 | def evaluation(model, model_d, cfg, device): 119 | f_chord_test = h5py.File(cfg['data']['chord_f_valid'], 'r') 120 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 121 | f_style_test = h5py.File(cfg['data']['attr_valid'], 'r') 122 | f_style_cls_test = h5py.File(cfg['data']['attr_cls_valid'], 'r') 123 | with open(cfg['data']['keys_valid'], 'rb') as f: 124 | keys_test = pickle.load(f) 125 | 126 | threshold_path = '/data/unagi0/kawai/Nottingham/processed_h5/roll_4_4_metadata/style_cls_thresholds.pkl' 127 | with open(threshold_path, 'rb') as f: 128 | thresholds = pickle.load(f, encoding='latin1') 129 | return test(f_chord_test, f_style_test, f_style_cls_test, f_event_test, keys_test, model, model_d, device, 130 | cfg['attr'], cfg['batch_size'], thresholds) 131 | 132 | 133 | def run(args): 134 | output_dir = os.path.join(output_path_base, args.model_name) 135 | latest_model_text_file = os.path.join(output_dir, 'latest_model.txt') 136 | sample_dir = os.path.join(output_dir, 'samples') 137 | if not os.path.exists(sample_dir): 138 | os.mkdir(sample_dir) 139 | 140 | if not os.path.exists(output_dir) and not os.path.exists(latest_model_text_file): 141 | raise IOError("Model file not found.") 142 | 143 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 144 | print(torch.cuda.current_device()) 145 | 146 | if args.model_path: 147 | latest_model_path = args.model_path 148 | else: 149 | with open(latest_model_text_file, 'r') as f: 150 | latest_model_path = f.read() 151 | checkpoint = torch.load(latest_model_path) 152 | latest_model_name = latest_model_path.split('/')[-1] 153 | cfg = generate_cfg_fader(None, args, output_dir, checkpoint) 154 | 155 | print(cfg) 156 | 157 | model_d = Classifier(input_dim=cfg['z_dims'], 158 | num_layers=cfg['model']['discriminator']['num_layers'], 159 | n_attr=len(cfg['attr']), 160 | activation=cfg['activation_d'], 161 | n_classes=8, 162 | device=device) 163 | 164 | model = FaderVAE(vocab_size=cfg['vocab_size'], 165 | hidden_dims=cfg['model']['encoder']['hidden_size'], 166 | z_dims=cfg['z_dims'], 167 | n_step=cfg['bars_per_data'] * cfg['steps_per_quarter'] * constants.QUARTERS_PER_BAR, 168 | device=device, 169 | n_attr=cfg['n_attr']) 170 | 171 | model.load_state_dict(checkpoint['model']) 172 | model_d.load_state_dict(checkpoint['model_d']) 173 | model.to(device) 174 | model_d.to(device) 175 | print(model) 176 | 177 | losses, losses_d, losses_kl, accs, preds = evaluation(model, model_d, cfg, device) 178 | 179 | val_result_path = os.path.join(output_dir, 'eval_result.txt') 180 | with open(val_result_path, 'a') as f: 181 | f.write("model: {}, val_loss: {}, val_loss_d: {}, val_loss_kl: {}, acc: {}, ".format( 182 | latest_model_name, losses, losses_d, losses_kl, accs)) 183 | 184 | RPSC = RollAugMonoSingleSequenceConverter(steps_per_quarter=cfg['steps_per_quarter'], 185 | quarters_per_bar=constants.QUARTERS_PER_BAR, 186 | chords_per_bar=cfg['chords_per_bar'], 187 | bars_per_data=cfg['bars_per_data']) 188 | 189 | sample_save_path = os.path.join(sample_dir, latest_model_name + '_inter') 190 | if not os.path.exists(sample_save_path): 191 | os.mkdir(sample_save_path) 192 | 193 | seconds_per_step = steps_per_quarter_to_seconds_per_step(cfg['steps_per_quarter'], 60) 194 | 195 | f_chord_test = h5py.File(cfg['data']['chord_valid'], 'r') 196 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 197 | with open(cfg['data']['keys_valid'], 'rb') as f: 198 | keys_test = pickle.load(f) 199 | 200 | chord_acc, chord_style_acc = 0., 0. 201 | for key, event in preds.items(): 202 | # Normalized 203 | event = event.reshape(cfg['chords_per_data'], -1) 204 | ns = RPSC.to_note_sequence_from_events(event, seconds_per_step) # Normalized 205 | tempo = Tempo(time=0., qpm=60) 206 | ns.tempos.append(tempo) 207 | ns.instrument_infos = { 208 | InstrumentInfo('piano', 0), 209 | } 210 | time_signature = TimeSignature(time=0, numerator=4, denominator=4) 211 | ns.time_signatures.append(time_signature) 212 | 213 | quantized_ns = quantize_note_sequence(ns, steps_per_quarter=cfg['steps_per_quarter']) 214 | try: 215 | ns_with_chord = infer_chords_for_sequence(quantized_ns, chords_per_bar=cfg['chords_per_bar']) 216 | except Exception as e: 217 | print(e) 218 | continue 219 | pm = note_sequence_to_pretty_midi(ns_with_chord) 220 | key_string = '_'.join(key.split('/')) 221 | output_path = os.path.join(sample_save_path, key_string + '.mid') 222 | 223 | print(output_path) 224 | pm.write(output_path) 225 | 226 | chord_list = [ta_chord.text for ta_chord in ns_with_chord.text_annotations] 227 | chord_txt = ",".join(chord_list) 228 | output_chord_path = os.path.join(sample_save_path, key_string + '.txt') 229 | with open(output_chord_path, 'w') as f: 230 | f.write(chord_txt) 231 | 232 | with open(val_result_path, 'a') as f: 233 | f.write("chord_acc: {}\n".format(chord_acc / len(keys_test))) 234 | 235 | original_path = os.path.join(sample_dir, 'original') 236 | if not os.path.exists(original_path): 237 | os.mkdir(original_path) 238 | 239 | for key in keys_test: 240 | # Unnormalized 241 | ns = RPSC.to_note_sequence_from_events(np.array(f_event_test[key]), seconds_per_step) 242 | pm = note_sequence_to_pretty_midi(ns) 243 | key_string = '_'.join(key.split('/')) 244 | output_path = os.path.join(original_path, key_string + '.mid') 245 | pm.write(output_path) 246 | chord = list(f_chord_test[key]) 247 | chord_list = [RPSC.chord_from_index(c) for c in chord] 248 | chord_list = ",".join(chord_list) 249 | output_chord_path = os.path.join(original_path, key_string + '.txt') 250 | with open(output_chord_path, 'w') as f: 251 | f.write(chord_list) 252 | 253 | 254 | if __name__ == "__main__": 255 | parser = ArgumentParser() 256 | parser.add_argument('--gpu', type=int, default=[0, 1], nargs='+', help='used gpu') 257 | parser.add_argument('--model_name', type=str, default="tmp", help='model name') 258 | parser.add_argument('--model_path', type=str, help='to use a specific model, not latest') 259 | 260 | args = parser.parse_args() 261 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) 262 | 263 | run(args) 264 | -------------------------------------------------------------------------------- /interpolate_triple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import pickle 4 | from argparse import ArgumentParser 5 | import numpy as np 6 | 7 | import torch 8 | 9 | import constants 10 | from model import Classifier, FaderVAE 11 | from arguments import generate_cfg_fader 12 | from data_converter import RollAugMonoSingleSequenceConverter 13 | from note_sequence_ops import steps_per_quarter_to_seconds_per_step, quantize_note_sequence 14 | from midi_io import note_sequence_to_pretty_midi 15 | from chords_infer import infer_chords_for_sequence 16 | from note_sequence import Tempo, TimeSignature, InstrumentInfo 17 | 18 | # CHANGE PATH HERE 19 | output_path_base = 'output' 20 | 21 | 22 | def to_tensor(data, device): 23 | if device == 'cpu': 24 | return torch.LongTensor(data, device=device) 25 | return torch.cuda.LongTensor(data, device=device) 26 | 27 | 28 | def to_float_tensor(data, device): 29 | if device == 'cpu': 30 | return torch.FloatTensor(data, device=device) 31 | return torch.cuda.FloatTensor(data, device=device) 32 | 33 | 34 | def test(f_chord_test, f_style_test, f_event_test, keys_test, model, device, attrs, 35 | batch_size): 36 | model.eval() 37 | losses, losses_d, losses_kl, accs = 0., 0., 0., 0. 38 | preds, org_styles, chords, interpolation_for_back = {}, {}, {}, {} 39 | 40 | interpolate_num = 11 41 | vec_to_interpolate = to_float_tensor(np.array([(i - 5) * 0.1 for i in range(interpolate_num)]), device=device) 42 | mul_tgt = [-0.5, 0.5] 43 | vec_to_interpolate_mul = to_float_tensor(np.array([[i, j, k] for i in mul_tgt for j in mul_tgt for k in mul_tgt]), 44 | device=device) 45 | 46 | with torch.no_grad(): 47 | for key in keys_test: 48 | chord_tensor = to_float_tensor(f_chord_test[key], device=device).repeat(batch_size, 1, 1) 49 | event_tensor = to_tensor(f_event_test[key], device=device).repeat(batch_size, 1, 1) 50 | style_tensor = to_float_tensor([f_style_test[attr + '/' + key] for attr in attrs], 51 | device=device).reshape(1, len(attrs)).repeat(batch_size, 1) 52 | 53 | style_tensor[:interpolate_num, 0] += vec_to_interpolate 54 | style_tensor[1 * interpolate_num: 2 * interpolate_num, 1] += vec_to_interpolate 55 | style_tensor[2 * interpolate_num: 3 * interpolate_num, 2] += vec_to_interpolate 56 | style_tensor[3 * interpolate_num: 3 * interpolate_num + len(mul_tgt) ** 3] += vec_to_interpolate_mul 57 | 58 | loss, pred, lv, acc, distribution = model(event_tensor, chord_tensor, style_tensor) 59 | 60 | losses += loss 61 | accs += acc 62 | 63 | original_style_value_0 = float(np.array(f_style_test[attrs[0] + '/' + key])) 64 | original_style_value_1 = float(np.array(f_style_test[attrs[1] + '/' + key])) 65 | original_style_value_2 = float(np.array(f_style_test[attrs[2] + '/' + key])) 66 | names = ['{}_{}_{:.5}'.format(key, attrs[0], original_style_value_0 + (i - 5) * 0.1) for i in 67 | range(interpolate_num)] + \ 68 | ['{}_{}_{:.5}'.format(key, attrs[1], original_style_value_1 + (i - 5) * 0.1) for i in 69 | range(interpolate_num)] + \ 70 | ['{}_{}_{:.5}'.format(key, attrs[2], original_style_value_2 + (i - 5) * 0.1) for i in 71 | range(interpolate_num)] + \ 72 | ['{}_demo_{}_{:.5}_{}_{:.5}_{}_{:.5}'.format( 73 | key, 74 | attrs[0], original_style_value_0 + i, 75 | attrs[1], original_style_value_1 + j, 76 | attrs[2], original_style_value_2 + k) 77 | for i in mul_tgt for j in mul_tgt for k in mul_tgt] 78 | for i, name in enumerate(names): 79 | preds[name] = pred[i] 80 | 81 | return losses.item() / batch_size, accs / batch_size, preds 82 | 83 | 84 | def evaluation(model, cfg, device): 85 | f_chord_test = h5py.File(cfg['data']['chord_f_valid'], 'r') 86 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 87 | f_style_test = h5py.File(cfg['data']['attr_valid'], 'r') 88 | with open(cfg['data']['keys_valid'], 'rb') as f: 89 | keys_test = pickle.load(f) 90 | 91 | return test(f_chord_test, f_style_test, f_event_test, keys_test, model, device, 92 | cfg['attr'], cfg['batch_size']) 93 | 94 | 95 | def run(args): 96 | output_dir = os.path.join(output_path_base, args.model_name) 97 | latest_model_text_file = os.path.join(output_dir, 'latest_model.txt') 98 | sample_dir = os.path.join(output_dir, 'samples') 99 | if not os.path.exists(sample_dir): 100 | os.mkdir(sample_dir) 101 | 102 | if not os.path.exists(output_dir) and not os.path.exists(latest_model_text_file): 103 | raise IOError("Model file not found.") 104 | 105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | print(torch.cuda.current_device()) 107 | 108 | if args.model_path: 109 | latest_model_path = args.model_path 110 | else: 111 | with open(latest_model_text_file, 'r') as f: 112 | latest_model_path = f.read() 113 | checkpoint = torch.load(latest_model_path) 114 | latest_model_name = latest_model_path.split('/')[-1] 115 | cfg = generate_cfg_fader(None, args, output_dir, checkpoint) 116 | 117 | print(cfg) 118 | 119 | model_d = Classifier(input_dim=cfg['z_dims'], 120 | num_layers=cfg['model']['discriminator']['num_layers'], 121 | n_attr=len(cfg['attr']), 122 | activation=cfg['activation_d'], 123 | n_classes=8, 124 | device=device) 125 | 126 | model = FaderVAE(vocab_size=cfg['vocab_size'], 127 | hidden_dims=cfg['model']['encoder']['hidden_size'], 128 | z_dims=cfg['z_dims'], 129 | n_step=cfg['bars_per_data'] * cfg['steps_per_quarter'] * constants.QUARTERS_PER_BAR, 130 | device=device, 131 | n_attr=cfg['n_attr']) 132 | 133 | model.load_state_dict(checkpoint['model']) 134 | model_d.load_state_dict(checkpoint['model_d']) 135 | model.to(device) 136 | model_d.to(device) 137 | print(model) 138 | 139 | losses, accs, preds = evaluation(model, cfg, device) 140 | val_result_path = os.path.join(output_dir, 'eval_result.txt') 141 | with open(val_result_path, 'a') as f: 142 | f.write("model: {}, val_loss: {}, acc: {}, ".format( 143 | latest_model_name, losses, accs)) 144 | 145 | RPSC = RollAugMonoSingleSequenceConverter(steps_per_quarter=cfg['steps_per_quarter'], 146 | quarters_per_bar=constants.QUARTERS_PER_BAR, 147 | chords_per_bar=cfg['chords_per_bar'], 148 | bars_per_data=cfg['bars_per_data']) 149 | 150 | sample_save_path = os.path.join(sample_dir, latest_model_name + '_original') 151 | if not os.path.exists(sample_save_path): 152 | os.mkdir(sample_save_path) 153 | 154 | seconds_per_step = steps_per_quarter_to_seconds_per_step(cfg['steps_per_quarter'], 60) 155 | 156 | f_chord_test = h5py.File(cfg['data']['chord_valid'], 'r') 157 | f_event_test = h5py.File(cfg['data']['event_valid'], 'r') 158 | with open(cfg['data']['keys_valid'], 'rb') as f: 159 | keys_test = pickle.load(f) 160 | 161 | chord_acc, chord_style_acc = 0., 0. 162 | for key, event in preds.items(): 163 | # Normalized 164 | event = event.reshape(cfg['chords_per_data'], -1) 165 | ns = RPSC.to_note_sequence_from_events(event, seconds_per_step) # Normalized 166 | tempo = Tempo(time=0., qpm=60) 167 | ns.tempos.append(tempo) 168 | ns.instrument_infos = { 169 | InstrumentInfo('piano', 0), 170 | } 171 | time_signature = TimeSignature(time=0, numerator=4, denominator=4) 172 | ns.time_signatures.append(time_signature) 173 | 174 | quantized_ns = quantize_note_sequence(ns, steps_per_quarter=cfg['steps_per_quarter']) 175 | try: 176 | ns_with_chord = infer_chords_for_sequence(quantized_ns, chords_per_bar=cfg['chords_per_bar']) 177 | except Exception as e: 178 | print(e) 179 | continue 180 | pm = note_sequence_to_pretty_midi(ns_with_chord) 181 | key_string = '_'.join(key.split('/')) 182 | output_path = os.path.join(sample_save_path, key_string + '.mid') 183 | 184 | print(output_path) 185 | pm.write(output_path) 186 | 187 | # Normalized key comparison 188 | chord_list = [ta_chord.text for ta_chord in ns_with_chord.text_annotations] 189 | chord_txt = ",".join(chord_list) 190 | output_chord_path = os.path.join(sample_save_path, key_string + '.txt') 191 | with open(output_chord_path, 'w') as f: 192 | f.write(chord_txt) 193 | 194 | with open(val_result_path, 'a') as f: 195 | f.write("chord_acc: {}\n".format(chord_acc / len(keys_test))) 196 | 197 | original_path = os.path.join(sample_dir, 'original') 198 | if not os.path.exists(original_path): 199 | os.mkdir(original_path) 200 | 201 | for key in keys_test: 202 | # Unnormalized 203 | ns = RPSC.to_note_sequence_from_events(np.array(f_event_test[key]), seconds_per_step) 204 | pm = note_sequence_to_pretty_midi(ns) 205 | key_string = '_'.join(key.split('/')) 206 | output_path = os.path.join(original_path, key_string + '.mid') 207 | pm.write(output_path) 208 | chord = list(f_chord_test[key]) 209 | chord_list = [RPSC.chord_from_index(c) for c in chord] 210 | chord_list = ",".join(chord_list) 211 | output_chord_path = os.path.join(original_path, key_string + '.txt') 212 | with open(output_chord_path, 'w') as f: 213 | f.write(chord_list) 214 | 215 | 216 | if __name__ == "__main__": 217 | parser = ArgumentParser() 218 | parser.add_argument('--gpu', type=int, default=[0, 1], nargs='+', help='used gpu') 219 | parser.add_argument('--model_name', type=str, default="tmp", help='model name') 220 | parser.add_argument('--model_path', type=str, help='to use a specific model, not latest') 221 | 222 | args = parser.parse_args() 223 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) 224 | 225 | run(args) 226 | -------------------------------------------------------------------------------- /midi_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | 15 | """MIDI ops. 16 | Input and output wrappers for converting between MIDI and other formats. 17 | ### THIS WORKS ONLY FOR 4/4 #### 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import sys 26 | 27 | import pretty_midi 28 | import six 29 | 30 | from note_sequence import NoteSequence, TimeSignature, KeySignature, Tempo, InstrumentInfo, Note, PitchBend, \ 31 | ControlChange 32 | import constants 33 | 34 | 35 | # Allow pretty_midi to read MIDI files with absurdly high tick rates. 36 | # Useful for reading the MAPS dataset. 37 | # https://github.com/craffel/pretty-midi/issues/112 38 | pretty_midi.pretty_midi.MAX_TICK = 1e10 39 | 40 | # The offset used to change the mode of a key from major to minor when 41 | # generating a PrettyMIDI KeySignature. 42 | _PRETTY_MIDI_MAJOR_TO_MINOR_OFFSET = 12 43 | 44 | 45 | class MIDIConversionError(Exception): 46 | pass 47 | 48 | 49 | def midi_to_note_sequence(midi_data, fixed_instrument_infos=None): 50 | """Convert MIDI file contents to a NoteSequence. 51 | Converts a MIDI file encoded as a string into a NoteSequence. Decoding errors 52 | are very common when working with large sets of MIDI files, so be sure to 53 | handle MIDIConversionError exceptions. 54 | Args: 55 | midi_data: A string containing the contents of a MIDI file or populated 56 | pretty_midi.PrettyMIDI object. 57 | Returns: 58 | A NoteSequence. 59 | Raises: 60 | MIDIConversionError: An improper MIDI mode was supplied. 61 | """ 62 | # In practice many MIDI files cannot be decoded with pretty_midi. Catch all 63 | # errors here and try to log a meaningful message. So many different 64 | # exceptions are raised in pretty_midi.PrettyMidi that it is cumbersome to 65 | # catch them all only for the purpose of error logging. 66 | 67 | if isinstance(midi_data, pretty_midi.PrettyMIDI): 68 | midi = midi_data 69 | else: 70 | try: 71 | midi = pretty_midi.PrettyMIDI(six.BytesIO(midi_data)) 72 | except BaseException: 73 | raise MIDIConversionError('Midi decoding error %s: %s' % 74 | (sys.exc_info()[0], sys.exc_info()[1])) 75 | 76 | sequence = NoteSequence(ticks_per_quarter=midi.resolution) 77 | time_signature = TimeSignature(time=0, numerator=4, denominator=4) 78 | sequence.time_signatures.append(time_signature) 79 | 80 | # Populate key signatures. 81 | for midi_key in midi.key_signature_changes: 82 | midi_mode = midi_key.key_number // 12 # MAJOR if 0 else MINOR 83 | key_signature = KeySignature( 84 | time=midi_key.time, key=midi_key.key_number % 85 | 12, mode=midi_mode) 86 | sequence.key_signatures.append(key_signature) 87 | 88 | # Populate tempo changes. 89 | tempo_times, tempo_qpms = midi.get_tempo_changes() 90 | for time_in_seconds, tempo_in_qpm in zip(tempo_times, tempo_qpms): 91 | tempo = Tempo(time=time_in_seconds, qpm=tempo_in_qpm) 92 | sequence.tempos.append(tempo) 93 | 94 | # Populate notes by gathering them all from the midi's instruments. 95 | # Also set the sequence.total_time as the max end time in the notes. 96 | midi_notes = [] 97 | midi_pitch_bends = [] 98 | midi_control_changes = [] 99 | for num_instrument, midi_instrument in enumerate(midi.instruments): 100 | # Populate instrument name from the midi's instruments 101 | if fixed_instrument_infos is not None: 102 | num_instrument = fixed_instrument_infos.index(midi_instrument.name) 103 | instrument_info = InstrumentInfo( 104 | midi_instrument.name, num_instrument) 105 | sequence.instrument_infos.append(instrument_info) 106 | 107 | for midi_note in midi_instrument.notes: 108 | if not sequence.total_time or midi_note.end > sequence.total_time: 109 | sequence.total_time = midi_note.end 110 | midi_notes.append((midi_instrument.program, num_instrument, 111 | midi_instrument.is_drum, midi_note)) 112 | for midi_pitch_bend in midi_instrument.pitch_bends: 113 | midi_pitch_bends.append( 114 | (midi_instrument.program, num_instrument, 115 | midi_instrument.is_drum, midi_pitch_bend)) 116 | for midi_control_change in midi_instrument.control_changes: 117 | midi_control_changes.append( 118 | (midi_instrument.program, num_instrument, 119 | midi_instrument.is_drum, midi_control_change)) 120 | 121 | for program, instrument, is_drum, midi_note in midi_notes: 122 | note = Note( 123 | instrument, 124 | program, 125 | midi_note.start, 126 | midi_note.end, 127 | midi_note.pitch, 128 | midi_note.velocity, 129 | is_drum) 130 | sequence.notes.append(note) 131 | 132 | for program, instrument, is_drum, midi_pitch_bend in midi_pitch_bends: 133 | pitch_bend = PitchBend( 134 | instrument, 135 | program, 136 | midi_pitch_bend.time, 137 | midi_pitch_bend.pitch, 138 | is_drum) 139 | sequence.pitch_bends.append(pitch_bend) 140 | 141 | for program, instrument, is_drum, midi_control_change in midi_control_changes: 142 | control_change = ControlChange( 143 | instrument, 144 | program, 145 | midi_control_change.time, 146 | midi_control_change.number, 147 | midi_control_change.value, 148 | is_drum) 149 | sequence.control_changes.append(control_change) 150 | 151 | return sequence 152 | 153 | 154 | def note_sequence_to_pretty_midi( 155 | sequence, drop_events_n_seconds_after_last_note=None): 156 | """Convert NoteSequence to a PrettyMIDI. 157 | Time is stored in the NoteSequence in absolute values (seconds) as opposed to 158 | relative values (MIDI ticks). When the NoteSequence is translated back to 159 | PrettyMIDI the absolute time is retained. The tempo map is also recreated. 160 | Args: 161 | sequence: A NoteSequence. 162 | drop_events_n_seconds_after_last_note: Events (e.g., time signature changes) 163 | that occur this many seconds after the last note will be dropped. If 164 | None, then no events will be dropped. 165 | Returns: 166 | A pretty_midi.PrettyMIDI object or None if sequence could not be decoded. 167 | """ 168 | ticks_per_quarter = sequence.ticks_per_quarter or constants.STANDARD_PPQ 169 | 170 | max_event_time = None 171 | if drop_events_n_seconds_after_last_note is not None: 172 | max_event_time = (max([n.end_time for n in sequence.notes] or [0]) + 173 | drop_events_n_seconds_after_last_note) 174 | 175 | # Try to find a tempo at time zero. The list is not guaranteed to be in 176 | # order. 177 | initial_seq_tempo = None 178 | for seq_tempo in sequence.tempos: 179 | if seq_tempo.time == 0: 180 | initial_seq_tempo = seq_tempo 181 | break 182 | 183 | kwargs = {} 184 | if initial_seq_tempo: 185 | kwargs['initial_tempo'] = initial_seq_tempo.qpm 186 | else: 187 | kwargs['initial_tempo'] = constants.DEFAULT_QUARTERS_PER_MINUTE 188 | 189 | pm = pretty_midi.PrettyMIDI(resolution=ticks_per_quarter, **kwargs) 190 | 191 | # Create an empty instrument to contain time and key signatures. 192 | instrument = pretty_midi.Instrument(0) 193 | pm.instruments.append(instrument) 194 | 195 | # Populate time signatures. 196 | for seq_ts in sequence.time_signatures: 197 | if max_event_time and seq_ts.time > max_event_time: 198 | continue 199 | time_signature = pretty_midi.containers.TimeSignature( 200 | seq_ts.numerator, seq_ts.denominator, seq_ts.time) 201 | pm.time_signature_changes.append(time_signature) 202 | 203 | # Populate key signatures. 204 | for seq_key in sequence.key_signatures: 205 | if max_event_time and seq_key.time > max_event_time: 206 | continue 207 | key_number = seq_key.key 208 | if constants.SCALE_MODE[seq_key.mode] == "MINOR": 209 | key_number += _PRETTY_MIDI_MAJOR_TO_MINOR_OFFSET 210 | key_signature = pretty_midi.containers.KeySignature( 211 | key_number, seq_key.time) 212 | pm.key_signature_changes.append(key_signature) 213 | 214 | # Populate tempos. 215 | for seq_tempo in sequence.tempos: 216 | # Skip if this tempo was added in the PrettyMIDI constructor. 217 | if seq_tempo == initial_seq_tempo: 218 | continue 219 | if max_event_time and seq_tempo.time > max_event_time: 220 | continue 221 | tick_scale = 60.0 / (pm.resolution * seq_tempo.qpm) 222 | tick = pm.time_to_tick(seq_tempo.time) 223 | pm._tick_scales.append((tick, tick_scale)) 224 | pm._update_tick_to_time(0) 225 | 226 | # Populate instrument names by first creating an instrument map between 227 | # instrument index and name. 228 | # Then, going over this map in the instrument event for loop 229 | inst_infos = {} 230 | for inst_info in sequence.instrument_infos: 231 | inst_infos[inst_info.instrument] = inst_info.name 232 | 233 | # Populate instrument events by first gathering notes and other event types 234 | # in lists then write them sorted to the PrettyMidi object. 235 | instrument_events = collections.defaultdict( 236 | lambda: collections.defaultdict(list)) 237 | for seq_note in sequence.notes: 238 | instrument_events[(seq_note.instrument, seq_note.program, 239 | seq_note.is_drum)]['notes'].append( 240 | pretty_midi.Note( 241 | seq_note.velocity, seq_note.pitch, 242 | seq_note.start_time, seq_note.end_time)) 243 | for seq_bend in sequence.pitch_bends: 244 | if max_event_time and seq_bend.time > max_event_time: 245 | continue 246 | instrument_events[(seq_bend.instrument, seq_bend.program, 247 | seq_bend.is_drum)]['bends'].append( 248 | pretty_midi.PitchBend(seq_bend.bend, seq_bend.time)) 249 | for seq_cc in sequence.control_changes: 250 | if max_event_time and seq_cc.time > max_event_time: 251 | continue 252 | instrument_events[(seq_cc.instrument, seq_cc.program, 253 | seq_cc.is_drum)]['controls'].append( 254 | pretty_midi.ControlChange( 255 | seq_cc.control_number, 256 | seq_cc.control_value, seq_cc.time)) 257 | 258 | for (instr_id, prog_id, is_drum) in sorted(instrument_events.keys()): 259 | # For instr_id 0 append to the instrument created above. 260 | if instr_id > 0: 261 | instrument = pretty_midi.Instrument(prog_id, is_drum) 262 | pm.instruments.append(instrument) 263 | else: 264 | instrument.is_drum = is_drum 265 | # propagate instrument name to the midi file 266 | instrument.program = prog_id 267 | if instr_id in inst_infos: 268 | instrument.name = inst_infos[instr_id] 269 | instrument.notes = instrument_events[ 270 | (instr_id, prog_id, is_drum)]['notes'] 271 | instrument.pitch_bends = instrument_events[ 272 | (instr_id, prog_id, is_drum)]['bends'] 273 | instrument.control_changes = instrument_events[ 274 | (instr_id, prog_id, is_drum)]['controls'] 275 | 276 | return pm 277 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | 6 | 7 | class FaderVAE(nn.Module): 8 | def __init__(self, 9 | vocab_size, 10 | hidden_dims, 11 | z_dims, 12 | n_step, 13 | device, 14 | n_attr, 15 | k=1000): 16 | super(FaderVAE, self).__init__() 17 | self.gru_0 = nn.GRU( 18 | vocab_size, 19 | hidden_dims, 20 | batch_first=True, 21 | bidirectional=True) 22 | self.linear_mu = nn.Linear(hidden_dims * 2, z_dims) 23 | self.linear_var = nn.Linear(hidden_dims * 2, z_dims) 24 | self.grucell_1 = nn.GRUCell( 25 | z_dims + vocab_size + n_attr, 26 | hidden_dims) 27 | self.grucell_2 = nn.GRUCell(hidden_dims, hidden_dims) 28 | self.linear_init_1 = nn.Linear(z_dims, hidden_dims) 29 | self.linear_out_1 = nn.Linear(hidden_dims, vocab_size) 30 | self.n_step = n_step 31 | self.vocab_size = vocab_size 32 | self.hidden_dims = hidden_dims 33 | self.eps = 1 34 | self.sample = None 35 | self.iteration = 0 36 | self.z_dims = z_dims 37 | self.k = torch.FloatTensor([k]) 38 | self.device = device 39 | 40 | def _sampling(self, x): 41 | idx = x.max(1)[1] 42 | x = torch.zeros_like(x) 43 | arange = torch.arange(x.size(0)).long() 44 | if torch.cuda.is_available(): 45 | arange = arange.cuda() 46 | x[arange, idx] = 1 47 | return x 48 | 49 | def encoder(self, x, condition): 50 | self.gru_0.flatten_parameters() 51 | x = self.gru_0(x) 52 | x = x[-1] 53 | x = x.transpose_(0, 1).contiguous() 54 | x = x.view(x.size(0), -1) 55 | mu = self.linear_mu(x) 56 | var = self.linear_var(x).exp_() 57 | distribution = Normal(mu, var) 58 | return distribution 59 | 60 | def encode(self, x, condition): 61 | b, c, s = x.size() 62 | x = x.reshape(b, -1) 63 | x = torch.eye(self.vocab_size)[x].to(self.device) 64 | dis = self.encoder(x, condition) 65 | z = dis.rsample() 66 | return z, None, None 67 | 68 | def decoder(self, z, condition, style): 69 | out = torch.zeros((z.size(0), self.vocab_size)) 70 | out[:, -1] = 1. 71 | x, hx = [], [None, None] 72 | t = torch.tanh(self.linear_init_1(z)) 73 | hx[0] = t 74 | if torch.cuda.is_available(): 75 | out = out.cuda() 76 | for i in range(self.n_step): 77 | out = torch.cat([out.float(), z, style], 1) 78 | hx[0] = self.grucell_1(out, hx[0]) 79 | if i == 0: 80 | hx[1] = hx[0] 81 | hx[1] = self.grucell_2(hx[0], hx[1]) 82 | out = F.log_softmax(self.linear_out_1(hx[1]), 1) 83 | x.append(out) 84 | if self.training: 85 | p = torch.rand(1).item() 86 | if p < self.eps: 87 | out = self.sample[:, i, :] 88 | else: 89 | out = self._sampling(out) 90 | self.eps = self.k / \ 91 | (self.k + torch.exp(self.iteration / self.k)) 92 | else: 93 | out = self._sampling(out) 94 | return torch.stack(x, 1) 95 | 96 | def forward(self, x, condition, style): 97 | b, c, s = x.size() 98 | x = x.reshape(b, -1) 99 | x_indices = x 100 | x = torch.eye(self.vocab_size)[x].to(self.device) 101 | condition = condition.repeat(1, 1, 1, s) 102 | condition = condition.reshape(b, c * s, -1) 103 | if self.training: 104 | self.sample = x 105 | self.iteration += 1 106 | dis = self.encoder(x, condition) 107 | z = dis.rsample() 108 | recon = self.decoder(z, condition, style) 109 | preds = torch.argmax(recon, dim=-1) 110 | acc = torch.sum(torch.eq(preds, x_indices)).item() / (x_indices.size(0) * x_indices.size(1)) 111 | loss = F.nll_loss(recon.reshape(-1, recon.size(-1)), x_indices.reshape(-1)) 112 | return loss, preds, z, acc, dis 113 | 114 | 115 | class Classifier(nn.Module): 116 | def __init__(self, input_dim, num_layers, n_attr, n_classes, activation, device): 117 | super(Classifier, self).__init__() 118 | if activation == 'tanh': 119 | activation_f = nn.Tanh 120 | elif activation == 'relu': 121 | activation_f = nn.ReLU 122 | elif activation == 'leakyrelu': 123 | activation_f = nn.LeakyReLU 124 | 125 | assert num_layers >= 2 126 | layers = [] 127 | for _ in range(num_layers - 1): 128 | layers.append(nn.Linear(input_dim, input_dim // 2)) 129 | layers.append(activation_f()) 130 | input_dim = input_dim // 2 131 | layers.append(nn.Linear(input_dim, n_attr * n_classes)) 132 | layers.append(nn.Sigmoid()) 133 | self.layers = nn.ModuleList(layers) 134 | self.n_classes = n_classes 135 | self.n_attr = n_attr 136 | self.device = device 137 | self.ordinal_labels = torch.ones((8, 8)).triu() 138 | 139 | def calc_loss(self, pred, class_label, is_discriminator=True, is_ordinal=False): 140 | """ 141 | 142 | :param pred: [b, num_attr] 143 | :param class_label: [b, num_attr] 144 | :param is_discriminator: 145 | :param is_ordinal: 146 | :return: 147 | """ 148 | criterion = nn.BCELoss() 149 | if not is_ordinal: 150 | label = torch.eye(self.n_classes)[class_label] 151 | else: 152 | b, num_attr = class_label.size() 153 | class_label = class_label.reshape(-1) 154 | label = self.ordinal_labels[class_label] 155 | label = label.reshape(b, num_attr, self.n_classes) 156 | label = label.to(self.device) 157 | if not is_discriminator: 158 | label = - label + 1 159 | return criterion(pred, label) 160 | 161 | def forward(self, lv): 162 | for i in range(len(self.layers)): 163 | lv = self.layers[i](lv) 164 | lv = lv.reshape(lv.size(0), self.n_attr, self.n_classes) 165 | return lv 166 | -------------------------------------------------------------------------------- /note_sequence.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes of note sequence. 3 | """ 4 | import constants 5 | 6 | 7 | class NoteSequence(object): 8 | def __init__(self, ticks_per_quarter): 9 | self.ticks_per_quarter = ticks_per_quarter 10 | self.time_signatures = [] 11 | self.key_signatures = [] 12 | self.tempos = [] 13 | self.instrument_infos = [] 14 | self.total_time = None 15 | self.notes = [] 16 | self.pitch_bends = [] 17 | self.control_changes = [] 18 | self.text_annotations = [] 19 | self.quantization_info = QuantizationInfo() 20 | self.total_quantized_steps = None 21 | 22 | 23 | class TimeSignature(object): 24 | def __init__(self, time, numerator, denominator): 25 | """ 26 | Ex. 6/8, numerator = 6, denominator = 8. 27 | :param time: Starting time of this object. 28 | :param numerator: Beats per measure. 29 | :param denominator: The type of beat. 30 | """ 31 | self.time = time 32 | self.numerator = numerator 33 | self.denominator = denominator 34 | 35 | 36 | class KeySignature(object): 37 | def __init__(self, time, key, mode): 38 | self.time = time 39 | self.key = key 40 | self.mode = mode 41 | 42 | def __repr__(self): 43 | return "time: {}, key: {}, mode: {}".format(self.time, self.key, self.mode) 44 | 45 | 46 | class Tempo(object): 47 | def __init__(self, time, qpm): 48 | self.time = time 49 | self.qpm = qpm 50 | 51 | 52 | class InstrumentInfo(object): 53 | def __init__(self, name, instrument): 54 | self.name = name 55 | self.instrument = instrument 56 | 57 | def __repr__(self): 58 | return "name: {}, instrument: {}".format(self.name, self.instrument) 59 | 60 | 61 | class TextAnnotation(object): 62 | def __init__(self, time, quantized_step, text, annotation_type, root, kind, pitch_vector): 63 | self.time = time 64 | self.quantized_step = quantized_step 65 | self.text = text 66 | self.annotation_type = annotation_type 67 | self.root = root 68 | self.kind = kind 69 | self.pitch_vector = pitch_vector 70 | 71 | def __repr__(self): 72 | return "time: {}, step: {}, text: {}, type: {}, {}, {}, pitch: ".format( 73 | self.time, self.quantized_step, self.text, self.annotation_type, self.root, self.kind, self.pitch_vector) 74 | 75 | 76 | class AnnotationType(object): 77 | CHORD_SYMBOL = 1 78 | 79 | 80 | class Note(object): 81 | def __init__(self, instrument, program, start_time, end_time, pitch, velocity, is_drum): 82 | self.instrument = instrument 83 | self.program = program 84 | self.start_time = start_time 85 | self.end_time = end_time 86 | self.pitch = pitch 87 | self.velocity = velocity 88 | self.is_drum = is_drum 89 | self.quantized_start_step = None 90 | self.quantized_end_step = None 91 | 92 | def __repr__(self): 93 | return "instrument: {}, start: {}, end: {}, pitch: {}, qstart: {}, qend: {}\n".format( 94 | self.instrument, self.start_time, self.end_time, self.pitch, self.quantized_start_step, 95 | self.quantized_end_step) 96 | 97 | 98 | class PitchBend(object): 99 | def __init__(self, instrument, program, time, pitch, is_drum): 100 | self.instrument = instrument 101 | self.program = program 102 | self.time = time 103 | self.pitch = pitch 104 | self.is_drum = is_drum 105 | 106 | 107 | class ControlChange(object): 108 | def __init__(self, instrument, program, time, control_number, control_value, is_drum): 109 | self.instrument = instrument 110 | self.program = program 111 | self.time = time 112 | self.control_number = control_number 113 | self.control_value = control_value 114 | self.is_drum = is_drum 115 | 116 | 117 | class QuantizationInfo(object): 118 | def __init__(self): 119 | self.steps_per_quarter = None 120 | 121 | 122 | class PerformanceEvent(object): 123 | # Start of a new note. 124 | NOTE_ON = 1 125 | # End of a note. 126 | NOTE_OFF = 2 127 | # Shift time forward. 128 | TIME_SHIFT = 3 129 | 130 | def __init__(self, event_type, event_value): 131 | if event_type in (PerformanceEvent.NOTE_ON, PerformanceEvent.NOTE_OFF): 132 | if not constants.MIN_MIDI_PITCH <= event_value <= constants.MAX_MIDI_PITCH: 133 | raise ValueError('Invalid pitch value: %s' % event_value) 134 | elif event_type == PerformanceEvent.TIME_SHIFT: 135 | if not 0 <= event_value: 136 | raise ValueError('Invalid time shift value: %s' % event_value) 137 | else: 138 | raise ValueError('Invalid event type: %s' % event_type) 139 | 140 | self.event_type = event_type 141 | self.event_value = event_value 142 | 143 | def __repr__(self): 144 | return 'PerformanceEvent(%r, %r)\n' % (self.event_type, self.event_value) 145 | 146 | def __eq__(self, other): 147 | if not isinstance(other, PerformanceEvent): 148 | return False 149 | return self.event_type == other.event_type and self.event_value == other.event_value 150 | 151 | def __hash__(self): 152 | return int(self.event_type * 300 + self.event_value) 153 | 154 | 155 | class Instrument(object): 156 | Piano = 0 157 | Guitar = 1 158 | Bass = 2 159 | Strings = 3 160 | Drums = 4 161 | 162 | 163 | class PerformanceEventWithInstrument(object): 164 | # Start of a new note. 165 | NOTE_ON = 1 166 | # End of a note. 167 | NOTE_OFF = 2 168 | # Shift time forward. 169 | TIME_SHIFT = 3 170 | # Both ends of sequence. 171 | ENDS_OF_SEQ = 4 172 | # Padding. 173 | PADDING = 5 174 | 175 | def __init__(self, event_type, event_value, instrument=None): 176 | if event_type in (PerformanceEventWithInstrument.NOTE_ON, PerformanceEventWithInstrument.NOTE_OFF): 177 | if not constants.MIN_MIDI_PITCH <= event_value <= constants.MAX_MIDI_PITCH: 178 | raise ValueError('Invalid pitch value: %s' % event_value) 179 | if instrument is None: 180 | raise ValueError('Instrument should not be empty.') 181 | elif event_type == PerformanceEventWithInstrument.TIME_SHIFT: 182 | if not 0 <= event_value: 183 | raise ValueError('Invalid time shift value: %s' % event_value) 184 | elif event_type in [PerformanceEventWithInstrument.ENDS_OF_SEQ, PerformanceEventWithInstrument.PADDING]: 185 | pass 186 | else: 187 | raise ValueError('Invalid event type: %s' % event_type) 188 | 189 | self.event_type = event_type 190 | self.event_value = event_value 191 | self.instrument = instrument 192 | 193 | def __repr__(self): 194 | return 'PerformanceEvent(%r, %r)\n' % (self.event_type, self.event_value) 195 | 196 | def __eq__(self, other): 197 | if not isinstance(other, PerformanceEventWithInstrument): 198 | return False 199 | return self.event_type == other.event_type and self.event_value == other.event_value and \ 200 | self.instrument == self.instrument 201 | 202 | def __hash__(self): 203 | if self.instrument: 204 | return int(self.event_type * 1000 + self.event_value + (self.instrument + 1) * 10000) 205 | else: 206 | return int(self.event_type * 1000 + self.event_value) 207 | 208 | -------------------------------------------------------------------------------- /note_sequence_ops.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | from note_sequence import Instrument 4 | 5 | QUANTIZE_CUTOFF = 0.5 6 | PITCH_CUTOFF = 20 7 | STEP_CUTOFF = 16 8 | 9 | 10 | class NegativeTimeError(Exception): 11 | pass 12 | 13 | 14 | # Public ops 15 | def quantize_note_sequence(note_sequence, steps_per_quarter): 16 | note_sequence.quantization_info.steps_per_quarter = steps_per_quarter 17 | steps_per_second = steps_per_quarter_to_steps_per_second( 18 | steps_per_quarter, note_sequence.tempos[0].qpm) 19 | note_sequence.total_quantized_steps = quantize_to_step(note_sequence.total_time, steps_per_second) 20 | return _quantize_notes(note_sequence, steps_per_second) 21 | 22 | 23 | def steps_per_bar_in_quantized_sequence(note_sequence): 24 | """Calculates steps per bar in a NoteSequence that has been quantized. 25 | Args: 26 | note_sequence: The NoteSequence to examine. 27 | Returns: 28 | Steps per bar as a floating point number. 29 | """ 30 | assert note_sequence.quantization_info.steps_per_quarter > 0 31 | 32 | quarters_per_beat = 4.0 / note_sequence.time_signatures[0].denominator 33 | quarters_per_bar = ( 34 | quarters_per_beat * note_sequence.time_signatures[0].numerator) 35 | steps_per_bar_float = ( 36 | note_sequence.quantization_info.steps_per_quarter * quarters_per_bar) 37 | return steps_per_bar_float 38 | 39 | 40 | # Private ops 41 | 42 | def steps_per_quarter_to_steps_per_second(steps_per_quarter, qpm): 43 | """Calculates steps per second given steps_per_quarter and a qpm.""" 44 | return steps_per_quarter * qpm / 60.0 45 | 46 | 47 | def steps_per_quarter_to_seconds_per_step(steps_per_quarter, qpm): 48 | return 60.0 / steps_per_quarter / qpm 49 | 50 | 51 | def quantize_to_step(unquantized_seconds, 52 | steps_per_second, 53 | quantize_cutoff=QUANTIZE_CUTOFF): 54 | """Quantizes seconds to the nearest step, given steps_per_second. 55 | See the comments above `QUANTIZE_CUTOFF` for details on how the quantizing 56 | algorithm works. 57 | Args: 58 | unquantized_seconds: Seconds to quantize. 59 | steps_per_second: Quantizing resolution. 60 | quantize_cutoff: Value to use for quantizing cutoff. 61 | Returns: 62 | The input value quantized to the nearest step. 63 | """ 64 | unquantized_steps = unquantized_seconds * steps_per_second 65 | return int(unquantized_steps + (1 - quantize_cutoff)) 66 | 67 | 68 | def _quantize_to_step(unquantized_seconds, 69 | steps_per_second, 70 | quantize_cutoff=QUANTIZE_CUTOFF): 71 | """Quantizes seconds to the nearest step, given steps_per_second. 72 | See the comments above `QUANTIZE_CUTOFF` for details on how the quantizing 73 | algorithm works. 74 | Args: 75 | unquantized_seconds: Seconds to quantize. 76 | steps_per_second: Quantizing resolution. 77 | quantize_cutoff: Value to use for quantizing cutoff. 78 | Returns: 79 | The input value quantized to the nearest step. 80 | """ 81 | unquantized_steps = unquantized_seconds * steps_per_second 82 | return int(unquantized_steps + (1 - quantize_cutoff)) 83 | 84 | 85 | def _quantize_notes(note_sequence, steps_per_second): 86 | """Quantize the notes and chords of a NoteSequence proto in place. 87 | Note start and end times, and chord times are snapped to a nearby quantized 88 | step, and the resulting times are stored in a separate field (e.g., 89 | quantized_start_step). See the comments above `QUANTIZE_CUTOFF` for details on 90 | how the quantizing algorithm works. 91 | Args: 92 | note_sequence: A music_pb2.NoteSequence protocol buffer. Will be modified in 93 | place. 94 | steps_per_second: Each second will be divided into this many quantized time 95 | steps. 96 | Raises: 97 | NegativeTimeError: If a note or chord occurs at a negative time. 98 | """ 99 | for note in note_sequence.notes: 100 | # Quantize the start and end times of the note. 101 | note.quantized_start_step = _quantize_to_step(note.start_time, 102 | steps_per_second) 103 | note.quantized_end_step = _quantize_to_step( 104 | note.end_time, steps_per_second) 105 | if note.quantized_end_step == note.quantized_start_step: 106 | note.quantized_end_step += 1 107 | 108 | # Do not allow notes to start or end in negative time. 109 | if note.quantized_start_step < 0 or note.quantized_end_step < 0: 110 | raise NegativeTimeError( 111 | 'Got negative note time: start_step = %s, end_step = %s' % 112 | (note.quantized_start_step, note.quantized_end_step)) 113 | 114 | # Extend quantized sequence if necessary. 115 | if note.quantized_end_step > note_sequence.total_quantized_steps: 116 | note_sequence.total_quantized_steps = note.quantized_end_step 117 | 118 | # Also quantize control changes and text annotations. 119 | for event in note_sequence.control_changes: 120 | # Quantize the event time, disallowing negative time. 121 | event.quantized_step = _quantize_to_step(event.time, steps_per_second) 122 | if event.quantized_step < 0: 123 | raise NegativeTimeError( 124 | 'Got negative event time: step = %s' % event.quantized_step) 125 | 126 | return note_sequence 127 | 128 | 129 | def calculate_style_feature(note_sequence, num_inst=1, mono=False): 130 | note_sequence.notes.sort(key=lambda note: note.quantized_start_step) 131 | style_feature = np.zeros((num_inst, PITCH_CUTOFF * 2 + 1, STEP_CUTOFF)) 132 | for i, note in enumerate(note_sequence.notes): 133 | for j in range(i + 1, len(note_sequence.notes)): 134 | if note.instrument != note_sequence.notes[j].instrument: 135 | continue 136 | time_gap = note_sequence.notes[j].quantized_start_step - note.quantized_start_step 137 | if time_gap >= STEP_CUTOFF: 138 | break 139 | pitch_gap = note_sequence.notes[j].pitch - note.pitch 140 | if abs(pitch_gap) < PITCH_CUTOFF: 141 | if mono: 142 | instrument_index = note.instrument 143 | else: 144 | instrument_index = instrument_index_from_infos(note.instrument, note_sequence.instrument_infos) 145 | style_feature[instrument_index, pitch_gap + PITCH_CUTOFF, time_gap] += 1 146 | style_feature = style_feature.reshape(num_inst, -1) 147 | dim_feature = style_feature.shape[1] 148 | inst_norm_each = np.linalg.norm(style_feature, axis=1).repeat(dim_feature).reshape(num_inst, dim_feature) 149 | style_feature = 1. * style_feature / inst_norm_each 150 | return style_feature 151 | 152 | 153 | def instrument_index_from_infos(index, instrument_infos): 154 | instrument_name = instrument_infos[index].name 155 | if instrument_name == 'Piano': 156 | return Instrument.Piano 157 | if instrument_name == 'Guitar': 158 | return Instrument.Guitar 159 | if instrument_name == 'Bass': 160 | return Instrument.Bass 161 | if instrument_name == 'Strings': 162 | return Instrument.Strings 163 | if instrument_name == 'Drums': 164 | return Instrument.Drums 165 | raise 'Unknown instrument %s' % instrument_name 166 | 167 | 168 | def find_start_time(ns): 169 | ns.notes = sorted(ns.notes, key=lambda note: note.start_time) 170 | for note in ns.notes: 171 | if note.instrument != 0: 172 | return note.start_time 173 | print('cant find') 174 | return 0 175 | 176 | 177 | def delete_auftakt(ns): 178 | start_time = find_start_time(ns) 179 | filtered_notes = [] 180 | for note in ns.notes: 181 | note.start_time = max(note.start_time - start_time, 0.) 182 | note.end_time = max(note.end_time - start_time, 0.) 183 | if note.end_time - note.start_time > 0.00001: 184 | filtered_notes.append(note) 185 | ns.notes = filtered_notes 186 | return ns 187 | 188 | 189 | def filter_note(ns, ins_index): 190 | filtered_notes = [] 191 | for note in ns.notes: 192 | if note.instrument == ins_index: 193 | filtered_notes.append(note) 194 | ns.notes = filtered_notes 195 | return ns -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import h5py 5 | import pickle 6 | from argparse import ArgumentParser 7 | import ruamel 8 | import ruamel.yaml 9 | 10 | import torch 11 | from torch import optim 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.distributions import kl_divergence, Normal 14 | 15 | from utils import time_since, std_normal 16 | from model import Classifier, FaderVAE 17 | from arguments import generate_cfg_fader 18 | from constants import QUARTERS_PER_BAR 19 | 20 | # CHANGE PATH HERE 21 | output_path_base = 'output' 22 | gradient_clip = 1.0 23 | 24 | 25 | def to_tensor(data, device): 26 | if device == 'cpu': 27 | return torch.LongTensor(data, device=device) 28 | return torch.cuda.LongTensor(data, device=device) 29 | 30 | 31 | def to_float_tensor(data, device): 32 | if device == 'cpu': 33 | return torch.FloatTensor(data, device=device) 34 | return torch.cuda.FloatTensor(data, device=device) 35 | 36 | 37 | def train(chord_tensor, style_tensor, style_cls_tensor, event_tensor, model, model_d, optimizer, optimizer_d, 38 | lambda_d, lambda_kl, is_ordinal): 39 | """ 40 | 41 | :param chord_tensor: [num_chord] 42 | :param style_tensor: [style_dim] 43 | :param event_tensor: [num_chord, event_max_length] 44 | :param model:h 45 | :param optimizer: 46 | :return: loss of rhythm, acc of rhythm, loss of pitch, acc of rhythm 47 | """ 48 | model.train() 49 | model_d.train() 50 | 51 | # train discriminator 52 | optimizer_d.zero_grad() 53 | lv, _, _ = model.encode(event_tensor, chord_tensor) 54 | output = model_d(lv.detach()) 55 | loss_d = model_d.calc_loss(output, style_cls_tensor, is_discriminator=True, is_ordinal=is_ordinal) 56 | loss_d.backward() 57 | 58 | torch.nn.utils.clip_grad_norm(model_d.parameters(), gradient_clip) 59 | optimizer_d.step() 60 | 61 | # train auto-encoder 62 | optimizer.zero_grad() 63 | loss_recon, _, lv, acc, distribution = model(event_tensor, chord_tensor, style_tensor) 64 | dis_out = model_d(lv) 65 | loss_d_gen = model_d.calc_loss(dis_out, style_cls_tensor, is_discriminator=False, is_ordinal=is_ordinal) 66 | 67 | normal = std_normal(distribution.mean.size()) 68 | loss_kl = kl_divergence(distribution, normal).mean() 69 | loss = loss_recon + lambda_d * loss_d_gen + lambda_kl * loss_kl 70 | 71 | loss.backward() 72 | 73 | torch.nn.utils.clip_grad_norm(model.parameters(), gradient_clip) 74 | optimizer.step() 75 | 76 | return loss_recon.item(), loss_d.item(), loss_kl.item(), acc 77 | 78 | 79 | def test(f_chord_test, f_style_test, f_style_cls_test, f_event_test, keys_test, model, model_d, device, attrs, 80 | batch_size, use_chord_vector, is_ordinal): 81 | model.eval() 82 | model_d.eval() 83 | losses, loss_ds, losses_kl, accs = 0., 0., 0., 0. 84 | 85 | with torch.no_grad(): 86 | for i in range(0, len(keys_test), batch_size): 87 | key_indices = keys_test[i: i + batch_size] 88 | if len(key_indices) < batch_size: 89 | continue 90 | if use_chord_vector: 91 | chord_tensor = to_float_tensor([f_chord_test[key] for key in key_indices], device=device) 92 | else: 93 | chord_tensor = to_tensor([f_chord_test[key] for key in key_indices], device=device) 94 | style_tensor = to_float_tensor([[f_style_test[attr + '/' + key] for attr in attrs] 95 | for key in key_indices], device=device).reshape(batch_size, -1) 96 | style_cls_tensor = to_tensor([[f_style_cls_test[attr + '/' + key] for attr in attrs] 97 | for key in key_indices], device=device).reshape(batch_size, -1) 98 | event_tensor = to_tensor([f_event_test[key] for key in key_indices], device=device) 99 | 100 | loss, _, lv, acc, distribution = model(event_tensor, chord_tensor, style_tensor) 101 | dis_out = model_d(lv) 102 | loss_d = model_d.calc_loss(dis_out, style_cls_tensor, is_ordinal=is_ordinal) 103 | 104 | normal = std_normal(distribution.mean.size()) 105 | loss_kl = kl_divergence(distribution, normal).mean() 106 | 107 | losses += loss 108 | loss_ds += loss_d 109 | losses_kl += loss_kl 110 | accs += acc 111 | 112 | batch_num = len(keys_test) // batch_size 113 | return losses.item() / batch_num, loss_ds.item() / batch_num, losses_kl.item() / batch_num, accs / batch_num 114 | 115 | 116 | def train_iters(model, model_d, optimizer, optimizer_d, cfg, device, print_every=500, model_save_every=100, 117 | start_iter=0, use_chord_vector=False): 118 | start = time.time() 119 | 120 | log_dir = os.path.join(cfg['output_dir'], 'logs') 121 | writer = SummaryWriter(log_dir=log_dir) 122 | 123 | print_losses = [] 124 | val_losses = [] 125 | print_loss_total = 0 # Reset every print_every 126 | 127 | if use_chord_vector: 128 | f_chord = h5py.File(cfg['data']['chord_f'], 'r') 129 | f_chord_test = h5py.File(cfg['data']['chord_f_test'], 'r') 130 | else: 131 | f_chord = h5py.File(cfg['data']['chord'], 'r') 132 | f_chord_test = h5py.File(cfg['data']['chord_test'], 'r') 133 | 134 | f_style = h5py.File(cfg['data']['attr'], 'r') 135 | f_style_cls = h5py.File(cfg['data']['attr_cls'], 'r') 136 | f_event = h5py.File(cfg['data']['event'], 'r') 137 | with open(cfg['data']['keys'], 'rb') as f: 138 | keys = pickle.load(f) 139 | 140 | f_style_test = h5py.File(cfg['data']['attr_test'], 'r') 141 | f_style_cls_test = h5py.File(cfg['data']['attr_cls_test'], 'r') 142 | f_event_test = h5py.File(cfg['data']['event_test'], 'r') 143 | with open(cfg['data']['keys_test'], 'rb') as f: 144 | keys_test = pickle.load(f) 145 | 146 | for iter in range(start_iter + 1, cfg['n_iters'] + 1): 147 | key_indices = random.sample(keys, cfg['batch_size']) 148 | if use_chord_vector: 149 | chord_tensor = to_float_tensor([f_chord[key] for key in key_indices], device=device) 150 | else: 151 | chord_tensor = to_tensor([f_chord[key] for key in key_indices], device=device) 152 | style_tensor = to_float_tensor([[f_style[attr + '/' + key] for attr in cfg['attr']] 153 | for key in key_indices], device=device).reshape(cfg['batch_size'], -1) 154 | style_cls_tensor = to_tensor([[f_style_cls[attr + '/' + key] for attr in cfg['attr']] 155 | for key in key_indices], device=device).reshape(cfg['batch_size'], -1) 156 | event_tensor = to_tensor([f_event[key] for key in key_indices], device=device) 157 | 158 | loss, loss_d, loss_kl, acc = train(chord_tensor, style_tensor, style_cls_tensor, event_tensor, model, model_d, 159 | optimizer, optimizer_d, cfg['lambda_d'], cfg['lambda_kl'], 160 | is_ordinal=cfg['is_ordinal']) 161 | print("loss: {:.5}, loss_d: {:.5}, loss_kl {:.5}, acc: {:.5}".format(loss, loss_d, loss_kl, acc)) 162 | writer.add_scalar("train/loss", loss, iter) 163 | writer.add_scalar("train/loss_d", loss_d, iter) 164 | writer.add_scalar("train/loss_kl", loss_kl, iter) 165 | writer.add_scalar("train/acc", acc, iter) 166 | print_loss_total += loss 167 | 168 | if iter % print_every == 0: 169 | print_loss_avg = print_loss_total / print_every 170 | print('%s (%d %d%%) %.4f' % (time_since(start, iter / cfg['n_iters']), 171 | iter, iter / cfg['n_iters'] * 100, print_loss_avg)) 172 | 173 | print_losses.append(print_loss_avg) 174 | print_loss_total = 0 175 | 176 | # Test 177 | val_loss, val_loss_d, val_loss_kl, acc = test(f_chord_test, f_style_test, f_style_cls_test, f_event_test, 178 | keys_test, model, model_d, device, cfg['attr'], 179 | cfg['batch_size'], use_chord_vector, 180 | is_ordinal=cfg['is_ordinal']) 181 | 182 | writer.add_scalar("test/loss", val_loss, iter) 183 | writer.add_scalar("test/loss_d", val_loss_d, iter) 184 | writer.add_scalar("test/loss_kl", val_loss_kl, iter) 185 | writer.add_scalar("test/acc", acc, iter) 186 | val_losses.append(val_loss) 187 | print("val_loss: {:.5}, loss_d: {:.5}, loss_kl: {:.5}, acc: {:.5}".format( 188 | val_loss, val_loss_d, val_loss_kl, acc)) 189 | 190 | if iter % model_save_every == 0: 191 | save_path = os.path.join(cfg['output_dir'], 'models', 'checkpoint_{}'.format(iter)) 192 | torch.save({ 193 | 'model': model.state_dict(), 194 | 'model_d': model_d.state_dict(), 195 | 'optimizer': optimizer.state_dict(), 196 | 'optimizer_d': optimizer_d.state_dict(), 197 | 'cfg': cfg, 198 | }, save_path) 199 | latest_model_text_file = os.path.join(cfg['output_dir'], 'latest_model.txt') 200 | with open(latest_model_text_file, 'w') as f: 201 | f.write(save_path) 202 | 203 | f_chord.close() 204 | f_style.close() 205 | f_event.close() 206 | f_chord_test.close() 207 | f_style_test.close() 208 | f_event_test.close() 209 | writer.close() 210 | 211 | 212 | def run(args): 213 | yaml_path = 'env.yml' 214 | yaml = ruamel.yaml.YAML() 215 | with open(yaml_path) as stream: 216 | env_var = yaml.load(stream) 217 | 218 | output_dir = os.path.join(output_path_base, args.model_name) 219 | model_dir = os.path.join(output_dir, 'models') 220 | latest_model_text_file = os.path.join(output_dir, 'latest_model.txt') 221 | 222 | checkpoint = None 223 | start_iter = 0 224 | if not os.path.exists(output_dir): 225 | os.mkdir(output_dir) 226 | os.mkdir(model_dir) 227 | else: 228 | if os.path.exists(latest_model_text_file): 229 | with open(latest_model_text_file, 'r') as f: 230 | latest_model_path = f.read() 231 | checkpoint = torch.load(latest_model_path) 232 | start_iter = int(latest_model_path.split('_')[-1]) 233 | 234 | cfg = generate_cfg_fader(env_var, args, output_dir, checkpoint) 235 | print(cfg) 236 | 237 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 238 | 239 | model_d = Classifier(input_dim=cfg['z_dims'], 240 | num_layers=cfg['model']['discriminator']['num_layers'], 241 | n_attr=len(cfg['attr']), 242 | activation=cfg['activation_d'], 243 | n_classes=cfg['n_classes'], 244 | device=device) 245 | 246 | use_chord_vector = True 247 | model = FaderVAE(vocab_size=cfg['vocab_size'], 248 | hidden_dims=cfg['model']['encoder']['hidden_size'], 249 | z_dims=cfg['z_dims'], 250 | n_step=cfg['bars_per_data'] * cfg['steps_per_quarter'] * QUARTERS_PER_BAR, 251 | device=device, 252 | n_attr=cfg['n_attr']) 253 | 254 | optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate']) 255 | optimizer_d = optim.Adam(model_d.parameters(), lr=cfg['learning_rate_d']) 256 | 257 | if checkpoint: 258 | model.load_state_dict(checkpoint['model']) 259 | model_d.load_state_dict(checkpoint['model_d']) 260 | optimizer.load_state_dict(checkpoint['optimizer']) 261 | optimizer_d.load_state_dict(checkpoint['optimizer_d']) 262 | if device != 'cpu': 263 | for state in optimizer.state.values(): 264 | for k, v in state.items(): 265 | if torch.is_tensor(v): 266 | state[k] = v.cuda() 267 | for state in optimizer_d.state.values(): 268 | for k, v in state.items(): 269 | if torch.is_tensor(v): 270 | state[k] = v.cuda() 271 | model.to(device) 272 | model_d.to(device) 273 | 274 | print(model) 275 | train_iters(model, model_d, optimizer, optimizer_d, cfg, device, print_every=1000, model_save_every=1000, 276 | start_iter=start_iter, use_chord_vector=use_chord_vector) 277 | 278 | 279 | if __name__ == "__main__": 280 | parser = ArgumentParser() 281 | parser.add_argument('--gpu', type=int, default=[0, 1], nargs='+', help='used gpu') 282 | parser.add_argument('--model_name', type=str, default="tmp", help='model name') 283 | parser.add_argument('--n_iters', type=int, help='epoch number') 284 | parser.add_argument('--learning_rate', type=float, help='learning rate') 285 | parser.add_argument('--learning_rate_d', type=float, help='learning rate for discriminator') 286 | parser.add_argument('--model_structure', type=str, help='model structure') 287 | parser.add_argument('--env', type=str, default='default', help='env yaml file') 288 | parser.add_argument('--c_num_layers', type=int, help='layer number for chord encoder') 289 | parser.add_argument('--c_hidden_size', type=int, help='hidden size for chord encoder') 290 | parser.add_argument('--e_num_layers', type=int, help='number of layers for encoder') 291 | parser.add_argument('--e_hidden_size', type=int, help='hidden size for encoder') 292 | parser.add_argument('--d_num_layers', type=int, help='number of layers for decoder') 293 | parser.add_argument('--d_hidden_size', type=int, help='hidden size for decoder') 294 | parser.add_argument('--dis_num_layers', type=int, help='number of layers for discriminator') 295 | parser.add_argument('--lambda_d', type=float, help='Loss weight for discriminator') 296 | parser.add_argument('--lambda_kl', type=float, help='Loss weight for kl') 297 | parser.add_argument('--attribute', type=str, nargs='*', help='attributes for style') 298 | parser.add_argument('--z_dims', type=int, default=128, help='dimension of latent vector') 299 | parser.add_argument('--batch_size', type=int, help='batch size') 300 | parser.add_argument('--is_ordinal', type=int, help='to use ordinal classification') 301 | 302 | args = parser.parse_args() 303 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpu) 304 | 305 | run(args) 306 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import matplotlib.pyplot as plt 5 | plt.switch_backend('agg') 6 | import matplotlib.ticker as ticker 7 | 8 | import torch 9 | from torch.distributions import kl_divergence, Normal 10 | 11 | 12 | def as_minutes(s): 13 | m = math.floor(s / 60) 14 | s -= m * 60 15 | return '%dm %ds' % (m, s) 16 | 17 | 18 | def time_since(since, percent): 19 | now = time.time() 20 | s = now - since 21 | es = s / percent 22 | rs = es - s 23 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 24 | 25 | 26 | def show_plot(points, output_dir, mode): 27 | plt.figure() 28 | fig, ax = plt.subplots() 29 | # this locator puts ticks at regular intervals 30 | loc = ticker.MultipleLocator(base=0.2) 31 | ax.yaxis.set_major_locator(loc) 32 | plt.plot(points) 33 | plt_path = os.path.join(output_dir, '{}_loss.png'.format(mode)) 34 | plt.savefig(plt_path) 35 | 36 | 37 | def sequence_mask(length, max_length=None): 38 | """Tensorflow의 sequence_mask""" 39 | if max_length is None: 40 | max_length = length.max() 41 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 42 | return x.unsqueeze(0) < length.unsqueeze(1) 43 | 44 | 45 | def get_masked_with_pad_tensor(size, src, trg, pad_token): 46 | """ 47 | :param size: the size of target input 48 | :param src: source tensor 49 | :param trg: target tensor 50 | :param pad_token: pad token 51 | :return: 52 | """ 53 | src = src[:, None, None, :] 54 | trg = trg[:, None, None, :] 55 | src_pad_tensor = torch.ones_like(src).to(src.device.type) * pad_token 56 | src_mask = torch.equal(src, src_pad_tensor) 57 | trg_mask = torch.equal(src, src_pad_tensor) 58 | if trg is not None: 59 | trg_pad_tensor = torch.ones_like(trg).to(trg.device.type) * pad_token 60 | dec_trg_mask = trg == trg_pad_tensor 61 | # boolean reversing i.e) True * -1 + 1 = False 62 | seq_mask = ~sequence_mask(torch.arange(1, size+1).to(trg.device), size) 63 | # look_ahead_mask = torch.max(dec_trg_mask, seq_mask) 64 | look_ahead_mask = dec_trg_mask | seq_mask 65 | 66 | else: 67 | trg_mask = None 68 | look_ahead_mask = None 69 | 70 | return src_mask, trg_mask, look_ahead_mask 71 | 72 | 73 | def std_normal(shape): 74 | N = Normal(torch.zeros(shape), torch.ones(shape)) 75 | if torch.cuda.is_available(): 76 | N.loc = N.loc.cuda() 77 | N.scale = N.scale.cuda() 78 | return N 79 | --------------------------------------------------------------------------------