├── .gitignore ├── __init__.py ├── artifacts ├── __init__.py ├── cp_tokenizer.json ├── phrase_lengths.json ├── test_file_list.json ├── tokenizer.json ├── train_file_list.json └── valid_file_list.json ├── baselines ├── autoregressive_transformer │ ├── data_loader.py │ ├── train_autoregressive.py │ └── train_pop_music.py └── compound_transformer │ ├── build_vocab.py │ ├── cp_model.py │ ├── data_loader.py │ ├── generate.py │ ├── reversible.py │ └── train.py ├── configs ├── __init__.py ├── configs_custom.yaml └── configs_os.yaml ├── data └── __init__.py ├── generate.py ├── generate_custom.py ├── generate_variations.py ├── images ├── Corruption_Refinement_Training.png └── YY_Generation_Framework.png ├── key_profile.pickle ├── notebooks ├── .ipynb_checkpoints │ └── Tests-checkpoint.ipynb ├── Midi_Encoding.ipynb └── Tests.ipynb ├── phrase_generator ├── __init__.py ├── data_loader.py └── train.py ├── phrase_refiner ├── __init__.py ├── data_loader.py ├── train.py └── transformations.py ├── phrase_selector ├── __init__.py ├── data_loader.py └── train.py ├── preprocess ├── __init__.py ├── phrase_extraction.py ├── preprocess_mono.py └── write_midi_files.py ├── readme.md ├── requirements.txt ├── structure_derivation ├── __init__.py ├── data_loader.py ├── pitch_evaluator.py ├── sd_evaluator.py └── train.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # .gitignore 2 | cache/ 3 | __pycache__/ 4 | phrase_chooser/ 5 | artifacts/phrase_generation/ 6 | artifacts/phrase_refinement/ 7 | artifacts/phrase_selection/ 8 | artifacts/motif_refinement/ 9 | artifacts/autoregressive_transformer/ 10 | artifacts/cp_transformer/ 11 | artifacts/phrase_similarity/ 12 | artifacts/fusion/ 13 | data/Mono_Midi_Files/ 14 | data/extracted_phrases/ 15 | data/annotations/ 16 | lightning_logs/ 17 | notebooks/ 18 | output/ 19 | artifacts.zip 20 | data.zip 21 | *.pyc 22 | *.pth -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/__init__.py -------------------------------------------------------------------------------- /artifacts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/artifacts/__init__.py -------------------------------------------------------------------------------- /artifacts/test_file_list.json: -------------------------------------------------------------------------------- 1 | ["NLB152532_01.json", "deut4702.json", "NLB177519_01.json", "tirol12.json", "deut1771.json", "NLB075191_03.json", "jugos091.json", "NLB142805_01.json", "NLB075160_01.json", "NLB074234_01.json", "NLB151986_01.json", "NLB167111_01.json", "NLB149200_01.json", "lothr009.json", "lothr040.json", "NLB148500_01.json", "han0766.json", "NLB177163_01.json", "NLB134819_01.json", "deut2034.json", "NLB151467_01.json", "NLB145155_01.json", "NLB072487_01.json", "deut2337.json", "deut214.json", "NLB144206_01.json", "NLB178258_01.json", "deut2747.json", "NLB002963_01.json", "NLB144997_01.json", "kindr030.json", "han0443.json", "NLB002967_01.json", "NLB136709_01.json", "shanx372.json", "kindr150.json", "NLB073804_02.json", "NLB196747_01.json", "NLB139722_01.json", "han0560.json", "natmn069.json", "deut3335.json", "NLB074461_01.json", "NLB179325_01.json", "NLB183423_01.json", "NLB138204_01.json", "NLB070117_01.json", "shanx384.json", "NLB181093_01.json", "NLB124531_01.json", "NLB196702_01.json", "NLB196783_01.json", "czech33.json", "deut0687.json", "NLB124231_01.json", "NLB191146_01.json", "NLB075031_01.json", "NLB139301_04.json", "deut3959.json", "NLB192054_01.json", "shanx253.json", "NLB196812_01.json", "NLB138443_01.json", "romani04.json", "NLB151279_01.json", "NLB136495_01.json", "han0699.json", "deut2660.json", "NLB075641_01.json", "NLB152330_01.json", "NLB197829_01.json", "NLB129133_01.json", "NLB074548_01.json", "NLB170299_01.json", "deut1497.json", "romani24.json", "NLB152567_01.json", "deut3026.json", "deut3957.json", "natmn060.json", "NLB195305_01.json", "NLB147927_01.json", "NLB126064_01.json", "NLB150927_01.json", "NLB197864_01.json", "NLB003684_01.json", "NLB152518_01.json", "NLB142415_01.json", "NLB004499_01.json", "NLB011074_01.json", "NLB123646_01.json", "NLB115828_01.json", "NLB151461_01.json", "deut4619.json", "france11.json", "NLB074200_01.json", "NLB150854_01.json", "magyar07.json", "NLB070712_01.json", "NLB075093_01.json"] -------------------------------------------------------------------------------- /artifacts/tokenizer.json: -------------------------------------------------------------------------------- 1 | {"Pitch_0": 1, "Pitch_1": 2, "Pitch_2": 3, "Pitch_3": 4, "Pitch_4": 5, "Pitch_5": 6, "Pitch_6": 7, "Pitch_7": 8, "Pitch_8": 9, "Pitch_9": 10, "Pitch_10": 11, "Pitch_11": 12, "Pitch_12": 13, "Pitch_13": 14, "Pitch_14": 15, "Pitch_15": 16, "Pitch_16": 17, "Pitch_17": 18, "Pitch_18": 19, "Pitch_19": 20, "Pitch_20": 21, "Pitch_21": 22, "Pitch_22": 23, "Pitch_23": 24, "Pitch_24": 25, "Pitch_25": 26, "Pitch_26": 27, "Pitch_27": 28, "Pitch_28": 29, "Pitch_29": 30, "Pitch_30": 31, "Pitch_31": 32, "Pitch_32": 33, "Pitch_33": 34, "Pitch_34": 35, "Pitch_35": 36, "Pitch_36": 37, "Pitch_37": 38, "Pitch_38": 39, "Pitch_39": 40, "Pitch_40": 41, "Pitch_41": 42, "Pitch_42": 43, "Pitch_43": 44, "Pitch_44": 45, "Pitch_45": 46, "Pitch_46": 47, "Pitch_47": 48, "Pitch_48": 49, "Pitch_49": 50, "Pitch_50": 51, "Pitch_51": 52, "Pitch_52": 53, "Pitch_53": 54, "Pitch_54": 55, "Pitch_55": 56, "Pitch_56": 57, "Pitch_57": 58, "Pitch_58": 59, "Pitch_59": 60, "Pitch_60": 61, "Pitch_61": 62, "Pitch_62": 63, "Pitch_63": 64, "Pitch_64": 65, "Pitch_65": 66, "Pitch_66": 67, "Pitch_67": 68, "Pitch_68": 69, "Pitch_69": 70, "Pitch_70": 71, "Pitch_71": 72, "Pitch_72": 73, "Pitch_73": 74, "Pitch_74": 75, "Pitch_75": 76, "Pitch_76": 77, "Pitch_77": 78, "Pitch_78": 79, "Pitch_79": 80, "Pitch_80": 81, "Pitch_81": 82, "Pitch_82": 83, "Pitch_83": 84, "Pitch_84": 85, "Pitch_85": 86, "Pitch_86": 87, "Pitch_87": 88, "Pitch_88": 89, "Pitch_89": 90, "Pitch_90": 91, "Pitch_91": 92, "Pitch_92": 93, "Pitch_93": 94, "Pitch_94": 95, "Pitch_95": 96, "Pitch_96": 97, "Pitch_97": 98, "Pitch_98": 99, "Pitch_99": 100, "Pitch_100": 101, "Pitch_101": 102, "Pitch_102": 103, "Pitch_103": 104, "Pitch_104": 105, "Pitch_105": 106, "Pitch_106": 107, "Pitch_107": 108, "Pitch_108": 109, "Pitch_109": 110, "Pitch_110": 111, "Pitch_111": 112, "Pitch_112": 113, "Pitch_113": 114, "Pitch_114": 115, "Pitch_115": 116, "Pitch_116": 117, "Pitch_117": 118, "Pitch_118": 119, "Pitch_119": 120, "Pitch_120": 121, "Pitch_121": 122, "Pitch_122": 123, "Pitch_123": 124, "Pitch_124": 125, "Pitch_125": 126, "Pitch_126": 127, "Pitch_127": 128, "PR_4": 129, "PL_10": 130, "Duration_0.75": 131, "Position_0": 132, "Duration_0.25": 133, "Position_0.75": 134, "Duration_0.5": 135, "Position_1.0": 136, "Position_1.5": 137, "Position_2.25": 138, "Position_2.5": 139, "Position_0.0": 140, "Duration_1.0": 141, "TimeSig_6/8": 142, "KS_A": 143, "MM_minor": 144, "PR_5": 145, "PL_6": 146, "Duration_1.5": 147, "PL_7": 148, "PL_8": 149, "PL_11": 150, "Duration_2.5": 151, "PR_14": 152, "Position_2.0": 153, "TimeSig_3/4": 154, "KS_G": 155, "MM_major": 156, "PR_17": 157, "Duration_2.0": 158, "PR_12": 159, "Position_0.5": 160, "PR_10": 161, "PL_12": 162, "Position_0.25": 163, "PR_7": 164, "Position_3.0": 165, "TimeSig_4/4": 166, "PL_9": 167, "Position_3.5": 168, "PL_4": 169, "PR_3": 170, "Position_1.75": 171, "TimeSig_2/4": 172, "TimeSig_7/8": 173, "PR_6": 174, "KS_E-": 175, "KS_C": 176, "PR_15": 177, "PR_18": 178, "Position_4.0": 179, "Position_5.5": 180, "TimeSig_6/4": 181, "KS_D": 182, "Duration_4.0": 183, "KS_F": 184, "Position_1.25": 185, "PR_9": 186, "PL_19": 187, "PR_8": 188, "PL_5": 189, "Duration_0.33": 190, "Position_0.33": 191, "Position_0.67": 192, "PL_14": 193, "PR_13": 194, "Duration_3.0": 195, "TimeSig_3/8": 196, "PL_15": 197, "PL_17": 198, "PL_13": 199, "Duration_0.12": 200, "Position_0.12": 201, "Position_0.38": 202, "Position_2.75": 203, "PL_18": 204, "KS_B-": 205, "PR_2": 206, "Position_3.75": 207, "PL_16": 208, "KS_D-": 209, "PR_21": 210, "PR_19": 211, "PL_21": 212, "PL_20": 213, "PL_26": 214, "PR_16": 215, "Duration_1.25": 216, "KS_B": 217, "Duration_4.5": 218, "KS_E": 219, "Position_6.0": 220, "Position_10.0": 221, "Position_14.0": 222, "Duration_8.0": 223, "Duration_6.0": 224, "TimeSig_4/1": 225, "Duration_12.0": 226, "Position_8.0": 227, "Duration_1.33": 228, "Position_3.33": 229, "Position_1.33": 230, "Position_2.67": 231, "TimeSig_3/1": 232, "Duration_14.0": 233, "Position_12.0": 234, "TimeSig_3/2": 235, "TimeSig_2/1": 236, "Position_5.0": 237, "TimeSig_4/2": 238, "TimeSig_6/2": 239, "Duration_10.0": 240, "Position_7.0": 241, "PR_11": 242, "PL_30": 243, "Duration_16.0": 244, "Position_11.0": 245, "Position_9.0": 246, "PR_1": 247, "PL_22": 248, "PL_24": 249, "PL_25": 250, "PL_23": 251, "Duration_5.0": 252, "TimeSig_2/2": 253, "Duration_0.67": 254, "Position_2.33": 255, "Position_3.67": 256, "PL_3": 257, "Position_7.5": 258, "Position_6.5": 259, "Position_3.25": 260, "Duration_0.38": 261, "Position_0.83": 262, "Position_1.17": 263, "Position_2.17": 264, "Position_2.83": 265, "Position_1.67": 266, "KS_A-": 267, "TimeSig_4/8": 268, "Position_3.83": 269, "Position_0.17": 270, "TimeSig_9/8": 271, "PR_0": 272, "Position_4.5": 273, "Position_5.75": 274, "Duration_0.17": 275, "Position_1.83": 276, "Position_0.88": 277, "PL_32": 278, "Position_1.88": 279, "Duration_2.75": 280, "PR_24": 281, "PR_23": 282, "Position_7.33": 283, "Position_4.67": 284, "Duration_3.5": 285, "TimeSig_5/4": 286, "Duration_1.75": 287, "PL_2": 288, "KS_F#": 289, "PR_20": 290, "PL_1": 291, "Duration_2.25": 292, "Position_3.17": 293, "Position_2.88": 294, "PL_42": 295, "TimeSig_5/8": 296, "Position_3.38": 297, "Position_2.38": 298, "Position_1.62": 299, "PR_22": 300, "PL_27": 301, "PL_33": 302, "Position_2.62": 303, "Position_3.12": 304, "Position_3.62": 305, "Position_3.88": 306, "Position_1.38": 307, "Position_0.62": 308, "Position_1.12": 309, "Position_2.12": 310, "PL_28": 311, "PL_29": 312, "Position_1.92": 313, "PR_25": 314, "PR_28": 315, "Duration_6.5": 316, "Duration_3.75": 317, "PL_31": 318, "Duration_0.08": 319, "Position_0.92": 320, "PL_65": 321, "Duration_0.88": 322, "TimeSig_1/4": 323, "Duration_1.67": 324, "Position_0.58": 325, "Position_2.08": 326, "Position_2.42": 327, "Position_1.08": 328, "Position_3.37": 329, "Position_2.37": 330, "Position_2.87": 331, "Position_3.08": 332, "PL_36": 333, "Duration_0.62": 334, "Duration_3.25": 335, "Position_4.75": 336, "Position_3.42": 337, "Position_3.58": 338, "Position_2.58": 339, "Position_3.92": 340, "Position_0.08": 341, "Position_0.42": 342, "Position_2.92": 343, "Duration_4.75": 344, "Duration_6.75": 345, "Position_2.54": 346, "Position_2.71": 347, "Position_3.87": 348, "Position_1.54": 349, "Position_1.71": 350, "Position_3.21": 351, "Position_3.54": 352, "Duration_7.0": 353, "Duration_5.5": 354, "Position_1.58": 355, "MM_phrygian": 356, "PL_34": 357, "PR_26": 358, "PL_39": 359, "PR_31": 360, "PR_27": 361, "PL_38": 362, "PL_40": 363, "PL_35": 364, "PL_45": 365, "TimeSig_12/8": 366, "MM_dorian": 367, "PL_47": 368, "PL_53": 369, "PL_82": 370, "MM_mixolydian": 371, "PL_49": 372, "PL_62": 373, "PL_46": 374, "PL_37": 375, "Duration_0.06": 376, "Position_0.31": 377, "Position_0.44": 378, "Position_0.56": 379, "Position_0.69": 380, "PL_41": 381, "PL_44": 382, "PL_52": 383, "PL_48": 384, "MM_lydian": 385, "Duration_9.0": 386, "TimeSig_9/4": 387, "Position_5.25": 388, "PL_106": 389, "PL_66": 390, "PL_59": 391, "PL_58": 392, "PL_113": 393, "Position_0.13": 394, "PL_56": 395, "PL_71": 396, "PL_43": 397, "PL_51": 398, "PL_57": 399, "PL_50": 400, "PL_84": 401, "PL_76": 402, "PL_67": 403, "Position_1.13": 404, "Position_1.63": 405, "Position_0.63": 406, "Position_1.42": 407, "PL_91": 408, "PL_63": 409, "PL_54": 410, "PL_60": 411, "PL_64": 412, "Position_0.94": 413, "Position_1.94": 414, "PL_79": 415, "PL_74": 416, "PL_121": 417, "PL_77": 418, "PL_128": 419, "PL_70": 420, "PL_61": 421, "PL_73": 422, "Position_4.33": 423, "Position_5.67": 424, "Position_4.25": 425, "TimeSig_7/4": 426, "TimeSig_8/8": 427, "Duration_1.38": 428, "TimeSig_2/8": 429, "Duration_2.67": 430, "Duration_3.33": 431, "TimeSig_8/4": 432, "Duration_10.5": 433, "Position_5.33": 434, "Position_5.83": 435, "Duration_2.33": 436, "PR_29": 437, "PL_55": 438, "Position_5.88": 439, "PL_88": 440, "PL_72": 441, "Duration_7.5": 442, "Duration_1.88": 443, "Duration_13.0": 444, "Position_16.0": 445, "TimeSig_10/2": 446, "Position_20.0": 447, "TimeSig_12/2": 448, "Position_18.0": 449, "Position_22.0": 450, "Position_24.0": 451, "Position_28.0": 452, "TimeSig_16/2": 453, "Position_26.0": 454, "Position_9.5": 455, "TimeSig_5/2": 456, "Position_4.17": 457, "Position_4.38": 458, "MM_locrian": 459, "Position_4.12": 460, "Position_5.62": 461, "Position_1.87": 462, "Duration_2.17": 463, "Position_0.87": 464, "Position_1.37": 465, "Position_0.37": 466, "Position_1.56": 467, "Position_2.56": 468, "Position_3.56": 469, "Position_3.06": 470, "Duration_1.17": 471, "PL_87": 472, "PL_150": 473, "PL_69": 474, "PL_81": 475, "Position_1.69": 476, "Position_1.19": 477, "Position_2.19": 478, "PL_110": 479, "PL_102": 480, "PL_68": 481, "PL_90": 482, "PL_96": 483, "Position_2.13": 484, "PL_100": 485, "PL_78": 486, "PL_85": 487, "Bar_None": 488, "PP_beginning": 489, "PP_middle": 490, "PP_end": 491, "CA_True": 492, "CA_False": 493, "BOS": 494, "EOS": 495, "SEP": 496, "COR_incorrect_transposition": 497, "COR_incorrect_inversion": 498, "COR_note_swapping": 499, "COR_melodic_stripping": 500, "COR_melodic_addition": 501, "COR_same_note_modification": 502, "COR_permute_note_pitch": 503, "COR_permute_note_duration": 504, "COR_permute_note_pitch_duration": 505, "COR_BAR_MASK": 506, "COR_PITCH_MASK": 507, "COR_DURATION_MASK": 508, "COR_FRAGMENT_NOTES": 509, "UNK": 510} -------------------------------------------------------------------------------- /baselines/autoregressive_transformer/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | # Append the path to the current working directory 14 | sys.path.append(os.getcwd()) 15 | from utils.utils import list_to_remi_encoding 16 | 17 | 18 | class JSONDataset(Dataset): 19 | def __init__(self, configs, file_list, mode="train", shuffle = False): 20 | self.mode = mode 21 | # Data dir 22 | self.data_dir = configs['raw_data']['json_folder'] 23 | self.file_list = file_list 24 | if shuffle: 25 | random.shuffle(self.file_list) 26 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 27 | self.file_number_phrase_number = [] 28 | for file_path in self.file_list: 29 | file_path = os.path.join(self.data_dir, file_path) 30 | with open(file_path, 'r') as f: 31 | data = json.load(f) 32 | phrase_number = len(data["phrases"].keys()) 33 | # Add the file path and phrase number to the list 34 | for i in range(phrase_number): 35 | # if mode == "eval" and i == 0: 36 | # self.file_number_phrase_number.append([file_path, i]) 37 | # break 38 | # if mode == "train": 39 | self.file_number_phrase_number.append([file_path, i]) 40 | 41 | # Artifact folder 42 | self.artifact_folder = configs['raw_data']['artifact_folder'] 43 | # Load encoder tokenizer json file dictionary 44 | tokenizer_filepath = os.path.join(self.artifact_folder, "tokenizer.json") 45 | # Load the tokenizer dictionary 46 | with open(tokenizer_filepath, 'r') as f: 47 | self.tokenizer = json.load(f) 48 | 49 | # Get the maximum sequence length 50 | self.decoder_max_sequence_length = 2048 51 | 52 | # Print length of dataset 53 | print("Length of dataset: ", len(self.file_list)) 54 | print("Length of phrases in dataset: ", len(self.file_number_phrase_number)) 55 | 56 | def __len__(self): 57 | return len(self.file_number_phrase_number) 58 | 59 | def transpose(self, phrase, pitch_change): 60 | encoding = copy.deepcopy(phrase) 61 | 62 | transposed_encoding = [ 63 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 64 | for event in encoding 65 | ] 66 | 67 | return transposed_encoding 68 | 69 | def transpose_key(self, current_key, semitones): 70 | keys = ["KS_A-", "KS_A", "KS_B-", "KS_B", "KS_C", "KS_D-", "KS_D", "KS_E-", "KS_E", "KS_F", "KS_F#", "KS_G"] 71 | 72 | # Find the index of the current key in the list 73 | current_index = keys.index(current_key) 74 | 75 | # Calculate the new index after transposing by the given semitones 76 | new_index = (current_index + semitones) % len(keys) 77 | 78 | # Return the new key 79 | return keys[new_index] 80 | 81 | def augment_phrase(self, target, current_key): 82 | if random.random() < 0.5: 83 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 84 | 85 | encoding = target 86 | 87 | # Find highest and lowest pitch values 88 | pitch_values = [event[3] for event in encoding] 89 | highest_pitch = max(pitch_values) 90 | lowest_pitch = min(pitch_values) 91 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 92 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 93 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 94 | if pitch_change < 0: 95 | pitch_change += 1 96 | else: 97 | pitch_change -= 1 98 | 99 | target = self.transpose(target, pitch_change) 100 | current_key = self.transpose_key(current_key, pitch_change) 101 | 102 | return target, current_key 103 | 104 | def __getitem__(self, idx): 105 | file_path = self.file_number_phrase_number[idx][0] 106 | phrase_number = self.file_number_phrase_number[idx][1] 107 | with open(file_path, 'r') as f: 108 | data = json.load(f) 109 | 110 | time_signature = data["metadata"]["time_signature"] 111 | key_signature = data["metadata"]["key_signature"] 112 | major_or_minor = data["metadata"]["major_or_minor"] 113 | # Get the phrase and the target 114 | # Get all phrases from phrase_number and after and concatenate them 115 | target = [] 116 | for i in range(phrase_number, len(data["phrases"].keys())): 117 | phrase = data["phrases"][str(i)][0] 118 | target += phrase 119 | 120 | # Extract an arbitrary phrase from random point till the last element of the list 121 | target_length = len(target) 122 | 123 | # Augment the phrases 124 | if self.mode == "train": 125 | target, key_signature = self.augment_phrase(target, key_signature) 126 | 127 | tempo_location = data["metadata"]["tempo"] 128 | 129 | # Add the BOS and EOS tokens to the target 130 | target = list_to_remi_encoding(target, tempo_location, time_signature) 131 | target = [major_or_minor] + [key_signature] + ["SEP"] + target + ["EOS"] 132 | # Tokenize the target 133 | target = [self.tokenizer[note] for note in target if note in self.tokenizer] 134 | 135 | # Pad the target to the maximum sequence length 136 | target = torch.tensor(target) 137 | if len(target) < self.decoder_max_sequence_length: 138 | target = F.pad(target, (0, self.decoder_max_sequence_length - len(target))) 139 | else: 140 | target = target[:self.decoder_max_sequence_length] 141 | target_attention_mask = torch.where(target != 0, 1, 0) 142 | target_attention_mask = target_attention_mask.float() 143 | 144 | train_data = {"input_ids": target, "labels": target, "attention_mask": target_attention_mask} 145 | 146 | return train_data 147 | 148 | 149 | if __name__ == "__main__": 150 | 151 | # Parse command line arguments 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument("--config", type=str, default=os.path.normpath(r"configs/configs_os.yaml"), 154 | help="Path to the config file") 155 | args = parser.parse_args() 156 | 157 | # Load config file 158 | with open(args.config, 'r') as f: 159 | configs = yaml.safe_load(f) 160 | 161 | batch_size = configs['training']['phrase_generation']['batch_size'] 162 | 163 | # Artifact folder 164 | artifact_folder = configs['raw_data']['artifact_folder'] 165 | # Load encoder tokenizer json file dictionary 166 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 167 | # Load the tokenizer dictionary 168 | with open(tokenizer_filepath, 'r') as f: 169 | tokenizer = json.load(f) 170 | 171 | 172 | # Open the train, validation, and test sets json files 173 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 174 | train_file_list = json.load(f) 175 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 176 | valid_file_list = json.load(f) 177 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 178 | test_file_list = json.load(f) 179 | 180 | # Print length of train, validation, and test sets 181 | print("Length of train set: ", len(train_file_list)) 182 | print("Length of validation set: ", len(valid_file_list)) 183 | print("Length of test set: ", len(test_file_list)) 184 | 185 | # Load the dataset 186 | dataset = JSONDataset(configs, train_file_list, mode="train") 187 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 188 | 189 | for n, data in enumerate(dataset): 190 | # print shape and type of tensor 191 | print(data["input_ids"].shape, data["input_ids"].dtype) 192 | print(data["labels"].shape, data["labels"].dtype) 193 | print(data["attention_mask"].shape, data["attention_mask"].dtype) 194 | if n > 0: 195 | break -------------------------------------------------------------------------------- /baselines/autoregressive_transformer/train_autoregressive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | import random 8 | from torch import Tensor, argmax 9 | from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback 10 | from evaluate import load as load_metric 11 | from data_loader import JSONDataset 12 | import sys 13 | import argparse 14 | from tqdm import tqdm 15 | 16 | 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config", type=str, default=os.path.normpath(r"C:\Users\Keshav\Desktop\QMUL\Research\PhraseBuilder_new\PhraseBuilder\configs\configs_windows.yaml"), 20 | help="Path to the config file") 21 | args = parser.parse_args() 22 | 23 | # Load config file 24 | with open(args.config, 'r') as f: 25 | configs = yaml.safe_load(f) 26 | 27 | batch_size = 48 28 | learning_rate = 0.0001 29 | epochs = 30 30 | 31 | # Artifact folder 32 | artifact_folder = configs['raw_data']['artifact_folder'] 33 | # Load encoder tokenizer json file dictionary 34 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 35 | # Load the tokenizer dictionary 36 | with open(tokenizer_filepath, 'r') as f: 37 | tokenizer = json.load(f) 38 | 39 | 40 | # Open the train, validation, and test sets json files 41 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 42 | train_file_list = json.load(f) 43 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 44 | valid_file_list = json.load(f) 45 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 46 | test_file_list = json.load(f) 47 | 48 | # Print length of train, validation, and test sets 49 | print("Length of train set: ", len(train_file_list)) 50 | print("Length of validation set: ", len(valid_file_list)) 51 | print("Length of test set: ", len(test_file_list)) 52 | 53 | # Load the dataset 54 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 55 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 56 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 57 | 58 | # Get the vocab size 59 | vocab_size = len(tokenizer) 60 | # Get the phrase length 61 | train_phrase_length = len(train_dataset.file_number_phrase_number) 62 | 63 | # Create the model 64 | # Define your model configuration 65 | config_decoder = GPT2Config( 66 | vocab_size=vocab_size, 67 | n_positions=2048, 68 | n_embd=512, 69 | n_layer=4, 70 | n_head=4, 71 | pad_token_id=0, 72 | bos_token_id=tokenizer["BOS"], 73 | eos_token_id=tokenizer["EOS"], 74 | n_inner=2048, 75 | ) 76 | 77 | # config_decoder = TransfoXLConfig() 78 | # config_decoder.vocab_size = vocab_size 79 | # # config_decoder.max_position_embeddings = 4096 80 | # # config_decoder.max_length = 4096 81 | # config_decoder.pad_token_id = 0 82 | # config_decoder.bos_token_id = tokenizer["BOS"] 83 | # config_decoder.eos_token_id = tokenizer["EOS"] 84 | # config_decoder.num_hidden_layers = 5 85 | # config_decoder.num_attention_heads = 4 86 | # config_decoder.d_model = 512 87 | # config_decoder.d_inner = 2048 88 | # config_decoder.d_embed = 512 89 | # config_decoder.cutoffs = [0, vocab_size] 90 | # config_decoder.div_val = 1 91 | # config_decoder.untie_r = True 92 | # config_decoder.tie_projs = [False] 93 | 94 | print(config_decoder) 95 | 96 | # Create the model 97 | model = GPT2LMHeadModel(config=config_decoder) 98 | 99 | # Print the number of parameters in the model 100 | num_params = sum(p.numel() for p in model.parameters()) 101 | print(f"Number of parameters in the model: {num_params}") 102 | 103 | # Create config for the Trainer 104 | USE_CUDA = cuda_available() 105 | print(f"USE_CUDA: {USE_CUDA}") 106 | if not cuda_available(): 107 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 108 | elif is_bf16_supported(): 109 | BF16 = BF16_EVAL = True 110 | FP16 = FP16_EVAL = False 111 | else: 112 | BF16 = BF16_EVAL = False 113 | FP16 = FP16_EVAL = True 114 | USE_MPS = not USE_CUDA and mps_available() 115 | 116 | 117 | class AccuracyMetricsCallback(TrainerCallback): 118 | """ 119 | Callback to calculate accuracy metrics during evaluation. 120 | """ 121 | def __init__(self, val_data): 122 | self.val_data = val_data 123 | 124 | def on_evaluate(self, args, state, control, model=None, **kwargs): 125 | """ 126 | Calculate accuracy metrics on the validation data during evaluation. 127 | """ 128 | model.eval() # Set the model to evaluation mode 129 | 130 | total_tokens = 0 131 | correct_tokens = 0 132 | 133 | with torch.no_grad(): 134 | 135 | for inputs in tqdm(self.val_data): 136 | input_ids = inputs['input_ids'] 137 | attention_mask = inputs['attention_mask'] 138 | 139 | input_ids = input_ids.unsqueeze(0).to("cuda") 140 | 141 | # Get the model's predictions 142 | outputs = model(input_ids) 143 | logits = outputs.logits 144 | 145 | # Calculate the number of correct and total tokens 146 | shift_logits = logits[..., :-1, :].contiguous() 147 | shift_labels = input_ids[..., 1:].contiguous() 148 | 149 | # Exclude padding tokens from calculation 150 | not_pad_mask = shift_labels != 0 151 | correct_tokens += (shift_logits.argmax(dim=-1) == shift_labels)[not_pad_mask].sum().item() 152 | total_tokens += not_pad_mask.sum().item() 153 | 154 | accuracy = correct_tokens / total_tokens 155 | 156 | metrics = { 157 | 'accuracy': accuracy, 158 | 'correct_tokens': correct_tokens, 159 | 'total_tokens': total_tokens, 160 | } 161 | 162 | # Log the metrics to the Trainer's console 163 | print(f"Validation Accuracy: {metrics['accuracy']}") 164 | 165 | # # Add the metrics to the Trainer's evaluation results 166 | # control.load_state_dict(metrics) 167 | 168 | return control 169 | 170 | # Create your custom callback 171 | accuracy_metrics_callback = AccuracyMetricsCallback(val_data=valid_dataset) 172 | 173 | # Define the training arguments 174 | training_args = TrainingArguments( 175 | output_dir=os.path.join(artifact_folder, "autoregressive_transformer"), 176 | per_device_train_batch_size=batch_size, 177 | per_device_eval_batch_size=batch_size, 178 | save_strategy="steps", # "steps" or "epoch" 179 | save_steps=1000, 180 | save_total_limit=1, 181 | learning_rate=learning_rate, 182 | max_steps=int(train_phrase_length//batch_size)*epochs, 183 | evaluation_strategy="steps", 184 | eval_steps=1000, 185 | gradient_accumulation_steps=1, 186 | warmup_steps=500, 187 | gradient_checkpointing=True, 188 | optim="adafactor", 189 | seed=444, 190 | logging_strategy="steps", 191 | logging_steps=100, 192 | logging_dir=os.path.join(artifact_folder, "autoregressive_transformer", "logs"), 193 | no_cuda=not USE_CUDA, 194 | fp16=FP16, 195 | fp16_full_eval=FP16_EVAL, 196 | bf16=BF16, 197 | bf16_full_eval=BF16_EVAL, 198 | load_best_model_at_end=True, 199 | metric_for_best_model="eval_loss", 200 | greater_is_better=False, 201 | report_to="tensorboard", 202 | run_name="autoregressive_transformer", 203 | push_to_hub=False 204 | ) 205 | 206 | # Define the Trainer 207 | trainer = Trainer( 208 | model=model, 209 | args=training_args, 210 | train_dataset=train_dataset, 211 | eval_dataset=valid_dataset, 212 | # compute_metrics=compute_metrics, 213 | # preprocess_logits_for_metrics=preprocess_logits, 214 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] #accuracy_metrics_callback, 215 | ) 216 | 217 | # Train and save the model 218 | train_result = trainer.train() 219 | trainer.save_model() # Saves the tokenizer too 220 | trainer.log_metrics("train", train_result.metrics) 221 | trainer.save_metrics("train", train_result.metrics) 222 | trainer.save_state() 223 | -------------------------------------------------------------------------------- /baselines/autoregressive_transformer/train_pop_music.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | import random 8 | from torch import Tensor, argmax 9 | from transformers import AutoModelForCausalLM, BertConfig, TransfoXLConfig, TransfoXLLMHeadModel, Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback 10 | from evaluate import load as load_metric 11 | from data_loader import JSONDataset 12 | import sys 13 | import argparse 14 | from tqdm import tqdm 15 | 16 | 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config", type=str, default=os.path.normpath(r"C:\Users\Keshav\Desktop\QMUL\Research\PhraseBuilder_new\PhraseBuilder\configs\configs_windows.yaml"), 20 | help="Path to the config file") 21 | args = parser.parse_args() 22 | 23 | # Load config file 24 | with open(args.config, 'r') as f: 25 | configs = yaml.safe_load(f) 26 | 27 | batch_size = 32 28 | learning_rate = 0.00025 29 | epochs = 30 30 | 31 | # Artifact folder 32 | artifact_folder = configs['raw_data']['artifact_folder'] 33 | # Load encoder tokenizer json file dictionary 34 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 35 | # Load the tokenizer dictionary 36 | with open(tokenizer_filepath, 'r') as f: 37 | tokenizer = json.load(f) 38 | 39 | 40 | # Open the train, validation, and test sets json files 41 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 42 | train_file_list = json.load(f) 43 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 44 | valid_file_list = json.load(f) 45 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 46 | test_file_list = json.load(f) 47 | 48 | # Print length of train, validation, and test sets 49 | print("Length of train set: ", len(train_file_list)) 50 | print("Length of validation set: ", len(valid_file_list)) 51 | print("Length of test set: ", len(test_file_list)) 52 | 53 | # Load the dataset 54 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 55 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 56 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 57 | 58 | # Get the vocab size 59 | vocab_size = len(tokenizer) 60 | # Get the phrase length 61 | train_phrase_length = len(train_dataset.file_number_phrase_number) 62 | 63 | # Create the model 64 | config_decoder = TransfoXLConfig() 65 | config_decoder.vocab_size = vocab_size 66 | # config_decoder.max_position_embeddings = 4096 67 | # config_decoder.max_length = 4096 68 | config_decoder.pad_token_id = 0 69 | config_decoder.bos_token_id = tokenizer["BOS"] 70 | config_decoder.eos_token_id = tokenizer["EOS"] 71 | config_decoder.num_hidden_layers = 5 72 | config_decoder.num_attention_heads = 4 73 | config_decoder.d_model = 512 74 | config_decoder.d_inner = 2048 75 | config_decoder.d_embed = 512 76 | config_decoder.cutoffs = [0, vocab_size] 77 | config_decoder.div_val = 1 78 | config_decoder.untie_r = True 79 | config_decoder.tie_projs = [False] 80 | 81 | print(config_decoder) 82 | 83 | # Create the model 84 | model = TransfoXLLMHeadModel(config=config_decoder) 85 | 86 | # Print the number of parameters in the model 87 | num_params = sum(p.numel() for p in model.parameters()) 88 | print(f"Number of parameters in the model: {num_params}") 89 | 90 | # Create config for the Trainer 91 | USE_CUDA = cuda_available() 92 | print(f"USE_CUDA: {USE_CUDA}") 93 | if not cuda_available(): 94 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 95 | elif is_bf16_supported(): 96 | BF16 = BF16_EVAL = True 97 | FP16 = FP16_EVAL = False 98 | else: 99 | BF16 = BF16_EVAL = False 100 | FP16 = FP16_EVAL = True 101 | USE_MPS = not USE_CUDA and mps_available() 102 | 103 | 104 | class AccuracyMetricsCallback(TrainerCallback): 105 | """ 106 | Callback to calculate accuracy metrics during evaluation. 107 | """ 108 | def __init__(self, val_data): 109 | self.val_data = val_data 110 | 111 | def on_evaluate(self, args, state, control, model=None, **kwargs): 112 | """ 113 | Calculate accuracy metrics on the validation data during evaluation. 114 | """ 115 | model.eval() # Set the model to evaluation mode 116 | 117 | total_tokens = 0 118 | correct_tokens = 0 119 | 120 | with torch.no_grad(): 121 | 122 | for inputs in tqdm(self.val_data): 123 | input_ids = inputs['input_ids'] 124 | attention_mask = inputs['attention_mask'] 125 | 126 | input_ids = input_ids.unsqueeze(0).to("cuda") 127 | 128 | # Get the model's predictions 129 | outputs = model(input_ids) 130 | logits = outputs.logits 131 | 132 | # Calculate the number of correct and total tokens 133 | shift_logits = logits[..., :-1, :].contiguous() 134 | shift_labels = input_ids[..., 1:].contiguous() 135 | 136 | # Exclude padding tokens from calculation 137 | not_pad_mask = shift_labels != 0 138 | correct_tokens += (shift_logits.argmax(dim=-1) == shift_labels)[not_pad_mask].sum().item() 139 | total_tokens += not_pad_mask.sum().item() 140 | 141 | accuracy = correct_tokens / total_tokens 142 | 143 | metrics = { 144 | 'accuracy': accuracy, 145 | 'correct_tokens': correct_tokens, 146 | 'total_tokens': total_tokens, 147 | } 148 | 149 | # Log the metrics to the Trainer's console 150 | print(f"Validation Accuracy: {metrics['accuracy']}") 151 | 152 | # # Add the metrics to the Trainer's evaluation results 153 | # control.load_state_dict(metrics) 154 | 155 | return control 156 | 157 | # Create your custom callback 158 | accuracy_metrics_callback = AccuracyMetricsCallback(val_data=valid_dataset) 159 | 160 | # Define the training arguments 161 | training_args = TrainingArguments( 162 | output_dir=os.path.join(artifact_folder, "pop_music_transformer"), 163 | per_device_train_batch_size=batch_size, 164 | per_device_eval_batch_size=batch_size, 165 | save_strategy="steps", # "steps" or "epoch" 166 | save_steps=1000, 167 | save_total_limit=1, 168 | learning_rate=learning_rate, 169 | max_steps=int(train_phrase_length//batch_size)*epochs, 170 | evaluation_strategy="steps", 171 | eval_steps=1000, 172 | gradient_accumulation_steps=1, 173 | warmup_steps=500, 174 | gradient_checkpointing=False, 175 | optim="adafactor", 176 | seed=444, 177 | logging_strategy="steps", 178 | logging_steps=100, 179 | logging_dir=os.path.join(artifact_folder, "pop_music_transformer", "logs"), 180 | no_cuda=not USE_CUDA, 181 | fp16=FP16, 182 | fp16_full_eval=FP16_EVAL, 183 | bf16=BF16, 184 | bf16_full_eval=BF16_EVAL, 185 | load_best_model_at_end=True, 186 | metric_for_best_model="eval_loss", 187 | greater_is_better=False, 188 | report_to="tensorboard", 189 | run_name="pop_music_transformer", 190 | push_to_hub=False 191 | ) 192 | 193 | # Define the Trainer 194 | trainer = Trainer( 195 | model=model, 196 | args=training_args, 197 | train_dataset=train_dataset, 198 | eval_dataset=valid_dataset, 199 | # compute_metrics=compute_metrics, 200 | # preprocess_logits_for_metrics=preprocess_logits, 201 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] #accuracy_metrics_callback, 202 | ) 203 | 204 | # Train and save the model 205 | train_result = trainer.train() 206 | trainer.save_model() # Saves the tokenizer too 207 | trainer.log_metrics("train", train_result.metrics) 208 | trainer.save_metrics("train", train_result.metrics) 209 | trainer.save_state() 210 | -------------------------------------------------------------------------------- /baselines/compound_transformer/build_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | import glob 8 | from tqdm import tqdm 9 | 10 | # Parse command line arguments 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--config", type=str, default=os.path.normpath(r"configs/configs_os.yaml"), 13 | help="Path to the config file") 14 | args = parser.parse_args() 15 | 16 | # Load config file 17 | with open(args.config, 'r') as f: 18 | configs = yaml.safe_load(f) 19 | 20 | # Get the raw data folders 21 | artifact_folder = configs["raw_data"]["artifact_folder"] 22 | mono_folder = configs["raw_data"]["mono_folder"] 23 | json_folder = configs["raw_data"]["json_folder"] 24 | raw_data_folders = configs["raw_data"]["raw_data_folders"] 25 | 26 | # Get all files from json_folder 27 | json_files = glob.glob(json_folder + "/*.json") 28 | 29 | # Create dictionaries to store the tokenizer 30 | all_tokenizer_dict = {} 31 | pitch_tokenizer = {} 32 | family_tokenizer = {} 33 | metric_tokenizer = {} 34 | velocity_tokenizer = {} 35 | chord_tokenizer = {} 36 | duration_tokenizer = {} 37 | time_signature_tokenizer = {} 38 | 39 | # Initialize midi pitches from 0 to 127 to tokenizer dictionary 40 | pitch_tokenizer["Ignore_None"] = len(pitch_tokenizer) + 1 41 | for i in range(128): 42 | pitch_tokenizer[f"Pitch_{i}"] = len(pitch_tokenizer) + 1 43 | 44 | family_tokenizer["Family_Metric"] = len(family_tokenizer) + 1 45 | family_tokenizer["Family_Note"] = len(family_tokenizer) + 1 46 | family_tokenizer["BOS"] = len(family_tokenizer) + 1 47 | family_tokenizer["EOS"] = len(family_tokenizer) + 1 48 | 49 | metric_tokenizer["Ignore_None"] = len(metric_tokenizer) + 1 50 | metric_tokenizer["Bar_None"] = len(metric_tokenizer) + 1 51 | 52 | chord_tokenizer["Ignore_None"] = len(chord_tokenizer) + 1 53 | 54 | velocity_tokenizer["Ignore_None"] = len(velocity_tokenizer) + 1 55 | 56 | duration_tokenizer["Ignore_None"] = len(duration_tokenizer) + 1 57 | 58 | time_signature_tokenizer["Ignore_None"] = len(time_signature_tokenizer) + 1 59 | 60 | for file in tqdm(json_files): 61 | # Load the JSON file 62 | with open(file, 'r') as f: 63 | data = json.load(f) 64 | for phrase_number in data["phrases"].keys(): 65 | phrase = data['phrases'][phrase_number][0] 66 | tempo_location = data["metadata"]["tempo"] 67 | time_signature = data["metadata"]["time_signature"] 68 | 69 | if time_signature not in time_signature_tokenizer: 70 | time_signature_tokenizer[time_signature] = len(time_signature_tokenizer) + 1 71 | 72 | for note in phrase: 73 | if f'Position_{note[1]}' not in metric_tokenizer: 74 | metric_tokenizer[f'Position_{note[1]}'] = len(metric_tokenizer) + 1 75 | if f'Velocity_{note[5]}' not in velocity_tokenizer: 76 | velocity_tokenizer[f'Velocity_{note[5]}'] = len(velocity_tokenizer) + 1 77 | if f'Duration_{note[4]}' not in duration_tokenizer: 78 | duration_tokenizer[f'Duration_{note[4]}'] = len(duration_tokenizer) + 1 79 | 80 | # Add the tokenizer dictionaries to the all_tokenizer dictionary 81 | all_tokenizer_dict["pitch_tokenizer"] = pitch_tokenizer 82 | all_tokenizer_dict["family_tokenizer"] = family_tokenizer 83 | all_tokenizer_dict["metric_tokenizer"] = metric_tokenizer 84 | all_tokenizer_dict["velocity_tokenizer"] = velocity_tokenizer 85 | all_tokenizer_dict["chord_tokenizer"] = chord_tokenizer 86 | all_tokenizer_dict["duration_tokenizer"] = duration_tokenizer 87 | all_tokenizer_dict["time_signature_tokenizer"] = time_signature_tokenizer 88 | 89 | # Save the tokenizer dictionary to the artifact folder 90 | tokenizer_filepath = os.path.join(artifact_folder, "cp_tokenizer.json") 91 | print(f"Saving tokenizer dictionary to {tokenizer_filepath}") 92 | with open(tokenizer_filepath, 'w') as f: 93 | json.dump(all_tokenizer_dict, f) -------------------------------------------------------------------------------- /baselines/compound_transformer/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | # Append the path to the current working directory 14 | sys.path.append(os.getcwd()) 15 | from utils.utils import list_to_cp_encoding 16 | 17 | 18 | class JSONDataset(Dataset): 19 | def __init__(self, configs, file_list, mode="train", shuffle = False): 20 | self.mode = mode 21 | # Data dir 22 | self.data_dir = configs['raw_data']['json_folder'] 23 | self.file_list = file_list 24 | if shuffle: 25 | random.shuffle(self.file_list) 26 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 27 | self.file_number_phrase_number = [] 28 | for file_path in self.file_list: 29 | file_path = os.path.join(self.data_dir, file_path) 30 | with open(file_path, 'r') as f: 31 | data = json.load(f) 32 | phrase_number = len(data["phrases"].keys()) 33 | # Add the file path and phrase number to the list 34 | for i in range(phrase_number): 35 | self.file_number_phrase_number.append([file_path, i]) 36 | 37 | # Artifact folder 38 | self.artifact_folder = configs['raw_data']['artifact_folder'] 39 | # Load encoder tokenizer json file dictionary 40 | tokenizer_filepath = os.path.join(self.artifact_folder, "cp_tokenizer.json") 41 | # Load the tokenizer dictionary 42 | with open(tokenizer_filepath, 'r') as f: 43 | self.tokenizer = json.load(f) 44 | 45 | # Get the maximum sequence length 46 | self.decoder_max_sequence_length = 2048 47 | 48 | # Print length of dataset 49 | print("Length of dataset: ", len(self.file_list)) 50 | print("Length of phrases in dataset: ", len(self.file_number_phrase_number)) 51 | 52 | def __len__(self): 53 | return len(self.file_number_phrase_number) 54 | 55 | def transpose(self, phrase, pitch_change): 56 | encoding = copy.deepcopy(phrase) 57 | 58 | transposed_encoding = [ 59 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 60 | for event in encoding 61 | ] 62 | 63 | return transposed_encoding 64 | 65 | def transpose_key(self, current_key, semitones): 66 | keys = ["KS_A-", "KS_A", "KS_B-", "KS_B", "KS_C", "KS_D-", "KS_D", "KS_E-", "KS_E", "KS_F", "KS_F#", "KS_G"] 67 | 68 | # Find the index of the current key in the list 69 | current_index = keys.index(current_key) 70 | 71 | # Calculate the new index after transposing by the given semitones 72 | new_index = (current_index + semitones) % len(keys) 73 | 74 | # Return the new key 75 | return keys[new_index] 76 | 77 | def augment_phrase(self, target, current_key): 78 | if random.random() < 0.5: 79 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 80 | 81 | encoding = target 82 | 83 | # Find highest and lowest pitch values 84 | pitch_values = [event[3] for event in encoding] 85 | highest_pitch = max(pitch_values) 86 | lowest_pitch = min(pitch_values) 87 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 88 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 89 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 90 | if pitch_change < 0: 91 | pitch_change += 1 92 | else: 93 | pitch_change -= 1 94 | 95 | target = self.transpose(target, pitch_change) 96 | current_key = self.transpose_key(current_key, pitch_change) 97 | 98 | return target, current_key 99 | 100 | def __getitem__(self, idx): 101 | file_path = self.file_number_phrase_number[idx][0] 102 | phrase_number = self.file_number_phrase_number[idx][1] 103 | with open(file_path, 'r') as f: 104 | data = json.load(f) 105 | 106 | time_signature = data["metadata"]["time_signature"] 107 | key_signature = data["metadata"]["key_signature"] 108 | major_or_minor = data["metadata"]["major_or_minor"] 109 | # Get the phrase and the target 110 | # Get all phrases from phrase_number and after and concatenate them 111 | target = [] 112 | for i in range(phrase_number, len(data["phrases"].keys())): 113 | phrase = data["phrases"][str(i)][0] 114 | target += phrase 115 | 116 | # Extract an arbitrary phrase from random point till the last element of the list 117 | target_length = len(target) 118 | 119 | # Augment the phrases 120 | if self.mode == "train": 121 | target, key_signature = self.augment_phrase(target, key_signature) 122 | 123 | tempo_location = data["metadata"]["tempo"] 124 | 125 | # Convert the target to a list of lists 126 | target = list_to_cp_encoding(target, tempo_location, time_signature) 127 | 128 | BOS = ["BOS", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None"] 129 | EOS = ["EOS", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None"] 130 | target = [BOS] + target + [EOS] 131 | 132 | family_tensor = [] 133 | metric_tensor = [] 134 | pitch_tensor = [] 135 | velocity_tensor = [] 136 | duration_tensor = [] 137 | chord_tensor = [] 138 | time_signature_tensor = [] 139 | 140 | for event in target: 141 | family_tensor.append(self.tokenizer["family_tokenizer"][event[0]]) 142 | metric_tensor.append(self.tokenizer["metric_tokenizer"][event[1]]) 143 | pitch_tensor.append(self.tokenizer["pitch_tokenizer"][event[2]]) 144 | velocity_tensor.append(self.tokenizer["velocity_tokenizer"][event[3]]) 145 | duration_tensor.append(self.tokenizer["duration_tokenizer"][event[4]]) 146 | chord_tensor.append(self.tokenizer["chord_tokenizer"][event[5]]) 147 | time_signature_tensor.append(self.tokenizer["time_signature_tokenizer"][event[6]]) 148 | 149 | # Create tensors from the lists 150 | family_tensor = torch.tensor(family_tensor) 151 | metric_tensor = torch.tensor(metric_tensor) 152 | pitch_tensor = torch.tensor(pitch_tensor) 153 | velocity_tensor = torch.tensor(velocity_tensor) 154 | duration_tensor = torch.tensor(duration_tensor) 155 | chord_tensor = torch.tensor(chord_tensor) 156 | time_signature_tensor = torch.tensor(time_signature_tensor) 157 | 158 | # Pad the tensors to the maximum sequence length 159 | if len(family_tensor) < self.decoder_max_sequence_length: 160 | family_tensor = F.pad(family_tensor, (0, self.decoder_max_sequence_length - len(family_tensor))) 161 | metric_tensor = F.pad(metric_tensor, (0, self.decoder_max_sequence_length - len(metric_tensor))) 162 | pitch_tensor = F.pad(pitch_tensor, (0, self.decoder_max_sequence_length - len(pitch_tensor))) 163 | velocity_tensor = F.pad(velocity_tensor, (0, self.decoder_max_sequence_length - len(velocity_tensor))) 164 | duration_tensor = F.pad(duration_tensor, (0, self.decoder_max_sequence_length - len(duration_tensor))) 165 | chord_tensor = F.pad(chord_tensor, (0, self.decoder_max_sequence_length - len(chord_tensor))) 166 | time_signature_tensor = F.pad(time_signature_tensor, (0, self.decoder_max_sequence_length - len(time_signature_tensor))) 167 | else: 168 | family_tensor = family_tensor[:self.decoder_max_sequence_length] 169 | metric_tensor = metric_tensor[:self.decoder_max_sequence_length] 170 | pitch_tensor = pitch_tensor[:self.decoder_max_sequence_length] 171 | velocity_tensor = velocity_tensor[:self.decoder_max_sequence_length] 172 | duration_tensor = duration_tensor[:self.decoder_max_sequence_length] 173 | chord_tensor = chord_tensor[:self.decoder_max_sequence_length] 174 | time_signature_tensor = time_signature_tensor[:self.decoder_max_sequence_length] 175 | 176 | # Get the labels for each tensor shifted by one and padded by 0 at the end 177 | family_tensor_labels = family_tensor[1:] 178 | metric_tensor_labels = metric_tensor[1:] 179 | pitch_tensor_labels = pitch_tensor[1:] 180 | velocity_tensor_labels = velocity_tensor[1:] 181 | duration_tensor_labels = duration_tensor[1:] 182 | chord_tensor_labels = chord_tensor[1:] 183 | time_signature_tensor_labels = time_signature_tensor[1:] 184 | # Add the pad token to the end of the labels to match the sequence length 185 | family_tensor_labels = F.pad(family_tensor_labels, (0, 1), value=0) 186 | metric_tensor_labels = F.pad(metric_tensor_labels, (0, 1), value=0) 187 | pitch_tensor_labels = F.pad(pitch_tensor_labels, (0, 1), value=0) 188 | velocity_tensor_labels = F.pad(velocity_tensor_labels, (0, 1), value=0) 189 | duration_tensor_labels = F.pad(duration_tensor_labels, (0, 1), value=0) 190 | chord_tensor_labels = F.pad(chord_tensor_labels, (0, 1), value=0) 191 | time_signature_tensor_labels = F.pad(time_signature_tensor_labels, (0, 1), value=0) 192 | 193 | # Concatenate the tensors 194 | train_data = torch.cat((time_signature_tensor.unsqueeze(0), chord_tensor.unsqueeze(0), metric_tensor.unsqueeze(0), family_tensor.unsqueeze(0), pitch_tensor.unsqueeze(0), duration_tensor.unsqueeze(0), velocity_tensor.unsqueeze(0)), dim=0) 195 | labels = torch.cat((time_signature_tensor_labels.unsqueeze(0), chord_tensor_labels.unsqueeze(0), metric_tensor_labels.unsqueeze(0), family_tensor_labels.unsqueeze(0), pitch_tensor_labels.unsqueeze(0), duration_tensor_labels.unsqueeze(0), velocity_tensor_labels.unsqueeze(0)), dim=0) 196 | 197 | # Switch the dimensions 198 | train_data = train_data.permute(1, 0) 199 | labels = labels.permute(1, 0) 200 | 201 | # Batch mask 202 | attention_mask = torch.ones_like(family_tensor) 203 | return {"x": train_data, "y": labels, "loss_mask": attention_mask} 204 | 205 | 206 | if __name__ == "__main__": 207 | 208 | # Parse command line arguments 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument("--config", type=str, default=os.path.normpath(r"configs/configs_os.yaml"), 211 | help="Path to the config file") 212 | args = parser.parse_args() 213 | 214 | # Load config file 215 | with open(args.config, 'r') as f: 216 | configs = yaml.safe_load(f) 217 | 218 | batch_size = 4 219 | 220 | # Artifact folder 221 | artifact_folder = configs['raw_data']['artifact_folder'] 222 | # Load encoder tokenizer json file dictionary 223 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 224 | # Load the tokenizer dictionary 225 | with open(tokenizer_filepath, 'r') as f: 226 | tokenizer = json.load(f) 227 | 228 | 229 | # Open the train, validation, and test sets json files 230 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 231 | train_file_list = json.load(f) 232 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 233 | valid_file_list = json.load(f) 234 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 235 | test_file_list = json.load(f) 236 | 237 | # Print length of train, validation, and test sets 238 | print("Length of train set: ", len(train_file_list)) 239 | print("Length of validation set: ", len(valid_file_list)) 240 | print("Length of test set: ", len(test_file_list)) 241 | 242 | # Load the dataset 243 | dataset = JSONDataset(configs, train_file_list, mode="train") 244 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 245 | 246 | for n, data in enumerate(dataset): 247 | # print shape and type of tensor 248 | print(data["input_ids"].shape, data["input_ids"].dtype) 249 | print(data["labels"].shape, data["labels"].dtype) 250 | print(data["attention_mask"].shape, data["attention_mask"].dtype) 251 | if n > 0: 252 | break -------------------------------------------------------------------------------- /baselines/compound_transformer/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available 3 | import yaml 4 | import json 5 | import os 6 | import argparse 7 | import random 8 | import numpy as np 9 | from cp_model import LinearAttentionTransformerLM 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | # Append the path to the current working directory 14 | sys.path.append(os.getcwd()) 15 | from utils.utils import find_beats_in_bar, list_to_cp_encoding, cp_to_list_encoding, encoding_to_midi 16 | from phrase_refiner.transformations import Melodic_Development 17 | 18 | # Parse command line arguments 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 21 | help="Path to the config file") 22 | args = parser.parse_args() 23 | 24 | # Load config file 25 | with open(args.config, 'r') as f: 26 | configs = yaml.safe_load(f) 27 | 28 | # Artifact folder 29 | artifact_folder = configs['raw_data']['artifact_folder'] 30 | 31 | # Load tokenizer json file dictionary 32 | tokenizer_filepath = os.path.join(artifact_folder, 'cp_tokenizer.json') 33 | with open(tokenizer_filepath, 'r') as f: 34 | tokenizer = json.load(f) 35 | 36 | reverse_dec_tokenizer = {} 37 | for key, value in tokenizer.items(): 38 | reverse_dec_tokenizer[key] = {v: k for k, v in value.items()} 39 | 40 | # Get the vocab size 41 | vocab_size = [len(tokenizer["time_signature_tokenizer"])+1, len(tokenizer["chord_tokenizer"])+1, len(tokenizer["metric_tokenizer"])+1, len(tokenizer["family_tokenizer"])+1, len(tokenizer["pitch_tokenizer"])+1, len(tokenizer["duration_tokenizer"])+1, len(tokenizer["velocity_tokenizer"])+1] 42 | print("Vocab size: ", vocab_size) 43 | 44 | # Load the model 45 | model = LinearAttentionTransformerLM( 46 | num_tokens = vocab_size, 47 | dim = 512, 48 | heads = 4, 49 | depth = 4, 50 | max_seq_len = 2048, 51 | causal = True, 52 | ff_dropout = 0, 53 | attn_layer_dropout = 0, 54 | attn_dropout = 0, 55 | emb_dim = 512, 56 | dim_head = 128, 57 | blindspot_size = 64, 58 | n_local_attn_heads = 4, 59 | local_attn_window_size = 128, 60 | reversible = True, 61 | ff_chunks = 2, 62 | ff_glu = True, 63 | attend_axially = False, 64 | shift_tokens = True 65 | ) 66 | 67 | device=0 68 | # Load the state dictionary from the .bin file 69 | model_state_dict = torch.load('artifacts/cp_transformer/pytorch_model.bin') #torch.device(device) 70 | 71 | # Load the state dictionary into the model 72 | model.load_state_dict(model_state_dict) 73 | 74 | model.eval() 75 | model.to("cuda" if cuda_available() else "cpu") 76 | 77 | 78 | # Load test file list 79 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 80 | test_file_list = json.load(f) 81 | 82 | 83 | def generate(test_file, configs, model, gen_length=256): 84 | # Read test file as json 85 | with open(os.path.join(configs['raw_data']['json_folder'], test_file), "r") as f: 86 | test_phrases = json.load(f) 87 | 88 | tempo_location = test_phrases['metadata']['tempo'] 89 | key_signature = test_phrases['metadata']['key_signature'] 90 | time_signature = test_phrases['metadata']['time_signature'] 91 | beats_in_bar = find_beats_in_bar(time_signature) 92 | 93 | # Get the first phrase from test file 94 | motif = test_phrases['phrases']['0'][0] 95 | 96 | # Convert motif to cp encoding 97 | cp_encoding = list_to_cp_encoding(motif, tempo_location, time_signature) 98 | 99 | # Tokenize cp encoding 100 | BOS = ["BOS", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None", "Ignore_None"] 101 | cp_encoding = [BOS] + cp_encoding 102 | 103 | # Convert cp encoding to tensor 104 | family_tensor = [] 105 | metric_tensor = [] 106 | pitch_tensor = [] 107 | velocity_tensor = [] 108 | duration_tensor = [] 109 | chord_tensor = [] 110 | time_signature_tensor = [] 111 | 112 | for event in cp_encoding: 113 | family_tensor.append(tokenizer["family_tokenizer"][event[0]]) 114 | metric_tensor.append(tokenizer["metric_tokenizer"][event[1]]) 115 | pitch_tensor.append(tokenizer["pitch_tokenizer"][event[2]]) 116 | velocity_tensor.append(tokenizer["velocity_tokenizer"][event[3]]) 117 | duration_tensor.append(tokenizer["duration_tokenizer"][event[4]]) 118 | chord_tensor.append(tokenizer["chord_tokenizer"][event[5]]) 119 | time_signature_tensor.append(tokenizer["time_signature_tokenizer"][event[6]]) 120 | 121 | # Create tensors from the lists 122 | family_tensor = torch.tensor(family_tensor) 123 | metric_tensor = torch.tensor(metric_tensor) 124 | pitch_tensor = torch.tensor(pitch_tensor) 125 | velocity_tensor = torch.tensor(velocity_tensor) 126 | duration_tensor = torch.tensor(duration_tensor) 127 | chord_tensor = torch.tensor(chord_tensor) 128 | time_signature_tensor = torch.tensor(time_signature_tensor) 129 | 130 | # Length of the tensor 131 | input_length = len(family_tensor) 132 | print("Length of family tensor: ", input_length) 133 | 134 | # Pad the tensors to the maximum sequence length 135 | if len(family_tensor) < 2048: 136 | family_tensor = torch.cat((family_tensor, torch.zeros(2048 - len(family_tensor)).long())) 137 | metric_tensor = torch.cat((metric_tensor, torch.zeros(2048 - len(metric_tensor)).long())) 138 | pitch_tensor = torch.cat((pitch_tensor, torch.zeros(2048 - len(pitch_tensor)).long())) 139 | velocity_tensor = torch.cat((velocity_tensor, torch.zeros(2048 - len(velocity_tensor)).long())) 140 | duration_tensor = torch.cat((duration_tensor, torch.zeros(2048 - len(duration_tensor)).long())) 141 | chord_tensor = torch.cat((chord_tensor, torch.zeros(2048 - len(chord_tensor)).long())) 142 | time_signature_tensor = torch.cat((time_signature_tensor, torch.zeros(2048 - len(time_signature_tensor)).long())) 143 | 144 | 145 | # Concatenate the tensors 146 | input_data = torch.cat((time_signature_tensor.unsqueeze(0), chord_tensor.unsqueeze(0), metric_tensor.unsqueeze(0), family_tensor.unsqueeze(0), pitch_tensor.unsqueeze(0), duration_tensor.unsqueeze(0), velocity_tensor.unsqueeze(0)), dim=0) 147 | # Switch the dimensions 148 | input_data = input_data.permute(1, 0) 149 | 150 | # Add batch dimension 151 | input_data = input_data.unsqueeze(0) 152 | 153 | # Move to cuda 154 | input_data = input_data.to("cuda" if cuda_available() else "cpu") 155 | 156 | # Generate the continuation 157 | with torch.no_grad(): 158 | final_res = [] 159 | # Append input_data to final_res as an array after removing the batch dimension 160 | final_res.append(input_data[0, :input_length, :].cpu().numpy()) 161 | 162 | h, y_type = model.forward_hidden(input_data, is_training=True) 163 | 164 | while True: 165 | # Get the time step corresponding to the length of the input before padding 166 | h = h[:, input_length-1, :] 167 | y_type = y_type[:, input_length-1, :] 168 | 169 | # sample others 170 | next_arr = model.forward_output_sampling(h, y_type, intervene=False) 171 | if reverse_dec_tokenizer['family_tokenizer'][next_arr[3]] == 'EOS' and input_length < gen_length: 172 | next_arr = model.forward_output_sampling(h, y_type, intervene=True) 173 | 174 | final_res.append(next_arr[None, ...]) 175 | 176 | # forward 177 | input_ = torch.from_numpy(next_arr).long().to("cuda" if cuda_available() else "cpu") 178 | input_ = input_.unsqueeze(0).unsqueeze(0) # (1, 1, 7) 179 | # Add input_ to the input_data at the input_length time step before the paddings 180 | input_data[:, input_length, :] = input_ 181 | input_length += 1 182 | 183 | # Do a forward pass again 184 | h, y_type = model.forward_hidden( 185 | input_data, is_training=True) 186 | 187 | # end of sequence 188 | if reverse_dec_tokenizer['family_tokenizer'][next_arr[3]] == 'EOS' or input_length > gen_length-1: 189 | break 190 | 191 | print('\n--------[Done]--------') 192 | final_res = np.concatenate(final_res) 193 | print(final_res.shape) # (2048, 7) 194 | 195 | # Convert the final_res to a list 196 | final_res_list = final_res.tolist() 197 | 198 | # Convert the list to cp encoding 199 | cp_encoding = [] 200 | for event in final_res_list: 201 | family_event = reverse_dec_tokenizer['family_tokenizer'][event[3]] 202 | metric_event = reverse_dec_tokenizer['metric_tokenizer'][event[2]] 203 | pitch_event = reverse_dec_tokenizer['pitch_tokenizer'][event[4]] 204 | velocity_event = reverse_dec_tokenizer['velocity_tokenizer'][event[6]] if event[6] != 0 else "Ignore_None" 205 | duration_event = reverse_dec_tokenizer['duration_tokenizer'][event[5]] 206 | chord_event = reverse_dec_tokenizer['chord_tokenizer'][event[1]] if event[1] != 0 else "Ignore_None" 207 | time_signature_event = reverse_dec_tokenizer['time_signature_tokenizer'][event[0]] 208 | cp_encoding.append([family_event, metric_event, pitch_event, velocity_event, duration_event, chord_event, time_signature_event]) 209 | 210 | # Remove the BOS token 211 | cp_encoding = cp_encoding[1:] 212 | # Remove the EOS token if it exists 213 | if cp_encoding[-1][0] == "EOS": 214 | cp_encoding = cp_encoding[:-1] 215 | 216 | # Convert cp encoding to list encoding 217 | list_encoding = cp_to_list_encoding(cp_encoding) 218 | 219 | # Fix the bars 220 | melodic_development_obj = Melodic_Development(beats_in_bar=beats_in_bar) 221 | list_encoding = melodic_development_obj.fix_bars(list_encoding) 222 | 223 | # Create an output folder if it doesn't exist 224 | output_folder = "output/compound_word" 225 | if not os.path.exists(output_folder): 226 | os.makedirs(output_folder) 227 | output_filepath = os.path.join(output_folder, test_file.split(".")[0] + ".mid") 228 | 229 | # Write the structure to a MIDI file 230 | encoding_to_midi(list_encoding, tempo_location, time_signature, output_filepath) 231 | 232 | 233 | if __name__ == "__main__": 234 | for test_file in test_file_list: 235 | while True: 236 | try: 237 | generate(test_file, configs, model, 300) 238 | print(f"Generated: {test_file}") 239 | break 240 | except Exception as e: 241 | print(f"Error generating: {test_file}") 242 | print(f"Error message: {str(e)}") 243 | continue 244 | print("All files generated successfully!") -------------------------------------------------------------------------------- /baselines/compound_transformer/reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from operator import itemgetter 4 | from torch.autograd.function import Function 5 | from torch.utils.checkpoint import get_device_states, set_device_states 6 | 7 | # for routing arguments into the functions of the reversible layer 8 | def route_args(router, args, depth): 9 | routed_args = [(dict(), dict()) for _ in range(depth)] 10 | matched_keys = [key for key in args.keys() if key in router] 11 | 12 | for key in matched_keys: 13 | val = args[key] 14 | for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): 15 | new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) 16 | routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) 17 | return routed_args 18 | 19 | def layer_drop(layers, prob): 20 | to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob 21 | blocks = [block for block, drop in zip(layers, to_drop) if not drop] 22 | blocks = layers[:1] if len(blocks) == 0 else blocks 23 | return blocks 24 | 25 | # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 26 | class Deterministic(nn.Module): 27 | def __init__(self, net): 28 | super().__init__() 29 | self.net = net 30 | self.cpu_state = None 31 | self.cuda_in_fwd = None 32 | self.gpu_devices = None 33 | self.gpu_states = None 34 | 35 | def record_rng(self, *args): 36 | self.cpu_state = torch.get_rng_state() 37 | if torch.cuda._initialized: 38 | self.cuda_in_fwd = True 39 | self.gpu_devices, self.gpu_states = get_device_states(*args) 40 | 41 | def forward(self, *args, record_rng = False, set_rng = False, **kwargs): 42 | if record_rng: 43 | self.record_rng(*args) 44 | 45 | if not set_rng: 46 | return self.net(*args, **kwargs) 47 | 48 | rng_devices = [] 49 | if self.cuda_in_fwd: 50 | rng_devices = self.gpu_devices 51 | 52 | with torch.random.fork_rng(devices=rng_devices, enabled=True): 53 | torch.set_rng_state(self.cpu_state) 54 | if self.cuda_in_fwd: 55 | set_device_states(self.gpu_devices, self.gpu_states) 56 | return self.net(*args, **kwargs) 57 | 58 | # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 59 | # once multi-GPU is confirmed working, refactor and send PR back to source 60 | class ReversibleBlock(nn.Module): 61 | def __init__(self, f, g): 62 | super().__init__() 63 | self.f = Deterministic(f) 64 | self.g = Deterministic(g) 65 | 66 | def forward(self, x, f_args = {}, g_args = {}): 67 | x1, x2 = torch.chunk(x, 2, dim=2) 68 | y1, y2 = None, None 69 | 70 | with torch.no_grad(): 71 | y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 72 | y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 73 | 74 | return torch.cat([y1, y2], dim=2) 75 | 76 | def backward_pass(self, y, dy, f_args = {}, g_args = {}): 77 | y1, y2 = torch.chunk(y, 2, dim=2) 78 | del y 79 | 80 | dy1, dy2 = torch.chunk(dy, 2, dim=2) 81 | del dy 82 | 83 | with torch.enable_grad(): 84 | y1.requires_grad = True 85 | gy1 = self.g(y1, set_rng=True, **g_args) 86 | torch.autograd.backward(gy1, dy2) 87 | 88 | with torch.no_grad(): 89 | x2 = y2 - gy1 90 | del y2, gy1 91 | 92 | dx1 = dy1 + y1.grad 93 | del dy1 94 | y1.grad = None 95 | 96 | with torch.enable_grad(): 97 | x2.requires_grad = True 98 | fx2 = self.f(x2, set_rng=True, **f_args) 99 | torch.autograd.backward(fx2, dx1, retain_graph=True) 100 | 101 | with torch.no_grad(): 102 | x1 = y1 - fx2 103 | del y1, fx2 104 | 105 | dx2 = dy2 + x2.grad 106 | del dy2 107 | x2.grad = None 108 | 109 | x = torch.cat([x1, x2.detach()], dim=2) 110 | dx = torch.cat([dx1, dx2], dim=2) 111 | 112 | return x, dx 113 | 114 | class _ReversibleFunction(Function): 115 | @staticmethod 116 | def forward(ctx, x, blocks, args): 117 | ctx.args = args 118 | for block, kwarg in zip(blocks, args): 119 | x = block(x, **kwarg) 120 | ctx.y = x.detach() 121 | ctx.blocks = blocks 122 | return x 123 | 124 | @staticmethod 125 | def backward(ctx, dy): 126 | y = ctx.y 127 | args = ctx.args 128 | for block, kwargs in zip(ctx.blocks[::-1], args[::-1]): 129 | y, dy = block.backward_pass(y, dy, **kwargs) 130 | return dy, None, None 131 | 132 | 133 | class SequentialSequence(nn.Module): 134 | def __init__(self, layers, args_route = {}, layer_dropout = 0.): 135 | super().__init__() 136 | assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers' 137 | self.layers = layers 138 | self.args_route = args_route 139 | self.layer_dropout = layer_dropout 140 | 141 | def forward(self, x, **kwargs): 142 | args = route_args(self.args_route, kwargs, len(self.layers)) 143 | layers_and_args = list(zip(self.layers, args)) 144 | 145 | if self.training and self.layer_dropout > 0: 146 | layers_and_args = layer_drop(layers_and_args, self.layer_dropout) 147 | 148 | for (f, g), (f_args, g_args) in layers_and_args: 149 | x = x + f(x, **f_args) 150 | x = x + g(x, **g_args) 151 | return x 152 | 153 | class ReversibleSequence(nn.Module): 154 | def __init__(self, blocks, args_route = {}, layer_dropout = 0.): 155 | super().__init__() 156 | self.args_route = args_route 157 | self.layer_dropout = layer_dropout 158 | self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks]) 159 | 160 | def forward(self, x, **kwargs): 161 | x = torch.cat([x, x], dim=-1) 162 | 163 | blocks = self.blocks 164 | args = route_args(self.args_route, kwargs, len(blocks)) 165 | args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args)) 166 | 167 | layers_and_args = list(zip(blocks, args)) 168 | 169 | if self.training and self.layer_dropout > 0: 170 | layers_and_args = layer_drop(layers_and_args, self.layer_dropout) 171 | blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1)) 172 | 173 | out = _ReversibleFunction.apply(x, blocks, args) 174 | return torch.stack(out.chunk(2, dim=-1)).mean(dim=0) -------------------------------------------------------------------------------- /baselines/compound_transformer/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | import random 8 | from torch import Tensor, argmax 9 | from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback 10 | from evaluate import load as load_metric 11 | from data_loader import JSONDataset, DataLoader 12 | import sys 13 | import argparse 14 | from tqdm import tqdm 15 | from cp_model import LinearAttentionTransformerLM 16 | 17 | # Parse command line arguments 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/PhraseBuilder/configs/configs_os.yaml"), 20 | help="Path to the config file") 21 | args = parser.parse_args() 22 | 23 | # Load config file 24 | with open(args.config, 'r') as f: 25 | configs = yaml.safe_load(f) 26 | 27 | batch_size = 116 28 | learning_rate = 0.0001 29 | epochs = 30 30 | 31 | # Artifact folder 32 | artifact_folder = configs['raw_data']['artifact_folder'] 33 | # Load encoder tokenizer json file dictionary 34 | tokenizer_filepath = os.path.join(artifact_folder, "cp_tokenizer.json") 35 | # Load the tokenizer dictionary 36 | with open(tokenizer_filepath, 'r') as f: 37 | tokenizer = json.load(f) 38 | 39 | 40 | # Open the train, validation, and test sets json files 41 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 42 | train_file_list = json.load(f) 43 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 44 | valid_file_list = json.load(f) 45 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 46 | test_file_list = json.load(f) 47 | 48 | # Print length of train, validation, and test sets 49 | print("Length of train set: ", len(train_file_list)) 50 | print("Length of validation set: ", len(valid_file_list)) 51 | print("Length of test set: ", len(test_file_list)) 52 | 53 | # Load the dataset 54 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 55 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 56 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 57 | valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) 58 | 59 | # Get the vocab size 60 | vocab_size = [len(tokenizer["time_signature_tokenizer"])+1, len(tokenizer["chord_tokenizer"])+1, len(tokenizer["metric_tokenizer"])+1, len(tokenizer["family_tokenizer"])+1, len(tokenizer["pitch_tokenizer"])+1, len(tokenizer["duration_tokenizer"])+1, len(tokenizer["velocity_tokenizer"])+1] 61 | print("Vocab size: ", vocab_size) 62 | 63 | # Get the phrase length 64 | train_phrase_length = len(train_dataset.file_number_phrase_number) 65 | 66 | model = LinearAttentionTransformerLM( 67 | num_tokens = vocab_size, 68 | dim = 512, 69 | heads = 4, 70 | depth = 4, 71 | max_seq_len = 2048, 72 | causal = True, # auto-regressive or not 73 | ff_dropout = 0.1, # dropout for feedforward 74 | attn_layer_dropout = 0.1, # dropout right after self-attention layer 75 | attn_dropout = 0.1, # dropout post-attention 76 | emb_dim = 512, # embedding factorization, to save on memory 77 | dim_head = 128, # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads 78 | blindspot_size = 64, # this gives the q(kv) attention a blindspot of 64 tokens back in the causal case, but gives back an order of magnitude return in memory savings. should be paired with local attention of at least a window size of this setting. setting this to 1 will allow for full q(kv) attention of past 79 | n_local_attn_heads = 4, # number of local attention heads for (qk)v attention. this can be a tuple specifying the exact number of local attention heads at that depth 80 | local_attn_window_size = 128, # receptive field of the local attention 81 | reversible = True, # use reversible nets, from Reformer paper 82 | ff_chunks = 2, # feedforward chunking, from Reformer paper 83 | ff_glu = True, # use GLU variant for feedforward 84 | attend_axially = False, # will fold the sequence by the local attention window size, and do an extra strided attention followed by a feedforward with the cheap q(kv) attention 85 | shift_tokens = True # add single token shifting, for great improved convergence 86 | ) 87 | 88 | # Print the number of parameters in the model 89 | num_params = sum(p.numel() for p in model.parameters()) 90 | print(f"Number of parameters in the model: {num_params}") 91 | 92 | # Create config for the Trainer 93 | USE_CUDA = cuda_available() 94 | print(f"USE_CUDA: {USE_CUDA}") 95 | if not cuda_available(): 96 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 97 | elif is_bf16_supported(): 98 | BF16 = BF16_EVAL = False #True 99 | FP16 = FP16_EVAL = False 100 | else: 101 | BF16 = BF16_EVAL = False 102 | FP16 = FP16_EVAL = True 103 | USE_MPS = not USE_CUDA and mps_available() 104 | 105 | loss_func = torch.nn.CrossEntropyLoss(ignore_index=0) 106 | def calc_loss(predict, target): 107 | loss = loss_func(predict, target) 108 | return loss 109 | 110 | # subclass trainer 111 | class CustomTrainer(Trainer): 112 | def compute_loss(self, model, inputs, return_outputs=False): 113 | x = inputs.pop("x") 114 | target = inputs.pop("y") 115 | # loss_mask = inputs.pop("loss_mask") 116 | 117 | y_tempo, y_chord, y_barbeat, y_type, y_pitch, y_duration, y_velocity = model(x, target, is_training=True) 118 | # reshape (b, s, f) -> (b, f, s) 119 | y_tempo = y_tempo[:, ...].permute(0, 2, 1) 120 | y_chord = y_chord[:, ...].permute(0, 2, 1) 121 | y_barbeat = y_barbeat[:, ...].permute(0, 2, 1) 122 | y_type = y_type[:, ...].permute(0, 2, 1) 123 | y_pitch = y_pitch[:, ...].permute(0, 2, 1) 124 | y_duration = y_duration[:, ...].permute(0, 2, 1) 125 | y_velocity = y_velocity[:, ...].permute(0, 2, 1) 126 | 127 | # loss 128 | loss_tempo = calc_loss( 129 | y_tempo, target[..., 0]) 130 | loss_chord = calc_loss( 131 | y_chord, target[..., 1]) 132 | loss_barbeat = calc_loss( 133 | y_barbeat, target[..., 2]) 134 | loss_type = calc_loss( 135 | y_type, target[..., 3]) 136 | loss_pitch = calc_loss( 137 | y_pitch, target[..., 4]) 138 | loss_duration = calc_loss( 139 | y_duration, target[..., 5]) 140 | loss_velocity = calc_loss( 141 | y_velocity, target[..., 6]) 142 | 143 | loss = loss_tempo + loss_chord + loss_barbeat + loss_type + loss_pitch + loss_duration + loss_velocity 144 | loss = loss / 7 145 | # Convert to BFLoat16 146 | loss = loss.to(torch.bfloat16) 147 | outputs = (y_tempo, y_chord, y_barbeat, y_type, y_pitch, y_duration, y_velocity) 148 | 149 | return (loss, outputs) if return_outputs else loss 150 | 151 | # Define the training arguments 152 | training_args = TrainingArguments( 153 | output_dir=os.path.join(artifact_folder, "cp_transformer"), 154 | per_device_train_batch_size=batch_size, 155 | per_device_eval_batch_size=batch_size, 156 | save_strategy="steps", # "steps" or "epoch" 157 | save_steps=1000, 158 | save_total_limit=1, 159 | learning_rate=learning_rate, 160 | max_steps=int(train_phrase_length//batch_size)*epochs, 161 | evaluation_strategy="steps", 162 | eval_steps=1000, 163 | gradient_accumulation_steps=1, 164 | gradient_checkpointing=False, 165 | max_grad_norm=3.0, 166 | optim="adafactor", 167 | seed=444, 168 | logging_strategy="steps", 169 | logging_steps=100, 170 | logging_dir=os.path.join(artifact_folder, "cp_transformer", "logs"), 171 | no_cuda=not USE_CUDA, 172 | fp16=FP16, 173 | fp16_full_eval=FP16_EVAL, 174 | bf16=BF16, 175 | bf16_full_eval=BF16_EVAL, 176 | load_best_model_at_end=True, 177 | metric_for_best_model="eval_loss", 178 | greater_is_better=False, 179 | report_to="tensorboard", 180 | run_name="cp_transformer", 181 | push_to_hub=False, 182 | label_names=["y"], 183 | ) 184 | 185 | # Define the Trainer 186 | trainer = CustomTrainer( 187 | model=model, 188 | args=training_args, 189 | train_dataset=train_dataset, 190 | eval_dataset=valid_dataset, 191 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] #accuracy_metrics_callback, 192 | ) 193 | 194 | # Train and save the model 195 | train_result = trainer.train() 196 | trainer.save_model() # Saves the tokenizer too 197 | trainer.log_metrics("train", train_result.metrics) 198 | trainer.save_metrics("train", train_result.metrics) 199 | trainer.save_state() -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/configs/__init__.py -------------------------------------------------------------------------------- /configs/configs_custom.yaml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | phrase_refinement_model: 4 | encoder_max_sequence_length: 1024 5 | decoder_max_sequence_length: 512 6 | num_layers: 4 7 | num_heads: 4 8 | hidden_size: 512 9 | intermediate_size: 2048 10 | phrase_generation_model: 11 | encoder_max_sequence_length: 2048 12 | decoder_max_sequence_length: 512 13 | num_layers: 4 14 | num_heads: 4 15 | hidden_size: 512 16 | intermediate_size: 2048 17 | phrase_selection_model: 18 | max_sequence_length: 512 19 | num_layers: 4 20 | num_heads: 4 21 | hidden_size: 512 22 | intermediate_size: 2048 23 | projection_size: 128 24 | structure_derivation_model: 25 | max_sequence_length: 1024 26 | num_layers: 4 27 | num_heads: 4 28 | hidden_size: 512 29 | intermediate_size: 2048 30 | projection_size: 128 31 | fusion_model: 32 | encoder_max_sequence_length: 1024 33 | decoder_max_sequence_length: 512 34 | num_layers: 12 35 | num_heads: 8 36 | hidden_size: 512 37 | intermediate_size: 2048 38 | 39 | training: 40 | phrase_refinement: 41 | epochs: 30 42 | batch_size: 128 43 | learning_rate: 0.0001 44 | validation_split: 0.1 45 | test_split: 100 46 | weight_decay: 0.01 47 | max_grad_norm: 3.0 48 | gradient_accumulation_steps: 1 49 | phrase_generation: 50 | epochs: 30 51 | batch_size: 56 52 | learning_rate: 0.0001 53 | weight_decay: 0.01 54 | max_grad_norm: 3.0 55 | gradient_accumulation_steps: 1 56 | phrase_selection: 57 | epochs: 30 58 | batch_size: 256 59 | learning_rate: 0.0001 60 | weight_decay: 0.01 61 | gradient_accumulation_steps: 1 62 | structure_derivation: 63 | epochs: 30 64 | batch_size: 128 65 | learning_rate: 0.0001 66 | weight_decay: 0.01 67 | gradient_accumulation_steps: 1 68 | fusion: 69 | epochs: 700 70 | batch_size: 80 71 | learning_rate: 0.0001 72 | weight_decay: 0.01 73 | gradient_accumulation_steps: 1 74 | 75 | raw_data: 76 | raw_data_folders: 77 | dataset_1: 78 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/essen 79 | file_extension: krn 80 | annotation_filepath: data/annotations/essen_sequences-1.1.jsonl/essen_sequences.jsonl 81 | dataset_2: 82 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/mtc-ann-2.0.1/MTC-ANN-2.0.1/krn 83 | file_extension: krn 84 | annotation_filepath: data/annotations/MTC-ANN-2.0.1_sequences-1.1.jsonl/mtcann_sequences.jsonl 85 | dataset_3: 86 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/MTC-FS-INST-2.0/MTC-FS-INST-2.0/krn 87 | file_extension: krn 88 | annotation_filepath: data/annotations/MTC-FS-INST-2.0_sequences-1.1.jsonl/mtcfsinst_sequences.jsonl 89 | mono_folder: data/Mono_Midi_Files 90 | json_folder: data/extracted_phrases 91 | artifact_folder: artifacts 92 | 93 | generation: 94 | metadata: 95 | key_signature: "C" 96 | major_or_minor: "major" 97 | time_signature: "4/4" # "4/4", "3/4", "2/4", "6/8" 98 | tempo: 90 99 | prompt: 100 | # The prompt is a list of lists, where each inner list represents a note. 101 | # Each note is represented by a list of six values: [Bar number, onset, instrument, pitch, duration, velocity] 102 | # Duration value of 1 means a quarter note, 0.5 means an eighth note, etc. 103 | # The values are as follows following the example of Twinkle Twinkle: 104 | phrase_1: [[0,0,0,60,0.5,91], [0,0.5,0,60,0.5,91], [0,1,0,67,0.5,91], [0,1.5,0,67,0.5,91], [0,2,0,69,0.5,91], [0,2.5,0,69,0.5,91], [0,3,0,67,1,91], [1,0,0,65,0.5,91], [1,0.5,0,65,0.5,91], [1,1,0,64,0.5,91], [1,1.5,0,64,0.5,91], [1,2,0,62,0.5,91], [1,2.5,0,62,0.5,91], [1,3,0,60,1,91]] 105 | # phrase_1: [[0,0,0,62,0.25,91],[0,0.25,0,62,0.25,91],[0,0.5,0,67,0.5,91],[0,1,0,67,0.5,91],[0,1.5,0,71,0.5,91],[1,0,0,71,0.25,91],[1,0.25,0,71,0.25,91],[1,0.5,0,74,0.5,91],[1,1,0,76,0.25,91],[1,1.25,0,74,0.25,91],[1,1.5,0,71,0.5,91],[2,0,0,71,0.25,91],[2,0.25,0,71,0.25,91],[2,0.5,0,74,0.5,91],[2,1,0,72,0.5,91],[2,1.5,0,69,0.5,91],[3,0,0,69,0.25,91],[3,0.25,0,69,0.25,91],[3,0.5,0,71,0.5,91],[3,1,0,79,0.5,91],[3,1.5,0,83,0.5,91]] 106 | use_velocity: false 107 | use_phrase_selection: true 108 | structure: ABBBAA 109 | phrases_per_section: [5, 1, 1, 5, 5, 3] 110 | transformations: [] # "contract_melody", "retrograde_melody_pitch_rhythm", "reduce_melody", "invert_melody_tonal", "expand_melody" 111 | allow_modulation: false 112 | ratio: 113 | phrase_generation_frequency: 2 114 | phrase_refinement_frequency: 1 115 | motif_repetition_frequency: 50 116 | write_midi: false # Change to false to write to mxl -------------------------------------------------------------------------------- /configs/configs_os.yaml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | phrase_refinement_model: 4 | encoder_max_sequence_length: 1024 5 | decoder_max_sequence_length: 512 6 | num_layers: 4 7 | num_heads: 4 8 | hidden_size: 512 9 | intermediate_size: 2048 10 | phrase_generation_model: 11 | encoder_max_sequence_length: 2048 12 | decoder_max_sequence_length: 512 13 | num_layers: 4 14 | num_heads: 4 15 | hidden_size: 512 16 | intermediate_size: 2048 17 | phrase_selection_model: 18 | max_sequence_length: 512 19 | num_layers: 4 20 | num_heads: 4 21 | hidden_size: 512 22 | intermediate_size: 2048 23 | projection_size: 128 24 | structure_derivation_model: 25 | max_sequence_length: 1024 26 | num_layers: 4 27 | num_heads: 4 28 | hidden_size: 512 29 | intermediate_size: 2048 30 | projection_size: 128 31 | fusion_model: 32 | encoder_max_sequence_length: 1024 33 | decoder_max_sequence_length: 512 34 | num_layers: 12 35 | num_heads: 8 36 | hidden_size: 512 37 | intermediate_size: 2048 38 | 39 | training: 40 | phrase_refinement: 41 | epochs: 30 42 | batch_size: 128 43 | learning_rate: 0.0001 44 | validation_split: 0.1 45 | test_split: 100 46 | weight_decay: 0.01 47 | max_grad_norm: 3.0 48 | gradient_accumulation_steps: 1 49 | phrase_generation: 50 | epochs: 30 51 | batch_size: 56 52 | learning_rate: 0.0001 53 | weight_decay: 0.01 54 | max_grad_norm: 3.0 55 | gradient_accumulation_steps: 1 56 | phrase_selection: 57 | epochs: 30 58 | batch_size: 256 59 | learning_rate: 0.0001 60 | weight_decay: 0.01 61 | gradient_accumulation_steps: 1 62 | structure_derivation: 63 | epochs: 30 64 | batch_size: 128 65 | learning_rate: 0.0001 66 | weight_decay: 0.01 67 | gradient_accumulation_steps: 1 68 | fusion: 69 | epochs: 700 70 | batch_size: 80 71 | learning_rate: 0.0001 72 | weight_decay: 0.01 73 | gradient_accumulation_steps: 1 74 | 75 | raw_data: 76 | raw_data_folders: 77 | dataset_1: 78 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/essen 79 | file_extension: krn 80 | annotation_filepath: data/annotations/essen_sequences-1.1.jsonl/essen_sequences.jsonl 81 | dataset_2: 82 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/mtc-ann-2.0.1/MTC-ANN-2.0.1/krn 83 | file_extension: krn 84 | annotation_filepath: data/annotations/MTC-ANN-2.0.1_sequences-1.1.jsonl/mtcann_sequences.jsonl 85 | dataset_3: 86 | folder_path: C:/Users/Keshav/Desktop/QMUL/Research/Datasets/MTC/MTC-FS-INST-2.0/MTC-FS-INST-2.0/krn 87 | file_extension: krn 88 | annotation_filepath: data/annotations/MTC-FS-INST-2.0_sequences-1.1.jsonl/mtcfsinst_sequences.jsonl 89 | mono_folder: data/Mono_Midi_Files 90 | json_folder: data/extracted_phrases 91 | artifact_folder: artifacts 92 | 93 | # generation: 94 | # generate_all: false 95 | # test_filepath: belgium1.json 96 | # use_velocity: false 97 | # use_phrase_selection: true 98 | # structure: ABACA 99 | # phrases_per_section: [5, 5, 3, 5, 3] 100 | # transformations: [] # "contract_melody", "retrograde_melody_pitch_rhythm", "reduce_melody", "invert_melody_tonal", "expand_melody" 101 | # allow_modulation: false 102 | # ratio: 103 | # phrase_generation_frequency: 2 104 | # phrase_refinement_frequency: 1 105 | # motif_repetition_frequency: 50 106 | 107 | # generation: 108 | # generate_all: false 109 | # test_filepath: danmark1.json 110 | # use_velocity: false 111 | # use_phrase_selection: true 112 | # structure: ABACA 113 | # phrases_per_section: [5, 5, 5, 5, 5] 114 | # transformations: [] # "contract_melody", "retrograde_melody_pitch_rhythm", "reduce_melody", "invert_melody_tonal", "expand_melody" 115 | # allow_modulation: false 116 | # ratio: 117 | # phrase_generation_frequency: 2 118 | # phrase_refinement_frequency: 1 119 | # motif_repetition_frequency: 50 120 | 121 | generation: 122 | generate_all: false 123 | test_filepath: 124 | use_velocity: false 125 | use_phrase_selection: true 126 | structure: ABBAA 127 | phrases_per_section: [5, 2, 5, 5, 3] 128 | transformations: [] # "contract_melody", "retrograde_melody_pitch_rhythm", "reduce_melody", "invert_melody_tonal", "expand_melody" 129 | allow_modulation: false 130 | ratio: 131 | phrase_generation_frequency: 2 132 | phrase_refinement_frequency: 1 133 | motif_repetition_frequency: 50 134 | write_midi: false # Change to false to write to mxl -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/data/__init__.py -------------------------------------------------------------------------------- /images/Corruption_Refinement_Training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/images/Corruption_Refinement_Training.png -------------------------------------------------------------------------------- /images/YY_Generation_Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/images/YY_Generation_Framework.png -------------------------------------------------------------------------------- /key_profile.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/key_profile.pickle -------------------------------------------------------------------------------- /notebooks/.ipynb_checkpoints/Tests-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "id": "8396050f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import yaml\n", 11 | "import jsonlines\n", 12 | "\n", 13 | "config_path = r\"C:\\Users\\Keshav\\Desktop\\QMUL\\Research\\PhraseBuilder\\configs\\configs_windows.yaml\"\n", 14 | "\n", 15 | "# Load config file\n", 16 | "with open(config_path, 'r') as f:\n", 17 | " configs = yaml.safe_load(f)\n", 18 | " \n", 19 | "artifact_folder = configs[\"raw_data\"][\"artifact_folder\"]\n", 20 | "mono_folder = configs[\"raw_data\"][\"mono_folder\"]\n", 21 | "json_folder = configs[\"raw_data\"][\"json_folder\"]\n", 22 | "raw_data_folders = configs[\"raw_data\"][\"raw_data_folders\"]" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 10, 28 | "id": "ff4def3f", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def load_jsonl(file_path):\n", 33 | " data = []\n", 34 | " with jsonlines.open(file_path) as reader:\n", 35 | " for line in reader:\n", 36 | " data.append(line)\n", 37 | " return data\n", 38 | "\n", 39 | "for dataset_name, dataset_info in raw_data_folders.items():\n", 40 | " annotation_file = load_jsonl(dataset_info.get('annotation_filepath'))\n", 41 | " break" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 106, 47 | "id": "e50f2d79", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "[None, None, None, None, None, None, None, None, None, None, None, None, None, None, '1', None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "print(annotation_file[i]['features']['restduration_frac'])" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 15, 65 | "id": "f7917ea9", 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "['6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8', '6/8']\n", 73 | "[0.75, 0.25, 0.5, 0.75, 0.25, 0.5, 0.75, 0.25, 0.5, 1.0, 1.0, 0.5, 1.0, 0.5, 1.5, 1.0, 1.0, 0.5, 1.0, 0.5, 1.0, 0.5, 1.5, 1.0, 0.5, 1.0, 0.5, 0.75, 0.25, 0.5, 1.0, 1.0, 0.5, 1.0, 0.5, 1.0, 0.5, 1.0, 0.5, 2.5, 0.5, 1.0]\n", 74 | "['Dotted Eighth', '16th', 'Eighth', 'Dotted Eighth', '16th', 'Eighth', 'Dotted Eighth', '16th', 'Eighth', 'Quarter', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Dotted Quarter', 'Quarter', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Dotted Quarter', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Dotted Eighth', '16th', 'Eighth', 'Quarter', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Quarter', 'Eighth', 'Half tied to Eighth (2 1/2 total QL)', 'Eighth', 'Quarter']\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "i=0\n", 80 | "print(annotation_file[i]['features']['timesignature'])\n", 81 | "print(annotation_file[i]['features']['duration'])\n", 82 | "print(annotation_file[i]['features']['duration_fullname'])" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 16, 88 | "id": "72e19140", 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "['3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4', '3/4']\n", 96 | "[1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 2.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 2.0, 0.25, 0.25, 0.5, 0.5, 0.5, 1.0, 0.25, 0.25, 0.5, 0.5, 0.5, 1.0, 0.25, 0.25, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 2.0]\n", 97 | "['Quarter', 'Eighth', 'Eighth', 'Quarter', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Quarter', 'Half', 'Quarter', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Quarter', 'Quarter', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Half', '16th', '16th', 'Eighth', 'Eighth', 'Eighth', 'Quarter', '16th', '16th', 'Eighth', 'Eighth', 'Eighth', 'Quarter', '16th', '16th', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Half']\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "i=1\n", 103 | "print(annotation_file[i]['features']['timesignature'])\n", 104 | "print(annotation_file[i]['features']['duration'])\n", 105 | "print(annotation_file[i]['features']['duration_fullname'])" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 20, 111 | "id": "fe217189", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "['4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4', '4/4']\n", 119 | "[0.5, 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 0.5, 0.5, 1.0, 0.5, 0.5, 1.0, 0.5, 0.5, 0.5, 0.5, 1.0, 2.0, 1.0, 1.0, 0.5, 0.5, 1.0, 0.5, 0.5, 0.5, 0.5, 1.0, 1.5, 0.5, 1.0, 1.0, 1.5, 0.5, 1.5, 0.5, 1.0, 1.0, 2.0]\n", 120 | "['Eighth', 'Eighth', 'Quarter', 'Quarter', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Quarter', 'Half', 'Quarter', 'Quarter', 'Eighth', 'Eighth', 'Quarter', 'Eighth', 'Eighth', 'Eighth', 'Eighth', 'Quarter', 'Dotted Quarter', 'Eighth', 'Quarter', 'Quarter', 'Dotted Quarter', 'Eighth', 'Dotted Quarter', 'Eighth', 'Quarter', 'Quarter', 'Half']\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "i=2\n", 126 | "print(annotation_file[i]['features']['timesignature'])\n", 127 | "print(annotation_file[i]['features']['duration'])\n", 128 | "print(annotation_file[i]['features']['duration_fullname'])" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 130, 134 | "id": "af95a15b", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "4/4 4.0\n" 142 | ] 143 | }, 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "[[0, 0.5, 0, 55, 1.0, 91],\n", 148 | " [0, 1.0, 0, 55, 0.5, 91],\n", 149 | " [0, 2.0, 0, 55, 1.0, 91],\n", 150 | " [0, 3.0, 0, 55, 0.5, 91],\n", 151 | " [0, 3.5, 0, 60, 1.0, 91],\n", 152 | " [1, 0.5, 0, 52, 1.0, 91],\n", 153 | " [1, 1.5, 0, 55, 0.5, 91],\n", 154 | " [1, 2.0, 0, 53, 0.5, 91],\n", 155 | " [1, 3.0, 0, 53, 0.5, 91],\n", 156 | " [1, 3.5, 0, 58, 1.0, 91],\n", 157 | " [2, 0.5, 0, 58, 1.0, 91],\n", 158 | " [2, 1.5, 0, 57, 0.5, 91],\n", 159 | " [2, 2.0, 0, 59, 0.5, 91],\n", 160 | " [2, 2.5, 0, 60, 0.5, 91],\n", 161 | " [2, 3.0, 0, 55, 0.5, 91],\n", 162 | " [2, 3.5, 0, 55, 1.0, 91],\n", 163 | " [3, 0.5, 0, 53, 1.0, 91],\n", 164 | " [3, 2.5, 0, 52, 1.0, 91],\n", 165 | " [3, 3.5, 0, 55, 1.0, 91],\n", 166 | " [4, 0.5, 0, 55, 1.0, 91],\n", 167 | " [4, 1.5, 0, 55, 0.5, 91],\n", 168 | " [4, 2.0, 0, 54, 0.5, 91],\n", 169 | " [4, 2.5, 0, 54, 1.0, 91],\n", 170 | " [4, 3.5, 0, 60, 1.0, 91],\n", 171 | " [5, 0.5, 0, 60, 1.0, 91],\n", 172 | " [5, 1.5, 0, 60, 0.5, 91],\n", 173 | " [5, 2.0, 0, 59, 0.5, 91],\n", 174 | " [5, 3.0, 0, 59, 0.5, 91],\n", 175 | " [5, 3.5, 0, 62, 1.0, 91],\n", 176 | " [6, 0.5, 0, 62, 1.0, 91],\n", 177 | " [6, 1.5, 0, 62, 0.5, 91],\n", 178 | " [6, 2.0, 0, 52, 0.5, 91],\n", 179 | " [6, 2.5, 0, 60, 0.5, 91],\n", 180 | " [6, 3.0, 0, 57, 0.5, 91],\n", 181 | " [6, 3.5, 0, 55, 1.0, 91],\n", 182 | " [7, 0.5, 0, 59, 0.5, 91],\n", 183 | " [7, 1.0, 0, 57, 0.5, 91],\n", 184 | " [7, 1.5, 0, 55, 2.0, 91],\n", 185 | " [7, 3.5, 0, 50, 1.0, 91],\n", 186 | " [8, 0.5, 0, 50, 1.0, 91],\n", 187 | " [8, 1.5, 0, 51, 1.0, 91],\n", 188 | " [8, 2.5, 0, 51, 1.0, 91],\n", 189 | " [8, 3.5, 0, 52, 1.0, 91],\n", 190 | " [9, 0.5, 0, 52, 1.0, 91],\n", 191 | " [9, 1.5, 0, 53, 1.0, 91],\n", 192 | " [9, 2.5, 0, 53, 1.0, 91],\n", 193 | " [9, 3.5, 0, 54, 1.0, 91],\n", 194 | " [10, 0.5, 0, 54, 1.0, 91],\n", 195 | " [10, 1.5, 0, 55, 1.0, 91],\n", 196 | " [10, 2.5, 0, 55, 1.0, 91],\n", 197 | " [10, 3.5, 0, 60, 1.0, 91],\n", 198 | " [11, 0.5, 0, 60, 1.0, 91],\n", 199 | " [11, 1.5, 0, 60, 3.0, 91],\n", 200 | " [12, 0.5, 0, 60, 0.25, 91],\n", 201 | " [12, 0.75, 0, 59, 0.25, 91],\n", 202 | " [12, 1.0, 0, 57, 0.25, 91],\n", 203 | " [12, 1.25, 0, 55, 0.25, 91],\n", 204 | " [12, 1.5, 0, 55, 2.0, 91],\n", 205 | " [13, 0.0, 0, 60, 1.0, 91],\n", 206 | " [13, 1.0, 0, 55, 0.5, 91],\n", 207 | " [13, 1.5, 0, 57, 1.5, 91],\n", 208 | " [13, 3.0, 0, 57, 0.5, 91],\n", 209 | " [13, 3.5, 0, 62, 0.5, 91],\n", 210 | " [14, 0.0, 0, 60, 0.5, 91],\n", 211 | " [14, 0.5, 0, 59, 0.5, 91],\n", 212 | " [14, 1.0, 0, 57, 0.5, 91],\n", 213 | " [14, 2.5, 0, 59, 1.0, 91],\n", 214 | " [14, 3.5, 0, 60, 1.0, 91],\n", 215 | " [15, 0.5, 0, 55, 1.0, 91],\n", 216 | " [15, 1.5, 0, 57, 0.5, 91],\n", 217 | " [15, 2.0, 0, 65, 0.5, 91],\n", 218 | " [15, 2.5, 0, 64, 0.5, 91],\n", 219 | " [15, 3.0, 0, 62, 0.5, 91],\n", 220 | " [15, 3.5, 0, 60, 1.0, 91],\n", 221 | " [16, 0.5, 0, 59, 1.0, 91],\n", 222 | " [16, 1.5, 0, 60, 2.0, 91]]" 223 | ] 224 | }, 225 | "execution_count": 130, 226 | "metadata": {}, 227 | "output_type": "execute_result" 228 | } 229 | ], 230 | "source": [ 231 | "i = 1105\n", 232 | "\n", 233 | "time_signature = annotation_file[i]['features']['timesignature'][0]\n", 234 | "numerator = int(time_signature.split(\"/\")[0])\n", 235 | "denominator = int(time_signature.split(\"/\")[1])\n", 236 | "if denominator == 4:\n", 237 | " beats_in_bar = numerator * (denominator / 8) * 2\n", 238 | "elif denominator == 8:\n", 239 | " beats_in_bar = numerator * (denominator / 8) / 2\n", 240 | "elif denominator == 2:\n", 241 | " beats_in_bar = numerator * (denominator / 8) * 8\n", 242 | "elif denominator == 1:\n", 243 | " beats_in_bar = numerator * (denominator / 8) * 32\n", 244 | "print(time_signature, beats_in_bar)\n", 245 | " \n", 246 | "pitches = annotation_file[i]['features']['midipitch']\n", 247 | "durations = annotation_file[i]['features']['duration']\n", 248 | "next_note_rest_value = annotation_file[i]['features']['restduration_frac']\n", 249 | "\n", 250 | "encoding = []\n", 251 | "bar = 0\n", 252 | "onset = 0\n", 253 | "for idx, pitch_value in enumerate(pitches):\n", 254 | " note_info = []\n", 255 | " if idx == 0:\n", 256 | " # Check if previous note was a rest\n", 257 | " if next_note_rest_value[idx] is None:\n", 258 | " rest = 0\n", 259 | " else: \n", 260 | " if \"/\" in next_note_rest_value[idx]:\n", 261 | " rest = float(int(next_note_rest_value[idx].split(\"/\")[0]) / int(next_note_rest_value[idx].split(\"/\")[1]))\n", 262 | " else:\n", 263 | " rest = int(next_note_rest_value[idx])\n", 264 | " \n", 265 | " note_info.append([bar, onset+rest, 0, pitches[idx], durations[idx], 91])\n", 266 | " else:\n", 267 | " # Check if previous note was a rest\n", 268 | " if next_note_rest_value[idx] is None:\n", 269 | " rest = 0\n", 270 | " else: \n", 271 | " if \"/\" in next_note_rest_value[idx]:\n", 272 | " rest = float(int(next_note_rest_value[idx].split(\"/\")[0]) / int(next_note_rest_value[idx].split(\"/\")[1]))\n", 273 | " else:\n", 274 | " rest = int(next_note_rest_value[idx])\n", 275 | " \n", 276 | " onset += durations[idx-1] + rest \n", 277 | " \n", 278 | " if onset >= beats_in_bar:\n", 279 | " previous_onset = encoding[-1][1]\n", 280 | " onset = (previous_onset + durations[idx-1] + rest) % beats_in_bar\n", 281 | " bar += 1\n", 282 | " note_info.append([bar, onset, 0, pitches[idx], durations[idx], 91])\n", 283 | " encoding+=note_info\n", 284 | " \n", 285 | "encoding" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 121, 291 | "id": "daa95e87", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "from music21 import stream, meter, note, metadata, tempo\n", 296 | "\n", 297 | "def encoding_to_midi(encoding, tempo, time_signature):\n", 298 | "\n", 299 | " # Create a Score\n", 300 | " score = stream.Score()\n", 301 | " score.metadata = metadata.Metadata()\n", 302 | " score.metadata.title = \"Your MIDI Score\"\n", 303 | "\n", 304 | " # Create a Part for the instrument\n", 305 | " part = stream.Part()\n", 306 | "\n", 307 | " # Set the initial tempo\n", 308 | " initial_tempo = tempo_info.get(0, 120)\n", 309 | " part.append(tempo.MetronomeMark(number=initial_tempo))\n", 310 | "\n", 311 | " # Set the time signature\n", 312 | " time_signature = meter.TimeSignature(time_signature)\n", 313 | "\n", 314 | " # Add the time signature to the Part\n", 315 | " part.append(time_signature)\n", 316 | "\n", 317 | " # Iterate through the MIDI data and create Note objects\n", 318 | " for entry in encoding:\n", 319 | " bar_number, onset_position, instrument_number, pitch, duration, velocity = entry\n", 320 | "\n", 321 | " # Create a Note\n", 322 | " n = note.Note(pitch, quarterLength=duration)\n", 323 | " n.volume.velocity = velocity\n", 324 | "\n", 325 | " # Calculate the offset position\n", 326 | " offset_position = bar_number * time_signature.barDuration.quarterLength + onset_position\n", 327 | "\n", 328 | " # Add the Note to the Part at the calculated offset position\n", 329 | " part.insert(offset_position, n)\n", 330 | "\n", 331 | " # Check if there is a tempo change for the next bar\n", 332 | " next_tempo = tempo_info.get(bar_number + 1, None)\n", 333 | " if next_tempo is not None:\n", 334 | " part.append(tempo.MetronomeMark(number=next_tempo))\n", 335 | "\n", 336 | " # Add the Part to the Score\n", 337 | " score.append(part)\n", 338 | "\n", 339 | " # Write the Score to a MIDI file\n", 340 | " midi_file_path = \"output.mid\"\n", 341 | " score.write('midi', fp=midi_file_path)\n", 342 | "\n", 343 | " print(f\"MIDI file '{midi_file_path}' generated successfully.\")" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 124, 349 | "id": "556f443a", 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "MIDI file 'output.mid' generated successfully.\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "# Tempo information as a dictionary {bar_number: tempo}\n", 362 | "tempo_info = {0: 120, 2: 100}\n", 363 | "\n", 364 | "encoding_to_midi(encoding, tempo, time_signature)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "id": "826dd478", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [] 374 | } 375 | ], 376 | "metadata": { 377 | "kernelspec": { 378 | "display_name": "Python 3 (ipykernel)", 379 | "language": "python", 380 | "name": "python3" 381 | }, 382 | "language_info": { 383 | "codemirror_mode": { 384 | "name": "ipython", 385 | "version": 3 386 | }, 387 | "file_extension": ".py", 388 | "mimetype": "text/x-python", 389 | "name": "python", 390 | "nbconvert_exporter": "python", 391 | "pygments_lexer": "ipython3", 392 | "version": "3.10.13" 393 | } 394 | }, 395 | "nbformat": 4, 396 | "nbformat_minor": 5 397 | } 398 | -------------------------------------------------------------------------------- /phrase_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/phrase_generator/__init__.py -------------------------------------------------------------------------------- /phrase_generator/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | from utils.utils import list_to_remi_encoding 14 | 15 | 16 | class JSONDataset(Dataset): 17 | def __init__(self, configs, file_list, mode="train", shuffle = False): 18 | self.mode = mode 19 | # Data dir 20 | self.data_dir = configs['raw_data']['json_folder'] 21 | self.file_list = file_list 22 | if shuffle: 23 | random.shuffle(self.file_list) 24 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 25 | self.file_number_phrase_number = [] 26 | for file_path in self.file_list: 27 | file_path = os.path.join(self.data_dir, file_path) 28 | with open(file_path, 'r') as f: 29 | data = json.load(f) 30 | phrase_number = len(data["phrases"].keys()) 31 | # Exclude the last phrase as this will be target 32 | for i in range(phrase_number-1): 33 | self.file_number_phrase_number.append([file_path, i]) 34 | 35 | # Artifact folder 36 | self.artifact_folder = configs['raw_data']['artifact_folder'] 37 | # Load encoder tokenizer json file dictionary 38 | tokenizer_filepath = os.path.join(self.artifact_folder, "tokenizer.json") 39 | # Load the tokenizer dictionary 40 | with open(tokenizer_filepath, 'r') as f: 41 | self.tokenizer = json.load(f) 42 | 43 | # Get the maximum sequence length 44 | self.encoder_max_sequence_length = configs['model']['phrase_generation_model']['encoder_max_sequence_length'] 45 | self.decoder_max_sequence_length = configs['model']['phrase_generation_model']['decoder_max_sequence_length'] 46 | 47 | # Print length of dataset 48 | print("Length of dataset: ", len(self.file_list)) 49 | print("Length of phrases in dataset: ", len(self.file_number_phrase_number)) 50 | 51 | def __len__(self): 52 | return len(self.file_number_phrase_number) 53 | 54 | def transpose(self, phrase, pitch_change): 55 | encoding = copy.deepcopy(phrase) 56 | 57 | transposed_encoding = [ 58 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 59 | for event in encoding 60 | ] 61 | 62 | return transposed_encoding 63 | 64 | def transpose_key(self, current_key, semitones): 65 | keys = ["KS_A-", "KS_A", "KS_B-", "KS_B", "KS_C", "KS_D-", "KS_D", "KS_E-", "KS_E", "KS_F", "KS_F#", "KS_G"] 66 | 67 | # Find the index of the current key in the list 68 | current_index = keys.index(current_key) 69 | 70 | # Calculate the new index after transposing by the given semitones 71 | new_index = (current_index + semitones) % len(keys) 72 | 73 | # Return the new key 74 | return keys[new_index] 75 | 76 | def augment_phrase(self, phrase_1, target, current_key): 77 | if random.random() < 0.5: 78 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 79 | 80 | encoding = phrase_1 + target 81 | 82 | # Find highest and lowest pitch values 83 | pitch_values = [event[3] for event in encoding] 84 | highest_pitch = max(pitch_values) 85 | lowest_pitch = min(pitch_values) 86 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 87 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 88 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 89 | if pitch_change < 0: 90 | pitch_change += 1 91 | else: 92 | pitch_change -= 1 93 | 94 | phrase_1 = self.transpose(phrase_1, pitch_change) 95 | target = self.transpose(target, pitch_change) 96 | current_key = self.transpose_key(current_key, pitch_change) 97 | 98 | return phrase_1, target, current_key 99 | 100 | def __getitem__(self, idx): 101 | file_path = self.file_number_phrase_number[idx][0] 102 | phrase_number = self.file_number_phrase_number[idx][1] 103 | with open(file_path, 'r') as f: 104 | data = json.load(f) 105 | 106 | time_signature = data["metadata"]["time_signature"] 107 | key_signature = data["metadata"]["key_signature"] 108 | major_or_minor = data["metadata"]["major_or_minor"] 109 | # # Get the phrase and the target 110 | # # Get all the phrases before the target and concatenate them 111 | phrases = [] 112 | for i in range(phrase_number+1): 113 | phrase = data["phrases"][str(i)][0] 114 | phrases += phrase 115 | 116 | # Extract an arbitrary phrase from random point till the last element of the list 117 | total_phrase_length = len(phrases) 118 | current_phrase_length = len(data["phrases"][str(phrase_number)][0]) 119 | 120 | go_back = random.randint(current_phrase_length, total_phrase_length) 121 | phrase_1 = phrases[-go_back:] 122 | # phrase_1 = data["phrases"][str(phrase_number)][0] 123 | target = data["phrases"][str(phrase_number+1)][0] 124 | target_cadence = data["phrases"][str(phrase_number + 1)][2] 125 | target_pitch_range = data["phrases"][str(phrase_number + 1)][3] 126 | target_length = data["phrases"][str(phrase_number + 1)][4] 127 | 128 | # Augment the phrases 129 | if self.mode == "train": 130 | phrase_1, target, key_signature = self.augment_phrase(phrase_1, target, key_signature) 131 | 132 | tempo_location = data["metadata"]["tempo"] 133 | 134 | # List to remi encoding 135 | phrase = list_to_remi_encoding(phrase_1, tempo_location, time_signature) 136 | # Add the BOS and EOS tokens to the phrase 137 | if random.random() < 0.4: 138 | phrase = ["BOS"] + phrase + ["SEP"] + [target_pitch_range] + [major_or_minor] + [key_signature] + [target_length] + [target_cadence] + ["SEP"] + ["EOS"] 139 | elif random.random() < 0.55: 140 | phrase = ["BOS"] + phrase + ["SEP"] + [target_pitch_range] + [key_signature] + [target_length] + [target_cadence] + ["SEP"] + ["EOS"] 141 | elif random.random() < 0.7: 142 | phrase = ["BOS"] + phrase + ["SEP"] + [target_pitch_range] + [major_or_minor] + [target_length] + [target_cadence] + ["SEP"] + ["EOS"] 143 | elif random.random() < 0.85: 144 | phrase = ["BOS"] + phrase + ["SEP"] + [major_or_minor] + [key_signature] + [target_length] + [target_cadence] + ["SEP"] + ["EOS"] 145 | else: 146 | phrase = ["BOS"] + phrase + ["SEP"] + [target_cadence] + ["SEP"] + ["EOS"] 147 | 148 | # Tokenize the phrase 149 | phrase = [self.tokenizer[note] for note in phrase if note in self.tokenizer] 150 | 151 | # Add the BOS and EOS tokens to the target 152 | target = list_to_remi_encoding(target, tempo_location, time_signature) 153 | target = target + ["EOS"] 154 | # Tokenize the target 155 | target = [self.tokenizer[note] for note in target if note in self.tokenizer] 156 | 157 | # Convert to tensor and pad the phrase to a fixed length of max_sequence_length if the phrase is shorter than max_sequence_length 158 | phrase = torch.tensor(phrase) 159 | if len(phrase) < self.encoder_max_sequence_length: 160 | phrase = F.pad(phrase, (0, self.encoder_max_sequence_length - len(phrase))) 161 | else: 162 | phrase = phrase[-self.encoder_max_sequence_length:] 163 | # Attention mask based on non-padded tokens of the phrase 164 | phrase_attention_mask = torch.where(phrase != 0, 1, 0) 165 | phrase_attention_mask = phrase_attention_mask.type(torch.bool) 166 | 167 | # Do the same for the target 168 | target = torch.tensor(target) 169 | if len(target) < self.decoder_max_sequence_length: 170 | target = F.pad(target, (0, self.decoder_max_sequence_length - len(target))) 171 | else: 172 | target = target[:self.decoder_max_sequence_length] 173 | 174 | train_data = {"input_ids": phrase, "labels": target, "attention_mask": phrase_attention_mask} 175 | 176 | return train_data 177 | 178 | 179 | if __name__ == "__main__": 180 | 181 | # Parse command line arguments 182 | parser = argparse.ArgumentParser() 183 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 184 | help="Path to the config file") 185 | args = parser.parse_args() 186 | 187 | # Load config file 188 | with open(args.config, 'r') as f: 189 | configs = yaml.safe_load(f) 190 | 191 | batch_size = configs['training']['phrase_generation']['batch_size'] 192 | 193 | # Artifact folder 194 | artifact_folder = configs['raw_data']['artifact_folder'] 195 | # Load encoder tokenizer json file dictionary 196 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 197 | # Load the tokenizer dictionary 198 | with open(tokenizer_filepath, 'r') as f: 199 | tokenizer = json.load(f) 200 | 201 | 202 | # Open the train, validation, and test sets json files 203 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 204 | train_file_list = json.load(f) 205 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 206 | valid_file_list = json.load(f) 207 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 208 | test_file_list = json.load(f) 209 | 210 | # Print length of train, validation, and test sets 211 | print("Length of train set: ", len(train_file_list)) 212 | print("Length of validation set: ", len(valid_file_list)) 213 | print("Length of test set: ", len(test_file_list)) 214 | 215 | # Load the dataset 216 | dataset = JSONDataset(configs, train_file_list, mode="train") 217 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 218 | 219 | for n, data in enumerate(dataset): 220 | # print shape and type of tensor 221 | print(data["input_ids"].shape, data["input_ids"].dtype) 222 | print(data["labels"].shape, data["labels"].dtype) 223 | print(data["attention_mask"].shape, data["attention_mask"].dtype) 224 | if n > 0: 225 | break -------------------------------------------------------------------------------- /phrase_generator/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | import random 8 | from torch import Tensor, argmax 9 | from transformers import EncoderDecoderModel, EncoderDecoderConfig, BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback 10 | from evaluate import load as load_metric 11 | from data_loader import JSONDataset 12 | import sys 13 | import argparse 14 | 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 19 | help="Path to the config file") 20 | args = parser.parse_args() 21 | 22 | # Load config file 23 | with open(args.config, 'r') as f: 24 | configs = yaml.safe_load(f) 25 | 26 | batch_size = configs['training']['phrase_generation']['batch_size'] 27 | learning_rate = configs['training']['phrase_generation']['learning_rate'] 28 | epochs = configs['training']['phrase_generation']['epochs'] 29 | 30 | # Artifact folder 31 | artifact_folder = configs['raw_data']['artifact_folder'] 32 | # Load encoder tokenizer json file dictionary 33 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 34 | # Load the tokenizer dictionary 35 | with open(tokenizer_filepath, 'r') as f: 36 | tokenizer = json.load(f) 37 | 38 | 39 | # Open the train, validation, and test sets json files 40 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 41 | train_file_list = json.load(f) 42 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 43 | valid_file_list = json.load(f) 44 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 45 | test_file_list = json.load(f) 46 | 47 | # Print length of train, validation, and test sets 48 | print("Length of train set: ", len(train_file_list)) 49 | print("Length of validation set: ", len(valid_file_list)) 50 | print("Length of test set: ", len(test_file_list)) 51 | 52 | # Load the dataset 53 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 54 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 55 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 56 | 57 | # Get the vocab size 58 | vocab_size = len(tokenizer) 59 | # Get the phrase length 60 | train_phrase_length = len(train_dataset.file_number_phrase_number) 61 | 62 | # Create the encoder-decoder model 63 | config_encoder = BertConfig() 64 | config_encoder.vocab_size = vocab_size 65 | config_encoder.max_position_embeddings = configs['model']['phrase_generation_model']['encoder_max_sequence_length'] 66 | config_encoder.max_length = configs['model']['phrase_generation_model']['encoder_max_sequence_length'] 67 | config_encoder.pad_token_id = 0 68 | config_encoder.bos_token_id = tokenizer["BOS"] 69 | config_encoder.eos_token_id = tokenizer["EOS"] 70 | config_encoder.num_hidden_layers = configs['model']['phrase_generation_model']['num_layers'] 71 | config_encoder.num_attention_heads = configs['model']['phrase_generation_model']['num_heads'] 72 | config_encoder.hidden_size = configs['model']['phrase_generation_model']['hidden_size'] 73 | config_encoder.intermediate_size = configs['model']['phrase_generation_model']['intermediate_size'] 74 | 75 | config_decoder = BertConfig() 76 | config_decoder.vocab_size = vocab_size 77 | config_decoder.max_position_embeddings = configs['model']['phrase_generation_model']['decoder_max_sequence_length'] 78 | config_decoder.max_length = configs['model']['phrase_generation_model']['decoder_max_sequence_length'] 79 | config_decoder.bos_token_id = tokenizer["BOS"] 80 | config_decoder.eos_token_id = tokenizer["EOS"] 81 | config_decoder.pad_token_id = 0 82 | config_decoder.num_hidden_layers = configs['model']['phrase_generation_model']['num_layers'] 83 | config_decoder.num_attention_heads = configs['model']['phrase_generation_model']['num_heads'] 84 | config_decoder.hidden_size = configs['model']['phrase_generation_model']['hidden_size'] 85 | config_decoder.intermediate_size = configs['model']['phrase_generation_model']['intermediate_size'] 86 | 87 | # set decoder config to causal lm 88 | config_decoder.is_decoder = True 89 | config_decoder.add_cross_attention = True 90 | config_decoder.tie_encoder_decoder = False 91 | config_decoder.tie_word_embeddings = False 92 | 93 | config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) 94 | model = EncoderDecoderModel(config=config) 95 | # config.max_length = configs['model']['phrase_generation_model']['max_sequence_length'] 96 | config.decoder_start_token_id = tokenizer["BOS"] 97 | config.pad_token_id = 0 98 | 99 | # Print the number of parameters in the model 100 | num_params = sum(p.numel() for p in model.parameters()) 101 | print(f"Number of parameters in the model: {num_params}") 102 | 103 | # Create config for the Trainer 104 | USE_CUDA = cuda_available() 105 | print(f"USE_CUDA: {USE_CUDA}") 106 | if not cuda_available(): 107 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 108 | elif is_bf16_supported(): 109 | BF16 = BF16_EVAL = True 110 | FP16 = FP16_EVAL = False 111 | else: 112 | BF16 = BF16_EVAL = False 113 | FP16 = FP16_EVAL = True 114 | USE_MPS = not USE_CUDA and mps_available() 115 | 116 | metrics = {metric: load_metric(metric) for metric in ["accuracy"]} 117 | 118 | def compute_metrics(eval_pred): 119 | """ 120 | Compute metrics for pretraining. 121 | 122 | Must use preprocess_logits function that converts logits to predictions (argmax or sampling). 123 | 124 | :param eval_pred: EvalPrediction containing predictions and labels 125 | :return: metrics 126 | """ 127 | predictions, labels = eval_pred 128 | not_pad_mask = labels != 0 129 | labels, predictions = labels[not_pad_mask], predictions[not_pad_mask] 130 | return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten()) 131 | 132 | def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor: 133 | """ 134 | Preprocess the logits before accumulating them during evaluation. 135 | 136 | This allows to significantly reduce the memory usage and make the training tractable. 137 | """ 138 | pred_ids = argmax(logits[0], dim=-1) # long dtype 139 | return pred_ids 140 | 141 | # Define the training arguments 142 | training_args = TrainingArguments( 143 | output_dir=os.path.join(artifact_folder, "phrase_generation"), 144 | per_device_train_batch_size=batch_size, 145 | per_device_eval_batch_size=batch_size, 146 | save_strategy="steps", # "steps" or "epoch" 147 | save_steps=1000, 148 | save_total_limit=1, 149 | learning_rate=learning_rate, 150 | weight_decay= configs['training']['phrase_generation']['weight_decay'], 151 | max_grad_norm=configs['training']['phrase_generation']['max_grad_norm'], 152 | max_steps=int(train_phrase_length//batch_size)*epochs, 153 | evaluation_strategy="steps", 154 | eval_steps=1000, 155 | gradient_accumulation_steps=configs['training']['phrase_generation']['gradient_accumulation_steps'], 156 | gradient_checkpointing=True, 157 | optim="adafactor", 158 | seed=444, 159 | logging_strategy="steps", 160 | logging_steps=100, 161 | logging_dir=os.path.join(artifact_folder, "phrase_generation", "logs"), 162 | no_cuda=not USE_CUDA, 163 | fp16=FP16, 164 | fp16_full_eval=FP16_EVAL, 165 | bf16=BF16, 166 | bf16_full_eval=BF16_EVAL, 167 | load_best_model_at_end=True, 168 | metric_for_best_model="eval_loss", 169 | greater_is_better=False, 170 | report_to="tensorboard", 171 | run_name="phrase_generation", 172 | push_to_hub=False 173 | ) 174 | 175 | # Define the Trainer 176 | trainer = Trainer( 177 | model=model, 178 | args=training_args, 179 | train_dataset=train_dataset, 180 | eval_dataset=valid_dataset, 181 | compute_metrics=compute_metrics, 182 | preprocess_logits_for_metrics=preprocess_logits, 183 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] 184 | ) 185 | 186 | # Train and save the model 187 | train_result = trainer.train() 188 | trainer.save_model() # Saves the tokenizer too 189 | trainer.log_metrics("train", train_result.metrics) 190 | trainer.save_metrics("train", train_result.metrics) 191 | trainer.save_state() 192 | -------------------------------------------------------------------------------- /phrase_refiner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/phrase_refiner/__init__.py -------------------------------------------------------------------------------- /phrase_refiner/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | from transformations import Phrase_Corruption, Melodic_Development 12 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 14 | from utils.utils import list_to_remi_encoding, find_beats_in_bar 15 | 16 | 17 | class JSONDataset(Dataset): 18 | def __init__(self, configs, file_list, mode="train", shuffle = False): 19 | self.mode = mode 20 | # Data dir 21 | self.data_dir = configs['raw_data']['json_folder'] 22 | self.file_list = file_list 23 | if shuffle: 24 | random.shuffle(self.file_list) 25 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 26 | self.file_number_phrase_number = [] 27 | for file_path in self.file_list: 28 | file_path = os.path.join(self.data_dir, file_path) 29 | with open(file_path, 'r') as f: 30 | data = json.load(f) 31 | phrase_number = len(data["phrases"].keys()) 32 | # Exclude the last phrase as this will be target 33 | for i in range(phrase_number-1): 34 | self.file_number_phrase_number.append([file_path, i]) 35 | 36 | # Artifact folder 37 | self.artifact_folder = configs['raw_data']['artifact_folder'] 38 | # Load encoder tokenizer json file dictionary 39 | tokenizer_filepath = os.path.join(self.artifact_folder, "tokenizer.json") 40 | # Load the tokenizer dictionary 41 | with open(tokenizer_filepath, 'r') as f: 42 | self.tokenizer = json.load(f) 43 | 44 | # Get the maximum sequence length 45 | self.encoder_max_sequence_length = configs['model']['phrase_refinement_model']['encoder_max_sequence_length'] 46 | self.decoder_max_sequence_length = configs['model']['phrase_refinement_model']['decoder_max_sequence_length'] 47 | 48 | # Print length of dataset 49 | print("Length of dataset: ", len(self.file_list)) 50 | print("Length of phrases in dataset: ", len(self.file_number_phrase_number)) 51 | 52 | def __len__(self): 53 | return len(self.file_number_phrase_number) 54 | 55 | def transpose(self, phrase, pitch_change): 56 | encoding = copy.deepcopy(phrase) 57 | 58 | transposed_encoding = [ 59 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 60 | for event in encoding 61 | ] 62 | 63 | return transposed_encoding 64 | 65 | def augment_phrase(self, melodic_development_obj, phrase_1, phrase_2, target): 66 | if random.random() < 0.5: 67 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 68 | 69 | encoding = phrase_1 + phrase_2 + target 70 | 71 | # Find highest and lowest pitch values 72 | pitch_values = [event[3] for event in encoding] 73 | highest_pitch = max(pitch_values) 74 | lowest_pitch = min(pitch_values) 75 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 76 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 77 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 78 | if pitch_change < 0: 79 | pitch_change += 1 80 | else: 81 | pitch_change -= 1 82 | 83 | phrase_1 = self.transpose(phrase_1, pitch_change) 84 | phrase_2 = self.transpose(phrase_2, pitch_change) 85 | target = self.transpose(target, pitch_change) 86 | 87 | if random.random() < 0.5: 88 | # Modify the position of the phrase 2 89 | phrase_2 = melodic_development_obj.fix_bars(phrase_2, start_onset=0) 90 | 91 | return phrase_1, phrase_2, target 92 | 93 | def corrupt_phrase(self, phrase, key_signature, mode, beats_in_bar): 94 | # Load the phrase corruption class 95 | self.phrase_corruption_obj = Phrase_Corruption(beats_in_bar) 96 | phrase = self.phrase_corruption_obj.apply_corruptions(phrase, key_signature, mode) 97 | 98 | return phrase 99 | 100 | def get_last_bar(self, melodic_development_obj, phrase): 101 | # Get the last bar of the phrase 102 | grouped_phrase = melodic_development_obj.group_by_bar(phrase) 103 | last_bar = grouped_phrase[-1] 104 | return last_bar 105 | 106 | def __getitem__(self, idx): 107 | file_path = self.file_number_phrase_number[idx][0] 108 | phrase_number = self.file_number_phrase_number[idx][1] 109 | with open(file_path, 'r') as f: 110 | data = json.load(f) 111 | 112 | # Get the phrase and the target 113 | time_signature = data["metadata"]["time_signature"] 114 | key_signature = data["metadata"]["key_signature"] 115 | major_or_minor = data["metadata"]["major_or_minor"] 116 | phrase_1 = data["phrases"][str(phrase_number)][0] 117 | phrase_2 = data["phrases"][str(phrase_number + 1)][0] 118 | phrase_2_position = data["phrases"][str(phrase_number + 1)][1] 119 | phrase_2_cadence = data["phrases"][str(phrase_number + 1)][2] 120 | phrase_2_pitch_range = data["phrases"][str(phrase_number + 1)][3] 121 | phrase_2_length = data["phrases"][str(phrase_number + 1)][4] 122 | target = data["phrases"][str(phrase_number + 1)][0] 123 | 124 | beats_in_bar = find_beats_in_bar(time_signature) 125 | melodic_development_obj = Melodic_Development(beats_in_bar) 126 | 127 | # Augment the phrases 128 | if self.mode == "train": 129 | phrase_1, phrase_2, target = self.augment_phrase(melodic_development_obj, phrase_1, phrase_2, target) 130 | 131 | tempo_location = data["metadata"]["tempo"] 132 | 133 | if random.random() < 0.2: 134 | # Take the last bar of phrase_1 as phrase_1 135 | phrase_1 = self.get_last_bar(melodic_development_obj, phrase_1) 136 | # Corrupt the phrase 2 here 137 | phrase_2, corruption_tokens = self.corrupt_phrase(phrase_2, key_signature, major_or_minor, beats_in_bar) 138 | 139 | # Just take last note of phrase_1 140 | phrase_1 = phrase_1[-1:] 141 | phrase = phrase_1 + ["SEP"] + [key_signature] + [major_or_minor] + [phrase_2_length] + [phrase_2_cadence] + corruption_tokens + ["SEP"] + phrase_2 142 | else: 143 | # Corrupt the phrase 2 here 144 | phrase_2, corruption_tokens = self.corrupt_phrase(phrase_2, key_signature, major_or_minor, beats_in_bar) 145 | 146 | # Add phrase 1 to phrase 2 147 | if random.random() < 0.33: 148 | phrase = phrase_1 + ["SEP"] + [key_signature] + [major_or_minor] + [phrase_2_length] + [phrase_2_cadence] + corruption_tokens + ["SEP"] + phrase_2 149 | elif random.random() < 0.67: 150 | phrase = phrase_1 + ["SEP"] + [phrase_2_length] + [phrase_2_cadence] + corruption_tokens + ["SEP"] + phrase_2 151 | else: 152 | phrase = phrase_1 + ["SEP"] + [phrase_2_cadence] + corruption_tokens + ["SEP"] + phrase_2 153 | 154 | # List to remi encoding 155 | phrase = list_to_remi_encoding(phrase, tempo_location, time_signature) 156 | # Add the BOS and EOS tokens to the phrase 157 | phrase = ["BOS"] + phrase + ["EOS"] 158 | # Tokenize the phrase 159 | phrase = [self.tokenizer[note] for note in phrase if note in self.tokenizer] 160 | 161 | # Add the BOS and EOS tokens to the target 162 | target = list_to_remi_encoding(target, tempo_location, time_signature) 163 | target = target + ["EOS"] 164 | # Tokenize the target 165 | target = [self.tokenizer[note] for note in target if note in self.tokenizer] 166 | 167 | # Convert to tensor and pad the phrase to a fixed length of max_sequence_length if the phrase is shorter than max_sequence_length 168 | phrase = torch.tensor(phrase) 169 | if len(phrase) < self.encoder_max_sequence_length: 170 | phrase = F.pad(phrase, (0, self.encoder_max_sequence_length - len(phrase))) 171 | else: 172 | phrase = phrase[-self.encoder_max_sequence_length:] 173 | # Attention mask based on non-padded tokens of the phrase 174 | phrase_attention_mask = torch.where(phrase != 0, 1, 0) 175 | phrase_attention_mask = phrase_attention_mask.type(torch.bool) 176 | 177 | # Do the same for the target 178 | target = torch.tensor(target) 179 | if len(target) < self.decoder_max_sequence_length: 180 | target = F.pad(target, (0, self.decoder_max_sequence_length - len(target))) 181 | else: 182 | target = target[:self.decoder_max_sequence_length] 183 | 184 | train_data = {"input_ids": phrase, "labels": target, "attention_mask": phrase_attention_mask} 185 | 186 | return train_data 187 | 188 | 189 | if __name__ == "__main__": 190 | 191 | # Parse command line arguments 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 194 | help="Path to the config file") 195 | args = parser.parse_args() 196 | 197 | # Load config file 198 | with open(args.config, 'r') as f: 199 | configs = yaml.safe_load(f) 200 | 201 | batch_size = configs['training']['phrase_generation']['batch_size'] 202 | 203 | # Artifact folder 204 | artifact_folder = configs['raw_data']['artifact_folder'] 205 | # Load encoder tokenizer json file dictionary 206 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 207 | # Load the tokenizer dictionary 208 | with open(tokenizer_filepath, 'r') as f: 209 | tokenizer = json.load(f) 210 | 211 | # Open the train, validation, and test sets json files 212 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 213 | train_file_list = json.load(f) 214 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 215 | valid_file_list = json.load(f) 216 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 217 | test_file_list = json.load(f) 218 | 219 | # Print length of train, validation, and test sets 220 | print("Length of train set: ", len(train_file_list)) 221 | print("Length of validation set: ", len(valid_file_list)) 222 | print("Length of test set: ", len(test_file_list)) 223 | 224 | # Load the dataset 225 | dataset = JSONDataset(configs, train_file_list, mode="train") 226 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 227 | 228 | for n, data in enumerate(dataset): 229 | # print shape and type of tensor 230 | print(data["input_ids"].shape, data["input_ids"].dtype) 231 | print(data["labels"].shape, data["labels"].dtype) 232 | print(data["attention_mask"].shape, data["attention_mask"].dtype) 233 | if n > 5: 234 | break -------------------------------------------------------------------------------- /phrase_refiner/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | import random 8 | from torch import Tensor, argmax 9 | from transformers import EncoderDecoderModel, EncoderDecoderConfig, BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback 10 | from evaluate import load as load_metric 11 | from data_loader import JSONDataset 12 | import sys 13 | import argparse 14 | 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 19 | help="Path to the config file") 20 | args = parser.parse_args() 21 | 22 | # Load config file 23 | with open(args.config, 'r') as f: 24 | configs = yaml.safe_load(f) 25 | 26 | batch_size = configs['training']['phrase_refinement']['batch_size'] 27 | # max_sequence_length = configs['model']['max_sequence_length'] 28 | learning_rate = configs['training']['phrase_refinement']['learning_rate'] 29 | epochs = configs['training']['phrase_refinement']['epochs'] 30 | 31 | # Artifact folder 32 | artifact_folder = configs['raw_data']['artifact_folder'] 33 | # Load encoder tokenizer json file dictionary 34 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 35 | # Load the tokenizer dictionary 36 | with open(tokenizer_filepath, 'r') as f: 37 | tokenizer = json.load(f) 38 | 39 | 40 | # Open the train, validation, and test sets json files if they exist 41 | train_file_list_path = os.path.join(artifact_folder, "train_file_list.json") 42 | valid_file_list_path = os.path.join(artifact_folder, "valid_file_list.json") 43 | test_file_list_path = os.path.join(artifact_folder, "test_file_list.json") 44 | 45 | if os.path.exists(train_file_list_path) and os.path.exists(valid_file_list_path) and os.path.exists(test_file_list_path): 46 | with open(train_file_list_path, "r") as f: 47 | train_file_list = json.load(f) 48 | with open(valid_file_list_path, "r") as f: 49 | valid_file_list = json.load(f) 50 | with open(test_file_list_path, "r") as f: 51 | test_file_list = json.load(f) 52 | else: 53 | # Data dir 54 | data_dir = configs['raw_data']['json_folder'] 55 | file_list = os.listdir(data_dir) 56 | valid_split = configs['training']['phrase_refinement']['validation_split'] 57 | n_test_files = configs['training']['phrase_refinement']['test_split'] 58 | 59 | # Split the file list into train and validation sets 60 | train_file_list = file_list[:int(len(file_list) * (1 - valid_split))] 61 | valid_file_list = file_list[int(len(file_list) * (1 - valid_split)):] 62 | 63 | # Now take 100 files randomly from train set as the test set and remove them from the train set 64 | test_file_list = random.sample(train_file_list, n_test_files) 65 | train_file_list = [file for file in train_file_list if file not in test_file_list] 66 | 67 | # Save the train, validation, and test sets to a json file 68 | with open(os.path.join(artifact_folder, "train_file_list.json"), "w") as f: 69 | json.dump(train_file_list, f) 70 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "w") as f: 71 | json.dump(valid_file_list, f) 72 | with open(os.path.join(artifact_folder, "test_file_list.json"), "w") as f: 73 | json.dump(test_file_list, f) 74 | 75 | # Print length of train, validation, and test sets 76 | print("Length of train set: ", len(train_file_list)) 77 | print("Length of validation set: ", len(valid_file_list)) 78 | print("Length of test set: ", len(test_file_list)) 79 | 80 | # Load the dataset 81 | train_dataset = JSONDataset(configs, train_file_list, mode="train", shuffle=True) 82 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval", shuffle=False) 83 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 84 | 85 | # Get the vocab size 86 | vocab_size = len(tokenizer) 87 | # Get the phrase length 88 | train_phrase_length = len(train_dataset.file_number_phrase_number) 89 | 90 | # Create the encoder-decoder model 91 | config_encoder = BertConfig() 92 | config_encoder.vocab_size = vocab_size 93 | config_encoder.max_position_embeddings = configs['model']['phrase_refinement_model']['encoder_max_sequence_length'] 94 | config_encoder.max_length = configs['model']['phrase_refinement_model']['encoder_max_sequence_length'] 95 | config_encoder.pad_token_id = 0 96 | config_encoder.bos_token_id = tokenizer["BOS"] 97 | config_encoder.eos_token_id = tokenizer["EOS"] 98 | config_encoder.num_hidden_layers = configs['model']['phrase_refinement_model']['num_layers'] 99 | config_encoder.num_attention_heads = configs['model']['phrase_refinement_model']['num_heads'] 100 | config_encoder.hidden_size = configs['model']['phrase_refinement_model']['hidden_size'] 101 | config_encoder.intermediate_size = configs['model']['phrase_refinement_model']['intermediate_size'] 102 | 103 | config_decoder = BertConfig() 104 | config_decoder.vocab_size = vocab_size 105 | config_decoder.max_position_embeddings = configs['model']['phrase_refinement_model']['decoder_max_sequence_length'] 106 | config_decoder.max_length = configs['model']['phrase_refinement_model']['decoder_max_sequence_length'] 107 | config_decoder.bos_token_id = tokenizer["BOS"] 108 | config_decoder.eos_token_id = tokenizer["EOS"] 109 | config_decoder.pad_token_id = 0 110 | config_decoder.num_hidden_layers = configs['model']['phrase_refinement_model']['num_layers'] 111 | config_decoder.num_attention_heads = configs['model']['phrase_refinement_model']['num_heads'] 112 | config_decoder.hidden_size = configs['model']['phrase_refinement_model']['hidden_size'] 113 | config_decoder.intermediate_size = configs['model']['phrase_refinement_model']['intermediate_size'] 114 | 115 | # set decoder config to causal lm 116 | config_decoder.is_decoder = True 117 | config_decoder.add_cross_attention = True 118 | config_decoder.tie_encoder_decoder = False 119 | config_decoder.tie_word_embeddings = False 120 | 121 | 122 | config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) 123 | model = EncoderDecoderModel(config=config) 124 | # config.max_length = configs['model']['phrase_refinement_model']['max_sequence_length'] 125 | config.decoder_start_token_id = tokenizer["BOS"] 126 | config.pad_token_id = 0 127 | 128 | # Print the number of parameters in the model 129 | num_params = sum(p.numel() for p in model.parameters()) 130 | print(f"Number of parameters in the model: {num_params}") 131 | 132 | # Create config for the Trainer 133 | USE_CUDA = cuda_available() 134 | print(f"USE_CUDA: {USE_CUDA}") 135 | if not cuda_available(): 136 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 137 | elif is_bf16_supported(): 138 | BF16 = BF16_EVAL = True 139 | FP16 = FP16_EVAL = False 140 | else: 141 | BF16 = BF16_EVAL = False 142 | FP16 = FP16_EVAL = True 143 | USE_MPS = not USE_CUDA and mps_available() 144 | 145 | metrics = {metric: load_metric(metric) for metric in ["accuracy"]} 146 | 147 | def compute_metrics(eval_pred): 148 | """ 149 | Compute metrics for pretraining. 150 | 151 | Must use preprocess_logits function that converts logits to predictions (argmax or sampling). 152 | 153 | :param eval_pred: EvalPrediction containing predictions and labels 154 | :return: metrics 155 | """ 156 | predictions, labels = eval_pred 157 | not_pad_mask = labels != 0 158 | labels, predictions = labels[not_pad_mask], predictions[not_pad_mask] 159 | return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten()) 160 | 161 | def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor: 162 | """ 163 | Preprocess the logits before accumulating them during evaluation. 164 | 165 | This allows to significantly reduce the memory usage and make the training tractable. 166 | """ 167 | pred_ids = argmax(logits[0], dim=-1) # long dtype 168 | return pred_ids 169 | 170 | # Define the training arguments 171 | training_args = TrainingArguments( 172 | output_dir=os.path.join(artifact_folder, "phrase_refinement_v2"), 173 | per_device_train_batch_size=batch_size, 174 | per_device_eval_batch_size=batch_size, 175 | save_strategy="steps", # "steps" or "epoch" 176 | save_steps=500, 177 | save_total_limit=1, 178 | learning_rate=learning_rate, #1e-4, 179 | weight_decay= configs['training']['phrase_refinement']['weight_decay'], 180 | max_grad_norm=configs['training']['phrase_refinement']['max_grad_norm'], 181 | max_steps=int(train_phrase_length//batch_size)*epochs, 182 | evaluation_strategy="steps", 183 | eval_steps=500, 184 | gradient_accumulation_steps=configs['training']['phrase_refinement']['gradient_accumulation_steps'], 185 | gradient_checkpointing=True, 186 | optim="adafactor", 187 | seed=444, 188 | logging_strategy="steps", 189 | logging_steps=100, 190 | logging_dir=os.path.join(artifact_folder, "phrase_refinement_v2", "logs"), 191 | no_cuda=not USE_CUDA, 192 | fp16=FP16, 193 | fp16_full_eval=FP16_EVAL, 194 | bf16=BF16, 195 | bf16_full_eval=BF16_EVAL, 196 | load_best_model_at_end=True, 197 | metric_for_best_model="eval_loss", 198 | greater_is_better=False, 199 | report_to="tensorboard", 200 | run_name="phrase_refinement_v2", 201 | push_to_hub=False 202 | ) 203 | 204 | # Define the Trainer 205 | trainer = Trainer( 206 | model=model, 207 | args=training_args, 208 | train_dataset=train_dataset, 209 | eval_dataset=valid_dataset, 210 | compute_metrics=compute_metrics, 211 | preprocess_logits_for_metrics=preprocess_logits, 212 | callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] 213 | ) 214 | 215 | # Train and save the model 216 | train_result = trainer.train() 217 | trainer.save_model() # Saves the tokenizer too 218 | trainer.log_metrics("train", train_result.metrics) 219 | trainer.save_metrics("train", train_result.metrics) 220 | trainer.save_state() 221 | -------------------------------------------------------------------------------- /phrase_selector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/phrase_selector/__init__.py -------------------------------------------------------------------------------- /phrase_selector/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | from utils.utils import list_to_remi_encoding, find_beats_in_bar 14 | from phrase_refiner.transformations import Melodic_Development 15 | 16 | 17 | class JSONDataset(Dataset): 18 | def __init__(self, configs, file_list, mode="train", shuffle = False): 19 | self.mode = mode 20 | # Data dir 21 | self.data_dir = configs['raw_data']['json_folder'] 22 | self.file_list = file_list 23 | if shuffle: 24 | random.shuffle(self.file_list) 25 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 26 | self.file_number_phrase_number = [] 27 | for file_path in self.file_list: 28 | file_path = os.path.join(self.data_dir, file_path) 29 | with open(file_path, 'r') as f: 30 | data = json.load(f) 31 | phrase_number = len(data["phrases"].keys()) 32 | # Exclude the last phrase as this will be target 33 | for i in range(phrase_number-1): 34 | self.file_number_phrase_number.append([file_path, i]) 35 | 36 | # Length of file_number_phrase_number 37 | self.length_phrases = len(self.file_number_phrase_number) 38 | # Length of file_list 39 | self.length_files = len(self.file_list) 40 | 41 | # Print length of dataset 42 | print("Length of dataset: ", self.length_files) 43 | print("Length of phrases in dataset: ", self.length_phrases) 44 | 45 | # Artifact folder 46 | self.artifact_folder = configs['raw_data']['artifact_folder'] 47 | # Load encoder tokenizer json file dictionary 48 | tokenizer_filepath = os.path.join(self.artifact_folder, "tokenizer.json") 49 | # Load the tokenizer dictionary 50 | with open(tokenizer_filepath, 'r') as f: 51 | self.tokenizer = json.load(f) 52 | 53 | # Get the maximum sequence length 54 | self.max_sequence_length = configs['model']['phrase_selection_model']['max_sequence_length'] 55 | 56 | def __len__(self): 57 | return self.length_files 58 | 59 | def transpose(self, phrase, pitch_change): 60 | encoding = copy.deepcopy(phrase) 61 | 62 | transposed_encoding = [ 63 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 64 | for event in encoding 65 | ] 66 | 67 | return transposed_encoding 68 | 69 | def augment_phrase(self, phrase_1): 70 | if random.random() < 0.5: 71 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 72 | 73 | encoding = copy.deepcopy(phrase_1) 74 | 75 | # Find highest and lowest pitch values 76 | pitch_values = [event[3] for event in encoding] 77 | highest_pitch = max(pitch_values) 78 | lowest_pitch = min(pitch_values) 79 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 80 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 81 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 82 | if pitch_change < 0: 83 | pitch_change += 1 84 | else: 85 | pitch_change -= 1 86 | 87 | phrase_1 = self.transpose(phrase_1, pitch_change) 88 | 89 | return phrase_1 90 | 91 | def reindex_phrase(self, phrase_1, phrase_2, beats_in_bar): 92 | self.melodic_development_obj = Melodic_Development(beats_in_bar=beats_in_bar) 93 | phrase_2 = self.melodic_development_obj.reindex_bars(phrase_2, start_bar=phrase_1[-1][0]+1) 94 | 95 | return phrase_2 96 | 97 | def __getitem__(self, idx): 98 | file_path = self.file_list[idx] 99 | # Get max phrases in file 100 | with open(os.path.join(self.data_dir, file_path), 'r') as f: 101 | data = json.load(f) 102 | 103 | time_signature = data["metadata"]["time_signature"] 104 | # Get key and major or minor for phrase 105 | key_signature = data["metadata"]["key_signature"] 106 | major_or_minor = data["metadata"]["major_or_minor"] 107 | 108 | # Choose a random phrase number 109 | total_phrases = len(data["phrases"].keys()) 110 | phrase_number = random.randint(0, total_phrases-2) 111 | 112 | phrase_1 = data["phrases"][str(phrase_number)][0] 113 | if random.random() < 0.5: 114 | # Get the phrase and the target from the same file as a positive sample 115 | phrase_2 = data["phrases"][str(phrase_number + 1)][0] 116 | # Reindex phrase_2 to match the last bar + 1 of phrase_1 117 | beats_in_bar = find_beats_in_bar(time_signature) 118 | phrase_2 = self.reindex_phrase(phrase_1, phrase_2, beats_in_bar) 119 | target = torch.tensor(1) 120 | else: 121 | # Choose a random file from self.file_list that is not idx 122 | random_file = random.choice([i for i in range(self.length_files) if i != idx]) 123 | random_file_path = self.file_list[random_file] 124 | with open(os.path.join(self.data_dir, random_file_path), 'r') as f: 125 | random_data = json.load(f) 126 | # Choose a random phrase from the random file as a negative sample 127 | random_phrase_number = random.randint(0, len(random_data["phrases"].keys())-1) 128 | phrase_2 = random_data["phrases"][str(random_phrase_number)][0] 129 | # Reindex phrase_2 to match the last bar + 1 of phrase_1 130 | beats_in_bar = find_beats_in_bar(time_signature) 131 | phrase_2 = self.reindex_phrase(phrase_1, phrase_2, beats_in_bar) 132 | key_signature = random_data["metadata"]["key_signature"] 133 | major_or_minor = random_data["metadata"]["major_or_minor"] 134 | target = torch.tensor(0) 135 | 136 | # Augment the phrases 137 | if self.mode == "train": 138 | phrase_1 = self.augment_phrase(phrase_1) 139 | phrase_2 = self.augment_phrase(phrase_2) 140 | # phrase_1, phrase_2 = self.augment_phrase(phrase_1, phrase_2) 141 | 142 | tempo_location = data["metadata"]["tempo"] 143 | 144 | # Add phrase 1 to phrase 2 145 | phrase = phrase_1 + ["SEP"] + phrase_2 146 | # List to remi encoding 147 | phrase = list_to_remi_encoding(phrase, tempo_location, time_signature) 148 | # Add the BOS and EOS tokens to the phrase 149 | phrase = ["BOS"] + phrase + ["EOS"] 150 | # Tokenize the phrase 151 | phrase = [self.tokenizer[note] for note in phrase if note in self.tokenizer] 152 | 153 | # Convert to tensor and pad the phrase to a fixed length of max_sequence_length if the phrase is shorter than max_sequence_length 154 | phrase = torch.tensor(phrase) 155 | if len(phrase) < self.max_sequence_length: 156 | phrase = F.pad(phrase, (0, self.max_sequence_length - len(phrase))) 157 | else: 158 | phrase = phrase[-self.max_sequence_length:] 159 | # Attention mask based on non-padded tokens of the phrase 160 | phrase_attention_mask = torch.where(phrase != 0, 1, 0) 161 | phrase_attention_mask = phrase_attention_mask.type(torch.bool) 162 | 163 | train_data = {"input_ids": phrase, "labels": target, "attention_mask": phrase_attention_mask} 164 | 165 | return train_data 166 | 167 | 168 | if __name__ == "__main__": 169 | 170 | # Parse command line arguments 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 173 | help="Path to the config file") 174 | args = parser.parse_args() 175 | 176 | # Load config file 177 | with open(args.config, 'r') as f: 178 | configs = yaml.safe_load(f) 179 | 180 | batch_size = configs['training']['phrase_generation']['batch_size'] 181 | 182 | # Artifact folder 183 | artifact_folder = configs['raw_data']['artifact_folder'] 184 | # Load encoder tokenizer json file dictionary 185 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 186 | # Load the tokenizer dictionary 187 | with open(tokenizer_filepath, 'r') as f: 188 | tokenizer = json.load(f) 189 | 190 | # Open the train, validation, and test sets json files 191 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 192 | train_file_list = json.load(f) 193 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 194 | valid_file_list = json.load(f) 195 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 196 | test_file_list = json.load(f) 197 | 198 | # Print length of train, validation, and test sets 199 | print("Length of train set: ", len(train_file_list)) 200 | print("Length of validation set: ", len(valid_file_list)) 201 | print("Length of test set: ", len(test_file_list)) 202 | 203 | # Load the dataset 204 | dataset = JSONDataset(configs, train_file_list, beats_in_bar=32, mode="train") 205 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 206 | 207 | for n, data in enumerate(dataset): 208 | # print shape and type of tensor 209 | print(data["input_ids"].shape, data["input_ids"].dtype) 210 | print(data["labels"].shape, data["labels"].dtype) 211 | if n > 5: 212 | break -------------------------------------------------------------------------------- /phrase_selector/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | from torch import Tensor, argmax 8 | from transformers import BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback, BertModel 9 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 10 | from data_loader import JSONDataset 11 | import argparse 12 | from transformers import AutoModelForSequenceClassification 13 | 14 | 15 | # Parse command line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 18 | help="Path to the config file") 19 | args = parser.parse_args() 20 | 21 | # Load config file 22 | with open(args.config, 'r') as f: 23 | configs = yaml.safe_load(f) 24 | 25 | batch_size = configs['training']['phrase_generation']['batch_size'] 26 | learning_rate = configs['training']['phrase_generation']['learning_rate'] 27 | epochs = configs['training']['phrase_generation']['epochs'] 28 | 29 | # Artifact folder 30 | artifact_folder = configs['raw_data']['artifact_folder'] 31 | # Load encoder tokenizer json file dictionary 32 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 33 | # Load the tokenizer dictionary 34 | with open(tokenizer_filepath, 'r') as f: 35 | tokenizer = json.load(f) 36 | 37 | 38 | # Open the train, validation, and test sets json files 39 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 40 | train_file_list = json.load(f) 41 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 42 | valid_file_list = json.load(f) 43 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 44 | test_file_list = json.load(f) 45 | 46 | # Print length of train, validation, and test sets 47 | print("Length of train set: ", len(train_file_list)) 48 | print("Length of validation set: ", len(valid_file_list)) 49 | print("Length of test set: ", len(test_file_list)) 50 | 51 | # Load the dataset 52 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 53 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 54 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 55 | 56 | # Get the vocab size 57 | vocab_size = len(tokenizer) 58 | # Get the phrase length 59 | train_phrase_length = len(train_dataset.file_number_phrase_number) 60 | 61 | # Create the encoder-decoder model 62 | config_encoder = BertConfig() 63 | config_encoder.vocab_size = vocab_size 64 | config_encoder.max_position_embeddings = configs['model']['phrase_selection_model']['max_sequence_length'] 65 | config_encoder.max_length = configs['model']['phrase_selection_model']['max_sequence_length'] 66 | config_encoder.pad_token_id = 0 67 | config_encoder.bos_token_id = tokenizer["BOS"] 68 | config_encoder.eos_token_id = tokenizer["EOS"] 69 | config_encoder.num_hidden_layers = configs['model']['phrase_selection_model']['num_layers'] 70 | config_encoder.num_attention_heads = configs['model']['phrase_selection_model']['num_heads'] 71 | config_encoder.hidden_size = configs['model']['phrase_selection_model']['hidden_size'] 72 | config_encoder.intermediate_size = configs['model']['phrase_selection_model']['intermediate_size'] 73 | config_encoder.num_labels = 2 74 | 75 | model = AutoModelForSequenceClassification.from_config(config_encoder) 76 | 77 | # Print the number of parameters in the model 78 | num_params = sum(p.numel() for p in model.parameters()) 79 | print(f"Number of parameters in the model: {num_params}") 80 | 81 | # Create config for the Trainer 82 | USE_CUDA = cuda_available() 83 | print(f"USE_CUDA: {USE_CUDA}") 84 | if not cuda_available(): 85 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 86 | elif is_bf16_supported(): 87 | BF16 = BF16_EVAL = True 88 | FP16 = FP16_EVAL = False 89 | else: 90 | BF16 = BF16_EVAL = False 91 | FP16 = FP16_EVAL = True 92 | USE_MPS = not USE_CUDA and mps_available() 93 | 94 | # Define the accuracy metric function 95 | def compute_metrics(p): 96 | preds = torch.argmax(torch.from_numpy(p.predictions), axis=1) 97 | return { 98 | 'accuracy': accuracy_score(p.label_ids, preds), 99 | 'precision': precision_recall_fscore_support(p.label_ids, preds, average='binary')[0], 100 | 'recall': precision_recall_fscore_support(p.label_ids, preds, average='binary')[1], 101 | 'f1': precision_recall_fscore_support(p.label_ids, preds, average='binary')[2], 102 | } 103 | 104 | # def compute_metrics(eval_pred): 105 | # predictions, labels = eval_pred 106 | # predictions = np.argmax(predictions, axis=1) 107 | # return accuracy.compute(predictions=predictions, references=labels) 108 | 109 | # Define the training arguments 110 | training_args = TrainingArguments( 111 | output_dir=os.path.join(artifact_folder, "phrase_selection"), 112 | per_device_train_batch_size=batch_size, 113 | per_device_eval_batch_size=batch_size, 114 | save_strategy="steps", # "steps" or "epoch" 115 | save_steps=500, 116 | save_total_limit=1, 117 | learning_rate=learning_rate, 118 | weight_decay= configs['training']['phrase_selection']['weight_decay'], 119 | max_steps=int(train_phrase_length//batch_size)*epochs, 120 | evaluation_strategy="steps", 121 | eval_steps=500, 122 | gradient_accumulation_steps=configs['training']['phrase_selection']['gradient_accumulation_steps'], 123 | gradient_checkpointing=True, 124 | optim="adafactor", 125 | seed=444, 126 | logging_strategy="steps", 127 | logging_steps=100, 128 | logging_dir=os.path.join(artifact_folder, "phrase_selection", "logs"), 129 | no_cuda=not USE_CUDA, 130 | fp16=FP16, 131 | fp16_full_eval=FP16_EVAL, 132 | bf16=BF16, 133 | bf16_full_eval=BF16_EVAL, 134 | load_best_model_at_end=True, 135 | metric_for_best_model="eval_loss", 136 | greater_is_better=False, 137 | report_to="tensorboard", 138 | run_name="phrase_selection", 139 | push_to_hub=False 140 | ) 141 | 142 | # Define the Trainer 143 | trainer = Trainer( 144 | model=model, 145 | args=training_args, 146 | train_dataset=train_dataset, 147 | eval_dataset=valid_dataset, 148 | compute_metrics=compute_metrics, 149 | callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] 150 | ) 151 | 152 | # Train and save the model 153 | train_result = trainer.train() 154 | trainer.save_model() # Saves the tokenizer too 155 | trainer.log_metrics("train", train_result.metrics) 156 | trainer.save_metrics("train", train_result.metrics) 157 | trainer.save_state() 158 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/phrase_extraction.py: -------------------------------------------------------------------------------- 1 | from miditok import REMIPlus, TokenizerConfig, TokSequence 2 | from miditoolkit import MidiFile 3 | import glob 4 | import random 5 | import copy 6 | import pickle 7 | import json 8 | import os 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | import yaml 12 | import argparse 13 | import jsonlines 14 | import sys 15 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 17 | from utils.utils import annotation_to_encoding 18 | 19 | def load_jsonl(file_path): 20 | data = [] 21 | with jsonlines.open(file_path) as reader: 22 | for line in reader: 23 | data.append(line) 24 | return data 25 | 26 | def extract_phrases(encoding, phrase_annotation): 27 | 28 | # Check if length of encoding is equal to length of phrase_endings 29 | if len(encoding) != len(phrase_annotation["features"]["phrase_end"]): 30 | return [] 31 | 32 | # Find indices where phrase_endings is True 33 | true_indices = [i for i, value in enumerate(phrase_annotation["features"]["phrase_end"]) if value] 34 | 35 | # Segment the encoding based on true_indices 36 | segmented_encoding = [] 37 | start_index = 0 38 | for end_index in true_indices: 39 | segment = encoding[start_index:end_index + 1] 40 | segmented_encoding.append(segment) 41 | start_index = end_index + 1 42 | 43 | # Add the remaining part if any 44 | if start_index < len(encoding): 45 | segmented_encoding.append(encoding[start_index:]) 46 | 47 | return segmented_encoding 48 | 49 | def main(configs): 50 | # Make artifact folder if it doesn't exist 51 | artifact_folder = configs["raw_data"]["artifact_folder"] 52 | os.makedirs(artifact_folder, exist_ok=True) 53 | 54 | # mono_folder = configs["raw_data"]["mono_folder"] 55 | # os.makedirs(mono_folder, exist_ok=True) 56 | # # Get all the files in the folder 57 | # mono_files = glob.glob(mono_folder + "/*.mid") 58 | raw_data_folders = configs["raw_data"]["raw_data_folders"] 59 | 60 | # Create a folder to store the json files 61 | json_folder = configs["raw_data"]["json_folder"] 62 | os.makedirs(json_folder, exist_ok=True) 63 | 64 | # Create a dictionary to store the tokenizer 65 | tokenizer_dict = {} 66 | # Initialize midi pitches from 0 to 127 to tokenizer dictionary 67 | for i in range(128): 68 | tokenizer_dict[f"Pitch_{i}"] = len(tokenizer_dict) + 1 69 | 70 | phrase_lengths = [] 71 | # Iterate through each dataset, extract phrases and save them as json files 72 | for dataset_name, dataset_info in raw_data_folders.items(): 73 | 74 | annotation_file = load_jsonl(dataset_info.get('annotation_filepath')) 75 | 76 | # Load the midi files that are in phrase annotations 77 | for phrase_annotation in tqdm(annotation_file): 78 | midi_filepath = os.path.join(json_folder, phrase_annotation["id"]) 79 | 80 | encoding, time_signature, key_signature, major_or_minor = annotation_to_encoding(phrase_annotation) 81 | if len(encoding) == 0: 82 | print(f"Skipping {midi_filepath} as encoding is most likely corrupt") 83 | continue 84 | 85 | # Create a dictionary to store the phrases 86 | phrase_dict = {'metadata': {'tempo': {0: 120}, 87 | 'time_signature': f"TimeSig_{time_signature}", 88 | 'key_signature': f"KS_{key_signature}", 89 | 'major_or_minor': f"MM_{major_or_minor}"}, 90 | 'phrases': {}} 91 | 92 | # Extract the phrases from the encoding 93 | phrases = extract_phrases(encoding, phrase_annotation) 94 | 95 | if len(phrases) <= 1: 96 | print(f"Skipping {midi_filepath} as there is only one phrase") 97 | continue 98 | 99 | for n, phrase in enumerate(phrases): 100 | phrase_info = [phrase] 101 | 102 | # Phrase position: beginning, middle or end relative to the total number of phrases 103 | phrase_position = "PP_middle" if n > 0 and n < len(phrases) - 1 else "PP_beginning" if n == 0 else "PP_end" 104 | phrase_info.append(phrase_position) 105 | 106 | # Check if last note duration is greater than equal to minim and last note pitch is lower than second last note pitch 107 | if len(phrase) > 1: 108 | if phrase[-1][4] >= 2 and phrase[-1][3] < phrase[-2][3]: 109 | # Add True to phrase_info 110 | phrase_info.append("CA_True") 111 | else: 112 | # Add False to phrase_info 113 | phrase_info.append("CA_False") 114 | else: 115 | # Add False to phrase_info 116 | phrase_info.append("CA_False") 117 | 118 | # Pitch range of the phrase 119 | pitch_range = max([note[3] for note in phrase]) - min([note[3] for note in phrase]) 120 | if f"PR_{pitch_range}" not in tokenizer_dict.keys(): 121 | tokenizer_dict[f"PR_{pitch_range}"] = len(tokenizer_dict) + 1 122 | phrase_info.append(f"PR_{pitch_range}") 123 | 124 | # Phrase length 125 | phrase_length = len(phrase) 126 | if f"PL_{phrase_length}" not in tokenizer_dict.keys(): 127 | tokenizer_dict[f"PL_{phrase_length}"] = len(tokenizer_dict) + 1 128 | phrase_info.append(f"PL_{phrase_length}") 129 | 130 | # Add phrase info to phrase dictionary 131 | phrase_dict['phrases'][n] = phrase_info 132 | 133 | # Add durations and onsets of each note in the phrase to tokenizer dictionary if it doesn't exist 134 | for note in phrase: 135 | note_duration = round(note[4], 2) 136 | note_onset = round(note[1], 2) 137 | if f"Duration_{note_duration}" not in tokenizer_dict.keys(): 138 | tokenizer_dict[f"Duration_{note_duration}"] = len(tokenizer_dict) + 1 139 | if f"Position_{note_onset}" not in tokenizer_dict.keys(): 140 | tokenizer_dict[f"Position_{note_onset}"] = len(tokenizer_dict) + 1 141 | 142 | # Add time signature to tokenizer dictionary if it doesn't exist 143 | if f"TimeSig_{time_signature}" not in tokenizer_dict.keys(): 144 | tokenizer_dict[f"TimeSig_{time_signature}"] = len(tokenizer_dict) + 1 145 | # Add key signature to tokenizer dictionary if it doesn't exist 146 | if f"KS_{key_signature}" not in tokenizer_dict.keys(): 147 | tokenizer_dict[f"KS_{key_signature}"] = len(tokenizer_dict) + 1 148 | # Add major or minor to tokenizer dictionary if it doesn't exist 149 | if f"MM_{major_or_minor}" not in tokenizer_dict.keys(): 150 | tokenizer_dict[f"MM_{major_or_minor}"] = len(tokenizer_dict) + 1 151 | 152 | phrase_lengths.append(len(phrase)) 153 | 154 | # Write phrases as a json file 155 | midi_file = Path(f"{midi_filepath}.json") 156 | with open(midi_file, "w") as f: 157 | json.dump(phrase_dict, f) 158 | 159 | # Add special tokens to tokenizer dictionary 160 | tokenizer_dict["Bar_None"] = len(tokenizer_dict) + 1 161 | # Add phrase position tokens to tokenizer dictionary 162 | tokenizer_dict["PP_beginning"] = len(tokenizer_dict) + 1 163 | tokenizer_dict["PP_middle"] = len(tokenizer_dict) + 1 164 | tokenizer_dict["PP_end"] = len(tokenizer_dict) + 1 165 | # Add cadence tokens to tokenizer dictionary 166 | tokenizer_dict["CA_True"] = len(tokenizer_dict) + 1 167 | tokenizer_dict["CA_False"] = len(tokenizer_dict) + 1 168 | # Add special tokens to tokenizer dictionary 169 | tokenizer_dict["BOS"] = len(tokenizer_dict) + 1 170 | tokenizer_dict["EOS"] = len(tokenizer_dict) + 1 171 | tokenizer_dict["SEP"] = len(tokenizer_dict) + 1 172 | # Add corruption tokens to tokenizer dictionary 173 | tokenizer_dict["COR_incorrect_transposition"] = len(tokenizer_dict) + 1 174 | tokenizer_dict["COR_incorrect_inversion"] = len(tokenizer_dict) + 1 175 | tokenizer_dict["COR_note_swapping"] = len(tokenizer_dict) + 1 176 | tokenizer_dict["COR_melodic_stripping"] = len(tokenizer_dict) + 1 177 | tokenizer_dict["COR_melodic_addition"] = len(tokenizer_dict) + 1 178 | tokenizer_dict["COR_same_note_modification"] = len(tokenizer_dict) + 1 179 | tokenizer_dict["COR_permute_note_pitch"] = len(tokenizer_dict) + 1 180 | tokenizer_dict["COR_permute_note_duration"] = len(tokenizer_dict) + 1 181 | tokenizer_dict["COR_permute_note_pitch_duration"] = len(tokenizer_dict) + 1 182 | tokenizer_dict["COR_BAR_MASK"] = len(tokenizer_dict) + 1 183 | tokenizer_dict["COR_PITCH_MASK"] = len(tokenizer_dict) + 1 184 | tokenizer_dict["COR_DURATION_MASK"] = len(tokenizer_dict) + 1 185 | tokenizer_dict["COR_FRAGMENT_NOTES"] = len(tokenizer_dict) + 1 186 | # Add special tokens to tokenizer dictionary 187 | tokenizer_dict["UNK"] = len(tokenizer_dict) + 1 188 | 189 | # Save the encoder tokenizer dictionary as a json file 190 | encoder_tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 191 | with open(encoder_tokenizer_filepath, "w") as f: 192 | json.dump(tokenizer_dict, f) 193 | 194 | return phrase_lengths 195 | 196 | 197 | 198 | if __name__ == "__main__": 199 | # Parse command line arguments 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 202 | help="Path to the config file") 203 | args = parser.parse_args() 204 | 205 | # Load config file 206 | with open(args.config, 'r') as f: 207 | configs = yaml.safe_load(f) 208 | 209 | phrase_lengths = main(configs) 210 | 211 | # Describe the phrase lengths 212 | max_phrase_length = max(phrase_lengths) 213 | min_phrase_length = min(phrase_lengths) 214 | average_phrase_length = sum(phrase_lengths) / len(phrase_lengths) 215 | no_of_phrases = len(phrase_lengths) 216 | print(f"Number of phrases: {no_of_phrases}, Max phrase length: {max_phrase_length}, Min phrase length: {min_phrase_length}, Average phrase length: {average_phrase_length}") 217 | 218 | # Save the phrase lengths list in a json file in the artifact folder 219 | phrase_lengths_filepath = os.path.join(configs["raw_data"]["artifact_folder"], "phrase_lengths.json") 220 | 221 | with open(phrase_lengths_filepath, "w") as f: 222 | json.dump(phrase_lengths, f) 223 | print(f"Phrase lengths saved to {phrase_lengths_filepath}") 224 | 225 | -------------------------------------------------------------------------------- /preprocess/preprocess_mono.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import tqdm 3 | import os 4 | import argparse 5 | import yaml 6 | from music21 import harmony, pitch, converter, stream, chord, note, instrument 7 | 8 | class Preprocess: 9 | def __init__(self, folderpath, ouput_folderpath): 10 | self.folderpath = folderpath 11 | self.ouput_folderpath = ouput_folderpath 12 | # Create the output folder if it does not exist 13 | # if not os.path.exists(self.ouput_folderpath): 14 | # os.makedirs(self.ouput_folderpath) 15 | 16 | def process_folder(self, file_extension = "mxl"): 17 | files = glob.glob(os.path.join(self.folderpath, f"**/*.{file_extension}"), recursive=True) 18 | print("Number of files: ", len(files)) 19 | for file in tqdm.tqdm(files): 20 | self.process_file(file, file_extension) 21 | 22 | def process_file(self, filepath, file_extension = "mxl"): 23 | try: 24 | score = converter.parse(filepath) 25 | except (harmony.ChordStepModificationException, pitch.PitchException) as e: 26 | print(f"Skipping file {os.path.basename(filepath)} due to chord parsing error") 27 | return 28 | score = self.change_instrument(score) 29 | mono_stream = self.get_mono_stream(score) 30 | self.write_midi(mono_stream, filepath, file_extension) 31 | 32 | def change_instrument(self, score): 33 | # Change the instrument to piano 34 | for part in score.parts: 35 | if part.getInstrument().midiProgram != 0: 36 | part.insert(1, instrument.Piano()) 37 | part.insert(0, instrument.Piano()) 38 | return score 39 | 40 | def get_mono_stream(self, score): 41 | # Create an empty stream to store the filtered notes 42 | filtered_stream = stream.Score() 43 | 44 | # Iterate through each element in the score 45 | for element in score.recurse(): 46 | # Check if the element is a Note 47 | if 'Note' in element.classes: 48 | # # Check if the Note is in the treble clef (G-clef) range 49 | # if element.activeSite.clef.name == 'treble': 50 | # Add the Note to the filtered stream 51 | filtered_stream.append(element) 52 | 53 | # Check if the element is a Chord 54 | elif 'Chord' in element.classes and 'Harmony' not in element.classes: 55 | # Get the highest pitch in the chord 56 | highest_pitch = max(n.midi for n in element.pitches) 57 | 58 | # Create a new Note with the highest pitch and add it to the filtered stream 59 | filtered_stream.append(note.Note(highest_pitch)) 60 | 61 | # Add rest too 62 | elif 'Rest' in element.classes: 63 | filtered_stream.append(element) 64 | 65 | return filtered_stream 66 | 67 | def write_midi(self, mono_stream, filepath, file_extension): 68 | filename = os.path.basename(filepath).split(f".{file_extension}")[0] + "_mono.mid" 69 | output_filepath = os.path.join(self.ouput_folderpath, filename) 70 | mono_stream.write('midi', fp=output_filepath) 71 | 72 | 73 | if __name__ == "__main__": 74 | # Parse command line arguments 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 77 | help="Path to the config file") 78 | args = parser.parse_args() 79 | 80 | # Load config file 81 | with open(args.config, 'r') as f: 82 | configs = yaml.safe_load(f) 83 | 84 | raw_data_folders = configs["raw_data"]["raw_data_folders"] 85 | output_folder = configs["raw_data"]["mono_folder"] 86 | # Create output folder if it does not exist 87 | if not os.path.exists(output_folder): 88 | os.makedirs(output_folder) 89 | 90 | for dataset_name, dataset_info in raw_data_folders.items(): 91 | print(f"Dataset: {dataset_name}") 92 | print(f"Path: {dataset_info.get('folder_path')}") 93 | print(f"Type: {dataset_info.get('file_extension')}") 94 | 95 | folderpath = dataset_info.get('folder_path') 96 | preprocessor = Preprocess(folderpath, output_folder) 97 | preprocessor.process_folder(file_extension = dataset_info.get('file_extension')) 98 | print("Processed all files in {dataset_name}") -------------------------------------------------------------------------------- /preprocess/write_midi_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | import glob 8 | from tqdm import tqdm 9 | import sys 10 | import jsonlines 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | from utils.utils import annotation_to_encoding, encoding_to_midi 14 | 15 | # Parse command line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 18 | help="Path to the config file") 19 | args = parser.parse_args() 20 | 21 | # Load config file 22 | with open(args.config, 'r') as f: 23 | configs = yaml.safe_load(f) 24 | 25 | # Get the raw data folders 26 | artifact_folder = configs["raw_data"]["artifact_folder"] 27 | mono_folder = configs["raw_data"]["mono_folder"] 28 | json_folder = configs["raw_data"]["json_folder"] 29 | raw_data_folders = configs["raw_data"]["raw_data_folders"] 30 | mono_folder = configs["raw_data"]["mono_folder"] 31 | os.makedirs(mono_folder, exist_ok=True) 32 | 33 | # Get all files from json_folder 34 | json_files = glob.glob(json_folder + "/*.json") 35 | 36 | def load_jsonl(file_path): 37 | data = [] 38 | with jsonlines.open(file_path) as reader: 39 | for line in reader: 40 | data.append(line) 41 | return data 42 | 43 | def extract_phrases(encoding, phrase_annotation): 44 | 45 | # Check if length of encoding is equal to length of phrase_endings 46 | if len(encoding) != len(phrase_annotation["features"]["phrase_end"]): 47 | return [] 48 | 49 | # Find indices where phrase_endings is True 50 | true_indices = [i for i, value in enumerate(phrase_annotation["features"]["phrase_end"]) if value] 51 | 52 | # Segment the encoding based on true_indices 53 | segmented_encoding = [] 54 | start_index = 0 55 | for end_index in true_indices: 56 | segment = encoding[start_index:end_index + 1] 57 | segmented_encoding.append(segment) 58 | start_index = end_index + 1 59 | 60 | # Add the remaining part if any 61 | if start_index < len(encoding): 62 | segmented_encoding.append(encoding[start_index:]) 63 | 64 | return segmented_encoding 65 | 66 | for dataset_name, dataset_info in raw_data_folders.items(): 67 | 68 | annotation_file = load_jsonl(dataset_info.get('annotation_filepath')) 69 | 70 | # Load the midi files that are in phrase annotations 71 | for phrase_annotation in tqdm(annotation_file): 72 | 73 | midi_filepath = os.path.join(mono_folder, phrase_annotation["id"] + '.mid') 74 | 75 | encoding, time_signature, key_signature, major_or_minor = annotation_to_encoding(phrase_annotation) 76 | if len(encoding) == 0: 77 | print(f"Skipping {phrase_annotation['id']} as encoding is most likely corrupt") 78 | continue 79 | 80 | # Extract the phrases from the encoding 81 | phrases = extract_phrases(encoding, phrase_annotation) 82 | 83 | if len(phrases) <= 1: 84 | print(f"Skipping {midi_filepath} as there is only one phrase") 85 | continue 86 | 87 | encoding_to_midi(encoding, {0: 120}, f"TimeSig_{time_signature}", midi_filepath) 88 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Yin-Yang: Developing Motifs With Long-Term Structure and Controllability 2 | 3 | 🌐 [**Demo Website**](https://keshavbhandari.github.io/portfolio/yin-yang.html) 4 | 5 | 📄 [**ArXiv Paper**](https://arxiv.org/abs/2501.17759) 6 | 7 | 🚀 [**Run in Colab**](https://colab.research.google.com/drive/1qsS9pX4grGVVLk4N5W19DZRk9j2ftLN9#scrollTo=byBfZDvFZ5cJ) 8 | 9 | ![Corruption Refinement Training](images/Corruption_Refinement_Training.png) 10 | ![Generation Framework](images/YY_Generation_Framework.png) 11 | --- 12 | 13 | ## Overview 14 | 15 | Yin-Yang is a novel framework for generating music that maintains long-term structure and motivic development. By leveraging a phrase generator, refiner, and selector model, it achieves the following: 16 | 17 | - **Coherent generation** of melodies with long term structure. 18 | - **Controllability** of musical structure and motivic transformations. 19 | - **Semi-interpretable outputs** for musical analysis. 20 | 21 | If you're curious about the research behind this framework, check out the paper: 22 | **Bhandari, K., Wiggins, G. A., & Colton, S. (2025, March). Yin-Yang: Developing Motifs With Long-Term Structure And Controllability.** 23 | *International Conference on Computational Intelligence in Music, Sound, Art and Design (Part of EvoStar)*. 24 | 25 | --- 26 | 27 | ## Setup Instructions 28 | 29 | ### 1. Clone the repository and install dependencies 30 | ```bash 31 | !git clone https://github.com/keshavbhandari/yinyang.git 32 | %cd yinyang 33 | %pip install -r requirements.txt 34 | ``` 35 | 36 | ### 2. Download data and model artifacts 37 | ```bash 38 | import gdown 39 | 40 | # Download dataset 41 | data_url = 'https://drive.google.com/uc?id=1DhtQV0-jVH1lOreqXY5V8L1qHOSO5PwZ' 42 | data_out = '/content/yinyang/data.zip' 43 | gdown.download(data_url, data_out, quiet=False) 44 | 45 | # Download model artifacts 46 | artifacts_url = 'https://drive.google.com/uc?id=1cMGRjUonP3qoKtHA_gqc39GZA4O_zdbh' 47 | artifacts_out = '/content/yinyang/artifacts.zip' 48 | gdown.download(artifacts_url, artifacts_out, quiet=False) 49 | 50 | # Unzip files 51 | !unzip -q /content/yinyang/artifacts.zip -d /content/yinyang/ 52 | !unzip -q /content/yinyang/data.zip -d /content/yinyang/ 53 | ``` 54 | 55 | ### 3. Generate a melody from a motif 56 | Change the generation configs in configs/configs_os.yaml. Look at the examples to change the structure (ABC), phrases per section [4,5,4], use the phrase selection model, specify transformations (optional) and apply key modulations. Then run the following to generate a melody from a phrase containing the motif: 57 | 58 | ```bash 59 | !python generate.py --config configs/configs_os.yaml 60 | ``` 61 | 62 | ### 4. Recreate experiments by training models from scratch 63 | To train individual models, use the following commands: 64 | 65 | - Train the phrase refiner model: 66 | ```bash 67 | !python phrase_refiner/train.py --config configs/configs_os.yaml 68 | ``` 69 | - Train the phrase generator model: 70 | ```bash 71 | !python phrase_generator/train.py --config configs/configs_os.yaml 72 | ``` 73 | - Train the phrase selector model: 74 | ```bash 75 | !python phrase_selector/train.py --config configs/configs_os.yaml 76 | ``` 77 | - Train the structure derivation model: 78 | ```bash 79 | !python structure_derivation/train.py --config configs/configs_os.yaml 80 | ``` 81 | 82 | ## Citation 83 | 84 | If you use this repository in your work, please cite: 85 | 86 | ```plaintext 87 | @inproceedings{bhandari2025yin, 88 | title={Yin-Yang: Developing Motifs With Long-Term Structure and Controllability}, 89 | author={Bhandari, Keshav and Wiggins, Geraint A. and Colton, Simon}, 90 | booktitle={International Conference on Computational Intelligence in Music, Sound, Art and Design (Part of EvoStar)}, 91 | year={2025} 92 | } 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | axial_positional_embedding==0.2.1 2 | einops==0.7.0 3 | evaluate==0.4.1 4 | jsonlines==4.0.0 5 | linformer==0.2.3 6 | local_attention==1.9.0 7 | miditoolkit==0.1.16 8 | music21==9.1.0 9 | numpy==1.26.4 10 | product_key_memory==0.2.10 11 | PyYAML==6.0.1 12 | scikit_learn==1.3.0 13 | torch==2.2.0 14 | tqdm==4.66.2 15 | transformers==4.34.1 16 | vendi_score==0.0.3 17 | -------------------------------------------------------------------------------- /structure_derivation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/structure_derivation/__init__.py -------------------------------------------------------------------------------- /structure_derivation/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import yaml 5 | import copy 6 | import argparse 7 | from torch.utils.data import Dataset 8 | import torch 9 | from torch.nn import functional as F 10 | import sys 11 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 13 | from utils.utils import list_to_remi_encoding, find_beats_in_bar 14 | from phrase_refiner.transformations import Melodic_Development 15 | 16 | 17 | class JSONDataset(Dataset): 18 | def __init__(self, configs, file_list, mode="train", shuffle = False): 19 | self.mode = mode 20 | # Data dir 21 | self.data_dir = configs['raw_data']['json_folder'] 22 | self.file_list = file_list 23 | if shuffle: 24 | random.shuffle(self.file_list) 25 | # Get number of phrases in each file and store in list as [file_name, phrase_number_{n}] 26 | self.file_number_phrase_number = [] 27 | for file_path in self.file_list: 28 | file_path = os.path.join(self.data_dir, file_path) 29 | with open(file_path, 'r') as f: 30 | data = json.load(f) 31 | phrase_number = len(data["phrases"].keys()) 32 | # Exclude the last phrase as this will be target 33 | for i in range(phrase_number-1): 34 | self.file_number_phrase_number.append([file_path, i]) 35 | 36 | # Length of file_number_phrase_number 37 | self.length_phrases = len(self.file_number_phrase_number) 38 | # Length of file_list 39 | self.length_files = len(self.file_list) 40 | 41 | # Print length of dataset 42 | print("Length of dataset: ", self.length_files) 43 | print("Length of phrases in dataset: ", self.length_phrases) 44 | 45 | # Artifact folder 46 | self.artifact_folder = configs['raw_data']['artifact_folder'] 47 | # Load encoder tokenizer json file dictionary 48 | tokenizer_filepath = os.path.join(self.artifact_folder, "tokenizer.json") 49 | # Load the tokenizer dictionary 50 | with open(tokenizer_filepath, 'r') as f: 51 | self.tokenizer = json.load(f) 52 | 53 | # Get the maximum sequence length 54 | self.max_sequence_length = configs['model']['structure_derivation_model']['max_sequence_length'] 55 | 56 | def __len__(self): 57 | return self.length_files 58 | 59 | def transpose(self, phrase, pitch_change): 60 | encoding = copy.deepcopy(phrase) 61 | 62 | transposed_encoding = [ 63 | [event[0], event[1], event[2], event[3] + pitch_change, *event[4:]] 64 | for event in encoding 65 | ] 66 | 67 | return transposed_encoding 68 | 69 | def augment_phrase(self, phrase_1): 70 | if random.random() < 0.5: 71 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 72 | 73 | encoding = copy.deepcopy(phrase_1) 74 | 75 | # Find highest and lowest pitch values 76 | pitch_values = [event[3] for event in encoding] 77 | highest_pitch = max(pitch_values) 78 | lowest_pitch = min(pitch_values) 79 | # Choose a random pitch change value but ensure it is not 0 and within the midi pitch range of 0 to 127 80 | pitch_change = random.choice([i for i in range(-12,12) if i not in [0]]) 81 | while highest_pitch + pitch_change > 127 or lowest_pitch + pitch_change < 0: 82 | if pitch_change < 0: 83 | pitch_change += 1 84 | else: 85 | pitch_change -= 1 86 | 87 | phrase_1 = self.transpose(phrase_1, pitch_change) 88 | 89 | return phrase_1 90 | 91 | def reindex_phrase(self, phrase_1, phrase_2, beats_in_bar): 92 | self.melodic_development_obj = Melodic_Development(beats_in_bar=beats_in_bar) 93 | phrase_2 = self.melodic_development_obj.reindex_bars(phrase_2, start_bar=phrase_1[-1][0]+1) 94 | 95 | return phrase_2 96 | 97 | def __getitem__(self, idx): 98 | file_path = self.file_list[idx] 99 | # Get max phrases in file 100 | with open(os.path.join(self.data_dir, file_path), 'r') as f: 101 | data = json.load(f) 102 | 103 | time_signature = data["metadata"]["time_signature"] 104 | # Get key and major or minor for phrase 105 | key_signature = data["metadata"]["key_signature"] 106 | major_or_minor = data["metadata"]["major_or_minor"] 107 | 108 | # Choose a random phrase number 109 | total_phrases = len(data["phrases"].keys()) 110 | phrase_number = random.randint(0, total_phrases-1) 111 | 112 | phrase_1 = data["phrases"][str(phrase_number)][0] 113 | if random.random() < 0.5: 114 | # Get the phrase and the target from the same file as a positive sample 115 | all_phrases = list(range(0, total_phrases)) 116 | all_phrases.remove(phrase_number) 117 | phrase_number = random.choice(all_phrases) 118 | phrase_2 = data["phrases"][str(phrase_number)][0] 119 | # Reindex phrase_2 to match the last bar + 1 of phrase_1 120 | beats_in_bar = find_beats_in_bar(time_signature) 121 | phrase_2 = self.reindex_phrase(phrase_1, phrase_2, beats_in_bar) 122 | target = torch.tensor(1) 123 | else: 124 | # Choose a random file from self.file_list that is not idx 125 | random_file = random.choice([i for i in range(self.length_files) if i != idx]) 126 | random_file_path = self.file_list[random_file] 127 | with open(os.path.join(self.data_dir, random_file_path), 'r') as f: 128 | random_data = json.load(f) 129 | # Choose a random phrase from the random file as a negative sample 130 | random_phrase_number = random.randint(0, len(random_data["phrases"].keys())-1) 131 | phrase_2 = random_data["phrases"][str(random_phrase_number)][0] 132 | # Reindex phrase_2 to match the last bar + 1 of phrase_1 133 | beats_in_bar = find_beats_in_bar(time_signature) 134 | phrase_2 = self.reindex_phrase(phrase_1, phrase_2, beats_in_bar) 135 | key_signature = random_data["metadata"]["key_signature"] 136 | major_or_minor = random_data["metadata"]["major_or_minor"] 137 | target = torch.tensor(0) 138 | 139 | # Augment the phrases 140 | if self.mode == "train": 141 | phrase_1 = self.augment_phrase(phrase_1) 142 | phrase_2 = self.augment_phrase(phrase_2) 143 | # phrase_1, phrase_2 = self.augment_phrase(phrase_1, phrase_2) 144 | 145 | tempo_location = data["metadata"]["tempo"] 146 | 147 | # Add phrase 1 to phrase 2 148 | phrase = phrase_1 + ["SEP"] + phrase_2 149 | # List to remi encoding 150 | phrase = list_to_remi_encoding(phrase, tempo_location, time_signature) 151 | # Add the BOS and EOS tokens to the phrase 152 | phrase = ["BOS"] + phrase + ["EOS"] 153 | # Tokenize the phrase 154 | phrase = [self.tokenizer[note] for note in phrase if note in self.tokenizer] 155 | 156 | # Convert to tensor and pad the phrase to a fixed length of max_sequence_length if the phrase is shorter than max_sequence_length 157 | phrase = torch.tensor(phrase) 158 | if len(phrase) < self.max_sequence_length: 159 | phrase = F.pad(phrase, (0, self.max_sequence_length - len(phrase))) 160 | else: 161 | phrase = phrase[-self.max_sequence_length:] 162 | # Attention mask based on non-padded tokens of the phrase 163 | phrase_attention_mask = torch.where(phrase != 0, 1, 0) 164 | phrase_attention_mask = phrase_attention_mask.type(torch.bool) 165 | 166 | train_data = {"input_ids": phrase, "labels": target, "attention_mask": phrase_attention_mask} 167 | 168 | return train_data 169 | 170 | 171 | if __name__ == "__main__": 172 | 173 | # Parse command line arguments 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("--config", type=str, default=os.path.normpath("configs/configs_os.yaml"), 176 | help="Path to the config file") 177 | args = parser.parse_args() 178 | 179 | # Load config file 180 | with open(args.config, 'r') as f: 181 | configs = yaml.safe_load(f) 182 | 183 | batch_size = configs['training']['phrase_generation']['batch_size'] 184 | 185 | # Artifact folder 186 | artifact_folder = configs['raw_data']['artifact_folder'] 187 | # Load encoder tokenizer json file dictionary 188 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 189 | # Load the tokenizer dictionary 190 | with open(tokenizer_filepath, 'r') as f: 191 | tokenizer = json.load(f) 192 | 193 | # Open the train, validation, and test sets json files 194 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 195 | train_file_list = json.load(f) 196 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 197 | valid_file_list = json.load(f) 198 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 199 | test_file_list = json.load(f) 200 | 201 | # Print length of train, validation, and test sets 202 | print("Length of train set: ", len(train_file_list)) 203 | print("Length of validation set: ", len(valid_file_list)) 204 | print("Length of test set: ", len(test_file_list)) 205 | 206 | # Load the dataset 207 | dataset = JSONDataset(configs, train_file_list, beats_in_bar=32, mode="train") 208 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 209 | 210 | for n, data in enumerate(dataset): 211 | # print shape and type of tensor 212 | print(data["input_ids"].shape, data["input_ids"].dtype) 213 | print(data["labels"].shape, data["labels"].dtype) 214 | if n > 5: 215 | break -------------------------------------------------------------------------------- /structure_derivation/pitch_evaluator.py: -------------------------------------------------------------------------------- 1 | from transformers import EncoderDecoderModel, AutoModelForSequenceClassification 2 | import torch 3 | from torch.cuda import is_available as cuda_available, is_bf16_supported 4 | import torch.nn.functional as F 5 | import pickle 6 | import yaml 7 | import json 8 | import os 9 | import argparse 10 | import random 11 | import copy 12 | import numpy as np 13 | from vendi_score import vendi 14 | from sklearn.metrics.pairwise import cosine_similarity 15 | import sys 16 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 18 | from utils.utils import parse_midi, list_to_remi_encoding, encoding_to_midi, string_to_list_encoding, find_beats_in_bar 19 | from phrase_refiner.transformations import Melodic_Development 20 | 21 | # Parse command line arguments 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 24 | help="Path to the config file") 25 | args = parser.parse_args() 26 | 27 | # Load config file 28 | with open(args.config, 'r') as f: 29 | configs = yaml.safe_load(f) 30 | 31 | max_sequence_length = configs['model']['structure_derivation_model']['max_sequence_length'] 32 | 33 | # Artifact folder 34 | artifact_folder = configs['raw_data']['artifact_folder'] 35 | 36 | # Load tokenizer json file dictionary 37 | dec_tokenizer_filepath = os.path.join(artifact_folder, 'tokenizer.json') 38 | with open(dec_tokenizer_filepath, 'r') as f: 39 | dec_tokenizer = json.load(f) 40 | reverse_dec_tokenizer = {str(v): k for k, v in dec_tokenizer.items()} 41 | 42 | # # Load the phrase similarity model 43 | # phrase_similarity_model = AutoModelForSequenceClassification.from_pretrained(os.path.join(artifact_folder, "phrase_similarity")) 44 | # phrase_similarity_model.eval() 45 | # phrase_similarity_model.to("cuda" if cuda_available() else "cpu") 46 | 47 | def reindex_phrase(phrase_1, phrase_2, beats_in_bar): 48 | melodic_development_obj = Melodic_Development(beats_in_bar=beats_in_bar) 49 | phrase_2 = melodic_development_obj.reindex_bars(phrase_2, start_bar=phrase_1[-1][0]+1) 50 | 51 | return phrase_2 52 | 53 | # Test folder name 54 | test_folder = "/homes/kb658/PhraseBuilder/output/yin_yang_ablated_low" 55 | # test_folder = "/homes/kb658/PhraseBuilder/data/Mono_Midi_Files" 56 | run_test = False 57 | 58 | if run_test: 59 | # Load test file list 60 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 61 | test_file_list = json.load(f) 62 | # Get all the files in the test folder 63 | test_files = [os.path.basename(file) for file in test_file_list] 64 | # Convert .json to .mid 65 | test_files = [file.replace(".json", ".mid") for file in test_files] 66 | else: 67 | # Get all the files in the test folder 68 | test_files = os.listdir(test_folder) 69 | 70 | avg_pr_per_song = [] 71 | avg_pitches_per_song = [] 72 | # Loop through all the test files 73 | for i, test_file in enumerate(test_files): 74 | 75 | # # Load a random test file (midi file) 76 | # test_file = random.choice(test_files) 77 | print("Test file: ", test_file) 78 | test_file_path = os.path.join(test_folder, test_file) 79 | 80 | # Load the midi file 81 | midi_data, time_signature = parse_midi(test_file_path) 82 | beats_per_bar = find_beats_in_bar(time_signature) 83 | 84 | melodic_development_obj = Melodic_Development(beats_in_bar=beats_per_bar) 85 | # Chop the midi data into bars 86 | bars = melodic_development_obj.group_by_bar(midi_data) 87 | total_bars = len(bars) 88 | 89 | # Load the same test file from extracted_phrases folder within data and extract the first phrase 90 | extracted_phrases_folder = "/homes/kb658/PhraseBuilder/data/extracted_phrases" 91 | extracted_phrases_file = os.path.join(extracted_phrases_folder, test_file.split(".")[0] + ".json") 92 | # Load the extracted phrases 93 | with open(extracted_phrases_file, 'r') as f: 94 | extracted_phrases = json.load(f) 95 | 96 | first_phrase = extracted_phrases["phrases"]['0'][0] 97 | # Get number of bars in the first phrase 98 | num_bars = first_phrase[-1][0] + 1 99 | increment = num_bars 100 | 101 | # Get the bars of phrase 1 102 | phrase_1 = bars[0:num_bars] 103 | # Flatten the phrase 104 | phrase_1 = [note for bar in phrase_1 for note in bar] 105 | print("Phrase 1: ", phrase_1) 106 | 107 | pitch_ranges = [] 108 | while True: 109 | # Get the next two bars as phrase 2 110 | if increment + num_bars <= total_bars: 111 | phrase_2 = bars[increment:increment+num_bars] 112 | increment += num_bars 113 | else: 114 | break 115 | # Flatten the phrase 116 | phrase_2 = [note for bar in phrase_2 for note in bar] 117 | 118 | # Calculate average pitch range of the song 119 | pitches = [] 120 | for note in phrase_2: 121 | pitches.append(note[3]) 122 | pitch_range = max(pitches) - min(pitches) 123 | pitch_ranges.append(pitch_range) 124 | 125 | # Calculate unique pitches in the phrase 126 | unique_pitches = set(pitches) 127 | 128 | avg_pitch_range = sum(pitch_ranges) / len(pitch_ranges) 129 | avg_pr_per_song.append(avg_pitch_range) 130 | print("Average pitch range of bars in the song: ", avg_pitch_range) 131 | avg_pitches_per_song.append(len(unique_pitches)) 132 | print("Unique pitches in the phrase: ", len(unique_pitches)) 133 | 134 | # Print the average probability of the phrases in the test folder 135 | avg_pr = sum(avg_pr_per_song) / len(avg_pr_per_song) 136 | print("Average pitch range of bars in the folder: ", avg_pr) 137 | 138 | # Print the average number of unique pitches in the phrases in the test folder 139 | avg_pitches = sum(avg_pitches_per_song) / len(avg_pitches_per_song) 140 | print("Average number of unique pitches in the folder: ", avg_pitches) -------------------------------------------------------------------------------- /structure_derivation/sd_evaluator.py: -------------------------------------------------------------------------------- 1 | from transformers import EncoderDecoderModel, AutoModelForSequenceClassification 2 | import torch 3 | from torch.cuda import is_available as cuda_available, is_bf16_supported 4 | import torch.nn.functional as F 5 | import pickle 6 | import yaml 7 | import json 8 | import os 9 | import argparse 10 | import random 11 | import copy 12 | import numpy as np 13 | from vendi_score import vendi 14 | from sklearn.metrics.pairwise import cosine_similarity 15 | import sys 16 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 18 | from utils.utils import parse_midi, list_to_remi_encoding, encoding_to_midi, string_to_list_encoding, find_beats_in_bar 19 | from phrase_refiner.transformations import Melodic_Development 20 | 21 | # Parse command line arguments 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 24 | help="Path to the config file") 25 | args = parser.parse_args() 26 | 27 | # Load config file 28 | with open(args.config, 'r') as f: 29 | configs = yaml.safe_load(f) 30 | 31 | max_sequence_length = configs['model']['structure_derivation_model']['max_sequence_length'] 32 | 33 | # Artifact folder 34 | artifact_folder = configs['raw_data']['artifact_folder'] 35 | 36 | # Load tokenizer json file dictionary 37 | dec_tokenizer_filepath = os.path.join(artifact_folder, 'tokenizer.json') 38 | with open(dec_tokenizer_filepath, 'r') as f: 39 | dec_tokenizer = json.load(f) 40 | reverse_dec_tokenizer = {str(v): k for k, v in dec_tokenizer.items()} 41 | 42 | # Load the phrase similarity model 43 | structure_derivation_model = AutoModelForSequenceClassification.from_pretrained(os.path.join(artifact_folder, "phrase_similarity")) 44 | structure_derivation_model.eval() 45 | structure_derivation_model.to("cuda" if cuda_available() else "cpu") 46 | 47 | def reindex_phrase(phrase_1, phrase_2, beats_in_bar): 48 | melodic_development_obj = Melodic_Development(beats_in_bar=beats_in_bar) 49 | phrase_2 = melodic_development_obj.reindex_bars(phrase_2, start_bar=phrase_1[-1][0]+1) 50 | 51 | return phrase_2 52 | 53 | # Test folder name 54 | test_folder = "/homes/kb658/PhraseBuilder/output/yin_yang_ablated_all" 55 | # test_folder = "/homes/kb658/PhraseBuilder/data/Mono_Midi_Files" 56 | run_test = False 57 | 58 | if run_test: 59 | # Load test file list 60 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 61 | test_file_list = json.load(f) 62 | # Get all the files in the test folder 63 | test_files = [os.path.basename(file) for file in test_file_list] 64 | # Convert .json to .mid 65 | test_files = [file.replace(".json", ".mid") for file in test_files] 66 | else: 67 | # Get all the files in the test folder 68 | test_files = os.listdir(test_folder) 69 | 70 | avg_prob_per_song = [] 71 | avg_vendi_per_song = [] 72 | # Loop through all the test files 73 | for i, test_file in enumerate(test_files): 74 | 75 | # # Load a random test file (midi file) 76 | # test_file = random.choice(test_files) 77 | print("Test file: ", test_file) 78 | test_file_path = os.path.join(test_folder, test_file) 79 | 80 | # Load the midi file 81 | midi_data, time_signature = parse_midi(test_file_path) 82 | beats_per_bar = find_beats_in_bar(time_signature) 83 | 84 | melodic_development_obj = Melodic_Development(beats_in_bar=beats_per_bar) 85 | # Chop the midi data into bars 86 | bars = melodic_development_obj.group_by_bar(midi_data) 87 | total_bars = len(bars) 88 | 89 | # Load the same test file from extracted_phrases folder within data and extract the first phrase 90 | extracted_phrases_folder = "/homes/kb658/PhraseBuilder/data/extracted_phrases" 91 | extracted_phrases_file = os.path.join(extracted_phrases_folder, test_file.split(".")[0] + ".json") 92 | # Load the extracted phrases 93 | with open(extracted_phrases_file, 'r') as f: 94 | extracted_phrases = json.load(f) 95 | 96 | first_phrase = extracted_phrases["phrases"]['0'][0] 97 | # Get number of bars in the first phrase 98 | num_bars = first_phrase[-1][0] + 1 99 | increment = num_bars 100 | 101 | # Get the bars of phrase 1 102 | phrase_1 = bars[0:num_bars] 103 | # Flatten the phrase 104 | phrase_1 = [note for bar in phrase_1 for note in bar] 105 | print("Phrase 1: ", phrase_1) 106 | 107 | list_of_probs = [] 108 | list_of_embeddings = [] 109 | 110 | while True: 111 | # Get the next two bars as phrase 2 112 | if increment + num_bars <= total_bars: 113 | phrase_2 = bars[increment:increment+num_bars] 114 | increment += num_bars 115 | else: 116 | break 117 | # Flatten the phrase 118 | phrase_2 = [note for bar in phrase_2 for note in bar] 119 | 120 | # Reindex phrase 2 121 | phrase_2 = reindex_phrase(phrase_1, phrase_2, beats_per_bar) 122 | 123 | # Add phrase 1 to phrase 2 124 | phrase = phrase_1 + ["SEP"] + phrase_2 125 | # List to remi encoding 126 | phrase = list_to_remi_encoding(phrase, {}, time_signature) 127 | # Add the BOS and EOS tokens to the phrase 128 | phrase = ["BOS"] + phrase + ["EOS"] 129 | # Tokenize the phrase 130 | phrase = [dec_tokenizer[note] for note in phrase if note in dec_tokenizer] 131 | 132 | # Convert the phrase to tensor 133 | input_ids = torch.tensor(phrase).unsqueeze(0).to("cuda" if cuda_available() else "cpu") 134 | 135 | output = structure_derivation_model(input_ids, output_hidden_states=True) 136 | logits = output.logits 137 | # Get the probability of the phrase as sigmoid of the logits 138 | prob = F.sigmoid(logits) 139 | prob = prob[-1, -1].item() 140 | # print("Probability of the phrase: ", prob) 141 | list_of_probs.append(prob) 142 | 143 | # Get the hidden states as the phrase embedding 144 | last_hidden_states = output.hidden_states[-1] 145 | embedding = last_hidden_states[0, 0, :] 146 | # print("Phrase embedding: ", embedding.shape) 147 | # Convert the embedding to numpy 148 | embedding = embedding.cpu().detach().numpy() 149 | list_of_embeddings.append(embedding) 150 | 151 | 152 | # Print average probability of the phrases 153 | if len(list_of_probs) == 0: 154 | continue 155 | avg_prob = sum(list_of_probs) / len(list_of_probs) 156 | print("Average probability of the phrases: ", avg_prob) 157 | avg_prob_per_song.append(avg_prob) 158 | 159 | # Calculate the similarity matrix 160 | similarity_matrix = cosine_similarity(list_of_embeddings) 161 | # Print average vendi score of the phrases 162 | avg_vendi_score = vendi.score_K(similarity_matrix) 163 | print("Average vendi score of the phrases: ", avg_vendi_score) 164 | avg_vendi_per_song.append(avg_vendi_score) 165 | 166 | # Print the average probability of the phrases in the test folder 167 | avg_prob_per_song = sum(avg_prob_per_song) / len(avg_prob_per_song) 168 | print("Average probability of the phrases in the test folder: ", avg_prob_per_song) 169 | 170 | avg_vendi_per_song = sum(avg_vendi_per_song) / len(avg_vendi_per_song) 171 | print("Average vendi score of the phrases in the test folder: ", avg_vendi_per_song) -------------------------------------------------------------------------------- /structure_derivation/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import is_available as cuda_available, is_bf16_supported 3 | from torch.backends.mps import is_available as mps_available 4 | import yaml 5 | import json 6 | import os 7 | from torch import Tensor, argmax 8 | from transformers import BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback, BertModel 9 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 10 | from data_loader import JSONDataset 11 | import argparse 12 | from transformers import AutoModelForSequenceClassification 13 | 14 | 15 | # Parse command line arguments 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--config", type=str, default=os.path.normpath("/homes/kb658/yinyang/configs/configs_os.yaml"), 18 | help="Path to the config file") 19 | args = parser.parse_args() 20 | 21 | # Load config file 22 | with open(args.config, 'r') as f: 23 | configs = yaml.safe_load(f) 24 | 25 | batch_size = configs['training']['structure_derivation']['batch_size'] 26 | learning_rate = configs['training']['structure_derivation']['learning_rate'] 27 | epochs = configs['training']['structure_derivation']['epochs'] 28 | 29 | # Artifact folder 30 | artifact_folder = configs['raw_data']['artifact_folder'] 31 | # Load encoder tokenizer json file dictionary 32 | tokenizer_filepath = os.path.join(artifact_folder, "tokenizer.json") 33 | # Load the tokenizer dictionary 34 | with open(tokenizer_filepath, 'r') as f: 35 | tokenizer = json.load(f) 36 | 37 | 38 | # Open the train, validation, and test sets json files 39 | with open(os.path.join(artifact_folder, "train_file_list.json"), "r") as f: 40 | train_file_list = json.load(f) 41 | with open(os.path.join(artifact_folder, "valid_file_list.json"), "r") as f: 42 | valid_file_list = json.load(f) 43 | with open(os.path.join(artifact_folder, "test_file_list.json"), "r") as f: 44 | test_file_list = json.load(f) 45 | 46 | # Print length of train, validation, and test sets 47 | print("Length of train set: ", len(train_file_list)) 48 | print("Length of validation set: ", len(valid_file_list)) 49 | print("Length of test set: ", len(test_file_list)) 50 | 51 | # Load the dataset 52 | train_dataset = JSONDataset(configs, train_file_list, mode="train") 53 | valid_dataset = JSONDataset(configs, valid_file_list, mode="eval") 54 | # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 55 | 56 | # Get the vocab size 57 | vocab_size = len(tokenizer) 58 | # Get the phrase length 59 | train_phrase_length = len(train_dataset.file_number_phrase_number) 60 | 61 | # Create the encoder-decoder model 62 | config_encoder = BertConfig() 63 | config_encoder.vocab_size = vocab_size 64 | config_encoder.max_position_embeddings = configs['model']['structure_derivation_model']['max_sequence_length'] 65 | config_encoder.max_length = configs['model']['structure_derivation_model']['max_sequence_length'] 66 | config_encoder.pad_token_id = 0 67 | config_encoder.bos_token_id = tokenizer["BOS"] 68 | config_encoder.eos_token_id = tokenizer["EOS"] 69 | config_encoder.num_hidden_layers = configs['model']['structure_derivation_model']['num_layers'] 70 | config_encoder.num_attention_heads = configs['model']['structure_derivation_model']['num_heads'] 71 | config_encoder.hidden_size = configs['model']['structure_derivation_model']['hidden_size'] 72 | config_encoder.intermediate_size = configs['model']['structure_derivation_model']['intermediate_size'] 73 | config_encoder.num_labels = 2 74 | 75 | model = AutoModelForSequenceClassification.from_config(config_encoder) 76 | 77 | # Print the number of parameters in the model 78 | num_params = sum(p.numel() for p in model.parameters()) 79 | print(f"Number of parameters in the model: {num_params}") 80 | 81 | # Create config for the Trainer 82 | USE_CUDA = cuda_available() 83 | print(f"USE_CUDA: {USE_CUDA}") 84 | if not cuda_available(): 85 | FP16 = FP16_EVAL = BF16 = BF16_EVAL = False 86 | elif is_bf16_supported(): 87 | BF16 = BF16_EVAL = True 88 | FP16 = FP16_EVAL = False 89 | else: 90 | BF16 = BF16_EVAL = False 91 | FP16 = FP16_EVAL = True 92 | USE_MPS = not USE_CUDA and mps_available() 93 | 94 | # Define the accuracy metric function 95 | def compute_metrics(p): 96 | preds = torch.argmax(torch.from_numpy(p.predictions), axis=1) 97 | return { 98 | 'accuracy': accuracy_score(p.label_ids, preds), 99 | 'precision': precision_recall_fscore_support(p.label_ids, preds, average='binary')[0], 100 | 'recall': precision_recall_fscore_support(p.label_ids, preds, average='binary')[1], 101 | 'f1': precision_recall_fscore_support(p.label_ids, preds, average='binary')[2], 102 | } 103 | 104 | # def compute_metrics(eval_pred): 105 | # predictions, labels = eval_pred 106 | # predictions = np.argmax(predictions, axis=1) 107 | # return accuracy.compute(predictions=predictions, references=labels) 108 | 109 | # Define the training arguments 110 | training_args = TrainingArguments( 111 | output_dir=os.path.join(artifact_folder, "phrase_similarity"), 112 | per_device_train_batch_size=batch_size, 113 | per_device_eval_batch_size=batch_size, 114 | save_strategy="steps", # "steps" or "epoch" 115 | save_steps=500, 116 | save_total_limit=1, 117 | learning_rate=learning_rate, 118 | weight_decay= configs['training']['structure_derivation']['weight_decay'], 119 | max_steps=int(train_phrase_length//batch_size)*epochs, 120 | evaluation_strategy="steps", 121 | eval_steps=500, 122 | gradient_accumulation_steps=configs['training']['structure_derivation']['gradient_accumulation_steps'], 123 | gradient_checkpointing=True, 124 | optim="adafactor", 125 | seed=444, 126 | logging_strategy="steps", 127 | logging_steps=100, 128 | logging_dir=os.path.join(artifact_folder, "phrase_similarity", "logs"), 129 | no_cuda=not USE_CUDA, 130 | fp16=FP16, 131 | fp16_full_eval=FP16_EVAL, 132 | bf16=BF16, 133 | bf16_full_eval=BF16_EVAL, 134 | load_best_model_at_end=True, 135 | metric_for_best_model="eval_loss", 136 | greater_is_better=False, 137 | report_to="tensorboard", 138 | run_name="structure_derivation", 139 | push_to_hub=False 140 | ) 141 | 142 | # Define the Trainer 143 | trainer = Trainer( 144 | model=model, 145 | args=training_args, 146 | train_dataset=train_dataset, 147 | eval_dataset=valid_dataset, 148 | compute_metrics=compute_metrics, 149 | callbacks=[EarlyStoppingCallback(early_stopping_patience=5)] 150 | ) 151 | 152 | # Train and save the model 153 | train_result = trainer.train() 154 | trainer.save_model() # Saves the tokenizer too 155 | trainer.log_metrics("train", train_result.metrics) 156 | trainer.save_metrics("train", train_result.metrics) 157 | trainer.save_state() 158 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/keshavbhandari/yinyang/59a2a001884a1f235ef9a0313565487874f6e3b9/utils/__init__.py --------------------------------------------------------------------------------