├── tests
├── __init__.py
├── midi_files
│ ├── Aicha.mid
│ ├── empty.mid
│ ├── Funkytown.mid
│ ├── Maestro_1.mid
│ ├── Maestro_2.mid
│ ├── Maestro_3.mid
│ ├── Maestro_4.mid
│ ├── Maestro_5.mid
│ ├── Maestro_6.mid
│ ├── Maestro_7.mid
│ ├── Maestro_8.mid
│ ├── Maestro_9.mid
│ ├── Shut Up.mid
│ ├── test_midi.mid
│ ├── In Too Deep.mid
│ ├── Maestro_10.mid
│ ├── POP909_008.mid
│ ├── POP909_010.mid
│ ├── POP909_022.mid
│ ├── POP909_191.mid
│ ├── Mr. Blue Sky.mid
│ ├── I Gotta Feeling.mid
│ ├── 6338816_Etude No. 4.mid
│ ├── Les Yeux Revolvers.mid
│ ├── 6354774_Macabre Waltz.mid
│ ├── All The Small Things.mid
│ ├── What a Fool Believes.mid
│ ├── Girls Just Want to Have Fun.mid
│ ├── d6caebd1964d9e4a3c5ea59525230e2a.mid
│ └── d8faddb8596fff7abb24d78666f73e4e.mid
├── utils_tests.py
├── test_run.py
└── test_nomml.py
├── Giga_MIDI_Logo_Final.png
├── loops_nomml
├── loop_ex_labeled.png
├── __init__.py
├── README.md
├── nomml.py
├── note_set.py
├── note_set_fast.py
├── process_file.py
├── process_file_fast.py
├── corr_mat.py
└── corr_mat_fast.py
├── scripts
├── figures
│ └── GigaMIDI_duration_bars.pdf
├── nomml_gigamidi.py
├── main_expressive.py
├── analyze_gigamidi_dataset.py
├── dataset_programs_dist.py
├── test_extract.ipynb
└── dataset_short_files.py
├── Data Source Links for the GigaMIDI Dataset - Sheet1.pdf
├── .gitmodules
├── Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models
├── Saved Logistic Regression Models
│ ├── model-DNVR.pkl
│ ├── model-DNODR.pkl
│ └── model-NOMML.pkl
├── Curated Evaluation Set
│ ├── MIDI Score (Non-expressive MIDI Tracks)
│ │ ├── ASAP-Score-only.numbers
│ │ ├── ATEPP-Score-Final.numbers
│ │ └── ATEPP-Score-Final-refinement.numbers
│ └── MIDI Performance (Expressive MIDI Tracks)
│ │ ├── Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers
│ │ └── Saarland.csv
└── Optimal Threshold Selection
│ └── Expressive_Performance_Detection_Training-Percentile-threshold.csv
├── LICENSE
├── main.py
├── pyproject.toml
├── MIDI-GPT-Loop
└── README.md
├── Expressive music loop detector-NOMML12.ipynb
├── GigaMIDI
├── create_gigamidi_dataset.py
├── GigaMIDI.py
└── README.md
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Giga_MIDI_Logo_Final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Giga_MIDI_Logo_Final.png
--------------------------------------------------------------------------------
/tests/midi_files/Aicha.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Aicha.mid
--------------------------------------------------------------------------------
/tests/midi_files/empty.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/empty.mid
--------------------------------------------------------------------------------
/tests/midi_files/Funkytown.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Funkytown.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_1.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_1.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_2.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_2.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_3.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_3.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_4.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_5.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_5.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_6.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_6.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_7.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_7.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_8.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_8.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_9.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_9.mid
--------------------------------------------------------------------------------
/tests/midi_files/Shut Up.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Shut Up.mid
--------------------------------------------------------------------------------
/tests/midi_files/test_midi.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/test_midi.mid
--------------------------------------------------------------------------------
/loops_nomml/loop_ex_labeled.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/loops_nomml/loop_ex_labeled.png
--------------------------------------------------------------------------------
/tests/midi_files/In Too Deep.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/In Too Deep.mid
--------------------------------------------------------------------------------
/tests/midi_files/Maestro_10.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Maestro_10.mid
--------------------------------------------------------------------------------
/tests/midi_files/POP909_008.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_008.mid
--------------------------------------------------------------------------------
/tests/midi_files/POP909_010.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_010.mid
--------------------------------------------------------------------------------
/tests/midi_files/POP909_022.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_022.mid
--------------------------------------------------------------------------------
/tests/midi_files/POP909_191.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/POP909_191.mid
--------------------------------------------------------------------------------
/tests/midi_files/Mr. Blue Sky.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Mr. Blue Sky.mid
--------------------------------------------------------------------------------
/tests/midi_files/I Gotta Feeling.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/I Gotta Feeling.mid
--------------------------------------------------------------------------------
/tests/midi_files/6338816_Etude No. 4.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/6338816_Etude No. 4.mid
--------------------------------------------------------------------------------
/tests/midi_files/Les Yeux Revolvers.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Les Yeux Revolvers.mid
--------------------------------------------------------------------------------
/scripts/figures/GigaMIDI_duration_bars.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/scripts/figures/GigaMIDI_duration_bars.pdf
--------------------------------------------------------------------------------
/tests/midi_files/6354774_Macabre Waltz.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/6354774_Macabre Waltz.mid
--------------------------------------------------------------------------------
/tests/midi_files/All The Small Things.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/All The Small Things.mid
--------------------------------------------------------------------------------
/tests/midi_files/What a Fool Believes.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/What a Fool Believes.mid
--------------------------------------------------------------------------------
/tests/midi_files/Girls Just Want to Have Fun.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/Girls Just Want to Have Fun.mid
--------------------------------------------------------------------------------
/Data Source Links for the GigaMIDI Dataset - Sheet1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Data Source Links for the GigaMIDI Dataset - Sheet1.pdf
--------------------------------------------------------------------------------
/tests/midi_files/d6caebd1964d9e4a3c5ea59525230e2a.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/d6caebd1964d9e4a3c5ea59525230e2a.mid
--------------------------------------------------------------------------------
/tests/midi_files/d8faddb8596fff7abb24d78666f73e4e.mid:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/tests/midi_files/d8faddb8596fff7abb24d78666f73e4e.mid
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "MIDI-GPT-Loop/MMM-Loop"]
2 | path = MIDI-GPT-Loop/MMM-Loop
3 | url = git@github.com:Metacreation-Lab/MMM-Loop.git
4 | [submodule "MIDI-GPT-Loop/MMM"]
5 | path = MIDI-GPT-Loop/MMM
6 | url = git@github.com:Metacreation-Lab/MMM.git
7 | branch = Expressive-Loops
8 |
--------------------------------------------------------------------------------
/loops_nomml/__init__.py:
--------------------------------------------------------------------------------
1 | """Main module."""
2 |
3 | from .process_file import detect_loops, detect_loops_from_path
4 | from .nomml import get_median_metric_depth
5 |
6 | __all__ = [
7 | "get_median_metric_depth",
8 | "detect_loops",
9 | "detect_loops_from_path"
10 | ]
11 |
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNVR.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNVR.pkl
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNODR.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-DNODR.pkl
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-NOMML.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Saved Logistic Regression Models/model-NOMML.pkl
--------------------------------------------------------------------------------
/tests/utils_tests.py:
--------------------------------------------------------------------------------
1 | """Test validation methods."""
2 |
3 | from __future__ import annotations
4 |
5 | from pathlib import Path
6 |
7 | SEED = 777
8 |
9 | HERE = Path(__file__).parent
10 | MIDI_PATHS_ALL = sorted((HERE / "midi_files").rglob("*.mid"))
11 | # MIDI_PATHS_ALL = [MIDI_PATHS_ALL[-1]]
12 | TEST_LOG_DIR = HERE / "test_logs"
13 |
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ASAP-Score-only.numbers:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ASAP-Score-only.numbers
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final.numbers:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final.numbers
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final-refinement.numbers:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Score (Non-expressive MIDI Tracks)/ATEPP-Score-Final-refinement.numbers
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/asigalov61/GigaMIDI-Dataset/main/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Expressively-Performed-EP-Only-Aggregated-Final-GroundTruth.numbers
--------------------------------------------------------------------------------
/tests/test_run.py:
--------------------------------------------------------------------------------
1 | """Testing tokenization, making sure the data integrity is not altered."""
2 |
3 | from __future__ import annotations
4 |
5 | from pathlib import Path
6 |
7 | import pytest
8 | from symusic import Score
9 |
10 | from loops_nomml import detect_loops
11 |
12 | from .utils_tests import MIDI_PATHS_ALL
13 |
14 |
15 | @pytest.mark.parametrize("file_path", MIDI_PATHS_ALL, ids=lambda p: p.name)
16 | def test_one_track_midi_to_tokens_to_midi(file_path: Path):
17 | score = Score(file_path)
18 | _ = detect_loops(score)
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # GigaMIDI Dataset Licensing Information
2 | The dataset is distributed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license.
3 | This license permits users to share, adapt, and utilize the dataset exclusively for non-commercial purposes, including research
4 | and educational applications, provided that proper attribution is given to the original creators.
5 | By adhering to the terms of CC BY-NC 4.0, users ensure the dataset's responsible use while fostering its accessibility
6 | for academic and non-commercial endeavours.
7 |
--------------------------------------------------------------------------------
/loops_nomml/README.md:
--------------------------------------------------------------------------------
1 | # Fast loops extractor (CPU/GPU accelerated)
2 |
3 | ## Requirements
4 |
5 | ```sh
6 | pip install numba
7 | pip install numpy==1.24.4
8 | pip install pretty-midi
9 | pip install symusic
10 | pip install miditok
11 | ```
12 |
13 | ## Basic use example
14 |
15 | ```python
16 | import os
17 |
18 | # Set desired environment variable
19 | os.environ["USE_NUMBA"] = "1"
20 | os.environ["USE_CUDA"] = "1"
21 |
22 | # Check the variable
23 | print(os.environ["USE_NUMBA"])
24 | print(os.environ["USE_CUDA"])
25 |
26 | from process_file_fast import detect_loops_from_path
27 |
28 | midi_file = './your_midi_file.mid'
29 |
30 | result = detect_loops_from_path({'file_path': [midi_file]})
31 | ```
32 |
33 | ## Enjoy! :)
34 |
--------------------------------------------------------------------------------
/loops_nomml/nomml.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | import numpy as np
5 | from symusic import Score
6 |
7 |
8 | def get_metric_depth(time, tpq, max_depth=6):
9 | for i in range(max_depth):
10 | period = tpq / int(2 ** i)
11 | if time % period == 0:
12 | return 2 * i
13 | for i in range(max_depth):
14 | period = tpq * 2 / (int(2 ** i) * 3)
15 | if time % period == 0:
16 | return 2 * i + 1
17 | return max_depth * 2
18 |
19 |
20 | def get_median_metric_depth(path):
21 | mf = Score(path)
22 | median_metric_depths = []
23 | for track in mf.tracks:
24 | metric_depths = [get_metric_depth(event.time, mf.tpq) for event in track.notes]
25 | if len(metric_depths) > 0:
26 | median_metric_depths.append(int(np.median(metric_depths)))
27 | return path, median_metric_depths
28 |
--------------------------------------------------------------------------------
/scripts/nomml_gigamidi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import random
5 | from tqdm import tqdm
6 | from multiprocessing import Pool
7 |
8 | from loops_nomml import get_median_metric_depth
9 |
10 |
11 | def load_json(path):
12 | with open(path, "r") as f:
13 | return json.load(f)
14 |
15 |
16 | def dump_json(x, path):
17 | with open(path, "w") as f:
18 | json.dump(x, f, indent=4)
19 |
20 |
21 | def worker(args):
22 | try:
23 | return get_median_metric_depth(*args)
24 | except Exception as e:
25 | print("FAILED : ", e)
26 | return None
27 |
28 |
29 | def main(folder, force=False, nthreads=8):
30 | output_path = os.path.basename(folder).lower() + ".json"
31 | if os.path.exists(output_path) and not force:
32 | return load_json(output_path)
33 |
34 | paths = [glob.glob(folder + f"/**/*.{ext}", recursive=True) for ext in ["mid", "midi", "MID"]]
35 | paths = [p for sublist in paths for p in sublist]
36 | random.shuffle(paths)
37 |
38 | count = 0
39 | result = {}
40 | p = Pool(nthreads)
41 | for path, median_metric_depths in filter(lambda x: x is not None,
42 | tqdm(p.imap_unordered(worker, [(p,) for p in paths]), total=len(paths))):
43 | result[os.path.relpath(path, folder)] = median_metric_depths
44 | if count % 50000 == 0:
45 | dump_json(result, output_path)
46 | count += 1
47 | dump_json(result, output_path)
48 | return result
49 |
50 |
51 | if __name__ == "__main__":
52 | import argparse
53 |
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--folder", type=str, default="", required=True)
56 | parser.add_argument("--force", action="store_true")
57 | parser.add_argument("--nthreads", type=int, default=8)
58 | args = parser.parse_args()
59 | main(args.folder, force=args.force, nthreads=args.nthreads)
60 |
--------------------------------------------------------------------------------
/tests/test_nomml.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | from multiprocessing import Pool
8 | from symusic import Score
9 |
10 | from loops_nomml import get_median_metric_depth
11 |
12 |
13 | def load_json(path):
14 | with open(path, "r") as f:
15 | return json.load(f)
16 |
17 |
18 | def dump_json(x, path):
19 | with open(path, "w") as f:
20 | json.dump(x, f, indent=4)
21 |
22 |
23 | def worker(args):
24 | try:
25 | return get_median_metric_depth(*args)
26 | except Exception as e:
27 | print("FAILED : ", e)
28 | return None
29 |
30 |
31 | #TODO test method
32 |
33 | def main(folder, force=False, nthreads=8):
34 | output_path = os.path.basename(folder).lower() + ".json"
35 | if os.path.exists(output_path) and not force:
36 | return load_json(output_path)
37 |
38 | paths = [glob.glob(folder + f"/**/*.{ext}", recursive=True) for ext in ["mid", "midi", "MID"]]
39 | paths = [p for sublist in paths for p in sublist]
40 | random.shuffle(paths)
41 |
42 | count = 0
43 | result = {}
44 | p = Pool(nthreads)
45 | for path, median_metric_depths in filter(lambda x: x is not None,
46 | tqdm(p.imap_unordered(worker, [(p,) for p in paths]), total=len(paths))):
47 | result[os.path.relpath(path, folder)] = median_metric_depths
48 | if count % 50000 == 0:
49 | dump_json(result, output_path)
50 | count += 1
51 | dump_json(result, output_path)
52 | return result
53 |
54 |
55 | if __name__ == "__main__":
56 | import argparse
57 |
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--folder", type=str, default="", required=True)
60 | parser.add_argument("--force", action="store_true")
61 | parser.add_argument("--nthreads", type=int, default=8)
62 | args = parser.parse_args()
63 | main(args.folder, force=args.force, nthreads=args.nthreads)
64 |
--------------------------------------------------------------------------------
/scripts/main_expressive.py:
--------------------------------------------------------------------------------
1 | from loops_nomml.process_file import detect_loops
2 | import os
3 | from datasets import Dataset, load_dataset
4 |
5 | DATA_PATH = "D:\\Documents\\GigaMIDI"
6 | METADATA_NAME = "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv"
7 | SHARD_SIZE = 20000
8 | OUTPUT_NAME = "gigamidi_expressive_loops_no_quant"
9 |
10 | if __name__ == "__main__":
11 | metadata_path = os.path.join(DATA_PATH, METADATA_NAME)
12 | output_path = os.path.join(DATA_PATH, OUTPUT_NAME)
13 | if not os.path.exists(output_path):
14 | os.mkdir(output_path)
15 |
16 | dataset = load_dataset("csv", data_files=metadata_path)['train']
17 | print(f"loaded {len(dataset)} tracks")
18 | dataset_expressive = dataset.filter(lambda row: row['medianMetricDepth'] == 12)
19 | print(f"filtered to {len(dataset_expressive)} expressive tracks")
20 | dataset_with_time_signature = dataset_expressive.filter(lambda row: row['hasTimeSignatures'])
21 | print(f"filtered to {len(dataset_with_time_signature)} expressive tracks with time signatures")
22 |
23 | unique_files = dataset_with_time_signature.unique('filepath')
24 | unique_files = [{"file_path": os.path.join(DATA_PATH, file_path), "file_name": file_path} for file_path in unique_files]
25 | unique_files_dataset = Dataset.from_list(unique_files)
26 | print(f"filtered to {len(unique_files_dataset)} unique MIDI files, expressive with time signatures")
27 | unique_files_dataset = unique_files_dataset.shuffle(seed=42)
28 |
29 | num_shards = int(round(len(unique_files_dataset) / SHARD_SIZE))
30 | print(f"Splitting dataset in {num_shards} shards")
31 | print(f"Saving shards to {output_path}")
32 | for shard_idx in range(num_shards):
33 | shard = unique_files_dataset.shard(num_shards=num_shards, index=shard_idx)
34 | shard = shard.map(
35 | detect_loops,
36 | remove_columns=['file_path', 'file_name'],
37 | batched=True,
38 | batch_size=1,
39 | writer_batch_size=1,
40 | num_proc=8
41 | )
42 |
43 | csv_path = os.path.join(output_path, OUTPUT_NAME + "_" + str(shard_idx) + ".csv")
44 | shard.to_csv(csv_path)
45 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from loops_nomml.process_file import detect_loops_from_path
2 | import os
3 | from datasets import Dataset, load_dataset
4 |
5 | DATA_PATH = "D:\\Documents\\GigaMIDI"
6 | METADATA_NAME = "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv"
7 | SHARD_SIZE = 20000
8 | OUTPUT_NAME = "gigamidi_non_expressive_loops"
9 |
10 | if __name__ == "__main__":
11 | metadata_path = os.path.join(DATA_PATH, METADATA_NAME)
12 | output_path = os.path.join(DATA_PATH, OUTPUT_NAME)
13 | if not os.path.exists(output_path):
14 | os.mkdir(output_path)
15 |
16 | dataset = load_dataset("csv", data_files=metadata_path)['train']
17 | print(f"loaded {len(dataset)} tracks")
18 | dataset_non_expressive = dataset.filter(lambda row: row['medianMetricDepth'] < 12)
19 | print(f"filtered to {len(dataset_non_expressive)} non-expressive tracks")
20 | dataset_with_time_signature = dataset_non_expressive.filter(lambda row: row['hasTimeSignatures'])
21 | print(f"filtered to {len(dataset_with_time_signature)} non-expressive tracks with time signatures")
22 |
23 | unique_files = dataset_with_time_signature.unique('filepath')
24 | unique_files = [{"file_path": os.path.join(DATA_PATH, file_path), "file_name": file_path} for file_path in unique_files]
25 | unique_files_dataset = Dataset.from_list(unique_files)
26 | print(f"filtered to {len(unique_files_dataset)} unique MIDI files, non-expressive with time signatures")
27 | unique_files_dataset = unique_files_dataset.shuffle(seed=42)
28 |
29 | num_shards = int(round(len(unique_files_dataset) / SHARD_SIZE))
30 | print(f"Splitting dataset in {num_shards} shards")
31 | print(f"Saving shards to {output_path}")
32 | for shard_idx in range(0, num_shards):
33 | shard = unique_files_dataset.shard(num_shards=num_shards, index=shard_idx)
34 | shard = shard.map(
35 | detect_loops_from_path,
36 | remove_columns=['file_path', 'file_name'],
37 | batched=True,
38 | batch_size=1,
39 | writer_batch_size=1,
40 | num_proc=8
41 | )
42 |
43 | csv_path = os.path.join(output_path, OUTPUT_NAME + "_" + str(shard_idx) + ".csv")
44 | shard.to_csv(csv_path)
45 |
--------------------------------------------------------------------------------
/scripts/analyze_gigamidi_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3 python
2 |
3 | """Measuring the length of GigaMIDI files in bars."""
4 |
5 | if __name__ == "__main__":
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | from datasets import concatenate_datasets, load_dataset
10 | from matplotlib import pyplot as plt
11 | from miditok.constants import SCORE_LOADING_EXCEPTION
12 | from miditok.utils import get_bars_ticks
13 | from symusic import Score
14 | from tqdm import tqdm
15 | from transformers import set_seed
16 |
17 | from utils.GigaMIDI.GigaMIDI import _SUBSETS
18 | from utils.utils import path_data_directory_local_fs
19 |
20 | SEED = 777
21 | NUM_FILES = 5000
22 | NUM_HIST_BINS = 25
23 | X_AXIS_LIM_BARS = 300
24 | X_AXIS_LIM_BEATS = X_AXIS_LIM_BARS * 4
25 | set_seed(SEED)
26 | FIGURES_PATH = Path("scripts", "analysis", "figures")
27 |
28 | # Measuring lengths in bars/beats
29 | dist_lengths_bars = {}
30 | for subset_name in _SUBSETS:
31 | subset = concatenate_datasets(
32 | list(
33 | load_dataset(
34 | str(path_data_directory_local_fs() / "GigaMIDI"),
35 | subset_name,
36 | trust_remote_code=True,
37 | ).values()
38 | )
39 | ).shuffle()
40 | dist_lengths_bars[subset_name] = []
41 |
42 | idx = 0
43 | with tqdm(total=NUM_FILES, desc=f"Analyzing `{subset_name}` subset") as pbar:
44 | while len(dist_lengths_bars[subset_name]) < NUM_FILES and idx < len(subset):
45 | try:
46 | score = Score.from_midi(subset[idx]["music"]["bytes"])
47 | dist_lengths_bars[subset_name].append(len(get_bars_ticks(score)))
48 | pbar.update(1)
49 | except SCORE_LOADING_EXCEPTION:
50 | pass
51 | finally:
52 | idx += 1
53 |
54 | # Plotting length (bars) distribution
55 | for subset_name in dist_lengths_bars:
56 | dist_arr = np.array(dist_lengths_bars[subset_name])
57 | dist_lengths_bars[subset_name] = dist_arr[
58 | np.where(dist_arr <= X_AXIS_LIM_BARS)[0]
59 | ]
60 | fig, ax = plt.subplots()
61 | ax.hist(
62 | dist_lengths_bars.values(),
63 | bins=NUM_HIST_BINS,
64 | density=True,
65 | label=dist_lengths_bars.keys(),
66 | )
67 | ax.grid(axis="y", linestyle="--", linewidth=0.6)
68 | ax.legend(prop={"size": 10})
69 | ax.set_ylabel("Density")
70 | ax.set_xlabel("Duration in bars")
71 | fig.savefig(
72 | FIGURES_PATH / "GigaMIDI_duration_bars.pdf", bbox_inches="tight", dpi=300
73 | )
74 | plt.close(fig)
75 |
--------------------------------------------------------------------------------
/loops_nomml/note_set.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import TYPE_CHECKING
5 |
6 | if TYPE_CHECKING:
7 | from collections.abc import Sequence
8 |
9 | from symusic import Note
10 | from symusic.core import NoteTickList
11 |
12 |
13 | def compute_note_sets(notes: NoteTickList, bars_ticks: Sequence[int]) -> list[NoteSet]:
14 | """
15 | Converts a list of MIDI notes and associated measure start times into
16 | a list of NoteSets. Barlines will be represented as empty NoteSets
17 | with a duration of 0
18 |
19 | :param notes: list of MIDI notes in a single track
20 | :param bar ticks: list of measure start times in ticks
21 | :return: NoteSet representation of the MIDI track
22 | """
23 | processed_notes = []
24 | for note in notes:
25 | start_new_set = len(processed_notes) == 0 or not processed_notes[-1].fits_in_set(note.start, note.end)
26 | if start_new_set:
27 | processed_notes.append(NoteSet(start=note.start, end=note.end))
28 | processed_notes[-1].add_note(note)
29 |
30 | notes = processed_notes + [NoteSet(start=db, end=db) for db in bars_ticks]
31 | notes.sort()
32 | return notes
33 |
34 |
35 | class NoteSet:
36 | """
37 | A set of unique pitches that occur at the same start time and end at
38 | the same time (have the same duration).
39 |
40 | If a NoteSet has no pitches are a duration of 0, it represents the
41 | start of a measure (ie a barline)
42 |
43 | :param start: start time in MIDI ticks
44 | :param end: end time in MIDI ticks
45 | """
46 | def __init__(self, start: int, end: int) -> None:
47 | self.start = start
48 | self.end = end
49 | self.duration = self.end - self.start
50 | self.pitches = set() # MIDI note numbers
51 |
52 | def add_note(self, note: Note) -> None:
53 | self.pitches.add(note.pitch)
54 |
55 | def fits_in_set(self, start: int, end: int) -> bool:
56 | return start == self.start and end == self.end
57 |
58 | def is_barline(self) -> bool:
59 | return self.start == self.end and len(self.pitches) == 0
60 |
61 | def __str__(self) -> str:
62 | return f"NoteSet({self.start}, {self.duration}, {self.pitches})"
63 |
64 | def __eq__(self, other: object) -> bool:
65 | """
66 | Two NoteSets are equal if they match in start time, end time and
67 | MIDI pitches present
68 | """
69 | if not isinstance(other, NoteSet):
70 | return False
71 |
72 | if self.duration != other.duration:
73 | return False
74 | if len(self.pitches) != len(other.pitches):
75 | return False
76 |
77 | for m in self.pitches:
78 | if m not in other.pitches:
79 | return False
80 |
81 | return True
82 |
83 | def __lt__(self, other: object):
84 | """
85 | A NoteSet is sorted based on start time
86 | """
87 | return self.start < other.start
88 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [build-system]
3 | requires = ["hatchling"]
4 | build-backend = "hatchling.build"
5 |
6 | [project]
7 | name = "GigaMIDI"
8 | version = "0.0.1"
9 | description = "Loop and expressive performance detection methods for symbolic music."
10 | readme = {file = "README.md", content-type = "text/markdown"}
11 | license = {file = "LICENSE"}
12 | requires-python = ">=3.8.0"
13 | authors = [
14 | ]
15 | keywords = [
16 | "artificial intelligence",
17 | "deep learning",
18 | "transformer",
19 | "midi",
20 | "music",
21 | "mir",
22 | ]
23 | classifiers = [
24 | "Intended Audience :: Developers",
25 | "Intended Audience :: Science/Research",
26 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
27 | "Topic :: Multimedia :: Sound/Audio :: MIDI",
28 | "License :: OSI Approved :: MIT License",
29 | "Programming Language :: Python",
30 | "Programming Language :: Python :: 3.8",
31 | "Programming Language :: Python :: 3.9",
32 | "Programming Language :: Python :: 3.10",
33 | "Programming Language :: Python :: 3.11",
34 | "Programming Language :: Python :: 3.12",
35 | "Operating System :: OS Independent",
36 | ]
37 | dependencies = [
38 | "numpy==1.26.3",
39 | "datasets>=2.20.0",
40 | "symusic>=0.5.0",
41 | "miditok>3.0.3",
42 | ]
43 |
44 | [project.urls]
45 | Homepage = "https://github.com/"
46 |
47 | [tool.hatch.version]
48 | path = "gigamidi/__init__.py"
49 |
50 | [tool.hatch.build.targets.sdist]
51 | include = [
52 | "/gigamidi",
53 | ]
54 |
55 | [mypy]
56 | warn_return_any = "True"
57 | warn_unused_configs = "True"
58 | plugins = "numpy.typing.mypy_plugin"
59 | exclude = [
60 | "venv",
61 | ".venv",
62 | ]
63 |
64 | [tool.ruff]
65 | target-version = "py312"
66 |
67 | [tool.ruff.lint]
68 | extend-select = [
69 | "ARG",
70 | "A",
71 | "ANN",
72 | "B",
73 | "BLE",
74 | "C4",
75 | "COM",
76 | "D",
77 | "E",
78 | "EM",
79 | "EXE",
80 | "F",
81 | "FA",
82 | "FBT",
83 | "G",
84 | "I",
85 | "ICN",
86 | "INP",
87 | "INT",
88 | "ISC",
89 | "N",
90 | "NPY",
91 | "PERF",
92 | "PGH",
93 | "PTH",
94 | "PIE",
95 | # "PL",
96 | "PT",
97 | "Q",
98 | "RET",
99 | "RSE",
100 | "RUF",
101 | "S",
102 | # "SLF",
103 | "SIM",
104 | "T",
105 | "TCH",
106 | "TID",
107 | "UP",
108 | "W",
109 | ]
110 |
111 | ignore = [
112 | "ANN003",
113 | "ANN101",
114 | "ANN102",
115 | "B905",
116 | "COM812",
117 | "D107",
118 | "D203",
119 | "D212",
120 | "FBT001",
121 | "FBT002",
122 | "UP038",
123 | "S105",
124 | "S311",
125 | ]
126 |
127 | [tool.ruff.lint.per-file-ignores]
128 | "tests/**" = [
129 | "ANN201", # allow no return type hint for pytest methods
130 | "D103", # no need to document pytest methods
131 | "S101", # allow assertions in tests
132 | "T201", # print allowed
133 | ]
134 | "docs/conf.py" = ["INP001"] # not a package
135 |
--------------------------------------------------------------------------------
/loops_nomml/note_set_fast.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Sequence
4 |
5 | if TYPE_CHECKING:
6 | from symusic import Note
7 | from symusic.core import NoteTickList
8 |
9 |
10 | class NoteSet:
11 | __slots__ = ("start", "end", "duration", "pitches")
12 |
13 | def __init__(self, start: int, end: int) -> None:
14 | s = int(start)
15 | e = int(end)
16 | self.start = s
17 | self.end = e
18 | self.duration = int(e - s)
19 | self.pitches = set()
20 |
21 | def add_note(self, note: "Note") -> None:
22 | try:
23 | self.pitches.add(int(note.pitch))
24 | except Exception:
25 | pass
26 |
27 | def fits_in_set(self, start: int, end: int) -> bool:
28 | return int(start) == self.start and int(end) == self.end
29 |
30 | def is_barline(self) -> bool:
31 | return self.start == self.end and len(self.pitches) == 0
32 |
33 | def __str__(self) -> str:
34 | return f"NoteSet({self.start}, {self.duration}, {self.pitches})"
35 |
36 | def __eq__(self, other: object) -> bool:
37 | if not isinstance(other, NoteSet):
38 | return False
39 | if self.duration != other.duration:
40 | return False
41 | if len(self.pitches) != len(other.pitches):
42 | return False
43 | return self.pitches == other.pitches
44 |
45 | def __lt__(self, other: object):
46 | return self.start < other.start
47 |
48 |
49 | _MAX_DURATION_TICKS = 4_000_000
50 | _MAX_GAP_TICKS = 4_000_000
51 | _MIN_NOTE_DURATION = 0
52 |
53 | # New: reject implausible tick values (malformed events often have huge tick numbers)
54 | _MAX_TICK = 10 ** 9 # 1 billion ticks; adjust upward if your dataset legitimately uses larger ticks
55 |
56 |
57 | def compute_note_sets(notes: NoteTickList, bars_ticks: Sequence[int]) -> list[NoteSet]:
58 | processed_notes = []
59 | valid_notes = []
60 | append_valid = valid_notes.append
61 | for note in notes:
62 | try:
63 | s = int(note.start)
64 | e = int(note.end)
65 | except Exception:
66 | # malformed note event: skip
67 | continue
68 | # sanity checks: non-negative, reasonable magnitude
69 | if s < 0 or e < 0:
70 | continue
71 | if s > _MAX_TICK or e > _MAX_TICK:
72 | # suspiciously large tick values -> skip this note
73 | continue
74 | dur = e - s
75 | if dur < _MIN_NOTE_DURATION:
76 | continue
77 | if dur > _MAX_DURATION_TICKS:
78 | continue
79 | append_valid(note)
80 |
81 | if not valid_notes and not bars_ticks:
82 | return []
83 |
84 | try:
85 | valid_notes.sort(key=lambda n: (int(n.start), int(n.end)))
86 | except Exception:
87 | valid_notes = sorted(valid_notes, key=lambda n: (getattr(n, "start", 0) or 0, getattr(n, "end", 0) or 0))
88 |
89 | for note in valid_notes:
90 | try:
91 | start = int(note.start)
92 | end = int(note.end)
93 | except Exception:
94 | continue
95 | start_new_set = len(processed_notes) == 0 or not processed_notes[-1].fits_in_set(start, end)
96 | if start_new_set:
97 | processed_notes.append(NoteSet(start=start, end=end))
98 | processed_notes[-1].add_note(note)
99 |
100 | bars_clean = []
101 | for b in bars_ticks:
102 | try:
103 | bi = int(b)
104 | except Exception:
105 | continue
106 | if bi < 0:
107 | continue
108 | if bi > _MAX_TICK:
109 | # skip implausible bar tick
110 | continue
111 | bars_clean.append(bi)
112 | if bars_clean:
113 | try:
114 | bars_clean = sorted(set(bars_clean))
115 | except Exception:
116 | bars_clean = sorted(list(dict.fromkeys(bars_clean)))
117 | else:
118 | bars_clean = []
119 |
120 | bar_note_sets = [NoteSet(start=db, end=db) for db in bars_clean]
121 | all_sets = processed_notes + bar_note_sets
122 | try:
123 | all_sets.sort()
124 | except Exception:
125 | all_sets = sorted(all_sets, key=lambda ns: getattr(ns, "start", 0) or 0)
126 |
127 | final_sets = []
128 | prev_start = None
129 | for ns in all_sets:
130 | if ns.start is None or ns.end is None:
131 | continue
132 | if ns.start < 0 or ns.duration < 0:
133 | continue
134 | if prev_start is not None and (ns.start - prev_start) > _MAX_GAP_TICKS:
135 | # gap too large; keep but don't crash (original code had a pass)
136 | pass
137 | final_sets.append(ns)
138 | prev_start = ns.start
139 |
140 | return final_sets
--------------------------------------------------------------------------------
/loops_nomml/process_file.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 | from typing import Dict, Tuple, Any, List
4 | import numpy as np
5 | from miditok.utils import get_bars_ticks, get_beats_ticks
6 | from miditok.constants import CLASS_OF_INST, INSTRUMENT_CLASSES
7 | from symusic import Score, Track, TimeSignature
8 |
9 | from .corr_mat import calc_correlation, get_valid_loops
10 | from .note_set import compute_note_sets
11 |
12 | MAX_NOTES_PER_TRACK = 50000
13 | MIN_NOTES_PER_TRACK = 5
14 |
15 | """
16 | def get_instrument_type(track: Track) -> str:
17 |
18 | Determines MIDI instrument class of a track
19 |
20 | :param track: MIDI track to identify instrument of
21 | :return: name of instrument class
22 |
23 | if track.is_drum:
24 | return "Drums"
25 |
26 | return INSTRUMENT_CLASSES[CLASS_OF_INST[0]]["name"]
27 |
28 | """
29 |
30 |
31 | def get_instrument_type(track) -> str:
32 | """
33 | Determines MIDI instrument class of a track.
34 |
35 | :param track: A pretty_midi.Instrument object
36 | :return: name of instrument class
37 | """
38 | if track.is_drum:
39 | return "Drums"
40 |
41 | program_number = track.program # MIDI program number (0–127)
42 | instrument_class_index = CLASS_OF_INST[program_number]
43 | instrument_class_name = INSTRUMENT_CLASSES[instrument_class_index]["name"]
44 |
45 | return instrument_class_name
46 |
47 |
48 |
49 | def create_loop_dict(endpoint_data: Tuple[int, int, float, float], track_idx: int, instrument_type: str) -> Dict[str,Any]:
50 | """
51 | Formats loop metadata into a dictionary for dataset generation
52 |
53 | :param endpoint_data: tuple of loop start time in ticks, loop end time in
54 | ticks, loop duration in beats, and density in notes per beat
55 | :param track_idx: MIDI track index the loop belongs to
56 | :instrument_type: MIDI instrument the loop represents, as a string
57 | :return: data entry containing all metadata for a single loop
58 | """
59 | start, end, beats, density = endpoint_data
60 | return {
61 | "track_idx": track_idx,
62 | "instrument_type": instrument_type,
63 | "start": start,
64 | "end": end,
65 | "duration_beats": beats,
66 | "note_density": density
67 | }
68 |
69 | def detect_loops_from_path(file_info: Dict) -> Dict[str,List]:
70 | """
71 | Given a MIDI file, locate all loops present across all of its tracks
72 |
73 | :param file_info: dictionary containing a file_path key
74 | :return: dictionary of metadata for each identified loop
75 | """
76 | file_path = file_info['file_path']
77 | if isinstance(file_path, list):
78 | file_path = file_path[0]
79 | try:
80 | score = Score(file_path)
81 | except:
82 | print(f"Unable to parse score for {file_path}, skipping")
83 | return {
84 | "track_idx": [],
85 | "instrument_type": [],
86 | "start": [],
87 | "end": [],
88 | "duration_beats": [],
89 | "note_density": [],
90 | }
91 | return detect_loops(score, file_path=file_path)
92 |
93 | def detect_loops(score: Score, file_path: str = None) -> Dict[str,List]:
94 | """
95 | Given a MIDI score, locate all loops present across off the tracks
96 |
97 | :param score: score to evaluate for loops
98 | :return: dictionary of metadata for each identified loop
99 | """
100 | data = {
101 | "track_idx": [],
102 | "instrument_type": [],
103 | "start": [],
104 | "end": [],
105 | "duration_beats": [],
106 | "note_density": [],
107 | }
108 | if file_path is not None:
109 | data["file_path"] = []
110 | # Check that there is a time signature. There might be none with abc files
111 | if len(score.time_signatures) == 0:
112 | score.time_signatures.append(TimeSignature(0, 4, 4))
113 |
114 | try:
115 | bars_ticks = np.array(get_bars_ticks(score))
116 | except ZeroDivisionError:
117 | print(f"Skipping, couldn't find any bar lines")
118 | return data
119 |
120 | beats_ticks = np.array(get_beats_ticks(score))
121 | for idx, track in enumerate(score.tracks):
122 | if len(track.notes) > MAX_NOTES_PER_TRACK or len(track.notes) < MIN_NOTES_PER_TRACK:
123 | #print(f"Skipping track {idx} for length")
124 | continue
125 | # cut beats_tick at the end of the track
126 | if any(track_bars_mask := bars_ticks > track.end()):
127 | bars_ticks_track = bars_ticks[:np.nonzero(track_bars_mask)[0][0]]
128 | else:
129 | bars_ticks_track = bars_ticks
130 |
131 | # cut beats_tick at the end of the track
132 | if any(track_beats_mask := beats_ticks > track.end()):
133 | beats_ticks_track = beats_ticks[:np.nonzero(track_beats_mask)[0][0]]
134 | else:
135 | beats_ticks_track = beats_ticks
136 |
137 | if len(bars_ticks_track) > MAX_NOTES_PER_TRACK:
138 | print(f"Skipping track {idx} due to ill-formed bars")
139 | continue
140 |
141 | instrument_type = get_instrument_type(track)
142 | note_sets = compute_note_sets(track.notes, bars_ticks_track)
143 | lead_mat = calc_correlation(note_sets)
144 | _, loop_endpoints = get_valid_loops(note_sets, lead_mat, beats_ticks_track)
145 | for endpoint in loop_endpoints:
146 | loop_dict = create_loop_dict(endpoint, idx, instrument_type)
147 | for key in loop_dict.keys():
148 | data[key].append(loop_dict[key])
149 | if file_path is not None:
150 | data["file_path"].append(file_path)
151 |
152 | return data
153 |
--------------------------------------------------------------------------------
/scripts/dataset_programs_dist.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3 python
2 |
3 | """
4 | Counts the programs in the MegaMIDI dataset.
5 |
6 | Results:
7 | Program -1 (Drums): 61400 (0.226%)
8 | Program 0 (Acoustic Grand Piano): 41079 (0.151%)
9 | Program 1 (Bright Acoustic Piano): 2511 (0.009%)
10 | Program 2 (Electric Grand Piano): 891 (0.003%)
11 | Program 3 (Honky-tonk Piano): 572 (0.002%)
12 | Program 4 (Electric Piano 1): 2156 (0.008%)
13 | Program 5 (Electric Piano 2): 1466 (0.005%)
14 | Program 6 (Harpsichord): 1727 (0.006%)
15 | Program 7 (Clavi): 621 (0.002%)
16 | Program 8 (Celesta): 407 (0.001%)
17 | Program 9 (Glockenspiel): 1167 (0.004%)
18 | Program 10 (Music Box): 405 (0.001%)
19 | Program 11 (Vibraphone): 1898 (0.007%)
20 | Program 12 (Marimba): 924 (0.003%)
21 | Program 13 (Xylophone): 440 (0.002%)
22 | Program 14 (Tubular Bells): 717 (0.003%)
23 | Program 15 (Dulcimer): 133 (0.000%)
24 | Program 16 (Drawbar Organ): 883 (0.003%)
25 | Program 17 (Percussive Organ): 874 (0.003%)
26 | Program 18 (Rock Organ): 1627 (0.006%)
27 | Program 19 (Church Organ): 520 (0.002%)
28 | Program 20 (Reed Organ): 158 (0.001%)
29 | Program 21 (Accordion): 997 (0.004%)
30 | Program 22 (Harmonica): 808 (0.003%)
31 | Program 23 (Tango Accordion): 343 (0.001%)
32 | Program 24 (Acoustic Guitar (nylon)): 3478 (0.013%)
33 | Program 25 (Acoustic Guitar (steel)): 7235 (0.027%)
34 | Program 26 (Electric Guitar (jazz)): 3503 (0.013%)
35 | Program 27 (Electric Guitar (clean)): 4416 (0.016%)
36 | Program 28 (Electric Guitar (muted)): 2767 (0.010%)
37 | Program 29 (Overdriven Guitar): 3219 (0.012%)
38 | Program 30 (Distortion Guitar): 3253 (0.012%)
39 | Program 31 (Guitar Harmonics): 299 (0.001%)
40 | Program 32 (Acoustic Bass): 3075 (0.011%)
41 | Program 33 (Electric Bass (finger)): 6316 (0.023%)
42 | Program 34 (Electric Bass (pick)): 831 (0.003%)
43 | Program 35 (Fretless Bass): 3110 (0.011%)
44 | Program 36 (Slap Bass 1): 354 (0.001%)
45 | Program 37 (Slap Bass 2): 298 (0.001%)
46 | Program 38 (Synth Bass 1): 1477 (0.005%)
47 | Program 39 (Synth Bass 2): 887 (0.003%)
48 | Program 40 (Violin): 1558 (0.006%)
49 | Program 41 (Viola): 672 (0.002%)
50 | Program 42 (Cello): 909 (0.003%)
51 | Program 43 (Contrabass): 750 (0.003%)
52 | Program 44 (Tremolo Strings): 666 (0.002%)
53 | Program 45 (Pizzicato Strings): 2260 (0.008%)
54 | Program 46 (Orchestral Harp): 1304 (0.005%)
55 | Program 47 (Timpani): 2109 (0.008%)
56 | Program 48 (String Ensembles 1): 9898 (0.036%)
57 | Program 49 (String Ensembles 2): 3749 (0.014%)
58 | Program 50 (SynthStrings 1): 3028 (0.011%)
59 | Program 51 (SynthStrings 2): 740 (0.003%)
60 | Program 52 (Choir Aahs): 5394 (0.020%)
61 | Program 53 (Voice Oohs): 2101 (0.008%)
62 | Program 54 (Synth Voice): 1288 (0.005%)
63 | Program 55 (Orchestra Hit): 460 (0.002%)
64 | Program 56 (Trumpet): 5559 (0.020%)
65 | Program 57 (Trombone): 4676 (0.017%)
66 | Program 58 (Tuba): 2099 (0.008%)
67 | Program 59 (Muted Trumpet): 649 (0.002%)
68 | Program 60 (French Horn): 3926 (0.014%)
69 | Program 61 (Brass Section): 2340 (0.009%)
70 | Program 62 (Synth Brass 1): 994 (0.004%)
71 | Program 63 (Synth Brass 2): 449 (0.002%)
72 | Program 64 (Soprano Sax): 544 (0.002%)
73 | Program 65 (Alto Sax): 3359 (0.012%)
74 | Program 66 (Tenor Sax): 2479 (0.009%)
75 | Program 67 (Baritone Sax): 1081 (0.004%)
76 | Program 68 (Oboe): 2669 (0.010%)
77 | Program 69 (English Horn): 499 (0.002%)
78 | Program 70 (Bassoon): 2029 (0.007%)
79 | Program 71 (Clarinet): 4666 (0.017%)
80 | Program 72 (Piccolo): 1202 (0.004%)
81 | Program 73 (Flute): 4962 (0.018%)
82 | Program 74 (Recorder): 434 (0.002%)
83 | Program 75 (Pan Flute): 1034 (0.004%)
84 | Program 76 (Blown Bottle): 133 (0.000%)
85 | Program 77 (Shakuhachi): 205 (0.001%)
86 | Program 78 (Whistle): 359 (0.001%)
87 | Program 79 (Ocarina): 345 (0.001%)
88 | Program 80 (Lead 1 (square)): 1422 (0.005%)
89 | Program 81 (Lead 2 (sawtooth)): 2021 (0.007%)
90 | Program 82 (Lead 3 (calliope)): 913 (0.003%)
91 | Program 83 (Lead 4 (chiff)): 127 (0.000%)
92 | Program 84 (Lead 5 (charang)): 259 (0.001%)
93 | Program 85 (Lead 6 (voice)): 281 (0.001%)
94 | Program 86 (Lead 7 (fifths)): 84 (0.000%)
95 | Program 87 (Lead 8 (bass + lead)): 876 (0.003%)
96 | Program 88 (Pad 1 (new age)): 931 (0.003%)
97 | Program 89 (Pad 2 (warm)): 1222 (0.004%)
98 | Program 90 (Pad 3 (polysynth)): 725 (0.003%)
99 | Program 91 (Pad 4 (choir)): 669 (0.002%)
100 | Program 92 (Pad 5 (bowed)): 260 (0.001%)
101 | Program 93 (Pad 6 (metallic)): 241 (0.001%)
102 | Program 94 (Pad 7 (halo)): 364 (0.001%)
103 | Program 95 (Pad 8 (sweep)): 508 (0.002%)
104 | Program 96 (FX 1 (rain)): 204 (0.001%)
105 | Program 97 (FX 2 (soundtrack)): 87 (0.000%)
106 | Program 98 (FX 3 (crystal)): 271 (0.001%)
107 | Program 99 (FX 4 (atmosphere)): 478 (0.002%)
108 | Program 100 (FX 5 (brightness)): 754 (0.003%)
109 | Program 101 (FX 6 (goblins)): 121 (0.000%)
110 | Program 102 (FX 7 (echoes)): 301 (0.001%)
111 | Program 103 (FX 8 (sci-fi)): 144 (0.001%)
112 | Program 104 (Sitar): 206 (0.001%)
113 | Program 105 (Banjo): 474 (0.002%)
114 | Program 106 (Shamisen): 135 (0.000%)
115 | Program 107 (Koto): 164 (0.001%)
116 | Program 108 (Kalimba): 224 (0.001%)
117 | Program 109 (Bag pipe): 87 (0.000%)
118 | Program 110 (Fiddle): 243 (0.001%)
119 | Program 111 (Shanai): 33 (0.000%)
120 | Program 112 (Tinkle Bell): 110 (0.000%)
121 | Program 113 (Agogo): 53 (0.000%)
122 | Program 114 (Steel Drums): 181 (0.001%)
123 | Program 115 (Woodblock): 143 (0.001%)
124 | Program 116 (Taiko Drum): 235 (0.001%)
125 | Program 117 (Melodic Tom): 187 (0.001%)
126 | Program 118 (Synth Drum): 362 (0.001%)
127 | Program 119 (Reverse Cymbal): 1228 (0.005%)
128 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 346 (0.001%)
129 | Program 121 (Breath Noise, Flute Key Click): 72 (0.000%)
130 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 407 (0.001%)
131 | Program 123 (Bird Tweet, Dog, Horse Gallop): 85 (0.000%)
132 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 185 (0.001%)
133 | Program 125 (Helicopter, Car Sounds): 158 (0.001%)
134 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 185 (0.001%)
135 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 208 (0.001%)
136 | """
137 |
138 | if __name__ == "__main__":
139 | from random import shuffle
140 |
141 | import numpy as np
142 | from miditok.constants import MIDI_INSTRUMENTS, SCORE_LOADING_EXCEPTION
143 | from symusic import Score
144 | from tqdm import tqdm
145 | from transformers.trainer_utils import set_seed
146 |
147 | from utils.baseline import is_score_valid, mmm
148 |
149 | set_seed(mmm.seed)
150 |
151 | NUM_FILES = 100000
152 | mmm.tokenizer.config.programs = list(range(-1, 128))
153 |
154 | # Iterate over files
155 | dataset_files_paths = mmm.dataset_files_paths
156 | shuffle(dataset_files_paths)
157 | dataset_files_paths = dataset_files_paths[:NUM_FILES]
158 | all_programs = []
159 | for file_path in tqdm(dataset_files_paths, desc="Analyzing files"):
160 | try:
161 | score = Score(file_path)
162 | except SCORE_LOADING_EXCEPTION:
163 | continue
164 | if is_score_valid(score):
165 | score = mmm.tokenizer.preprocess_score(score)
166 | all_programs += [
167 | track.program if not track.is_drum else -1 for track in score.tracks
168 | ]
169 |
170 | all_programs = np.array(all_programs)
171 | for program in range(-1, 128):
172 | num_occurrences = len(np.where(all_programs == program)[0])
173 | ratio = num_occurrences / len(all_programs)
174 | print( # noqa: T201
175 | f"Program {program} ("
176 | f"{'Drums' if program == -1 else MIDI_INSTRUMENTS[program]['name']}): "
177 | f"{num_occurrences} ({ratio:.3f}%)"
178 | )
179 |
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Curated Evaluation Set/MIDI Performance (Expressive MIDI Tracks)/Saarland.csv:
--------------------------------------------------------------------------------
1 | filepath,trackNum,instrument,isDrum,hasTimeSignatures,medianMetricDepth,tpq,velocity_per_track,onset_per_track,velocity_entropy_per_track,onset_entropy_per_track
2 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV849-01_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[70],[54],[2.242208320260014],[2.1877194014196313]
3 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV849-02_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[70],[57],[2.1944598281059347],[2.1227598342494285]
4 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV871-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[50],[57],[2.2215143362956855],[2.1560432773081404]
5 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV871-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[56],[57],[2.203766563904754],[2.1007718471971693]
6 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV875-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[45],[54],[2.301780586273925],[2.1838442915604994]
7 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV875-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[49],[52],[2.1812860620817767],[2.0979846011128274]
8 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV888-01_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[64],[52],[2.1082250272983973],[2.1107298161412866]
9 | Eval-subsets/Saarland Music Data (SMD)/Bach_BWV888-02_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[57],[56],[2.206863231520465],[2.0958697321118414]
10 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-01_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[110],[58],[2.2161044801654106],[2.0749961898767566]
11 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-02_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[96],[54],[2.2231056737689365],[2.1168599493420825]
12 | Eval-subsets/Saarland Music Data (SMD)/Bartok_SZ080-03_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[108],[58],[2.2144987525393445],[2.0118768121066197]
13 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-01_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[99],[58],[2.2332564927802374],[2.1444039613100063]
14 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-02_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[81],[55],[2.230988423853517],[2.2014043078234637]
15 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op027No1-03_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.2316670391122067],[2.1411010733691795]
16 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-01_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[95],[58],[2.27603091366256],[2.149934476450946]
17 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-02_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[81],[57],[2.21160805229714],[2.2032513200533073]
18 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_Op031No2-03_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[86],[58],[2.3149297700033453],[2.15484620927114]
19 | Eval-subsets/Saarland Music Data (SMD)/Beethoven_WoO080_001_20081107-SMD.mid,0,0,FALSE,TRUE,12,960,[106],[115],[2.1725154903213344],[2.0529039299803067]
20 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op005-01_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[110],[58],[2.253343013720334],[2.1357865892808614]
21 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op010No1_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.2754737287354354],[2.1821638405378394]
22 | Eval-subsets/Saarland Music Data (SMD)/Brahms_Op010No2_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[98],[58],[2.213394918335671],[2.2448515165885334]
23 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op010-03_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.1045524249432686],[2.1659579392860815]
24 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op010-04_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[83],[58],[2.2542927779942095],[2.159979927040184]
25 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op026No1_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[96],[58],[2.202725184192824],[2.2354584009352236]
26 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op026No2_005_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.1881458947675463],[2.11692043120291]
27 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-01_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[66],[47],[1.9361231770370437],[2.071933547679693]
28 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-03_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[62],[52],[2.236948427679256],[2.2063719586526718]
29 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-04_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[73],[50],[2.1465632974667974],[2.1917559128738]
30 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-11_003_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[59],[44],[2.0073681046323375],[2.286339826430251]
31 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-15_006_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[90],[57],[2.2794211506178277],[2.157003822283769]
32 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op028-17_005_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[89],[58],[2.151500180755954],[2.155545080515562]
33 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op029_004_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[86],[58],[2.0610677568177698],[2.1684607976033172]
34 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op048No1_007_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[95],[58],[2.203475547767666],[2.142876929559246]
35 | Eval-subsets/Saarland Music Data (SMD)/Chopin_Op066_006_20100611-SMD.mid,0,0,FALSE,TRUE,12,480,[91],[58],[2.0558183388968496],[2.1855233823119002]
36 | Eval-subsets/Saarland Music Data (SMD)/Haydn_Hob017No4_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[89],[58],[2.1828314668975484],[2.16630610134019]
37 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-01_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.2486043974740966],[2.1748447884374515]
38 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-02_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[90],[58],[2.2494194922620165],[2.196055535021124]
39 | Eval-subsets/Saarland Music Data (SMD)/Haydn_HobXVINo52-03_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[88],[58],[2.2193096916376054],[2.1728952328180213]
40 | Eval-subsets/Saarland Music Data (SMD)/Liszt_AnnesDePelerinage-LectureDante_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[113],[58],[2.2236645223841704],[2.0992465021883246]
41 | Eval-subsets/Saarland Music Data (SMD)/Liszt_KonzertetuedeNo2LaLeggierezza_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.0325195357473147],[1.8893376557365782]
42 | Eval-subsets/Saarland Music Data (SMD)/Liszt_VariationenBachmotivWeinenKlagenSorgenZagen_001_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[119],[58],[2.184494479552385],[2.127697490110233]
43 | Eval-subsets/Saarland Music Data (SMD)/Mozart_KV265_006_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[84],[58],[2.252625013291874],[2.157166107071062]
44 | Eval-subsets/Saarland Music Data (SMD)/Mozart_KV398_002_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[84],[58],[2.2460228428920783],[2.200669693579367]
45 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-01_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[103],[58],[2.1822623783379704],[2.1523931604966497]
46 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-02_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.201395213629322],[2.229464503041128]
47 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninoff_Op036-03_007_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[109],[58],[2.1848708896471534],[2.1348497473142523]
48 | Eval-subsets/Saarland Music Data (SMD)/Rachmaninov_Op039No1_002_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[100],[58],[2.266573511283351],[2.158250339709647]
49 | Eval-subsets/Saarland Music Data (SMD)/Ravel_JeuxDEau_008_20110315-SMD.mid,0,0,FALSE,TRUE,12,480,[102],[58],[2.1628408923677926],[2.20423179159176]
50 | Eval-subsets/Saarland Music Data (SMD)/Ravel_ValsesNoblesEtSentimentales_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[105],[58],[2.204846504136731],[2.278582113192171]
51 | Eval-subsets/Saarland Music Data (SMD)/Skryabin_Op008No8_003_20090916-SMD.mid,0,0,FALSE,TRUE,12,480,[79],[57],[2.1402922752695988],[2.182203169781139]
--------------------------------------------------------------------------------
/loops_nomml/process_file_fast.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Dict, Tuple, Any, List
3 | import os
4 | import numpy as np
5 | from miditok.utils import get_bars_ticks, get_beats_ticks
6 | from miditok.constants import CLASS_OF_INST, INSTRUMENT_CLASSES
7 | from symusic import Score, Track, TimeSignature
8 |
9 | from corr_mat_fast import calc_correlation, get_valid_loops
10 | from note_set_fast import compute_note_sets
11 |
12 | MAX_NOTES_PER_TRACK = 50000
13 | MIN_NOTES_PER_TRACK = 5
14 |
15 | # New: sanity threshold for tick values (should match other modules)
16 | _MAX_TICK = 10 ** 9
17 |
18 | # Optional: allow enabling numba/cuda via environment variables here as well.
19 | # Example:
20 | # export USE_NUMBA=1
21 | # export USE_CUDA=1
22 | # The corr_mat_fast module reads these env vars at import time.
23 |
24 | def get_instrument_type(track) -> str:
25 | """
26 | Determines MIDI instrument class of a track.
27 |
28 | :param track: A pretty_midi.Instrument object
29 | :return: name of instrument class
30 | """
31 | if track.is_drum:
32 | return "Drums"
33 |
34 | program_number = track.program # MIDI program number (0–127)
35 | instrument_class_index = CLASS_OF_INST[program_number]
36 | instrument_class_name = INSTRUMENT_CLASSES[instrument_class_index]["name"]
37 |
38 | return instrument_class_name
39 |
40 |
41 | def create_loop_dict(endpoint_data: Tuple[int, int, float, float], track_idx: int, instrument_type: str) -> Dict[str,Any]:
42 | """
43 | Formats loop metadata into a dictionary for dataset generation
44 | """
45 | start, end, beats, density = endpoint_data
46 | return {
47 | "track_idx": track_idx,
48 | "instrument_type": instrument_type,
49 | "start": start,
50 | "end": end,
51 | "duration_beats": beats,
52 | "note_density": density
53 | }
54 |
55 |
56 | def detect_loops_from_path(file_info: Dict) -> Dict[str,List]:
57 | """
58 | Given a MIDI file, locate all loops present across all of its tracks
59 | """
60 | file_path = file_info['file_path']
61 | if isinstance(file_path, list):
62 | file_path = file_path[0]
63 | try:
64 | score = Score(file_path)
65 | except Exception:
66 | # Unable to parse score (malformed file) -> skip file
67 | print(f"Unable to parse score for {file_path}, skipping")
68 | return {
69 | "track_idx": [],
70 | "instrument_type": [],
71 | "start": [],
72 | "end": [],
73 | "duration_beats": [],
74 | "note_density": [],
75 | }
76 | return detect_loops(score, file_path=file_path)
77 |
78 |
79 | def detect_loops(score: Score, file_path: str = None) -> Dict[str,List]:
80 | """
81 | Given a MIDI score, locate all loops present across off the tracks
82 | """
83 | data = {
84 | "track_idx": [],
85 | "instrument_type": [],
86 | "start": [],
87 | "end": [],
88 | "duration_beats": [],
89 | "note_density": [],
90 | }
91 | if file_path is not None:
92 | data["file_path"] = []
93 | # Check that there is a time signature. There might be none with abc files
94 | if len(score.time_signatures) == 0:
95 | score.time_signatures.append(TimeSignature(0, 4, 4))
96 |
97 | # Extract bars and beats ticks defensively
98 | try:
99 | bars_ticks_raw = get_bars_ticks(score)
100 | beats_ticks_raw = get_beats_ticks(score)
101 | except Exception:
102 | print(f"Skipping, couldn't extract bars/beats for {file_path or 'score'} due to malformed events")
103 | return data
104 |
105 | # sanitize arrays: ensure numeric, finite, reasonable magnitude
106 | try:
107 | bars_ticks = np.asarray(bars_ticks_raw, dtype=np.int64)
108 | bars_ticks = bars_ticks[np.isfinite(bars_ticks)]
109 | bars_ticks = bars_ticks[(bars_ticks >= 0) & (bars_ticks <= _MAX_TICK)]
110 | except Exception:
111 | bars_ticks = np.array([], dtype=np.int64)
112 |
113 | try:
114 | beats_ticks = np.asarray(beats_ticks_raw, dtype=np.int64)
115 | beats_ticks = beats_ticks[np.isfinite(beats_ticks)]
116 | beats_ticks = beats_ticks[(beats_ticks >= 0) & (beats_ticks <= _MAX_TICK)]
117 | except Exception:
118 | beats_ticks = np.array([], dtype=np.int64)
119 |
120 | for idx, track in enumerate(score.tracks):
121 | # Basic track length checks
122 | try:
123 | n_notes = len(track.notes)
124 | except Exception:
125 | # malformed track object: skip
126 | continue
127 | if n_notes > MAX_NOTES_PER_TRACK or n_notes < MIN_NOTES_PER_TRACK:
128 | continue
129 |
130 | # cut bars_ticks at the end of the track defensively
131 | try:
132 | track_end = int(track.end())
133 | except Exception:
134 | # malformed end time: skip track
135 | continue
136 |
137 | if bars_ticks.size and np.any(bars_ticks > track_end):
138 | try:
139 | bars_ticks_track = bars_ticks[:np.nonzero(bars_ticks > track_end)[0][0]]
140 | except Exception:
141 | bars_ticks_track = bars_ticks
142 | else:
143 | bars_ticks_track = bars_ticks
144 |
145 | if beats_ticks.size and np.any(beats_ticks > track_end):
146 | try:
147 | beats_ticks_track = beats_ticks[:np.nonzero(beats_ticks > track_end)[0][0]]
148 | except Exception:
149 | beats_ticks_track = beats_ticks
150 | else:
151 | beats_ticks_track = beats_ticks
152 |
153 | if len(bars_ticks_track) > MAX_NOTES_PER_TRACK:
154 | # ill-formed bars for this track
155 | continue
156 |
157 | # instrument type
158 | try:
159 | instrument_type = get_instrument_type(track)
160 | except Exception:
161 | instrument_type = "Unknown"
162 |
163 | # Compute note sets with defensive handling of malformed note events
164 | try:
165 | note_sets = compute_note_sets(track.notes, bars_ticks_track)
166 | except Exception:
167 | # compute_note_sets failed due to malformed notes -> skip track
168 | continue
169 |
170 | # If compute_note_sets returned nothing and there are no bars, skip
171 | if not note_sets and (bars_ticks_track.size == 0):
172 | continue
173 |
174 | # Defensive: ensure note_sets entries look sane
175 | bad_ns = False
176 | for ns in note_sets:
177 | try:
178 | if ns.start is None or ns.end is None:
179 | bad_ns = True
180 | break
181 | if ns.start < 0 or ns.end < 0 or ns.start > _MAX_TICK or ns.end > _MAX_TICK:
182 | bad_ns = True
183 | break
184 | except Exception:
185 | bad_ns = True
186 | break
187 | if bad_ns:
188 | # skip track with malformed note sets
189 | continue
190 |
191 | # Compute correlation and loops with try/except so a single bad track doesn't crash everything
192 | try:
193 | lead_mat = calc_correlation(note_sets)
194 | except Exception:
195 | # correlation failed (malformed note_sets) -> skip track
196 | continue
197 |
198 | try:
199 | _, loop_endpoints = get_valid_loops(note_sets, lead_mat, beats_ticks_track)
200 | except Exception:
201 | # loop detection failed -> skip track
202 | continue
203 |
204 | for endpoint in loop_endpoints:
205 | loop_dict = create_loop_dict(endpoint, idx, instrument_type)
206 | for key in loop_dict.keys():
207 | data[key].append(loop_dict[key])
208 | if file_path is not None:
209 | data["file_path"].append(file_path)
210 |
211 | return data
--------------------------------------------------------------------------------
/scripts/test_extract.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pretty_midi\n",
10 | "from corr_mat import calc_correlation, get_valid_loops\n",
11 | "from track import Track\n",
12 | "from util import get_instrument_type, create_loop_dict\n",
13 | "import os\n",
14 | "from tqdm import tqdm"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "test_file = \"./test_midi.mid\"\n",
24 | "pm = pretty_midi.PrettyMIDI(test_file)"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 3,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "name": "stdout",
34 | "output_type": "stream",
35 | "text": [
36 | "found 1 loops in Instrument(program=0, is_drum=True, name=\"GOT TO HAVE YOUR LOVE \")\n",
37 | "0 18.113216000000005 22.641519999999996 (4, 4) 8.0 DRUM 3.875\n",
38 | "found 1 loops in Instrument(program=39, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
39 | "1 9.056607999999999 13.584912000000005 (4, 4) 8.0 BASS 2.0\n",
40 | "found 1 loops in Instrument(program=4, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
41 | "2 33.96227999999998 43.01888799999996 (4, 4) 16.0 PIANO 0.5625\n",
42 | "found 1 loops in Instrument(program=48, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
43 | "3 24.905671999999992 29.433975999999983 (4, 4) 8.0 ENSEMBLE 0.75\n",
44 | "found 1 loops in Instrument(program=48, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
45 | "4 24.905671999999992 29.433975999999983 (4, 4) 8.0 ENSEMBLE 0.75\n",
46 | "found 1 loops in Instrument(program=28, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
47 | "5 27.169823999999988 31.69812799999998 (4, 4) 8.0 GUITAR 1.75\n",
48 | "found 0 loops in Instrument(program=2, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
49 | "found 0 loops in Instrument(program=82, is_drum=False, name=\"GOT TO HAVE YOUR LOVE \")\n",
50 | "6 total loops in 8 tracks\n"
51 | ]
52 | }
53 | ],
54 | "source": [
55 | "final_loops = []\n",
56 | "for idx, instrument in enumerate(pm.instruments):\n",
57 | " instrument_type = get_instrument_type(instrument)\n",
58 | " track = Track(pm, instrument)\n",
59 | " note_list = track.notes\n",
60 | " lead_mat, lead_dur = calc_correlation(note_list)\n",
61 | " loops, loop_endpoints = get_valid_loops(track, lead_mat, lead_dur)\n",
62 | " print(f\"found {len(loops)} loops in {instrument}\")\n",
63 | "\n",
64 | " for endpoint in loop_endpoints:\n",
65 | " start, end, beats, density = endpoint\n",
66 | " time_sig = track.get_time_sig_at_time(start)\n",
67 | " print(idx, start, end, time_sig, beats, instrument_type, density)\n",
68 | " for loop in loops:\n",
69 | " final_loops.append(loop)\n",
70 | "print(f\"{len(final_loops)} total loops in {len(pm.instruments)} tracks\")"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 4,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "def run_file(file_path, name):\n",
80 | " try:\n",
81 | " pm = pretty_midi.PrettyMIDI(file_path)\n",
82 | " except:\n",
83 | " print(f\"failed to parse {file_path}, skipping\")\n",
84 | " return 0,0,[]\n",
85 | " \n",
86 | " total_loops = 0\n",
87 | " loops = []\n",
88 | " for idx, instrument in enumerate(pm.instruments):\n",
89 | " instrument_type = get_instrument_type(instrument)\n",
90 | " track = Track(pm, instrument)\n",
91 | " note_list = track.notes\n",
92 | " lead_mat, lead_dur = calc_correlation(note_list)\n",
93 | " full_loops, loop_endpoints = get_valid_loops(track, lead_mat, lead_dur)\n",
94 | " for endpoint in loop_endpoints:\n",
95 | " time_sig = track.get_time_sig_at_time(endpoint[0])\n",
96 | " if time_sig is None:\n",
97 | " continue\n",
98 | " loop_dict = create_loop_dict(endpoint, idx, instrument_type, time_sig, name)\n",
99 | " loops.append(loop_dict)\n",
100 | " for loop_list in full_loops:\n",
101 | " if loop_list[0].duration != 0 or loop_list[-1].duration != 0:\n",
102 | " loop_list[0]\n",
103 | " print(name, loop_list[0], loop_list[-1])\n",
104 | " total_loops += len(loop_endpoints)\n",
105 | " return total_loops, len(pm.instruments), loops"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 5,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "data": {
115 | "text/plain": [
116 | "25050"
117 | ]
118 | },
119 | "execution_count": 5,
120 | "metadata": {},
121 | "output_type": "execute_result"
122 | }
123 | ],
124 | "source": [
125 | "full_directory = \"D:\\\\Documents\\\\GigaMIDI\\\\Final_GigaMIDI_TISMIR\\\\Validatation-10%\\\\GigaMIDI-Val-Drum+Music-MD5\"\n",
126 | "len(os.listdir(full_directory))"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 6,
132 | "metadata": {},
133 | "outputs": [
134 | {
135 | "name": "stderr",
136 | "output_type": "stream",
137 | "text": [
138 | " 0%| | 0/100 [00:00, ?it/s]"
139 | ]
140 | },
141 | {
142 | "name": "stdout",
143 | "output_type": "stream",
144 | "text": [
145 | "failed to parse D:\\Documents\\GigaMIDI\\Final_GigaMIDI_TISMIR\\Validatation-10%\\GigaMIDI-Val-Drum+Music-MD5\\.DS_Store, skipping\n"
146 | ]
147 | },
148 | {
149 | "name": "stderr",
150 | "output_type": "stream",
151 | "text": [
152 | " 20%|██ | 20/100 [00:09<00:24, 3.27it/s]d:\\Documents\\GigaMIDI\\midi_loop_detection\\.venv\\lib\\site-packages\\pretty_midi\\pretty_midi.py:100: RuntimeWarning: Tempo, Key or Time signature change events found on non-zero tracks. This is not a valid type 0 or type 1 MIDI file. Tempo, Key or Time Signature may be wrong.\n",
153 | " warnings.warn(\n",
154 | "100%|██████████| 100/100 [01:22<00:00, 1.21it/s]"
155 | ]
156 | },
157 | {
158 | "name": "stdout",
159 | "output_type": "stream",
160 | "text": [
161 | "Found 1441 loops in 929 tracks across 100 files\n"
162 | ]
163 | },
164 | {
165 | "name": "stderr",
166 | "output_type": "stream",
167 | "text": [
168 | "\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "total_loops = 0\n",
174 | "total_tracks = 0\n",
175 | "num_files = 100\n",
176 | "all_loops = []\n",
177 | "for file in tqdm(os.listdir(full_directory)[:num_files]):\n",
178 | " full_path = os.path.join(full_directory, file)\n",
179 | " string_path = full_path\n",
180 | " num_loops, num_tracks, loops = run_file(full_path, string_path)\n",
181 | " total_loops += num_loops\n",
182 | " total_tracks += num_tracks\n",
183 | " for loop in loops:\n",
184 | " if len(loop) > 0:\n",
185 | " all_loops.append(loop)\n",
186 | "print(f\"Found {total_loops} loops in {total_tracks} tracks across {num_files} files\")"
187 | ]
188 | }
189 | ],
190 | "metadata": {
191 | "kernelspec": {
192 | "display_name": ".venv",
193 | "language": "python",
194 | "name": "python3"
195 | },
196 | "language_info": {
197 | "codemirror_mode": {
198 | "name": "ipython",
199 | "version": 3
200 | },
201 | "file_extension": ".py",
202 | "mimetype": "text/x-python",
203 | "name": "python",
204 | "nbconvert_exporter": "python",
205 | "pygments_lexer": "ipython3",
206 | "version": "3.10.5"
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 2
211 | }
212 |
--------------------------------------------------------------------------------
/MIDI-GPT-Loop/README.md:
--------------------------------------------------------------------------------
1 | # Loop Generation
2 | ## Generating and evaluating loops using MIDI-GPT
3 |
4 | Here we present the scripts for loop generation as well as evaluation scripts used for our paper.
5 |
6 | ### MMM
7 |
8 | This repository is the library to run and train our MIDI-GPT-Loop model. It is a fork of the `MMM` repository from the `Metacreation Lab` adapted for training and inference with loops. This library also uses a fork of `MidiTok` allowing the integration of loop tokens.
9 |
10 | The main entry point is `inference.py`, where `generate` and `generate_batch`, may be used for MIDI generation
11 |
12 | ### MMM-Loop
13 |
14 | This library holds the scripts used to generate the synthetic MIDI data, and evaluate the accuracy of the generation NOMML controls and the generated loops.
15 |
16 | # Installation and Setup
17 |
18 | > **Note:** The following scripts were created for computing on [Compute Canada](https://docs.alliancecan.ca/). Therefore, slight modifications (file organization, module imports, environment variables) may be needed in order to run these scripts elsewhere.
19 |
20 | The dependencies installed are those needed for inference as well as training, therefore some may not be needed if you do not wish to train a model. The following setup is identical for both submodules, MMM and MMM-Loop
21 |
22 | ## On Compute Canada
23 |
24 | On **Canada Canada**, load the correct modules and install dependencies:
25 |
26 | ```bash
27 | cd MMM-Loop
28 | bash slurm/install.sh
29 | ```
30 |
31 | ## Elsewhere
32 |
33 | ### 🧰 Dependencies
34 |
35 | The code depends on several Python packages, some of which may require special installation steps or system libraries.
36 |
37 | #### 📦 Key Python Dependencies
38 |
39 | | Package | Version | Notes |
40 | | ------------------------------------------- | ---------- | --------------------------------------------------------- |
41 | | `python` | 3.11 | Required for compatibility |
42 | | `symusic` | 0.5.0 | For symbolic music representations |
43 | | `MidiTok` | expressive | Custom fork installed from GitHub |
44 | | `transformers` | 4.49.0 | Hugging Face Transformers |
45 | | `accelerate` | 1.4.0 | Hugging Face Accelerate |
46 | | `tensorboard` | 2.19.0 | TensorBoard for logging |
47 | | `flash-attn` | 2.5.7 | May require building from source (see instructions below) |
48 | | `deepspeed` | 0.14.4 | For model parallelism and training acceleration |
49 | | `datasets` | 3.3.2 | Hugging Face Datasets |
50 | | `triton` | 3.1.0 | Required for some GPU optimizations |
51 | | `nvitop`, `pandas`, `matplotlib`, `sklearn` | Latest | Utility and visualization tools |
52 |
53 | ### 💻 Installation Instructions
54 |
55 | 1. Create the virtual environment
56 |
57 | Use Python 3.11:
58 |
59 | ```bash
60 | cd MMM
61 | ```
62 |
63 | or
64 |
65 | ```bash
66 | cd MMM-Loop
67 | ```
68 |
69 | ```bash
70 | python3.11 -m venv .venv
71 | source .venv/bin/activate
72 | ```
73 |
74 | If `python3.11` is not available, install it via pyenv or your system's package manager.
75 |
76 | 2. Install dependencies
77 |
78 | ```bash
79 | pip install --upgrade pip
80 |
81 | # Required packages
82 | pip install symusic==0.5.0
83 | pip install git+https://github.com/DaoTwenty/MidiTok@expressive
84 | pip install transformers==4.49.0 accelerate==1.4.0 tensorboard==2.19.0
85 | pip install deepspeed==0.14.4
86 | pip install datasets==3.3.2
87 | pip install triton==3.1.0
88 | pip install nvitop pandas matplotlib scikit-learn
89 | ```
90 |
91 | ### ⚡ Installing FlashAttention (Optional but Recommended)
92 |
93 | FlashAttention often provides significant speedups for training Transformer models, but may require a manual installation from source depending on your system and GPU.
94 |
95 | 1. Clone FlashAttention
96 |
97 | ```bash
98 | git clone https://github.com/Dao-AILab/flash-attention.git
99 | cd flash-attention
100 | git checkout v2.5.7
101 | ```
102 |
103 | 2. Install with `pip`
104 |
105 | ```bash
106 | pip install .
107 | ```
108 |
109 | > **Note:** You may need to have the following:
110 |
111 | - A CUDA-capable GPU (Compute Capability ≥ 7.5)
112 | - CUDA toolkit ≥ 11.8
113 | - Compatible PyTorch version (typically latest stable)
114 | - Refer to FlashAttention's official documentation for details.
115 |
116 | ### 📊 Verifying Installation
117 |
118 | Run the following to verify installed versions:
119 |
120 | ```bash
121 | pip freeze
122 | ```
123 |
124 | ### 💬 Notes
125 |
126 | - On Compute Canada, system modules like `gcc`, `arrow`, and `rust` were required. These are **not needed** if you can build FlashAttention from source locally.
127 | - If `arrow` or `rust` are required by specific packages, ensure they are available on your system (`brew`, `apt`, or via `conda`).
128 |
129 | # 🎶 Usage
130 |
131 | First, download the model via [https://1sfu-my.sharepoint.com/:u:/g/personal/pta63_sfu_ca/EbpBz06rnaJMtirT0DqvTFoBQSC2OqJ_gex88fenJ60CQQ?e=o24IMz](https://1sfu-my.sharepoint.com/:u:/g/personal/pta63_sfu_ca/EbpBz06rnaJMtirT0DqvTFoBQSC2OqJ_gex88fenJ60CQQ?e=o24IMz) and place it in ``MMM-Loop/models``
132 |
133 | ## Loop Generation
134 |
135 | ### 🧠 Environment Variables
136 |
137 | The script (`slurm/ge_synthetic_data.sh`) expects these repositories to be present:
138 |
139 | ```bash
140 | export PYTHONPATH=$PYTHONPATH:/path/to/MMM:/path/to/MMM-Loop
141 | ```
142 | > Replace `/path/to/MMM`and `/path/to/MMM-Loop` with the actual path where the repository is cloned.
143 |
144 | ### 🛠 Running Locally (Without SLURM)
145 |
146 | The SLURM batch loop can be mimicked using a local shell script or Python launcher. Here’s a basic loop for **manual local use**:
147 |
148 | ```bash
149 | #!/bin/bash
150 |
151 | source .venv/bin/activate
152 |
153 | MODEL=MIDI-GPT-Loop-model
154 | TOKENIZER=MMM_epl_mistral_tokenizer_with_acs.json
155 | CONFIG=slurm/gen_config.json
156 | NOMML=0
157 | NUM_GEN=1000
158 | BATCH=12
159 | OUTPUT="./SYNTHETIC_DATA"
160 | LABEL="V1"
161 |
162 | python -m src.generate_loops \
163 | --config $CONFIG \
164 | --model models/$MODEL \
165 | --tokenizer models/$TOKENIZER \
166 | --num_generations $NUM_GEN \
167 | --max_attempts $NUM_GEN \
168 | --batch $BATCH \
169 | --nomml $EFFECTIVE_NOMML \
170 | --output_folder $OUTPUT \
171 | --label $LABEL \
172 | --rank 0 &
173 | ```
174 |
175 | ## 📈 Loop Evaluation
176 |
177 | ### 📝 Script Overview
178 |
179 | The SLURM script (`slurm/eval_loops/sh`):
180 | - Aggregates CSV result files (`RESULTS_*.csv`) from a directory.
181 | - Merges them into a single `RESULTS.csv`.
182 | - Runs the evaluation script via Python: `src.evaluate`.
183 |
184 | ```bash
185 | #!/bin/bash
186 |
187 | SOURCE="./SYNTHETIC_DATA"
188 | LABEL="V1"
189 | OUTFILE="$SOURCE/$LABEL/RESULTS.csv"
190 |
191 | mkdir -p "$SOURCE/$LABEL"
192 | mkdir -p "$SOURCE/$LABEL/MIDI"
193 |
194 | # Clear existing merged file
195 | > "$OUTFILE"
196 |
197 | first=1
198 | for file in "$SOURCE/$LABEL"/RESULTS_*.csv; do
199 | if [ -f "$file" ]; then
200 | echo "Processing $file"
201 | if [ $first -eq 1 ]; then
202 | cat "$file" >> "$OUTFILE"
203 | first=0
204 | else
205 | tail -n +2 "$file" >> "$OUTFILE"
206 | fi
207 | fi
208 | done
209 |
210 | # Run evaluation script
211 | python -m src.evaluate --source "$SOURCE/$LABEL"
212 | ```
213 |
214 | ## Visualisation
215 |
216 | Create Cross-entropy graph of the evaluation (`slurm/plot_results.sh`)
217 |
218 | ```bash
219 | #!/bin/bash
220 |
221 | SOURCE="../SYNTHETIC_DATA"
222 | LABEL="V1"
223 | INPUT_DIR="$SOURCE/$LABEL"
224 |
225 | PLOT_ARGS=" \
226 | --input $INPUT_DIR \
227 | --output $INPUT_DIR \
228 | "
229 |
230 | source .venv/bin/activate
231 |
232 | python -m src.plot_results $PLOT_ARGS
233 | ```
234 |
235 | ### 🧾 Output
236 |
237 | Results are merged to: `SYNTHETIC_DATA/V1/RESULTS.csv`
--------------------------------------------------------------------------------
/Expressive music loop detector-NOMML12.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 5,
6 | "id": "e65838ae",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "data": {
11 | "text/plain": [
12 | "1094620"
13 | ]
14 | },
15 | "execution_count": 5,
16 | "metadata": {},
17 | "output_type": "execute_result"
18 | }
19 | ],
20 | "source": [
21 | "len(file_paths)"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "id": "0dee9022",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "import os\n",
32 | "import numpy as np\n",
33 | "import pandas as pd\n",
34 | "import ast\n",
35 | "from symusic import Score\n",
36 | "from loops_nomml.note_set import compute_note_sets\n",
37 | "import loops_nomml.corr_mat as corr\n",
38 | "from loops_nomml.corr_mat import get_valid_loops\n",
39 | "from joblib import Parallel, delayed\n",
40 | "from tqdm.notebook import tqdm\n",
41 | "\n",
42 | "# ─── Patch safe get_duration_beats ────────────────────────────────────────\n",
43 | "def safe_get_duration_beats(start: int, end: int, ticks_beats: list[int]) -> float:\n",
44 | " i0 = max((i for i, t in enumerate(ticks_beats) if t <= start), default=0)\n",
45 | " i1 = max((i for i, t in enumerate(ticks_beats) if t <= end), default=i0)\n",
46 | " return float(i1 - i0)\n",
47 | "corr.get_duration_beats = safe_get_duration_beats\n",
48 | "\n",
49 | "# ─── Constants ─────────────────────────────────────────────────────────────\n",
50 | "GM_GROUPS = [\n",
51 | " 'Piano','Chromatic Percussion','Organ','Guitar',\n",
52 | " 'Bass','Strings','Ensemble','Brass',\n",
53 | " 'Reed','Pipe','Synth Lead','Synth Pad',\n",
54 | " 'Synth Effects','Ethnic','Percussive','Sound Effects'\n",
55 | "]\n",
56 | "DRUM_GROUP = 'Drums'\n",
57 | "\n",
58 | "# ─── similarity + soft-count ──────────────────────────────────────────────\n",
59 | "def note_similarity(a, b, v_a, v_b, w_p=0.5, w_v=0.3, w_t=0.2, max_time_diff=0.05):\n",
60 | " p = len(a.pitches & b.pitches) / max(1, len(a.pitches | b.pitches))\n",
61 | " v = 1 - abs(v_a - v_b) / 127\n",
62 | " t = np.exp(-abs(a.start - b.start) / max_time_diff)\n",
63 | " return w_p*p + w_v*v + w_t*t\n",
64 | "\n",
65 | "def calc_correlation_soft_count(ns, vel_means, tau):\n",
66 | " N = len(ns)\n",
67 | " C = np.zeros((N, N), dtype=int)\n",
68 | " for j in range(1, N):\n",
69 | " if note_similarity(ns[0], ns[j], vel_means[0], vel_means[j]) >= tau and ns[0].is_barline():\n",
70 | " C[0, j] = 1\n",
71 | " for i in range(1, N-1):\n",
72 | " for j in range(i+1, N):\n",
73 | " sim = note_similarity(ns[i], ns[j], vel_means[i], vel_means[j])\n",
74 | " if sim >= tau and (C[i-1, j-1] > 0 or ns[i].is_barline()):\n",
75 | " C[i, j] = C[i-1, j-1] + 1\n",
76 | " return C\n",
77 | "\n",
78 | "# ─── loopability score ───────────────────────────────────────────────────\n",
79 | "def score_loopability(ns, vel_means, tau, alpha=0.7, beta=0.3):\n",
80 | " C = calc_correlation_soft_count(ns, vel_means, tau)\n",
81 | " N = len(ns)\n",
82 | " if N < 2:\n",
83 | " return 0.0\n",
84 | " S_max = C.max() / N\n",
85 | " S_den = C.sum() / (N*(N-1)/2)\n",
86 | " return alpha * S_max + beta * S_den\n",
87 | "\n",
88 | "# ─── Process one file ─────────────────────────────────────────────────────\n",
89 | "def process_file(path, melodic_tau=0.3, drum_tau=0.1):\n",
90 | " loops = []\n",
91 | " try:\n",
92 | " score = Score(path, ttype='tick')\n",
93 | " try:\n",
94 | " beat_ticks = score.beat_ticks()\n",
95 | " except:\n",
96 | " ppq = getattr(score, 'ticks_per_quarter', getattr(score, 'ppq', 480))\n",
97 | " beat_ticks = list(range(0, score.end()+1, ppq))\n",
98 | " bars = [beat_ticks[i] for i in range(0, len(beat_ticks), 4)]\n",
99 | "\n",
100 | " for ti, track in enumerate(score.tracks):\n",
101 | " is_drum = getattr(track, 'channel', None) == 9\n",
102 | " tau = drum_tau if is_drum else melodic_tau\n",
103 | "\n",
104 | " prog = getattr(track, 'program', None)\n",
105 | " if \"drums-only\" in path:\n",
106 | " group = DRUM_GROUP\n",
107 | " else:\n",
108 | " group = DRUM_GROUP if is_drum else (GM_GROUPS[prog // 8] if prog is not None else 'Unknown')\n",
109 | "\n",
110 | " ns = compute_note_sets(track.notes, bars)\n",
111 | " if len(ns) < 2:\n",
112 | " continue\n",
113 | " vel_means = [\n",
114 | " float(np.mean([n.velocity for n in track.notes\n",
115 | " if n.start == nset.start and n.end == nset.end]))\n",
116 | " if any(n.start == nset.start and n.end == nset.end for n in track.notes)\n",
117 | " else 0.0\n",
118 | " for nset in ns\n",
119 | " ]\n",
120 | "\n",
121 | " loopability = score_loopability(ns, vel_means, tau)\n",
122 | " C = calc_correlation_soft_count(ns, vel_means, tau)\n",
123 | " try:\n",
124 | " _, endpoints = get_valid_loops(\n",
125 | " ns, C, beat_ticks,\n",
126 | " min_rep_notes=0,\n",
127 | " min_rep_beats=1.0 if not is_drum else 0.5,\n",
128 | " min_beats=1.0 if not is_drum else 0.5,\n",
129 | " max_beats=32.0,\n",
130 | " min_loop_note_density=0.0\n",
131 | " )\n",
132 | " except IndexError:\n",
133 | " continue\n",
134 | "\n",
135 | " for start, end, dur, dens in endpoints:\n",
136 | " loops.append({\n",
137 | " 'track_idx': ti,\n",
138 | " 'MIDI program number': prog,\n",
139 | " 'instrument_group': group,\n",
140 | " 'loopability': loopability,\n",
141 | " 'start_tick': start,\n",
142 | " 'end_tick': end,\n",
143 | " 'duration_beats': dur,\n",
144 | " 'note_density': dens\n",
145 | " })\n",
146 | " except Exception as e:\n",
147 | " print(f\"[Error] {os.path.basename(path)}: {e}\")\n",
148 | " return loops\n",
149 | "\n",
150 | "# ─── 1) Load CSV & select only rows whose NOMML list contains a 12 ──────────\n",
151 | "df_input = pd.read_csv(\n",
152 | " \"Final_GigaMIDI_Loop_V2_path-instrument-NOMML-type.csv\",\n",
153 | " converters={'NOMML': ast.literal_eval}\n",
154 | ")\n",
155 | "\n",
156 | "# keep rows where the NOMML list has at least one 12\n",
157 | "df_input = df_input[df_input['NOMML'].apply(lambda lst: isinstance(lst, (list,tuple)) and 12 in lst)]\n",
158 | "\n",
159 | "file_paths = df_input['file_path'].tolist()\n",
160 | "\n",
161 | "# ─── 2) Chunk size 100,000 for checkpoint ───────────────────────────────────\n",
162 | "chunk_size = 100000\n",
163 | "\n",
164 | "# ─── 3) Process in chunks, checkpoint each chunk ───────────────────────────\n",
165 | "all_rows = []\n",
166 | "for idx in range(0, len(file_paths), chunk_size):\n",
167 | " chunk = file_paths[idx: idx + chunk_size]\n",
168 | " results = Parallel(n_jobs=-1, backend='loky')(\n",
169 | " delayed(process_file)(p) for p in tqdm(chunk, desc=f\"Files {idx+1}-{idx+len(chunk)}\")\n",
170 | " )\n",
171 | "\n",
172 | " # organize one row per file, unpacking loops into parallel arrays\n",
173 | " rows = []\n",
174 | " for path, loops in zip(chunk, results):\n",
175 | " rows.append({\n",
176 | " 'file_path': path,\n",
177 | " 'track_idx': [d['track_idx'] for d in loops],\n",
178 | " 'MIDI program number': [d['MIDI program number'] for d in loops],\n",
179 | " 'instrument_group': [d['instrument_group'] for d in loops],\n",
180 | " 'loopability': [d['loopability'] for d in loops],\n",
181 | " 'start_tick': [d['start_tick'] for d in loops],\n",
182 | " 'end_tick': [d['end_tick'] for d in loops],\n",
183 | " 'duration_beats': [d['duration_beats'] for d in loops],\n",
184 | " 'note_density': [d['note_density'] for d in loops]\n",
185 | " })\n",
186 | " df_chunk = pd.DataFrame(rows)\n",
187 | "\n",
188 | " # save checkpoint\n",
189 | " checkpoint = f\"loops_checkpoint_{idx//chunk_size + 1}.csv\"\n",
190 | " df_chunk.to_csv(checkpoint, index=False)\n",
191 | " print(f\"Saved checkpoint: {checkpoint}\")\n",
192 | "\n",
193 | " all_rows.extend(rows)\n",
194 | "\n",
195 | "# ─── 4) Final combined DataFrame ─────────────────────────────────────────────\n",
196 | "df_all = pd.DataFrame(all_rows)\n",
197 | "\n",
198 | "# ─── 5) Save the full output to CSV ─────────────────────────────────────────\n",
199 | "df_all.to_csv(\"loops_full_output.csv\", index=False)\n",
200 | "print(\"Saved full output: loops_full_output.csv\")\n"
201 | ]
202 | }
203 | ],
204 | "metadata": {
205 | "kernelspec": {
206 | "display_name": "Python 3 (ipykernel)",
207 | "language": "python",
208 | "name": "python3"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.9.12"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 5
225 | }
226 |
--------------------------------------------------------------------------------
/GigaMIDI/create_gigamidi_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3 python
2 |
3 | """Script to create the WebDataset for the GigaMIDI dataset."""
4 |
5 | from __future__ import annotations
6 |
7 | import json
8 | from typing import TYPE_CHECKING
9 |
10 | from datasets import Dataset
11 | from huggingface_hub import create_branch, upload_file
12 | from tqdm import tqdm
13 | from webdataset import ShardWriter
14 |
15 | from utils.GigaMIDI.GigaMIDI import _SPLITS
16 |
17 | if TYPE_CHECKING:
18 | from pathlib import Path
19 |
20 | from datasets import DatasetDict
21 |
22 |
23 | MAX_NUM_ENTRIES_PER_SHARD = 50000
24 | SUBSET_PATHS = {
25 | "all-instruments-with-drums": "drums+music",
26 | "drums-only": "drums",
27 | "no-drums": "music",
28 | }
29 |
30 |
31 | def create_webdataset_gigamidi(main_data_dir_path: Path) -> None:
32 | """
33 | Create the WebDataset shard for the GigaMIDI dataset.
34 |
35 | :param main_data_dir_path: path of the directory containing the datasets.
36 | """
37 | dataset_path = main_data_dir_path / "GigaMIDI_original"
38 | webdataset_path = main_data_dir_path / "GigaMIDI"
39 |
40 | # Load metadata
41 | md5_sid_matches_scores = {}
42 | with (
43 | main_data_dir_path / "MMD_METADATA" / "MMD_audio_matches.tsv"
44 | ).open() as matches_file:
45 | matches_file.seek(0)
46 | next(matches_file) # first line skipped
47 | for line in tqdm(matches_file, desc="Reading MMD match file"):
48 | midi_md5, score, audio_sid = line.split()
49 | if midi_md5 not in md5_sid_matches_scores:
50 | md5_sid_matches_scores[midi_md5] = []
51 | md5_sid_matches_scores[midi_md5].append((audio_sid, float(score)))
52 | sid_to_mbid = json.load(
53 | (main_data_dir_path / "MMD_METADATA" / "MMD_sid_to_mbid.json").open()
54 | )
55 |
56 | md5_genres = {}
57 | with (
58 | main_data_dir_path / "MMD_METADATA" / "MMD_audio_matched_genre.jsonl"
59 | ).open() as file:
60 | for row in tqdm(file, desc="Reading genres MMD metadata"):
61 | entry = json.loads(row)
62 | md5 = entry.pop("md5")
63 | md5_genres[md5] = entry
64 | md5_genres_scraped = {}
65 | with (
66 | main_data_dir_path / "MMD_METADATA" / "MMD_scraped_genre.jsonl"
67 | ).open() as file:
68 | for row in tqdm(file, desc="Reading scraped genres MMD metadata"):
69 | entry = json.loads(row)
70 | genres = []
71 | for genre_list in entry["genre"]:
72 | genres += genre_list
73 | md5_genres_scraped[entry["md5"]] = genres
74 | md5_artist_title_scraped = {}
75 | with (
76 | main_data_dir_path / "MMD_METADATA" / "MMD_scraped_title_artist.jsonl"
77 | ).open() as file:
78 | for row in tqdm(file, desc="Reading scraped titles/artists MMD metadata"):
79 | entry = json.loads(row)
80 | md5_artist_title_scraped[entry["md5"]] = entry["title_artist"][0]
81 | md5_expressive = {}
82 | with (
83 | dataset_path / "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv"
84 | ).open() as file:
85 | file.seek(0)
86 | next(file) # skipping first row (header)
87 | for row in tqdm(file, desc="Reading expressiveness metadata"):
88 | parts = row.split(",")
89 | md5 = parts[0].split("/")[-1].split(".")[0]
90 | if md5 not in md5_expressive:
91 | md5_expressive[md5] = []
92 | md5_expressive[md5].append(int(parts[5]))
93 | """md5_loop = {}
94 | with (
95 | dataset_path
96 | / "GigaMIDI-combined-non-expressive-loop-data"
97 | / "Expressive_Performance_Detection_NOMML_gigamidi_tismir.csv"
98 | ).open() as file:
99 | file.seek(0)
100 | next(file) # skipping first row (header)
101 | for row in tqdm(file, desc="Reading loops metadata"):
102 | parts = row.split(",")
103 | md5 = parts[0].split("/")[-1].split(".")[0]
104 | if md5 not in md5_loop:
105 | md5_loop[md5] = {}
106 | track_idx = parts[1]
107 | if track_idx not in md5_loop[md5]:
108 | md5_loop[md5][track_idx] = []
109 | md5_loop[md5][track_idx].append(int(parts[5]))"""
110 |
111 | # Sharding the data into tar archives
112 | num_shards = {}
113 | for subset, subset_path in SUBSET_PATHS.items():
114 | num_shards[subset] = {}
115 | for split in _SPLITS:
116 | files_paths = list((dataset_path / split / subset_path).glob("**/*.mid"))
117 | save_path = webdataset_path / subset / split
118 | save_path.mkdir(parents=True, exist_ok=True)
119 | metadata = {}
120 | with ShardWriter(
121 | f"{save_path!s}/GigaMIDI_{subset}_{split}_%04d.tar",
122 | maxcount=MAX_NUM_ENTRIES_PER_SHARD,
123 | ) as writer:
124 | for file_path in files_paths:
125 | md5 = file_path.stem
126 | example = {
127 | "__key__": md5,
128 | "mid": file_path.open("rb").read(), # bytes
129 | }
130 | writer.write(example)
131 |
132 | # Get metadata if existing
133 | metadata_row = {}
134 |
135 | sid_matches = md5_sid_matches_scores.get(md5)
136 | if sid_matches:
137 | metadata_row["sid_matches"] = sid_matches
138 | metadata_row["mbid_matches"] = []
139 | for sid, _ in sid_matches:
140 | mbids = sid_to_mbid.get(sid, None)
141 | if mbids:
142 | metadata_row["mbid_matches"].append([sid, mbids])
143 |
144 | title_artist = md5_artist_title_scraped.get(md5)
145 | if title_artist:
146 | (
147 | metadata_row["title_scraped"],
148 | metadata_row["artist_scraped"],
149 | ) = title_artist
150 | genres_scraped = md5_genres_scraped.get(md5)
151 | if genres_scraped:
152 | metadata_row["genres_scraped"] = genres_scraped
153 | genres = md5_genres.get(md5)
154 | if genres:
155 | for key, val in genres.items():
156 | metadata_row[f"genres_{key.split('_')[1]}"] = val
157 | interpreted_scores = md5_expressive.get(md5)
158 | if interpreted_scores:
159 | metadata_row["median_metric_depth"] = interpreted_scores
160 | # TODO loops
161 | if len(metadata_row) > 0:
162 | metadata[md5] = metadata_row
163 |
164 | num_shards[subset][split] = len(list(save_path.glob("*.tar")))
165 | # Saving metadata for this subset and split
166 | """
167 | with (save_path.parent / f"metadata_{subset}_{split}.csv").open("w") as f:
168 | writer = csv.writer(f)
169 | writer.writerow(["md5", ]) # TODO header
170 | for row in metadata:
171 | writer.writerow([row])"""
172 | with (save_path.parent / f"metadata_{subset}_{split}.json").open("w") as f:
173 | json.dump(metadata, f)
174 |
175 | # Saving n shards
176 | with (webdataset_path / "n_shards.json").open("w") as f:
177 | json.dump(num_shards, f, indent=4)
178 |
179 |
180 | def load_dataset_from_generator(
181 | dataset_path: Path, num_files_limit: int = 100
182 | ) -> Dataset:
183 | """
184 | Load the dataset.
185 |
186 | :param dataset_path: path of the directory containing the datasets.
187 | :param num_files_limit: maximum number of entries/files to retain.
188 | :return: dataset.
189 | """
190 | files_paths = list(dataset_path.glob("**/*.mid"))[:num_files_limit]
191 | return Dataset.from_dict({"music": [str(path_) for path_ in files_paths]})
192 |
193 |
194 | def convert_to_parquet(
195 | dataset: DatasetDict, repo_id: str, token: str | None = None
196 | ) -> None:
197 | """
198 | Convert a dataset to parquet files.
199 |
200 | :param dataset: dataset to convert.
201 | :param repo_id: id of the repo to upload the files.
202 | :param token: token for authentication.
203 | """
204 | create_branch(
205 | repo_id,
206 | branch="refs/convert/parquet",
207 | revision="main",
208 | token=token,
209 | repo_type="dataset",
210 | )
211 | files_to_push = []
212 | for split, subset in dataset.items():
213 | file_path = f"{split}.parquet"
214 | subset.to_parquet(file_path)
215 | upload_file(
216 | path_or_fileobj=file_path,
217 | path_in_repo=file_path,
218 | repo_id=repo_id,
219 | token=token,
220 | revision="refs/convert/parquet",
221 | repo_type="dataset",
222 | )
223 | files_to_push.append(file_path)
224 |
225 |
226 | if __name__ == "__main__":
227 | from argparse import ArgumentParser
228 |
229 | from utils.utils import path_data_directory_local_fs
230 |
231 | parser = ArgumentParser(description="Dataset creation script")
232 | parser.add_argument(
233 | "--hf-repo-name", type=str, required=False, default="Metacreation/GigaMIDI"
234 | )
235 | parser.add_argument("--hf-token", type=str, required=False, default=None)
236 | args = vars(parser.parse_args())
237 |
238 | create_webdataset_gigamidi(path_data_directory_local_fs())
239 |
240 | """dataset_ = load_dataset(
241 | args["hf_repo_name"], "music", token=args["hf_token"], trust_remote_code=True
242 | )
243 | convert_to_parquet(dataset_, args["hf_repo_name"], token=args["hf_token"])"""
244 | """from datasets import load_dataset
245 |
246 | dataset_ = load_dataset(
247 | str(path_data_directory_local_fs() / "GigaMIDI"),
248 | "no-drums",
249 | subsets=["no-drums", "all-instruments-with-drums"],
250 | trust_remote_code=True,
251 | )
252 | data = dataset_["train"]
253 | for i in range(7):
254 | t = data[i]
255 | f = 0
256 |
257 | test = data[0]
258 | print(test)
259 | from symusic import Score
260 |
261 | score = Score.from_midi(test["music"]["bytes"])
262 | t = 0"""
263 |
--------------------------------------------------------------------------------
/loops_nomml/corr_mat.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import annotations
3 |
4 | from typing import TYPE_CHECKING
5 |
6 | import numpy as np
7 |
8 | from .note_set import NoteSet
9 |
10 | if TYPE_CHECKING:
11 | from collections.abc import Sequence
12 | from typing import Tuple, Dict
13 |
14 | from symusic import Note
15 |
16 |
17 | # Implementation of Correlative Matrix approach presented in:
18 | # Jia Lien Hsu, Chih Chin Liu, and Arbee L.P. Chen. Discovering
19 | # nontrivial repeating patterns in music data. IEEE Transactions on
20 | # Multimedia, 3:311–325, 9 2001.
21 | def calc_correlation(note_sets: Sequence[NoteSet]) -> np.ndarray:
22 | """
23 | Calculates a correlation matrix of repeated segments with the note_sets.
24 | All repetitions are required to start on the downbeat of measure.
25 |
26 | :param note_sets: list of NoteSets to calculate repetitions for
27 | :return: 2d square correlation matrix the length of note_sets, each
28 | entry is an integer representing the number of continuous matching
29 | elements counting backwards from the current row and column
30 | """
31 | corr_size = len(note_sets)
32 | corr_mat = np.zeros((corr_size, corr_size), dtype='int16')
33 |
34 | # Complete the first row
35 | for j in range(1, corr_size):
36 | if note_sets[0] == note_sets[j] and note_sets[0].is_barline():
37 | corr_mat[0, j] = 1
38 | # Complete rest of the correlation matrix
39 | for i in range(1, corr_size - 1):
40 | for j in range(i + 1, corr_size):
41 | if note_sets[i] == note_sets[j]:
42 | if corr_mat[i - 1, j - 1] == 0:
43 | if not note_sets[i].is_barline():
44 | continue # loops must start on the downbeat (start of bar)
45 | corr_mat[i, j] = corr_mat[i - 1, j - 1] + 1
46 |
47 | return corr_mat
48 |
49 |
50 | def get_loop_density(loop: Sequence[NoteSet], num_beats: int | float) -> float:
51 | """
52 | Calculates the density of a list of NoteSets in active notes per beat
53 |
54 | :param loop: list of NoteSet groups in the loop
55 | :param num_beats: duration of the loop in beats
56 | :return: loop density in active notes per beat
57 | """
58 | return len([n_set for n_set in loop if n_set.start != n_set.end]) / num_beats
59 |
60 |
61 | def is_empty_loop(loop: Sequence[Note]) -> bool:
62 | """
63 | Checks if a sequence of notes contains at least one non-rest
64 |
65 | :param loop: sequence of MIDI notes to check
66 | :return: True if a non-rest note exists, False otherwise
67 | """
68 | for note in loop:
69 | if len(note.pitches) > 0:
70 | return False
71 | return True
72 |
73 |
74 | def compare_loops(p1: Sequence[NoteSet], p2: Sequence[NoteSet], min_rep_beats: int | float) -> int:
75 | """
76 | Checks if two lists of NoteSets match up to a certain number of beats.
77 | Used to track the longest common loop
78 |
79 | :param p1: new loop to compare
80 | :param p2: existing loop to compare with
81 | :return: 0 for a mismatch, 1 if p1 is a subloop of p2, 2 if p2 is a
82 | subloop of p1
83 | """
84 | min_rep_beats = int(round(min_rep_beats))
85 | if len(p1) < len(p2):
86 | for i in range(min_rep_beats):
87 | if p1[i] != p2[i]:
88 | return 0 #not a subloop, theres a mismatch
89 | return 1 #is a subloop
90 | else:
91 | for i in range(min_rep_beats):
92 | if p1[i] != p2[i]:
93 | return 0 #not a subloop, theres a mismatch
94 | return 2 #existing loop is subloop of the new one, replace it
95 |
96 |
97 | def test_loop_exists(loop_list: Sequence[Sequence[NoteSet]], loop: Sequence[NoteSet], min_rep_beats: int | float) -> int:
98 | """
99 | Checks if a loop already exists in a loop, and mark it for replacement if
100 | it is longer than the existing matching loop
101 |
102 | :param loop_list: list of loops to check
103 | :param loop: new loop to check for a match
104 | :param min_rep_beats: number of beats to check for a match
105 | :return: -1 if loop is a subloop of a current loop in loop_list, idx of
106 | existing loop to replace if loop is a superstring, or None if loop
107 | is an entirely new loop
108 | """
109 | for i, pat in enumerate(loop_list):
110 | result = compare_loops(loop, pat, min_rep_beats)
111 | if result == 1:
112 | return -1 #ignore this loop since its a subloop
113 | if result == 2:
114 | return i #replace existing loop with this new longer one
115 | return None #we're just appending the new loop
116 |
117 |
118 | def filter_sub_loops(candidate_indices: Dict[float, Tuple[int, int]]) -> Sequence[Tuple[int, int, float]]:
119 | """
120 | Processes endpoints for identified loops, keeping only the largest
121 | unique loop when multiple loops intersect, thus eliminating "sub loops."
122 | For instance, if a 4 bar loop is made up of two 2 bar loops, only a
123 | single 2 bar loop will be returned.
124 |
125 | :param candidate_indices: dictionary of (start_tick, end_tick) for each
126 | identified group, keyed by loop length in beats
127 | :return: filtered list of loops with subloops removed
128 | """
129 | candidate_indices = dict(sorted(candidate_indices.items()))
130 |
131 | repeats = {}
132 | final = []
133 | for duration in candidate_indices.keys():
134 | curr_start = 0
135 | curr_end = 0
136 | curr_dur = 0
137 | for start, end in candidate_indices[duration]:
138 | if start in repeats and repeats[start][0] == end:
139 | continue
140 |
141 | if start == curr_end:
142 | curr_end = end
143 | curr_dur += duration
144 | else:
145 | if curr_start != curr_end:
146 | repeats[curr_start] = (curr_end, curr_dur)
147 | curr_start = start
148 | curr_end = end
149 | curr_dur = duration
150 |
151 | final.append((start, end, duration))
152 |
153 | return final
154 |
155 |
156 | def get_duration_beats(start: int, end: int, ticks_beats: Sequence[int]) -> float:
157 | """
158 | Given a loop start and end time in ticks and a list of beat tick times,
159 | calculate the duration of the loop in beats
160 |
161 | :param start: start time of the loop in ticks
162 | :param end: end time of the loop in ticks
163 | :param ticks_beat: list of all the beat times in the track
164 | :return: duration of the loop in beats
165 | """
166 | idx_beat_previous = None
167 | idx_beat_first_in = None
168 | idx_beat_last_in = None
169 | idx_beat_after = None
170 |
171 | for bi, beat_tick in enumerate(ticks_beats):
172 | if idx_beat_first_in is None and beat_tick >= start:
173 | idx_beat_first_in = bi
174 | idx_beat_previous = max(bi - 1, 0)
175 | elif idx_beat_last_in is None and beat_tick == end:
176 | idx_beat_last_in = idx_beat_after = bi
177 | elif idx_beat_last_in is None and beat_tick > end:
178 | idx_beat_last_in = max(bi - 1, 0)
179 | idx_beat_after = bi
180 | if idx_beat_after is None:
181 | idx_beat_after = idx_beat_last_in + ticks_beats[-1] - ticks_beats[-2] # TODO what if length 0?
182 |
183 | beat_length_before = ticks_beats[idx_beat_first_in] - ticks_beats[idx_beat_previous]
184 | if beat_length_before > 0:
185 | num_beats_before = (ticks_beats[idx_beat_first_in] - ticks_beats[idx_beat_previous]) / beat_length_before
186 | else:
187 | num_beats_before = 0
188 | beat_length_after = ticks_beats[idx_beat_after] - ticks_beats[idx_beat_last_in]
189 | if beat_length_after > 0:
190 | num_beats_after = (ticks_beats[idx_beat_after] - ticks_beats[end]) / beat_length_after
191 | else:
192 | num_beats_after = 0
193 | return float(idx_beat_last_in - idx_beat_first_in + num_beats_before + num_beats_after - 1)
194 |
195 |
196 | def get_valid_loops(
197 | note_sets: Sequence[NoteSet],
198 | corr_mat: np.ndarray,
199 | ticks_beats: Sequence[int],
200 | min_rep_notes: int=4,
201 | min_rep_beats: float=2.0,
202 | min_beats: float=4.0,
203 | max_beats: float=32.0,
204 | min_loop_note_density: float = 0.5,
205 | ) -> Tuple[Sequence[NoteSet], Tuple[int, int, float, float]]:
206 | """
207 | Returns all of the loops detected in note_sets, filtering based on the
208 | specified hyperparameters. Loops that are subloops of larger loops will
209 | be filtered out
210 |
211 | :param min_rep_notes: Minimum number of notes that must be present in
212 | the repeated bookend of a loop for it to be considered valid
213 | :param min_rep_beats: Minimum length in beats of the repeated bookend
214 | of a loop for it be considered valid
215 | :param min_beats: Minimum total length of the loop in beats
216 | :param max_beats: Maximum total length of the loop in beats
217 | :param min_loop_note_density: Minimum valid density of a loop in average
218 | notes per beat across the whole loop
219 | :return: tuple containing the loop as a sequence of NoteSets, and an
220 | additional tuple with loop metadata: (start time in ticks, end time
221 | in ticks, duration in beats, density)
222 | """
223 | min_rep_notes += 1 # don't count bar lines as a repetition
224 | x_num_elem, y_num_elem = np.where(corr_mat == min_rep_notes)
225 |
226 | # Parse the correlation matrix to retrieve the loops starts/ends ticks
227 | # keys are loops durations in beats, values tuples of indices TODO ??
228 | valid_indices = {}
229 | for i, x in enumerate(x_num_elem):
230 | y = y_num_elem[i]
231 | start_x = x - corr_mat[x, y] + 1
232 | start_y = y - corr_mat[x, y] + 1
233 |
234 | loop_start_time = note_sets[start_x].start
235 | loop_end_time = note_sets[start_y].start
236 | loop_num_beats = round(get_duration_beats(loop_start_time, loop_end_time, ticks_beats), 2)
237 | if max_beats >= loop_num_beats >= min_beats:
238 | if loop_num_beats not in valid_indices:
239 | valid_indices[loop_num_beats] = []
240 | valid_indices[loop_num_beats].append((x_num_elem[i], y_num_elem[i]))
241 |
242 | filtered_indices = filter_sub_loops(valid_indices)
243 |
244 | loops = []
245 | loop_bp = []
246 | corr_size = corr_mat.shape[0]
247 | for start_x, start_y, loop_num_beats in filtered_indices:
248 | x = start_x
249 | y = start_y
250 | while x + 1 < corr_size and y + 1 < corr_size and corr_mat[x + 1, y + 1] > corr_mat[x, y]:
251 | x = x + 1
252 | y = y + 1
253 | beginning = x - corr_mat[x, y] + 1
254 | end = y - corr_mat[x, y] + 1
255 | start_tick = note_sets[beginning].start
256 | end_tick = note_sets[end].start
257 | duration_beats = get_duration_beats(start_tick, end_tick, ticks_beats)
258 |
259 | if duration_beats >= min_rep_beats and not is_empty_loop(note_sets[beginning:end]):
260 | loop = note_sets[beginning:(end + 1)]
261 | loop_density = get_loop_density(loop, loop_num_beats)
262 | if loop_density < min_loop_note_density:
263 | continue
264 | exist_result = test_loop_exists(loops, loop, min_rep_beats)
265 | if exist_result is None:
266 | loops.append(loop)
267 | loop_bp.append((start_tick, end_tick, loop_num_beats, loop_density))
268 | elif exist_result > 0: # index to replace
269 | loops[exist_result] = loop
270 | loop_bp[exist_result] = (start_tick, end_tick, loop_num_beats, loop_density)
271 |
272 | return loops, loop_bp
273 |
--------------------------------------------------------------------------------
/Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models/Optimal Threshold Selection/Expressive_Performance_Detection_Training-Percentile-threshold.csv:
--------------------------------------------------------------------------------
1 | ,trackNum,instrument,medianMetricDepth,tpq,velocity_per_track,onset_per_track,velocity_entropy_per_track,onset_entropy_per_track,DNVR,DNODR,"Ground Truth (0=Score, 1=Performance)"
2 | count,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0,20911.0
3 | mean,0.02247620869398881,0.004686528621299794,11.616230691980297,405.8130170723543,71.75218784371862,47.5673090717804,2.17519401474504,2.3491900877509306,56.49778570371541,12.266475946985935,0.958203816173306
4 | std,0.1600209488189645,0.45176420847283083,1.9479262330559586,96.8261362680451,22.05038341551496,11.507997369933452,0.30037865799320457,0.31002732650143766,17.362506626390047,3.196148670826988,0.20012790450225568
5 | min,0.0,0.0,0.0,220.0,1.0,1.0,0.0,0.0,0.78740157480315,0.09765625,0.0
6 | 3.5%,0.0,0.0,4.0,384.0,13.0,8.0,1.6410221333320987,1.847199957331255,10.2362204724409,1.45833333333333,0.0
7 | 3.51%,0.0,0.0,4.0,384.0,13.0,8.0,1.64341771979318,1.852555048428222,10.2362204724409,1.45833333333333,0.0
8 | 3.52%,0.0,0.0,4.0,384.0,13.0,8.0,1.643599811492699,1.85465261855982,10.2362204724409,1.46484375,0.0
9 | 3.53%,0.0,0.0,4.0,384.0,13.0,8.0,1.64884707167248,1.8551538950148374,10.2362204724409,1.46484375,0.0
10 | 3.54%,0.0,0.0,4.0,384.0,13.0,8.0,1.6588237201427147,1.8586239971562957,10.2362204724409,1.46484375,0.0
11 | 3.55%,0.0,0.0,4.0,384.0,13.0,8.0,1.66094743328692,1.860947629316187,10.2362204724409,1.46484375,0.0
12 | 3.56%,0.0,0.0,4.0,384.0,13.0,8.0,1.66687569743084,1.862463289114066,10.2362204724409,1.46484375,0.0
13 | 3.57%,0.0,0.0,4.0,384.0,13.0,8.0,1.6720969699001071,1.8638320116484202,10.2362204724409,1.46484375,0.0
14 | 3.58%,0.0,0.0,4.0,384.0,13.0,8.0,1.674929776698617,1.8640964797040853,10.2362204724409,1.46484375,0.0
15 | 3.59%,0.0,0.0,4.668999999999983,384.0,13.0,8.0,1.6928641492368666,1.8643394237437514,10.2362204724409,1.46484375,0.0
16 | 3.6%,0.0,0.0,5.759999999999991,384.0,13.0,8.0,1.69574253416963,1.8670664996869568,10.2362204724409,1.46484375,0.0
17 | 3.61%,0.0,0.0,6.0,384.0,13.0,8.0,1.7003000325145192,1.8678801716292,10.2362204724409,1.66666666666667,0.0
18 | 3.62%,0.0,0.0,6.0,384.0,13.0,8.0,1.7071083077240585,1.869199024407026,10.2362204724409,1.66666666666667,0.0
19 | 3.63%,0.0,0.0,6.0,384.0,13.0,8.0,1.7202515718959752,1.87180217690159,10.2362204724409,1.66666666666667,0.0
20 | 3.64%,0.0,0.0,6.0,384.0,13.0,9.0,1.7303242760737383,1.8735889466151716,10.2362204724409,1.66666666666667,0.0
21 | 3.65%,0.0,0.0,6.0,384.0,13.0,9.0,1.73286795139986,1.87464074053575,10.2362204724409,1.66666666666667,0.0
22 | 3.66%,0.0,0.0,6.0,384.0,13.0,9.0,1.73286795139986,1.8751403907674762,10.2362204724409,1.66666666666667,0.0
23 | 3.67%,0.0,0.0,6.0,384.0,13.0,9.0,1.7344455175356575,1.8774704083373137,10.2362204724409,1.66666666666667,0.0
24 | 3.68%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8806242090443093,10.2362204724409,1.66666666666667,0.0
25 | 3.69%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.881979697998224,10.2362204724409,1.66666666666667,0.0
26 | 3.7%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.884375397494431,10.2362204724409,1.66666666666667,0.0
27 | 3.71%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8854310253127236,10.2362204724409,1.66666666666667,0.0
28 | 3.72%,0.0,0.0,6.0,384.0,13.0,9.0,1.73512645696292,1.8861247034515896,10.2362204724409,1.66666666666667,0.0
29 | 3.73%,0.0,0.0,6.0,384.0,13.942999999999984,9.0,1.7468540433690751,1.8866247284476956,10.978740157480305,1.66666666666667,0.0
30 | 3.74%,0.0,0.0,6.0,384.0,14.0,9.0,1.74786809746676,1.88669678465808,11.0236220472441,1.66666666666667,0.0
31 | 3.75%,0.0,0.0,6.0,384.0,14.0,9.0,1.74786809746676,1.8867633891806825,11.0236220472441,1.66666666666667,0.0
32 | 3.76%,0.0,0.0,6.216000000000008,384.0,14.0,9.0,1.749235158976883,1.8879165897684507,11.0236220472441,1.66666666666667,0.0
33 | 3.77%,0.0,0.0,8.0,384.0,14.0,9.0,1.7543489653872566,1.88915916375402,11.0236220472441,1.66666666666667,0.0
34 | 3.78%,0.0,0.0,8.0,384.0,14.0,9.0,1.7565046102348396,1.88915916375402,11.0236220472441,1.66666666666667,0.0
35 | 3.79%,0.0,0.0,8.0,384.0,14.0,9.0,1.76017552682609,1.8901455523684245,11.0236220472441,1.66666666666667,0.0
36 | 3.8%,0.0,0.0,8.0,384.0,14.0,9.0,1.7641516589423036,1.8917806280499891,11.0236220472441,1.66666666666667,0.0
37 | 3.81%,0.0,0.0,8.0,384.0,14.0,9.0,1.7676503992262922,1.8936273912507633,11.0236220472441,1.66666666666667,0.0
38 | 3.82%,0.0,0.0,8.0,384.0,14.0,9.0,1.7711315368391405,1.89378823239114,11.0236220472441,1.66666666666667,0.0
39 | 3.83%,0.0,0.0,8.0,384.0,14.0,9.0,1.7722341368064598,1.89378823239114,11.0236220472441,1.66666666666667,0.0
40 | 3.84%,0.0,0.0,8.0,384.0,14.0,9.0,1.7803737124381809,1.8946386676072702,11.0236220472441,1.66666666666667,0.0
41 | 3.85%,0.0,0.0,10.069999999999936,384.0,14.0,9.0,1.78557386526543,1.8978464722049861,11.0236220472441,1.66666666666667,0.0
42 | 3.86%,0.0,0.0,12.0,384.0,14.0,9.0,1.7857294090800417,1.89892678933633,11.0236220472441,1.66666666666667,0.0
43 | 3.87%,0.0,0.0,12.0,384.0,14.0,9.0,1.789379564136363,1.8997133227581198,11.0236220472441,1.7005729166666677,0.0
44 | 3.88%,0.0,0.0,12.0,384.0,14.0,9.0,1.7907074707888906,1.9021724839218015,11.0236220472441,1.846166666666664,0.0
45 | 3.89%,0.0,0.0,12.0,384.0,14.0,9.0,1.7915399392101077,1.90368668833762,11.0236220472441,1.875,0.0
46 | 3.9%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90368668833762,11.0236220472441,1.875,0.0
47 | 3.91%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.9058469058028016,11.0236220472441,1.875,0.0
48 | 3.92%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.906297674512365,11.0236220472441,1.875,0.0
49 | 3.93%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90728399932138,11.0236220472441,1.875,0.0
50 | 3.94%,0.0,0.0,12.0,384.0,14.0,9.0,1.79175946922806,1.90728399932138,11.0236220472441,1.875,0.0
51 | 3.95%,0.0,0.0,12.0,384.0,14.0,9.94500000000005,1.79175946922806,1.90853528164356,11.0236220472441,1.875,0.0
52 | 3.96%,0.0,0.0,12.0,384.0,14.0,10.0,1.7917878768887283,1.9085725163331464,11.0236220472441,1.875,0.0
53 | 3.97%,0.0,0.0,12.0,384.0,14.0,10.0,1.7958592519658925,1.9102170160583758,11.0236220472441,1.875,0.0
54 | 3.98%,0.0,0.0,12.0,384.0,14.0,10.0,1.8095272062299135,1.9173405295680208,11.0236220472441,1.875,0.0
55 | 3.99%,0.0,0.0,12.0,384.0,15.0,10.0,1.8122218330460924,1.919268730921757,11.8110236220472,1.875,0.0
56 | 4%,0.0,0.0,12.0,384.0,15.0,10.0,1.817740144249624,1.921908958579216,11.8110236220472,1.875,0.0
57 | 4.01%,0.0,0.0,12.0,384.0,15.0,10.0,1.81884777724898,1.9229410566301042,11.8110236220472,1.875,0.0
58 | 4.02%,0.0,0.0,12.0,384.0,15.0,10.0,1.8198399615292822,1.92295892360442,11.8110236220472,1.875,0.0
59 | 4.03%,0.0,0.0,12.0,384.0,15.0,10.0,1.8251774601370696,1.9262625443250243,11.8110236220472,1.875,0.0
60 | 4.04%,0.0,0.0,12.0,384.0,15.0,10.0,1.8284410091155763,1.92739212613927,11.8110236220472,1.875,0.0
61 | 4.05%,0.0,0.0,12.0,384.0,15.0,10.0,1.8290448005300781,1.929631317498723,11.8110236220472,1.875,0.0
62 | 4.06%,0.0,0.0,12.0,384.0,15.0,10.0,1.8387412451018255,1.9308749430628558,11.8110236220472,1.875,0.0
63 | 4.07%,0.0,0.0,12.0,384.0,15.0,10.0,1.8414770403073026,1.9341879826241253,11.8110236220472,1.875,0.0
64 | 4.08%,0.0,0.0,12.0,384.0,15.0,10.0,1.844300645896431,1.9357406070586625,11.8110236220472,1.875,0.0
65 | 4.09%,0.0,0.0,12.0,384.0,15.0,10.0,1.8484789183547972,1.93619934568452,11.8110236220472,1.875,0.0
66 | 4.1%,0.0,0.0,12.0,384.0,15.0,10.0,1.8556092770482155,1.938055203531081,11.8110236220472,1.875,0.0
67 | 4.11%,0.0,0.0,12.0,384.0,15.0,10.0,1.8581555151930649,1.9452419877155491,11.8110236220472,1.875,0.0
68 | 4.12%,0.0,0.0,12.0,384.0,15.0,10.0,1.8606839748791058,1.946808362617706,11.8110236220472,1.875,0.0
69 | 4.13%,0.0,0.0,12.0,384.0,15.0,10.0,1.8625483330972583,1.9480946835987856,11.8110236220472,1.875,0.0
70 | 4.14%,0.0,0.0,12.0,384.0,15.0,10.0,1.8646311472827994,1.9502876379902592,11.8110236220472,1.875,0.0
71 | 4.15%,0.0,0.0,12.0,384.0,15.0,10.0,1.867393051906037,1.9522471628500282,11.8110236220472,2.034375000000018,0.0
72 | 4.16%,0.0,0.0,12.0,384.0,15.0,10.0,1.870845405337236,1.9557341208567933,11.8110236220472,2.08333333333333,0.0
73 | 4.17%,0.0,0.0,12.0,384.0,15.0,10.0,1.872001404857079,1.9566235887987975,11.8110236220472,2.08333333333333,0.0
74 | 4.18%,0.0,0.0,12.0,384.0,15.0,10.0,1.8740618122105086,1.9622061996729672,11.8110236220472,2.08333333333333,1.0
75 | 4.19%,0.0,0.0,12.0,384.0,15.0,10.0,1.875586991460673,1.9638435223793027,11.8110236220472,2.08333333333333,1.0
76 | 4.2%,0.0,0.0,12.0,384.0,15.0,10.0,1.8788601482661171,1.964708555687614,11.8110236220472,2.08333333333333,1.0
77 | 4.21%,0.0,0.0,12.0,384.0,15.0,10.0,1.8813090426966708,1.9650726610011424,11.8110236220472,2.08333333333333,1.0
78 | 4.22%,0.0,0.0,12.0,384.0,15.0,10.0,1.8833259464900192,1.9668244613798862,11.8110236220472,2.08333333333333,1.0
79 | 4.23%,0.0,0.0,12.0,384.0,15.492999999999938,10.0,1.885097503590677,1.968519069916699,12.19921259842513,2.08333333333333,1.0
80 | 4.24%,0.0,0.0,12.0,384.0,16.0,10.0,1.8867947672607561,1.9694562392644799,12.5984251968504,2.08333333333333,1.0
81 | 4.25%,0.0,0.0,12.0,384.0,16.0,10.0,1.88702726883415,1.9727340405642466,12.5984251968504,2.08333333333333,1.0
82 | 4.26%,0.0,0.0,12.0,384.0,16.0,10.0,1.8875810435120208,1.976268663684989,12.5984251968504,2.08333333333333,1.0
83 | 4.27%,0.0,0.0,12.0,384.0,16.0,10.857000000000085,1.8879031479708135,1.9807480692407462,12.5984251968504,2.08333333333333,1.0
84 | 4.28%,0.0,0.0,12.0,384.0,16.0,11.0,1.88915916375402,1.9825107772601644,12.5984251968504,2.08333333333333,1.0
85 | 4.29%,0.0,0.0,12.0,384.0,16.0,11.0,1.889540710736435,1.9841601841484624,12.5984251968504,2.08333333333333,1.0
86 | 4.3%,0.0,0.0,12.0,384.0,16.0,11.0,1.89118654382655,1.98605696815486,12.5984251968504,2.08333333333333,1.0
87 | 4.31%,0.0,0.0,12.0,384.0,16.0,11.0,1.89378823239114,1.9865522721238016,12.5984251968504,2.08333333333333,1.0
88 | 4.32%,0.0,0.0,12.0,384.0,16.0,11.0,1.89378823239114,1.9883612586606534,12.5984251968504,2.08333333333333,1.0
89 | 4.33%,0.0,0.0,12.0,384.0,16.0,11.0,1.8957379690011265,1.9889304963821068,12.5984251968504,2.08333333333333,1.0
90 | 4.34%,0.0,0.0,12.0,384.0,16.0,11.0,1.898265521640075,1.9901410988933488,12.5984251968504,2.08333333333333,1.0
91 | 4.35%,0.0,0.0,12.0,384.0,16.0,11.0,1.9011798275286038,1.9907208264464311,12.5984251968504,2.08333333333333,1.0
92 | 4.36%,0.0,0.0,12.0,384.0,16.0,11.0,1.9034547051415784,1.9922035413725887,12.5984251968504,2.08333333333333,1.0
93 | 4.37%,0.0,0.0,12.0,384.0,16.0,11.0,1.9054145198796109,1.99279814410984,12.5984251968504,2.08333333333333,1.0
94 | 4.38%,0.0,0.0,12.0,384.0,16.0,11.0,1.9059420123164048,1.9959159730930844,12.5984251968504,2.08333333333333,1.0
95 | 4.39%,0.0,0.0,12.0,384.0,16.0,11.0,1.9061517078204513,1.9968823346531932,12.5984251968504,2.08333333333333,1.0
96 | 4.4%,0.0,0.0,12.0,384.0,16.0,11.0,1.906165859476912,2.0023690864566617,12.5984251968504,2.08333333333333,1.0
97 | 4.41%,0.0,0.0,12.0,384.0,16.0,11.0,1.9073531774934107,2.0040748783423328,12.5984251968504,2.08333333333333,1.0
98 | 4.42%,0.0,0.0,12.0,384.0,16.0,11.0,1.9083781938383528,2.0044808496496294,12.5984251968504,2.08333333333333,1.0
99 | 4.43%,0.0,0.0,12.0,384.0,16.0,11.0,1.910151208135982,2.0049627474285217,12.5984251968504,2.08333333333333,1.0
100 | 4.44%,0.0,0.0,12.0,384.0,16.0,11.0,1.913697567543809,2.0052414460245642,12.5984251968504,2.08333333333333,1.0
101 | 4.45%,0.0,0.0,12.0,384.0,16.49500000000012,11.0,1.9167169320313464,2.006417223562179,12.988188976378028,2.29166666666667,1.0
102 | 4.46%,0.0,0.0,12.0,384.0,17.0,11.0,1.9177677338472006,2.0083300423976396,13.3858267716535,2.29166666666667,1.0
103 | 4.47%,0.0,0.0,12.0,384.0,17.0,11.0,1.9206649087574386,2.0096663948928795,13.3858267716535,2.29166666666667,1.0
104 | 4.48%,0.0,0.0,12.0,384.0,17.0,11.0,1.9250333341959631,2.0112340867501795,13.3858267716535,2.29166666666667,1.0
105 | 4.49%,0.0,0.0,12.0,384.0,17.0,11.0,1.925133130993522,2.0120230786147806,13.3858267716535,2.29166666666667,1.0
106 | 4.5%,0.0,0.0,12.0,384.0,17.0,11.0,1.9255530695368681,2.0125241882899965,13.3858267716535,2.29166666666667,1.0
107 | 50%,0.0,0.0,12.0,384.0,75.0,52.0,2.22156599484268,2.38234678527306,59.0551181102362,13.5416666666667,1.0
108 | max,3.0,46.0,12.0,1024.0,126.0,120.0,3.08733027111959,3.18461464679519,99.2125984251969,54.5454545454546,1.0
109 |
--------------------------------------------------------------------------------
/GigaMIDI/GigaMIDI.py:
--------------------------------------------------------------------------------
1 | """The GigaMIDI dataset.""" # noqa:N999
2 |
3 | from __future__ import annotations
4 |
5 | import json
6 | from collections import defaultdict
7 | from pathlib import Path
8 | from typing import TYPE_CHECKING, Literal
9 |
10 | import datasets
11 |
12 | if TYPE_CHECKING:
13 | from collections.abc import Sequence
14 |
15 | from datasets.utils.file_utils import ArchiveIterable
16 |
17 | _CITATION = ""
18 | _DESCRIPTION = "A large-scale MIDI symbolic music dataset."
19 | _HOMEPAGE = "https://github.com/Metacreation-Lab/GigaMIDI"
20 | _LICENSE = "CC0, also see https://www.europarl.europa.eu/legal-notice/en/"
21 | _SUBSETS = ["all-instruments-with-drums", "drums-only", "no-drums"]
22 | _SPLITS = ["train", "validation", "test"]
23 | _BASE_DATA_DIR = ""
24 | _N_SHARDS_FILE = _BASE_DATA_DIR + "n_shards.json"
25 | _MUSIC_PATH = (
26 | _BASE_DATA_DIR + "{subset}/{split}/GigaMIDI_{subset}_{split}_{shard_idx}.tar"
27 | )
28 | _METADATA_PATH = _BASE_DATA_DIR + "{subset}/metadata_{subset}_{split}.json"
29 | _METADATA_FEATURES = {
30 | "sid_matches": datasets.Sequence(
31 | {"sid": datasets.Value("string"), "score": datasets.Value("float16")}
32 | ),
33 | "mbid_matches": datasets.Sequence(
34 | {
35 | "sid": datasets.Value("string"),
36 | "mbids": datasets.Sequence(datasets.Value("string")),
37 | }
38 | ),
39 | "artist_scraped": datasets.Value("string"),
40 | "title_scraped": datasets.Value("string"),
41 | "genres_scraped": datasets.Sequence(datasets.Value("string")),
42 | "genres_discogs": datasets.Sequence(
43 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")}
44 | ),
45 | "genres_tagtraum": datasets.Sequence(
46 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")}
47 | ),
48 | "genres_lastfm": datasets.Sequence(
49 | {"genre": datasets.Value("string"), "count": datasets.Value("int16")}
50 | ),
51 | "median_metric_depth": datasets.Sequence(datasets.Value("int16")),
52 | # "loops": datasets.Value("string"),
53 | }
54 | _VERSION = "1.0.0"
55 |
56 |
57 | """def cast_metadata_to_python(
58 | metadata: dict[str, Any],
59 | features: dict[str, datasets.Features] | None = None,
60 | ) -> dict:
61 | if features is None:
62 | features = _METADATA_FEATURES
63 | metadata_ = {}
64 | for feature_name, feature in features.items():
65 | data = metadata.get(feature_name, None)
66 | if (
67 | data
68 | and isinstance(feature, datasets.Sequence)
69 | and isinstance(feature.feature, (datasets.Sequence, dict))
70 | ):
71 | if isinstance(feature.feature, datasets.Sequence):
72 | metadata_[feature_name] = [
73 | cast_metadata_to_python(
74 | {feature_name: sample}, {feature_name: feature.feature}
75 | )
76 | for sample in data
77 | ]
78 | else:
79 | metadata_[feature_name] = {
80 | cast_metadata_to_python(
81 | {feature_name_: sample}, feature.feature
82 | )
83 | for sample in data
84 | for feature_name_ in feature.feature
85 | }
86 | else:
87 | metadata_[feature_name] = data
88 |
89 | return metadata_"""
90 |
91 |
92 | class GigaMIDIConfig(datasets.BuilderConfig):
93 | """BuilderConfig for GigaMIDI."""
94 |
95 | def __init__(
96 | self,
97 | name: str,
98 | subsets: Sequence[
99 | Literal[
100 | "all-instruments-with-drums",
101 | "all-instruments-no-drums",
102 | "drums-only",
103 | ]
104 | ]
105 | | None = None,
106 | **kwargs,
107 | ) -> None:
108 | """
109 | BuilderConfig for GigaMIDI.
110 |
111 | Args:
112 | ----
113 | name: `string` or `List[string]`:
114 | name of the dataset subset. Must be either "drums" for files containing
115 | only drum tracks, "music" for others or "all" for all.
116 | subsets: `Sequence[string]`: list of subsets to use. It is None by default
117 | and resort to the `name` argument to select one subset if not provided.
118 | **kwargs: keyword arguments forwarded to super.
119 |
120 | """
121 | if name == "all":
122 | self.subsets = _SUBSETS
123 | elif subsets is not None:
124 | self.subsets = subsets
125 | name = "_".join(subsets)
126 | else:
127 | self.subsets = [name]
128 |
129 | super().__init__(name=name, **kwargs)
130 |
131 |
132 | class GigaMIDI(datasets.GeneratorBasedBuilder):
133 | """The GigaMIDI dataset."""
134 |
135 | VERSION = datasets.Version(_VERSION)
136 | BUILDER_CONFIGS = [ # noqa:RUF012
137 | GigaMIDIConfig(
138 | name=name,
139 | version=datasets.Version(_VERSION),
140 | )
141 | for name in ["all", *_SUBSETS]
142 | ]
143 | DEFAULT_WRITER_BATCH_SIZE = 256
144 |
145 | def _info(self) -> datasets.DatasetInfo:
146 | features = datasets.Features(
147 | {
148 | "md5": datasets.Value("string"),
149 | "music": {
150 | "path": datasets.Value("string"),
151 | "bytes": datasets.Value("binary"),
152 | },
153 | "is_drums": datasets.Value("bool"),
154 | **_METADATA_FEATURES,
155 | }
156 | )
157 | return datasets.DatasetInfo(
158 | description=_DESCRIPTION,
159 | features=features,
160 | homepage=_HOMEPAGE,
161 | license=_LICENSE,
162 | citation=_CITATION,
163 | version=_VERSION,
164 | )
165 |
166 | def _split_generators(
167 | self, dl_manager: datasets.DownloadManager | datasets.StreamingDownloadManager
168 | ) -> list[datasets.SplitGenerator]:
169 | n_shards_path = Path(dl_manager.download_and_extract(_N_SHARDS_FILE))
170 | with n_shards_path.open() as f:
171 | n_shards = json.load(f)
172 |
173 | music_urls = defaultdict(dict)
174 | for split in _SPLITS:
175 | for subset in self.config.subsets:
176 | music_urls[split][subset] = [
177 | _MUSIC_PATH.format(subset=subset, split=split, shard_idx=f"{i:04d}")
178 | for i in range(n_shards[subset][split])
179 | ]
180 |
181 | meta_urls = defaultdict(dict)
182 | for split in _SPLITS:
183 | for subset in self.config.subsets:
184 | meta_urls[split][subset] = _METADATA_PATH.format(
185 | subset=subset, split=split
186 | )
187 |
188 | # dl_manager.download_config.num_proc = len(urls)
189 |
190 | meta_paths = dl_manager.download_and_extract(meta_urls)
191 | music_paths = dl_manager.download(music_urls)
192 |
193 | local_extracted_music_paths = (
194 | dl_manager.extract(music_paths)
195 | if not dl_manager.is_streaming
196 | else {
197 | split: {
198 | subset: [None] * len(music_paths[split][subset])
199 | for subset in self.config.subsets
200 | }
201 | for split in _SPLITS
202 | }
203 | )
204 |
205 | return [
206 | datasets.SplitGenerator(
207 | name=split_name,
208 | gen_kwargs={
209 | "music_shards": {
210 | subset: [
211 | dl_manager.iter_archive(shard) for shard in subset_shards
212 | ]
213 | for subset, subset_shards in music_paths[split_name].items()
214 | },
215 | "local_extracted_shards_paths": local_extracted_music_paths[
216 | split_name
217 | ],
218 | "metadata_paths": meta_paths[split_name],
219 | },
220 | )
221 | for split_name in _SPLITS
222 | ]
223 |
224 | def _generate_examples(
225 | self,
226 | music_shards: dict[str, Sequence[ArchiveIterable]],
227 | local_extracted_shards_paths: dict[str, Sequence[dict]],
228 | metadata_paths: dict[str, Path],
229 | ) -> dict:
230 | if not (
231 | len(metadata_paths)
232 | == len(music_shards)
233 | == len(local_extracted_shards_paths)
234 | ):
235 | msg = "The number of subsets provided are not equals"
236 | raise ValueError(msg)
237 |
238 | for subset in self.config.subsets:
239 | if len(music_shards[subset]) != len(local_extracted_shards_paths[subset]):
240 | msg = "the number of shards must be equal to the number of paths"
241 | raise ValueError(msg)
242 |
243 | is_drums = subset == "drums"
244 | with Path(metadata_paths[subset]).open() as file:
245 | metadata = json.load(file)
246 |
247 | for music_shard, local_extracted_shard_path in zip(
248 | music_shards[subset], local_extracted_shards_paths[subset]
249 | ):
250 | for music_file_name, music_file in music_shard:
251 | md5 = music_file_name.split(".")[0]
252 | path = (
253 | str(Path(str(local_extracted_shard_path)) / music_file_name)
254 | if local_extracted_shard_path
255 | else music_file_name
256 | )
257 |
258 | metadata_ = metadata.get(md5, {})
259 | yield (
260 | md5,
261 | {
262 | "md5": md5,
263 | "music": {"path": path, "bytes": music_file.read()},
264 | "is_drums": is_drums,
265 | "sid_matches": [
266 | {"sid": sid, "score": score}
267 | for sid, score in metadata_.get("sid_matches", [])
268 | ],
269 | "mbid_matches": [
270 | {"sid": sid, "mbids": mbids}
271 | for sid, mbids in metadata_.get("mbid_matches", [])
272 | ],
273 | "artist_scraped": metadata_.get("artist_scraped", None),
274 | "title_scraped": metadata_.get("title_scraped", None),
275 | "genres_scraped": metadata_.get("genres_scraped", None),
276 | "genres_discogs": [
277 | {"genre": genre, "count": count}
278 | for genre, count in metadata_.get(
279 | "genres_discogs", {}
280 | ).items()
281 | ],
282 | "genres_tagtraum": [
283 | {"genre": genre, "count": count}
284 | for genre, count in metadata_.get(
285 | "genres_tagtraum", {}
286 | ).items()
287 | ],
288 | "genres_lastfm": [
289 | {"genre": genre, "count": count}
290 | for genre, count in metadata_.get(
291 | "genres_lastfm", {}
292 | ).items()
293 | ],
294 | "median_metric_depth": metadata_.get(
295 | "median_metric_depth", None
296 | ),
297 | # "loops": metadata_.get("loops", None),
298 | },
299 | )
300 |
--------------------------------------------------------------------------------
/scripts/dataset_short_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3 python
2 |
3 | """
4 | Script analyzing the programs and number of tracks of short files from the dataset.
5 |
6 | Results:
7 | Number of files with less than 8 bars: 589546
8 | Program -1 (Drums): 368071 (0.549%)
9 | Program 0 (Acoustic Grand Piano): 291196 (0.434%)
10 | Program 1 (Bright Acoustic Piano): 110 (0.000%)
11 | Program 2 (Electric Grand Piano): 75 (0.000%)
12 | Program 3 (Honky-tonk Piano): 43 (0.000%)
13 | Program 4 (Electric Piano 1): 93 (0.000%)
14 | Program 5 (Electric Piano 2): 32 (0.000%)
15 | Program 6 (Harpsichord): 71 (0.000%)
16 | Program 7 (Clavi): 11 (0.000%)
17 | Program 8 (Celesta): 14 (0.000%)
18 | Program 9 (Glockenspiel): 26 (0.000%)
19 | Program 10 (Music Box): 48 (0.000%)
20 | Program 11 (Vibraphone): 123 (0.000%)
21 | Program 12 (Marimba): 54 (0.000%)
22 | Program 13 (Xylophone): 15 (0.000%)
23 | Program 14 (Tubular Bells): 29 (0.000%)
24 | Program 15 (Dulcimer): 34 (0.000%)
25 | Program 16 (Drawbar Organ): 23 (0.000%)
26 | Program 17 (Percussive Organ): 8 (0.000%)
27 | Program 18 (Rock Organ): 51 (0.000%)
28 | Program 19 (Church Organ): 150 (0.000%)
29 | Program 20 (Reed Organ): 4 (0.000%)
30 | Program 21 (Accordion): 13 (0.000%)
31 | Program 22 (Harmonica): 17 (0.000%)
32 | Program 23 (Tango Accordion): 16 (0.000%)
33 | Program 24 (Acoustic Guitar (nylon)): 189 (0.000%)
34 | Program 25 (Acoustic Guitar (steel)): 100 (0.000%)
35 | Program 26 (Electric Guitar (jazz)): 84 (0.000%)
36 | Program 27 (Electric Guitar (clean)): 90 (0.000%)
37 | Program 28 (Electric Guitar (muted)): 47 (0.000%)
38 | Program 29 (Overdriven Guitar): 73 (0.000%)
39 | Program 30 (Distortion Guitar): 85 (0.000%)
40 | Program 31 (Guitar Harmonics): 6 (0.000%)
41 | Program 32 (Acoustic Bass): 172 (0.000%)
42 | Program 33 (Electric Bass (finger)): 154 (0.000%)
43 | Program 34 (Electric Bass (pick)): 34 (0.000%)
44 | Program 35 (Fretless Bass): 63 (0.000%)
45 | Program 36 (Slap Bass 1): 23 (0.000%)
46 | Program 37 (Slap Bass 2): 24 (0.000%)
47 | Program 38 (Synth Bass 1): 1727 (0.003%)
48 | Program 39 (Synth Bass 2): 28 (0.000%)
49 | Program 40 (Violin): 40 (0.000%)
50 | Program 41 (Viola): 22 (0.000%)
51 | Program 42 (Cello): 24 (0.000%)
52 | Program 43 (Contrabass): 26 (0.000%)
53 | Program 44 (Tremolo Strings): 45 (0.000%)
54 | Program 45 (Pizzicato Strings): 32 (0.000%)
55 | Program 46 (Orchestral Harp): 71 (0.000%)
56 | Program 47 (Timpani): 101 (0.000%)
57 | Program 48 (String Ensembles 1): 388 (0.001%)
58 | Program 49 (String Ensembles 2): 71 (0.000%)
59 | Program 50 (SynthStrings 1): 48 (0.000%)
60 | Program 51 (SynthStrings 2): 25 (0.000%)
61 | Program 52 (Choir Aahs): 79 (0.000%)
62 | Program 53 (Voice Oohs): 16 (0.000%)
63 | Program 54 (Synth Voice): 24 (0.000%)
64 | Program 55 (Orchestra Hit): 49 (0.000%)
65 | Program 56 (Trumpet): 202 (0.000%)
66 | Program 57 (Trombone): 107 (0.000%)
67 | Program 58 (Tuba): 49 (0.000%)
68 | Program 59 (Muted Trumpet): 27 (0.000%)
69 | Program 60 (French Horn): 116 (0.000%)
70 | Program 61 (Brass Section): 115 (0.000%)
71 | Program 62 (Synth Brass 1): 42 (0.000%)
72 | Program 63 (Synth Brass 2): 19 (0.000%)
73 | Program 64 (Soprano Sax): 6 (0.000%)
74 | Program 65 (Alto Sax): 28 (0.000%)
75 | Program 66 (Tenor Sax): 30 (0.000%)
76 | Program 67 (Baritone Sax): 12 (0.000%)
77 | Program 68 (Oboe): 52 (0.000%)
78 | Program 69 (English Horn): 9 (0.000%)
79 | Program 70 (Bassoon): 32 (0.000%)
80 | Program 71 (Clarinet): 95 (0.000%)
81 | Program 72 (Piccolo): 29 (0.000%)
82 | Program 73 (Flute): 92 (0.000%)
83 | Program 74 (Recorder): 7 (0.000%)
84 | Program 75 (Pan Flute): 28 (0.000%)
85 | Program 76 (Blown Bottle): 9 (0.000%)
86 | Program 77 (Shakuhachi): 12 (0.000%)
87 | Program 78 (Whistle): 12 (0.000%)
88 | Program 79 (Ocarina): 26 (0.000%)
89 | Program 80 (Lead 1 (square)): 2165 (0.003%)
90 | Program 81 (Lead 2 (sawtooth)): 2037 (0.003%)
91 | Program 82 (Lead 3 (calliope)): 13 (0.000%)
92 | Program 83 (Lead 4 (chiff)): 4 (0.000%)
93 | Program 84 (Lead 5 (charang)): 9 (0.000%)
94 | Program 85 (Lead 6 (voice)): 4 (0.000%)
95 | Program 86 (Lead 7 (fifths)): 3 (0.000%)
96 | Program 87 (Lead 8 (bass + lead)): 31 (0.000%)
97 | Program 88 (Pad 1 (new age)): 40 (0.000%)
98 | Program 89 (Pad 2 (warm)): 37 (0.000%)
99 | Program 90 (Pad 3 (polysynth)): 19 (0.000%)
100 | Program 91 (Pad 4 (choir)): 9 (0.000%)
101 | Program 92 (Pad 5 (bowed)): 47 (0.000%)
102 | Program 93 (Pad 6 (metallic)): 11 (0.000%)
103 | Program 94 (Pad 7 (halo)): 21 (0.000%)
104 | Program 95 (Pad 8 (sweep)): 16 (0.000%)
105 | Program 96 (FX 1 (rain)): 3 (0.000%)
106 | Program 97 (FX 2 (soundtrack)): 9 (0.000%)
107 | Program 98 (FX 3 (crystal)): 10 (0.000%)
108 | Program 99 (FX 4 (atmosphere)): 12 (0.000%)
109 | Program 100 (FX 5 (brightness)): 9 (0.000%)
110 | Program 101 (FX 6 (goblins)): 24 (0.000%)
111 | Program 102 (FX 7 (echoes)): 5 (0.000%)
112 | Program 103 (FX 8 (sci-fi)): 4 (0.000%)
113 | Program 104 (Sitar): 16 (0.000%)
114 | Program 105 (Banjo): 13 (0.000%)
115 | Program 106 (Shamisen): 7 (0.000%)
116 | Program 107 (Koto): 6 (0.000%)
117 | Program 108 (Kalimba): 12 (0.000%)
118 | Program 109 (Bag pipe): 1 (0.000%)
119 | Program 110 (Fiddle): 19 (0.000%)
120 | Program 111 (Shanai): 3 (0.000%)
121 | Program 112 (Tinkle Bell): 15 (0.000%)
122 | Program 113 (Agogo): 4 (0.000%)
123 | Program 114 (Steel Drums): 7 (0.000%)
124 | Program 115 (Woodblock): 8 (0.000%)
125 | Program 116 (Taiko Drum): 20 (0.000%)
126 | Program 117 (Melodic Tom): 16 (0.000%)
127 | Program 118 (Synth Drum): 9 (0.000%)
128 | Program 119 (Reverse Cymbal): 17 (0.000%)
129 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 6 (0.000%)
130 | Program 121 (Breath Noise, Flute Key Click): 2 (0.000%)
131 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 6 (0.000%)
132 | Program 123 (Bird Tweet, Dog, Horse Gallop): 3 (0.000%)
133 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 7 (0.000%)
134 | Program 125 (Helicopter, Car Sounds): 4 (0.000%)
135 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 16 (0.000%)
136 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 0 (0.000%)
137 |
138 | When reversing the condition on the file duration (to keep long files):
139 | Number of files with less than 8 bars: 279213
140 | Program -1 (Drums): 247301 (0.126%)
141 | Program 0 (Acoustic Grand Piano): 243676 (0.124%)
142 | Program 1 (Bright Acoustic Piano): 21722 (0.011%)
143 | Program 2 (Electric Grand Piano): 7453 (0.004%)
144 | Program 3 (Honky-tonk Piano): 4806 (0.002%)
145 | Program 4 (Electric Piano 1): 19245 (0.010%)
146 | Program 5 (Electric Piano 2): 13067 (0.007%)
147 | Program 6 (Harpsichord): 15961 (0.008%)
148 | Program 7 (Clavi): 5560 (0.003%)
149 | Program 8 (Celesta): 3726 (0.002%)
150 | Program 9 (Glockenspiel): 9911 (0.005%)
151 | Program 10 (Music Box): 3361 (0.002%)
152 | Program 11 (Vibraphone): 16293 (0.008%)
153 | Program 12 (Marimba): 7852 (0.004%)
154 | Program 13 (Xylophone): 3917 (0.002%)
155 | Program 14 (Tubular Bells): 6310 (0.003%)
156 | Program 15 (Dulcimer): 1123 (0.001%)
157 | Program 16 (Drawbar Organ): 7152 (0.004%)
158 | Program 17 (Percussive Organ): 7511 (0.004%)
159 | Program 18 (Rock Organ): 14286 (0.007%)
160 | Program 19 (Church Organ): 4997 (0.003%)
161 | Program 20 (Reed Organ): 1365 (0.001%)
162 | Program 21 (Accordion): 8649 (0.004%)
163 | Program 22 (Harmonica): 7059 (0.004%)
164 | Program 23 (Tango Accordion): 2963 (0.002%)
165 | Program 24 (Acoustic Guitar (nylon)): 30844 (0.016%)
166 | Program 25 (Acoustic Guitar (steel)): 62442 (0.032%)
167 | Program 26 (Electric Guitar (jazz)): 30134 (0.015%)
168 | Program 27 (Electric Guitar (clean)): 39305 (0.020%)
169 | Program 28 (Electric Guitar (muted)): 23762 (0.012%)
170 | Program 29 (Overdriven Guitar): 28155 (0.014%)
171 | Program 30 (Distortion Guitar): 28541 (0.015%)
172 | Program 31 (Guitar Harmonics): 2644 (0.001%)
173 | Program 32 (Acoustic Bass): 26662 (0.014%)
174 | Program 33 (Electric Bass (finger)): 54744 (0.028%)
175 | Program 34 (Electric Bass (pick)): 7300 (0.004%)
176 | Program 35 (Fretless Bass): 26724 (0.014%)
177 | Program 36 (Slap Bass 1): 3124 (0.002%)
178 | Program 37 (Slap Bass 2): 2557 (0.001%)
179 | Program 38 (Synth Bass 1): 12350 (0.006%)
180 | Program 39 (Synth Bass 2): 7469 (0.004%)
181 | Program 40 (Violin): 14198 (0.007%)
182 | Program 41 (Viola): 5949 (0.003%)
183 | Program 42 (Cello): 8222 (0.004%)
184 | Program 43 (Contrabass): 6633 (0.003%)
185 | Program 44 (Tremolo Strings): 5769 (0.003%)
186 | Program 45 (Pizzicato Strings): 20000 (0.010%)
187 | Program 46 (Orchestral Harp): 12108 (0.006%)
188 | Program 47 (Timpani): 18025 (0.009%)
189 | Program 48 (String Ensembles 1): 87048 (0.044%)
190 | Program 49 (String Ensembles 2): 33319 (0.017%)
191 | Program 50 (SynthStrings 1): 26076 (0.013%)
192 | Program 51 (SynthStrings 2): 6654 (0.003%)
193 | Program 52 (Choir Aahs): 45370 (0.023%)
194 | Program 53 (Voice Oohs): 18223 (0.009%)
195 | Program 54 (Synth Voice): 10778 (0.006%)
196 | Program 55 (Orchestra Hit): 3813 (0.002%)
197 | Program 56 (Trumpet): 47874 (0.024%)
198 | Program 57 (Trombone): 40421 (0.021%)
199 | Program 58 (Tuba): 17945 (0.009%)
200 | Program 59 (Muted Trumpet): 5584 (0.003%)
201 | Program 60 (French Horn): 34418 (0.018%)
202 | Program 61 (Brass Section): 20338 (0.010%)
203 | Program 62 (Synth Brass 1): 8776 (0.004%)
204 | Program 63 (Synth Brass 2): 3898 (0.002%)
205 | Program 64 (Soprano Sax): 4883 (0.002%)
206 | Program 65 (Alto Sax): 28511 (0.015%)
207 | Program 66 (Tenor Sax): 20751 (0.011%)
208 | Program 67 (Baritone Sax): 9189 (0.005%)
209 | Program 68 (Oboe): 23015 (0.012%)
210 | Program 69 (English Horn): 4493 (0.002%)
211 | Program 70 (Bassoon): 18090 (0.009%)
212 | Program 71 (Clarinet): 40191 (0.021%)
213 | Program 72 (Piccolo): 10367 (0.005%)
214 | Program 73 (Flute): 43843 (0.022%)
215 | Program 74 (Recorder): 3815 (0.002%)
216 | Program 75 (Pan Flute): 8700 (0.004%)
217 | Program 76 (Blown Bottle): 1210 (0.001%)
218 | Program 77 (Shakuhachi): 1754 (0.001%)
219 | Program 78 (Whistle): 3024 (0.002%)
220 | Program 79 (Ocarina): 2934 (0.001%)
221 | Program 80 (Lead 1 (square)): 11486 (0.006%)
222 | Program 81 (Lead 2 (sawtooth)): 16702 (0.009%)
223 | Program 82 (Lead 3 (calliope)): 7935 (0.004%)
224 | Program 83 (Lead 4 (chiff)): 1143 (0.001%)
225 | Program 84 (Lead 5 (charang)): 2186 (0.001%)
226 | Program 85 (Lead 6 (voice)): 2405 (0.001%)
227 | Program 86 (Lead 7 (fifths)): 617 (0.000%)
228 | Program 87 (Lead 8 (bass + lead)): 7257 (0.004%)
229 | Program 88 (Pad 1 (new age)): 8296 (0.004%)
230 | Program 89 (Pad 2 (warm)): 10565 (0.005%)
231 | Program 90 (Pad 3 (polysynth)): 5853 (0.003%)
232 | Program 91 (Pad 4 (choir)): 6063 (0.003%)
233 | Program 92 (Pad 5 (bowed)): 2008 (0.001%)
234 | Program 93 (Pad 6 (metallic)): 2011 (0.001%)
235 | Program 94 (Pad 7 (halo)): 2952 (0.002%)
236 | Program 95 (Pad 8 (sweep)): 4357 (0.002%)
237 | Program 96 (FX 1 (rain)): 1616 (0.001%)
238 | Program 97 (FX 2 (soundtrack)): 761 (0.000%)
239 | Program 98 (FX 3 (crystal)): 2101 (0.001%)
240 | Program 99 (FX 4 (atmosphere)): 4091 (0.002%)
241 | Program 100 (FX 5 (brightness)): 6566 (0.003%)
242 | Program 101 (FX 6 (goblins)): 1170 (0.001%)
243 | Program 102 (FX 7 (echoes)): 2688 (0.001%)
244 | Program 103 (FX 8 (sci-fi)): 1336 (0.001%)
245 | Program 104 (Sitar): 1857 (0.001%)
246 | Program 105 (Banjo): 3734 (0.002%)
247 | Program 106 (Shamisen): 1057 (0.001%)
248 | Program 107 (Koto): 1416 (0.001%)
249 | Program 108 (Kalimba): 1844 (0.001%)
250 | Program 109 (Bag pipe): 819 (0.000%)
251 | Program 110 (Fiddle): 2041 (0.001%)
252 | Program 111 (Shanai): 362 (0.000%)
253 | Program 112 (Tinkle Bell): 1044 (0.001%)
254 | Program 113 (Agogo): 552 (0.000%)
255 | Program 114 (Steel Drums): 1569 (0.001%)
256 | Program 115 (Woodblock): 1216 (0.001%)
257 | Program 116 (Taiko Drum): 2245 (0.001%)
258 | Program 117 (Melodic Tom): 1599 (0.001%)
259 | Program 118 (Synth Drum): 3203 (0.002%)
260 | Program 119 (Reverse Cymbal): 10701 (0.005%)
261 | Program 120 (Guitar Fret Noise, Guitar Cutting Noise): 2697 (0.001%)
262 | Program 121 (Breath Noise, Flute Key Click): 548 (0.000%)
263 | Program 122 (Seashore, Rain, Thunder, Wind, Stream, Bubbles): 3429 (0.002%)
264 | Program 123 (Bird Tweet, Dog, Horse Gallop): 659 (0.000%)
265 | Program 124 (Telephone Ring, Door Creaking, Door, Scratch, Wind Chime): 1642 (0.001%)
266 | Program 125 (Helicopter, Car Sounds): 1261 (0.001%)
267 | Program 126 (Applause, Laughing, Screaming, Punch, Heart Beat, Footstep): 1438 (0.001%)
268 | Program 127 (Gunshot, Machine Gun, Lasergun, Explosion): 0 (0.000%)
269 | """
270 |
271 | if __name__ == "__main__":
272 | from pathlib import Path
273 |
274 | import numpy as np
275 | from matplotlib import pyplot as plt
276 | from miditok.constants import MIDI_INSTRUMENTS, SCORE_LOADING_EXCEPTION
277 | from miditok.utils import get_bars_ticks
278 | from symusic import Score
279 | from tqdm import tqdm
280 |
281 | from utils.baseline import mmm
282 | from utils.constants import (
283 | MIN_NUM_BARS_FILE_VALID,
284 | )
285 |
286 | NUM_HIST_BINS = 50
287 |
288 | # Filter non-valid files
289 | dataset_files_paths = mmm.dataset_files_paths
290 | num_tracks, programs = [], []
291 | for file_path in tqdm(dataset_files_paths, desc="Reading MIDI files"):
292 | try:
293 | score = Score(file_path)
294 | except SCORE_LOADING_EXCEPTION:
295 | continue
296 | score = mmm.tokenizer.preprocess_score(score)
297 | if len(get_bars_ticks(score)) < MIN_NUM_BARS_FILE_VALID:
298 | continue
299 |
300 | num_tracks.append(len(score.tracks))
301 | programs += [-1 if track.is_drum else track.program for track in score.tracks]
302 |
303 | print( # noqa: T201
304 | f"Number of files with less than {MIN_NUM_BARS_FILE_VALID} bars: "
305 | f"{len(num_tracks)}"
306 | )
307 |
308 | programs = np.array(programs)
309 | for program in range(-1, 128):
310 | num_occurrences = len(np.where(programs == program)[0])
311 | ratio = num_occurrences / len(programs)
312 | print( # noqa: T201
313 | f"Program {program} ("
314 | f"{'Drums' if program == -1 else MIDI_INSTRUMENTS[program]['name']}): "
315 | f"{num_occurrences} ({ratio:.3f}%)"
316 | )
317 |
318 | # Plotting the distributions
319 | fig, ax = plt.subplots()
320 | ax.hist(num_tracks, bins=NUM_HIST_BINS)
321 | ax.grid(axis="y", linestyle="--", linewidth=0.6)
322 | ax.set_ylabel("Count files")
323 | ax.set_xlabel("Number of tracks")
324 | fig.savefig(Path("GigaMIDI_length_bars.pdf"), bbox_inches="tight", dpi=300)
325 | plt.close(fig)
326 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GigaMIDI Dataset
2 | ## The Extended GigaMIDI Dataset with Music Loop Detection and Expressive Multitrack Loop Generation
3 | \
4 | 
5 |
6 | ## The extended GigaMIDI Dataset Summary
7 | We present the extended GigaMIDI dataset, a large-scale symbolic music collection comprising over 2.1 million unique MIDI files with detailed annotations for music loop detection. Expanding on its predecessor, this release introduces a novel expressive loop detection method that captures performance nuances such as microtiming and dynamic variation, essential for advanced generative music modelling. Our method extends previous approaches, which were limited to strictly quantized, non-expressive tracks, by employing the Note Onset Median Metric Level (NOMML) heuristic to distinguish expressive from non-expressive material. This enables robust loop detection across a broader spectrum of MIDI data. Our loop detection method reveals more than 9.2 million non-expressive loops spanning all General MIDI instruments, alongside 2.3 million expressive loops identified through our new method. As the largest resource of its kind, the extended GigaMIDI dataset provides a strong foundation for developing models that synthesize structurally coherent and expressively rich musical loops. As a use case, we leverage this dataset to train an expressive multitrack symbolic music loop generation model using the MIDI-GPT system, resulting in the creation of a synthetic loop dataset.
8 |
9 | ## GigaMIDI Dataset Version Update
10 |
11 | We present the extended GigaMIDI dataset (select v2.0.0), a large-scale symbolic music collection comprising over 2.1 million unique MIDI files with detailed annotations for music loop detection. Expanding on its predecessor, this release introduces a novel expressive loop detection method that captures performance nuances such as microtiming and dynamic variation, essential for advanced generative music modelling. Our method extends previous approaches, which were limited to strictly quantized, non-expressive tracks, by employing the Note Onset Median Metric Level (NOMML) heuristic to distinguish expressive from non-expressive material. This enables robust loop detection across a broader spectrum of MIDI data. Our loop detection method reveals more than 9.2 million non-expressive loops spanning all General MIDI instruments, alongside 2.3 million expressive loops identified through our new method. As the largest resource of its kind, the extended GigaMIDI dataset provides a strong foundation for developing models that synthesize structurally coherent and expressively rich musical loops. As a use case, we leverage this dataset to train an expressive multitrack symbolic music loop generation model using the MIDI-GPT system, resulting in the creation of a synthetic loop dataset. The GigaMIDI dataset is accessible for research purposes on the Hugging Face hub [https://huggingface.co/datasets/Metacreation/GigaMIDI] in a user-friendly way for convenience and reproducibility.
12 |
13 |
14 | For the Hugging Face Hub dataset release, we disclaim any responsibility for the misuse of this dataset.
15 | The subset version `v2.0.0` refers specifically to the extended GigaMIDI dataset, while version `v2.0.0` denotes the original GigaMIDI dataset (see [Lee et al., 2025](https://doi.org/10.5334/tismir.203)).
16 | New users must request access via our Hugging Face Hub page before retrieving the dataset from the following link:
17 | [https://huggingface.co/datasets/Metacreation/GigaMIDI/viewer/v2.0.0](https://huggingface.co/datasets/Metacreation/GigaMIDI/viewer/v2.0.0)
18 |
19 | ### Dataset Curators
20 |
21 | Main curator: Keon Ju Maverick Lee
22 |
23 | Assistance: Jeff Ens, Sara Adkins, Nathan Fradet, Pedro Sarmento, Mathieu Barthet, Phillip Long, Paul Triana
24 |
25 | Research Director: Philippe Pasquier
26 |
27 | ### Citation/Reference
28 |
29 | If you use the GigaMIDI dataset or any part of this project, please cite the following paper:
30 | https://transactions.ismir.net/articles/10.5334/tismir.203
31 | ```bibtex
32 | @article{lee2025gigamidi,
33 | title={The GigaMIDI Dataset with Features for Expressive Music Performance Detection},
34 | author={Lee, Keon Ju Maverick and Ens, Jeff and Adkins, Sara and Sarmento, Pedro and Barthet, Mathieu and Pasquier, Philippe},
35 | journal={Transactions of the International Society for Music Information Retrieval (TISMIR)},
36 | volume={8},
37 | number={1},
38 | pages={1--19},
39 | year={2025}
40 | }
41 | ```
42 |
43 | ## Repository Layout
44 |
45 | [**/GigaMIDI**](./GigaMIDI): Code for creating the full GigaMIDI dataset from
46 | source files, and README with example code for loading and processing the
47 | data set using the `datasets` library
48 |
49 | [**/loops_nomml**](./loops_nomml): Source files for non-expressive loop detection algorithm
50 | and expressive performance detection algorithm
51 |
52 | [**Expressive Loop Detector**](Expressive%20music%20loop%20detector-NOMML12.ipynb): code for the expressive loop detection method (in the .ipynb file) and instructions are available in the later section of this readme file.
53 |
54 | [**Expressive Loop Generation**](https://github.com/Metacreation-Lab/GigaMIDI-Dataset/tree/main/MIDI-GPT-Loop): code for the expressive loop generation and instructions are available in the hyperlink which connects to the readme file.
55 |
56 | [**/scripts**](./scripts): Scripts and code notebooks for analyzing the
57 | GigaMIDI dataset and the loop dataset
58 |
59 | [**/tests**](./tests): E2E tests for expressive performance detection and
60 | loop extractions
61 |
62 | [**Analysis of Evaluation Set and Optimal Threshold Selection including Machine Learning Models**](https://github.com/GigaMidiDataset/The-GigaMIDI-dataset-with-loops-and-expressive-music-performance-detection/tree/82d424ae7ff48a2fb3ce5bb07de13d5cca4fc8c5/Analysis%20of%20Evaluation%20Set%20and%20Optimal%20Threshold%20Selection%20including%20Machine%20Learning%20Models): This archive includes CSV files corresponding to our curated evaluation set, which comprises both a training set and a testing set. These files contain percentile calculations used to determine the optimal thresholds for each heuristic in expressive music performance detection. The use of percentiles from the data distribution is intended to establish clear boundaries between non-expressive and expressive tracks, based on the values of our heuristic features. Additionally, we provide pre-trained models in .pkl format, developed using features derived from our novel heuristics. The hyperparameter setup is detailed in the following section titled *Pipeline Configuration*.
63 |
64 | [**Data Source Links for the GigaMIDI Dataset**](https://github.com/GigaMidiDataset/The-GigaMIDI-dataset-with-loops-and-expressive-music-performance-detection/blob/8acb0e5ca8ac5eb21c072ed381fa737689748c81/Data%20Source%20Links%20for%20the%20GigaMIDI%20Dataset%20-%20Sheet1.pdf): Data source links for each collected subset of the GigaMIDI dataset are all organized and uploaded in PDF.
65 |
66 | ## Running MIDI-based Loop Detection
67 |
68 | Included with GigaMIDI dataset is a collection of all loops identified in the
69 | dataset between 4 and 32 bars in length, with a minimum density of 0.5 notes
70 | per beat. For our purposes, we consider a segment of a track to be loopable if
71 | it is bookended by a repeated phrase of a minimum length (at least 2 beats
72 | and 4 note events)
73 |
74 | 
75 |
76 | ### Starter Code
77 |
78 | To run loop detection on a single MIDI file, use the `detect_loops` function
79 | ```python
80 | from loops_nomml import detect_loops
81 | from symusic import Score
82 |
83 | score = Score("tests\midi_files\Mr. Blue Sky.mid")
84 | loops = detect_loops(score)
85 | print(loops)
86 | ```
87 |
88 | The output will contain all the metadata needed to locate the loop within the
89 | file. Start and end times are represented as MIDI ticks, and density is
90 | given in units of notes per beat:
91 | ```
92 | {'track_idx': [0, 0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 5], 'instrument_type': ['Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Piano', 'Drums', 'Drums', 'Drums', 'Drums', 'Drums', 'Piano', 'Piano'], 'start': [238080, 67200, 165120, 172800, 1920, 97920, 15360, 216960, 276480, 7680, 195840, 122880, 284160, 117120, 49920, 65280], 'end': [241920, 82560, 180480, 188160, 3840, 99840, 17280, 220800, 291840, 9600, 211200, 138240, 291840, 130560, 51840, 80640], 'duration_beats': [8.0, 32.0, 32.0, 32.0, 4.0, 4.0, 4.0, 8.0, 32.0, 4.0, 32.0, 32.0, 16.0, 28.0, 4.0, 32.0], 'note_density': [0.75, 1.84375, 0.8125, 0.8125, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.8125, 2.46875, 2.4375, 2.5, 0.5, 0.6875]}
93 | ```
94 |
95 | ### Batch Processing Loops
96 |
97 | We also provide a script, `main.py` that batch extracts all loops in a
98 | dataset. This requires that you have downloaded GigaMIDI, see the [dataset README](./GigaMIDI/README.md) for instructions on doing this. Once you have the dataset downloaded, update the `DATA_PATH` and `METADATA_NAME` globals to reflect the location of GigaMIDI on your machine and run the script:
99 |
100 | ```python
101 | python main.py
102 | ```
103 |
104 |
105 | ## Instruction for using the code for note onset median metric level (NOMML) heuristic
106 | ### Install and import Python libraries for the NOMML code:
107 | Imported libraries:
108 | ```
109 | pip install numpy tqdm symusic
110 | ```
111 |
112 | Note: symusic library is used for MIDI parsing.
113 |
114 | ### Using with the command line
115 | usage:
116 | ```python
117 | python nomml.py [-h] --folder FOLDER [--force] [--nthreads NTHREADS]
118 | ```
119 |
120 | Note: If you run the code succesfully, it will generate .JSON file with appropriate metadata.
121 |
122 | The metadata median metric depth corresponds to the Note Onset Median Metric Level (NOMML).
123 | Please refer to the latest information below, as this repository was intended to be temporary during the anonymous peer-review process
124 | of the academic paper. Based on our experiments, tracks at levels 0-11 can be classified as non-expressive, while level 12 indicates
125 | expressive MIDI tracks. Note that this classification applies at the track level, not the file level.
126 |
127 | ## Pipeline Configuration
128 |
129 | The following pipeline configuration was determined through hyperparameter tuning using leave-one-out cross-validation and GridSearchCV for the logistic regression model:
130 |
131 | ```python
132 | # Hyperparameters
133 | {'C': 0.046415888336127774}
134 |
135 | # Logistic Regression Instance
136 | LogisticRegression(random_state=0, C=0.046415888336127774, max_iter=10000, tol=0.1)
137 |
138 | # Pipeline
139 | Pipeline(steps=[('scaler', StandardScaler(with_std=False)),
140 | ('logistic',
141 | LogisticRegression(C=0.046415888336127774, max_iter=10000,
142 | tol=0.1))])
143 | ```
144 |
145 | # Expressive Loop Detection Pipeline
146 |
147 | This repository contains a highly parallelized Python pipeline for detecting both **non-expressive** (hard-match) and **expressive** (soft-count) loops in large MIDI datasets. Built on the GigaMIDI `loops_nomml` library and Symusic for MIDI parsing, it uses `joblib` to distribute work across all available CPU cores. The workflow periodically writes out **checkpoint** CSVs (every 500 000 files) and produces a final, aggregated result. The code is available in the file named Expressive music loop detector-NOMML12.ipynb.
148 |
149 | ---
150 |
151 | ## 🚀 Key Features
152 |
153 | - **Hard-match detection**
154 | Finds exact, bar-aligned pitch-set repeats.
155 | - **Soft-count detection**
156 | Captures expressive loops by combining pitch overlap, velocity similarity, and micro-timing tolerance.
157 | - **Loopability scoring**
158 | A single metric blending the length of the longest repeat with overall repetition density.
159 | - **Scalable batch processing**
160 | Checkpoints output every 500 000 files (`loops_checkpoint_1.csv`, `loops_checkpoint_2.csv`, …).
161 | - **Full parallelization**
162 | Leverages `joblib.Parallel` to utilize all CPU cores for maximum throughput.
163 |
164 | ---
165 |
166 | ## 🔧 Prerequisites
167 |
168 | - **Python 3.8+**
169 | - **Symusic** for MIDI I/O
170 | - **GigaMIDI `loops_nomml`** module (place alongside this repo)
171 | - Install required packages:
172 | ```bash
173 | pip install numpy pandas joblib tqdm
174 |
175 |
176 | ---
177 |
178 | ## 📦 Installation
179 |
180 | 1. Clone this repo alongside your local `loops_nomml` checkout:
181 |
182 | ```bash
183 | git clone https://github.com/YourUser/expressive-loop-detect.git
184 | cd expressive-loop-detect
185 | ```
186 | 2. (Optional) Create & activate a virtual environment:
187 |
188 | ```bash
189 | python3 -m venv venv
190 | source venv/bin/activate
191 | pip install -r requirements.txt
192 | ```
193 |
194 | ---
195 |
196 | ## ⚙️ Usage
197 |
198 | 1. **Prepare your CSV**
199 | Place your input CSV (default name:
200 | `Final_GigaMIDI_Loop_V2_path-instrument-NOMML-type.csv`) in the working directory. It must have:
201 |
202 | | Column | Description |
203 | | ----------- | ----------------------------------------------------------------- |
204 | | `file_path` | Path to each `.mid` or `.midi` file |
205 | | `NOMML` | Python list of per-track expressiveness flags (e.g. `[12,2,4,…]`) |
206 |
207 | 2. **Configure parameters**
208 | At the top of `detect_loops.py`, you can adjust:
209 |
210 | * `melodic_tau` (default `0.3`): similarity threshold for melodic tracks
211 | * `drum_tau` (default `0.1`): threshold for drum tracks
212 | * `chunk_size` (default `500_000`): number of files per checkpoint
213 | * Bars are quantized every 4 beats; min/max loop lengths and density filters live in the `get_valid_loops` call.
214 |
215 | 3. **Run the detector**
216 |
217 | ```bash
218 | python detect_loops.py
219 | ```
220 |
221 | Or open the Jupyter notebook:
222 |
223 | ```bash
224 | jupyter notebook detect_loops.ipynb
225 | ```
226 |
227 | The script will:
228 |
229 | * Read the CSV of file paths
230 | * Process MIDI files in parallel, chunk by chunk
231 | * Save checkpoint CSVs named `loops_checkpoint_1.csv`, `loops_checkpoint_2.csv`, …
232 | * After all chunks, combine results into one DataFrame and save `loops_full_output.csv`
233 |
234 | ---
235 |
236 | ## 📊 Output
237 |
238 | * **Checkpoint CSVs** (`loops_checkpoint_.csv`): one row per MIDI file in that chunk, with columns:
239 |
240 | * `file_path`
241 | * `track_idx` (list of track indices with detected loops)
242 | * `MIDI program number` (list)
243 | * `instrument_group` (list of GM groups or “Drums”)
244 | * `loopability` (list of floats)
245 | * `start_tick`, `end_tick` (lists of integers)
246 | * `duration_beats` (list of floats)
247 | * `note_density` (list of floats)
248 |
249 | * **Full output** (`loops_full_output.csv`): concatenation of all checkpoint rows.
250 |
251 | **Example to load with correct list parsing**:
252 |
253 | ```python
254 | import pandas as pd
255 |
256 | converters = {
257 | 'track_idx': eval,
258 | 'MIDI program number': eval,
259 | 'instrument_group': eval,
260 | 'loopability': eval,
261 | 'start_tick': eval,
262 | 'end_tick': eval,
263 | 'duration_beats': eval,
264 | 'note_density': eval
265 | }
266 |
267 | df = pd.read_csv("loops_full_output.csv", converters=converters)
268 | ```
269 | | Column | Type | Description |
270 | | ------------------------ | -------------- | ------------------------------------------- |
271 | | `file_path` | string | Original MIDI filepath |
272 | | `track_idx` | list of ints | Indices of tracks where loops were detected |
273 | | `MIDI program number` | list of ints | Corresponding MIDI program codes |
274 | | `instrument_group` | list of strs | GM group (or “Drums”) for each loop |
275 | | `loopability` | list of floats | Loopability score per detected loop |
276 | | `start_tick`, `end_tick` | list of ints | Loop boundaries (MIDI ticks) |
277 | | `duration_beats` | list of floats | Loop lengths in beats |
278 | | `note_density` | list of floats | Active-notes-per-beat density per loop |
279 |
280 | ---
281 |
282 | ## 📝 Troubleshooting
283 |
284 | * **No loops found**: try lowering `melodic_tau` or relaxing `min_rep_beats`/`min_beats`.
285 | * **`IndexError` in beat-duration**: the script patches `get_duration_beats` for safety.
286 | * **Performance issues**: set `n_jobs` in `Parallel(...)` to fewer cores or reduce `chunk_size`.
287 |
288 | ---
289 |
290 | # Expressive Loop Generation
291 |
292 | Information about loop generation is found at `MIDI-GPT-Loop/README.md`
293 |
294 | ---
295 |
296 | ## 🤝 Contributing
297 |
298 | 1. Fork this repo
299 | 2. Create a feature branch
300 | 3. Submit a pull request
301 |
302 | ---
303 |
304 |
305 |
306 |
307 | ## Acknowledgement
308 | We gratefully acknowledge the support and contributions that have directly or indirectly aided this research. This work was supported in part by funding from the Natural Sciences and Engineering Research Council of Canada (NSERC) and the Social Sciences and Humanities Research Council of Canada (SSHRC). We also extend our gratitude to the School of Interactive Arts and Technology (SIAT) at Simon Fraser University (SFU) for providing resources and an enriching research environment. Additionally, we thank the Centre for Digital Music (C4DM) at Queen Mary University of London (QMUL) for fostering collaborative opportunities and supporting our engagement with interdisciplinary research initiatives.
309 |
310 | Special thanks are extended to Dr. Cale Plut for his meticulous manual curation of musical styles and to Dr. Nathan Fradet for his invaluable assistance in developing the HuggingFace Hub website for the GigaMIDI dataset, ensuring it is accessible and user-friendly for music computing and MIR researchers. We also sincerely thank our research interns, Paul Triana and Davide Rizotti, for their thorough proofreading of the manuscript.
311 |
312 | Finally, we express our heartfelt appreciation to the individuals and communities who generously shared their MIDI files for research purposes. Their contributions have been instrumental in advancing this work and fostering collaborative knowledge in the field.
313 |
--------------------------------------------------------------------------------
/GigaMIDI/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | annotations_creators: []
3 | license:
4 | - cc-by-4.0
5 | - other
6 | pretty_name: GigaMIDI
7 | size_categories: []
8 | source_datasets: []
9 | tags: []
10 | task_ids: []
11 | ---
12 |
13 | # Dataset Card for GigaMIDI
14 |
15 | ## Table of Contents
16 |
17 | - [Dataset Description](#dataset-description)
18 | - [Dataset Summary](#dataset-summary)
19 | - [How to use](#how-to-use)
20 | - [Dataset Structure](#dataset-structure)
21 | - [Data Instances](#data-instances)
22 | - [Data Fields](#data-fields)
23 | - [Data Splits](#data-splits)
24 | - [Dataset Creation](#dataset-creation)
25 | - [Curation Rationale](#curation-rationale)
26 | - [Source Data](#source-data)
27 | - [Annotations](#annotations)
28 | - [Personal and Sensitive Information](#personal-and-sensitive-information)
29 | - [Considerations for Using the Data](#considerations-for-using-the-data)
30 | - [Social Impact of Dataset](#social-impact-of-dataset)
31 | - [Discussion of Biases](#discussion-of-biases)
32 | - [Other Known Limitations](#other-known-limitations)
33 | - [Additional Information](#additional-information)
34 | - [Dataset Curators](#dataset-curators)
35 | - [Licensing Information](#licensing-information)
36 |
37 |
38 | ## Dataset Description
39 |
40 |
41 | - **Repository:** Anonymized during the peer review process
42 |
43 | - **Point of Contact:** Anonymized during the peer review process
44 |
45 | ### Dataset Summary
46 |
47 | The GigaMIDI dataset is a corpus of over 1 million MIDI files covering all music genres.
48 |
49 | We provide three subsets: `drums-only`, which contain MIDI files exclusively containing drum tracks, `no-drums` for MIDI files containing any MIDI program except drums (channel 10) and `all-instruments-with-drums` for MIDI files containing multiple MIDI programs including drums. The `all` subset encompasses the three to get the full dataset.
50 |
51 | ## How to use
52 |
53 | The `datasets` library allows you to load and pre-process your dataset in pure Python at scale. The dataset can be downloaded and prepared in one call to your local drive by using the `load_dataset` function.
54 |
55 | ```python
56 | from datasets import load_dataset
57 |
58 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True)
59 | ```
60 |
61 | You can load combinations of specific subsets by using the `subset` keyword argument when loading the dataset:
62 |
63 | ```Python
64 | from datasets import load_dataset
65 |
66 | dataset = load_dataset("Metacreation/GigaMIDI", "music", subsets=["no-drums", "all-instruments-with-drums"], trust_remote_code=True)
67 | ```
68 |
69 | Using the datasets library, you can also stream the dataset on-the-fly by adding a `streaming=True` argument to the `load_dataset` function call. Loading a dataset in streaming mode loads individual samples of the dataset at a time, rather than downloading the entire dataset to disk.
70 |
71 | ```python
72 | from datasets import load_dataset
73 |
74 | dataset = load_dataset(
75 | "Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, streaming=True
76 | )
77 |
78 | print(next(iter(dataset)))
79 | ```
80 |
81 | *Bonus*: create a [PyTorch dataloader](https://huggingface.co/docs/datasets/use_with_pytorch) directly with your own datasets (local/streamed).
82 |
83 | ### Local
84 |
85 | ```python
86 | from datasets import load_dataset
87 | from torch.utils.data.sampler import BatchSampler, RandomSampler
88 |
89 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train")
90 | batch_sampler = BatchSampler(RandomSampler(dataset), batch_size=32, drop_last=False)
91 | dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
92 | ```
93 |
94 | ### Streaming
95 |
96 | ```python
97 | from datasets import load_dataset
98 | from torch.utils.data import DataLoader
99 |
100 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train")
101 | dataloader = DataLoader(dataset, batch_size=32)
102 | ```
103 |
104 | ### Example scripts
105 |
106 | MIDI files can be easily loaded and tokenized with [Symusic](https://github.com/Yikai-Liao/symusic) and [MidiTok](https://github.com/Natooz/MidiTok) respectively.
107 |
108 | ```python
109 | from datasets import load_dataset
110 |
111 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train")
112 | ```
113 |
114 | The dataset can be [processed](https://huggingface.co/docs/datasets/process) by using the `dataset.map` and `dataset.filter` methods.
115 |
116 | ```Python
117 | from pathlib import Path
118 | from datasets import load_dataset
119 | from miditok.constants import SCORE_LOADING_EXCEPTION
120 | from miditok.utils import get_bars_ticks
121 | from symusic import Score
122 |
123 | def is_score_valid(
124 | score: Score | Path | bytes, min_num_bars: int, min_num_notes: int
125 | ) -> bool:
126 | """
127 | Check if a ``symusic.Score`` is valid, contains the minimum required number of bars.
128 |
129 | :param score: ``symusic.Score`` to inspect or path to a MIDI file.
130 | :param min_num_bars: minimum number of bars the score should contain.
131 | :param min_num_notes: minimum number of notes that score should contain.
132 | :return: boolean indicating if ``score`` is valid.
133 | """
134 | if isinstance(score, Path):
135 | try:
136 | score = Score(score)
137 | except SCORE_LOADING_EXCEPTION:
138 | return False
139 | elif isinstance(score, bytes):
140 | try:
141 | score = Score.from_midi(score)
142 | except SCORE_LOADING_EXCEPTION:
143 | return False
144 |
145 | return (
146 | len(get_bars_ticks(score)) >= min_num_bars and score.note_num() > min_num_notes
147 | )
148 |
149 | dataset = load_dataset("Metacreation/GigaMIDI", "all-instruments-with-drums", trust_remote_code=True, split="train")
150 | dataset = dataset.filter(
151 | lambda ex: is_score_valid(ex["music"]["bytes"], min_num_bars=8, min_num_notes=50)
152 | )
153 | ```
154 |
155 | ## Dataset Structure
156 |
157 | ### Data Instances
158 |
159 | A typical data sample comprises the `md5` of the file which corresponds to its file name, a `music` entry containing dictionary mapping to its absolute file `path` and `bytes` that can be loaded with `symusic` as `score = Score.from_midi(dataset[sample_idx]["music"]["bytes"])`.
160 | Metadata accompanies each file, which is introduced in the next section.
161 |
162 | A data sample indexed from the dataset may look like this (the `bytes` entry is voluntarily shorten):
163 |
164 | ```python
165 | {
166 | 'md5': '0211bbf6adf0cf10d42117e5929929a4',
167 | 'music': {'path': '/Users/nathan/.cache/huggingface/datasets/downloads/extracted/cc8e36bbe8d5ec7ecf1160714d38de3f2f670c13bc83e0289b2f1803f80d2970/0211bbf6adf0cf10d42117e5929929a4.mid', 'bytes': b"MThd\x00\x00\x00\x06\x00\x01\x00\x05\x01\x00MTrk\x00"},
168 | 'is_drums': False,
169 | 'sid_matches': {'sid': ['065TU5v0uWSQmnTlP5Cnsz', '29OG7JWrnT0G19tOXwk664', '2lL9TiCxUt7YpwJwruyNGh'], 'score': [0.711, 0.8076, 0.8315]},
170 | 'mbid_matches': {'sid': ['065TU5v0uWSQmnTlP5Cnsz', '29OG7JWrnT0G19tOXwk664', '2lL9TiCxUt7YpwJwruyNGh'], 'mbids': [['43d521a9-54b0-416a-b15e-08ad54982e63', '70645f54-a13d-4123-bf49-c73d8c961db8', 'f46bba68-588f-49e7-bb4d-e321396b0d8e'], ['43d521a9-54b0-416a-b15e-08ad54982e63', '70645f54-a13d-4123-bf49-c73d8c961db8'], ['3a4678e6-9d8f-4379-aa99-78c19caf1ff5']]},
171 | 'artist_scraped': 'Bach, Johann Sebastian',
172 | 'title_scraped': 'Contrapunctus 1 from Art of Fugue',
173 | 'genres_scraped': ['classical', 'romantic'],
174 | 'genres_discogs': {'genre': ['classical', 'classical---baroque'], 'count': [14, 1]},
175 | 'genres_tagtraum': {'genre': ['classical', 'classical---baroque'], 'count': [1, 1]},
176 | 'genres_lastfm': {'genre': [], 'count': []},
177 | 'median_metric_depth': [0, 0, 0, 0]
178 | }
179 | ```
180 |
181 | ### Data Fields
182 |
183 | The GigaMIDI dataset comprises the [MetaMIDI dataset](https://www.metacreation.net/projects/metamidi-dataset). Consequently, the GigaMIDI dataset also contains its [metadata](https://github.com/jeffreyjohnens/MetaMIDIDataset) which we compiled here in a convenient and easy to use dataset format. The fields of each data entry are:
184 |
185 | * `md5` (`string`): hash the MIDI file, corresponding to its file name;
186 | * `music` (`dict`): a dictionary containing the absolute `path` to the downloaded file and the file content as `bytes` to be loaded with an external Python package such as symusic;
187 | * `is_drums` (`boolean`): whether the sample comes from the `drums` subset, this can be useful when working with the `all` subset;
188 | * `sid_matches` (`dict[str, list[str] | list[float16]]`): ids of the Spotify entries matched and their scores.
189 | * `mbid_matches` (`dict[str, str | list[str]]`): ids of the MusicBrainz entries matched with the Spotify entries.
190 | * `artist_scraped` (`string`): scraped artist of the entry;
191 | * `title_scraped` (`string`): scraped song title of the entry;
192 | * `genres_scraped` (`list[str]`): scraped genres of the entry;
193 | * `genres_discogs` (`dict[str, list[str] | list[int16]]`): Discogs genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/);
194 | * `genres_tagtraum` (`dict[str, list[str] | list[int16]]`): Tagtraum genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/);
195 | * `genres_lastfm` (`dict[str, list[str] | list[int16]]`): Lastfm genres matched from the [AcousticBrainz dataset](https://multimediaeval.github.io/2018-AcousticBrainz-Genre-Task/data/);
196 | * `median_metric_depth` (`list[int16]`):
197 |
198 |
199 | ### Data Splits
200 |
201 | The dataset has been subdivided into portions for training (`train`), validation (`validation`) and testing (`test`).
202 |
203 | The validation and test splits contain each 10% of the dataset, while the training split contains the rest (about 80%).
204 |
205 | ## Dataset Creation
206 |
207 | ### Curation Rationale
208 |
209 | [Needs More Information]
210 |
211 | ### Source Data
212 |
213 | #### Initial Data Collection and Normalization
214 |
215 | [Needs More Information]
216 |
217 | #### Who are the source language producers?
218 |
219 | [Needs More Information]
220 |
221 | ### Annotations
222 |
223 | #### Annotation process
224 |
225 | [Needs More Information]
226 |
227 | #### Who are the annotators?
228 |
229 | [Needs More Information]
230 |
231 | ## Considerations for Using the Data
232 |
233 | ### Discussion of Biases
234 |
235 | [More Information Needed]
236 |
237 | ### Other Known Limitations
238 |
239 | [More Information Needed]
240 |
241 | ## Additional Information
242 |
243 | ### Dataset Curators
244 |
245 | [More Information Needed]
246 |
247 | ### Licensing Information
248 |
249 | Available for research purposes via Hugging Face hub.
250 |
251 | ## Ethical Statement
252 | The GigaMIDI dataset consists of MIDI files acquired via the aggregation of previously available datasets and web scraping from publicly available online sources. Each subset is accompanied by source links, copyright information when available, and acknowledgments. File names are anonymized using MD5 hash encryption. We acknowledge and cited the work from the previous dataset papers that we aggregate and analyze as part of the GigaMIDI subsets.
253 | This data has been collected, used, and distributed under Fair Dealing [ref to country and law copyright act anonymized]. Fair Dealing permits the limited use of copyright-protected material without the risk of infringement and without having to seek the permission of copyright owners. It is intended to provide a balance between the rights of creators and the rights of users. As per instructions of the Copyright Office of [anonymized University], two protective measures have been put in place that are deemed sufficient given the nature of the data (accessible online):
254 |
255 | 1) We explicitly state that this dataset has been collected, used, and is distributed under Fair Dealing [ref to law/country removed here for anonymity].
256 | 2) On the Hugging Face hub, we advertise that the data is available for research purposes only and collect the user's legal name and email as proof of agreement before granting access.
257 |
258 | We thus decline any responsibility for misuse.
259 |
260 | To justify the fair use of MIDI data for research purposes, we try to follow the FAIR (Findable, Accessible, Interoperable, and Reusable) principles in our MIDI data collection. These principles are widely recognized and frequently cited within the data research community, providing a robust framework for ethical data management. By following the FAIR principles, we ensure that our dataset is managed responsibly, supporting its use in research while maintaining high standards of accessibility, interoperability, and reusability.
261 |
262 |
263 | In navigating the use of MIDI datasets for research and creative explorations, it is imperative to consider the ethical implications inherent in dataset bias. MIDI dataset bias often reflects the prevailing practices in Western contemporary music production, where certain instruments, notably the piano and drums, dominate due to their inherent MIDI compatibility. The piano is a primary compositional tool and a ubiquitous MIDI controller and keyboard, facilitating input for a wide range of virtual instruments and synthesizers. Similarly, drums, whether through drum machines or MIDI drum pads, enjoy widespread use for rhythm programming and beat production. This prevalence stems from their intuitive interface and versatility within digital audio workstations. Consequently, MIDI datasets tend to be skewed towards piano and drums, with fewer representations of other instruments, particularly those that may require more nuanced interpretation or are less commonly played using MIDI controllers.
264 |
265 |
266 | A potential issue with the detected loops in the GigaMIDI dataset arises from the possibility that similar note content may appear, particularly in loop-focused applications. To mitigate this, we implemented an additional deduplication process for the detected loops. This process involved using MD5 checksums based on the extracted music loop content to ensure that identical loops are not provided to users.
267 |
268 |
269 | Another potential issue with the dataset is the album-effect. When using the dataset for machine learning tasks, a random split may result in data with nearly identical note content appearing in both the evaluation and training splits of the GigaMIDI dataset. To address this potential issue, we provide metadata, including the composer's name, uniform piece title, performer’s name, and genre, where available. Additionally, the GigaMIDI dataset includes a substantial portion of drum grooves, which are single-track MIDI files; such files typically do not contribute to the album-effect.
270 |
271 |
272 | Lastly, all source data is duly acknowledged and cited in accordance with fair use and ethical standards. More than 50% of the dataset was collected through web scraping and author-led initiatives, which include manual data collection from online sources and retrieval from Zenodo and GitHub. To ensure transparency and prevent misuse, links to data sources for each subset are systematically organized and provided in our GitHub repository, enabling users to identify and verify the datasets used.
273 |
274 | ## FAIR (Findable, Accessible, Interoperable, Reusable) principles with the GigaMIDI dataset
275 |
276 | The FAIR (Findable, Accessible, Interoperable, Reusable) principles serve as a framework to ensure that data is well-managed, easily discoverable, and usable for a broad range of purposes in research. These principles are particularly important in the context of data management to facilitate open science, collaboration, and reproducibility.
277 |
278 | 1. Findable: Data should be easily discoverable by both humans and machines. This is typically achieved through proper metadata, traceable source links and searchable resources. Applying this to MIDI data, each subset of MIDI files collected from public domain sources should be accompanied by clear and consistent metadata. For example, organizing the source links of each data subset, as done with the GigaMIDI dataset, ensures that each source can be easily traced and referenced, improving discoverability.
279 |
280 | 2. Accessible: Once found, data should be easily retrievable using standard protocols. Accessibility does not necessarily imply open access, but it does mean that data should be available under well-defined conditions. For the GigaMIDI dataset, hosting the data on platforms like Hugging Face Hub improves accessibility, as these platforms provide efficient data retrieval mechanisms, especially for large-scale datasets. Ensuring that MIDI data is accessible for public use, while respecting any applicable licenses, supports wider research and analysis in music computing.
281 |
282 | 3. Interoperable: Data should be structured in such a way that it can be integrated with other datasets and used by various applications. MIDI data, being a widely accepted format in music research, is inherently interoperable, especially when standardized metadata and file formats are used. By ensuring that the GigaMIDI dataset complies with widely adopted standards and supports integration with state-of-the-art libraries in symbolic music processing, such as Symusic (https://github.com/Yikai-Liao/symusic) and MidiTok (https://github.com/Natooz/MidiTok), the dataset enhances its utility for music researchers and practitioners working across different platforms and systems.
283 |
284 | 4. Reusable: Data should be well-documented and licensed so it can be reused in future research. Reusability is ensured through proper metadata, clear licenses, and documentation of provenance. In the case of GigaMIDI, aggregating all subsets from public domain sources and linking them to the original sources strengthens the reproducibility and traceability of the data. This practice allows future researchers to not only use the dataset but also verify and expand upon it by referring to the original data sources.
285 |
286 | In summary, applying FAIR principles to managing MIDI data, such as the GigaMIDI dataset, ensures that the data is organized in a manner that promotes reproducibility and traceability. By clearly documenting the source links of each subset and ensuring the dataset is findable, accessible, interoperable, and reusable, the data becomes a robust resource for the research community.
287 |
288 |
289 |
295 |
--------------------------------------------------------------------------------
/loops_nomml/corr_mat_fast.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Sequence, Tuple, Dict, List, Optional
4 |
5 | import os
6 | import numpy as np
7 | import threading
8 |
9 | from note_set_fast import NoteSet # updated import path if you renamed file
10 |
11 | if TYPE_CHECKING:
12 | from symusic import Note
13 |
14 | # Module-level small cache for duration computations to avoid repeated work
15 | # Keyed by (start, end, tb_first, tb_last, tb_len)
16 | _duration_cache: Dict[Tuple[int, int, int, int, int], float] = {}
17 | _duration_cache_lock = threading.Lock()
18 |
19 | # Safety caps and windowing parameters (tune to your environment)
20 | _MAX_TICKS_KEEP = 100_000 # hard cap on ticks array length
21 | _TRUNCATE_KEEP = 20_000 # keep this many ticks from head and tail if truncating
22 | _MAX_NOTESETS_FOR_FULL_CORR = 5000 # if n > this, use windowed correlation
23 | _WINDOW_SIZE = 4000 # window size for windowed correlation
24 | _WINDOW_OVERLAP = 500 # overlap between windows to catch cross-boundary runs
25 |
26 | # New: sanity threshold for tick arrays used for interpolation
27 | _MAX_TICK = 10 ** 9
28 |
29 | # Optional Numba acceleration flags (opt-in via environment)
30 | _USE_NUMBA = os.environ.get("USE_NUMBA", "0") == "1"
31 | _USE_CUDA = os.environ.get("USE_CUDA", "0") == "1"
32 |
33 | _NUMBA_AVAILABLE = False
34 | _CUDA_AVAILABLE = False
35 | _njit = None
36 | _cuda = None
37 |
38 | if _USE_NUMBA:
39 | try:
40 | from numba import njit, prange
41 | _NUMBA_AVAILABLE = True
42 | _njit = njit
43 | except Exception:
44 | _NUMBA_AVAILABLE = False
45 | _njit = None
46 |
47 | if _USE_CUDA and _NUMBA_AVAILABLE:
48 | try:
49 | from numba import cuda
50 | _CUDA_AVAILABLE = cuda.is_available()
51 | _cuda = cuda
52 | except Exception:
53 | _CUDA_AVAILABLE = False
54 | _cuda = None
55 |
56 | # If CUDA requested but not available, fall back to CPU numba if available
57 | if _USE_CUDA and not _CUDA_AVAILABLE and _NUMBA_AVAILABLE:
58 | # keep _USE_CUDA False to avoid trying GPU kernels later
59 | _USE_CUDA = False
60 |
61 | # Helper: safe strictly increasing check
62 | def _is_strictly_increasing(arr: np.ndarray) -> bool:
63 | if arr.size < 2:
64 | return True
65 | try:
66 | return bool((arr[1:] > arr[:-1]).all())
67 | except Exception:
68 | prev = arr[0]
69 | for v in arr[1:]:
70 | if v <= prev:
71 | return False
72 | prev = v
73 | return True
74 |
75 |
76 | def _sanitize_ticks_beats_once(ticks_beats: Sequence[int]) -> np.ndarray:
77 | tb = np.asarray(ticks_beats, dtype=np.int64)
78 | if tb.size == 0:
79 | return tb
80 |
81 | # Remove implausible tick values
82 | try:
83 | tb = tb[np.isfinite(tb)]
84 | except Exception:
85 | pass
86 |
87 | # Clip out-of-range ticks
88 | if tb.size > 0:
89 | tb = tb[(tb >= 0) & (tb <= _MAX_TICK)]
90 |
91 | if tb.size == 0:
92 | return tb
93 |
94 | if _is_strictly_increasing(tb):
95 | if tb.size > _MAX_TICKS_KEEP:
96 | head = tb[:_TRUNCATE_KEEP]
97 | tail = tb[-_TRUNCATE_KEEP:]
98 | tb = np.concatenate((head, tail))
99 | return tb
100 |
101 | try:
102 | tb_unique = np.unique(tb)
103 | except Exception:
104 | try:
105 | tb_list = sorted(set(int(x) for x in tb if np.isfinite(x)))
106 | tb_unique = np.asarray(tb_list, dtype=np.int64)
107 | except Exception:
108 | return np.asarray([], dtype=np.int64)
109 |
110 | if tb_unique.size > _MAX_TICKS_KEEP:
111 | head = tb_unique[:_TRUNCATE_KEEP]
112 | tail = tb_unique[-_TRUNCATE_KEEP:]
113 | tb_unique = np.concatenate((head, tail))
114 |
115 | return tb_unique
116 |
117 |
118 | # -------------------------
119 | # Numba-accelerated helpers
120 | # -------------------------
121 | # CPU njit implementation of diagonal processing (pure-Python fallback exists)
122 | if _NUMBA_AVAILABLE and not _USE_CUDA:
123 | # We implement a simplified njit version that mirrors the original logic.
124 | # Note: numba njit does not support all numpy conveniences used above, so
125 | # we implement the core loops in a form numba accepts.
126 | @njit # type: ignore[name-defined]
127 | def _process_diagonals_numba_cpu(key_ids: np.ndarray, is_barline: np.ndarray, out_mat: np.ndarray, offset: int) -> None:
128 | n_local = key_ids.size
129 | if n_local < 2:
130 | return
131 |
132 | # Fast path for first element matching later ones when it's a barline
133 | if is_barline[0]:
134 | first_id = key_ids[0]
135 | for idx in range(1, n_local):
136 | if key_ids[idx] == first_id:
137 | out_mat[offset, offset + idx] = 1
138 |
139 | for k in range(1, n_local):
140 | # left: 0..n_local-k-1, right: k..n_local-1
141 | prev_val = 0
142 | for i_local in range(0, n_local - k):
143 | j_local = i_local + k
144 | if key_ids[i_local] == key_ids[j_local]:
145 | # contiguous run handling: we need to detect run starts/ends
146 | # We emulate the "growing" counts by checking previous diagonal value
147 | if i_local == 0:
148 | if is_barline[i_local]:
149 | prev_val = 1
150 | out_mat[offset + i_local, offset + j_local] = 1
151 | else:
152 | prev_val = 0
153 | else:
154 | if prev_val == 0 and (not is_barline[i_local]):
155 | prev_val = 0
156 | else:
157 | prev_val = prev_val + 1
158 | out_mat[offset + i_local, offset + j_local] = prev_val
159 | else:
160 | prev_val = 0
161 |
162 | # GPU kernel: compute equality matches for a given diagonal offset k
163 | if _CUDA_AVAILABLE:
164 | @cuda.jit # type: ignore[name-defined]
165 | def _cuda_kernel_match(key_ids, k, n_local, matches):
166 | i = cuda.grid(1)
167 | if i < n_local - k:
168 | matches[i] = 1 if key_ids[i] == key_ids[i + k] else 0
169 |
170 |
171 | def _process_diagonals_into_matrix(key_ids_slice: np.ndarray, is_barline_slice: np.ndarray, out_mat: np.ndarray, offset: int) -> None:
172 | """
173 | Wrapper that selects accelerated implementation if available, otherwise uses
174 | the original Python/Numpy implementation.
175 | """
176 | n_local = key_ids_slice.size
177 | if n_local < 2:
178 | return
179 |
180 | # If CUDA is enabled and available, use a GPU kernel to compute per-diagonal matches
181 | if _CUDA_AVAILABLE:
182 | try:
183 | # copy slice to device
184 | key_ids_dev = _cuda.to_device(key_ids_slice.astype(np.int32))
185 | # allocate a device array for matches (max length n_local)
186 | for k in range(1, n_local):
187 | matches_dev = _cuda.device_array(n_local - k, dtype=np.uint8)
188 | threadsperblock = 128
189 | blocks = (n_local - k + threadsperblock - 1) // threadsperblock
190 | _cuda_kernel_match[blocks, threadsperblock](key_ids_dev, k, n_local, matches_dev)
191 | matches = matches_dev.copy_to_host()
192 | if not matches.any():
193 | continue
194 |
195 | # find runs in matches (host side)
196 | true_idx = np.flatnonzero(matches)
197 | if true_idx.size == 0:
198 | continue
199 |
200 | # Identify contiguous runs in true_idx
201 | if true_idx.size == 1:
202 | runs = [(int(true_idx[0]), int(true_idx[0]))]
203 | else:
204 | diffs = np.diff(true_idx)
205 | breaks = np.nonzero(diffs > 1)[0]
206 | run_starts = np.concatenate(([0], breaks + 1))
207 | run_ends = np.concatenate((breaks, [true_idx.size - 1]))
208 | runs = [(int(true_idx[s]), int(true_idx[e])) for s, e in zip(run_starts, run_ends)]
209 |
210 | # For each run, walk and set incremental counts
211 | for run_start, run_end in runs:
212 | prev_val = 0
213 | for i_local in range(run_start, run_end + 1):
214 | j_local = i_local + k
215 | global_i = offset + i_local
216 | global_j = offset + j_local
217 | if i_local == 0:
218 | if is_barline_slice[i_local]:
219 | prev_val = 1
220 | out_mat[global_i, global_j] = 1
221 | else:
222 | prev_val = 0
223 | else:
224 | if prev_val == 0 and not is_barline_slice[i_local]:
225 | prev_val = 0
226 | else:
227 | prev_val = prev_val + 1
228 | out_mat[global_i, global_j] = prev_val
229 | # also handle first-element barline fast path (matches against first element)
230 | if is_barline_slice[0]:
231 | first_id = int(key_ids_slice[0])
232 | matches = (key_ids_slice[1:] == first_id)
233 | if matches.any():
234 | global_i = offset
235 | global_js = np.nonzero(matches)[0] + offset + 1
236 | out_mat[global_i, global_js] = 1
237 | return
238 | except Exception:
239 | # If GPU path fails for any reason, fall back to CPU/NumPy implementation below
240 | pass
241 |
242 | # If Numba CPU is available and requested, use njit implementation
243 | if _NUMBA_AVAILABLE and not _USE_CUDA:
244 | try:
245 | _process_diagonals_numba_cpu(key_ids_slice.astype(np.int32), is_barline_slice.astype(np.uint8), out_mat, offset)
246 | return
247 | except Exception:
248 | # fall back to Python implementation on error
249 | pass
250 |
251 | # Original Python/Numpy implementation (fallback)
252 | # Fast path for first element matching later ones when it's a barline
253 | if is_barline_slice[0]:
254 | first_id = key_ids_slice[0]
255 | matches = (key_ids_slice[1:] == first_id)
256 | if matches.any():
257 | global_i = offset
258 | global_js = np.nonzero(matches)[0] + offset + 1
259 | out_mat[global_i, global_js] = 1
260 |
261 | # For each diagonal offset k, find matches between left and right slices
262 | for k in range(1, n_local):
263 | left = key_ids_slice[: n_local - k]
264 | right = key_ids_slice[k : n_local]
265 | matches = (left == right)
266 | if not matches.any():
267 | continue
268 |
269 | true_idx = np.flatnonzero(matches)
270 | if true_idx.size == 0:
271 | continue
272 |
273 | # Identify contiguous runs in true_idx
274 | if true_idx.size == 1:
275 | runs = [(int(true_idx[0]), int(true_idx[0]))]
276 | else:
277 | diffs = np.diff(true_idx)
278 | breaks = np.nonzero(diffs > 1)[0]
279 | run_starts = np.concatenate(([0], breaks + 1))
280 | run_ends = np.concatenate((breaks, [true_idx.size - 1]))
281 | runs = [(int(true_idx[s]), int(true_idx[e])) for s, e in zip(run_starts, run_ends)]
282 |
283 | # For each run, walk and set incremental counts; this inner loop is
284 | # necessary to produce the "growing" correlation counts along diagonals.
285 | for run_start, run_end in runs:
286 | prev_val = 0
287 | for i_local in range(run_start, run_end + 1):
288 | j_local = i_local + k
289 | global_i = offset + i_local
290 | global_j = offset + j_local
291 | if i_local == 0:
292 | if is_barline_slice[i_local]:
293 | prev_val = 1
294 | out_mat[global_i, global_j] = 1
295 | else:
296 | prev_val = 0
297 | else:
298 | if prev_val == 0 and not is_barline_slice[i_local]:
299 | prev_val = 0
300 | else:
301 | prev_val = prev_val + 1
302 | out_mat[global_i, global_j] = prev_val
303 |
304 |
305 | def calc_correlation(note_sets: Sequence[NoteSet]) -> np.ndarray:
306 | n = len(note_sets)
307 | if n < 2:
308 | return np.zeros((n, n), dtype=np.int16)
309 |
310 | key_to_id: Dict[Tuple[int, frozenset], int] = {}
311 | key_ids = np.empty(n, dtype=np.int32)
312 | next_id = 1
313 | for i, ns in enumerate(note_sets):
314 | # Defensive: ensure ns has expected attributes and reasonable values
315 | try:
316 | dur = int(ns.duration)
317 | pitches = frozenset(ns.pitches)
318 | if dur < 0 or dur > _MAX_TICK:
319 | dur = 0
320 | except Exception:
321 | dur = 0
322 | pitches = frozenset()
323 | key = (dur, pitches)
324 | kid = key_to_id.get(key)
325 | if kid is None:
326 | kid = next_id
327 | key_to_id[key] = kid
328 | next_id += 1
329 | key_ids[i] = kid
330 |
331 | is_barline = np.fromiter((bool(getattr(ns, "is_barline", lambda: False)()) for ns in note_sets), dtype=bool, count=n)
332 | corr_mat = np.zeros((n, n), dtype=np.int16)
333 |
334 | if n <= _MAX_NOTESETS_FOR_FULL_CORR:
335 | _process_diagonals_into_matrix(key_ids, is_barline, corr_mat, offset=0)
336 | return corr_mat
337 |
338 | win = _WINDOW_SIZE
339 | overlap = min(_WINDOW_OVERLAP, win // 4)
340 | if overlap >= win:
341 | overlap = win // 4
342 | stride = win - overlap
343 | if stride <= 0:
344 | stride = max(1, win // 2)
345 |
346 | start = 0
347 | while start < n:
348 | end = min(n, start + win)
349 | key_ids_slice = key_ids[start:end]
350 | is_barline_slice = is_barline[start:end]
351 | _process_diagonals_into_matrix(key_ids_slice, is_barline_slice, corr_mat, offset=start)
352 | if end == n:
353 | break
354 | start += stride
355 |
356 | return corr_mat
357 |
358 |
359 | # The rest of the file (duration helpers and loop detection) remains unchanged
360 | # (kept here for completeness). They are unchanged from the previous version,
361 | # but are included so this module is self-contained.
362 |
363 | def get_loop_density(loop: Sequence[NoteSet], num_beats: int | float) -> float:
364 | if num_beats == 0:
365 | return 0.0
366 | active = 0
367 | for ns in loop:
368 | if ns.duration != 0:
369 | active += 1
370 | return active / float(num_beats)
371 |
372 |
373 | def is_empty_loop(loop: Sequence[Note]) -> bool:
374 | for ns in loop:
375 | if ns.pitches:
376 | return False
377 | return True
378 |
379 |
380 | def compare_loops(p1: Sequence[NoteSet], p2: Sequence[NoteSet], min_rep_beats: int | float) -> int:
381 | check_len = int(round(min_rep_beats))
382 | check_len = min(check_len, len(p1), len(p2))
383 | for i in range(check_len):
384 | if p1[i] != p2[i]:
385 | return 0
386 | return 1 if len(p1) < len(p2) else 2
387 |
388 |
389 | def test_loop_exists(loop_list: Sequence[Sequence[NoteSet]], loop: Sequence[NoteSet], min_rep_beats: int | float) -> Optional[int]:
390 | for i, pat in enumerate(loop_list):
391 | result = compare_loops(loop, pat, min_rep_beats)
392 | if result == 1:
393 | return -1
394 | if result == 2:
395 | return i
396 | return None
397 |
398 |
399 | def filter_sub_loops(candidate_indices: Dict[float, List[Tuple[int, int]]]) -> List[Tuple[int, int, float]]:
400 | if not candidate_indices:
401 | return []
402 |
403 | final: List[Tuple[int, int, float]] = []
404 | for duration in sorted(candidate_indices.keys()):
405 | intervals = candidate_indices[duration]
406 | if not intervals:
407 | continue
408 | intervals_sorted = sorted(intervals, key=lambda x: x[0])
409 | merged_start, merged_end = intervals_sorted[0]
410 | for s, e in intervals_sorted[1:]:
411 | if s == merged_end:
412 | merged_end = e
413 | else:
414 | final.append((merged_start, merged_end, duration))
415 | merged_start, merged_end = s, e
416 | final.append((merged_start, merged_end, duration))
417 |
418 | seen = set()
419 | unique_final: List[Tuple[int, int, float]] = []
420 | for s, e, d in sorted(final, key=lambda x: (x[0], x[1], x[2])):
421 | key = (s, e, d)
422 | if key not in seen:
423 | seen.add(key)
424 | unique_final.append((s, e, d))
425 | return unique_final
426 |
427 |
428 | def _compute_frac_positions_vectorized(values: np.ndarray, tb: np.ndarray) -> np.ndarray:
429 | """
430 | Vectorized computation of fractional positions for an array of tick values.
431 | Returns a float array where integer positions correspond to exact tick indices,
432 | and fractional positions are interpolated between indices.
433 | """
434 | if values.size == 0:
435 | return np.array([], dtype=float)
436 |
437 | tb_size = tb.size
438 | # Use searchsorted to find insertion positions
439 | pos = tb.searchsorted(values, side='left') # pos in [0..tb_size]
440 | res = np.empty(values.shape, dtype=float)
441 |
442 | # Exact matches where pos < tb_size and tb[pos] == value
443 | exact_mask = (pos < tb_size) & (tb[pos] == values)
444 | if exact_mask.any():
445 | res[exact_mask] = pos[exact_mask].astype(float)
446 |
447 | # pos == 0 and not exact
448 | mask_pos0 = (pos == 0) & (~exact_mask)
449 | if mask_pos0.any():
450 | prev_tick = tb[0]
451 | next_tick = tb[1] if tb_size > 1 else tb[0] + 1
452 | denom = next_tick - prev_tick if next_tick != prev_tick else 1
453 | res[mask_pos0] = (values[mask_pos0] - prev_tick) / denom
454 |
455 | # pos >= tb_size (to the right of last tick)
456 | mask_right = (pos >= tb_size) & (~exact_mask)
457 | if mask_right.any():
458 | if tb_size > 1:
459 | prev_tick = tb[-2]
460 | last_tick = tb[-1]
461 | denom = last_tick - prev_tick if last_tick != prev_tick else 1
462 | res[mask_right] = float(tb_size - 1) + (values[mask_right] - last_tick) / denom
463 | else:
464 | # single tick in tb
465 | res[mask_right] = float(0) + (values[mask_right] - tb[0]) # fallback
466 |
467 | # middle cases: 0 < pos < tb_size and not exact
468 | mask_mid = (~exact_mask) & (~mask_pos0) & (~mask_right)
469 | if mask_mid.any():
470 | pos_mid = pos[mask_mid]
471 | prev_tick = tb[pos_mid - 1]
472 | next_tick = tb[pos_mid]
473 | denom = next_tick - prev_tick
474 | # avoid division by zero
475 | denom = np.where(denom == 0, 1, denom)
476 | res[mask_mid] = (pos_mid - 1).astype(float) + (values[mask_mid] - prev_tick) / denom
477 |
478 | return res
479 |
480 |
481 | def _compute_durations_batch(starts: np.ndarray, ends: np.ndarray, tb: np.ndarray) -> np.ndarray:
482 | """
483 | Compute durations (in beats) for arrays of start and end tick values.
484 | Vectorized and uses _compute_frac_positions_vectorized.
485 | """
486 | if tb is None or tb.size == 0:
487 | return np.zeros_like(starts, dtype=float)
488 |
489 | # Convert to numpy arrays
490 | starts_arr = np.asarray(starts, dtype=np.int64)
491 | ends_arr = np.asarray(ends, dtype=np.int64)
492 |
493 | # Compute fractional positions for starts and ends
494 | start_pos = _compute_frac_positions_vectorized(starts_arr, tb)
495 | end_pos = _compute_frac_positions_vectorized(ends_arr, tb)
496 |
497 | durations = end_pos - start_pos
498 | # Clip negative durations to 0
499 | durations = np.where(durations < 0.0, 0.0, durations)
500 | return durations.astype(float)
501 |
502 |
503 | def get_duration_beats(start: int, end: int, tb: np.ndarray, tick_to_idx: Optional[Dict[int, int]] = None) -> float:
504 | """
505 | Backwards-compatible single-pair duration computation.
506 | For heavy workloads prefer the batch computation used in get_valid_loops.
507 | """
508 | if tb is None or tb.size == 0:
509 | return 0.0
510 |
511 | if end <= start:
512 | return 0.0
513 |
514 | try:
515 | tb_first = int(tb[0])
516 | tb_last = int(tb[-1])
517 | except Exception:
518 | return 0.0
519 |
520 | key = (int(start), int(end), tb_first, tb_last, int(tb.size))
521 | with _duration_cache_lock:
522 | cached = _duration_cache.get(key)
523 | if cached is not None:
524 | return cached
525 |
526 | # Fast path for single-element tb
527 | if tb.size == 1:
528 | duration = float(max(0.0, end - start))
529 | with _duration_cache_lock:
530 | _duration_cache[key] = duration
531 | return duration
532 |
533 | # Fallback to vectorized batch helper for a single pair
534 | dur = _compute_durations_batch(np.array([start], dtype=np.int64), np.array([end], dtype=np.int64), tb)[0]
535 | with _duration_cache_lock:
536 | _duration_cache[key] = dur
537 | return dur
538 |
539 |
540 | def get_valid_loops(
541 | note_sets: Sequence[NoteSet],
542 | corr_mat: np.ndarray,
543 | ticks_beats: Sequence[int],
544 | min_rep_notes: int = 4,
545 | min_rep_beats: float = 2.0,
546 | min_beats: float = 32.0,
547 | max_beats: float = 32.0,
548 | min_loop_note_density: float = 0.5,
549 | ) -> Tuple[List[Sequence[NoteSet]], List[Tuple[int, int, float, float]]]:
550 | """
551 | Return detected loops and metadata.
552 | """
553 | min_rep_notes += 1 # original behavior to not count barlines
554 | x_idx, y_idx = np.where(corr_mat == min_rep_notes)
555 | if x_idx.size == 0:
556 | return [], []
557 |
558 | tb_sanitized = _sanitize_ticks_beats_once(ticks_beats)
559 | if tb_sanitized.size == 0:
560 | # no valid beat ticks to compute durations -> no loops
561 | return [], []
562 |
563 | # Precompute tick->index mapping once (cheap relative to repeated rebuilds)
564 | try:
565 | tick_to_idx_global = {int(t): i for i, t in enumerate(tb_sanitized)}
566 | except Exception:
567 | tick_to_idx_global = {}
568 |
569 | valid_indices: Dict[float, List[Tuple[int, int]]] = {}
570 | # Collect unique key pairs to compute durations in batch
571 | unique_pairs = {}
572 | pairs_list_starts = []
573 | pairs_list_ends = []
574 | pairs_keys = []
575 |
576 | # First pass: collect candidate pairs and unique (start_tick, end_tick) pairs
577 | for xi, yi in zip(x_idx, y_idx):
578 | try:
579 | run_len = int(corr_mat[xi, yi])
580 | except Exception:
581 | continue
582 | start_x = xi - run_len + 1
583 | start_y = yi - run_len + 1
584 |
585 | if start_x < 0 or start_y < 0 or start_x >= len(note_sets) or start_y >= len(note_sets):
586 | continue
587 |
588 | try:
589 | loop_start_time = int(note_sets[start_x].start)
590 | loop_end_time = int(note_sets[start_y].start)
591 | except Exception:
592 | continue
593 |
594 | # sanity check tick magnitudes
595 | if loop_start_time < 0 or loop_end_time < 0 or loop_start_time > _MAX_TICK or loop_end_time > _MAX_TICK:
596 | continue
597 |
598 | key_pair = (loop_start_time, loop_end_time)
599 | if key_pair not in unique_pairs:
600 | unique_pairs[key_pair] = len(pairs_list_starts)
601 | pairs_list_starts.append(loop_start_time)
602 | pairs_list_ends.append(loop_end_time)
603 | pairs_keys.append(key_pair)
604 |
605 | if not pairs_list_starts:
606 | return [], []
607 |
608 | # Batch compute durations for all unique pairs
609 | starts_arr = np.asarray(pairs_list_starts, dtype=np.int64)
610 | ends_arr = np.asarray(pairs_list_ends, dtype=np.int64)
611 | durations_arr = _compute_durations_batch(starts_arr, ends_arr, tb_sanitized)
612 | # Round durations to 2 decimals to match previous behavior
613 | durations_rounded = np.round(durations_arr.astype(float), 2)
614 |
615 | # Build a mapping from key_pair to rounded duration
616 | keypair_to_duration: Dict[Tuple[int, int], float] = {
617 | kp: float(durations_rounded[idx]) for idx, kp in enumerate(pairs_keys)
618 | }
619 |
620 | # Second pass: populate valid_indices using computed durations
621 | for xi, yi in zip(x_idx, y_idx):
622 | try:
623 | run_len = int(corr_mat[xi, yi])
624 | except Exception:
625 | continue
626 | start_x = xi - run_len + 1
627 | start_y = yi - run_len + 1
628 |
629 | if start_x < 0 or start_y < 0 or start_x >= len(note_sets) or start_y >= len(note_sets):
630 | continue
631 |
632 | try:
633 | loop_start_time = int(note_sets[start_x].start)
634 | loop_end_time = int(note_sets[start_y].start)
635 | except Exception:
636 | continue
637 |
638 | key_pair = (loop_start_time, loop_end_time)
639 | loop_num_beats = keypair_to_duration.get(key_pair, 0.0)
640 |
641 | if min_beats <= loop_num_beats <= max_beats:
642 | valid_indices.setdefault(loop_num_beats, []).append((int(xi), int(yi)))
643 |
644 | if not valid_indices:
645 | return [], []
646 |
647 | filtered = filter_sub_loops(valid_indices)
648 |
649 | loops: List[Sequence[NoteSet]] = []
650 | loop_bp: List[Tuple[int, int, float, float]] = []
651 | corr_size = corr_mat.shape[0]
652 |
653 | for start_x, start_y, loop_num_beats in filtered:
654 | x = start_x
655 | y = start_y
656 | while x + 1 < corr_size and y + 1 < corr_size and corr_mat[x + 1, y + 1] > corr_mat[x, y]:
657 | x += 1
658 | y += 1
659 |
660 | beginning = x - int(corr_mat[x, y]) + 1
661 | end = y - int(corr_mat[x, y]) + 1
662 |
663 | if beginning < 0 or end < 0 or beginning >= len(note_sets) or end >= len(note_sets):
664 | continue
665 |
666 | start_tick = int(note_sets[beginning].start)
667 | end_tick = int(note_sets[end].start)
668 |
669 | # Try to get duration from keypair_to_duration first
670 | duration_beats = keypair_to_duration.get((start_tick, end_tick))
671 | if duration_beats is None:
672 | # Fallback: compute single pair (rare)
673 | duration_beats = get_duration_beats(start_tick, end_tick, tb_sanitized, tick_to_idx=tick_to_idx_global)
674 |
675 | if duration_beats >= min_rep_beats and not is_empty_loop(note_sets[beginning:end]):
676 | loop = note_sets[beginning : (end + 1)]
677 | loop_density = get_loop_density(loop, loop_num_beats)
678 | if loop_density < min_loop_note_density:
679 | continue
680 | exist_result = test_loop_exists(loops, loop, min_rep_beats)
681 | if exist_result is None:
682 | loops.append(loop)
683 | loop_bp.append((start_tick, end_tick, loop_num_beats, loop_density))
684 | elif exist_result > 0:
685 | loops[exist_result] = loop
686 | loop_bp[exist_result] = (start_tick, end_tick, loop_num_beats, loop_density)
687 |
688 | return loops, loop_bp
--------------------------------------------------------------------------------