├── tests ├── __init__.py ├── midi_files │ ├── Aicha.mid │ ├── empty.mid │ ├── Funkytown.mid │ ├── Maestro_1.mid │ ├── Maestro_2.mid │ ├── Maestro_3.mid │ ├── Maestro_4.mid │ ├── Maestro_5.mid │ ├── Maestro_6.mid │ ├── Maestro_7.mid │ ├── Maestro_8.mid │ ├── Maestro_9.mid │ ├── Shut Up.mid │ ├── test_midi.mid │ ├── In Too Deep.mid │ ├── Maestro_10.mid │ ├── POP909_008.mid │ ├── POP909_010.mid │ ├── POP909_022.mid │ ├── POP909_191.mid │ ├── Mr. Blue Sky.mid │ ├── I Gotta Feeling.mid │ ├── 6338816_Etude No. 4.mid │ ├── Les Yeux Revolvers.mid │ ├── 6354774_Macabre Waltz.mid │ ├── All The Small Things.mid │ ├── What a Fool Believes.mid │ ├── Girls Just Want to Have Fun.mid │ ├── d6caebd1964d9e4a3c5ea59525230e2a.mid │ └── d8faddb8596fff7abb24d78666f73e4e.mid ├── utils_tests.py ├── test_run.py └── test_nomml.py ├── Giga_MIDI_Logo_Final.png ├── loops_nomml ├── loop_ex_labeled.png ├── __init__.py ├── README.md ├── nomml.py ├── note_set.py ├── note_set_fast.py ├── process_file.py ├── process_file_fast.py ├── corr_mat.py └── corr_mat_fast.py ├── scripts ├── figures │ └── GigaMIDI_duration_bars.pdf ├── nomml_gigamidi.py ├── main_expressive.py ├── analyze_gigamidi_dataset.py ├── dataset_programs_dist.py ├── test_extract.ipynb └── dataset_short_files.py ├── Data Source Links for the GigaMIDI Dataset - Sheet1.pdf ├── .gitmodules ├── Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models ├── Saved Logistic Regression Models │ ├── model-DNVR.pkl │ ├── model-DNODR.pkl │ └── model-NOMML.pkl ├── Curated Evaluation Set │ ├── MIDI Score (Non-expressive MIDI Tracks) │ │ ├── ASAP-Score-only.numbers │ │ ├── ATEPP-Score-Final.numbers │ │ └── ATEPP-Score-Final-refinement.numbers │ └── MIDI Performance (Expressive MIDI Tracks) │ │ ├── Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers │ │ └── Saarland.csv └── Optimal Threshold Selection │ └── Expressive_Performance_Detection_Training-Percentile-threshold.csv ├── LICENSE ├── main.py ├── pyproject.toml ├── MIDI-GPT-Loop └── README.md ├── Expressive music loop detector-NOMML12.ipynb ├── GigaMIDI ├── create_gigamidi_dataset.py ├── GigaMIDI.py └── README.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Giga_MIDI_Logo_Final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Giga_MIDI_Logo_Final.png -------------------------------------------------------------------------------- /tests/midi_files/Aicha.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Aicha.mid -------------------------------------------------------------------------------- /tests/midi_files/empty.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/empty.mid -------------------------------------------------------------------------------- /tests/midi_files/Funkytown.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Funkytown.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_1.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_2.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_3.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_4.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_5.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_6.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_6.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_7.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_7.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_8.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_8.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_9.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_9.mid -------------------------------------------------------------------------------- /tests/midi_files/Shut Up.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Shut Up.mid -------------------------------------------------------------------------------- /tests/midi_files/test_midi.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/test_midi.mid -------------------------------------------------------------------------------- /loops_nomml/loop_ex_labeled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/loops_nomml/loop_ex_labeled.png -------------------------------------------------------------------------------- /tests/midi_files/In Too Deep.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/In Too Deep.mid -------------------------------------------------------------------------------- /tests/midi_files/Maestro_10.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_10.mid -------------------------------------------------------------------------------- /tests/midi_files/POP909_008.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_008.mid -------------------------------------------------------------------------------- /tests/midi_files/POP909_010.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_010.mid -------------------------------------------------------------------------------- /tests/midi_files/POP909_022.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_022.mid -------------------------------------------------------------------------------- /tests/midi_files/POP909_191.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_191.mid -------------------------------------------------------------------------------- /tests/midi_files/Mr. Blue Sky.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Mr. Blue Sky.mid -------------------------------------------------------------------------------- /tests/midi_files/I Gotta Feeling.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/I Gotta Feeling.mid -------------------------------------------------------------------------------- /tests/midi_files/6338816_Etude No. 4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/6338816_Etude No. 4.mid -------------------------------------------------------------------------------- /tests/midi_files/Les Yeux Revolvers.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Les Yeux Revolvers.mid -------------------------------------------------------------------------------- /scripts/figures/GigaMIDI_duration_bars.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/scripts/figures/GigaMIDI_duration_bars.pdf -------------------------------------------------------------------------------- /tests/midi_files/6354774_Macabre Waltz.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/6354774_Macabre Waltz.mid -------------------------------------------------------------------------------- /tests/midi_files/All The Small Things.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/All The Small Things.mid -------------------------------------------------------------------------------- /tests/midi_files/What a Fool Believes.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/What a Fool Believes.mid -------------------------------------------------------------------------------- /tests/midi_files/Girls Just Want to Have Fun.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Girls Just Want to Have Fun.mid -------------------------------------------------------------------------------- /Data Source Links for the GigaMIDI Dataset - Sheet1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Data Source Links for the GigaMIDI Dataset - Sheet1.pdf -------------------------------------------------------------------------------- /tests/midi_files/d6caebd1964d9e4a3c5ea59525230e2a.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/d6caebd1964d9e4a3c5ea59525230e2a.mid -------------------------------------------------------------------------------- /tests/midi_files/d8faddb8596fff7abb24d78666f73e4e.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/d8faddb8596fff7abb24d78666f73e4e.mid -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MIDI-GPT-Loop/MMM-Loop"] 2 | path = MIDI-GPT-Loop/MMM-Loop 3 | url = git@github.com:Metacreation-Lab/MMM-Loop.git 4 | [submodule "MIDI-GPT-Loop/MMM"] 5 | path = MIDI-GPT-Loop/MMM 6 | url = git@github.com:Metacreation-Lab/MMM.git 7 | branch = Expressive-Loops 8 | -------------------------------------------------------------------------------- /loops_nomml/__init__.py: -------------------------------------------------------------------------------- 1 | """Main module.""" 2 | 3 | from .process_file import detect_loops, detect_loops_from_path 4 | from .nomml import get_median_metric_depth 5 | 6 | __all__ = [ 7 | "get_median_metric_depth", 8 | "detect_loops", 9 | "detect_loops_from_path" 10 | ] 11 | -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNVR.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNVR.pkl -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNODR.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNODR.pkl -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-NOMML.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-NOMML.pkl -------------------------------------------------------------------------------- /tests/utils_tests.py: -------------------------------------------------------------------------------- 1 | """Test validation methods.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | 7 | SEED = 777 8 | 9 | HERE = Path(__file__).parent 10 | MIDI_PATHS_ALL = sorted((HERE / "midi_files").rglob("*.mid")) 11 | # MIDI_PATHS_ALL = [MIDI_PATHS_ALL[-1]] 12 | TEST_LOG_DIR = HERE / "test_logs" 13 | -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ASAP-Score-only.numbers: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ASAP-Score-only.numbers -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final.numbers: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final.numbers -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final-refinement.numbers: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final-refinement.numbers -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers -------------------------------------------------------------------------------- /tests/test_run.py: -------------------------------------------------------------------------------- 1 | """Testing tokenization, making sure the data integrity is not altered.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | 7 | import pytest 8 | from symusic import Score 9 | 10 | from loops_nomml import detect_loops 11 | 12 | from .utils_tests import MIDI_PATHS_ALL 13 | 14 | 15 | @pytest.mark.parametrize("file_path", MIDI_PATHS_ALL, ids=lambda p: p.name) 16 | def test_one_track_midi_to_tokens_to_midi(file_path: Path): 17 | score = Score(file_path) 18 | _ = detect_loops(score) 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # GigaMIDI Dataset Licensing Information 2 | The dataset is distributed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. 3 | This license permits users to share, adapt, and utilize the dataset exclusively for non-commercial purposes, including research 4 | and educational applications, provided that proper attribution is given to the original creators. 5 | By adhering to the terms of CC BY-NC 4.0, users ensure the dataset's responsible use while fostering its accessibility 6 | for academic and non-commercial endeavours. 7 | -------------------------------------------------------------------------------- /loops_nomml/README.md: -------------------------------------------------------------------------------- 1 | # Fast loops extractor (CPU/GPU accelerated) 2 | 3 | ## Requirements 4 | 5 | ```sh 6 | pip install numba 7 | pip install numpy==1.24.4 8 | pip install pretty-midi 9 | pip install symusic 10 | pip install miditok 11 | ``` 12 | 13 | ## Basic use example 14 | 15 | ```python 16 | import os 17 | 18 | # Set desired environment variable 19 | os.environ["USE_NUMBA"] = "1" 20 | os.environ["USE_CUDA"] = "1" 21 | 22 | # Check the variable 23 | print(os.environ["USE_NUMBA"]) 24 | print(os.environ["USE_CUDA"]) 25 | 26 | from process_file_fast import detect_loops_from_path 27 | 28 | midi_file = './your_midi_file.mid' 29 | 30 | result = detect_loops_from_path({'file_path': [midi_file]}) 31 | ``` 32 | 33 | ## Enjoy! :) 34 | -------------------------------------------------------------------------------- /loops_nomml/nomml.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | import numpy as np 5 | from symusic import Score 6 | 7 | 8 | def get_metric_depth(time, tpq, max_depth=6): 9 | for i in range(max_depth): 10 | period = tpq / int(2 ** i) 11 | if time % period == 0: 12 | return 2 * i 13 | for i in range(max_depth): 14 | period = tpq * 2 / (int(2 ** i) * 3) 15 | if time % period == 0: 16 | return 2 * i + 1 17 | return max_depth * 2 18 | 19 | 20 | def get_median_metric_depth(path): 21 | mf = Score(path) 22 | median_metric_depths = [] 23 | for track in mf.tracks: 24 | metric_depths = [get_metric_depth(event.time, mf.tpq) for event in track.notes] 25 | if len(metric_depths) > 0: 26 | median_metric_depths.append(int(np.median(metric_depths))) 27 | return path, median_metric_depths 28 | -------------------------------------------------------------------------------- /scripts/nomml_gigamidi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import random 5 | from tqdm import tqdm 6 | from multiprocessing import Pool 7 | 8 | from loops_nomml import get_median_metric_depth 9 | 10 | 11 | def load_json(path): 12 | with open(path, "r") as f: 13 | return json.load(f) 14 | 15 | 16 | def dump_json(x, path): 17 | with open(path, "w") as f: 18 | json.dump(x, f, indent=4) 19 | 20 | 21 | def worker(args): 22 | try: 23 | return get_median_metric_depth(*args) 24 | except Exception as e: 25 | print("FAILED : ", e) 26 | return None 27 | 28 | 29 | def main(folder, force=False, nthreads=8): 30 | output_path = os.path.basename(folder).lower() + ".json" 31 | if os.path.exists(output_path) and not force: 32 | return load_json(output_path) 33 | 34 | paths = [glob.glob(folder + f"/**/*.{ext}", recursive=True) for ext in ["mid", "midi", "MID"]] 35 | paths = [p for sublist in paths for p in sublist] 36 | random.shuffle(paths) 37 | 38 | count = 0 39 | result = {} 40 | p = Pool(nthreads) 41 | for path, median_metric_depths in filter(lambda x: x is not None, 42 | tqdm(p.imap_unordered(worker, [(p,) for p in paths]), total=len(paths))): 43 | result[os.path.relpath(path, folder)] = median_metric_depths 44 | if count % 50000 == 0: 45 | dump_json(result, output_path) 46 | count += 1 47 | dump_json(result, output_path) 48 | return result 49 | 50 | 51 | if __name__ == "__main__": 52 | import argparse 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--folder", type=str, default="", required=True) 56 | parser.add_argument("--force", action="store_true") 57 | parser.add_argument("--nthreads", type=int, default=8) 58 | args = parser.parse_args() 59 | main(args.folder, force=args.force, nthreads=args.nthreads) 60 | -------------------------------------------------------------------------------- /tests/test_nomml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | from symusic import Score 9 | 10 | from loops_nomml import get_median_metric_depth 11 | 12 | 13 | def load_json(path): 14 | with open(path, "r") as f: 15 | return json.load(f) 16 | 17 | 18 | def dump_json(x, path): 19 | with open(path, "w") as f: 20 | json.dump(x, f, indent=4) 21 | 22 | 23 | def worker(args): 24 | try: 25 | return get_median_metric_depth(*args) 26 | except Exception as e: 27 | print("FAILED : ", e) 28 | return None 29 | 30 | 31 | #TODO test method 32 | 33 | def main(folder, force=False, nthreads=8): 34 | output_path = os.path.basename(folder).lower() + ".json" 35 | if os.path.exists(output_path) and not force: 36 | return load_json(output_path) 37 | 38 | paths = [glob.glob(folder + f"/**/*.{ext}", recursive=True) for ext in ["mid", "midi", "MID"]] 39 | paths = [p for sublist in paths for p in sublist] 40 | random.shuffle(paths) 41 | 42 | count = 0 43 | result = {} 44 | p = Pool(nthreads) 45 | for path, median_metric_depths in filter(lambda x: x is not None, 46 | tqdm(p.imap_unordered(worker, [(p,) for p in paths]), total=len(paths))): 47 | result[os.path.relpath(path, folder)] = median_metric_depths 48 | if count % 50000 == 0: 49 | dump_json(result, output_path) 50 | count += 1 51 | dump_json(result, output_path) 52 | return result 53 | 54 | 55 | if __name__ == "__main__": 56 | import argparse 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--folder", type=str, default="", required=True) 60 | parser.add_argument("--force", action="store_true") 61 | parser.add_argument("--nthreads", type=int, default=8) 62 | args = parser.parse_args() 63 | main(args.folder, force=args.force, nthreads=args.nthreads) 64 | -------------------------------------------------------------------------------- /scripts/main_expressive.py: -------------------------------------------------------------------------------- 1 | from loops_nomml.process_file import detect_loops 2 | import os 3 | from datasets import Dataset, load_dataset 4 | 5 | DATA_PATH = "D:\\Documents\\GigaMIDI" 6 | METADATA_NAME = "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv" 7 | SHARD_SIZE = 20000 8 | OUTPUT_NAME = "gigamidi_expressive_loops_no_quant" 9 | 10 | if __name__ == "__main__": 11 | metadata_path = os.path.join(DATA_PATH, METADATA_NAME) 12 | output_path = os.path.join(DATA_PATH, OUTPUT_NAME) 13 | if not os.path.exists(output_path): 14 | os.mkdir(output_path) 15 | 16 | dataset = load_dataset("csv", data_files=metadata_path)['train'] 17 | print(f"loaded {len(dataset)} tracks") 18 | dataset_expressive = dataset.filter(lambda row: row['medianMetricDepth'] == 12) 19 | print(f"filtered to {len(dataset_expressive)} expressive tracks") 20 | dataset_with_time_signature = dataset_expressive.filter(lambda row: row['hasTimeSignatures']) 21 | print(f"filtered to {len(dataset_with_time_signature)} expressive tracks with time signatures") 22 | 23 | unique_files = dataset_with_time_signature.unique('filepath') 24 | unique_files = [{"file_path": os.path.join(DATA_PATH, file_path), "file_name": file_path} for file_path in unique_files] 25 | unique_files_dataset = Dataset.from_list(unique_files) 26 | print(f"filtered to {len(unique_files_dataset)} unique MIDI files, expressive with time signatures") 27 | unique_files_dataset = unique_files_dataset.shuffle(seed=42) 28 | 29 | num_shards = int(round(len(unique_files_dataset) / SHARD_SIZE)) 30 | print(f"Splitting dataset in {num_shards} shards") 31 | print(f"Saving shards to {output_path}") 32 | for shard_idx in range(num_shards): 33 | shard = unique_files_dataset.shard(num_shards=num_shards, index=shard_idx) 34 | shard = shard.map( 35 | detect_loops, 36 | remove_columns=['file_path', 'file_name'], 37 | batched=True, 38 | batch_size=1, 39 | writer_batch_size=1, 40 | num_proc=8 41 | ) 42 | 43 | csv_path = os.path.join(output_path, OUTPUT_NAME + "_" + str(shard_idx) + ".csv") 44 | shard.to_csv(csv_path) 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from loops_nomml.process_file import detect_loops_from_path 2 | import os 3 | from datasets import Dataset, load_dataset 4 | 5 | DATA_PATH = "D:\\Documents\\GigaMIDI" 6 | METADATA_NAME = "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv" 7 | SHARD_SIZE = 20000 8 | OUTPUT_NAME = "gigamidi_non_expressive_loops" 9 | 10 | if __name__ == "__main__": 11 | metadata_path = os.path.join(DATA_PATH, METADATA_NAME) 12 | output_path = os.path.join(DATA_PATH, OUTPUT_NAME) 13 | if not os.path.exists(output_path): 14 | os.mkdir(output_path) 15 | 16 | dataset = load_dataset("csv", data_files=metadata_path)['train'] 17 | print(f"loaded {len(dataset)} tracks") 18 | dataset_non_expressive = dataset.filter(lambda row: row['medianMetricDepth'] < 12) 19 | print(f"filtered to {len(dataset_non_expressive)} non-expressive tracks") 20 | dataset_with_time_signature = dataset_non_expressive.filter(lambda row: row['hasTimeSignatures']) 21 | print(f"filtered to {len(dataset_with_time_signature)} non-expressive tracks with time signatures") 22 | 23 | unique_files = dataset_with_time_signature.unique('filepath') 24 | unique_files = [{"file_path": os.path.join(DATA_PATH, file_path), "file_name": file_path} for file_path in unique_files] 25 | unique_files_dataset = Dataset.from_list(unique_files) 26 | print(f"filtered to {len(unique_files_dataset)} unique MIDI files, non-expressive with time signatures") 27 | unique_files_dataset = unique_files_dataset.shuffle(seed=42) 28 | 29 | num_shards = int(round(len(unique_files_dataset) / SHARD_SIZE)) 30 | print(f"Splitting dataset in {num_shards} shards") 31 | print(f"Saving shards to {output_path}") 32 | for shard_idx in range(0, num_shards): 33 | shard = unique_files_dataset.shard(num_shards=num_shards, index=shard_idx) 34 | shard = shard.map( 35 | detect_loops_from_path, 36 | remove_columns=['file_path', 'file_name'], 37 | batched=True, 38 | batch_size=1, 39 | writer_batch_size=1, 40 | num_proc=8 41 | ) 42 | 43 | csv_path = os.path.join(output_path, OUTPUT_NAME + "_" + str(shard_idx) + ".csv") 44 | shard.to_csv(csv_path) 45 | -------------------------------------------------------------------------------- /scripts/analyze_gigamidi_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 python 2 | 3 | """Measuring the length of GigaMIDI files in bars.""" 4 | 5 | if __name__ == "__main__": 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | from datasets import concatenate_datasets, load_dataset 10 | from matplotlib import pyplot as plt 11 | from miditok.constants import SCORE_LOADING_EXCEPTION 12 | from miditok.utils import get_bars_ticks 13 | from symusic import Score 14 | from tqdm import tqdm 15 | from transformers import set_seed 16 | 17 | from utils.GigaMIDI.GigaMIDI import _SUBSETS 18 | from utils.utils import path_data_directory_local_fs 19 | 20 | SEED = 777 21 | NUM_FILES = 5000 22 | NUM_HIST_BINS = 25 23 | X_AXIS_LIM_BARS = 300 24 | X_AXIS_LIM_BEATS = X_AXIS_LIM_BARS * 4 25 | set_seed(SEED) 26 | FIGURES_PATH = Path("scripts", "analysis", "figures") 27 | 28 | # Measuring lengths in bars/beats 29 | dist_lengths_bars = {} 30 | for subset_name in _SUBSETS: 31 | subset = concatenate_datasets( 32 | list( 33 | load_dataset( 34 | str(path_data_directory_local_fs() / "GigaMIDI"), 35 | subset_name, 36 | trust_remote_code=True, 37 | ).values() 38 | ) 39 | ).shuffle() 40 | dist_lengths_bars[subset_name] = [] 41 | 42 | idx = 0 43 | with tqdm(total=NUM_FILES, desc=f"Analyzing `{subset_name}` subset") as pbar: 44 | while len(dist_lengths_bars[subset_name]) < NUM_FILES and idx < len(subset): 45 | try: 46 | score = Score.from_midi(subset[idx]["music"]["bytes"]) 47 | dist_lengths_bars[subset_name].append(len(get_bars_ticks(score))) 48 | pbar.update(1) 49 | except SCORE_LOADING_EXCEPTION: 50 | pass 51 | finally: 52 | idx += 1 53 | 54 | # Plotting length (bars) distribution 55 | for subset_name in dist_lengths_bars: 56 | dist_arr = np.array(dist_lengths_bars[subset_name]) 57 | dist_lengths_bars[subset_name] = dist_arr[ 58 | np.where(dist_arr <= X_AXIS_LIM_BARS)[0] 59 | ] 60 | fig, ax = plt.subplots() 61 | ax.hist( 62 | dist_lengths_bars.values(), 63 | bins=NUM_HIST_BINS, 64 | density=True, 65 | label=dist_lengths_bars.keys(), 66 | ) 67 | ax.grid(axis="y", linestyle="--", linewidth=0.6) 68 | ax.legend(prop={"size": 10}) 69 | ax.set_ylabel("Density") 70 | ax.set_xlabel("Duration in bars") 71 | fig.savefig( 72 | FIGURES_PATH / "GigaMIDI_duration_bars.pdf", bbox_inches="tight", dpi=300 73 | ) 74 | plt.close(fig) 75 | -------------------------------------------------------------------------------- /loops_nomml/note_set.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | from typing import TYPE_CHECKING 5 | 6 | if TYPE_CHECKING: 7 | from collections.abc import Sequence 8 | 9 | from symusic import Note 10 | from symusic.core import NoteTickList 11 | 12 | 13 | def compute_note_sets(notes: NoteTickList, bars_ticks: Sequence[int]) -> list[NoteSet]: 14 | """ 15 | Converts a list of MIDI notes and associated measure start times into 16 | a list of NoteSets. Barlines will be represented as empty NoteSets 17 | with a duration of 0 18 | 19 | :param notes: list of MIDI notes in a single track 20 | :param bar ticks: list of measure start times in ticks 21 | :return: NoteSet representation of the MIDI track 22 | """ 23 | processed_notes = [] 24 | for note in notes: 25 | start_new_set = len(processed_notes) == 0 or not processed_notes[-1].fits_in_set(note.start, note.end) 26 | if start_new_set: 27 | processed_notes.append(NoteSet(start=note.start, end=note.end)) 28 | processed_notes[-1].add_note(note) 29 | 30 | notes = processed_notes + [NoteSet(start=db, end=db) for db in bars_ticks] 31 | notes.sort() 32 | return notes 33 | 34 | 35 | class NoteSet: 36 | """ 37 | A set of unique pitches that occur at the same start time and end at 38 | the same time (have the same duration). 39 | 40 | If a NoteSet has no pitches are a duration of 0, it represents the 41 | start of a measure (ie a barline) 42 | 43 | :param start: start time in MIDI ticks 44 | :param end: end time in MIDI ticks 45 | """ 46 | def __init__(self, start: int, end: int) -> None: 47 | self.start = start 48 | self.end = end 49 | self.duration = self.end - self.start 50 | self.pitches = set() # MIDI note numbers 51 | 52 | def add_note(self, note: Note) -> None: 53 | self.pitches.add(note.pitch) 54 | 55 | def fits_in_set(self, start: int, end: int) -> bool: 56 | return start == self.start and end == self.end 57 | 58 | def is_barline(self) -> bool: 59 | return self.start == self.end and len(self.pitches) == 0 60 | 61 | def __str__(self) -> str: 62 | return f"NoteSet({self.start}, {self.duration}, {self.pitches})" 63 | 64 | def __eq__(self, other: object) -> bool: 65 | """ 66 | Two NoteSets are equal if they match in start time, end time and 67 | MIDI pitches present 68 | """ 69 | if not isinstance(other, NoteSet): 70 | return False 71 | 72 | if self.duration != other.duration: 73 | return False 74 | if len(self.pitches) != len(other.pitches): 75 | return False 76 | 77 | for m in self.pitches: 78 | if m not in other.pitches: 79 | return False 80 | 81 | return True 82 | 83 | def __lt__(self, other: object): 84 | """ 85 | A NoteSet is sorted based on start time 86 | """ 87 | return self.start < other.start 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [build-system] 3 | requires = ["hatchling"] 4 | build-backend = "hatchling.build" 5 | 6 | [project] 7 | name = "GigaMIDI" 8 | version = "0.0.1" 9 | description = "Loop and expressive performance detection methods for symbolic music." 10 | readme = {file = "README.md", content-type = "text/markdown"} 11 | license = {file = "LICENSE"} 12 | requires-python = ">=3.8.0" 13 | authors = [ 14 | ] 15 | keywords = [ 16 | "artificial intelligence", 17 | "deep learning", 18 | "transformer", 19 | "midi", 20 | "music", 21 | "mir", 22 | ] 23 | classifiers = [ 24 | "Intended Audience :: Developers", 25 | "Intended Audience :: Science/Research", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Topic :: Multimedia :: Sound/Audio :: MIDI", 28 | "License :: OSI Approved :: MIT License", 29 | "Programming Language :: Python", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Programming Language :: Python :: 3.12", 35 | "Operating System :: OS Independent", 36 | ] 37 | dependencies = [ 38 | "numpy==1.26.3", 39 | "datasets>=2.20.0", 40 | "symusic>=0.5.0", 41 | "miditok>3.0.3", 42 | ] 43 | 44 | [project.urls] 45 | Homepage = "https://github.com/" 46 | 47 | [tool.hatch.version] 48 | path = "gigamidi/__init__.py" 49 | 50 | [tool.hatch.build.targets.sdist] 51 | include = [ 52 | "/gigamidi", 53 | ] 54 | 55 | [mypy] 56 | warn_return_any = "True" 57 | warn_unused_configs = "True" 58 | plugins = "numpy.typing.mypy_plugin" 59 | exclude = [ 60 | "venv", 61 | ".venv", 62 | ] 63 | 64 | [tool.ruff] 65 | target-version = "py312" 66 | 67 | [tool.ruff.lint] 68 | extend-select = [ 69 | "ARG", 70 | "A", 71 | "ANN", 72 | "B", 73 | "BLE", 74 | "C4", 75 | "COM", 76 | "D", 77 | "E", 78 | "EM", 79 | "EXE", 80 | "F", 81 | "FA", 82 | "FBT", 83 | "G", 84 | "I", 85 | "ICN", 86 | "INP", 87 | "INT", 88 | "ISC", 89 | "N", 90 | "NPY", 91 | "PERF", 92 | "PGH", 93 | "PTH", 94 | "PIE", 95 | # "PL", 96 | "PT", 97 | "Q", 98 | "RET", 99 | "RSE", 100 | "RUF", 101 | "S", 102 | # "SLF", 103 | "SIM", 104 | "T", 105 | "TCH", 106 | "TID", 107 | "UP", 108 | "W", 109 | ] 110 | 111 | ignore = [ 112 | "ANN003", 113 | "ANN101", 114 | "ANN102", 115 | "B905", 116 | "COM812", 117 | "D107", 118 | "D203", 119 | "D212", 120 | "FBT001", 121 | "FBT002", 122 | "UP038", 123 | "S105", 124 | "S311", 125 | ] 126 | 127 | [tool.ruff.lint.per-file-ignores] 128 | "tests/**" = [ 129 | "ANN201", # allow no return type hint for pytest methods 130 | "D103", # no need to document pytest methods 131 | "S101", # allow assertions in tests 132 | "T201", # print allowed 133 | ] 134 | "docs/conf.py" = ["INP001"] # not a package 135 | -------------------------------------------------------------------------------- /loops_nomml/note_set_fast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Sequence 4 | 5 | if TYPE_CHECKING: 6 | from symusic import Note 7 | from symusic.core import NoteTickList 8 | 9 | 10 | class NoteSet: 11 | __slots__ = ("start", "end", "duration", "pitches") 12 | 13 | def __init__(self, start: int, end: int) -> None: 14 | s = int(start) 15 | e = int(end) 16 | self.start = s 17 | self.end = e 18 | self.duration = int(e - s) 19 | self.pitches = set() 20 | 21 | def add_note(self, note: "Note") -> None: 22 | try: 23 | self.pitches.add(int(note.pitch)) 24 | except Exception: 25 | pass 26 | 27 | def fits_in_set(self, start: int, end: int) -> bool: 28 | return int(start) == self.start and int(end) == self.end 29 | 30 | def is_barline(self) -> bool: 31 | return self.start == self.end and len(self.pitches) == 0 32 | 33 | def __str__(self) -> str: 34 | return f"NoteSet({self.start}, {self.duration}, {self.pitches})" 35 | 36 | def __eq__(self, other: object) -> bool: 37 | if not isinstance(other, NoteSet): 38 | return False 39 | if self.duration != other.duration: 40 | return False 41 | if len(self.pitches) != len(other.pitches): 42 | return False 43 | return self.pitches == other.pitches 44 | 45 | def __lt__(self, other: object): 46 | return self.start < other.start 47 | 48 | 49 | _MAX_DURATION_TICKS = 4_000_000 50 | _MAX_GAP_TICKS = 4_000_000 51 | _MIN_NOTE_DURATION = 0 52 | 53 | # New: reject implausible tick values (malformed events often have huge tick numbers) 54 | _MAX_TICK = 10 ** 9 # 1 billion ticks; adjust upward if your dataset legitimately uses larger ticks 55 | 56 | 57 | def compute_note_sets(notes: NoteTickList, bars_ticks: Sequence[int]) -> list[NoteSet]: 58 | processed_notes = [] 59 | valid_notes = [] 60 | append_valid = valid_notes.append 61 | for note in notes: 62 | try: 63 | s = int(note.start) 64 | e = int(note.end) 65 | except Exception: 66 | # malformed note event: skip 67 | continue 68 | # sanity checks: non-negative, reasonable magnitude 69 | if s < 0 or e < 0: 70 | continue 71 | if s > _MAX_TICK or e > _MAX_TICK: 72 | # suspiciously large tick values -> skip this note 73 | continue 74 | dur = e - s 75 | if dur < _MIN_NOTE_DURATION: 76 | continue 77 | if dur > _MAX_DURATION_TICKS: 78 | continue 79 | append_valid(note) 80 | 81 | if not valid_notes and not bars_ticks: 82 | return [] 83 | 84 | try: 85 | valid_notes.sort(key=lambda n: (int(n.start), int(n.end))) 86 | except Exception: 87 | valid_notes = sorted(valid_notes, key=lambda n: (getattr(n, "start", 0) or 0, getattr(n, "end", 0) or 0)) 88 | 89 | for note in valid_notes: 90 | try: 91 | start = int(note.start) 92 | end = int(note.end) 93 | except Exception: 94 | continue 95 | start_new_set = len(processed_notes) == 0 or not processed_notes[-1].fits_in_set(start, end) 96 | if start_new_set: 97 | processed_notes.append(NoteSet(start=start, end=end)) 98 | processed_notes[-1].add_note(note) 99 | 100 | bars_clean = [] 101 | for b in bars_ticks: 102 | try: 103 | bi = int(b) 104 | except Exception: 105 | continue 106 | if bi < 0: 107 | continue 108 | if bi > _MAX_TICK: 109 | # skip implausible bar tick 110 | continue 111 | bars_clean.append(bi) 112 | if bars_clean: 113 | try: 114 | bars_clean = sorted(set(bars_clean)) 115 | except Exception: 116 | bars_clean = sorted(list(dict.fromkeys(bars_clean))) 117 | else: 118 | bars_clean = [] 119 | 120 | bar_note_sets = [NoteSet(start=db, end=db) for db in bars_clean] 121 | all_sets = processed_notes + bar_note_sets 122 | try: 123 | all_sets.sort() 124 | except Exception: 125 | all_sets = sorted(all_sets, key=lambda ns: getattr(ns, "start", 0) or 0) 126 | 127 | final_sets = [] 128 | prev_start = None 129 | for ns in all_sets: 130 | if ns.start is None or ns.end is None: 131 | continue 132 | if ns.start < 0 or ns.duration < 0: 133 | continue 134 | if prev_start is not None and (ns.start - prev_start) > _MAX_GAP_TICKS: 135 | # gap too large; keep but don't crash (original code had a pass) 136 | pass 137 | final_sets.append(ns) 138 | prev_start = ns.start 139 | 140 | return final_sets -------------------------------------------------------------------------------- /loops_nomml/process_file.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | from typing import Dict, Tuple, Any, List 4 | import numpy as np 5 | from miditok.utils import get_bars_ticks, get_beats_ticks 6 | from miditok.constants import CLASS_OF_INST, INSTRUMENT_CLASSES 7 | from symusic import Score, Track, TimeSignature 8 | 9 | from .corr_mat import calc_correlation, get_valid_loops 10 | from .note_set import compute_note_sets 11 | 12 | MAX_NOTES_PER_TRACK = 50000 13 | MIN_NOTES_PER_TRACK = 5 14 | 15 | """ 16 | def get_instrument_type(track: Track) -> str: 17 | 18 | Determines MIDI instrument class of a track 19 | 20 | :param track: MIDI track to identify instrument of 21 | :return: name of instrument class 22 | 23 | if track.is_drum: 24 | return "Drums" 25 | 26 | return INSTRUMENT_CLASSES[CLASS_OF_INST[0]]["name"] 27 | 28 | """ 29 | 30 | 31 | def get_instrument_type(track) -> str: 32 | """ 33 | Determines MIDI instrument class of a track. 34 | 35 | :param track: A pretty_midi.Instrument object 36 | :return: name of instrument class 37 | """ 38 | if track.is_drum: 39 | return "Drums" 40 | 41 | program_number = track.program # MIDI program number (0–127) 42 | instrument_class_index = CLASS_OF_INST[program_number] 43 | instrument_class_name = INSTRUMENT_CLASSES[instrument_class_index]["name"] 44 | 45 | return instrument_class_name 46 | 47 | 48 | 49 | def create_loop_dict(endpoint_data: Tuple[int, int, float, float], track_idx: int, instrument_type: str) -> Dict[str,Any]: 50 | """ 51 | Formats loop metadata into a dictionary for dataset generation 52 | 53 | :param endpoint_data: tuple of loop start time in ticks, loop end time in 54 | ticks, loop duration in beats, and density in notes per beat 55 | :param track_idx: MIDI track index the loop belongs to 56 | :instrument_type: MIDI instrument the loop represents, as a string 57 | :return: data entry containing all metadata for a single loop 58 | """ 59 | start, end, beats, density = endpoint_data 60 | return { 61 | "track_idx": track_idx, 62 | "instrument_type": instrument_type, 63 | "start": start, 64 | "end": end, 65 | "duration_beats": beats, 66 | "note_density": density 67 | } 68 | 69 | def detect_loops_from_path(file_info: Dict) -> Dict[str,List]: 70 | """ 71 | Given a MIDI file, locate all loops present across all of its tracks 72 | 73 | :param file_info: dictionary containing a file_path key 74 | :return: dictionary of metadata for each identified loop 75 | """ 76 | file_path = file_info['file_path'] 77 | if isinstance(file_path, list): 78 | file_path = file_path[0] 79 | try: 80 | score = Score(file_path) 81 | except: 82 | print(f"Unable to parse score for {file_path}, skipping") 83 | return { 84 | "track_idx": [], 85 | "instrument_type": [], 86 | "start": [], 87 | "end": [], 88 | "duration_beats": [], 89 | "note_density": [], 90 | } 91 | return detect_loops(score, file_path=file_path) 92 | 93 | def detect_loops(score: Score, file_path: str = None) -> Dict[str,List]: 94 | """ 95 | Given a MIDI score, locate all loops present across off the tracks 96 | 97 | :param score: score to evaluate for loops 98 | :return: dictionary of metadata for each identified loop 99 | """ 100 | data = { 101 | "track_idx": [], 102 | "instrument_type": [], 103 | "start": [], 104 | "end": [], 105 | "duration_beats": [], 106 | "note_density": [], 107 | } 108 | if file_path is not None: 109 | data["file_path"] = [] 110 | # Check that there is a time signature. There might be none with abc files 111 | if len(score.time_signatures) == 0: 112 | score.time_signatures.append(TimeSignature(0, 4, 4)) 113 | 114 | try: 115 | bars_ticks = np.array(get_bars_ticks(score)) 116 | except ZeroDivisionError: 117 | print(f"Skipping, couldn't find any bar lines") 118 | return data 119 | 120 | beats_ticks = np.array(get_beats_ticks(score)) 121 | for idx, track in enumerate(score.tracks): 122 | if len(track.notes) > MAX_NOTES_PER_TRACK or len(track.notes) < MIN_NOTES_PER_TRACK: 123 | #print(f"Skipping track {idx} for length") 124 | continue 125 | # cut beats_tick at the end of the track 126 | if any(track_bars_mask := bars_ticks > track.end()): 127 | bars_ticks_track = bars_ticks[:np.nonzero(track_bars_mask)[0][0]] 128 | else: 129 | bars_ticks_track = bars_ticks 130 | 131 | # cut beats_tick at the end of the track 132 | if any(track_beats_mask := beats_ticks > track.end()): 133 | beats_ticks_track = beats_ticks[:np.nonzero(track_beats_mask)[0][0]] 134 | else: 135 | beats_ticks_track = beats_ticks 136 | 137 | if len(bars_ticks_track) > MAX_NOTES_PER_TRACK: 138 | print(f"Skipping track {idx} due to ill-formed bars") 139 | continue 140 | 141 | instrument_type = get_instrument_type(track) 142 | note_sets = compute_note_sets(track.notes, bars_ticks_track) 143 | lead_mat = calc_correlation(note_sets) 144 | _, loop_endpoints = get_valid_loops(note_sets, lead_mat, beats_ticks_track) 145 | for endpoint in loop_endpoints: 146 | loop_dict = create_loop_dict(endpoint, idx, instrument_type) 147 | for key in loop_dict.keys(): 148 | data[key].append(loop_dict[key]) 149 | if file_path is not None: 150 | data["file_path"].append(file_path) 151 | 152 | return data 153 | -------------------------------------------------------------------------------- /scripts/dataset_programs_dist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 python 2 | 3 | """ 4 | Counts the programs in the MegaMIDI dataset. 5 | 6 | Results: 7 | Program -1 (Drums): 61400 (0.226%) 8 | Program 0 (Acoustic Grand Piano): 41079 (0.151%) 9 | Program 1 (Bright Acoustic Piano): 2511 (0.009%) 10 | Program 2 (Electric Grand Piano): 891 (0.003%) 11 | Program 3 (Honky-tonk Piano): 572 (0.002%) 12 | Program 4 (Electric Piano 1): 2156 (0.008%) 13 | Program 5 (Electric Piano 2): 1466 (0.005%) 14 | Program 6 (Harpsichord): 1727 (0.006%) 15 | Program 7 (Clavi): 621 (0.002%) 16 | Program 8 (Celesta): 407 (0.001%) 17 | Program 9 (Glockenspiel): 1167 (0.004%) 18 | Program 10 (Music Box): 405 (0.001%) 19 | Program 11 (Vibraphone): 1898 (0.007%) 20 | Program 12 (Marimba): 924 (0.003%) 21 | Program 13 (Xylophone): 440 (0.002%) 22 | Program 14 (Tubular Bells): 717 (0.003%) 23 | Program 15 (Dulcimer): 133 (0.000%) 24 | Program 16 (Drawbar Organ): 883 (0.003%) 25 | Program 17 (Percussive Organ): 874 (0.003%) 26 | Program 18 (Rock Organ): 1627 (0.006%) 27 | Program 19 (Church Organ): 520 (0.002%) 28 | Program 20 (Reed Organ): 158 (0.001%) 29 | Program 21 (Accordion): 997 (0.004%) 30 | Program 22 (Harmonica): 808 (0.003%) 31 | Program 23 (Tango Accordion): 343 (0.001%) 32 | Program 24 (Acoustic Guitar (nylon)): 3478 (0.013%) 33 | Program 25 (Acoustic Guitar (steel)): 7235 (0.027%) 34 | Program 26 (Electric Guitar (jazz)): 3503 (0.013%) 35 | Program 27 (Electric Guitar (clean)): 4416 (0.016%) 36 | Program 28 (Electric Guitar (muted)): 2767 (0.010%) 37 | Program 29 (Overdriven Guitar): 3219 (0.012%) 38 | Program 30 (Distortion Guitar): 3253 (0.012%) 39 | Program 31 (Guitar Harmonics): 299 (0.001%) 40 | Program 32 (Acoustic Bass): 3075 (0.011%) 41 | Program 33 (Electric Bass (finger)): 6316 (0.023%) 42 | Program 34 (Electric Bass (pick)): 831 (0.003%) 43 | Program 35 (Fretless Bass): 3110 (0.011%) 44 | Program 36 (Slap Bass 1): 354 (0.001%) 45 | Program 37 (Slap Bass 2): 298 (0.001%) 46 | Program 38 (Synth Bass 1): 1477 (0.005%) 47 | Program 39 (Synth Bass 2): 887 (0.003%) 48 | Program 40 (Violin): 1558 (0.006%) 49 | Program 41 (Viola): 672 (0.002%) 50 | Program 42 (Cello): 909 (0.003%) 51 | Program 43 (Contrabass): 750 (0.003%) 52 | Program 44 (Tremolo Strings): 666 (0.002%) 53 | Program 45 (Pizzicato Strings): 2260 (0.008%) 54 | Program 46 (Orchestral Harp): 1304 (0.005%) 55 | Program 47 (Timpani): 2109 (0.008%) 56 | Program 48 (String Ensembles 1): 9898 (0.036%) 57 | Program 49 (String Ensembles 2): 3749 (0.014%) 58 | Program 50 (SynthStrings 1): 3028 (0.011%) 59 | Program 51 (SynthStrings 2): 740 (0.003%) 60 | Program 52 (Choir Aahs): 5394 (0.020%) 61 | Program 53 (Voice Oohs): 2101 (0.008%) 62 | Program 54 (Synth Voice): 1288 (0.005%) 63 | Program 55 (Orchestra Hit): 460 (0.002%) 64 | Program 56 (Trumpet): 5559 (0.020%) 65 | Program 57 (Trombone): 4676 (0.017%) 66 | Program 58 (Tuba): 2099 (0.008%) 67 | Program 59 (Muted Trumpet): 649 (0.002%) 68 | Program 60 (French Horn): 3926 (0.014%) 69 | Program 61 (Brass Section): 2340 (0.009%) 70 | Program 62 (Synth Brass 1): 994 (0.004%) 71 | Program 63 (Synth Brass 2): 449 (0.002%) 72 | Program 64 (Soprano Sax): 544 (0.002%) 73 | Program 65 (Alto Sax): 3359 (0.012%) 74 | Program 66 (Tenor Sax): 2479 (0.009%) 75 | Program 67 (Baritone Sax): 1081 (0.004%) 76 | Program 68 (Oboe): 2669 (0.010%) 77 | Program 69 (English Horn): 499 (0.002%) 78 | Program 70 (Bassoon): 2029 (0.007%) 79 | Program 71 (Clarinet): 4666 (0.017%) 80 | Program 72 (Piccolo): 1202 (0.004%) 81 | Program 73 (Flute): 4962 (0.018%) 82 | Program 74 (Recorder): 434 (0.002%) 83 | Program 75 (Pan Flute): 1034 (0.004%) 84 | Program 76 (Blown Bottle): 133 (0.000%) 85 | Program 77 (Shakuhachi): 205 (0.001%) 86 | Program 78 (Whistle): 359 (0.001%) 87 | Program 79 (Ocarina): 345 (0.001%) 88 | Program 80 (Lead 1 (square)): 1422 (0.005%) 89 | Program 81 (Lead 2 (sawtooth)): 2021 (0.007%) 90 | Program 82 (Lead 3 (calliope)): 913 (0.003%) 91 | Program 83 (Lead 4 (chiff)): 127 (0.000%) 92 | Program 84 (Lead 5 (charang)): 259 (0.001%) 93 | Program 85 (Lead 6 (voice)): 281 (0.001%) 94 | Program 86 (Lead 7 (fifths)): 84 (0.000%) 95 | Program 87 (Lead 8 (bass + lead)): 876 (0.003%) 96 | Program 88 (Pad 1 (new age)): 931 (0.003%) 97 | Program 89 (Pad 2 (warm)): 1222 (0.004%) 98 | Program 90 (Pad 3 (polysynth)): 725 (0.003%) 99 | Program 91 (Pad 4 (choir)): 669 (0.002%) 100 | Program 92 (Pad 5 (bowed)): 260 (0.001%) 101 | Program 93 (Pad 6 (metallic)): 241 (0.001%) 102 | Program 94 (Pad 7 (halo)): 364 (0.001%) 103 | Program 95 (Pad 8 (sweep)): 508 (0.002%) 104 | Program 96 (FX 1 (rain)): 204 (0.001%) 105 | Program 97 (FX 2 (soundtrack)): 87 (0.000%) 106 | Program 98 (FX 3 (crystal)): 271 (0.001%) 107 | Program 99 (FX 4 (atmosphere)): 478 (0.002%) 108 | Program 100 (FX 5 (brightness)): 754 (0.003%) 109 | Program 101 (FX 6 (goblins)): 121 (0.000%) 110 | Program 102 (FX 7 (echoes)): 301 (0.001%) 111 | Program 103 (FX 8 (sci-fi)): 144 (0.001%) 112 | Program 104 (Sitar): 206 (0.001%) 113 | Program 105 (Banjo): 474 (0.002%) 114 | Program 106 (Shamisen): 135 (0.000%) 115 | Program 107 (Koto): 164 (0.001%) 116 | Program 108 (Kalimba): 224 (0.001%) 117 | Program 109 (Bag pipe): 87 (0.000%) 118 | Program 110 (Fiddle): 243 (0.001%) 119 | Program 111 (Shanai): 33 (0.000%) 120 | Program 112 (Tinkle Bell): 110 (0.000%) 121 | Program 113 (Agogo): 53 (0.000%) 122 | Program 114 (Steel Drums): 181 (0.001%) 123 | Program 115 (Woodblock): 143 (0.001%) 124 | Program 116 (Taiko Drum): 235 (0.001%) 125 | Program 117 (Melodic Tom): 187 (0.001%) 126 | Program 118 (Synth Drum): 362 (0.001%) 127 | Program 119 (Reverse Cymbal): 1228 (0.005%) 128 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 346 (0.001%) 129 | Program 121 (Breath Noise, Flute Key Click): 72 (0.000%) 130 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 407 (0.001%) 131 | Program 123 (Bird Tweet, Dog, Horse Gallop): 85 (0.000%) 132 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 185 (0.001%) 133 | Program 125 (Helicopter, Car Sounds): 158 (0.001%) 134 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 185 (0.001%) 135 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 208 (0.001%) 136 | """ 137 | 138 | if __name__ == "__main__": 139 | from random import shuffle 140 | 141 | import numpy as np 142 | from miditok.constants import MIDI_INSTRUMENTS, SCORE_LOADING_EXCEPTION 143 | from symusic import Score 144 | from tqdm import tqdm 145 | from transformers.trainer_utils import set_seed 146 | 147 | from utils.baseline import is_score_valid, mmm 148 | 149 | set_seed(mmm.seed) 150 | 151 | NUM_FILES = 100000 152 | mmm.tokenizer.config.programs = list(range(-1, 128)) 153 | 154 | # Iterate over files 155 | dataset_files_paths = mmm.dataset_files_paths 156 | shuffle(dataset_files_paths) 157 | dataset_files_paths = dataset_files_paths[:NUM_FILES] 158 | all_programs = [] 159 | for file_path in tqdm(dataset_files_paths, desc="Analyzing files"): 160 | try: 161 | score = Score(file_path) 162 | except SCORE_LOADING_EXCEPTION: 163 | continue 164 | if is_score_valid(score): 165 | score = mmm.tokenizer.preprocess_score(score) 166 | all_programs += [ 167 | track.program if not track.is_drum else -1 for track in score.tracks 168 | ] 169 | 170 | all_programs = np.array(all_programs) 171 | for program in range(-1, 128): 172 | num_occurrences = len(np.where(all_programs == program)[0]) 173 | ratio = num_occurrences / len(all_programs) 174 | print( # noqa: T201 175 | f"Program {program} (" 176 | f"{'Drums' if program == -1 else MIDI_INSTRUMENTS[program]['name']}): " 177 | f"{num_occurrences} ({ratio:.3f}%)" 178 | ) 179 | -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Saarland.csv: -------------------------------------------------------------------------------- 1 | filepath,trackNum,instrument,isDrum,hasTimeSignatures,medianMetricDepth,tpq,velocity_per_track,onset_per_track,velocity_entropy_per_track,onset_entropy_per_track 2 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV849-01_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[70],[54],[2.242208320260014],[2.1877194014196313] 3 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV849-02_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[70],[57],[2.1944598281059347],[2.1227598342494285] 4 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV871-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[50],[57],[2.2215143362956855],[2.1560432773081404] 5 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV871-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[56],[57],[2.203766563904754],[2.1007718471971693] 6 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV875-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[45],[54],[2.301780586273925],[2.1838442915604994] 7 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV875-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[49],[52],[2.1812860620817767],[2.0979846011128274] 8 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV888-01_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[64],[52],[2.1082250272983973],[2.1107298161412866] 9 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV888-02_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[57],[56],[2.206863231520465],[2.0958697321118414] 10 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-01_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[110],[58],[2.2161044801654106],[2.0749961898767566] 11 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-02_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[96],[54],[2.2231056737689365],[2.1168599493420825] 12 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-03_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[108],[58],[2.2144987525393445],[2.0118768121066197] 13 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-01_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[99],[58],[2.2332564927802374],[2.1444039613100063] 14 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-02_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[81],[55],[2.230988423853517],[2.2014043078234637] 15 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-03_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.2316670391122067],[2.1411010733691795] 16 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[95],[58],[2.27603091366256],[2.149934476450946] 17 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[81],[57],[2.21160805229714],[2.2032513200533073] 18 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-03_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[86],[58],[2.3149297700033453],[2.15484620927114] 19 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_WoO080_001_20081107-SMD.mid,0,0,FALSE,TRUE,12,960,[106],[115],[2.1725154903213344],[2.0529039299803067] 20 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op005-01_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[110],[58],[2.253343013720334],[2.1357865892808614] 21 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op010No1_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.2754737287354354],[2.1821638405378394] 22 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op010No2_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[98],[58],[2.213394918335671],[2.2448515165885334] 23 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op010-03_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.1045524249432686],[2.1659579392860815] 24 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op010-04_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[83],[58],[2.2542927779942095],[2.159979927040184] 25 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op026No1_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[96],[58],[2.202725184192824],[2.2354584009352236] 26 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op026No2_005_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.1881458947675463],[2.11692043120291] 27 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-01_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[66],[47],[1.9361231770370437],[2.071933547679693] 28 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-03_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[62],[52],[2.236948427679256],[2.2063719586526718] 29 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-04_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[73],[50],[2.1465632974667974],[2.1917559128738] 30 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-11_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[59],[44],[2.0073681046323375],[2.286339826430251] 31 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-15_006_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[90],[57],[2.2794211506178277],[2.157003822283769] 32 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-17_005_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[89],[58],[2.151500180755954],[2.155545080515562] 33 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op029_004_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[86],[58],[2.0610677568177698],[2.1684607976033172] 34 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op048No1_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[95],[58],[2.203475547767666],[2.142876929559246] 35 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op066_006_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[91],[58],[2.0558183388968496],[2.1855233823119002] 36 | Eval-subsets/Saarland Music Data (SMD)/Haydn_Hob017No4_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[89],[58],[2.1828314668975484],[2.16630610134019] 37 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-01_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.2486043974740966],[2.1748447884374515] 38 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-02_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[90],[58],[2.2494194922620165],[2.196055535021124] 39 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-03_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.2193096916376054],[2.1728952328180213] 40 | Eval-subsets/Saarland Music Data (SMD)/Liszt_AnnesDePelerinage-LectureDante_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[113],[58],[2.2236645223841704],[2.0992465021883246] 41 | Eval-subsets/Saarland Music Data (SMD)/Liszt_KonzertetuedeNo2LaLeggierezza_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.0325195357473147],[1.8893376557365782] 42 | Eval-subsets/Saarland Music Data (SMD)/Liszt_VariationenBachmotivWeinenKlagenSorgenZagen_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[119],[58],[2.184494479552385],[2.127697490110233] 43 | Eval-subsets/Saarland Music Data (SMD)/Mozart_KV265_006_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[84],[58],[2.252625013291874],[2.157166107071062] 44 | Eval-subsets/Saarland Music Data (SMD)/Mozart_KV398_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[84],[58],[2.2460228428920783],[2.200669693579367] 45 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-01_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[103],[58],[2.1822623783379704],[2.1523931604966497] 46 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-02_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.201395213629322],[2.229464503041128] 47 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-03_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[109],[58],[2.1848708896471534],[2.1348497473142523] 48 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninov_Op039No1_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.266573511283351],[2.158250339709647] 49 | Eval-subsets/Saarland Music Data (SMD)/Ravel_JeuxDEau_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.1628408923677926],[2.20423179159176] 50 | Eval-subsets/Saarland Music Data (SMD)/Ravel_ValsesNoblesEtSentimentales_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[105],[58],[2.204846504136731],[2.278582113192171] 51 | Eval-subsets/Saarland Music Data (SMD)/Skryabin_Op008No8_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[79],[57],[2.1402922752695988],[2.182203169781139] -------------------------------------------------------------------------------- /loops_nomml/process_file_fast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Dict, Tuple, Any, List 3 | import os 4 | import numpy as np 5 | from miditok.utils import get_bars_ticks, get_beats_ticks 6 | from miditok.constants import CLASS_OF_INST, INSTRUMENT_CLASSES 7 | from symusic import Score, Track, TimeSignature 8 | 9 | from corr_mat_fast import calc_correlation, get_valid_loops 10 | from note_set_fast import compute_note_sets 11 | 12 | MAX_NOTES_PER_TRACK = 50000 13 | MIN_NOTES_PER_TRACK = 5 14 | 15 | # New: sanity threshold for tick values (should match other modules) 16 | _MAX_TICK = 10 ** 9 17 | 18 | # Optional: allow enabling numba/cuda via environment variables here as well. 19 | # Example: 20 | # export USE_NUMBA=1 21 | # export USE_CUDA=1 22 | # The corr_mat_fast module reads these env vars at import time. 23 | 24 | def get_instrument_type(track) -> str: 25 | """ 26 | Determines MIDI instrument class of a track. 27 | 28 | :param track: A pretty_midi.Instrument object 29 | :return: name of instrument class 30 | """ 31 | if track.is_drum: 32 | return "Drums" 33 | 34 | program_number = track.program # MIDI program number (0–127) 35 | instrument_class_index = CLASS_OF_INST[program_number] 36 | instrument_class_name = INSTRUMENT_CLASSES[instrument_class_index]["name"] 37 | 38 | return instrument_class_name 39 | 40 | 41 | def create_loop_dict(endpoint_data: Tuple[int, int, float, float], track_idx: int, instrument_type: str) -> Dict[str,Any]: 42 | """ 43 | Formats loop metadata into a dictionary for dataset generation 44 | """ 45 | start, end, beats, density = endpoint_data 46 | return { 47 | "track_idx": track_idx, 48 | "instrument_type": instrument_type, 49 | "start": start, 50 | "end": end, 51 | "duration_beats": beats, 52 | "note_density": density 53 | } 54 | 55 | 56 | def detect_loops_from_path(file_info: Dict) -> Dict[str,List]: 57 | """ 58 | Given a MIDI file, locate all loops present across all of its tracks 59 | """ 60 | file_path = file_info['file_path'] 61 | if isinstance(file_path, list): 62 | file_path = file_path[0] 63 | try: 64 | score = Score(file_path) 65 | except Exception: 66 | # Unable to parse score (malformed file) -> skip file 67 | print(f"Unable to parse score for {file_path}, skipping") 68 | return { 69 | "track_idx": [], 70 | "instrument_type": [], 71 | "start": [], 72 | "end": [], 73 | "duration_beats": [], 74 | "note_density": [], 75 | } 76 | return detect_loops(score, file_path=file_path) 77 | 78 | 79 | def detect_loops(score: Score, file_path: str = None) -> Dict[str,List]: 80 | """ 81 | Given a MIDI score, locate all loops present across off the tracks 82 | """ 83 | data = { 84 | "track_idx": [], 85 | "instrument_type": [], 86 | "start": [], 87 | "end": [], 88 | "duration_beats": [], 89 | "note_density": [], 90 | } 91 | if file_path is not None: 92 | data["file_path"] = [] 93 | # Check that there is a time signature. There might be none with abc files 94 | if len(score.time_signatures) == 0: 95 | score.time_signatures.append(TimeSignature(0, 4, 4)) 96 | 97 | # Extract bars and beats ticks defensively 98 | try: 99 | bars_ticks_raw = get_bars_ticks(score) 100 | beats_ticks_raw = get_beats_ticks(score) 101 | except Exception: 102 | print(f"Skipping, couldn't extract bars/beats for {file_path or 'score'} due to malformed events") 103 | return data 104 | 105 | # sanitize arrays: ensure numeric, finite, reasonable magnitude 106 | try: 107 | bars_ticks = np.asarray(bars_ticks_raw, dtype=np.int64) 108 | bars_ticks = bars_ticks[np.isfinite(bars_ticks)] 109 | bars_ticks = bars_ticks[(bars_ticks >= 0) & (bars_ticks <= _MAX_TICK)] 110 | except Exception: 111 | bars_ticks = np.array([], dtype=np.int64) 112 | 113 | try: 114 | beats_ticks = np.asarray(beats_ticks_raw, dtype=np.int64) 115 | beats_ticks = beats_ticks[np.isfinite(beats_ticks)] 116 | beats_ticks = beats_ticks[(beats_ticks >= 0) & (beats_ticks <= _MAX_TICK)] 117 | except Exception: 118 | beats_ticks = np.array([], dtype=np.int64) 119 | 120 | for idx, track in enumerate(score.tracks): 121 | # Basic track length checks 122 | try: 123 | n_notes = len(track.notes) 124 | except Exception: 125 | # malformed track object: skip 126 | continue 127 | if n_notes > MAX_NOTES_PER_TRACK or n_notes < MIN_NOTES_PER_TRACK: 128 | continue 129 | 130 | # cut bars_ticks at the end of the track defensively 131 | try: 132 | track_end = int(track.end()) 133 | except Exception: 134 | # malformed end time: skip track 135 | continue 136 | 137 | if bars_ticks.size and np.any(bars_ticks > track_end): 138 | try: 139 | bars_ticks_track = bars_ticks[:np.nonzero(bars_ticks > track_end)[0][0]] 140 | except Exception: 141 | bars_ticks_track = bars_ticks 142 | else: 143 | bars_ticks_track = bars_ticks 144 | 145 | if beats_ticks.size and np.any(beats_ticks > track_end): 146 | try: 147 | beats_ticks_track = beats_ticks[:np.nonzero(beats_ticks > track_end)[0][0]] 148 | except Exception: 149 | beats_ticks_track = beats_ticks 150 | else: 151 | beats_ticks_track = beats_ticks 152 | 153 | if len(bars_ticks_track) > MAX_NOTES_PER_TRACK: 154 | # ill-formed bars for this track 155 | continue 156 | 157 | # instrument type 158 | try: 159 | instrument_type = get_instrument_type(track) 160 | except Exception: 161 | instrument_type = "Unknown" 162 | 163 | # Compute note sets with defensive handling of malformed note events 164 | try: 165 | note_sets = compute_note_sets(track.notes, bars_ticks_track) 166 | except Exception: 167 | # compute_note_sets failed due to malformed notes -> skip track 168 | continue 169 | 170 | # If compute_note_sets returned nothing and there are no bars, skip 171 | if not note_sets and (bars_ticks_track.size == 0): 172 | continue 173 | 174 | # Defensive: ensure note_sets entries look sane 175 | bad_ns = False 176 | for ns in note_sets: 177 | try: 178 | if ns.start is None or ns.end is None: 179 | bad_ns = True 180 | break 181 | if ns.start < 0 or ns.end < 0 or ns.start > _MAX_TICK or ns.end > _MAX_TICK: 182 | bad_ns = True 183 | break 184 | except Exception: 185 | bad_ns = True 186 | break 187 | if bad_ns: 188 | # skip track with malformed note sets 189 | continue 190 | 191 | # Compute correlation and loops with try/except so a single bad track doesn't crash everything 192 | try: 193 | lead_mat = calc_correlation(note_sets) 194 | except Exception: 195 | # correlation failed (malformed note_sets) -> skip track 196 | continue 197 | 198 | try: 199 | _, loop_endpoints = get_valid_loops(note_sets, lead_mat, beats_ticks_track) 200 | except Exception: 201 | # loop detection failed -> skip track 202 | continue 203 | 204 | for endpoint in loop_endpoints: 205 | loop_dict = create_loop_dict(endpoint, idx, instrument_type) 206 | for key in loop_dict.keys(): 207 | data[key].append(loop_dict[key]) 208 | if file_path is not None: 209 | data["file_path"].append(file_path) 210 | 211 | return data -------------------------------------------------------------------------------- /scripts/test_extract.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pretty_midi\n", 10 | "from corr_mat import calc_correlation, get_valid_loops\n", 11 | "from track import Track\n", 12 | "from util import get_instrument_type, create_loop_dict\n", 13 | "import os\n", 14 | "from tqdm import tqdm" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "test_file = \"./test_midi.mid\"\n", 24 | "pm = pretty_midi.PrettyMIDI(test_file)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "found 1 loops in Instrument(program=0, is_drum=True, name=\"GOT TO HAVE YOUR LOVE \")\n", 37 | "0 18.113216000000005 22.641519999999996 (4, 4) 8.0 DRUM 3.875\n", 38 | "found 1 loops in Instrument(program=39, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 39 | "1 9.056607999999999 13.584912000000005 (4, 4) 8.0 BASS 2.0\n", 40 | "found 1 loops in Instrument(program=4, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 41 | "2 33.96227999999998 43.01888799999996 (4, 4) 16.0 PIANO 0.5625\n", 42 | "found 1 loops in Instrument(program=48, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 43 | "3 24.905671999999992 29.433975999999983 (4, 4) 8.0 ENSEMBLE 0.75\n", 44 | "found 1 loops in Instrument(program=48, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 45 | "4 24.905671999999992 29.433975999999983 (4, 4) 8.0 ENSEMBLE 0.75\n", 46 | "found 1 loops in Instrument(program=28, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 47 | "5 27.169823999999988 31.69812799999998 (4, 4) 8.0 GUITAR 1.75\n", 48 | "found 0 loops in Instrument(program=2, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 49 | "found 0 loops in Instrument(program=82, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n", 50 | "6 total loops in 8 tracks\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "final_loops = []\n", 56 | "for idx, instrument in enumerate(pm.instruments):\n", 57 | " instrument_type = get_instrument_type(instrument)\n", 58 | " track = Track(pm, instrument)\n", 59 | " note_list = track.notes\n", 60 | " lead_mat, lead_dur = calc_correlation(note_list)\n", 61 | " loops, loop_endpoints = get_valid_loops(track, lead_mat, lead_dur)\n", 62 | " print(f\"found {len(loops)} loops in {instrument}\")\n", 63 | "\n", 64 | " for endpoint in loop_endpoints:\n", 65 | " start, end, beats, density = endpoint\n", 66 | " time_sig = track.get_time_sig_at_time(start)\n", 67 | " print(idx, start, end, time_sig, beats, instrument_type, density)\n", 68 | " for loop in loops:\n", 69 | " final_loops.append(loop)\n", 70 | "print(f\"{len(final_loops)} total loops in {len(pm.instruments)} tracks\")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def run_file(file_path, name):\n", 80 | " try:\n", 81 | " pm = pretty_midi.PrettyMIDI(file_path)\n", 82 | " except:\n", 83 | " print(f\"failed to parse {file_path}, skipping\")\n", 84 | " return 0,0,[]\n", 85 | " \n", 86 | " total_loops = 0\n", 87 | " loops = []\n", 88 | " for idx, instrument in enumerate(pm.instruments):\n", 89 | " instrument_type = get_instrument_type(instrument)\n", 90 | " track = Track(pm, instrument)\n", 91 | " note_list = track.notes\n", 92 | " lead_mat, lead_dur = calc_correlation(note_list)\n", 93 | " full_loops, loop_endpoints = get_valid_loops(track, lead_mat, lead_dur)\n", 94 | " for endpoint in loop_endpoints:\n", 95 | " time_sig = track.get_time_sig_at_time(endpoint[0])\n", 96 | " if time_sig is None:\n", 97 | " continue\n", 98 | " loop_dict = create_loop_dict(endpoint, idx, instrument_type, time_sig, name)\n", 99 | " loops.append(loop_dict)\n", 100 | " for loop_list in full_loops:\n", 101 | " if loop_list[0].duration != 0 or loop_list[-1].duration != 0:\n", 102 | " loop_list[0]\n", 103 | " print(name, loop_list[0], loop_list[-1])\n", 104 | " total_loops += len(loop_endpoints)\n", 105 | " return total_loops, len(pm.instruments), loops" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "25050" 117 | ] 118 | }, 119 | "execution_count": 5, 120 | "metadata": {}, 121 | "output_type": "execute_result" 122 | } 123 | ], 124 | "source": [ 125 | "full_directory = \"D:\\\\Documents\\\\GigaMIDI\\\\Final_GigaMIDI_TISMIR\\\\Validatation-10%\\\\GigaMIDI-Val-Drum+Music-MD5\"\n", 126 | "len(os.listdir(full_directory))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stderr", 136 | "output_type": "stream", 137 | "text": [ 138 | " 0%| | 0/100 [00:00 0:\n", 185 | " all_loops.append(loop)\n", 186 | "print(f\"Found {total_loops} loops in {total_tracks} tracks across {num_files} files\")" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": ".venv", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.10.5" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /MIDI-GPT-Loop/README.md: -------------------------------------------------------------------------------- 1 | # Loop Generation 2 | ## Generating and evaluating loops using MIDI-GPT 3 | 4 | Here we present the scripts for loop generation as well as evaluation scripts used for our paper. 5 | 6 | ### MMM 7 | 8 | This repository is the library to run and train our MIDI-GPT-Loop model. It is a fork of the `MMM` repository from the `Metacreation Lab` adapted for training and inference with loops. This library also uses a fork of `MidiTok` allowing the integration of loop tokens. 9 | 10 | The main entry point is `inference.py`, where `generate` and `generate_batch`, may be used for MIDI generation 11 | 12 | ### MMM-Loop 13 | 14 | This library holds the scripts used to generate the synthetic MIDI data, and evaluate the accuracy of the generation NOMML controls and the generated loops. 15 | 16 | # Installation and Setup 17 | 18 | > **Note:** The following scripts were created for computing on [Compute Canada](https://docs.alliancecan.ca/). Therefore, slight modifications (file organization, module imports, environment variables) may be needed in order to run these scripts elsewhere. 19 | 20 | The dependencies installed are those needed for inference as well as training, therefore some may not be needed if you do not wish to train a model. The following setup is identical for both submodules, MMM and MMM-Loop 21 | 22 | ## On Compute Canada 23 | 24 | On **Canada Canada**, load the correct modules and install dependencies: 25 | 26 | ```bash 27 | cd MMM-Loop 28 | bash slurm/install.sh 29 | ``` 30 | 31 | ## Elsewhere 32 | 33 | ### 🧰 Dependencies 34 | 35 | The code depends on several Python packages, some of which may require special installation steps or system libraries. 36 | 37 | #### 📦 Key Python Dependencies 38 | 39 | | Package | Version | Notes | 40 | | ------------------------------------------- | ---------- | --------------------------------------------------------- | 41 | | `python` | 3.11 | Required for compatibility | 42 | | `symusic` | 0.5.0 | For symbolic music representations | 43 | | `MidiTok` | expressive | Custom fork installed from GitHub | 44 | | `transformers` | 4.49.0 | Hugging Face Transformers | 45 | | `accelerate` | 1.4.0 | Hugging Face Accelerate | 46 | | `tensorboard` | 2.19.0 | TensorBoard for logging | 47 | | `flash-attn` | 2.5.7 | May require building from source (see instructions below) | 48 | | `deepspeed` | 0.14.4 | For model parallelism and training acceleration | 49 | | `datasets` | 3.3.2 | Hugging Face Datasets | 50 | | `triton` | 3.1.0 | Required for some GPU optimizations | 51 | | `nvitop`, `pandas`, `matplotlib`, `sklearn` | Latest | Utility and visualization tools | 52 | 53 | ### 💻 Installation Instructions 54 | 55 | 1. Create the virtual environment 56 | 57 | Use Python 3.11: 58 | 59 | ```bash 60 | cd MMM 61 | ``` 62 | 63 | or 64 | 65 | ```bash 66 | cd MMM-Loop 67 | ``` 68 | 69 | ```bash 70 | python3.11 -m venv .venv 71 | source .venv/bin/activate 72 | ``` 73 | 74 | If `python3.11` is not available, install it via pyenv or your system's package manager. 75 | 76 | 2. Install dependencies 77 | 78 | ```bash 79 | pip install --upgrade pip 80 | 81 | # Required packages 82 | pip install symusic==0.5.0 83 | pip install git+https://github.com/DaoTwenty/MidiTok@expressive 84 | pip install transformers==4.49.0 accelerate==1.4.0 tensorboard==2.19.0 85 | pip install deepspeed==0.14.4 86 | pip install datasets==3.3.2 87 | pip install triton==3.1.0 88 | pip install nvitop pandas matplotlib scikit-learn 89 | ``` 90 | 91 | ### ⚡ Installing FlashAttention (Optional but Recommended) 92 | 93 | FlashAttention often provides significant speedups for training Transformer models, but may require a manual installation from source depending on your system and GPU. 94 | 95 | 1. Clone FlashAttention 96 | 97 | ```bash 98 | git clone https://github.com/Dao-AILab/flash-attention.git 99 | cd flash-attention 100 | git checkout v2.5.7 101 | ``` 102 | 103 | 2. Install with `pip` 104 | 105 | ```bash 106 | pip install . 107 | ``` 108 | 109 | > **Note:** You may need to have the following: 110 | 111 | - A CUDA-capable GPU (Compute Capability ≥ 7.5) 112 | - CUDA toolkit ≥ 11.8 113 | - Compatible PyTorch version (typically latest stable) 114 | - Refer to FlashAttention's official documentation for details. 115 | 116 | ### 📊 Verifying Installation 117 | 118 | Run the following to verify installed versions: 119 | 120 | ```bash 121 | pip freeze 122 | ``` 123 | 124 | ### 💬 Notes 125 | 126 | - On Compute Canada, system modules like `gcc`, `arrow`, and `rust` were required. These are **not needed** if you can build FlashAttention from source locally. 127 | - If `arrow` or `rust` are required by specific packages, ensure they are available on your system (`brew`, `apt`, or via `conda`). 128 | 129 | # 🎶 Usage 130 | 131 | First, download the model via [https://1sfu-my.sharepoint.com/:u:/g/personal/pta63_sfu_ca/EbpBz06rnaJMtirT0DqvTFoBQSC2OqJ_gex88fenJ60CQQ?e=o24IMz](https://1sfu-my.sharepoint.com/:u:/g/personal/pta63_sfu_ca/EbpBz06rnaJMtirT0DqvTFoBQSC2OqJ_gex88fenJ60CQQ?e=o24IMz) and place it in ``MMM-Loop/models`` 132 | 133 | ## Loop Generation 134 | 135 | ### 🧠 Environment Variables 136 | 137 | The script (`slurm/ge_synthetic_data.sh`) expects these repositories to be present: 138 | 139 | ```bash 140 | export PYTHONPATH=$PYTHONPATH:/path/to/MMM:/path/to/MMM-Loop 141 | ``` 142 | > Replace `/path/to/MMM`and `/path/to/MMM-Loop` with the actual path where the repository is cloned. 143 | 144 | ### 🛠 Running Locally (Without SLURM) 145 | 146 | The SLURM batch loop can be mimicked using a local shell script or Python launcher. Here’s a basic loop for **manual local use**: 147 | 148 | ```bash 149 | #!/bin/bash 150 | 151 | source .venv/bin/activate 152 | 153 | MODEL=MIDI-GPT-Loop-model 154 | TOKENIZER=MMM_epl_mistral_tokenizer_with_acs.json 155 | CONFIG=slurm/gen_config.json 156 | NOMML=0 157 | NUM_GEN=1000 158 | BATCH=12 159 | OUTPUT="./SYNTHETIC_DATA" 160 | LABEL="V1" 161 | 162 | python -m src.generate_loops \ 163 | --config $CONFIG \ 164 | --model models/$MODEL \ 165 | --tokenizer models/$TOKENIZER \ 166 | --num_generations $NUM_GEN \ 167 | --max_attempts $NUM_GEN \ 168 | --batch $BATCH \ 169 | --nomml $EFFECTIVE_NOMML \ 170 | --output_folder $OUTPUT \ 171 | --label $LABEL \ 172 | --rank 0 & 173 | ``` 174 | 175 | ## 📈 Loop Evaluation 176 | 177 | ### 📝 Script Overview 178 | 179 | The SLURM script (`slurm/eval_loops/sh`): 180 | - Aggregates CSV result files (`RESULTS_*.csv`) from a directory. 181 | - Merges them into a single `RESULTS.csv`. 182 | - Runs the evaluation script via Python: `src.evaluate`. 183 | 184 | ```bash 185 | #!/bin/bash 186 | 187 | SOURCE="./SYNTHETIC_DATA" 188 | LABEL="V1" 189 | OUTFILE="$SOURCE/$LABEL/RESULTS.csv" 190 | 191 | mkdir -p "$SOURCE/$LABEL" 192 | mkdir -p "$SOURCE/$LABEL/MIDI" 193 | 194 | # Clear existing merged file 195 | > "$OUTFILE" 196 | 197 | first=1 198 | for file in "$SOURCE/$LABEL"/RESULTS_*.csv; do 199 | if [ -f "$file" ]; then 200 | echo "Processing $file" 201 | if [ $first -eq 1 ]; then 202 | cat "$file" >> "$OUTFILE" 203 | first=0 204 | else 205 | tail -n +2 "$file" >> "$OUTFILE" 206 | fi 207 | fi 208 | done 209 | 210 | # Run evaluation script 211 | python -m src.evaluate --source "$SOURCE/$LABEL" 212 | ``` 213 | 214 | ## Visualisation 215 | 216 | Create Cross-entropy graph of the evaluation (`slurm/plot_results.sh`) 217 | 218 | ```bash 219 | #!/bin/bash 220 | 221 | SOURCE="../SYNTHETIC_DATA" 222 | LABEL="V1" 223 | INPUT_DIR="$SOURCE/$LABEL" 224 | 225 | PLOT_ARGS=" \ 226 | --input $INPUT_DIR \ 227 | --output $INPUT_DIR \ 228 | " 229 | 230 | source .venv/bin/activate 231 | 232 | python -m src.plot_results $PLOT_ARGS 233 | ``` 234 | 235 | ### 🧾 Output 236 | 237 | Results are merged to: `SYNTHETIC_DATA/V1/RESULTS.csv` -------------------------------------------------------------------------------- /Expressive music loop detector-NOMML12.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "e65838ae", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "1094620" 13 | ] 14 | }, 15 | "execution_count": 5, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "len(file_paths)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "0dee9022", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "import numpy as np\n", 33 | "import pandas as pd\n", 34 | "import ast\n", 35 | "from symusic import Score\n", 36 | "from loops_nomml.note_set import compute_note_sets\n", 37 | "import loops_nomml.corr_mat as corr\n", 38 | "from loops_nomml.corr_mat import get_valid_loops\n", 39 | "from joblib import Parallel, delayed\n", 40 | "from tqdm.notebook import tqdm\n", 41 | "\n", 42 | "# ─── Patch safe get_duration_beats ────────────────────────────────────────\n", 43 | "def safe_get_duration_beats(start: int, end: int, ticks_beats: list[int]) -> float:\n", 44 | " i0 = max((i for i, t in enumerate(ticks_beats) if t <= start), default=0)\n", 45 | " i1 = max((i for i, t in enumerate(ticks_beats) if t <= end), default=i0)\n", 46 | " return float(i1 - i0)\n", 47 | "corr.get_duration_beats = safe_get_duration_beats\n", 48 | "\n", 49 | "# ─── Constants ─────────────────────────────────────────────────────────────\n", 50 | "GM_GROUPS = [\n", 51 | " 'Piano','Chromatic Percussion','Organ','Guitar',\n", 52 | " 'Bass','Strings','Ensemble','Brass',\n", 53 | " 'Reed','Pipe','Synth Lead','Synth Pad',\n", 54 | " 'Synth Effects','Ethnic','Percussive','Sound Effects'\n", 55 | "]\n", 56 | "DRUM_GROUP = 'Drums'\n", 57 | "\n", 58 | "# ─── similarity + soft-count ──────────────────────────────────────────────\n", 59 | "def note_similarity(a, b, v_a, v_b, w_p=0.5, w_v=0.3, w_t=0.2, max_time_diff=0.05):\n", 60 | " p = len(a.pitches & b.pitches) / max(1, len(a.pitches | b.pitches))\n", 61 | " v = 1 - abs(v_a - v_b) / 127\n", 62 | " t = np.exp(-abs(a.start - b.start) / max_time_diff)\n", 63 | " return w_p*p + w_v*v + w_t*t\n", 64 | "\n", 65 | "def calc_correlation_soft_count(ns, vel_means, tau):\n", 66 | " N = len(ns)\n", 67 | " C = np.zeros((N, N), dtype=int)\n", 68 | " for j in range(1, N):\n", 69 | " if note_similarity(ns[0], ns[j], vel_means[0], vel_means[j]) >= tau and ns[0].is_barline():\n", 70 | " C[0, j] = 1\n", 71 | " for i in range(1, N-1):\n", 72 | " for j in range(i+1, N):\n", 73 | " sim = note_similarity(ns[i], ns[j], vel_means[i], vel_means[j])\n", 74 | " if sim >= tau and (C[i-1, j-1] > 0 or ns[i].is_barline()):\n", 75 | " C[i, j] = C[i-1, j-1] + 1\n", 76 | " return C\n", 77 | "\n", 78 | "# ─── loopability score ───────────────────────────────────────────────────\n", 79 | "def score_loopability(ns, vel_means, tau, alpha=0.7, beta=0.3):\n", 80 | " C = calc_correlation_soft_count(ns, vel_means, tau)\n", 81 | " N = len(ns)\n", 82 | " if N < 2:\n", 83 | " return 0.0\n", 84 | " S_max = C.max() / N\n", 85 | " S_den = C.sum() / (N*(N-1)/2)\n", 86 | " return alpha * S_max + beta * S_den\n", 87 | "\n", 88 | "# ─── Process one file ─────────────────────────────────────────────────────\n", 89 | "def process_file(path, melodic_tau=0.3, drum_tau=0.1):\n", 90 | " loops = []\n", 91 | " try:\n", 92 | " score = Score(path, ttype='tick')\n", 93 | " try:\n", 94 | " beat_ticks = score.beat_ticks()\n", 95 | " except:\n", 96 | " ppq = getattr(score, 'ticks_per_quarter', getattr(score, 'ppq', 480))\n", 97 | " beat_ticks = list(range(0, score.end()+1, ppq))\n", 98 | " bars = [beat_ticks[i] for i in range(0, len(beat_ticks), 4)]\n", 99 | "\n", 100 | " for ti, track in enumerate(score.tracks):\n", 101 | " is_drum = getattr(track, 'channel', None) == 9\n", 102 | " tau = drum_tau if is_drum else melodic_tau\n", 103 | "\n", 104 | " prog = getattr(track, 'program', None)\n", 105 | " if \"drums-only\" in path:\n", 106 | " group = DRUM_GROUP\n", 107 | " else:\n", 108 | " group = DRUM_GROUP if is_drum else (GM_GROUPS[prog // 8] if prog is not None else 'Unknown')\n", 109 | "\n", 110 | " ns = compute_note_sets(track.notes, bars)\n", 111 | " if len(ns) < 2:\n", 112 | " continue\n", 113 | " vel_means = [\n", 114 | " float(np.mean([n.velocity for n in track.notes\n", 115 | " if n.start == nset.start and n.end == nset.end]))\n", 116 | " if any(n.start == nset.start and n.end == nset.end for n in track.notes)\n", 117 | " else 0.0\n", 118 | " for nset in ns\n", 119 | " ]\n", 120 | "\n", 121 | " loopability = score_loopability(ns, vel_means, tau)\n", 122 | " C = calc_correlation_soft_count(ns, vel_means, tau)\n", 123 | " try:\n", 124 | " _, endpoints = get_valid_loops(\n", 125 | " ns, C, beat_ticks,\n", 126 | " min_rep_notes=0,\n", 127 | " min_rep_beats=1.0 if not is_drum else 0.5,\n", 128 | " min_beats=1.0 if not is_drum else 0.5,\n", 129 | " max_beats=32.0,\n", 130 | " min_loop_note_density=0.0\n", 131 | " )\n", 132 | " except IndexError:\n", 133 | " continue\n", 134 | "\n", 135 | " for start, end, dur, dens in endpoints:\n", 136 | " loops.append({\n", 137 | " 'track_idx': ti,\n", 138 | " 'MIDI program number': prog,\n", 139 | " 'instrument_group': group,\n", 140 | " 'loopability': loopability,\n", 141 | " 'start_tick': start,\n", 142 | " 'end_tick': end,\n", 143 | " 'duration_beats': dur,\n", 144 | " 'note_density': dens\n", 145 | " })\n", 146 | " except Exception as e:\n", 147 | " print(f\"[Error] {os.path.basename(path)}: {e}\")\n", 148 | " return loops\n", 149 | "\n", 150 | "# ─── 1) Load CSV & select only rows whose NOMML list contains a 12 ──────────\n", 151 | "df_input = pd.read_csv(\n", 152 | " \"Final_GigaMIDI_Loop_V2_path-instrument-NOMML-type.csv\",\n", 153 | " converters={'NOMML': ast.literal_eval}\n", 154 | ")\n", 155 | "\n", 156 | "# keep rows where the NOMML list has at least one 12\n", 157 | "df_input = df_input[df_input['NOMML'].apply(lambda lst: isinstance(lst, (list,tuple)) and 12 in lst)]\n", 158 | "\n", 159 | "file_paths = df_input['file_path'].tolist()\n", 160 | "\n", 161 | "# ─── 2) Chunk size 100,000 for checkpoint ───────────────────────────────────\n", 162 | "chunk_size = 100000\n", 163 | "\n", 164 | "# ─── 3) Process in chunks, checkpoint each chunk ───────────────────────────\n", 165 | "all_rows = []\n", 166 | "for idx in range(0, len(file_paths), chunk_size):\n", 167 | " chunk = file_paths[idx: idx + chunk_size]\n", 168 | " results = Parallel(n_jobs=-1, backend='loky')(\n", 169 | " delayed(process_file)(p) for p in tqdm(chunk, desc=f\"Files {idx+1}-{idx+len(chunk)}\")\n", 170 | " )\n", 171 | "\n", 172 | " # organize one row per file, unpacking loops into parallel arrays\n", 173 | " rows = []\n", 174 | " for path, loops in zip(chunk, results):\n", 175 | " rows.append({\n", 176 | " 'file_path': path,\n", 177 | " 'track_idx': [d['track_idx'] for d in loops],\n", 178 | " 'MIDI program number': [d['MIDI program number'] for d in loops],\n", 179 | " 'instrument_group': [d['instrument_group'] for d in loops],\n", 180 | " 'loopability': [d['loopability'] for d in loops],\n", 181 | " 'start_tick': [d['start_tick'] for d in loops],\n", 182 | " 'end_tick': [d['end_tick'] for d in loops],\n", 183 | " 'duration_beats': [d['duration_beats'] for d in loops],\n", 184 | " 'note_density': [d['note_density'] for d in loops]\n", 185 | " })\n", 186 | " df_chunk = pd.DataFrame(rows)\n", 187 | "\n", 188 | " # save checkpoint\n", 189 | " checkpoint = f\"loops_checkpoint_{idx//chunk_size + 1}.csv\"\n", 190 | " df_chunk.to_csv(checkpoint, index=False)\n", 191 | " print(f\"Saved checkpoint: {checkpoint}\")\n", 192 | "\n", 193 | " all_rows.extend(rows)\n", 194 | "\n", 195 | "# ─── 4) Final combined DataFrame ─────────────────────────────────────────────\n", 196 | "df_all = pd.DataFrame(all_rows)\n", 197 | "\n", 198 | "# ─── 5) Save the full output to CSV ─────────────────────────────────────────\n", 199 | "df_all.to_csv(\"loops_full_output.csv\", index=False)\n", 200 | "print(\"Saved full output: loops_full_output.csv\")\n" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3 (ipykernel)", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.9.12" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 5 225 | } 226 | -------------------------------------------------------------------------------- /GigaMIDI/create_gigamidi_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 python 2 | 3 | """Script to create the WebDataset for the GigaMIDI dataset.""" 4 | 5 | from __future__ import annotations 6 | 7 | import json 8 | from typing import TYPE_CHECKING 9 | 10 | from datasets import Dataset 11 | from huggingface_hub import create_branch, upload_file 12 | from tqdm import tqdm 13 | from webdataset import ShardWriter 14 | 15 | from utils.GigaMIDI.GigaMIDI import _SPLITS 16 | 17 | if TYPE_CHECKING: 18 | from pathlib import Path 19 | 20 | from datasets import DatasetDict 21 | 22 | 23 | MAX_NUM_ENTRIES_PER_SHARD = 50000 24 | SUBSET_PATHS = { 25 | "all-instruments-with-drums": "drums+music", 26 | "drums-only": "drums", 27 | "no-drums": "music", 28 | } 29 | 30 | 31 | def create_webdataset_gigamidi(main_data_dir_path: Path) -> None: 32 | """ 33 | Create the WebDataset shard for the GigaMIDI dataset. 34 | 35 | :param main_data_dir_path: path of the directory containing the datasets. 36 | """ 37 | dataset_path = main_data_dir_path / "GigaMIDI_original" 38 | webdataset_path = main_data_dir_path / "GigaMIDI" 39 | 40 | # Load metadata 41 | md5_sid_matches_scores = {} 42 | with ( 43 | main_data_dir_path / "MMD_METADATA" / "MMD_audio_matches.tsv" 44 | ).open() as matches_file: 45 | matches_file.seek(0) 46 | next(matches_file) # first line skipped 47 | for line in tqdm(matches_file, desc="Reading MMD match file"): 48 | midi_md5, score, audio_sid = line.split() 49 | if midi_md5 not in md5_sid_matches_scores: 50 | md5_sid_matches_scores[midi_md5] = [] 51 | md5_sid_matches_scores[midi_md5].append((audio_sid, float(score))) 52 | sid_to_mbid = json.load( 53 | (main_data_dir_path / "MMD_METADATA" / "MMD_sid_to_mbid.json").open() 54 | ) 55 | 56 | md5_genres = {} 57 | with ( 58 | main_data_dir_path / "MMD_METADATA" / "MMD_audio_matched_genre.jsonl" 59 | ).open() as file: 60 | for row in tqdm(file, desc="Reading genres MMD metadata"): 61 | entry = json.loads(row) 62 | md5 = entry.pop("md5") 63 | md5_genres[md5] = entry 64 | md5_genres_scraped = {} 65 | with ( 66 | main_data_dir_path / "MMD_METADATA" / "MMD_scraped_genre.jsonl" 67 | ).open() as file: 68 | for row in tqdm(file, desc="Reading scraped genres MMD metadata"): 69 | entry = json.loads(row) 70 | genres = [] 71 | for genre_list in entry["genre"]: 72 | genres += genre_list 73 | md5_genres_scraped[entry["md5"]] = genres 74 | md5_artist_title_scraped = {} 75 | with ( 76 | main_data_dir_path / "MMD_METADATA" / "MMD_scraped_title_artist.jsonl" 77 | ).open() as file: 78 | for row in tqdm(file, desc="Reading scraped titles/artists MMD metadata"): 79 | entry = json.loads(row) 80 | md5_artist_title_scraped[entry["md5"]] = entry["title_artist"][0] 81 | md5_expressive = {} 82 | with ( 83 | dataset_path / "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv" 84 | ).open() as file: 85 | file.seek(0) 86 | next(file) # skipping first row (header) 87 | for row in tqdm(file, desc="Reading expressiveness metadata"): 88 | parts = row.split(",") 89 | md5 = parts[0].split("/")[-1].split(".")[0] 90 | if md5 not in md5_expressive: 91 | md5_expressive[md5] = [] 92 | md5_expressive[md5].append(int(parts[5])) 93 | """md5_loop = {} 94 | with ( 95 | dataset_path 96 | / "GigaMIDI-combined-non-expressive-loop-data" 97 | / "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv" 98 | ).open() as file: 99 | file.seek(0) 100 | next(file) # skipping first row (header) 101 | for row in tqdm(file, desc="Reading loops metadata"): 102 | parts = row.split(",") 103 | md5 = parts[0].split("/")[-1].split(".")[0] 104 | if md5 not in md5_loop: 105 | md5_loop[md5] = {} 106 | track_idx = parts[1] 107 | if track_idx not in md5_loop[md5]: 108 | md5_loop[md5][track_idx] = [] 109 | md5_loop[md5][track_idx].append(int(parts[5]))""" 110 | 111 | # Sharding the data into tar archives 112 | num_shards = {} 113 | for subset, subset_path in SUBSET_PATHS.items(): 114 | num_shards[subset] = {} 115 | for split in _SPLITS: 116 | files_paths = list((dataset_path / split / subset_path).glob("**/*.mid")) 117 | save_path = webdataset_path / subset / split 118 | save_path.mkdir(parents=True, exist_ok=True) 119 | metadata = {} 120 | with ShardWriter( 121 | f"{save_path!s}/GigaMIDI_{subset}_{split}_%04d.tar", 122 | maxcount=MAX_NUM_ENTRIES_PER_SHARD, 123 | ) as writer: 124 | for file_path in files_paths: 125 | md5 = file_path.stem 126 | example = { 127 | "__key__": md5, 128 | "mid": file_path.open("rb").read(), # bytes 129 | } 130 | writer.write(example) 131 | 132 | # Get metadata if existing 133 | metadata_row = {} 134 | 135 | sid_matches = md5_sid_matches_scores.get(md5) 136 | if sid_matches: 137 | metadata_row["sid_matches"] = sid_matches 138 | metadata_row["mbid_matches"] = [] 139 | for sid, _ in sid_matches: 140 | mbids = sid_to_mbid.get(sid, None) 141 | if mbids: 142 | metadata_row["mbid_matches"].append([sid, mbids]) 143 | 144 | title_artist = md5_artist_title_scraped.get(md5) 145 | if title_artist: 146 | ( 147 | metadata_row["title_scraped"], 148 | metadata_row["artist_scraped"], 149 | ) = title_artist 150 | genres_scraped = md5_genres_scraped.get(md5) 151 | if genres_scraped: 152 | metadata_row["genres_scraped"] = genres_scraped 153 | genres = md5_genres.get(md5) 154 | if genres: 155 | for key, val in genres.items(): 156 | metadata_row[f"genres_{key.split('_')[1]}"] = val 157 | interpreted_scores = md5_expressive.get(md5) 158 | if interpreted_scores: 159 | metadata_row["median_metric_depth"] = interpreted_scores 160 | # TODO loops 161 | if len(metadata_row) > 0: 162 | metadata[md5] = metadata_row 163 | 164 | num_shards[subset][split] = len(list(save_path.glob("*.tar"))) 165 | # Saving metadata for this subset and split 166 | """ 167 | with (save_path.parent / f"metadata_{subset}_{split}.csv").open("w") as f: 168 | writer = csv.writer(f) 169 | writer.writerow(["md5", ]) # TODO header 170 | for row in metadata: 171 | writer.writerow([row])""" 172 | with (save_path.parent / f"metadata_{subset}_{split}.json").open("w") as f: 173 | json.dump(metadata, f) 174 | 175 | # Saving n shards 176 | with (webdataset_path / "n_shards.json").open("w") as f: 177 | json.dump(num_shards, f, indent=4) 178 | 179 | 180 | def load_dataset_from_generator( 181 | dataset_path: Path, num_files_limit: int = 100 182 | ) -> Dataset: 183 | """ 184 | Load the dataset. 185 | 186 | :param dataset_path: path of the directory containing the datasets. 187 | :param num_files_limit: maximum number of entries/files to retain. 188 | :return: dataset. 189 | """ 190 | files_paths = list(dataset_path.glob("**/*.mid"))[:num_files_limit] 191 | return Dataset.from_dict({"music": [str(path_) for path_ in files_paths]}) 192 | 193 | 194 | def convert_to_parquet( 195 | dataset: DatasetDict, repo_id: str, token: str | None = None 196 | ) -> None: 197 | """ 198 | Convert a dataset to parquet files. 199 | 200 | :param dataset: dataset to convert. 201 | :param repo_id: id of the repo to upload the files. 202 | :param token: token for authentication. 203 | """ 204 | create_branch( 205 | repo_id, 206 | branch="refs/convert/parquet", 207 | revision="main", 208 | token=token, 209 | repo_type="dataset", 210 | ) 211 | files_to_push = [] 212 | for split, subset in dataset.items(): 213 | file_path = f"{split}.parquet" 214 | subset.to_parquet(file_path) 215 | upload_file( 216 | path_or_fileobj=file_path, 217 | path_in_repo=file_path, 218 | repo_id=repo_id, 219 | token=token, 220 | revision="refs/convert/parquet", 221 | repo_type="dataset", 222 | ) 223 | files_to_push.append(file_path) 224 | 225 | 226 | if __name__ == "__main__": 227 | from argparse import ArgumentParser 228 | 229 | from utils.utils import path_data_directory_local_fs 230 | 231 | parser = ArgumentParser(description="Dataset creation script") 232 | parser.add_argument( 233 | "--hf-repo-name", type=str, required=False, default="Metacreation/GigaMIDI" 234 | ) 235 | parser.add_argument("--hf-token", type=str, required=False, default=None) 236 | args = vars(parser.parse_args()) 237 | 238 | create_webdataset_gigamidi(path_data_directory_local_fs()) 239 | 240 | """dataset_ = load_dataset( 241 | args["hf_repo_name"], "music", token=args["hf_token"], trust_remote_code=True 242 | ) 243 | convert_to_parquet(dataset_, args["hf_repo_name"], token=args["hf_token"])""" 244 | """from datasets import load_dataset 245 | 246 | dataset_ = load_dataset( 247 | str(path_data_directory_local_fs() / "GigaMIDI"), 248 | "no-drums", 249 | subsets=["no-drums", "all-instruments-with-drums"], 250 | trust_remote_code=True, 251 | ) 252 | data = dataset_["train"] 253 | for i in range(7): 254 | t = data[i] 255 | f = 0 256 | 257 | test = data[0] 258 | print(test) 259 | from symusic import Score 260 | 261 | score = Score.from_midi(test["music"]["bytes"]) 262 | t = 0""" 263 | -------------------------------------------------------------------------------- /loops_nomml/corr_mat.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | from typing import TYPE_CHECKING 5 | 6 | import numpy as np 7 | 8 | from .note_set import NoteSet 9 | 10 | if TYPE_CHECKING: 11 | from collections.abc import Sequence 12 | from typing import Tuple, Dict 13 | 14 | from symusic import Note 15 | 16 | 17 | # Implementation of Correlative Matrix approach presented in: 18 | # Jia Lien Hsu, Chih Chin Liu, and Arbee L.P. Chen. Discovering 19 | # nontrivial repeating patterns in music data. IEEE Transactions on 20 | # Multimedia, 3:311–325, 9 2001. 21 | def calc_correlation(note_sets: Sequence[NoteSet]) -> np.ndarray: 22 | """ 23 | Calculates a correlation matrix of repeated segments with the note_sets. 24 | All repetitions are required to start on the downbeat of measure. 25 | 26 | :param note_sets: list of NoteSets to calculate repetitions for 27 | :return: 2d square correlation matrix the length of note_sets, each 28 | entry is an integer representing the number of continuous matching 29 | elements counting backwards from the current row and column 30 | """ 31 | corr_size = len(note_sets) 32 | corr_mat = np.zeros((corr_size, corr_size), dtype='int16') 33 | 34 | # Complete the first row 35 | for j in range(1, corr_size): 36 | if note_sets[0] == note_sets[j] and note_sets[0].is_barline(): 37 | corr_mat[0, j] = 1 38 | # Complete rest of the correlation matrix 39 | for i in range(1, corr_size - 1): 40 | for j in range(i + 1, corr_size): 41 | if note_sets[i] == note_sets[j]: 42 | if corr_mat[i - 1, j - 1] == 0: 43 | if not note_sets[i].is_barline(): 44 | continue # loops must start on the downbeat (start of bar) 45 | corr_mat[i, j] = corr_mat[i - 1, j - 1] + 1 46 | 47 | return corr_mat 48 | 49 | 50 | def get_loop_density(loop: Sequence[NoteSet], num_beats: int | float) -> float: 51 | """ 52 | Calculates the density of a list of NoteSets in active notes per beat 53 | 54 | :param loop: list of NoteSet groups in the loop 55 | :param num_beats: duration of the loop in beats 56 | :return: loop density in active notes per beat 57 | """ 58 | return len([n_set for n_set in loop if n_set.start != n_set.end]) / num_beats 59 | 60 | 61 | def is_empty_loop(loop: Sequence[Note]) -> bool: 62 | """ 63 | Checks if a sequence of notes contains at least one non-rest 64 | 65 | :param loop: sequence of MIDI notes to check 66 | :return: True if a non-rest note exists, False otherwise 67 | """ 68 | for note in loop: 69 | if len(note.pitches) > 0: 70 | return False 71 | return True 72 | 73 | 74 | def compare_loops(p1: Sequence[NoteSet], p2: Sequence[NoteSet], min_rep_beats: int | float) -> int: 75 | """ 76 | Checks if two lists of NoteSets match up to a certain number of beats. 77 | Used to track the longest common loop 78 | 79 | :param p1: new loop to compare 80 | :param p2: existing loop to compare with 81 | :return: 0 for a mismatch, 1 if p1 is a subloop of p2, 2 if p2 is a 82 | subloop of p1 83 | """ 84 | min_rep_beats = int(round(min_rep_beats)) 85 | if len(p1) < len(p2): 86 | for i in range(min_rep_beats): 87 | if p1[i] != p2[i]: 88 | return 0 #not a subloop, theres a mismatch 89 | return 1 #is a subloop 90 | else: 91 | for i in range(min_rep_beats): 92 | if p1[i] != p2[i]: 93 | return 0 #not a subloop, theres a mismatch 94 | return 2 #existing loop is subloop of the new one, replace it 95 | 96 | 97 | def test_loop_exists(loop_list: Sequence[Sequence[NoteSet]], loop: Sequence[NoteSet], min_rep_beats: int | float) -> int: 98 | """ 99 | Checks if a loop already exists in a loop, and mark it for replacement if 100 | it is longer than the existing matching loop 101 | 102 | :param loop_list: list of loops to check 103 | :param loop: new loop to check for a match 104 | :param min_rep_beats: number of beats to check for a match 105 | :return: -1 if loop is a subloop of a current loop in loop_list, idx of 106 | existing loop to replace if loop is a superstring, or None if loop 107 | is an entirely new loop 108 | """ 109 | for i, pat in enumerate(loop_list): 110 | result = compare_loops(loop, pat, min_rep_beats) 111 | if result == 1: 112 | return -1 #ignore this loop since its a subloop 113 | if result == 2: 114 | return i #replace existing loop with this new longer one 115 | return None #we're just appending the new loop 116 | 117 | 118 | def filter_sub_loops(candidate_indices: Dict[float, Tuple[int, int]]) -> Sequence[Tuple[int, int, float]]: 119 | """ 120 | Processes endpoints for identified loops, keeping only the largest 121 | unique loop when multiple loops intersect, thus eliminating "sub loops." 122 | For instance, if a 4 bar loop is made up of two 2 bar loops, only a 123 | single 2 bar loop will be returned. 124 | 125 | :param candidate_indices: dictionary of (start_tick, end_tick) for each 126 | identified group, keyed by loop length in beats 127 | :return: filtered list of loops with subloops removed 128 | """ 129 | candidate_indices = dict(sorted(candidate_indices.items())) 130 | 131 | repeats = {} 132 | final = [] 133 | for duration in candidate_indices.keys(): 134 | curr_start = 0 135 | curr_end = 0 136 | curr_dur = 0 137 | for start, end in candidate_indices[duration]: 138 | if start in repeats and repeats[start][0] == end: 139 | continue 140 | 141 | if start == curr_end: 142 | curr_end = end 143 | curr_dur += duration 144 | else: 145 | if curr_start != curr_end: 146 | repeats[curr_start] = (curr_end, curr_dur) 147 | curr_start = start 148 | curr_end = end 149 | curr_dur = duration 150 | 151 | final.append((start, end, duration)) 152 | 153 | return final 154 | 155 | 156 | def get_duration_beats(start: int, end: int, ticks_beats: Sequence[int]) -> float: 157 | """ 158 | Given a loop start and end time in ticks and a list of beat tick times, 159 | calculate the duration of the loop in beats 160 | 161 | :param start: start time of the loop in ticks 162 | :param end: end time of the loop in ticks 163 | :param ticks_beat: list of all the beat times in the track 164 | :return: duration of the loop in beats 165 | """ 166 | idx_beat_previous = None 167 | idx_beat_first_in = None 168 | idx_beat_last_in = None 169 | idx_beat_after = None 170 | 171 | for bi, beat_tick in enumerate(ticks_beats): 172 | if idx_beat_first_in is None and beat_tick >= start: 173 | idx_beat_first_in = bi 174 | idx_beat_previous = max(bi - 1, 0) 175 | elif idx_beat_last_in is None and beat_tick == end: 176 | idx_beat_last_in = idx_beat_after = bi 177 | elif idx_beat_last_in is None and beat_tick > end: 178 | idx_beat_last_in = max(bi - 1, 0) 179 | idx_beat_after = bi 180 | if idx_beat_after is None: 181 | idx_beat_after = idx_beat_last_in + ticks_beats[-1] - ticks_beats[-2] # TODO what if length 0? 182 | 183 | beat_length_before = ticks_beats[idx_beat_first_in] - ticks_beats[idx_beat_previous] 184 | if beat_length_before > 0: 185 | num_beats_before = (ticks_beats[idx_beat_first_in] - ticks_beats[idx_beat_previous]) / beat_length_before 186 | else: 187 | num_beats_before = 0 188 | beat_length_after = ticks_beats[idx_beat_after] - ticks_beats[idx_beat_last_in] 189 | if beat_length_after > 0: 190 | num_beats_after = (ticks_beats[idx_beat_after] - ticks_beats[end]) / beat_length_after 191 | else: 192 | num_beats_after = 0 193 | return float(idx_beat_last_in - idx_beat_first_in + num_beats_before + num_beats_after - 1) 194 | 195 | 196 | def get_valid_loops( 197 | note_sets: Sequence[NoteSet], 198 | corr_mat: np.ndarray, 199 | ticks_beats: Sequence[int], 200 | min_rep_notes: int=4, 201 | min_rep_beats: float=2.0, 202 | min_beats: float=4.0, 203 | max_beats: float=32.0, 204 | min_loop_note_density: float = 0.5, 205 | ) -> Tuple[Sequence[NoteSet], Tuple[int, int, float, float]]: 206 | """ 207 | Returns all of the loops detected in note_sets, filtering based on the 208 | specified hyperparameters. Loops that are subloops of larger loops will 209 | be filtered out 210 | 211 | :param min_rep_notes: Minimum number of notes that must be present in 212 | the repeated bookend of a loop for it to be considered valid 213 | :param min_rep_beats: Minimum length in beats of the repeated bookend 214 | of a loop for it be considered valid 215 | :param min_beats: Minimum total length of the loop in beats 216 | :param max_beats: Maximum total length of the loop in beats 217 | :param min_loop_note_density: Minimum valid density of a loop in average 218 | notes per beat across the whole loop 219 | :return: tuple containing the loop as a sequence of NoteSets, and an 220 | additional tuple with loop metadata: (start time in ticks, end time 221 | in ticks, duration in beats, density) 222 | """ 223 | min_rep_notes += 1 # don't count bar lines as a repetition 224 | x_num_elem, y_num_elem = np.where(corr_mat == min_rep_notes) 225 | 226 | # Parse the correlation matrix to retrieve the loops starts/ends ticks 227 | # keys are loops durations in beats, values tuples of indices TODO ?? 228 | valid_indices = {} 229 | for i, x in enumerate(x_num_elem): 230 | y = y_num_elem[i] 231 | start_x = x - corr_mat[x, y] + 1 232 | start_y = y - corr_mat[x, y] + 1 233 | 234 | loop_start_time = note_sets[start_x].start 235 | loop_end_time = note_sets[start_y].start 236 | loop_num_beats = round(get_duration_beats(loop_start_time, loop_end_time, ticks_beats), 2) 237 | if max_beats >= loop_num_beats >= min_beats: 238 | if loop_num_beats not in valid_indices: 239 | valid_indices[loop_num_beats] = [] 240 | valid_indices[loop_num_beats].append((x_num_elem[i], y_num_elem[i])) 241 | 242 | filtered_indices = filter_sub_loops(valid_indices) 243 | 244 | loops = [] 245 | loop_bp = [] 246 | corr_size = corr_mat.shape[0] 247 | for start_x, start_y, loop_num_beats in filtered_indices: 248 | x = start_x 249 | y = start_y 250 | while x + 1 < corr_size and y + 1 < corr_size and corr_mat[x + 1, y + 1] > corr_mat[x, y]: 251 | x = x + 1 252 | y = y + 1 253 | beginning = x - corr_mat[x, y] + 1 254 | end = y - corr_mat[x, y] + 1 255 | start_tick = note_sets[beginning].start 256 | end_tick = note_sets[end].start 257 | duration_beats = get_duration_beats(start_tick, end_tick, ticks_beats) 258 | 259 | if duration_beats >= min_rep_beats and not is_empty_loop(note_sets[beginning:end]): 260 | loop = note_sets[beginning:(end + 1)] 261 | loop_density = get_loop_density(loop, loop_num_beats) 262 | if loop_density < min_loop_note_density: 263 | continue 264 | exist_result = test_loop_exists(loops, loop, min_rep_beats) 265 | if exist_result is None: 266 | loops.append(loop) 267 | loop_bp.append((start_tick, end_tick, loop_num_beats, loop_density)) 268 | elif exist_result > 0: # index to replace 269 | loops[exist_result] = loop 270 | loop_bp[exist_result] = (start_tick, end_tick, loop_num_beats, loop_density) 271 | 272 | return loops, loop_bp 273 | -------------------------------------------------------------------------------- /Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Optimal Threshold Selection/Expressive_Performance_Detection_Training-Percentile-threshold.csv: -------------------------------------------------------------------------------- 1 | ,trackNum,instrument,medianMetricDepth,tpq,velocity_per_track,onset_per_track,velocity_entropy_per_track,onset_entropy_per_track,DNVR,DNODR,"Ground Truth (0=Score, 1=Performance)" 2 | count,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0 3 | mean,0.02247620869398881,0.004686528621299794,11.616230691980297,405.8130170723543,71.75218784371862,47.5673090717804,2.17519401474504,2.3491900877509306,56.49778570371541,12.266475946985935,0.958203816173306 4 | std,0.1600209488189645,0.45176420847283083,1.9479262330559586,96.8261362680451,22.05038341551496,11.507997369933452,0.30037865799320457,0.31002732650143766,17.362506626390047,3.196148670826988,0.20012790450225568 5 | min,0.0,0.0,0.0,220.0,1.0,1.0,0.0,0.0,0.78740157480315,0.09765625,0.0 6 | 3.5%,0.0,0.0,4.0,384.0,13.0,8.0,1.6410221333320987,1.847199957331255,10.2362204724409,1.45833333333333,0.0 7 | 3.51%,0.0,0.0,4.0,384.0,13.0,8.0,1.64341771979318,1.852555048428222,10.2362204724409,1.45833333333333,0.0 8 | 3.52%,0.0,0.0,4.0,384.0,13.0,8.0,1.643599811492699,1.85465261855982,10.2362204724409,1.46484375,0.0 9 | 3.53%,0.0,0.0,4.0,384.0,13.0,8.0,1.64884707167248,1.8551538950148374,10.2362204724409,1.46484375,0.0 10 | 3.54%,0.0,0.0,4.0,384.0,13.0,8.0,1.6588237201427147,1.8586239971562957,10.2362204724409,1.46484375,0.0 11 | 3.55%,0.0,0.0,4.0,384.0,13.0,8.0,1.66094743328692,1.860947629316187,10.2362204724409,1.46484375,0.0 12 | 3.56%,0.0,0.0,4.0,384.0,13.0,8.0,1.66687569743084,1.862463289114066,10.2362204724409,1.46484375,0.0 13 | 3.57%,0.0,0.0,4.0,384.0,13.0,8.0,1.6720969699001071,1.8638320116484202,10.2362204724409,1.46484375,0.0 14 | 3.58%,0.0,0.0,4.0,384.0,13.0,8.0,1.674929776698617,1.8640964797040853,10.2362204724409,1.46484375,0.0 15 | 3.59%,0.0,0.0,4.668999999999983,384.0,13.0,8.0,1.6928641492368666,1.8643394237437514,10.2362204724409,1.46484375,0.0 16 | 3.6%,0.0,0.0,5.759999999999991,384.0,13.0,8.0,1.69574253416963,1.8670664996869568,10.2362204724409,1.46484375,0.0 17 | 3.61%,0.0,0.0,6.0,384.0,13.0,8.0,1.7003000325145192,1.8678801716292,10.2362204724409,1.66666666666667,0.0 18 | 3.62%,0.0,0.0,6.0,384.0,13.0,8.0,1.7071083077240585,1.869199024407026,10.2362204724409,1.66666666666667,0.0 19 | 3.63%,0.0,0.0,6.0,384.0,13.0,8.0,1.7202515718959752,1.87180217690159,10.2362204724409,1.66666666666667,0.0 20 | 3.64%,0.0,0.0,6.0,384.0,13.0,9.0,1.7303242760737383,1.8735889466151716,10.2362204724409,1.66666666666667,0.0 21 | 3.65%,0.0,0.0,6.0,384.0,13.0,9.0,1.73286795139986,1.87464074053575,10.2362204724409,1.66666666666667,0.0 22 | 3.66%,0.0,0.0,6.0,384.0,13.0,9.0,1.73286795139986,1.8751403907674762,10.2362204724409,1.66666666666667,0.0 23 | 3.67%,0.0,0.0,6.0,384.0,13.0,9.0,1.7344455175356575,1.8774704083373137,10.2362204724409,1.66666666666667,0.0 24 | 3.68%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8806242090443093,10.2362204724409,1.66666666666667,0.0 25 | 3.69%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.881979697998224,10.2362204724409,1.66666666666667,0.0 26 | 3.7%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.884375397494431,10.2362204724409,1.66666666666667,0.0 27 | 3.71%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8854310253127236,10.2362204724409,1.66666666666667,0.0 28 | 3.72%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8861247034515896,10.2362204724409,1.66666666666667,0.0 29 | 3.73%,0.0,0.0,6.0,384.0,13.942999999999984,9.0,1.7468540433690751,1.8866247284476956,10.978740157480305,1.66666666666667,0.0 30 | 3.74%,0.0,0.0,6.0,384.0,14.0,9.0,1.74786809746676,1.88669678465808,11.0236220472441,1.66666666666667,0.0 31 | 3.75%,0.0,0.0,6.0,384.0,14.0,9.0,1.74786809746676,1.8867633891806825,11.0236220472441,1.66666666666667,0.0 32 | 3.76%,0.0,0.0,6.216000000000008,384.0,14.0,9.0,1.749235158976883,1.8879165897684507,11.0236220472441,1.66666666666667,0.0 33 | 3.77%,0.0,0.0,8.0,384.0,14.0,9.0,1.7543489653872566,1.88915916375402,11.0236220472441,1.66666666666667,0.0 34 | 3.78%,0.0,0.0,8.0,384.0,14.0,9.0,1.7565046102348396,1.88915916375402,11.0236220472441,1.66666666666667,0.0 35 | 3.79%,0.0,0.0,8.0,384.0,14.0,9.0,1.76017552682609,1.8901455523684245,11.0236220472441,1.66666666666667,0.0 36 | 3.8%,0.0,0.0,8.0,384.0,14.0,9.0,1.7641516589423036,1.8917806280499891,11.0236220472441,1.66666666666667,0.0 37 | 3.81%,0.0,0.0,8.0,384.0,14.0,9.0,1.7676503992262922,1.8936273912507633,11.0236220472441,1.66666666666667,0.0 38 | 3.82%,0.0,0.0,8.0,384.0,14.0,9.0,1.7711315368391405,1.89378823239114,11.0236220472441,1.66666666666667,0.0 39 | 3.83%,0.0,0.0,8.0,384.0,14.0,9.0,1.7722341368064598,1.89378823239114,11.0236220472441,1.66666666666667,0.0 40 | 3.84%,0.0,0.0,8.0,384.0,14.0,9.0,1.7803737124381809,1.8946386676072702,11.0236220472441,1.66666666666667,0.0 41 | 3.85%,0.0,0.0,10.069999999999936,384.0,14.0,9.0,1.78557386526543,1.8978464722049861,11.0236220472441,1.66666666666667,0.0 42 | 3.86%,0.0,0.0,12.0,384.0,14.0,9.0,1.7857294090800417,1.89892678933633,11.0236220472441,1.66666666666667,0.0 43 | 3.87%,0.0,0.0,12.0,384.0,14.0,9.0,1.789379564136363,1.8997133227581198,11.0236220472441,1.7005729166666677,0.0 44 | 3.88%,0.0,0.0,12.0,384.0,14.0,9.0,1.7907074707888906,1.9021724839218015,11.0236220472441,1.846166666666664,0.0 45 | 3.89%,0.0,0.0,12.0,384.0,14.0,9.0,1.7915399392101077,1.90368668833762,11.0236220472441,1.875,0.0 46 | 3.9%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90368668833762,11.0236220472441,1.875,0.0 47 | 3.91%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.9058469058028016,11.0236220472441,1.875,0.0 48 | 3.92%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.906297674512365,11.0236220472441,1.875,0.0 49 | 3.93%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90728399932138,11.0236220472441,1.875,0.0 50 | 3.94%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90728399932138,11.0236220472441,1.875,0.0 51 | 3.95%,0.0,0.0,12.0,384.0,14.0,9.94500000000005,1.79175946922806,1.90853528164356,11.0236220472441,1.875,0.0 52 | 3.96%,0.0,0.0,12.0,384.0,14.0,10.0,1.7917878768887283,1.9085725163331464,11.0236220472441,1.875,0.0 53 | 3.97%,0.0,0.0,12.0,384.0,14.0,10.0,1.7958592519658925,1.9102170160583758,11.0236220472441,1.875,0.0 54 | 3.98%,0.0,0.0,12.0,384.0,14.0,10.0,1.8095272062299135,1.9173405295680208,11.0236220472441,1.875,0.0 55 | 3.99%,0.0,0.0,12.0,384.0,15.0,10.0,1.8122218330460924,1.919268730921757,11.8110236220472,1.875,0.0 56 | 4%,0.0,0.0,12.0,384.0,15.0,10.0,1.817740144249624,1.921908958579216,11.8110236220472,1.875,0.0 57 | 4.01%,0.0,0.0,12.0,384.0,15.0,10.0,1.81884777724898,1.9229410566301042,11.8110236220472,1.875,0.0 58 | 4.02%,0.0,0.0,12.0,384.0,15.0,10.0,1.8198399615292822,1.92295892360442,11.8110236220472,1.875,0.0 59 | 4.03%,0.0,0.0,12.0,384.0,15.0,10.0,1.8251774601370696,1.9262625443250243,11.8110236220472,1.875,0.0 60 | 4.04%,0.0,0.0,12.0,384.0,15.0,10.0,1.8284410091155763,1.92739212613927,11.8110236220472,1.875,0.0 61 | 4.05%,0.0,0.0,12.0,384.0,15.0,10.0,1.8290448005300781,1.929631317498723,11.8110236220472,1.875,0.0 62 | 4.06%,0.0,0.0,12.0,384.0,15.0,10.0,1.8387412451018255,1.9308749430628558,11.8110236220472,1.875,0.0 63 | 4.07%,0.0,0.0,12.0,384.0,15.0,10.0,1.8414770403073026,1.9341879826241253,11.8110236220472,1.875,0.0 64 | 4.08%,0.0,0.0,12.0,384.0,15.0,10.0,1.844300645896431,1.9357406070586625,11.8110236220472,1.875,0.0 65 | 4.09%,0.0,0.0,12.0,384.0,15.0,10.0,1.8484789183547972,1.93619934568452,11.8110236220472,1.875,0.0 66 | 4.1%,0.0,0.0,12.0,384.0,15.0,10.0,1.8556092770482155,1.938055203531081,11.8110236220472,1.875,0.0 67 | 4.11%,0.0,0.0,12.0,384.0,15.0,10.0,1.8581555151930649,1.9452419877155491,11.8110236220472,1.875,0.0 68 | 4.12%,0.0,0.0,12.0,384.0,15.0,10.0,1.8606839748791058,1.946808362617706,11.8110236220472,1.875,0.0 69 | 4.13%,0.0,0.0,12.0,384.0,15.0,10.0,1.8625483330972583,1.9480946835987856,11.8110236220472,1.875,0.0 70 | 4.14%,0.0,0.0,12.0,384.0,15.0,10.0,1.8646311472827994,1.9502876379902592,11.8110236220472,1.875,0.0 71 | 4.15%,0.0,0.0,12.0,384.0,15.0,10.0,1.867393051906037,1.9522471628500282,11.8110236220472,2.034375000000018,0.0 72 | 4.16%,0.0,0.0,12.0,384.0,15.0,10.0,1.870845405337236,1.9557341208567933,11.8110236220472,2.08333333333333,0.0 73 | 4.17%,0.0,0.0,12.0,384.0,15.0,10.0,1.872001404857079,1.9566235887987975,11.8110236220472,2.08333333333333,0.0 74 | 4.18%,0.0,0.0,12.0,384.0,15.0,10.0,1.8740618122105086,1.9622061996729672,11.8110236220472,2.08333333333333,1.0 75 | 4.19%,0.0,0.0,12.0,384.0,15.0,10.0,1.875586991460673,1.9638435223793027,11.8110236220472,2.08333333333333,1.0 76 | 4.2%,0.0,0.0,12.0,384.0,15.0,10.0,1.8788601482661171,1.964708555687614,11.8110236220472,2.08333333333333,1.0 77 | 4.21%,0.0,0.0,12.0,384.0,15.0,10.0,1.8813090426966708,1.9650726610011424,11.8110236220472,2.08333333333333,1.0 78 | 4.22%,0.0,0.0,12.0,384.0,15.0,10.0,1.8833259464900192,1.9668244613798862,11.8110236220472,2.08333333333333,1.0 79 | 4.23%,0.0,0.0,12.0,384.0,15.492999999999938,10.0,1.885097503590677,1.968519069916699,12.19921259842513,2.08333333333333,1.0 80 | 4.24%,0.0,0.0,12.0,384.0,16.0,10.0,1.8867947672607561,1.9694562392644799,12.5984251968504,2.08333333333333,1.0 81 | 4.25%,0.0,0.0,12.0,384.0,16.0,10.0,1.88702726883415,1.9727340405642466,12.5984251968504,2.08333333333333,1.0 82 | 4.26%,0.0,0.0,12.0,384.0,16.0,10.0,1.8875810435120208,1.976268663684989,12.5984251968504,2.08333333333333,1.0 83 | 4.27%,0.0,0.0,12.0,384.0,16.0,10.857000000000085,1.8879031479708135,1.9807480692407462,12.5984251968504,2.08333333333333,1.0 84 | 4.28%,0.0,0.0,12.0,384.0,16.0,11.0,1.88915916375402,1.9825107772601644,12.5984251968504,2.08333333333333,1.0 85 | 4.29%,0.0,0.0,12.0,384.0,16.0,11.0,1.889540710736435,1.9841601841484624,12.5984251968504,2.08333333333333,1.0 86 | 4.3%,0.0,0.0,12.0,384.0,16.0,11.0,1.89118654382655,1.98605696815486,12.5984251968504,2.08333333333333,1.0 87 | 4.31%,0.0,0.0,12.0,384.0,16.0,11.0,1.89378823239114,1.9865522721238016,12.5984251968504,2.08333333333333,1.0 88 | 4.32%,0.0,0.0,12.0,384.0,16.0,11.0,1.89378823239114,1.9883612586606534,12.5984251968504,2.08333333333333,1.0 89 | 4.33%,0.0,0.0,12.0,384.0,16.0,11.0,1.8957379690011265,1.9889304963821068,12.5984251968504,2.08333333333333,1.0 90 | 4.34%,0.0,0.0,12.0,384.0,16.0,11.0,1.898265521640075,1.9901410988933488,12.5984251968504,2.08333333333333,1.0 91 | 4.35%,0.0,0.0,12.0,384.0,16.0,11.0,1.9011798275286038,1.9907208264464311,12.5984251968504,2.08333333333333,1.0 92 | 4.36%,0.0,0.0,12.0,384.0,16.0,11.0,1.9034547051415784,1.9922035413725887,12.5984251968504,2.08333333333333,1.0 93 | 4.37%,0.0,0.0,12.0,384.0,16.0,11.0,1.9054145198796109,1.99279814410984,12.5984251968504,2.08333333333333,1.0 94 | 4.38%,0.0,0.0,12.0,384.0,16.0,11.0,1.9059420123164048,1.9959159730930844,12.5984251968504,2.08333333333333,1.0 95 | 4.39%,0.0,0.0,12.0,384.0,16.0,11.0,1.9061517078204513,1.9968823346531932,12.5984251968504,2.08333333333333,1.0 96 | 4.4%,0.0,0.0,12.0,384.0,16.0,11.0,1.906165859476912,2.0023690864566617,12.5984251968504,2.08333333333333,1.0 97 | 4.41%,0.0,0.0,12.0,384.0,16.0,11.0,1.9073531774934107,2.0040748783423328,12.5984251968504,2.08333333333333,1.0 98 | 4.42%,0.0,0.0,12.0,384.0,16.0,11.0,1.9083781938383528,2.0044808496496294,12.5984251968504,2.08333333333333,1.0 99 | 4.43%,0.0,0.0,12.0,384.0,16.0,11.0,1.910151208135982,2.0049627474285217,12.5984251968504,2.08333333333333,1.0 100 | 4.44%,0.0,0.0,12.0,384.0,16.0,11.0,1.913697567543809,2.0052414460245642,12.5984251968504,2.08333333333333,1.0 101 | 4.45%,0.0,0.0,12.0,384.0,16.49500000000012,11.0,1.9167169320313464,2.006417223562179,12.988188976378028,2.29166666666667,1.0 102 | 4.46%,0.0,0.0,12.0,384.0,17.0,11.0,1.9177677338472006,2.0083300423976396,13.3858267716535,2.29166666666667,1.0 103 | 4.47%,0.0,0.0,12.0,384.0,17.0,11.0,1.9206649087574386,2.0096663948928795,13.3858267716535,2.29166666666667,1.0 104 | 4.48%,0.0,0.0,12.0,384.0,17.0,11.0,1.9250333341959631,2.0112340867501795,13.3858267716535,2.29166666666667,1.0 105 | 4.49%,0.0,0.0,12.0,384.0,17.0,11.0,1.925133130993522,2.0120230786147806,13.3858267716535,2.29166666666667,1.0 106 | 4.5%,0.0,0.0,12.0,384.0,17.0,11.0,1.9255530695368681,2.0125241882899965,13.3858267716535,2.29166666666667,1.0 107 | 50%,0.0,0.0,12.0,384.0,75.0,52.0,2.22156599484268,2.38234678527306,59.0551181102362,13.5416666666667,1.0 108 | max,3.0,46.0,12.0,1024.0,126.0,120.0,3.08733027111959,3.18461464679519,99.2125984251969,54.5454545454546,1.0 109 | -------------------------------------------------------------------------------- /GigaMIDI/GigaMIDI.py: -------------------------------------------------------------------------------- 1 | """The GigaMIDI dataset.""" # noqa:N999 2 | 3 | from __future__ import annotations 4 | 5 | import json 6 | from collections import defaultdict 7 | from pathlib import Path 8 | from typing import TYPE_CHECKING, Literal 9 | 10 | import datasets 11 | 12 | if TYPE_CHECKING: 13 | from collections.abc import Sequence 14 | 15 | from datasets.utils.file_utils import ArchiveIterable 16 | 17 | _CITATION = "" 18 | _DESCRIPTION = "A large-scale MIDI symbolic music dataset." 19 | _HOMEPAGE = "https://github.com/Metacreation-Lab/GigaMIDI" 20 | _LICENSE = "CC0, also see https://www.europarl.europa.eu/legal-notice/en/" 21 | _SUBSETS = ["all-instruments-with-drums", "drums-only", "no-drums"] 22 | _SPLITS = ["train", "validation", "test"] 23 | _BASE_DATA_DIR = "" 24 | _N_SHARDS_FILE = _BASE_DATA_DIR + "n_shards.json" 25 | _MUSIC_PATH = ( 26 | _BASE_DATA_DIR + "{subset}/{split}/GigaMIDI_{subset}_{split}_{shard_idx}.tar" 27 | ) 28 | _METADATA_PATH = _BASE_DATA_DIR + "{subset}/metadata_{subset}_{split}.json" 29 | _METADATA_FEATURES = { 30 | "sid_matches": datasets.Sequence( 31 | {"sid": datasets.Value("string"), "score": datasets.Value("float16")} 32 | ), 33 | "mbid_matches": datasets.Sequence( 34 | { 35 | "sid": datasets.Value("string"), 36 | "mbids": datasets.Sequence(datasets.Value("string")), 37 | } 38 | ), 39 | "artist_scraped": datasets.Value("string"), 40 | "title_scraped": datasets.Value("string"), 41 | "genres_scraped": datasets.Sequence(datasets.Value("string")), 42 | "genres_discogs": datasets.Sequence( 43 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")} 44 | ), 45 | "genres_tagtraum": datasets.Sequence( 46 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")} 47 | ), 48 | "genres_lastfm": datasets.Sequence( 49 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")} 50 | ), 51 | "median_metric_depth": datasets.Sequence(datasets.Value("int16")), 52 | # "loops": datasets.Value("string"), 53 | } 54 | _VERSION = "1.0.0" 55 | 56 | 57 | """def cast_metadata_to_python( 58 | metadata: dict[str, Any], 59 | features: dict[str, datasets.Features] | None = None, 60 | ) -> dict: 61 | if features is None: 62 | features = _METADATA_FEATURES 63 | metadata_ = {} 64 | for feature_name, feature in features.items(): 65 | data = metadata.get(feature_name, None) 66 | if ( 67 | data 68 | and isinstance(feature, datasets.Sequence) 69 | and isinstance(feature.feature, (datasets.Sequence, dict)) 70 | ): 71 | if isinstance(feature.feature, datasets.Sequence): 72 | metadata_[feature_name] = [ 73 | cast_metadata_to_python( 74 | {feature_name: sample}, {feature_name: feature.feature} 75 | ) 76 | for sample in data 77 | ] 78 | else: 79 | metadata_[feature_name] = { 80 | cast_metadata_to_python( 81 | {feature_name_: sample}, feature.feature 82 | ) 83 | for sample in data 84 | for feature_name_ in feature.feature 85 | } 86 | else: 87 | metadata_[feature_name] = data 88 | 89 | return metadata_""" 90 | 91 | 92 | class GigaMIDIConfig(datasets.BuilderConfig): 93 | """BuilderConfig for GigaMIDI.""" 94 | 95 | def __init__( 96 | self, 97 | name: str, 98 | subsets: Sequence[ 99 | Literal[ 100 | "all-instruments-with-drums", 101 | "all-instruments-no-drums", 102 | "drums-only", 103 | ] 104 | ] 105 | | None = None, 106 | **kwargs, 107 | ) -> None: 108 | """ 109 | BuilderConfig for GigaMIDI. 110 | 111 | Args: 112 | ---- 113 | name: `string` or `List[string]`: 114 | name of the dataset subset. Must be either "drums" for files containing 115 | only drum tracks, "music" for others or "all" for all. 116 | subsets: `Sequence[string]`: list of subsets to use. It is None by default 117 | and resort to the `name` argument to select one subset if not provided. 118 | **kwargs: keyword arguments forwarded to super. 119 | 120 | """ 121 | if name == "all": 122 | self.subsets = _SUBSETS 123 | elif subsets is not None: 124 | self.subsets = subsets 125 | name = "_".join(subsets) 126 | else: 127 | self.subsets = [name] 128 | 129 | super().__init__(name=name, **kwargs) 130 | 131 | 132 | class GigaMIDI(datasets.GeneratorBasedBuilder): 133 | """The GigaMIDI dataset.""" 134 | 135 | VERSION = datasets.Version(_VERSION) 136 | BUILDER_CONFIGS = [ # noqa:RUF012 137 | GigaMIDIConfig( 138 | name=name, 139 | version=datasets.Version(_VERSION), 140 | ) 141 | for name in ["all", *_SUBSETS] 142 | ] 143 | DEFAULT_WRITER_BATCH_SIZE = 256 144 | 145 | def _info(self) -> datasets.DatasetInfo: 146 | features = datasets.Features( 147 | { 148 | "md5": datasets.Value("string"), 149 | "music": { 150 | "path": datasets.Value("string"), 151 | "bytes": datasets.Value("binary"), 152 | }, 153 | "is_drums": datasets.Value("bool"), 154 | **_METADATA_FEATURES, 155 | } 156 | ) 157 | return datasets.DatasetInfo( 158 | description=_DESCRIPTION, 159 | features=features, 160 | homepage=_HOMEPAGE, 161 | license=_LICENSE, 162 | citation=_CITATION, 163 | version=_VERSION, 164 | ) 165 | 166 | def _split_generators( 167 | self, dl_manager: datasets.DownloadManager | datasets.StreamingDownloadManager 168 | ) -> list[datasets.SplitGenerator]: 169 | n_shards_path = Path(dl_manager.download_and_extract(_N_SHARDS_FILE)) 170 | with n_shards_path.open() as f: 171 | n_shards = json.load(f) 172 | 173 | music_urls = defaultdict(dict) 174 | for split in _SPLITS: 175 | for subset in self.config.subsets: 176 | music_urls[split][subset] = [ 177 | _MUSIC_PATH.format(subset=subset, split=split, shard_idx=f"{i:04d}") 178 | for i in range(n_shards[subset][split]) 179 | ] 180 | 181 | meta_urls = defaultdict(dict) 182 | for split in _SPLITS: 183 | for subset in self.config.subsets: 184 | meta_urls[split][subset] = _METADATA_PATH.format( 185 | subset=subset, split=split 186 | ) 187 | 188 | # dl_manager.download_config.num_proc = len(urls) 189 | 190 | meta_paths = dl_manager.download_and_extract(meta_urls) 191 | music_paths = dl_manager.download(music_urls) 192 | 193 | local_extracted_music_paths = ( 194 | dl_manager.extract(music_paths) 195 | if not dl_manager.is_streaming 196 | else { 197 | split: { 198 | subset: [None] * len(music_paths[split][subset]) 199 | for subset in self.config.subsets 200 | } 201 | for split in _SPLITS 202 | } 203 | ) 204 | 205 | return [ 206 | datasets.SplitGenerator( 207 | name=split_name, 208 | gen_kwargs={ 209 | "music_shards": { 210 | subset: [ 211 | dl_manager.iter_archive(shard) for shard in subset_shards 212 | ] 213 | for subset, subset_shards in music_paths[split_name].items() 214 | }, 215 | "local_extracted_shards_paths": local_extracted_music_paths[ 216 | split_name 217 | ], 218 | "metadata_paths": meta_paths[split_name], 219 | }, 220 | ) 221 | for split_name in _SPLITS 222 | ] 223 | 224 | def _generate_examples( 225 | self, 226 | music_shards: dict[str, Sequence[ArchiveIterable]], 227 | local_extracted_shards_paths: dict[str, Sequence[dict]], 228 | metadata_paths: dict[str, Path], 229 | ) -> dict: 230 | if not ( 231 | len(metadata_paths) 232 | == len(music_shards) 233 | == len(local_extracted_shards_paths) 234 | ): 235 | msg = "The number of subsets provided are not equals" 236 | raise ValueError(msg) 237 | 238 | for subset in self.config.subsets: 239 | if len(music_shards[subset]) != len(local_extracted_shards_paths[subset]): 240 | msg = "the number of shards must be equal to the number of paths" 241 | raise ValueError(msg) 242 | 243 | is_drums = subset == "drums" 244 | with Path(metadata_paths[subset]).open() as file: 245 | metadata = json.load(file) 246 | 247 | for music_shard, local_extracted_shard_path in zip( 248 | music_shards[subset], local_extracted_shards_paths[subset] 249 | ): 250 | for music_file_name, music_file in music_shard: 251 | md5 = music_file_name.split(".")[0] 252 | path = ( 253 | str(Path(str(local_extracted_shard_path)) / music_file_name) 254 | if local_extracted_shard_path 255 | else music_file_name 256 | ) 257 | 258 | metadata_ = metadata.get(md5, {}) 259 | yield ( 260 | md5, 261 | { 262 | "md5": md5, 263 | "music": {"path": path, "bytes": music_file.read()}, 264 | "is_drums": is_drums, 265 | "sid_matches": [ 266 | {"sid": sid, "score": score} 267 | for sid, score in metadata_.get("sid_matches", []) 268 | ], 269 | "mbid_matches": [ 270 | {"sid": sid, "mbids": mbids} 271 | for sid, mbids in metadata_.get("mbid_matches", []) 272 | ], 273 | "artist_scraped": metadata_.get("artist_scraped", None), 274 | "title_scraped": metadata_.get("title_scraped", None), 275 | "genres_scraped": metadata_.get("genres_scraped", None), 276 | "genres_discogs": [ 277 | {"genre": genre, "count": count} 278 | for genre, count in metadata_.get( 279 | "genres_discogs", {} 280 | ).items() 281 | ], 282 | "genres_tagtraum": [ 283 | {"genre": genre, "count": count} 284 | for genre, count in metadata_.get( 285 | "genres_tagtraum", {} 286 | ).items() 287 | ], 288 | "genres_lastfm": [ 289 | {"genre": genre, "count": count} 290 | for genre, count in metadata_.get( 291 | "genres_lastfm", {} 292 | ).items() 293 | ], 294 | "median_metric_depth": metadata_.get( 295 | "median_metric_depth", None 296 | ), 297 | # "loops": metadata_.get("loops", None), 298 | }, 299 | ) 300 | -------------------------------------------------------------------------------- /scripts/dataset_short_files.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 python 2 | 3 | """ 4 | Script analyzing the programs and number of tracks of short files from the dataset. 5 | 6 | Results: 7 | Number of files with less than 8 bars: 589546 8 | Program -1 (Drums): 368071 (0.549%) 9 | Program 0 (Acoustic Grand Piano): 291196 (0.434%) 10 | Program 1 (Bright Acoustic Piano): 110 (0.000%) 11 | Program 2 (Electric Grand Piano): 75 (0.000%) 12 | Program 3 (Honky-tonk Piano): 43 (0.000%) 13 | Program 4 (Electric Piano 1): 93 (0.000%) 14 | Program 5 (Electric Piano 2): 32 (0.000%) 15 | Program 6 (Harpsichord): 71 (0.000%) 16 | Program 7 (Clavi): 11 (0.000%) 17 | Program 8 (Celesta): 14 (0.000%) 18 | Program 9 (Glockenspiel): 26 (0.000%) 19 | Program 10 (Music Box): 48 (0.000%) 20 | Program 11 (Vibraphone): 123 (0.000%) 21 | Program 12 (Marimba): 54 (0.000%) 22 | Program 13 (Xylophone): 15 (0.000%) 23 | Program 14 (Tubular Bells): 29 (0.000%) 24 | Program 15 (Dulcimer): 34 (0.000%) 25 | Program 16 (Drawbar Organ): 23 (0.000%) 26 | Program 17 (Percussive Organ): 8 (0.000%) 27 | Program 18 (Rock Organ): 51 (0.000%) 28 | Program 19 (Church Organ): 150 (0.000%) 29 | Program 20 (Reed Organ): 4 (0.000%) 30 | Program 21 (Accordion): 13 (0.000%) 31 | Program 22 (Harmonica): 17 (0.000%) 32 | Program 23 (Tango Accordion): 16 (0.000%) 33 | Program 24 (Acoustic Guitar (nylon)): 189 (0.000%) 34 | Program 25 (Acoustic Guitar (steel)): 100 (0.000%) 35 | Program 26 (Electric Guitar (jazz)): 84 (0.000%) 36 | Program 27 (Electric Guitar (clean)): 90 (0.000%) 37 | Program 28 (Electric Guitar (muted)): 47 (0.000%) 38 | Program 29 (Overdriven Guitar): 73 (0.000%) 39 | Program 30 (Distortion Guitar): 85 (0.000%) 40 | Program 31 (Guitar Harmonics): 6 (0.000%) 41 | Program 32 (Acoustic Bass): 172 (0.000%) 42 | Program 33 (Electric Bass (finger)): 154 (0.000%) 43 | Program 34 (Electric Bass (pick)): 34 (0.000%) 44 | Program 35 (Fretless Bass): 63 (0.000%) 45 | Program 36 (Slap Bass 1): 23 (0.000%) 46 | Program 37 (Slap Bass 2): 24 (0.000%) 47 | Program 38 (Synth Bass 1): 1727 (0.003%) 48 | Program 39 (Synth Bass 2): 28 (0.000%) 49 | Program 40 (Violin): 40 (0.000%) 50 | Program 41 (Viola): 22 (0.000%) 51 | Program 42 (Cello): 24 (0.000%) 52 | Program 43 (Contrabass): 26 (0.000%) 53 | Program 44 (Tremolo Strings): 45 (0.000%) 54 | Program 45 (Pizzicato Strings): 32 (0.000%) 55 | Program 46 (Orchestral Harp): 71 (0.000%) 56 | Program 47 (Timpani): 101 (0.000%) 57 | Program 48 (String Ensembles 1): 388 (0.001%) 58 | Program 49 (String Ensembles 2): 71 (0.000%) 59 | Program 50 (SynthStrings 1): 48 (0.000%) 60 | Program 51 (SynthStrings 2): 25 (0.000%) 61 | Program 52 (Choir Aahs): 79 (0.000%) 62 | Program 53 (Voice Oohs): 16 (0.000%) 63 | Program 54 (Synth Voice): 24 (0.000%) 64 | Program 55 (Orchestra Hit): 49 (0.000%) 65 | Program 56 (Trumpet): 202 (0.000%) 66 | Program 57 (Trombone): 107 (0.000%) 67 | Program 58 (Tuba): 49 (0.000%) 68 | Program 59 (Muted Trumpet): 27 (0.000%) 69 | Program 60 (French Horn): 116 (0.000%) 70 | Program 61 (Brass Section): 115 (0.000%) 71 | Program 62 (Synth Brass 1): 42 (0.000%) 72 | Program 63 (Synth Brass 2): 19 (0.000%) 73 | Program 64 (Soprano Sax): 6 (0.000%) 74 | Program 65 (Alto Sax): 28 (0.000%) 75 | Program 66 (Tenor Sax): 30 (0.000%) 76 | Program 67 (Baritone Sax): 12 (0.000%) 77 | Program 68 (Oboe): 52 (0.000%) 78 | Program 69 (English Horn): 9 (0.000%) 79 | Program 70 (Bassoon): 32 (0.000%) 80 | Program 71 (Clarinet): 95 (0.000%) 81 | Program 72 (Piccolo): 29 (0.000%) 82 | Program 73 (Flute): 92 (0.000%) 83 | Program 74 (Recorder): 7 (0.000%) 84 | Program 75 (Pan Flute): 28 (0.000%) 85 | Program 76 (Blown Bottle): 9 (0.000%) 86 | Program 77 (Shakuhachi): 12 (0.000%) 87 | Program 78 (Whistle): 12 (0.000%) 88 | Program 79 (Ocarina): 26 (0.000%) 89 | Program 80 (Lead 1 (square)): 2165 (0.003%) 90 | Program 81 (Lead 2 (sawtooth)): 2037 (0.003%) 91 | Program 82 (Lead 3 (calliope)): 13 (0.000%) 92 | Program 83 (Lead 4 (chiff)): 4 (0.000%) 93 | Program 84 (Lead 5 (charang)): 9 (0.000%) 94 | Program 85 (Lead 6 (voice)): 4 (0.000%) 95 | Program 86 (Lead 7 (fifths)): 3 (0.000%) 96 | Program 87 (Lead 8 (bass + lead)): 31 (0.000%) 97 | Program 88 (Pad 1 (new age)): 40 (0.000%) 98 | Program 89 (Pad 2 (warm)): 37 (0.000%) 99 | Program 90 (Pad 3 (polysynth)): 19 (0.000%) 100 | Program 91 (Pad 4 (choir)): 9 (0.000%) 101 | Program 92 (Pad 5 (bowed)): 47 (0.000%) 102 | Program 93 (Pad 6 (metallic)): 11 (0.000%) 103 | Program 94 (Pad 7 (halo)): 21 (0.000%) 104 | Program 95 (Pad 8 (sweep)): 16 (0.000%) 105 | Program 96 (FX 1 (rain)): 3 (0.000%) 106 | Program 97 (FX 2 (soundtrack)): 9 (0.000%) 107 | Program 98 (FX 3 (crystal)): 10 (0.000%) 108 | Program 99 (FX 4 (atmosphere)): 12 (0.000%) 109 | Program 100 (FX 5 (brightness)): 9 (0.000%) 110 | Program 101 (FX 6 (goblins)): 24 (0.000%) 111 | Program 102 (FX 7 (echoes)): 5 (0.000%) 112 | Program 103 (FX 8 (sci-fi)): 4 (0.000%) 113 | Program 104 (Sitar): 16 (0.000%) 114 | Program 105 (Banjo): 13 (0.000%) 115 | Program 106 (Shamisen): 7 (0.000%) 116 | Program 107 (Koto): 6 (0.000%) 117 | Program 108 (Kalimba): 12 (0.000%) 118 | Program 109 (Bag pipe): 1 (0.000%) 119 | Program 110 (Fiddle): 19 (0.000%) 120 | Program 111 (Shanai): 3 (0.000%) 121 | Program 112 (Tinkle Bell): 15 (0.000%) 122 | Program 113 (Agogo): 4 (0.000%) 123 | Program 114 (Steel Drums): 7 (0.000%) 124 | Program 115 (Woodblock): 8 (0.000%) 125 | Program 116 (Taiko Drum): 20 (0.000%) 126 | Program 117 (Melodic Tom): 16 (0.000%) 127 | Program 118 (Synth Drum): 9 (0.000%) 128 | Program 119 (Reverse Cymbal): 17 (0.000%) 129 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 6 (0.000%) 130 | Program 121 (Breath Noise, Flute Key Click): 2 (0.000%) 131 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 6 (0.000%) 132 | Program 123 (Bird Tweet, Dog, Horse Gallop): 3 (0.000%) 133 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 7 (0.000%) 134 | Program 125 (Helicopter, Car Sounds): 4 (0.000%) 135 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 16 (0.000%) 136 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 0 (0.000%) 137 | 138 | When reversing the condition on the file duration (to keep long files): 139 | Number of files with less than 8 bars: 279213 140 | Program -1 (Drums): 247301 (0.126%) 141 | Program 0 (Acoustic Grand Piano): 243676 (0.124%) 142 | Program 1 (Bright Acoustic Piano): 21722 (0.011%) 143 | Program 2 (Electric Grand Piano): 7453 (0.004%) 144 | Program 3 (Honky-tonk Piano): 4806 (0.002%) 145 | Program 4 (Electric Piano 1): 19245 (0.010%) 146 | Program 5 (Electric Piano 2): 13067 (0.007%) 147 | Program 6 (Harpsichord): 15961 (0.008%) 148 | Program 7 (Clavi): 5560 (0.003%) 149 | Program 8 (Celesta): 3726 (0.002%) 150 | Program 9 (Glockenspiel): 9911 (0.005%) 151 | Program 10 (Music Box): 3361 (0.002%) 152 | Program 11 (Vibraphone): 16293 (0.008%) 153 | Program 12 (Marimba): 7852 (0.004%) 154 | Program 13 (Xylophone): 3917 (0.002%) 155 | Program 14 (Tubular Bells): 6310 (0.003%) 156 | Program 15 (Dulcimer): 1123 (0.001%) 157 | Program 16 (Drawbar Organ): 7152 (0.004%) 158 | Program 17 (Percussive Organ): 7511 (0.004%) 159 | Program 18 (Rock Organ): 14286 (0.007%) 160 | Program 19 (Church Organ): 4997 (0.003%) 161 | Program 20 (Reed Organ): 1365 (0.001%) 162 | Program 21 (Accordion): 8649 (0.004%) 163 | Program 22 (Harmonica): 7059 (0.004%) 164 | Program 23 (Tango Accordion): 2963 (0.002%) 165 | Program 24 (Acoustic Guitar (nylon)): 30844 (0.016%) 166 | Program 25 (Acoustic Guitar (steel)): 62442 (0.032%) 167 | Program 26 (Electric Guitar (jazz)): 30134 (0.015%) 168 | Program 27 (Electric Guitar (clean)): 39305 (0.020%) 169 | Program 28 (Electric Guitar (muted)): 23762 (0.012%) 170 | Program 29 (Overdriven Guitar): 28155 (0.014%) 171 | Program 30 (Distortion Guitar): 28541 (0.015%) 172 | Program 31 (Guitar Harmonics): 2644 (0.001%) 173 | Program 32 (Acoustic Bass): 26662 (0.014%) 174 | Program 33 (Electric Bass (finger)): 54744 (0.028%) 175 | Program 34 (Electric Bass (pick)): 7300 (0.004%) 176 | Program 35 (Fretless Bass): 26724 (0.014%) 177 | Program 36 (Slap Bass 1): 3124 (0.002%) 178 | Program 37 (Slap Bass 2): 2557 (0.001%) 179 | Program 38 (Synth Bass 1): 12350 (0.006%) 180 | Program 39 (Synth Bass 2): 7469 (0.004%) 181 | Program 40 (Violin): 14198 (0.007%) 182 | Program 41 (Viola): 5949 (0.003%) 183 | Program 42 (Cello): 8222 (0.004%) 184 | Program 43 (Contrabass): 6633 (0.003%) 185 | Program 44 (Tremolo Strings): 5769 (0.003%) 186 | Program 45 (Pizzicato Strings): 20000 (0.010%) 187 | Program 46 (Orchestral Harp): 12108 (0.006%) 188 | Program 47 (Timpani): 18025 (0.009%) 189 | Program 48 (String Ensembles 1): 87048 (0.044%) 190 | Program 49 (String Ensembles 2): 33319 (0.017%) 191 | Program 50 (SynthStrings 1): 26076 (0.013%) 192 | Program 51 (SynthStrings 2): 6654 (0.003%) 193 | Program 52 (Choir Aahs): 45370 (0.023%) 194 | Program 53 (Voice Oohs): 18223 (0.009%) 195 | Program 54 (Synth Voice): 10778 (0.006%) 196 | Program 55 (Orchestra Hit): 3813 (0.002%) 197 | Program 56 (Trumpet): 47874 (0.024%) 198 | Program 57 (Trombone): 40421 (0.021%) 199 | Program 58 (Tuba): 17945 (0.009%) 200 | Program 59 (Muted Trumpet): 5584 (0.003%) 201 | Program 60 (French Horn): 34418 (0.018%) 202 | Program 61 (Brass Section): 20338 (0.010%) 203 | Program 62 (Synth Brass 1): 8776 (0.004%) 204 | Program 63 (Synth Brass 2): 3898 (0.002%) 205 | Program 64 (Soprano Sax): 4883 (0.002%) 206 | Program 65 (Alto Sax): 28511 (0.015%) 207 | Program 66 (Tenor Sax): 20751 (0.011%) 208 | Program 67 (Baritone Sax): 9189 (0.005%) 209 | Program 68 (Oboe): 23015 (0.012%) 210 | Program 69 (English Horn): 4493 (0.002%) 211 | Program 70 (Bassoon): 18090 (0.009%) 212 | Program 71 (Clarinet): 40191 (0.021%) 213 | Program 72 (Piccolo): 10367 (0.005%) 214 | Program 73 (Flute): 43843 (0.022%) 215 | Program 74 (Recorder): 3815 (0.002%) 216 | Program 75 (Pan Flute): 8700 (0.004%) 217 | Program 76 (Blown Bottle): 1210 (0.001%) 218 | Program 77 (Shakuhachi): 1754 (0.001%) 219 | Program 78 (Whistle): 3024 (0.002%) 220 | Program 79 (Ocarina): 2934 (0.001%) 221 | Program 80 (Lead 1 (square)): 11486 (0.006%) 222 | Program 81 (Lead 2 (sawtooth)): 16702 (0.009%) 223 | Program 82 (Lead 3 (calliope)): 7935 (0.004%) 224 | Program 83 (Lead 4 (chiff)): 1143 (0.001%) 225 | Program 84 (Lead 5 (charang)): 2186 (0.001%) 226 | Program 85 (Lead 6 (voice)): 2405 (0.001%) 227 | Program 86 (Lead 7 (fifths)): 617 (0.000%) 228 | Program 87 (Lead 8 (bass + lead)): 7257 (0.004%) 229 | Program 88 (Pad 1 (new age)): 8296 (0.004%) 230 | Program 89 (Pad 2 (warm)): 10565 (0.005%) 231 | Program 90 (Pad 3 (polysynth)): 5853 (0.003%) 232 | Program 91 (Pad 4 (choir)): 6063 (0.003%) 233 | Program 92 (Pad 5 (bowed)): 2008 (0.001%) 234 | Program 93 (Pad 6 (metallic)): 2011 (0.001%) 235 | Program 94 (Pad 7 (halo)): 2952 (0.002%) 236 | Program 95 (Pad 8 (sweep)): 4357 (0.002%) 237 | Program 96 (FX 1 (rain)): 1616 (0.001%) 238 | Program 97 (FX 2 (soundtrack)): 761 (0.000%) 239 | Program 98 (FX 3 (crystal)): 2101 (0.001%) 240 | Program 99 (FX 4 (atmosphere)): 4091 (0.002%) 241 | Program 100 (FX 5 (brightness)): 6566 (0.003%) 242 | Program 101 (FX 6 (goblins)): 1170 (0.001%) 243 | Program 102 (FX 7 (echoes)): 2688 (0.001%) 244 | Program 103 (FX 8 (sci-fi)): 1336 (0.001%) 245 | Program 104 (Sitar): 1857 (0.001%) 246 | Program 105 (Banjo): 3734 (0.002%) 247 | Program 106 (Shamisen): 1057 (0.001%) 248 | Program 107 (Koto): 1416 (0.001%) 249 | Program 108 (Kalimba): 1844 (0.001%) 250 | Program 109 (Bag pipe): 819 (0.000%) 251 | Program 110 (Fiddle): 2041 (0.001%) 252 | Program 111 (Shanai): 362 (0.000%) 253 | Program 112 (Tinkle Bell): 1044 (0.001%) 254 | Program 113 (Agogo): 552 (0.000%) 255 | Program 114 (Steel Drums): 1569 (0.001%) 256 | Program 115 (Woodblock): 1216 (0.001%) 257 | Program 116 (Taiko Drum): 2245 (0.001%) 258 | Program 117 (Melodic Tom): 1599 (0.001%) 259 | Program 118 (Synth Drum): 3203 (0.002%) 260 | Program 119 (Reverse Cymbal): 10701 (0.005%) 261 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 2697 (0.001%) 262 | Program 121 (Breath Noise, Flute Key Click): 548 (0.000%) 263 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 3429 (0.002%) 264 | Program 123 (Bird Tweet, Dog, Horse Gallop): 659 (0.000%) 265 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 1642 (0.001%) 266 | Program 125 (Helicopter, Car Sounds): 1261 (0.001%) 267 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 1438 (0.001%) 268 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 0 (0.000%) 269 | """ 270 | 271 | if __name__ == "__main__": 272 | from pathlib import Path 273 | 274 | import numpy as np 275 | from matplotlib import pyplot as plt 276 | from miditok.constants import MIDI_INSTRUMENTS, SCORE_LOADING_EXCEPTION 277 | from miditok.utils import get_bars_ticks 278 | from symusic import Score 279 | from tqdm import tqdm 280 | 281 | from utils.baseline import mmm 282 | from utils.constants import ( 283 | MIN_NUM_BARS_FILE_VALID, 284 | ) 285 | 286 | NUM_HIST_BINS = 50 287 | 288 | # Filter non-valid files 289 | dataset_files_paths = mmm.dataset_files_paths 290 | num_tracks, programs = [], [] 291 | for file_path in tqdm(dataset_files_paths, desc="Reading MIDI files"): 292 | try: 293 | score = Score(file_path) 294 | except SCORE_LOADING_EXCEPTION: 295 | continue 296 | score = mmm.tokenizer.preprocess_score(score) 297 | if len(get_bars_ticks(score)) < MIN_NUM_BARS_FILE_VALID: 298 | continue 299 | 300 | num_tracks.append(len(score.tracks)) 301 | programs += [-1 if track.is_drum else track.program for track in score.tracks] 302 | 303 | print( # noqa: T201 304 | f"Number of files with less than {MIN_NUM_BARS_FILE_VALID} bars: " 305 | f"{len(num_tracks)}" 306 | ) 307 | 308 | programs = np.array(programs) 309 | for program in range(-1, 128): 310 | num_occurrences = len(np.where(programs == program)[0]) 311 | ratio = num_occurrences / len(programs) 312 | print( # noqa: T201 313 | f"Program {program} (" 314 | f"{'Drums' if program == -1 else MIDI_INSTRUMENTS[program]['name']}): " 315 | f"{num_occurrences} ({ratio:.3f}%)" 316 | ) 317 | 318 | # Plotting the distributions 319 | fig, ax = plt.subplots() 320 | ax.hist(num_tracks, bins=NUM_HIST_BINS) 321 | ax.grid(axis="y", linestyle="--", linewidth=0.6) 322 | ax.set_ylabel("Count files") 323 | ax.set_xlabel("Number of tracks") 324 | fig.savefig(Path("GigaMIDI_length_bars.pdf"), bbox_inches="tight", dpi=300) 325 | plt.close(fig) 326 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GigaMIDI Dataset 2 | ## The Extended GigaMIDI Dataset with Music Loop Detection and Expressive Multitrack Loop Generation 3 | \ 4 | ![GigaMIDI Logo](./Giga_MIDI_Logo_Final.png) 5 | 6 | ## The extended GigaMIDI Dataset Summary 7 | We present the extended GigaMIDI dataset, a large-scale symbolic music collection comprising over 2.1 million unique MIDI files with detailed annotations for music loop detection. Expanding on its predecessor, this release introduces a novel expressive loop detection method that captures performance nuances such as microtiming and dynamic variation, essential for advanced generative music modelling. Our method extends previous approaches, which were limited to strictly quantized, non-expressive tracks, by employing the Note Onset Median Metric Level (NOMML) heuristic to distinguish expressive from non-expressive material. This enables robust loop detection across a broader spectrum of MIDI data. Our loop detection method reveals more than 9.2 million non-expressive loops spanning all General MIDI instruments, alongside 2.3 million expressive loops identified through our new method. As the largest resource of its kind, the extended GigaMIDI dataset provides a strong foundation for developing models that synthesize structurally coherent and expressively rich musical loops. As a use case, we leverage this dataset to train an expressive multitrack symbolic music loop generation model using the MIDI-GPT system, resulting in the creation of a synthetic loop dataset. 8 | 9 | ## GigaMIDI Dataset Version Update 10 | 11 | We present the extended GigaMIDI dataset (select v2.0.0), a large-scale symbolic music collection comprising over 2.1 million unique MIDI files with detailed annotations for music loop detection. Expanding on its predecessor, this release introduces a novel expressive loop detection method that captures performance nuances such as microtiming and dynamic variation, essential for advanced generative music modelling. Our method extends previous approaches, which were limited to strictly quantized, non-expressive tracks, by employing the Note Onset Median Metric Level (NOMML) heuristic to distinguish expressive from non-expressive material. This enables robust loop detection across a broader spectrum of MIDI data. Our loop detection method reveals more than 9.2 million non-expressive loops spanning all General MIDI instruments, alongside 2.3 million expressive loops identified through our new method. As the largest resource of its kind, the extended GigaMIDI dataset provides a strong foundation for developing models that synthesize structurally coherent and expressively rich musical loops. As a use case, we leverage this dataset to train an expressive multitrack symbolic music loop generation model using the MIDI-GPT system, resulting in the creation of a synthetic loop dataset. The GigaMIDI dataset is accessible for research purposes on the Hugging Face hub [https://huggingface.co/datasets/Metacreation/GigaMIDI] in a user-friendly way for convenience and reproducibility. 12 | 13 | 14 | For the Hugging Face Hub dataset release, we disclaim any responsibility for the misuse of this dataset. 15 | The subset version `v2.0.0` refers specifically to the extended GigaMIDI dataset, while version `v2.0.0` denotes the original GigaMIDI dataset (see [Lee et al., 2025](https://doi.org/10.5334/tismir.203)). 16 | New users must request access via our Hugging Face Hub page before retrieving the dataset from the following link: 17 | [https://huggingface.co/datasets/Metacreation/GigaMIDI/viewer/v2.0.0](https://huggingface.co/datasets/Metacreation/GigaMIDI/viewer/v2.0.0) 18 | 19 | ### Dataset Curators 20 | 21 | Main curator: Keon Ju Maverick Lee 22 | 23 | Assistance: Jeff Ens, Sara Adkins, Nathan Fradet, Pedro Sarmento, Mathieu Barthet, Phillip Long, Paul Triana 24 | 25 | Research Director: Philippe Pasquier 26 | 27 | ### Citation/Reference 28 | 29 | If you use the GigaMIDI dataset or any part of this project, please cite the following paper: 30 | https://transactions.ismir.net/articles/10.5334/tismir.203 31 | ```bibtex 32 | @article{lee2025gigamidi, 33 | title={The GigaMIDI Dataset with Features for Expressive Music Performance Detection}, 34 | author={Lee, Keon Ju Maverick and Ens, Jeff and Adkins, Sara and Sarmento, Pedro and Barthet, Mathieu and Pasquier, Philippe}, 35 | journal={Transactions of the International Society for Music Information Retrieval (TISMIR)}, 36 | volume={8}, 37 | number={1}, 38 | pages={1--19}, 39 | year={2025} 40 | } 41 | ``` 42 | 43 | ## Repository Layout 44 | 45 | [**/GigaMIDI**](./GigaMIDI): Code for creating the full GigaMIDI dataset from 46 | source files, and README with example code for loading and processing the 47 | data set using the `datasets` library 48 | 49 | [**/loops_nomml**](./loops_nomml): Source files for non-expressive loop detection algorithm 50 | and expressive performance detection algorithm 51 | 52 | [**Expressive Loop Detector**](Expressive%20music%20loop%20detector-NOMML12.ipynb): code for the expressive loop detection method (in the .ipynb file) and instructions are available in the later section of this readme file. 53 | 54 | [**Expressive Loop Generation**](https://github.com/Metacreation-Lab/GigaMIDI-Dataset/tree/main/MIDI-GPT-Loop): code for the expressive loop generation and instructions are available in the hyperlink which connects to the readme file. 55 | 56 | [**/scripts**](./scripts): Scripts and code notebooks for analyzing the 57 | GigaMIDI dataset and the loop dataset 58 | 59 | [**/tests**](./tests): E2E tests for expressive performance detection and 60 | loop extractions 61 | 62 | [**Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models**](https://github.com/GigaMidiDataset/The-GigaMIDI-dataset-with-loops-and-expressive-music-performance-detection/tree/82d424ae7ff48a2fb3ce5bb07de13d5cca4fc8c5/Analysis%20of%20Evaluation%20Set%20and%20Optimal%20Threshold%20Selection%20including%20Machine%20Learning%20Models): This archive includes CSV files corresponding to our curated evaluation set, which comprises both a training set and a testing set. These files contain percentile calculations used to determine the optimal thresholds for each heuristic in expressive music performance detection. The use of percentiles from the data distribution is intended to establish clear boundaries between non-expressive and expressive tracks, based on the values of our heuristic features. Additionally, we provide pre-trained models in .pkl format, developed using features derived from our novel heuristics. The hyperparameter setup is detailed in the following section titled *Pipeline Configuration*. 63 | 64 | [**Data Source Links for the GigaMIDI Dataset**](https://github.com/GigaMidiDataset/The-GigaMIDI-dataset-with-loops-and-expressive-music-performance-detection/blob/8acb0e5ca8ac5eb21c072ed381fa737689748c81/Data%20Source%20Links%20for%20the%20GigaMIDI%20Dataset%20-%20Sheet1.pdf): Data source links for each collected subset of the GigaMIDI dataset are all organized and uploaded in PDF. 65 | 66 | ## Running MIDI-based Loop Detection 67 | 68 | Included with GigaMIDI dataset is a collection of all loops identified in the 69 | dataset between 4 and 32 bars in length, with a minimum density of 0.5 notes 70 | per beat. For our purposes, we consider a segment of a track to be loopable if 71 | it is bookended by a repeated phrase of a minimum length (at least 2 beats 72 | and 4 note events) 73 | 74 | ![Loop example](./loops_nomml/loop_ex_labeled.png) 75 | 76 | ### Starter Code 77 | 78 | To run loop detection on a single MIDI file, use the `detect_loops` function 79 | ```python 80 | from loops_nomml import detect_loops 81 | from symusic import Score 82 | 83 | score = Score("tests\midi_files\Mr. Blue Sky.mid") 84 | loops = detect_loops(score) 85 | print(loops) 86 | ``` 87 | 88 | The output will contain all the metadata needed to locate the loop within the 89 | file. Start and end times are represented as MIDI ticks, and density is 90 | given in units of notes per beat: 91 | ``` 92 | {'track_idx': [0, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 5], 'instrument_type': ['Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Drums', 'Drums', 'Drums', 'Drums', 'Drums', 'Piano', 'Piano'], 'start': [238080, 67200, 165120, 172800, 1920, 97920, 15360, 216960, 276480, 7680, 195840, 122880, 284160, 117120, 49920, 65280], 'end': [241920, 82560, 180480, 188160, 3840, 99840, 17280, 220800, 291840, 9600, 211200, 138240, 291840, 130560, 51840, 80640], 'duration_beats': [8.0, 32.0, 32.0, 32.0, 4.0, 4.0, 4.0, 8.0, 32.0, 4.0, 32.0, 32.0, 16.0, 28.0, 4.0, 32.0], 'note_density': [0.75, 1.84375, 0.8125, 0.8125, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.8125, 2.46875, 2.4375, 2.5, 0.5, 0.6875]} 93 | ``` 94 | 95 | ### Batch Processing Loops 96 | 97 | We also provide a script, `main.py` that batch extracts all loops in a 98 | dataset. This requires that you have downloaded GigaMIDI, see the [dataset README](./GigaMIDI/README.md) for instructions on doing this. Once you have the dataset downloaded, update the `DATA_PATH` and `METADATA_NAME` globals to reflect the location of GigaMIDI on your machine and run the script: 99 | 100 | ```python 101 | python main.py 102 | ``` 103 | 104 | 105 | ## Instruction for using the code for note onset median metric level (NOMML) heuristic 106 | ### Install and import Python libraries for the NOMML code:
107 | Imported libraries: 108 | ``` 109 | pip install numpy tqdm symusic 110 | ``` 111 |
112 | Note: symusic library is used for MIDI parsing. 113 | 114 | ### Using with the command line
115 | usage: 116 | ```python 117 | python nomml.py [-h] --folder FOLDER [--force] [--nthreads NTHREADS] 118 | ``` 119 |
120 | Note: If you run the code succesfully, it will generate .JSON file with appropriate metadata. 121 | 122 | The metadata median metric depth corresponds to the Note Onset Median Metric Level (NOMML). 123 | Please refer to the latest information below, as this repository was intended to be temporary during the anonymous peer-review process 124 | of the academic paper. Based on our experiments, tracks at levels 0-11 can be classified as non-expressive, while level 12 indicates 125 | expressive MIDI tracks. Note that this classification applies at the track level, not the file level. 126 | 127 | ## Pipeline Configuration 128 | 129 | The following pipeline configuration was determined through hyperparameter tuning using leave-one-out cross-validation and GridSearchCV for the logistic regression model: 130 | 131 | ```python 132 | # Hyperparameters 133 | {'C': 0.046415888336127774} 134 | 135 | # Logistic Regression Instance 136 | LogisticRegression(random_state=0, C=0.046415888336127774, max_iter=10000, tol=0.1) 137 | 138 | # Pipeline 139 | Pipeline(steps=[('scaler', StandardScaler(with_std=False)), 140 | ('logistic', 141 | LogisticRegression(C=0.046415888336127774, max_iter=10000, 142 | tol=0.1))]) 143 | ``` 144 | 145 | # Expressive Loop Detection Pipeline 146 | 147 | This repository contains a highly parallelized Python pipeline for detecting both **non-expressive** (hard-match) and **expressive** (soft-count) loops in large MIDI datasets. Built on the GigaMIDI `loops_nomml` library and Symusic for MIDI parsing, it uses `joblib` to distribute work across all available CPU cores. The workflow periodically writes out **checkpoint** CSVs (every 500 000 files) and produces a final, aggregated result. The code is available in the file named Expressive music loop detector-NOMML12.ipynb. 148 | 149 | --- 150 | 151 | ## 🚀 Key Features 152 | 153 | - **Hard-match detection** 154 | Finds exact, bar-aligned pitch-set repeats. 155 | - **Soft-count detection** 156 | Captures expressive loops by combining pitch overlap, velocity similarity, and micro-timing tolerance. 157 | - **Loopability scoring** 158 | A single metric blending the length of the longest repeat with overall repetition density. 159 | - **Scalable batch processing** 160 | Checkpoints output every 500 000 files (`loops_checkpoint_1.csv`, `loops_checkpoint_2.csv`, …). 161 | - **Full parallelization** 162 | Leverages `joblib.Parallel` to utilize all CPU cores for maximum throughput. 163 | 164 | --- 165 | 166 | ## 🔧 Prerequisites 167 | 168 | - **Python 3.8+** 169 | - **Symusic** for MIDI I/O 170 | - **GigaMIDI `loops_nomml`** module (place alongside this repo) 171 | - Install required packages: 172 | ```bash 173 | pip install numpy pandas joblib tqdm 174 | 175 | 176 | --- 177 | 178 | ## 📦 Installation 179 | 180 | 1. Clone this repo alongside your local `loops_nomml` checkout: 181 | 182 | ```bash 183 | git clone https://github.com/YourUser/expressive-loop-detect.git 184 | cd expressive-loop-detect 185 | ``` 186 | 2. (Optional) Create & activate a virtual environment: 187 | 188 | ```bash 189 | python3 -m venv venv 190 | source venv/bin/activate 191 | pip install -r requirements.txt 192 | ``` 193 | 194 | --- 195 | 196 | ## ⚙️ Usage 197 | 198 | 1. **Prepare your CSV** 199 | Place your input CSV (default name: 200 | `Final_GigaMIDI_Loop_V2_path-instrument-NOMML-type.csv`) in the working directory. It must have: 201 | 202 | | Column | Description | 203 | | ----------- | ----------------------------------------------------------------- | 204 | | `file_path` | Path to each `.mid` or `.midi` file | 205 | | `NOMML` | Python list of per-track expressiveness flags (e.g. `[12,2,4,…]`) | 206 | 207 | 2. **Configure parameters** 208 | At the top of `detect_loops.py`, you can adjust: 209 | 210 | * `melodic_tau` (default `0.3`): similarity threshold for melodic tracks 211 | * `drum_tau` (default `0.1`): threshold for drum tracks 212 | * `chunk_size` (default `500_000`): number of files per checkpoint 213 | * Bars are quantized every 4 beats; min/max loop lengths and density filters live in the `get_valid_loops` call. 214 | 215 | 3. **Run the detector** 216 | 217 | ```bash 218 | python detect_loops.py 219 | ``` 220 | 221 | Or open the Jupyter notebook: 222 | 223 | ```bash 224 | jupyter notebook detect_loops.ipynb 225 | ``` 226 | 227 | The script will: 228 | 229 | * Read the CSV of file paths 230 | * Process MIDI files in parallel, chunk by chunk 231 | * Save checkpoint CSVs named `loops_checkpoint_1.csv`, `loops_checkpoint_2.csv`, … 232 | * After all chunks, combine results into one DataFrame and save `loops_full_output.csv` 233 | 234 | --- 235 | 236 | ## 📊 Output 237 | 238 | * **Checkpoint CSVs** (`loops_checkpoint_.csv`): one row per MIDI file in that chunk, with columns: 239 | 240 | * `file_path` 241 | * `track_idx` (list of track indices with detected loops) 242 | * `MIDI program number` (list) 243 | * `instrument_group` (list of GM groups or “Drums”) 244 | * `loopability` (list of floats) 245 | * `start_tick`, `end_tick` (lists of integers) 246 | * `duration_beats` (list of floats) 247 | * `note_density` (list of floats) 248 | 249 | * **Full output** (`loops_full_output.csv`): concatenation of all checkpoint rows. 250 | 251 | **Example to load with correct list parsing**: 252 | 253 | ```python 254 | import pandas as pd 255 | 256 | converters = { 257 | 'track_idx': eval, 258 | 'MIDI program number': eval, 259 | 'instrument_group': eval, 260 | 'loopability': eval, 261 | 'start_tick': eval, 262 | 'end_tick': eval, 263 | 'duration_beats': eval, 264 | 'note_density': eval 265 | } 266 | 267 | df = pd.read_csv("loops_full_output.csv", converters=converters) 268 | ``` 269 | | Column | Type | Description | 270 | | ------------------------ | -------------- | ------------------------------------------- | 271 | | `file_path` | string | Original MIDI filepath | 272 | | `track_idx` | list of ints | Indices of tracks where loops were detected | 273 | | `MIDI program number` | list of ints | Corresponding MIDI program codes | 274 | | `instrument_group` | list of strs | GM group (or “Drums”) for each loop | 275 | | `loopability` | list of floats | Loopability score per detected loop | 276 | | `start_tick`, `end_tick` | list of ints | Loop boundaries (MIDI ticks) | 277 | | `duration_beats` | list of floats | Loop lengths in beats | 278 | | `note_density` | list of floats | Active-notes-per-beat density per loop | 279 | 280 | --- 281 | 282 | ## 📝 Troubleshooting 283 | 284 | * **No loops found**: try lowering `melodic_tau` or relaxing `min_rep_beats`/`min_beats`. 285 | * **`IndexError` in beat-duration**: the script patches `get_duration_beats` for safety. 286 | * **Performance issues**: set `n_jobs` in `Parallel(...)` to fewer cores or reduce `chunk_size`. 287 | 288 | --- 289 | 290 | # Expressive Loop Generation 291 | 292 | Information about loop generation is found at `MIDI-GPT-Loop/README.md` 293 | 294 | --- 295 | 296 | ## 🤝 Contributing 297 | 298 | 1. Fork this repo 299 | 2. Create a feature branch 300 | 3. Submit a pull request 301 | 302 | --- 303 | 304 | 305 | 306 | 307 | ## Acknowledgement 308 | We gratefully acknowledge the support and contributions that have directly or indirectly aided this research. This work was supported in part by funding from the Natural Sciences and Engineering Research Council of Canada (NSERC) and the Social Sciences and Humanities Research Council of Canada (SSHRC). We also extend our gratitude to the School of Interactive Arts and Technology (SIAT) at Simon Fraser University (SFU) for providing resources and an enriching research environment. Additionally, we thank the Centre for Digital Music (C4DM) at Queen Mary University of London (QMUL) for fostering collaborative opportunities and supporting our engagement with interdisciplinary research initiatives. 309 | 310 | Special thanks are extended to Dr. Cale Plut for his meticulous manual curation of musical styles and to Dr. Nathan Fradet for his invaluable assistance in developing the HuggingFace Hub website for the GigaMIDI dataset, ensuring it is accessible and user-friendly for music computing and MIR researchers. We also sincerely thank our research interns, Paul Triana and Davide Rizotti, for their thorough proofreading of the manuscript. 311 | 312 | Finally, we express our heartfelt appreciation to the individuals and communities who generously shared their MIDI files for research purposes. Their contributions have been instrumental in advancing this work and fostering collaborative knowledge in the field. 313 | -------------------------------------------------------------------------------- /GigaMIDI/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | annotations_creators: [] 3 | license: 4 | - cc-by-4.0 5 | - other 6 | pretty_name: GigaMIDI 7 | size_categories: [] 8 | source_datasets: [] 9 | tags: [] 10 | task_ids: [] 11 | --- 12 | 13 | # Dataset Card for GigaMIDI 14 | 15 | ## Table of Contents 16 | 17 | - [Dataset Description](#dataset-description) 18 | - [Dataset Summary](#dataset-summary) 19 | - [How to use](#how-to-use) 20 | - [Dataset Structure](#dataset-structure) 21 | - [Data Instances](#data-instances) 22 | - [Data Fields](#data-fields) 23 | - [Data Splits](#data-splits) 24 | - [Dataset Creation](#dataset-creation) 25 | - [Curation Rationale](#curation-rationale) 26 | - [Source Data](#source-data) 27 | - [Annotations](#annotations) 28 | - [Personal and Sensitive Information](#personal-and-sensitive-information) 29 | - [Considerations for Using the Data](#considerations-for-using-the-data) 30 | - [Social Impact of Dataset](#social-impact-of-dataset) 31 | - [Discussion of Biases](#discussion-of-biases) 32 | - [Other Known Limitations](#other-known-limitations) 33 | - [Additional Information](#additional-information) 34 | - [Dataset Curators](#dataset-curators) 35 | - [Licensing Information](#licensing-information) 36 | 37 | 38 | ## Dataset Description 39 | 40 | 41 | - **Repository:** Anonymized during the peer review process 42 | 43 | - **Point of Contact:** Anonymized during the peer review process 44 | 45 | ### Dataset Summary 46 | 47 | The GigaMIDI dataset is a corpus of over 1 million MIDI files covering all music genres. 48 | 49 | We provide three subsets: `drums-only`, which contain MIDI files exclusively containing drum tracks, `no-drums` for MIDI files containing any MIDI program except drums (channel 10) and `all-instruments-with-drums` for MIDI files containing multiple MIDI programs including drums. The `all` subset encompasses the three to get the full dataset. 50 | 51 | ## How to use 52 | 53 | The `datasets` library allows you to load and pre-process your dataset in pure Python at scale. The dataset can be downloaded and prepared in one call to your local drive by using the `load_dataset` function. 54 | 55 | ```python 56 | from datasets import load_dataset 57 | 58 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True) 59 | ``` 60 | 61 | You can load combinations of specific subsets by using the `subset` keyword argument when loading the dataset: 62 | 63 | ```Python 64 | from datasets import load_dataset 65 | 66 | dataset = load_dataset("Metacreation/GigaMIDI", "music", subsets=["no-drums", "all-instruments-with-drums"], trust_remote_code=True) 67 | ``` 68 | 69 | Using the datasets library, you can also stream the dataset on-the-fly by adding a `streaming=True` argument to the `load_dataset` function call. Loading a dataset in streaming mode loads individual samples of the dataset at a time, rather than downloading the entire dataset to disk. 70 | 71 | ```python 72 | from datasets import load_dataset 73 | 74 | dataset = load_dataset( 75 | "Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, streaming=True 76 | ) 77 | 78 | print(next(iter(dataset))) 79 | ``` 80 | 81 | *Bonus*: create a [PyTorch dataloader](https://huggingface.co/docs/datasets/use_with_pytorch) directly with your own datasets (local/streamed). 82 | 83 | ### Local 84 | 85 | ```python 86 | from datasets import load_dataset 87 | from torch.utils.data.sampler import BatchSampler, RandomSampler 88 | 89 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train") 90 | batch_sampler = BatchSampler(RandomSampler(dataset), batch_size=32, drop_last=False) 91 | dataloader = DataLoader(dataset, batch_sampler=batch_sampler) 92 | ``` 93 | 94 | ### Streaming 95 | 96 | ```python 97 | from datasets import load_dataset 98 | from torch.utils.data import DataLoader 99 | 100 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train") 101 | dataloader = DataLoader(dataset, batch_size=32) 102 | ``` 103 | 104 | ### Example scripts 105 | 106 | MIDI files can be easily loaded and tokenized with [Symusic](https://github.com/Yikai-Liao/symusic) and [MidiTok](https://github.com/Natooz/MidiTok) respectively. 107 | 108 | ```python 109 | from datasets import load_dataset 110 | 111 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train") 112 | ``` 113 | 114 | The dataset can be [processed](https://huggingface.co/docs/datasets/process) by using the `dataset.map` and `dataset.filter` methods. 115 | 116 | ```Python 117 | from pathlib import Path 118 | from datasets import load_dataset 119 | from miditok.constants import SCORE_LOADING_EXCEPTION 120 | from miditok.utils import get_bars_ticks 121 | from symusic import Score 122 | 123 | def is_score_valid( 124 | score: Score | Path | bytes, min_num_bars: int, min_num_notes: int 125 | ) -> bool: 126 | """ 127 | Check if a ``symusic.Score`` is valid, contains the minimum required number of bars. 128 | 129 | :param score: ``symusic.Score`` to inspect or path to a MIDI file. 130 | :param min_num_bars: minimum number of bars the score should contain. 131 | :param min_num_notes: minimum number of notes that score should contain. 132 | :return: boolean indicating if ``score`` is valid. 133 | """ 134 | if isinstance(score, Path): 135 | try: 136 | score = Score(score) 137 | except SCORE_LOADING_EXCEPTION: 138 | return False 139 | elif isinstance(score, bytes): 140 | try: 141 | score = Score.from_midi(score) 142 | except SCORE_LOADING_EXCEPTION: 143 | return False 144 | 145 | return ( 146 | len(get_bars_ticks(score)) >= min_num_bars and score.note_num() > min_num_notes 147 | ) 148 | 149 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train") 150 | dataset = dataset.filter( 151 | lambda ex: is_score_valid(ex["music"]["bytes"], min_num_bars=8, min_num_notes=50) 152 | ) 153 | ``` 154 | 155 | ## Dataset Structure 156 | 157 | ### Data Instances 158 | 159 | A typical data sample comprises the `md5` of the file which corresponds to its file name, a `music` entry containing dictionary mapping to its absolute file `path` and `bytes` that can be loaded with `symusic` as `score = Score.from_midi(dataset[sample_idx]["music"]["bytes"])`. 160 | Metadata accompanies each file, which is introduced in the next section. 161 | 162 | A data sample indexed from the dataset may look like this (the `bytes` entry is voluntarily shorten): 163 | 164 | ```python 165 | { 166 | 'md5': '0211bbf6adf0cf10d42117e5929929a4', 167 | 'music': {'path': '/Users/nathan/.cache/huggingface/datasets/downloads/extracted/cc8e36bbe8d5ec7ecf1160714d38de3f2f670c13bc83e0289b2f1803f80d2970/0211bbf6adf0cf10d42117e5929929a4.mid', 'bytes': b"MThd\x00\x00\x00\x06\x00\x01\x00\x05\x01\x00MTrk\x00"}, 168 | 'is_drums': False, 169 | 'sid_matches': {'sid': ['065TU5v0uWSQmnTlP5Cnsz', '29OG7JWrnT0G19tOXwk664', '2lL9TiCxUt7YpwJwruyNGh'], 'score': [0.711, 0.8076, 0.8315]}, 170 | 'mbid_matches': {'sid': ['065TU5v0uWSQmnTlP5Cnsz', '29OG7JWrnT0G19tOXwk664', '2lL9TiCxUt7YpwJwruyNGh'], 'mbids': [['43d521a9-54b0-416a-b15e-08ad54982e63', '70645f54-a13d-4123-bf49-c73d8c961db8', 'f46bba68-588f-49e7-bb4d-e321396b0d8e'], ['43d521a9-54b0-416a-b15e-08ad54982e63', '70645f54-a13d-4123-bf49-c73d8c961db8'], ['3a4678e6-9d8f-4379-aa99-78c19caf1ff5']]}, 171 | 'artist_scraped': 'Bach, Johann Sebastian', 172 | 'title_scraped': 'Contrapunctus 1 from Art of Fugue', 173 | 'genres_scraped': ['classical', 'romantic'], 174 | 'genres_discogs': {'genre': ['classical', 'classical---baroque'], 'count': [14, 1]}, 175 | 'genres_tagtraum': {'genre': ['classical', 'classical---baroque'], 'count': [1, 1]}, 176 | 'genres_lastfm': {'genre': [], 'count': []}, 177 | 'median_metric_depth': [0, 0, 0, 0] 178 | } 179 | ``` 180 | 181 | ### Data Fields 182 | 183 | The GigaMIDI dataset comprises the [MetaMIDI dataset](https://www.metacreation.net/projects/metamidi-dataset). Consequently, the GigaMIDI dataset also contains its [metadata](https://github.com/jeffreyjohnens/MetaMIDIDataset) which we compiled here in a convenient and easy to use dataset format. The fields of each data entry are: 184 | 185 | * `md5` (`string`): hash the MIDI file, corresponding to its file name; 186 | * `music` (`dict`): a dictionary containing the absolute `path` to the downloaded file and the file content as `bytes` to be loaded with an external Python package such as symusic; 187 | * `is_drums` (`boolean`): whether the sample comes from the `drums` subset, this can be useful when working with the `all` subset; 188 | * `sid_matches` (`dict[str, list[str] | list[float16]]`): ids of the Spotify entries matched and their scores. 189 | * `mbid_matches` (`dict[str, str | list[str]]`): ids of the MusicBrainz entries matched with the Spotify entries. 190 | * `artist_scraped` (`string`): scraped artist of the entry; 191 | * `title_scraped` (`string`): scraped song title of the entry; 192 | * `genres_scraped` (`list[str]`): scraped genres of the entry; 193 | * `genres_discogs` (`dict[str, list[str] | list[int16]]`): Discogs genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/); 194 | * `genres_tagtraum` (`dict[str, list[str] | list[int16]]`): Tagtraum genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/); 195 | * `genres_lastfm` (`dict[str, list[str] | list[int16]]`): Lastfm genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/); 196 | * `median_metric_depth` (`list[int16]`): 197 | 198 | 199 | ### Data Splits 200 | 201 | The dataset has been subdivided into portions for training (`train`), validation (`validation`) and testing (`test`). 202 | 203 | The validation and test splits contain each 10% of the dataset, while the training split contains the rest (about 80%). 204 | 205 | ## Dataset Creation 206 | 207 | ### Curation Rationale 208 | 209 | [Needs More Information] 210 | 211 | ### Source Data 212 | 213 | #### Initial Data Collection and Normalization 214 | 215 | [Needs More Information] 216 | 217 | #### Who are the source language producers? 218 | 219 | [Needs More Information] 220 | 221 | ### Annotations 222 | 223 | #### Annotation process 224 | 225 | [Needs More Information] 226 | 227 | #### Who are the annotators? 228 | 229 | [Needs More Information] 230 | 231 | ## Considerations for Using the Data 232 | 233 | ### Discussion of Biases 234 | 235 | [More Information Needed] 236 | 237 | ### Other Known Limitations 238 | 239 | [More Information Needed] 240 | 241 | ## Additional Information 242 | 243 | ### Dataset Curators 244 | 245 | [More Information Needed] 246 | 247 | ### Licensing Information 248 | 249 | Available for research purposes via Hugging Face hub. 250 | 251 | ## Ethical Statement 252 | The GigaMIDI dataset consists of MIDI files acquired via the aggregation of previously available datasets and web scraping from publicly available online sources. Each subset is accompanied by source links, copyright information when available, and acknowledgments. File names are anonymized using MD5 hash encryption. We acknowledge and cited the work from the previous dataset papers that we aggregate and analyze as part of the GigaMIDI subsets. 253 | This data has been collected, used, and distributed under Fair Dealing [ref to country and law copyright act anonymized]. Fair Dealing permits the limited use of copyright-protected material without the risk of infringement and without having to seek the permission of copyright owners. It is intended to provide a balance between the rights of creators and the rights of users. As per instructions of the Copyright Office of [anonymized University], two protective measures have been put in place that are deemed sufficient given the nature of the data (accessible online): 254 | 255 | 1) We explicitly state that this dataset has been collected, used, and is distributed under Fair Dealing [ref to law/country removed here for anonymity]. 256 | 2) On the Hugging Face hub, we advertise that the data is available for research purposes only and collect the user's legal name and email as proof of agreement before granting access. 257 | 258 | We thus decline any responsibility for misuse. 259 | 260 | To justify the fair use of MIDI data for research purposes, we try to follow the FAIR (Findable, Accessible, Interoperable, and Reusable) principles in our MIDI data collection. These principles are widely recognized and frequently cited within the data research community, providing a robust framework for ethical data management. By following the FAIR principles, we ensure that our dataset is managed responsibly, supporting its use in research while maintaining high standards of accessibility, interoperability, and reusability. 261 | 262 | 263 | In navigating the use of MIDI datasets for research and creative explorations, it is imperative to consider the ethical implications inherent in dataset bias. MIDI dataset bias often reflects the prevailing practices in Western contemporary music production, where certain instruments, notably the piano and drums, dominate due to their inherent MIDI compatibility. The piano is a primary compositional tool and a ubiquitous MIDI controller and keyboard, facilitating input for a wide range of virtual instruments and synthesizers. Similarly, drums, whether through drum machines or MIDI drum pads, enjoy widespread use for rhythm programming and beat production. This prevalence stems from their intuitive interface and versatility within digital audio workstations. Consequently, MIDI datasets tend to be skewed towards piano and drums, with fewer representations of other instruments, particularly those that may require more nuanced interpretation or are less commonly played using MIDI controllers. 264 | 265 | 266 | A potential issue with the detected loops in the GigaMIDI dataset arises from the possibility that similar note content may appear, particularly in loop-focused applications. To mitigate this, we implemented an additional deduplication process for the detected loops. This process involved using MD5 checksums based on the extracted music loop content to ensure that identical loops are not provided to users. 267 | 268 | 269 | Another potential issue with the dataset is the album-effect. When using the dataset for machine learning tasks, a random split may result in data with nearly identical note content appearing in both the evaluation and training splits of the GigaMIDI dataset. To address this potential issue, we provide metadata, including the composer's name, uniform piece title, performer’s name, and genre, where available. Additionally, the GigaMIDI dataset includes a substantial portion of drum grooves, which are single-track MIDI files; such files typically do not contribute to the album-effect. 270 | 271 | 272 | Lastly, all source data is duly acknowledged and cited in accordance with fair use and ethical standards. More than 50% of the dataset was collected through web scraping and author-led initiatives, which include manual data collection from online sources and retrieval from Zenodo and GitHub. To ensure transparency and prevent misuse, links to data sources for each subset are systematically organized and provided in our GitHub repository, enabling users to identify and verify the datasets used. 273 | 274 | ## FAIR (Findable, Accessible, Interoperable, Reusable) principles with the GigaMIDI dataset 275 | 276 | The FAIR (Findable, Accessible, Interoperable, Reusable) principles serve as a framework to ensure that data is well-managed, easily discoverable, and usable for a broad range of purposes in research. These principles are particularly important in the context of data management to facilitate open science, collaboration, and reproducibility. 277 | 278 | 1. Findable: Data should be easily discoverable by both humans and machines. This is typically achieved through proper metadata, traceable source links and searchable resources. Applying this to MIDI data, each subset of MIDI files collected from public domain sources should be accompanied by clear and consistent metadata. For example, organizing the source links of each data subset, as done with the GigaMIDI dataset, ensures that each source can be easily traced and referenced, improving discoverability. 279 | 280 | 2. Accessible: Once found, data should be easily retrievable using standard protocols. Accessibility does not necessarily imply open access, but it does mean that data should be available under well-defined conditions. For the GigaMIDI dataset, hosting the data on platforms like Hugging Face Hub improves accessibility, as these platforms provide efficient data retrieval mechanisms, especially for large-scale datasets. Ensuring that MIDI data is accessible for public use, while respecting any applicable licenses, supports wider research and analysis in music computing. 281 | 282 | 3. Interoperable: Data should be structured in such a way that it can be integrated with other datasets and used by various applications. MIDI data, being a widely accepted format in music research, is inherently interoperable, especially when standardized metadata and file formats are used. By ensuring that the GigaMIDI dataset complies with widely adopted standards and supports integration with state-of-the-art libraries in symbolic music processing, such as Symusic (https://github.com/Yikai-Liao/symusic) and MidiTok (https://github.com/Natooz/MidiTok), the dataset enhances its utility for music researchers and practitioners working across different platforms and systems. 283 | 284 | 4. Reusable: Data should be well-documented and licensed so it can be reused in future research. Reusability is ensured through proper metadata, clear licenses, and documentation of provenance. In the case of GigaMIDI, aggregating all subsets from public domain sources and linking them to the original sources strengthens the reproducibility and traceability of the data. This practice allows future researchers to not only use the dataset but also verify and expand upon it by referring to the original data sources. 285 | 286 | In summary, applying FAIR principles to managing MIDI data, such as the GigaMIDI dataset, ensures that the data is organized in a manner that promotes reproducibility and traceability. By clearly documenting the source links of each subset and ensuring the dataset is findable, accessible, interoperable, and reusable, the data becomes a robust resource for the research community. 287 | 288 | 289 | 295 | -------------------------------------------------------------------------------- /loops_nomml/corr_mat_fast.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Sequence, Tuple, Dict, List, Optional 4 | 5 | import os 6 | import numpy as np 7 | import threading 8 | 9 | from note_set_fast import NoteSet # updated import path if you renamed file 10 | 11 | if TYPE_CHECKING: 12 | from symusic import Note 13 | 14 | # Module-level small cache for duration computations to avoid repeated work 15 | # Keyed by (start, end, tb_first, tb_last, tb_len) 16 | _duration_cache: Dict[Tuple[int, int, int, int, int], float] = {} 17 | _duration_cache_lock = threading.Lock() 18 | 19 | # Safety caps and windowing parameters (tune to your environment) 20 | _MAX_TICKS_KEEP = 100_000 # hard cap on ticks array length 21 | _TRUNCATE_KEEP = 20_000 # keep this many ticks from head and tail if truncating 22 | _MAX_NOTESETS_FOR_FULL_CORR = 5000 # if n > this, use windowed correlation 23 | _WINDOW_SIZE = 4000 # window size for windowed correlation 24 | _WINDOW_OVERLAP = 500 # overlap between windows to catch cross-boundary runs 25 | 26 | # New: sanity threshold for tick arrays used for interpolation 27 | _MAX_TICK = 10 ** 9 28 | 29 | # Optional Numba acceleration flags (opt-in via environment) 30 | _USE_NUMBA = os.environ.get("USE_NUMBA", "0") == "1" 31 | _USE_CUDA = os.environ.get("USE_CUDA", "0") == "1" 32 | 33 | _NUMBA_AVAILABLE = False 34 | _CUDA_AVAILABLE = False 35 | _njit = None 36 | _cuda = None 37 | 38 | if _USE_NUMBA: 39 | try: 40 | from numba import njit, prange 41 | _NUMBA_AVAILABLE = True 42 | _njit = njit 43 | except Exception: 44 | _NUMBA_AVAILABLE = False 45 | _njit = None 46 | 47 | if _USE_CUDA and _NUMBA_AVAILABLE: 48 | try: 49 | from numba import cuda 50 | _CUDA_AVAILABLE = cuda.is_available() 51 | _cuda = cuda 52 | except Exception: 53 | _CUDA_AVAILABLE = False 54 | _cuda = None 55 | 56 | # If CUDA requested but not available, fall back to CPU numba if available 57 | if _USE_CUDA and not _CUDA_AVAILABLE and _NUMBA_AVAILABLE: 58 | # keep _USE_CUDA False to avoid trying GPU kernels later 59 | _USE_CUDA = False 60 | 61 | # Helper: safe strictly increasing check 62 | def _is_strictly_increasing(arr: np.ndarray) -> bool: 63 | if arr.size < 2: 64 | return True 65 | try: 66 | return bool((arr[1:] > arr[:-1]).all()) 67 | except Exception: 68 | prev = arr[0] 69 | for v in arr[1:]: 70 | if v <= prev: 71 | return False 72 | prev = v 73 | return True 74 | 75 | 76 | def _sanitize_ticks_beats_once(ticks_beats: Sequence[int]) -> np.ndarray: 77 | tb = np.asarray(ticks_beats, dtype=np.int64) 78 | if tb.size == 0: 79 | return tb 80 | 81 | # Remove implausible tick values 82 | try: 83 | tb = tb[np.isfinite(tb)] 84 | except Exception: 85 | pass 86 | 87 | # Clip out-of-range ticks 88 | if tb.size > 0: 89 | tb = tb[(tb >= 0) & (tb <= _MAX_TICK)] 90 | 91 | if tb.size == 0: 92 | return tb 93 | 94 | if _is_strictly_increasing(tb): 95 | if tb.size > _MAX_TICKS_KEEP: 96 | head = tb[:_TRUNCATE_KEEP] 97 | tail = tb[-_TRUNCATE_KEEP:] 98 | tb = np.concatenate((head, tail)) 99 | return tb 100 | 101 | try: 102 | tb_unique = np.unique(tb) 103 | except Exception: 104 | try: 105 | tb_list = sorted(set(int(x) for x in tb if np.isfinite(x))) 106 | tb_unique = np.asarray(tb_list, dtype=np.int64) 107 | except Exception: 108 | return np.asarray([], dtype=np.int64) 109 | 110 | if tb_unique.size > _MAX_TICKS_KEEP: 111 | head = tb_unique[:_TRUNCATE_KEEP] 112 | tail = tb_unique[-_TRUNCATE_KEEP:] 113 | tb_unique = np.concatenate((head, tail)) 114 | 115 | return tb_unique 116 | 117 | 118 | # ------------------------- 119 | # Numba-accelerated helpers 120 | # ------------------------- 121 | # CPU njit implementation of diagonal processing (pure-Python fallback exists) 122 | if _NUMBA_AVAILABLE and not _USE_CUDA: 123 | # We implement a simplified njit version that mirrors the original logic. 124 | # Note: numba njit does not support all numpy conveniences used above, so 125 | # we implement the core loops in a form numba accepts. 126 | @njit # type: ignore[name-defined] 127 | def _process_diagonals_numba_cpu(key_ids: np.ndarray, is_barline: np.ndarray, out_mat: np.ndarray, offset: int) -> None: 128 | n_local = key_ids.size 129 | if n_local < 2: 130 | return 131 | 132 | # Fast path for first element matching later ones when it's a barline 133 | if is_barline[0]: 134 | first_id = key_ids[0] 135 | for idx in range(1, n_local): 136 | if key_ids[idx] == first_id: 137 | out_mat[offset, offset + idx] = 1 138 | 139 | for k in range(1, n_local): 140 | # left: 0..n_local-k-1, right: k..n_local-1 141 | prev_val = 0 142 | for i_local in range(0, n_local - k): 143 | j_local = i_local + k 144 | if key_ids[i_local] == key_ids[j_local]: 145 | # contiguous run handling: we need to detect run starts/ends 146 | # We emulate the "growing" counts by checking previous diagonal value 147 | if i_local == 0: 148 | if is_barline[i_local]: 149 | prev_val = 1 150 | out_mat[offset + i_local, offset + j_local] = 1 151 | else: 152 | prev_val = 0 153 | else: 154 | if prev_val == 0 and (not is_barline[i_local]): 155 | prev_val = 0 156 | else: 157 | prev_val = prev_val + 1 158 | out_mat[offset + i_local, offset + j_local] = prev_val 159 | else: 160 | prev_val = 0 161 | 162 | # GPU kernel: compute equality matches for a given diagonal offset k 163 | if _CUDA_AVAILABLE: 164 | @cuda.jit # type: ignore[name-defined] 165 | def _cuda_kernel_match(key_ids, k, n_local, matches): 166 | i = cuda.grid(1) 167 | if i < n_local - k: 168 | matches[i] = 1 if key_ids[i] == key_ids[i + k] else 0 169 | 170 | 171 | def _process_diagonals_into_matrix(key_ids_slice: np.ndarray, is_barline_slice: np.ndarray, out_mat: np.ndarray, offset: int) -> None: 172 | """ 173 | Wrapper that selects accelerated implementation if available, otherwise uses 174 | the original Python/Numpy implementation. 175 | """ 176 | n_local = key_ids_slice.size 177 | if n_local < 2: 178 | return 179 | 180 | # If CUDA is enabled and available, use a GPU kernel to compute per-diagonal matches 181 | if _CUDA_AVAILABLE: 182 | try: 183 | # copy slice to device 184 | key_ids_dev = _cuda.to_device(key_ids_slice.astype(np.int32)) 185 | # allocate a device array for matches (max length n_local) 186 | for k in range(1, n_local): 187 | matches_dev = _cuda.device_array(n_local - k, dtype=np.uint8) 188 | threadsperblock = 128 189 | blocks = (n_local - k + threadsperblock - 1) // threadsperblock 190 | _cuda_kernel_match[blocks, threadsperblock](key_ids_dev, k, n_local, matches_dev) 191 | matches = matches_dev.copy_to_host() 192 | if not matches.any(): 193 | continue 194 | 195 | # find runs in matches (host side) 196 | true_idx = np.flatnonzero(matches) 197 | if true_idx.size == 0: 198 | continue 199 | 200 | # Identify contiguous runs in true_idx 201 | if true_idx.size == 1: 202 | runs = [(int(true_idx[0]), int(true_idx[0]))] 203 | else: 204 | diffs = np.diff(true_idx) 205 | breaks = np.nonzero(diffs > 1)[0] 206 | run_starts = np.concatenate(([0], breaks + 1)) 207 | run_ends = np.concatenate((breaks, [true_idx.size - 1])) 208 | runs = [(int(true_idx[s]), int(true_idx[e])) for s, e in zip(run_starts, run_ends)] 209 | 210 | # For each run, walk and set incremental counts 211 | for run_start, run_end in runs: 212 | prev_val = 0 213 | for i_local in range(run_start, run_end + 1): 214 | j_local = i_local + k 215 | global_i = offset + i_local 216 | global_j = offset + j_local 217 | if i_local == 0: 218 | if is_barline_slice[i_local]: 219 | prev_val = 1 220 | out_mat[global_i, global_j] = 1 221 | else: 222 | prev_val = 0 223 | else: 224 | if prev_val == 0 and not is_barline_slice[i_local]: 225 | prev_val = 0 226 | else: 227 | prev_val = prev_val + 1 228 | out_mat[global_i, global_j] = prev_val 229 | # also handle first-element barline fast path (matches against first element) 230 | if is_barline_slice[0]: 231 | first_id = int(key_ids_slice[0]) 232 | matches = (key_ids_slice[1:] == first_id) 233 | if matches.any(): 234 | global_i = offset 235 | global_js = np.nonzero(matches)[0] + offset + 1 236 | out_mat[global_i, global_js] = 1 237 | return 238 | except Exception: 239 | # If GPU path fails for any reason, fall back to CPU/NumPy implementation below 240 | pass 241 | 242 | # If Numba CPU is available and requested, use njit implementation 243 | if _NUMBA_AVAILABLE and not _USE_CUDA: 244 | try: 245 | _process_diagonals_numba_cpu(key_ids_slice.astype(np.int32), is_barline_slice.astype(np.uint8), out_mat, offset) 246 | return 247 | except Exception: 248 | # fall back to Python implementation on error 249 | pass 250 | 251 | # Original Python/Numpy implementation (fallback) 252 | # Fast path for first element matching later ones when it's a barline 253 | if is_barline_slice[0]: 254 | first_id = key_ids_slice[0] 255 | matches = (key_ids_slice[1:] == first_id) 256 | if matches.any(): 257 | global_i = offset 258 | global_js = np.nonzero(matches)[0] + offset + 1 259 | out_mat[global_i, global_js] = 1 260 | 261 | # For each diagonal offset k, find matches between left and right slices 262 | for k in range(1, n_local): 263 | left = key_ids_slice[: n_local - k] 264 | right = key_ids_slice[k : n_local] 265 | matches = (left == right) 266 | if not matches.any(): 267 | continue 268 | 269 | true_idx = np.flatnonzero(matches) 270 | if true_idx.size == 0: 271 | continue 272 | 273 | # Identify contiguous runs in true_idx 274 | if true_idx.size == 1: 275 | runs = [(int(true_idx[0]), int(true_idx[0]))] 276 | else: 277 | diffs = np.diff(true_idx) 278 | breaks = np.nonzero(diffs > 1)[0] 279 | run_starts = np.concatenate(([0], breaks + 1)) 280 | run_ends = np.concatenate((breaks, [true_idx.size - 1])) 281 | runs = [(int(true_idx[s]), int(true_idx[e])) for s, e in zip(run_starts, run_ends)] 282 | 283 | # For each run, walk and set incremental counts; this inner loop is 284 | # necessary to produce the "growing" correlation counts along diagonals. 285 | for run_start, run_end in runs: 286 | prev_val = 0 287 | for i_local in range(run_start, run_end + 1): 288 | j_local = i_local + k 289 | global_i = offset + i_local 290 | global_j = offset + j_local 291 | if i_local == 0: 292 | if is_barline_slice[i_local]: 293 | prev_val = 1 294 | out_mat[global_i, global_j] = 1 295 | else: 296 | prev_val = 0 297 | else: 298 | if prev_val == 0 and not is_barline_slice[i_local]: 299 | prev_val = 0 300 | else: 301 | prev_val = prev_val + 1 302 | out_mat[global_i, global_j] = prev_val 303 | 304 | 305 | def calc_correlation(note_sets: Sequence[NoteSet]) -> np.ndarray: 306 | n = len(note_sets) 307 | if n < 2: 308 | return np.zeros((n, n), dtype=np.int16) 309 | 310 | key_to_id: Dict[Tuple[int, frozenset], int] = {} 311 | key_ids = np.empty(n, dtype=np.int32) 312 | next_id = 1 313 | for i, ns in enumerate(note_sets): 314 | # Defensive: ensure ns has expected attributes and reasonable values 315 | try: 316 | dur = int(ns.duration) 317 | pitches = frozenset(ns.pitches) 318 | if dur < 0 or dur > _MAX_TICK: 319 | dur = 0 320 | except Exception: 321 | dur = 0 322 | pitches = frozenset() 323 | key = (dur, pitches) 324 | kid = key_to_id.get(key) 325 | if kid is None: 326 | kid = next_id 327 | key_to_id[key] = kid 328 | next_id += 1 329 | key_ids[i] = kid 330 | 331 | is_barline = np.fromiter((bool(getattr(ns, "is_barline", lambda: False)()) for ns in note_sets), dtype=bool, count=n) 332 | corr_mat = np.zeros((n, n), dtype=np.int16) 333 | 334 | if n <= _MAX_NOTESETS_FOR_FULL_CORR: 335 | _process_diagonals_into_matrix(key_ids, is_barline, corr_mat, offset=0) 336 | return corr_mat 337 | 338 | win = _WINDOW_SIZE 339 | overlap = min(_WINDOW_OVERLAP, win // 4) 340 | if overlap >= win: 341 | overlap = win // 4 342 | stride = win - overlap 343 | if stride <= 0: 344 | stride = max(1, win // 2) 345 | 346 | start = 0 347 | while start < n: 348 | end = min(n, start + win) 349 | key_ids_slice = key_ids[start:end] 350 | is_barline_slice = is_barline[start:end] 351 | _process_diagonals_into_matrix(key_ids_slice, is_barline_slice, corr_mat, offset=start) 352 | if end == n: 353 | break 354 | start += stride 355 | 356 | return corr_mat 357 | 358 | 359 | # The rest of the file (duration helpers and loop detection) remains unchanged 360 | # (kept here for completeness). They are unchanged from the previous version, 361 | # but are included so this module is self-contained. 362 | 363 | def get_loop_density(loop: Sequence[NoteSet], num_beats: int | float) -> float: 364 | if num_beats == 0: 365 | return 0.0 366 | active = 0 367 | for ns in loop: 368 | if ns.duration != 0: 369 | active += 1 370 | return active / float(num_beats) 371 | 372 | 373 | def is_empty_loop(loop: Sequence[Note]) -> bool: 374 | for ns in loop: 375 | if ns.pitches: 376 | return False 377 | return True 378 | 379 | 380 | def compare_loops(p1: Sequence[NoteSet], p2: Sequence[NoteSet], min_rep_beats: int | float) -> int: 381 | check_len = int(round(min_rep_beats)) 382 | check_len = min(check_len, len(p1), len(p2)) 383 | for i in range(check_len): 384 | if p1[i] != p2[i]: 385 | return 0 386 | return 1 if len(p1) < len(p2) else 2 387 | 388 | 389 | def test_loop_exists(loop_list: Sequence[Sequence[NoteSet]], loop: Sequence[NoteSet], min_rep_beats: int | float) -> Optional[int]: 390 | for i, pat in enumerate(loop_list): 391 | result = compare_loops(loop, pat, min_rep_beats) 392 | if result == 1: 393 | return -1 394 | if result == 2: 395 | return i 396 | return None 397 | 398 | 399 | def filter_sub_loops(candidate_indices: Dict[float, List[Tuple[int, int]]]) -> List[Tuple[int, int, float]]: 400 | if not candidate_indices: 401 | return [] 402 | 403 | final: List[Tuple[int, int, float]] = [] 404 | for duration in sorted(candidate_indices.keys()): 405 | intervals = candidate_indices[duration] 406 | if not intervals: 407 | continue 408 | intervals_sorted = sorted(intervals, key=lambda x: x[0]) 409 | merged_start, merged_end = intervals_sorted[0] 410 | for s, e in intervals_sorted[1:]: 411 | if s == merged_end: 412 | merged_end = e 413 | else: 414 | final.append((merged_start, merged_end, duration)) 415 | merged_start, merged_end = s, e 416 | final.append((merged_start, merged_end, duration)) 417 | 418 | seen = set() 419 | unique_final: List[Tuple[int, int, float]] = [] 420 | for s, e, d in sorted(final, key=lambda x: (x[0], x[1], x[2])): 421 | key = (s, e, d) 422 | if key not in seen: 423 | seen.add(key) 424 | unique_final.append((s, e, d)) 425 | return unique_final 426 | 427 | 428 | def _compute_frac_positions_vectorized(values: np.ndarray, tb: np.ndarray) -> np.ndarray: 429 | """ 430 | Vectorized computation of fractional positions for an array of tick values. 431 | Returns a float array where integer positions correspond to exact tick indices, 432 | and fractional positions are interpolated between indices. 433 | """ 434 | if values.size == 0: 435 | return np.array([], dtype=float) 436 | 437 | tb_size = tb.size 438 | # Use searchsorted to find insertion positions 439 | pos = tb.searchsorted(values, side='left') # pos in [0..tb_size] 440 | res = np.empty(values.shape, dtype=float) 441 | 442 | # Exact matches where pos < tb_size and tb[pos] == value 443 | exact_mask = (pos < tb_size) & (tb[pos] == values) 444 | if exact_mask.any(): 445 | res[exact_mask] = pos[exact_mask].astype(float) 446 | 447 | # pos == 0 and not exact 448 | mask_pos0 = (pos == 0) & (~exact_mask) 449 | if mask_pos0.any(): 450 | prev_tick = tb[0] 451 | next_tick = tb[1] if tb_size > 1 else tb[0] + 1 452 | denom = next_tick - prev_tick if next_tick != prev_tick else 1 453 | res[mask_pos0] = (values[mask_pos0] - prev_tick) / denom 454 | 455 | # pos >= tb_size (to the right of last tick) 456 | mask_right = (pos >= tb_size) & (~exact_mask) 457 | if mask_right.any(): 458 | if tb_size > 1: 459 | prev_tick = tb[-2] 460 | last_tick = tb[-1] 461 | denom = last_tick - prev_tick if last_tick != prev_tick else 1 462 | res[mask_right] = float(tb_size - 1) + (values[mask_right] - last_tick) / denom 463 | else: 464 | # single tick in tb 465 | res[mask_right] = float(0) + (values[mask_right] - tb[0]) # fallback 466 | 467 | # middle cases: 0 < pos < tb_size and not exact 468 | mask_mid = (~exact_mask) & (~mask_pos0) & (~mask_right) 469 | if mask_mid.any(): 470 | pos_mid = pos[mask_mid] 471 | prev_tick = tb[pos_mid - 1] 472 | next_tick = tb[pos_mid] 473 | denom = next_tick - prev_tick 474 | # avoid division by zero 475 | denom = np.where(denom == 0, 1, denom) 476 | res[mask_mid] = (pos_mid - 1).astype(float) + (values[mask_mid] - prev_tick) / denom 477 | 478 | return res 479 | 480 | 481 | def _compute_durations_batch(starts: np.ndarray, ends: np.ndarray, tb: np.ndarray) -> np.ndarray: 482 | """ 483 | Compute durations (in beats) for arrays of start and end tick values. 484 | Vectorized and uses _compute_frac_positions_vectorized. 485 | """ 486 | if tb is None or tb.size == 0: 487 | return np.zeros_like(starts, dtype=float) 488 | 489 | # Convert to numpy arrays 490 | starts_arr = np.asarray(starts, dtype=np.int64) 491 | ends_arr = np.asarray(ends, dtype=np.int64) 492 | 493 | # Compute fractional positions for starts and ends 494 | start_pos = _compute_frac_positions_vectorized(starts_arr, tb) 495 | end_pos = _compute_frac_positions_vectorized(ends_arr, tb) 496 | 497 | durations = end_pos - start_pos 498 | # Clip negative durations to 0 499 | durations = np.where(durations < 0.0, 0.0, durations) 500 | return durations.astype(float) 501 | 502 | 503 | def get_duration_beats(start: int, end: int, tb: np.ndarray, tick_to_idx: Optional[Dict[int, int]] = None) -> float: 504 | """ 505 | Backwards-compatible single-pair duration computation. 506 | For heavy workloads prefer the batch computation used in get_valid_loops. 507 | """ 508 | if tb is None or tb.size == 0: 509 | return 0.0 510 | 511 | if end <= start: 512 | return 0.0 513 | 514 | try: 515 | tb_first = int(tb[0]) 516 | tb_last = int(tb[-1]) 517 | except Exception: 518 | return 0.0 519 | 520 | key = (int(start), int(end), tb_first, tb_last, int(tb.size)) 521 | with _duration_cache_lock: 522 | cached = _duration_cache.get(key) 523 | if cached is not None: 524 | return cached 525 | 526 | # Fast path for single-element tb 527 | if tb.size == 1: 528 | duration = float(max(0.0, end - start)) 529 | with _duration_cache_lock: 530 | _duration_cache[key] = duration 531 | return duration 532 | 533 | # Fallback to vectorized batch helper for a single pair 534 | dur = _compute_durations_batch(np.array([start], dtype=np.int64), np.array([end], dtype=np.int64), tb)[0] 535 | with _duration_cache_lock: 536 | _duration_cache[key] = dur 537 | return dur 538 | 539 | 540 | def get_valid_loops( 541 | note_sets: Sequence[NoteSet], 542 | corr_mat: np.ndarray, 543 | ticks_beats: Sequence[int], 544 | min_rep_notes: int = 4, 545 | min_rep_beats: float = 2.0, 546 | min_beats: float = 32.0, 547 | max_beats: float = 32.0, 548 | min_loop_note_density: float = 0.5, 549 | ) -> Tuple[List[Sequence[NoteSet]], List[Tuple[int, int, float, float]]]: 550 | """ 551 | Return detected loops and metadata. 552 | """ 553 | min_rep_notes += 1 # original behavior to not count barlines 554 | x_idx, y_idx = np.where(corr_mat == min_rep_notes) 555 | if x_idx.size == 0: 556 | return [], [] 557 | 558 | tb_sanitized = _sanitize_ticks_beats_once(ticks_beats) 559 | if tb_sanitized.size == 0: 560 | # no valid beat ticks to compute durations -> no loops 561 | return [], [] 562 | 563 | # Precompute tick->index mapping once (cheap relative to repeated rebuilds) 564 | try: 565 | tick_to_idx_global = {int(t): i for i, t in enumerate(tb_sanitized)} 566 | except Exception: 567 | tick_to_idx_global = {} 568 | 569 | valid_indices: Dict[float, List[Tuple[int, int]]] = {} 570 | # Collect unique key pairs to compute durations in batch 571 | unique_pairs = {} 572 | pairs_list_starts = [] 573 | pairs_list_ends = [] 574 | pairs_keys = [] 575 | 576 | # First pass: collect candidate pairs and unique (start_tick, end_tick) pairs 577 | for xi, yi in zip(x_idx, y_idx): 578 | try: 579 | run_len = int(corr_mat[xi, yi]) 580 | except Exception: 581 | continue 582 | start_x = xi - run_len + 1 583 | start_y = yi - run_len + 1 584 | 585 | if start_x < 0 or start_y < 0 or start_x >= len(note_sets) or start_y >= len(note_sets): 586 | continue 587 | 588 | try: 589 | loop_start_time = int(note_sets[start_x].start) 590 | loop_end_time = int(note_sets[start_y].start) 591 | except Exception: 592 | continue 593 | 594 | # sanity check tick magnitudes 595 | if loop_start_time < 0 or loop_end_time < 0 or loop_start_time > _MAX_TICK or loop_end_time > _MAX_TICK: 596 | continue 597 | 598 | key_pair = (loop_start_time, loop_end_time) 599 | if key_pair not in unique_pairs: 600 | unique_pairs[key_pair] = len(pairs_list_starts) 601 | pairs_list_starts.append(loop_start_time) 602 | pairs_list_ends.append(loop_end_time) 603 | pairs_keys.append(key_pair) 604 | 605 | if not pairs_list_starts: 606 | return [], [] 607 | 608 | # Batch compute durations for all unique pairs 609 | starts_arr = np.asarray(pairs_list_starts, dtype=np.int64) 610 | ends_arr = np.asarray(pairs_list_ends, dtype=np.int64) 611 | durations_arr = _compute_durations_batch(starts_arr, ends_arr, tb_sanitized) 612 | # Round durations to 2 decimals to match previous behavior 613 | durations_rounded = np.round(durations_arr.astype(float), 2) 614 | 615 | # Build a mapping from key_pair to rounded duration 616 | keypair_to_duration: Dict[Tuple[int, int], float] = { 617 | kp: float(durations_rounded[idx]) for idx, kp in enumerate(pairs_keys) 618 | } 619 | 620 | # Second pass: populate valid_indices using computed durations 621 | for xi, yi in zip(x_idx, y_idx): 622 | try: 623 | run_len = int(corr_mat[xi, yi]) 624 | except Exception: 625 | continue 626 | start_x = xi - run_len + 1 627 | start_y = yi - run_len + 1 628 | 629 | if start_x < 0 or start_y < 0 or start_x >= len(note_sets) or start_y >= len(note_sets): 630 | continue 631 | 632 | try: 633 | loop_start_time = int(note_sets[start_x].start) 634 | loop_end_time = int(note_sets[start_y].start) 635 | except Exception: 636 | continue 637 | 638 | key_pair = (loop_start_time, loop_end_time) 639 | loop_num_beats = keypair_to_duration.get(key_pair, 0.0) 640 | 641 | if min_beats <= loop_num_beats <= max_beats: 642 | valid_indices.setdefault(loop_num_beats, []).append((int(xi), int(yi))) 643 | 644 | if not valid_indices: 645 | return [], [] 646 | 647 | filtered = filter_sub_loops(valid_indices) 648 | 649 | loops: List[Sequence[NoteSet]] = [] 650 | loop_bp: List[Tuple[int, int, float, float]] = [] 651 | corr_size = corr_mat.shape[0] 652 | 653 | for start_x, start_y, loop_num_beats in filtered: 654 | x = start_x 655 | y = start_y 656 | while x + 1 < corr_size and y + 1 < corr_size and corr_mat[x + 1, y + 1] > corr_mat[x, y]: 657 | x += 1 658 | y += 1 659 | 660 | beginning = x - int(corr_mat[x, y]) + 1 661 | end = y - int(corr_mat[x, y]) + 1 662 | 663 | if beginning < 0 or end < 0 or beginning >= len(note_sets) or end >= len(note_sets): 664 | continue 665 | 666 | start_tick = int(note_sets[beginning].start) 667 | end_tick = int(note_sets[end].start) 668 | 669 | # Try to get duration from keypair_to_duration first 670 | duration_beats = keypair_to_duration.get((start_tick, end_tick)) 671 | if duration_beats is None: 672 | # Fallback: compute single pair (rare) 673 | duration_beats = get_duration_beats(start_tick, end_tick, tb_sanitized, tick_to_idx=tick_to_idx_global) 674 | 675 | if duration_beats >= min_rep_beats and not is_empty_loop(note_sets[beginning:end]): 676 | loop = note_sets[beginning : (end + 1)] 677 | loop_density = get_loop_density(loop, loop_num_beats) 678 | if loop_density < min_loop_note_density: 679 | continue 680 | exist_result = test_loop_exists(loops, loop, min_rep_beats) 681 | if exist_result is None: 682 | loops.append(loop) 683 | loop_bp.append((start_tick, end_tick, loop_num_beats, loop_density)) 684 | elif exist_result > 0: 685 | loops[exist_result] = loop 686 | loop_bp[exist_result] = (start_tick, end_tick, loop_num_beats, loop_density) 687 | 688 | return loops, loop_bp --------------------------------------------------------------------------------