├── .gitignore ├── LICENSE.txt ├── README.md ├── frechet_music_distance ├── __init__.py ├── __main__.py ├── dataset_loaders │ ├── __init__.py │ ├── abc_loader.py │ ├── dataset_loader.py │ ├── midi_as_mtf_loader.py │ └── utils.py ├── fmd.py ├── gaussian_estimators │ ├── __init__.py │ ├── bootstrapping_estimator.py │ ├── gaussian_estimator.py │ ├── leodit_wolf_estimator.py │ ├── max_likelihood_estimator.py │ ├── oas_estimator.py │ ├── shrikage_estimator.py │ └── utils.py ├── memory.py ├── models │ ├── __init__.py │ ├── clamp │ │ ├── __init__.py │ │ ├── clamp_extractor.py │ │ ├── clamp_model.py │ │ └── clamp_utils.py │ ├── clamp2 │ │ ├── __init__.py │ │ ├── clamp2_extractor.py │ │ ├── clamp2_model.py │ │ ├── config.py │ │ ├── m3_patch_encoder.py │ │ └── m3_patchilizer.py │ ├── feature_extractor.py │ └── utils.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt └── tests ├── conftest.py ├── data ├── abc │ ├── example_1.abc │ ├── example_2.abc │ ├── example_3.abc │ ├── example_4.abc │ └── example_5.abc └── midi │ ├── example_1.mid │ ├── example_2.mid │ ├── example_3.mid │ ├── example_4.mid │ └── example_5.mid └── test_fmd.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv*/ 2 | .venv*/ 3 | 4 | .idea/ 5 | cache/ 6 | **/__pycache__/ 7 | logs/ 8 | **/*.pth 9 | data/ 10 | !tests/data/ 11 | checkpoints/ 12 | .DS_store 13 | dist/ 14 | frechet_music_distance.egg-info/ 15 | .sander-wood/ 16 | .pytest_cache -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 jryban 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Frechet Music Distance 3 | 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2412.07948v2-b31b1b.svg)](https://arxiv.org/abs/2412.07948) 6 | 7 | ## Table of Contents 8 | - [Introduction](#introduction) 9 | - [Features](#features) 10 | - [Installation](#installation) 11 | - [Usage](#usage) 12 | - [Extending FMD](#extending-fmd) 13 | - [Citation](#citation) 14 | - [Acknowledgements](#acknowledgements) 15 | - [License](#license) 16 | 17 | 18 | ## Introduction 19 | A library for calculating Frechet Music Distance (FMD). This is an official implementation of the paper [_Frechet Music Distance: A Metric For Generative Symbolic Music Evaluation_](https://www.arxiv.org/abs/2412.07948). 20 | 21 | 22 | ## Features 23 | - Calculating FMD and FMD-Inf scores between two datasets for evaluation 24 | - Caching extracted features and distribution parameters to speedup subsequent computations 25 | - Support for various symbolic music representations (**MIDI** and **ABC**) 26 | - Support for various embedding models (**CLaMP 2**, **CLaMP 1**) 27 | - Support for various methods of estimating embedding distribution parameters (**MLE**, **Leodit Wolf**, **Shrinkage**, **OAS**, **Bootstrap**) 28 | - Computation of per-song FMD to find outliers in the dataset 29 | 30 | 31 | ## Installation 32 | 33 | The library can be installed from from [PyPi](https://pypi.org/project/frechet-music-distance/) using pip: 34 | ```bash 35 | pip install frechet-music-distance 36 | ``` 37 | 38 | **Note**: If it doesn't work try updating `pip`: 39 | ```bash 40 | pip install --upgrade pip 41 | ``` 42 | 43 | You can also install from source by cloning the repository and installing it locally: 44 | ```bash 45 | git clone https://github.com/jryban/frechet-music-distance.git 46 | cd frechet-music-distance 47 | pip install -e . 48 | ``` 49 | 50 | The library was tested on Linux and MacOS, but it should work on Windows as well. 51 | 52 | **Note**: If you encounter `NotOpenSSLWarning` please downgrade your `urllib3` version to `1.26.6`: 53 | ```bash 54 | pip install urllib3==1.26.6 55 | ``` 56 | or use a different version of Python that supports OpenSSL, by following the instructions provided in this [urllib3 GitHub issue](https://github.com/urllib3/urllib3/issues/3020) 57 | 58 | 59 | ## Usage 60 | The library currently supports **MIDI** and **ABC** symbolic music representations. 61 | 62 | **Note**: When using ABC Notation please ensure that each song is located in a separate file. 63 | 64 | ### Command Line 65 | 66 | ```bash 67 | fmd score [-h] [--model {clamp2,clamp}] [--estimator {mle,bootstrap,oas,shrinkage,leodit_wolf}] [--inf] [--steps STEPS] [--min_n MIN_N] [--clear-cache] [reference_dataset] [test_dataset] 68 | 69 | ``` 70 | 71 | #### Positional arguments: 72 | * `reference_dataset`: Path to the reference dataset 73 | * `test_dataset`: Path to the test dataset 74 | 75 | #### Options: 76 | * `--model {clamp2,clamp}, -m {clamp2,clamp}` Embedding model name 77 | * `--estimator {mle,bootstrap,oas,shrinkage,leodit_wolf}, -e {mle,bootstrap,oas,shrinkage,leodit_wolf}` Gaussian estimator for mean and covariance 78 | * `--inf` Use FMD-Inf extrapolation 79 | * `--steps STEPS, -s STEPS` Number of steps when calculating FMD-Inf 80 | * `--min_n MIN_N, -n MIN_N` Mininum sample size when calculating FMD-Inf (Must be smaller than the size of the test dataset) 81 | * `--clear-cache` Clear the pre-computed cache before FMD calculation 82 | 83 | #### Cleanup 84 | Additionaly the pre-computed cache can be cleared by executing: 85 | 86 | ```bash 87 | fmd clear 88 | ``` 89 | 90 | ### Python API 91 | 92 | #### Initialization 93 | You can initialize the metric like so: 94 | 95 | ```python 96 | from frechet_music_distance import FrechetMusicDistance 97 | 98 | metric = FrechetMusicDistance(feature_extractor='', gaussian_estimator='', verbose=True) 99 | ``` 100 | Valid values for `` are: `clamp2` (default), `clamp` 101 | Valid values for `` are: `mle` (default), `bootstrap`, `shrinkage`, `leodit_wolf`, `oas` 102 | 103 | If you want more control over feature extraction models and gaussian estimators, you can instantiate the object outside and pass it to the constructor directly like so: 104 | 105 | ```python 106 | from frechet_music_distance import FrechetMusicDistance 107 | from frechet_music_distance.gaussian_estimators import LeoditWolfEstimator, MaxLikelihoodEstimator, OASEstimator, BootstrappingEstimator, ShrinkageEstimator 108 | from frechet_music_distance.models import CLaMP2Extractor, CLaMPExtractor 109 | 110 | extractor = CLaMP2Extractor(verbose=True) 111 | estimator = ShrinkageEstimator(shrinkage=0.1) 112 | fmd = FrechetMusicDistance(feature_extractor=extractor, gaussian_estimator=estimator, verbose=True) 113 | 114 | ``` 115 | 116 | #### Standard FMD score 117 | ```python 118 | score = metric.score( 119 | reference_path="", 120 | test_path="" 121 | ) 122 | ``` 123 | 124 | 125 | #### FMD-Inf score 126 | ```python 127 | 128 | result = metric.score_inf( 129 | reference_path="", 130 | test_path="", 131 | steps= # default=25 132 | min_n= # default=500 133 | ) 134 | 135 | result.score # To get the FMD-Inf score 136 | result.r2 # To get the R^2 of FMD-Inf linear regression 137 | result.slope # To get the slope of the regression 138 | result.points # To get the point estimates used in FMD-Inf regression 139 | 140 | ``` 141 | 142 | #### Individual (per-song) score 143 | ```python 144 | 145 | result = metric.score_individual( 146 | reference_dataset="", 147 | test_song_path="", 148 | ) 149 | 150 | ``` 151 | 152 | #### Cleanup 153 | Additionaly the pre-computed cache can be cleared like so: 154 | 155 | ```python 156 | from frechet_music_distance.utils import clear_cache 157 | 158 | clear_cache() 159 | ``` 160 | 161 | ## Extending FMD 162 | 163 | ### Embedding Model 164 | 165 | You can add your own model as a feature extractor like so: 166 | 167 | ```python 168 | from frechet_music_distance.models import FeatureExtractor 169 | 170 | class MyExtractor(FeatureExtractor): 171 | 172 | def __init__(self, verbose: bool = True) -> None: 173 | super().__init__(verbose) 174 | """""" 175 | 176 | 177 | @torch.no_grad() 178 | def _extract_feature(self, data: Any) -> NDArray: 179 | """""" 180 | 181 | 182 | def extract_features(self, dataset_path: str | Path) -> NDArray: 183 | """""" 184 | 185 | return super()._extract_features(data) 186 | 187 | 188 | def extract_feature(self, filepath: str | Path) -> NDArray: 189 | """""" 190 | 191 | return self._extract_feature(data) 192 | 193 | 194 | ``` 195 | If your model uses the same data format as CLaMP2 or CLaMP you can use `frechet_music_distance.dataset_loaders.ABCLoader` or `frechet_music_distance.dataset_loaders.MIDIasMTFLoader` for loading music data. 196 | 197 | ### Gaussian Estimator 198 | 199 | You can add your own estimator like so: 200 | ```python 201 | from .gaussian_estimator import GaussianEstimator 202 | 203 | class MyEstimator(GaussianEstimator): 204 | 205 | def __init__(self, num_samples: int = 1000) -> None: 206 | super().__init__() 207 | """""" 208 | 209 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 210 | """""" 211 | 212 | return mean, cov 213 | ``` 214 | 215 | ## Supported Embedding Models 216 | 217 | | Model | Name in library | Description | Creator | 218 | | --- | --- | --- |-----------------| 219 | | [CLaMP](https://github.com/microsoft/muzic/tree/main/clamp) | `clamp` | CLaMP: Contrastive Language-Music Pre-training for Cross-Modal Symbolic Music Information Retrieval | Microsoft Muzic | 220 | | [CLaMP2](https://github.com/sanderwood/clamp2) | `clamp2` | CLaMP 2: Multimodal Music Information Retrieval Across 101 Languages Using Large Language Models | sanderwood | 221 | 222 | 223 | ## Citation 224 | 225 | If you use Frecheet Music Distance in your research, please cite the following paper: 226 | 227 | ```bibtex 228 | @article{retkowski2024frechet, 229 | title={Frechet Music Distance: A Metric For Generative Symbolic Music Evaluation}, 230 | author={Retkowski, Jan and St{\k{e}}pniak, Jakub and Modrzejewski, Mateusz}, 231 | journal={arXiv preprint arXiv:2412.07948}, 232 | year={2024} 233 | } 234 | ``` 235 | 236 | ## Acknowledgements 237 | 238 | This library uses code from the following repositories for handling the embedding models: 239 | * CLaMP 1: [microsoft/muzic/clamp](https://github.com/microsoft/muzic/tree/main/clamp) 240 | * CLaMP 2: [sanderwood/clamp2](https://github.com/sanderwood/clamp2) 241 | 242 | ## License 243 | This project is licensed under the **MIT License**. See the [LICENSE](LICENSE.txt) file for details. 244 | 245 | --- 246 | -------------------------------------------------------------------------------- /frechet_music_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .fmd import FrechetMusicDistance 2 | -------------------------------------------------------------------------------- /frechet_music_distance/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .fmd import FrechetMusicDistance 4 | from .utils import clear_cache 5 | 6 | 7 | def create_parser() -> argparse.ArgumentParser: 8 | """ 9 | Create the top-level parser and subparsers for the 'score' and 'clear' commands. 10 | """ 11 | parser = argparse.ArgumentParser(prog="fmd", description="A script for calculating Frechet Music Distance[FMD]") 12 | 13 | subparsers = parser.add_subparsers(dest="command", help="Sub-command to run") 14 | # ------------------------ 15 | # Subparser: "score" 16 | # ------------------------ 17 | score_parser = subparsers.add_parser("score", help="Compute Frechet Music Distance") 18 | score_parser.add_argument("reference_dataset", nargs="?", help="Path to reference dataset") 19 | score_parser.add_argument("test_dataset", nargs="?", help="Path to test dataset") 20 | score_parser.add_argument("--model", "-m", choices=["clamp2", "clamp"], default="clamp2", help="Embedding model name") 21 | score_parser.add_argument("--estimator", "-e", choices=["mle", "bootstrap", "oas", "shrinkage", "leodit_wolf"], default="mle", help="Gaussian estimator for mean and covariance") 22 | score_parser.add_argument("--inf", action="store_true", help="Use FMD-Inf extrapolation") 23 | score_parser.add_argument("--steps", "-s", default=25, type=int, help="Number of steps when calculating FMD-Inf") 24 | score_parser.add_argument("--min_n", "-n", default=500, type=int, help="Mininum sample size when calculating FMD-Inf (Must be smaller than the size of test dataset)") 25 | score_parser.add_argument("--clear-cache", action="store_true", help="Clear precomputed cache") 26 | 27 | # ------------------------ 28 | # Subparser: "clear" 29 | # ------------------------ 30 | subparsers.add_parser("clear", help="Clear precomputed cache") 31 | return parser 32 | 33 | 34 | def run_score(parser: argparse.ArgumentParser, args: argparse.Namespace, metric: FrechetMusicDistance) -> None: 35 | if args.clear_cache: 36 | clear_cache() 37 | if not args.reference_dataset or not args.test_dataset: 38 | parser.error("The following arguments are required: reference_dataset, test_dataset") 39 | 40 | if args.inf: 41 | result = metric.score_inf(args.reference_dataset, args.test_dataset, steps=args.steps, min_n=args.min_n) 42 | print(f"Frechet Music Distance [FMD-Inf]: {result.score}; R^2 = {result.r2}") 43 | 44 | else: 45 | score = metric.score(args.reference_dataset, args.test_dataset) 46 | print(f"Frechet Music Distance [FMD]: {score}") 47 | 48 | 49 | def main() -> None: 50 | parser = create_parser() 51 | args = parser.parse_args() 52 | if args.command == "clear": 53 | clear_cache() 54 | 55 | elif args.command == "score": 56 | 57 | metric = FrechetMusicDistance(feature_extractor=args.model, gaussian_estimator=args.estimator, verbose=True) 58 | run_score(parser, args, metric) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /frechet_music_distance/dataset_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .abc_loader import ABCLoader 2 | from .dataset_loader import DatasetLoader 3 | from .midi_as_mtf_loader import MIDIasMTFLoader 4 | -------------------------------------------------------------------------------- /frechet_music_distance/dataset_loaders/abc_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | from .dataset_loader import DatasetLoader 5 | 6 | 7 | class ABCLoader(DatasetLoader): 8 | 9 | def __init__(self, verbose: bool = True) -> None: 10 | supported_extensions = (".abc",) 11 | super().__init__(supported_extensions, verbose) 12 | 13 | def load_file(self, filepath: Union[str, Path]) -> str: 14 | self._validate_file(filepath) 15 | 16 | with open(filepath, "r", encoding="utf-8") as file: 17 | data = file.read() 18 | 19 | return data 20 | -------------------------------------------------------------------------------- /frechet_music_distance/dataset_loaders/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Iterable 5 | from functools import reduce 6 | from multiprocessing import Pool as ProcessPool 7 | from pathlib import Path 8 | from typing import Any 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class DatasetLoader(ABC): 14 | 15 | def __init__(self, supported_extensions: tuple[str], verbose: bool = True) -> None: 16 | self.verbose = verbose 17 | self._supported_extensions = supported_extensions 18 | 19 | @abstractmethod 20 | def load_file(self, filepath: str | Path) -> Any: 21 | pass 22 | 23 | def load_dataset(self, dataset: str | Path) -> Iterable[Any]: 24 | if self.verbose: 25 | print(f"Loading files from {dataset}") 26 | file_paths = self.get_file_paths(dataset) 27 | return self._load_files(file_paths) 28 | 29 | def load_dataset_async(self, dataset: str | Path) -> Iterable[Any]: 30 | if self.verbose: 31 | print(f"Loading files from {dataset}") 32 | file_paths = self.get_file_paths(dataset) 33 | return self._load_files_async(file_paths) 34 | 35 | def get_file_paths(self, dataset_path: str | Path) -> Iterable[str]: 36 | dataset_path = Path(dataset_path) 37 | file_paths = reduce( 38 | lambda acc, arr: acc + arr, 39 | [[str(f) for f in dataset_path.rglob(f"**/*{file_ext}")] for file_ext in self._supported_extensions] 40 | ) 41 | return file_paths 42 | 43 | def _load_files(self, file_paths: Iterable[str]) -> Iterable[Any]: 44 | results = [] 45 | 46 | pbar = tqdm(total=len(file_paths), disable=(not self.verbose)) 47 | 48 | for filepath in file_paths: 49 | res = self.load_file(filepath) 50 | results.append(res) 51 | pbar.update() 52 | 53 | return results 54 | 55 | def _load_files_async(self, file_paths: Iterable[str]) -> Iterable[Any]: 56 | task_results = [] 57 | 58 | pool = ProcessPool() 59 | pbar = tqdm(total=len(file_paths), disable=(not self.verbose)) 60 | 61 | for filepath in file_paths: 62 | res = pool.apply_async( 63 | self.load_file, 64 | args=(filepath,), 65 | callback=lambda *args, **kwargs: pbar.update(), 66 | ) 67 | task_results.append(res) 68 | pool.close() 69 | pool.join() 70 | 71 | return [task.get() for task in task_results] 72 | 73 | def _validate_file(self, filepath: str | Path) -> None: 74 | ext = Path(filepath).suffix 75 | if ext not in self._supported_extensions: 76 | msg = f"{self} supports the following extensions: {self._supported_extensions}, but got: {ext}" 77 | raise ValueError(msg) 78 | -------------------------------------------------------------------------------- /frechet_music_distance/dataset_loaders/midi_as_mtf_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import mido 6 | 7 | from .dataset_loader import DatasetLoader 8 | 9 | 10 | class MIDIasMTFLoader(DatasetLoader): 11 | 12 | def __init__(self, m3_compatible: bool = True, verbose: bool = True) -> None: 13 | supported_extensions = (".mid", ".midi") 14 | super().__init__(supported_extensions, verbose) 15 | self._m3_compatible = m3_compatible 16 | 17 | def load_file(self, filepath: str | Path) -> str: 18 | self._validate_file(filepath) 19 | 20 | skip_elements = {"text", "copyright", "track_name", "instrument_name", 21 | "lyrics", "marker", "cue_marker", "device_name", "sequencer_specific"} 22 | try: 23 | # Load a MIDI file 24 | mid = mido.MidiFile(str(filepath)) 25 | msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)] 26 | 27 | # Traverse the MIDI file 28 | for msg in mid.merged_track: 29 | if not self._m3_compatible or (msg.type != "sysex" and not (msg.is_meta and msg.type in skip_elements)): 30 | str_msg = self._msg_to_str(msg) 31 | msg_list.append(str_msg) 32 | except Exception as e: 33 | msg = f"Could not load file: {filepath}. Error: {e}" 34 | raise OSError(msg) from e 35 | 36 | return "\n".join(msg_list) 37 | 38 | def _msg_to_str(self, msg: str) -> str: 39 | str_msg = "" 40 | for value in msg.dict().values(): 41 | str_msg += " " + str(value) 42 | 43 | return str_msg.strip().encode("unicode_escape").decode("utf-8") 44 | -------------------------------------------------------------------------------- /frechet_music_distance/dataset_loaders/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | from .abc_loader import ABCLoader 6 | from .dataset_loader import DatasetLoader 7 | from .midi_as_mtf_loader import MIDIasMTFLoader 8 | 9 | 10 | def get_dataset_ext(dataset_path: str | Path, supported_extensions: set[str] | None = None) -> str | None: 11 | if supported_extensions is None: 12 | supported_extensions = {".mid", ".midi", ".abc"} 13 | 14 | for file in Path(dataset_path).rglob("**/*"): 15 | if file.suffix in supported_extensions: 16 | return file.suffix 17 | return None 18 | 19 | 20 | def get_dataset_loader_by_extension_and_model(file_ext: str, model_name: str, verbose: bool | None = None) -> DatasetLoader: 21 | if model_name == "clamp": 22 | if file_ext == ".abc": 23 | return ABCLoader(verbose=verbose) 24 | 25 | elif model_name == "clamp2": 26 | if file_ext == ".abc": 27 | return ABCLoader(verbose=verbose) 28 | elif file_ext in {".midi", ".mid"}: 29 | return MIDIasMTFLoader(verbose=verbose) 30 | 31 | msg = f"Unsupported file extension {file_ext} and model {model_name} combination" 32 | raise ValueError(msg) -------------------------------------------------------------------------------- /frechet_music_distance/fmd.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import scipy.linalg 8 | from numpy.typing import NDArray 9 | from tqdm import tqdm 10 | 11 | from .gaussian_estimators import GaussianEstimator, MaxLikelihoodEstimator 12 | from .gaussian_estimators.utils import get_estimator_by_name 13 | from .models import FeatureExtractor 14 | from .models.utils import get_feature_extractor_by_name 15 | 16 | 17 | @dataclass 18 | class FMDInfResults: 19 | score: float 20 | slope: float 21 | r2: float 22 | points: list[tuple[int, float]] 23 | 24 | 25 | class FrechetMusicDistance: 26 | 27 | def __init__( 28 | self, 29 | feature_extractor: str | FeatureExtractor = "clamp2", 30 | gaussian_estimator: str | GaussianEstimator = "mle", 31 | verbose: bool = True, 32 | ) -> None: 33 | if isinstance(feature_extractor, str): 34 | feature_extractor = get_feature_extractor_by_name(feature_extractor, verbose=verbose) 35 | 36 | if isinstance(gaussian_estimator, str): 37 | gaussian_estimator = get_estimator_by_name(gaussian_estimator) 38 | 39 | self._feature_extractor = feature_extractor 40 | self._gaussian_estimator = gaussian_estimator 41 | self._verbose = verbose 42 | 43 | def score(self, reference_path: str | Path, test_path: str | Path) -> float: 44 | reference_features = self._feature_extractor.extract_features(reference_path) 45 | mean_reference, covariance_reference = self._gaussian_estimator.estimate_parameters(reference_features) 46 | 47 | test_features = self._feature_extractor.extract_features(test_path) 48 | mean_test, covariance_test = self._gaussian_estimator.estimate_parameters(test_features) 49 | 50 | return self._compute_fmd(mean_reference, mean_test, covariance_reference, covariance_test) 51 | 52 | def score_inf( 53 | self, 54 | reference_path: str | Path, 55 | test_path: str | Path, 56 | steps: int = 25, 57 | min_n: int = 500, 58 | ) -> FMDInfResults: 59 | 60 | reference_features = self._feature_extractor.extract_features(reference_path) 61 | test_features = self._feature_extractor.extract_features(test_path) 62 | mean_reference, covariance_reference = self._gaussian_estimator.estimate_parameters(reference_features) 63 | 64 | score, slope, r2, points = self._compute_fmd_inf(mean_reference, covariance_reference, test_features, steps, min_n) 65 | return FMDInfResults(score, slope, r2, points) 66 | 67 | def score_individual(self, reference_path: str | Path, test_song_path: str | Path) -> float: 68 | reference_features = self._feature_extractor.extract_features(reference_path) 69 | test_feature = self._feature_extractor.extract_feature(test_song_path) 70 | mean_reference, covariance_reference = self._gaussian_estimator.estimate_parameters(reference_features) 71 | mean_test, covariance_test = test_feature.flatten(), covariance_reference 72 | 73 | return self._compute_fmd(mean_reference, mean_test, covariance_reference, covariance_test) 74 | 75 | def _compute_fmd( 76 | self, 77 | mean_reference: NDArray, 78 | mean_test: NDArray, 79 | cov_reference: NDArray, 80 | cov_test: NDArray, 81 | eps: float = 1e-6, 82 | ) -> float: 83 | mu_test = np.atleast_1d(mean_test) 84 | mu_ref = np.atleast_1d(mean_reference) 85 | 86 | sigma_test = np.atleast_2d(cov_test) 87 | sigma_ref = np.atleast_2d(cov_reference) 88 | 89 | assert ( 90 | mu_test.shape == mu_ref.shape 91 | ), f"Reference and test mean vectors have different dimensions, {mu_test.shape} and {mu_ref.shape}" 92 | assert ( 93 | sigma_test.shape == sigma_ref.shape 94 | ), f"Reference and test covariances have different dimensions, {sigma_test.shape} and {sigma_ref.shape}" 95 | 96 | diff = mu_test - mu_ref 97 | 98 | # Product might be almost singular 99 | covmean, _ = scipy.linalg.sqrtm(sigma_test.dot(sigma_ref), disp=False) 100 | if not np.isfinite(covmean).all(): 101 | msg = f"FMD calculation produces singular product; adding {eps} to diagonal of cov estimates" 102 | if self._verbose: 103 | print(msg) 104 | offset = np.eye(sigma_test.shape[0]) * eps 105 | covmean = scipy.linalg.sqrtm((sigma_test + offset).dot(sigma_ref + offset)) 106 | 107 | # Numerical error might give slight imaginary component 108 | if np.iscomplexobj(covmean): 109 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 110 | m = np.max(np.abs(covmean.imag)) 111 | msg = f"Imaginary component {m}" 112 | raise ValueError(msg) 113 | covmean = covmean.real 114 | 115 | tr_covmean = np.trace(covmean) 116 | 117 | return (diff.dot(diff) + np.trace(sigma_test) + np.trace(sigma_ref) - 2 * tr_covmean).item() 118 | 119 | def _compute_fmd_inf( 120 | self, 121 | mean_reference: NDArray, 122 | cov_reference: NDArray, 123 | test_features: NDArray, 124 | steps: int = 25, 125 | min_n: int = 500, 126 | ) -> tuple[float, float, float, NDArray]: 127 | 128 | # Calculate maximum n 129 | max_n = len(test_features) 130 | 131 | assert min_n < max_n, f"min_n={min_n} must be smaller than number of elements in the test set: max_n={max_n}" 132 | 133 | # Generate list of ns to use 134 | ns = [int(n) for n in np.linspace(min_n, max_n, steps)] 135 | results = [] 136 | rng = np.random.default_rng() 137 | 138 | for n in tqdm(ns, desc="Calculating FMD-inf", disable=(not self._verbose)): 139 | # Select n feature frames randomly (with replacement) 140 | indices = rng.choice(test_features.shape[0], size=n, replace=True) 141 | sample_test_features = test_features[indices] 142 | 143 | mean_test, cov_test = MaxLikelihoodEstimator().estimate_parameters(sample_test_features) 144 | fad_score = self._compute_fmd(mean_reference, mean_test, cov_reference, cov_test) 145 | 146 | # Add to results 147 | results.append([n, fad_score]) 148 | 149 | # Compute FMD-inf based on linear regression of 1/n 150 | ys = np.array(results) 151 | xs = 1 / np.array(ns) 152 | slope, intercept = np.polyfit(xs, ys[:, 1], 1) 153 | 154 | # Compute R^2 155 | r2 = 1 - np.sum((ys[:, 1] - (slope * xs + intercept)) ** 2) / np.sum((ys[:, 1] - np.mean(ys[:, 1])) ** 2) 156 | 157 | # Since intercept is the FMD-inf, we can just return it 158 | return intercept.item(), slope.item(), r2.item(), results 159 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/__init__.py: -------------------------------------------------------------------------------- 1 | from .shrikage_estimator import ShrinkageEstimator 2 | from .bootstrapping_estimator import BootstrappingEstimator 3 | from .gaussian_estimator import GaussianEstimator 4 | from .leodit_wolf_estimator import LeoditWolfEstimator 5 | from .max_likelihood_estimator import MaxLikelihoodEstimator 6 | from .oas_estimator import OASEstimator 7 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/bootstrapping_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.typing import NDArray 3 | 4 | from .gaussian_estimator import GaussianEstimator 5 | from .max_likelihood_estimator import MaxLikelihoodEstimator 6 | 7 | 8 | class BootstrappingEstimator(GaussianEstimator): 9 | 10 | def __init__(self, num_samples: int = 1000) -> None: 11 | super().__init__() 12 | self._num_samples = num_samples 13 | self._mle = MaxLikelihoodEstimator() 14 | self._rng = np.random.default_rng() 15 | 16 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 17 | running_mean = 0 18 | runing_cov = np.zeros((features.shape[1], features.shape[1])) 19 | for _ in range(self._num_samples): 20 | sample_indices = self._rng.choice(features.shape[0], size=features.shape[0], replace=True) 21 | bootstrap_sample = features[sample_indices] 22 | mean, cov = self._mle.estimate_parameters(bootstrap_sample) 23 | running_mean += mean / self._num_samples 24 | runing_cov += cov / self._num_samples 25 | 26 | return running_mean, runing_cov 27 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/gaussian_estimator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from numpy.typing import NDArray 4 | 5 | from ..memory import MEMORY 6 | 7 | 8 | class GaussianEstimator(ABC): 9 | 10 | def __init__(self) -> None: 11 | self.estimate_parameters = MEMORY.cache(self.estimate_parameters, ignore=["self"]) 12 | 13 | @abstractmethod 14 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 15 | pass 16 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/leodit_wolf_estimator.py: -------------------------------------------------------------------------------- 1 | from numpy.typing import NDArray 2 | from sklearn.covariance import LedoitWolf 3 | 4 | from .gaussian_estimator import GaussianEstimator 5 | 6 | 7 | class LeoditWolfEstimator(GaussianEstimator): 8 | 9 | def __init__(self, block_size: int = 1000) -> None: 10 | super().__init__() 11 | self._model = LedoitWolf(assume_centered=False, block_size=block_size) 12 | 13 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 14 | results = self._model.fit(features) 15 | 16 | mean = results.location_ 17 | cov = results.covariance_ 18 | return mean, cov 19 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/max_likelihood_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.typing import NDArray 3 | 4 | from .gaussian_estimator import GaussianEstimator 5 | 6 | 7 | class MaxLikelihoodEstimator(GaussianEstimator): 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | 12 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 13 | mean = np.mean(features, axis=0) 14 | covariance = np.cov(features, rowvar=False) 15 | return mean, covariance 16 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/oas_estimator.py: -------------------------------------------------------------------------------- 1 | from numpy.typing import NDArray 2 | from sklearn.covariance import OAS 3 | 4 | from .gaussian_estimator import GaussianEstimator 5 | 6 | 7 | class OASEstimator(GaussianEstimator): 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self._model = OAS(assume_centered=False) 12 | 13 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 14 | results = self._model.fit(features) 15 | 16 | mean = results.location_ 17 | cov = results.covariance_ 18 | return mean, cov 19 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/shrikage_estimator.py: -------------------------------------------------------------------------------- 1 | from numpy.typing import NDArray 2 | from sklearn.covariance import ShrunkCovariance 3 | 4 | from .gaussian_estimator import GaussianEstimator 5 | 6 | 7 | class ShrinkageEstimator(GaussianEstimator): 8 | 9 | def __init__(self, shrinkage: float = 0.1) -> None: 10 | super().__init__() 11 | self._model = ShrunkCovariance(assume_centered=False, shrinkage=shrinkage) 12 | 13 | def estimate_parameters(self, features: NDArray) -> tuple[NDArray, NDArray]: 14 | results = self._model.fit(features) 15 | 16 | mean = results.location_ 17 | cov = results.covariance_ 18 | return mean, cov 19 | -------------------------------------------------------------------------------- /frechet_music_distance/gaussian_estimators/utils.py: -------------------------------------------------------------------------------- 1 | from .shrikage_estimator import ShrinkageEstimator 2 | from .bootstrapping_estimator import BootstrappingEstimator 3 | from .gaussian_estimator import GaussianEstimator 4 | from .leodit_wolf_estimator import LeoditWolfEstimator 5 | from .max_likelihood_estimator import MaxLikelihoodEstimator 6 | from .oas_estimator import OASEstimator 7 | 8 | 9 | def get_estimator_by_name(name: str) -> GaussianEstimator: 10 | if name == "mle": 11 | return MaxLikelihoodEstimator() 12 | elif name == "bootstrap": 13 | return BootstrappingEstimator() 14 | elif name == "shrinkage": 15 | return ShrinkageEstimator() 16 | elif name == "leodit_wolf": 17 | return LeoditWolfEstimator() 18 | elif name == "oas": 19 | return OASEstimator() 20 | else: 21 | msg = f"Unknown estimator: {name}, valid options are: mle, bootstrap, shrinkage, leodit_wolf, oas" 22 | raise ValueError(msg) 23 | -------------------------------------------------------------------------------- /frechet_music_distance/memory.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from joblib import Memory 4 | 5 | CAHE_MEMORY_DIR = Path.home() / ".cache" / "frechet_music_distance" / "precomputed" 6 | 7 | MEMORY = Memory(CAHE_MEMORY_DIR, verbose=0) 8 | MEMORY.reduce_size(bytes_limit="10G") 9 | -------------------------------------------------------------------------------- /frechet_music_distance/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clamp import CLaMPExtractor 2 | from .clamp2 import CLaMP2Extractor 3 | from .feature_extractor import FeatureExtractor 4 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp/__init__.py: -------------------------------------------------------------------------------- 1 | from .clamp_extractor import CLaMPExtractor 2 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp/clamp_extractor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging as lg 4 | from pathlib import Path 5 | 6 | import torch 7 | from numpy.typing import NDArray 8 | 9 | from frechet_music_distance.dataset_loaders.abc_loader import ABCLoader, DatasetLoader 10 | from frechet_music_distance.dataset_loaders.utils import get_dataset_ext 11 | from frechet_music_distance.models.feature_extractor import FeatureExtractor 12 | 13 | from .clamp_model import CLaMP 14 | from .clamp_utils import PATCH_LENGTH, MusicPatchilizer 15 | 16 | logger = lg.getLogger(__name__) 17 | 18 | 19 | class CLaMPExtractor(FeatureExtractor): 20 | 21 | def __init__(self, verbose: bool = True) -> None: 22 | super().__init__(verbose) 23 | self._clamp_model_name = "sander-wood/clamp-small-1024" 24 | self._device = self._get_available_device() 25 | self._model = CLaMP.from_pretrained(self._clamp_model_name) 26 | self._model = self._model.to(self._device) 27 | self._model.eval() 28 | 29 | self._patchilizer = MusicPatchilizer() 30 | self._softmax = torch.nn.Softmax(dim=1) 31 | 32 | self._patch_length = PATCH_LENGTH 33 | self._abc_dataset_loader = ABCLoader(verbose=verbose) 34 | 35 | 36 | @staticmethod 37 | def _get_available_device() -> torch.device: 38 | if torch.cuda.is_available(): 39 | logger.info(f"There are {torch.cuda.device_count()} GPU(s) available.") 40 | logger.info(f"We will use the GPU: {torch.cuda.get_device_name(0)}") 41 | return torch.device("cuda") 42 | else: 43 | logger.info("No GPU available, using the CPU instead.") 44 | return torch.device("cpu") 45 | 46 | def _encoding_data(self, data: list[str], music_length: int) -> list[torch.Tensor]: 47 | """ 48 | Encode the data into ids 49 | 50 | Args: 51 | data (list): List of strings 52 | 53 | Returns: 54 | ids_list (list): List of ids 55 | """ 56 | ids_list = [] 57 | for item in data: 58 | patches = self._patchilizer.encode(item, music_length=music_length, add_eos_patch=True) 59 | new_patches = torch.tensor(patches) 60 | new_patches = new_patches.reshape(-1) 61 | ids_list.append(new_patches) 62 | return ids_list 63 | 64 | @staticmethod 65 | def _abc_filter(lines: list[str]) -> str: 66 | """ 67 | Filter out the metadata from the abc file 68 | 69 | Args: 70 | lines (list): List of lines in the abc file 71 | 72 | Returns: 73 | music (str): Music string 74 | """ 75 | music = "" 76 | for line in lines: 77 | if line[:2] in ["A:", "B:", "C:", "D:", "F:", "G", "H:", "N:", "O:", "R:", "r:", "S:", "T:", "W:", "w:", 78 | "X:", "Z:"] \ 79 | or line == "\n" \ 80 | or (line.startswith("%") and not line.startswith("%%score")): 81 | continue 82 | else: 83 | if "%" in line and not line.startswith("%%score"): 84 | line = "%".join(line.split("%")[:-1]) 85 | music += line[:-1] + "\n" 86 | else: 87 | music += line + '\n' 88 | return music 89 | 90 | 91 | def _get_features(self, ids_list: list[torch.Tensor]) -> torch.Tensor: 92 | """ 93 | Get the features from the CLaMP _model 94 | 95 | Args: 96 | ids_list (list): List of ids 97 | 98 | Returns: 99 | features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) 100 | """ 101 | 102 | features_list = [] 103 | with torch.no_grad(): 104 | for ids in ids_list: 105 | ids = ids.unsqueeze(0) 106 | masks = torch.tensor([1] * (int(len(ids[0]) / PATCH_LENGTH))).unsqueeze(0) 107 | features = self._model.music_enc(ids, masks)["last_hidden_state"] 108 | features = self._model.avg_pooling(features, masks) 109 | features = self._model.music_proj(features) 110 | features_list.append(features[0]) 111 | 112 | return torch.stack(features_list).to(self._device) 113 | 114 | @torch.no_grad() 115 | def _extract_feature(self, data: str) -> torch.Tensor: 116 | """ 117 | Extract features from the music data 118 | 119 | Args: 120 | data (str): music data in abc format 121 | Returns: 122 | features (torch.Tensor): Extracted features 123 | 124 | """ 125 | # self._abc_filter([data]) 126 | ids = self._encoding_data([data], music_length=self._model.config.max_length) 127 | features = self._get_features(ids_list=ids) 128 | return features.detach().cpu().numpy() 129 | 130 | def _choose_dataset_loader(self, extension: str) -> DatasetLoader: 131 | if extension == ".abc": 132 | return self._abc_dataset_loader 133 | else: 134 | msg = f"CLAmP 2 supports .abc files but got {extension}" 135 | raise ValueError(msg) 136 | 137 | def extract_features(self, dataset_path: str | Path) -> NDArray: 138 | extension = get_dataset_ext(dataset_path) 139 | data = self._choose_dataset_loader(extension).load_dataset_async(dataset_path) 140 | 141 | return super()._extract_features(data) 142 | 143 | def extract_feature(self, filepath: str | Path) -> NDArray: 144 | extension = Path(filepath).suffix 145 | data = self._choose_dataset_loader(extension).load_file(filepath) 146 | 147 | return self._extract_feature(data) 148 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp/clamp_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import Tuple 5 | 6 | import requests 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoConfig, AutoModel, BertConfig, PreTrainedModel 10 | 11 | from .clamp_utils import MusicEncoder 12 | 13 | 14 | class CLaMP(PreTrainedModel): 15 | """ 16 | CLaMP model for joint text and music encoding. 17 | 18 | Args: 19 | config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. 20 | Initializing with a config file does not load the weights associated with the model, only the configuration. 21 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 22 | text_model_name (:obj:`str`, `optional`, defaults to :obj:`"distilroberta-base"`): 23 | The name of the pre-trained text model to be used for text encoding. 24 | 25 | Attributes: 26 | text_enc (:obj:`AutoModel`): The pre-trained text model used for text encoding. 27 | text_proj (:obj:`torch.nn.Linear`): A linear layer to project the text encoding to the hidden size of the model. 28 | music_enc (:obj:`MusicEncoder`): The music encoder model used for music encoding. 29 | music_proj (:obj:`torch.nn.Linear`): A linear layer to project the music encoding to the hidden size of the model. 30 | """ 31 | 32 | def __init__(self, config: BertConfig, text_model_name: str = "distilroberta-base") -> None: 33 | super().__init__(config) 34 | self.text_enc = AutoModel.from_pretrained(text_model_name) 35 | self.text_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) 36 | torch.nn.init.normal_(self.text_proj.weight, std=0.02) 37 | 38 | self.music_enc = MusicEncoder(config=config) 39 | self.music_proj = torch.nn.Linear(config.hidden_size, config.hidden_size) 40 | torch.nn.init.normal_(self.music_proj.weight, std=0.02) 41 | 42 | def forward(self, input_texts: torch.LongTensor, text_masks: torch.LongTensor, input_musics: torch.LongTensor, 43 | music_masks: torch.LongTensor) -> tuple[torch.FloatTensor, torch.FloatTensor]: 44 | """ 45 | Args: 46 | input_texts (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): 47 | Tensor containing the integer-encoded text. 48 | text_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, text_length)`): 49 | Tensor containing the attention masks for the text. 50 | input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): 51 | Tensor containing the integer-encoded music patches. 52 | music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): 53 | Tensor containing the attention masks for the music patches. 54 | 55 | Returns: 56 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: 57 | music_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): 58 | The music features extracted from the music encoder. 59 | text_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): 60 | The text features extracted from the text encoder. 61 | """ 62 | # Encode input texts 63 | text_features = self.text_enc(input_texts.to(self.device), attention_mask=text_masks.to(self.device))[ 64 | "last_hidden_state"] 65 | text_features = self.avg_pooling(text_features, text_masks) 66 | text_features = self.text_proj(text_features) 67 | 68 | # Encode input musics 69 | music_features = self.music_enc(input_musics, music_masks)["last_hidden_state"] 70 | music_features = self.avg_pooling(music_features, music_masks) 71 | music_features = self.music_proj(music_features) 72 | 73 | return music_features, text_features 74 | 75 | def avg_pooling(self, input_features: torch.FloatTensor, input_masks: torch.LongTensor) -> torch.FloatTensor: 76 | """ 77 | Applies average pooling to the input features. 78 | 79 | Args: 80 | input_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_length, hidden_size)`): 81 | Tensor containing the input features. 82 | input_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`): 83 | Tensor containing the attention masks for the input features. 84 | 85 | Returns: 86 | :obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`: 87 | The pooled features. 88 | """ 89 | input_masks = input_masks.unsqueeze(-1).to(self.device) 90 | input_features = input_features * input_masks 91 | avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) 92 | 93 | return avg_pool 94 | 95 | @classmethod 96 | def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs) -> CLaMP: 97 | """ 98 | Instantiate a CLaMP model from a pre-trained model configuration. 99 | 100 | Args: 101 | pretrained_model_name_or_path (:obj:`str`): 102 | This can be either: 103 | "clamp-small-512" for the small CLaMP model with 512 max sequence length. 104 | "clamp-small-1024" for the small CLaMP model with 1024 max sequence length. 105 | 106 | Returns: 107 | :class:`~transformers.CLaMP`: The CLaMP model. 108 | """ 109 | model_dir = "." + pretrained_model_name_or_path 110 | 111 | # If the pre-trained model is not found locally, download it from Hugging Face 112 | if not os.path.exists(model_dir): 113 | # Create the model directory and download the config and pytorch model files 114 | print(f"Downloading CLaMP model from: {pretrained_model_name_or_path} to local machine") 115 | os.makedirs(model_dir) 116 | config_url = f"https://huggingface.co/{pretrained_model_name_or_path}/raw/main/config.json" 117 | model_url = f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/pytorch_model.bin" 118 | chunk_size = 1024 * 1024 # 1MB 119 | 120 | # download config file 121 | with requests.get(config_url, stream=True) as r: 122 | r.raise_for_status() 123 | total_size = int(r.headers.get("content-length", 0)) 124 | with open(model_dir + "/config.json", "wb") as f: 125 | with tqdm(total=total_size, unit="B", unit_scale=True, desc="Downloading config") as pbar: 126 | for chunk in r.iter_content(chunk_size=chunk_size): 127 | f.write(chunk) 128 | pbar.update(len(chunk)) 129 | 130 | # download pytorch model file 131 | with requests.get(model_url, stream=True) as r: 132 | r.raise_for_status() 133 | total_size = int(r.headers.get("content-length", 0)) 134 | with open(model_dir + "/pytorch_model.bin", "wb") as f: 135 | with tqdm(total=total_size, unit="B", unit_scale=True, desc="Downloading model") as pbar: 136 | for chunk in r.iter_content(chunk_size=chunk_size): 137 | f.write(chunk) 138 | pbar.update(len(chunk)) 139 | 140 | # Load the model weights and configuration 141 | config = AutoConfig.from_pretrained(model_dir, *model_args, **kwargs) 142 | model = cls(config) 143 | state_dict = torch.load(model_dir + str("/pytorch_model.bin"), weights_only=True) 144 | model.load_state_dict(state_dict, strict=False) 145 | 146 | return model 147 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp/clamp_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Tuple 3 | 4 | import torch 5 | from unidecode import unidecode 6 | from transformers import BertModel, PreTrainedModel, BertConfig 7 | import contextlib 8 | 9 | # Constants for patch length and number of features in a patch 10 | PATCH_LENGTH = 64 11 | PATCH_FEATURES = 98 12 | 13 | class MusicPatchilizer: 14 | """ 15 | Class for converting music data to patches and vice-versa. 16 | 17 | Attributes: 18 | delimiters (tuple): A tuple of strings containing the delimiters used for splitting bars. 19 | regexPattern (str): A regular expression pattern for splitting bars. 20 | pad_id (int): The id of the padding token. 21 | mask_id (int): The id of the mask token. 22 | eos_id (int): The id of the end-of-sequence token. 23 | 24 | Methods: 25 | split_bars(body): Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. 26 | bar2patch(bar, patch_length): Encodes a single bar as a patch of specified length. 27 | patch2bar(patch): Converts a patch to a bar string. 28 | encode(music, music_length, patch_length=PATCH_LENGTH, add_eos_patch=False): Encodes the input music string as a list of patches. 29 | decode(patches): Decodes a sequence of patches into a music score. 30 | """ 31 | def __init__(self) -> None: 32 | # Delimiters used for splitting bars 33 | self.delimiters = "|:", "::", ":|", "[|", "||", "|]", "|" 34 | # Regular expression pattern for splitting bars 35 | self.regexPattern = "(" + "|".join(map(re.escape, self.delimiters)) + ")" 36 | # Padding, mask, and end-of-sequence token ids 37 | self.pad_id = 0 38 | self.mask_id = 96 39 | self.eos_id = 97 40 | 41 | def split_bars(self, body: str) -> list[str]: 42 | """ 43 | Splits a body of music into individual bars using the delimiters specified in `self.delimiters`. 44 | 45 | Args: 46 | body (str): A string containing the body of music to be split into bars. 47 | 48 | Returns: 49 | list: A list of strings containing the individual bars. 50 | """ 51 | body = "".join(body) 52 | bars = re.split(self.regexPattern, body) 53 | while("" in bars): 54 | bars.remove("") 55 | if bars[0] in self.delimiters: 56 | bars[1] = bars[0]+bars[1] 57 | bars = bars[1:] 58 | bars = [bars[i*2]+bars[i*2+1] for i in range(int(len(bars)/2))] 59 | 60 | return bars 61 | 62 | def bar2patch(self, bar: str, patch_length: int) -> list[int]: 63 | """ 64 | Encodes a single bar as a patch of specified length. 65 | 66 | Args: 67 | bar (str): A string containing the bar to be encoded. 68 | patch_length (int): An integer indicating the length of the patch to be returned. 69 | 70 | Returns: 71 | list: A list of integer-encoded musical tokens. 72 | """ 73 | patch = [self.pad_id] * patch_length 74 | 75 | for i in range(min(patch_length, len(bar))): 76 | chr = bar[i] 77 | idx = ord(chr) 78 | if 32 <= idx < 127: 79 | patch[i] = idx-31 80 | 81 | if i+1 str: 87 | """ 88 | Converts a patch to a bar string. 89 | 90 | Args: 91 | patch (list): A list of integer-encoded musical tokens. 92 | 93 | Returns: 94 | str: A string containing the decoded bar. 95 | """ 96 | bar = "" 97 | 98 | for idx in patch: 99 | if 0 < idx < 96: 100 | bar += chr(idx + 31) 101 | else: 102 | break 103 | 104 | return bar 105 | 106 | def encode(self, music: str, music_length: int, patch_length: int = PATCH_LENGTH, add_eos_patch: bool = False) -> list[list[int]]: 107 | """ 108 | Encodes the input music string as a list of patches. 109 | 110 | Args: 111 | music (str): A string containing the music to be encoded. 112 | music_length (int): An integer indicating the maximum number of patches to be returned. 113 | patch_length (int): An integer indicating the length of each patch. 114 | add_eos_patch (bool): A boolean indicating whether to add an extra patch consisting of all EOS tokens at the end of the encoded music. 115 | 116 | Returns: 117 | list: A list of integer-encoded patches. 118 | """ 119 | # Convert to ASCII and split into lines 120 | music = unidecode(music) 121 | lines = music.split("\n") 122 | with contextlib.suppress(Exception): 123 | lines.remove("") 124 | 125 | body = "" 126 | patches = [] 127 | 128 | # Iterate over lines, splitting bars and encoding each one as a patch 129 | for line in lines: 130 | # check if the line is a music score line or not 131 | if len(line)>1 and ((line[0].isalpha() and line[1] == ":") or line.startswith("%%score")): 132 | # if the current line is a music score line, encode the previous body as patches 133 | if body!="": 134 | bars = self.split_bars(body) 135 | for bar in bars: 136 | # encode each bar in the body as a patch and append to the patches list 137 | patch = self.bar2patch(bar, patch_length) 138 | patches.append(patch) 139 | # reset the body variable 140 | body = "" 141 | # encode the current line as a patch and append to the patches list 142 | patch = self.bar2patch(line, patch_length) 143 | patches.append(patch) 144 | else: 145 | # if the line is not a music score line, append to the body variable 146 | body += line 147 | 148 | if body!="": 149 | bars = self.split_bars(body) 150 | 151 | for bar in bars: 152 | # encode each bar in the body as a patch and append to the patches list 153 | patch = self.bar2patch(bar, patch_length) 154 | patches.append(patch) 155 | # add an extra patch consisting of all EOS tokens, if required 156 | if add_eos_patch: 157 | eos_patch = [self.eos_id] * patch_length 158 | patches = patches + [eos_patch] 159 | 160 | return patches[:music_length] 161 | 162 | def decode(self, patches: list[list[int]]) -> str: 163 | """ 164 | Decodes a sequence of patches into a music score. 165 | 166 | Args: 167 | patches (list): A list of integer-encoded patches. 168 | 169 | Returns: 170 | str: A string containing the decoded music score. 171 | """ 172 | music = "" 173 | for patch in patches: 174 | music += self.patch2bar(patch) + "\n" 175 | 176 | return music 177 | 178 | 179 | class MusicEncoder(PreTrainedModel): 180 | """ 181 | MusicEncoder model for encoding music patches into a sequence of hidden states. 182 | 183 | Args: 184 | config (:obj:`BertConfig`): Model configuration class with all the parameters of the model. 185 | Initializing with a config file does not load the weights associated with the model, only the configuration. 186 | Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 187 | 188 | Attributes: 189 | patch_embedding (:obj:`torch.nn.Linear`): A linear layer to convert the one-hot encoded patches to the hidden size of the model. 190 | enc (:obj:`BertModel`): The BERT model used to encode the patches. 191 | """ 192 | def __init__(self, config: BertConfig) -> None: 193 | super().__init__(config) 194 | self.patch_embedding = torch.nn.Linear(PATCH_LENGTH*PATCH_FEATURES, config.hidden_size) 195 | torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) 196 | self.enc = BertModel(config=config) 197 | 198 | def forward(self, input_musics: torch.LongTensor, music_masks: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 199 | """ 200 | Args: 201 | input_musics (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length, patch_length)`): 202 | Tensor containing the integer-encoded music patches. 203 | music_masks (:obj:`torch.LongTensor` of shape :obj:`(batch_size, music_length)`): 204 | Tensor containing the attention masks for the music patches. 205 | 206 | Returns: 207 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: 208 | last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, music_length, hidden_size)`): 209 | Sequence of hidden-states at the output of the last layer of the model. 210 | """ 211 | # One-hot encode the input music patches 212 | input_musics = torch.nn.functional.one_hot(input_musics, num_classes=PATCH_FEATURES) 213 | 214 | # Reshape the input music patches to feed into the linear layer 215 | input_musics = input_musics.reshape(len(input_musics), -1, PATCH_LENGTH*PATCH_FEATURES).type(torch.FloatTensor) 216 | 217 | # Apply the linear layer to convert the one-hot encoded patches to hidden features 218 | input_musics = self.patch_embedding(input_musics.to(self.device)) 219 | 220 | # Apply the BERT model to encode the music data 221 | output = self.enc(inputs_embeds=input_musics, attention_mask=music_masks.to(self.device)) 222 | 223 | return output -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/__init__.py: -------------------------------------------------------------------------------- 1 | from .clamp2_extractor import CLaMP2Extractor 2 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/clamp2_extractor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from numpy.typing import NDArray 8 | from transformers import AutoTokenizer, BertConfig 9 | 10 | from frechet_music_distance.dataset_loaders import ABCLoader, DatasetLoader, MIDIasMTFLoader 11 | from frechet_music_distance.dataset_loaders.utils import get_dataset_ext 12 | from frechet_music_distance.models.feature_extractor import FeatureExtractor 13 | from frechet_music_distance.utils import download_file 14 | 15 | from . import config 16 | from .clamp2_model import CLaMP2 17 | from .m3_patchilizer import M3Patchilizer 18 | 19 | 20 | class CLaMP2Extractor(FeatureExtractor): 21 | 22 | def __init__(self, verbose: bool = True) -> None: 23 | super().__init__(verbose) 24 | self._accelerator = Accelerator() 25 | self._device = self._accelerator.device 26 | self._midi_dataset_loader = MIDIasMTFLoader(verbose=verbose) 27 | self._abc_dataset_loader = ABCLoader(verbose=verbose) 28 | 29 | m3_config = BertConfig( 30 | vocab_size=1, 31 | hidden_size=config.M3_HIDDEN_SIZE, 32 | num_hidden_layers=config.PATCH_NUM_LAYERS, 33 | num_attention_heads=config.M3_HIDDEN_SIZE//64, 34 | intermediate_size=config.M3_HIDDEN_SIZE*4, 35 | max_position_embeddings=config.PATCH_LENGTH 36 | ) 37 | self._model = CLaMP2(m3_config, text_model_name=config.TEXT_MODEL_NAME, hidden_size=config.CLAMP2_HIDDEN_SIZE) 38 | self._model = self._model.to(self._device) 39 | self._tokenizer = AutoTokenizer.from_pretrained(config.TEXT_MODEL_NAME) 40 | self._patchilizer = M3Patchilizer() 41 | 42 | self._model.eval() 43 | 44 | try: 45 | self._checkpoint = torch.load(config.CLAMP2_WEIGHTS_PATH, map_location="cpu", weights_only=True) 46 | except Exception: 47 | self._download_checkpoint() 48 | self._checkpoint = torch.load(config.CLAMP2_WEIGHTS_PATH, map_location="cpu", weights_only=True) 49 | 50 | self._model.load_state_dict(self._checkpoint["model"]) 51 | 52 | def _download_checkpoint(self) -> None: 53 | print(f"Downloading CLaMP2 weights from: {config.CLAMP2_WEIGHTS_URL} into {config.CLAMP2_WEIGHTS_PATH}") 54 | download_file(config.CLAMP2_WEIGHTS_URL, config.CLAMP2_WEIGHTS_PATH, verbose=self._verbose) 55 | 56 | @torch.no_grad() 57 | def _extract_feature(self, data: str) -> NDArray: 58 | 59 | input_data = self._patchilizer.encode(data, add_special_patches=True) 60 | input_data = torch.tensor(input_data) 61 | max_input_length = config.PATCH_LENGTH 62 | 63 | segment_list = [] 64 | for i in range(0, len(input_data), max_input_length): 65 | segment_list.append(input_data[i:i+max_input_length]) 66 | segment_list[-1] = input_data[-max_input_length:] 67 | 68 | last_hidden_states_list = [] 69 | 70 | for input_segment in segment_list: 71 | input_masks = torch.tensor([1]*input_segment.size(0)) 72 | pad_indices = torch.ones((config.PATCH_LENGTH - input_segment.size(0), config.PATCH_SIZE)).long() * self._patchilizer.pad_token_id 73 | input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0) 74 | input_segment = torch.cat((input_segment, pad_indices), 0) 75 | last_hidden_states = self._model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(self._device), 76 | music_masks=input_masks.unsqueeze(0).to(self._device)) 77 | last_hidden_states_list.append(last_hidden_states) 78 | 79 | full_chunk_cnt = len(input_data) // max_input_length 80 | remain_chunk_len = len(input_data) % max_input_length 81 | if remain_chunk_len == 0: 82 | feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=self._device).view(-1, 1) 83 | else: 84 | feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=self._device).view(-1, 1) 85 | 86 | last_hidden_states_list = torch.concat(last_hidden_states_list, 0) 87 | last_hidden_states_list = last_hidden_states_list * feature_weights 88 | last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum() 89 | 90 | return last_hidden_states_list.unsqueeze(0).detach().cpu().numpy() 91 | 92 | def _choose_dataset_loader(self, extension: str) -> DatasetLoader: 93 | if extension in (".mid", ".midi"): 94 | return self._midi_dataset_loader 95 | elif extension == ".abc": 96 | return self._abc_dataset_loader 97 | else: 98 | msg = f"CLAmP 2 supports .mid, .midi and .abc files but got {extension}" 99 | raise ValueError(msg) 100 | 101 | def extract_features(self, dataset_path: str | Path) -> NDArray: 102 | extension = get_dataset_ext(dataset_path) 103 | data = self._choose_dataset_loader(extension).load_dataset_async(dataset_path) 104 | 105 | return super()._extract_features(data) 106 | 107 | def extract_feature(self, filepath: str | Path) -> NDArray: 108 | extension = Path(filepath).suffix 109 | data = self._choose_dataset_loader(extension).load_file(filepath) 110 | 111 | return self._extract_feature(data) 112 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/clamp2_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, BertConfig, PreTrainedModel 3 | 4 | from .config import CLAMP2_HIDDEN_SIZE, M3_HIDDEN_SIZE, TEXT_MODEL_NAME 5 | from .m3_patch_encoder import M3PatchEncoder 6 | 7 | 8 | class CLaMP2(PreTrainedModel): 9 | 10 | def __init__( 11 | self, 12 | music_config: BertConfig, 13 | text_model_name: str = TEXT_MODEL_NAME, 14 | hidden_size: int = CLAMP2_HIDDEN_SIZE 15 | ) -> None: 16 | 17 | super().__init__(music_config) 18 | 19 | self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model 20 | self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections 21 | torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution 22 | 23 | self.music_model = M3PatchEncoder(music_config) # Initialize the music model 24 | self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections 25 | torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution 26 | 27 | def avg_pooling(self, input_features: torch.Tensor, input_masks: torch.Tensor) -> torch.Tensor: 28 | input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension 29 | input_features = input_features * input_masks # apply mask to input_features 30 | avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling 31 | 32 | return avg_pool 33 | 34 | def get_music_features(self, music_inputs: torch.Tensor, music_masks: torch.Tensor) -> torch.Tensor: 35 | music_features = self.music_model(music_inputs.to(self.device), music_masks.to(self.device))["last_hidden_state"] 36 | 37 | # Normalize features (Reduce Temporal Dimension) 38 | music_features = self.avg_pooling(music_features, music_masks) 39 | music_features = self.music_proj(music_features) 40 | 41 | return music_features -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | PATCH_SIZE = 64 # Size of each patch 5 | PATCH_LENGTH = 512 # Length of the patches 6 | PATCH_NUM_LAYERS = 12 # Number of layers in the encoder 7 | TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder 8 | M3_HIDDEN_SIZE = 768 # Size of the hidden layer 9 | 10 | # -------------------- Configuration for CLaMP2 ---------------- 11 | CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer 12 | TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model 13 | 14 | CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training 15 | CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer 16 | CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training 17 | LOGIT_SCALE = 1 # Scaling factor for contrastive loss 18 | MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input 19 | TEXT_DROPOUT = True # Whether to apply dropout during text processing 20 | CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds 21 | CLAMP2_LOAD_M3 = True # Load weights from the M3 model 22 | CLAMP2_WEIGHTS_URL = "https://huggingface.co/sander-wood/clamp2/resolve/main/weights_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.pth" 23 | 24 | CLAMP2_WEIGHT_DIR = Path.home() / ".cache" / "frechet_music_distance" / "checkpoints" / "clamp2" 25 | CLAMP_CKPT_NAME = ( 26 | "weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) + 27 | "_lr_" + str(CLAMP2_LEARNING_RATE) + 28 | "_batch_" + str(CLAMP2_BATCH_SIZE) + 29 | "_scale_" + str(LOGIT_SCALE) + 30 | "_t_length_" + str(MAX_TEXT_LENGTH) + 31 | "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") + 32 | "_t_dropout_" + str(TEXT_DROPOUT) + 33 | "_m3_" + str(CLAMP2_LOAD_M3) + ".pth" 34 | ) 35 | CLAMP2_WEIGHTS_PATH = CLAMP2_WEIGHT_DIR / CLAMP_CKPT_NAME # Path to store CLaMP2 model weights 36 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/m3_patch_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertConfig, BertModel, PreTrainedModel 3 | 4 | from .config import M3_HIDDEN_SIZE, PATCH_SIZE 5 | 6 | 7 | class M3PatchEncoder(PreTrainedModel): 8 | 9 | def __init__(self, config: BertConfig) -> None: 10 | super().__init__(config) 11 | self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE) 12 | torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) 13 | self.base = BertModel(config=config) 14 | self.pad_token_id = 0 15 | self.bos_token_id = 1 16 | self.eos_token_id = 2 17 | self.mask_token_id = 3 18 | 19 | def forward( 20 | self, 21 | input_patches: torch.Tensor, # [batch_size, seq_length, hidden_size] 22 | input_masks: torch.Tensor # [batch_size, seq_length] 23 | ) -> torch.Tensor: 24 | 25 | # Transform input_patches into embeddings 26 | input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128) 27 | input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor) 28 | input_patches = self.patch_embedding(input_patches.to(self.device)) 29 | 30 | # Apply BERT model to input_patches and input_masks 31 | return self.base(inputs_embeds=input_patches, attention_mask=input_masks) 32 | -------------------------------------------------------------------------------- /frechet_music_distance/models/clamp2/m3_patchilizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | from typing import Iterable 4 | 5 | from unidecode import unidecode 6 | 7 | from .config import PATCH_LENGTH, PATCH_SIZE 8 | 9 | 10 | class M3Patchilizer: 11 | 12 | def __init__(self) -> None: 13 | self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] 14 | self.regexPattern = "(" + "|".join(map(re.escape, self.delimiters)) + ")" 15 | self.pad_token_id = 0 16 | self.bos_token_id = 1 17 | self.eos_token_id = 2 18 | self.mask_token_id = 3 19 | 20 | def split_bars(self, body: Iterable[str]) -> list[str]: 21 | bars = re.split(self.regexPattern, "".join(body)) 22 | bars = list(filter(None, bars)) # remove empty strings 23 | if bars[0] in self.delimiters: 24 | bars[1] = bars[0] + bars[1] 25 | bars = bars[1:] 26 | bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)] 27 | return bars 28 | 29 | def bar2patch(self, bar: str, patch_size: int = PATCH_SIZE) -> list[int]: 30 | patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id] 31 | patch = patch[:patch_size] 32 | patch += [self.pad_token_id] * (patch_size - len(patch)) 33 | return patch 34 | 35 | def patch2bar(self, patch: list[int]) -> str: 36 | return "".join(chr(idx) if idx > self.mask_token_id else "" for idx in patch) 37 | 38 | def encode( 39 | self, 40 | item: str, 41 | patch_size: int = PATCH_SIZE, 42 | add_special_patches: bool = False, 43 | truncate: bool = False, 44 | random_truncate: bool = False, 45 | ) -> list[list[int]]: 46 | 47 | item = unidecode(item) 48 | lines = re.findall(r".*?\n|.*$", item) 49 | lines = list(filter(None, lines)) # remove empty lines 50 | 51 | patches = [] 52 | 53 | if lines[0].split(" ")[0] == "ticks_per_beat": 54 | patch = "" 55 | for line in lines: 56 | if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2): 57 | patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:]) 58 | else: 59 | if patch: 60 | patches.append(patch) 61 | patch = line 62 | if patch!="": 63 | patches.append(patch) 64 | else: 65 | for line in lines: 66 | if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')): 67 | patches.append(line) 68 | else: 69 | bars = self.split_bars(line) 70 | if bars: 71 | bars[-1] += "\n" 72 | patches.extend(bars) 73 | 74 | if add_special_patches: 75 | bos_patch = chr(self.bos_token_id) * patch_size 76 | eos_patch = chr(self.eos_token_id) * patch_size 77 | patches = [bos_patch] + patches + [eos_patch] 78 | 79 | if len(patches) > PATCH_LENGTH and truncate: 80 | choices = ["head", "tail", "middle"] 81 | choice = random.choice(choices) 82 | if choice=="head" or random_truncate is False: 83 | patches = patches[:PATCH_LENGTH] 84 | elif choice=="tail": 85 | patches = patches[-PATCH_LENGTH:] 86 | else: 87 | start = random.randint(1, len(patches)-PATCH_LENGTH) 88 | patches = patches[start:start+PATCH_LENGTH] 89 | 90 | patches = [self.bar2patch(patch) for patch in patches] 91 | 92 | return patches 93 | 94 | def decode(self, patches: list[list[int]]) -> str: 95 | return "".join(self.patch2bar(patch) for patch in patches) 96 | -------------------------------------------------------------------------------- /frechet_music_distance/models/feature_extractor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from pathlib import Path 5 | from typing import Any, Iterable 6 | 7 | from ..memory import MEMORY 8 | import numpy as np 9 | from numpy.typing import NDArray 10 | from tqdm import tqdm 11 | 12 | 13 | class FeatureExtractor(ABC): 14 | 15 | def __init__(self, verbose: bool = True) -> None: 16 | self._verbose = verbose 17 | self.extract_features = MEMORY.cache(self.extract_features, ignore=["self"]) 18 | 19 | @abstractmethod 20 | def _extract_feature(self, data: Any) -> NDArray: 21 | pass 22 | 23 | def _extract_features(self, data: Iterable[Any]) -> NDArray: 24 | features = [] 25 | 26 | for song in tqdm(data, desc="Extracting features", disable=(not self._verbose)): 27 | feature = self._extract_feature(song) 28 | features.append(feature) 29 | 30 | return np.vstack(features) 31 | 32 | @abstractmethod 33 | def extract_features(self, dataset_path: str | Path) -> NDArray: 34 | pass 35 | 36 | @abstractmethod 37 | def extract_feature(self, filepath: str | Path) -> NDArray: 38 | pass 39 | -------------------------------------------------------------------------------- /frechet_music_distance/models/utils.py: -------------------------------------------------------------------------------- 1 | from .clamp import CLaMPExtractor 2 | from .clamp2 import CLaMP2Extractor 3 | from .feature_extractor import FeatureExtractor 4 | 5 | 6 | def get_feature_extractor_by_name(name: str, verbose: bool = True) -> FeatureExtractor: 7 | if name == "clamp2": 8 | return CLaMP2Extractor(verbose=verbose) 9 | elif name == "clamp": 10 | return CLaMPExtractor(verbose=verbose) 11 | else: 12 | msg = f"Unknown feature extractor: {name}, valid options are: clamp, clamp2" 13 | raise ValueError(msg) 14 | -------------------------------------------------------------------------------- /frechet_music_distance/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | from .memory import MEMORY 9 | 10 | KB = 1024 11 | MB = 1024 * KB 12 | 13 | 14 | def download_file(url: str, destination: str | Path, verbose: bool = True, chunk_size: int = 10 * MB) -> None: 15 | try: 16 | with requests.get(url, stream=True) as response: 17 | response.raise_for_status() 18 | total_size = int(response.headers.get("content-length", 0)) 19 | 20 | if verbose: 21 | progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) 22 | 23 | destination = Path(destination) 24 | destination.parent.mkdir(parents=True, exist_ok=True) 25 | with open(destination, "wb") as file: 26 | for chunk in response.iter_content(chunk_size=chunk_size): 27 | if verbose: 28 | progress_bar.update(len(chunk)) 29 | file.write(chunk) 30 | 31 | if verbose: 32 | progress_bar.close() 33 | 34 | except requests.exceptions.RequestException as e: 35 | print(f"Error downloading the file from url: {url}. Error: {e}") 36 | 37 | 38 | def clear_cache() -> None: 39 | MEMORY.clear(warn=False) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "frechet_music_distance" 7 | version = "1.0.0" 8 | dependencies = [ 9 | "abctoolkit", 10 | "accelerate", 11 | "joblib", 12 | "numpy", 13 | "tqdm", 14 | "scipy", 15 | "requests", 16 | "mido", 17 | "transformers", 18 | "torch", 19 | "unidecode", 20 | "scikit-learn" 21 | ] 22 | requires-python = ">=3.9" 23 | authors = [ 24 | {name = "jryban"}, 25 | ] 26 | maintainers = [ 27 | {name = "jryban"}, 28 | ] 29 | description = "A library for computing Frechet Music Distance." 30 | readme = "README.md" 31 | license = {file = "LICENSE.txt"} 32 | keywords = [ 33 | "frechet", "music", "distance", "metric", "symbolic", "evaluation", 34 | "generative", "frechet music distance", "symbolic music", "frechet distance", 35 | "music metric", "symbolic music evaluation" 36 | ] 37 | classifiers = [ 38 | "Programming Language :: Python :: 3", 39 | "License :: OSI Approved :: MIT License", 40 | "Operating System :: OS Independent", 41 | ] 42 | 43 | [project.urls] 44 | Homepage = "https://github.com/jryban/frechet-music-distance" 45 | Repository = "https://github.com/jryban/frechet-music-distance.git" 46 | 47 | [project.scripts] 48 | fmd = "frechet_music_distance.__main__:main" 49 | 50 | [tool.setuptools] 51 | packages = [ 52 | "frechet_music_distance", 53 | "frechet_music_distance.models", 54 | "frechet_music_distance.models.clamp2", 55 | "frechet_music_distance.models.clamp", 56 | "frechet_music_distance.gaussian_estimators", 57 | "frechet_music_distance.dataset_loaders", 58 | ] 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | abctoolkit==0.0.4 2 | accelerate==1.2.1 3 | certifi==2024.12.14 4 | charset-normalizer==3.4.0 5 | filelock==3.16.1 6 | fsspec==2024.12.0 7 | huggingface-hub==0.27.0 8 | idna==3.10 9 | jellyfish==1.1.3 10 | Jinja2==3.1.5 11 | joblib==1.4.2 12 | MarkupSafe==3.0.2 13 | mido==1.3.3 14 | mpmath==1.3.0 15 | networkx==3.4.2 16 | numpy==2.2.1 17 | packaging==24.2 18 | psutil==6.1.1 19 | PyYAML==6.0.2 20 | RapidFuzz==3.11.0 21 | regex==2024.11.6 22 | requests==2.32.3 23 | safetensors==0.4.5 24 | scipy==1.14.1 25 | setuptools==75.6.0 26 | sympy==1.13.1 27 | tokenizers==0.21.0 28 | torch==2.5.1 29 | tqdm==4.67.1 30 | transformers==4.47.1 31 | typing_extensions==4.12.2 32 | Unidecode==1.3.8 33 | urllib3==2.3.0 34 | scikit-learn==1.6.1 35 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pytest==8.3.* -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pytest 3 | 4 | from frechet_music_distance.fmd import FrechetMusicDistance 5 | 6 | 7 | @pytest.fixture(scope="session", name="test_data_path") 8 | def fixture_test_data_path() -> Path: 9 | return Path("tests/data").resolve(strict=True) 10 | 11 | 12 | @pytest.fixture(scope="session", name="midi_data_path") 13 | def fixture_midi_data_path(test_data_path) -> Path: 14 | return test_data_path / "midi" 15 | 16 | 17 | @pytest.fixture(scope="session", name="abc_data_path") 18 | def fixture_abc_data_path(test_data_path) -> Path: 19 | return test_data_path / "abc" 20 | 21 | 22 | @pytest.fixture(scope="session", name="abc_song_path") 23 | def fixture_abc_song_path(abc_data_path) -> Path: 24 | return abc_data_path / "example_1.abc" 25 | 26 | 27 | @pytest.fixture(scope="session", name="midi_song_path") 28 | def fixture_midi_song_path(midi_data_path) -> Path: 29 | return midi_data_path / "example_1.mid" 30 | 31 | 32 | @pytest.fixture(scope="session", name="base_fmd_clamp") 33 | def fixture_base_fmd_clamp(): 34 | return FrechetMusicDistance(feature_extractor="clamp", gaussian_estimator="mle", verbose=False) 35 | 36 | 37 | @pytest.fixture(scope="session", name="base_fmd_clamp2") 38 | def fixture_base_fmd_clamp2(): 39 | return FrechetMusicDistance(feature_extractor="clamp2", gaussian_estimator="mle", verbose=False) -------------------------------------------------------------------------------- /tests/data/abc/example_1.abc: -------------------------------------------------------------------------------- 1 | M:9/8 2 | K:Cmaj 3 | G E E E 2 D E D C | G E E E F G A B c | G E E E 2 D E D C | A D D G E C D 2 A | G E E E 2 D E D C | G E E E F G A B c | G E E E 2 D E D C | A D D G E C D 2 D | E D E c 2 A B A G | E D E A /2 B /2 c A B 2 D | E D E c 2 A B A G | A D D D E G A 2 D | E D E c 2 A B A G | E D E A /2 B /2 c A B 2 B | G A B c B A B A G | A D D D E G A B c | 4 | -------------------------------------------------------------------------------- /tests/data/abc/example_2.abc: -------------------------------------------------------------------------------- 1 | M:6/8 2 | L: 1/8 3 | K:Cdor 4 | |: e 2 g g g f d | e 2 g c c' b g | g c' g g a b | b g f g f d | e 2 g f g f d | e c g c c' b g | g c' g a b d' 3 | b g f g 3 | g c' g b g d b | b g f g f d | -------------------------------------------------------------------------------- /tests/data/abc/example_3.abc: -------------------------------------------------------------------------------- 1 | M:4/4 2 | L: 1/8 3 | K:Cmaj 4 | V:1 5 | |: G 2 | c 2 c 2 B 2 d 2 | c 2 B 2 A 2 G 2 | A 2 B 2 c 2 d 2 | e 2 g 2 g 2 d e | f 2 d 2 B 2 A B | c 2 B 2 A 2 G 2 | A B c A G F E D | C 2 C 2 C 2 :| |: G 2 | c 2 c 2 c 2 d e | f 2 f 2 f 2 e d | e 2 g 2 d 2 e 2 | c B A G A 2 G 2 | c 2 c 2 B 2 c d | e 2 e 2 e 2 d e | f 2 d 2 B c d B | c 2 c 2 c 2 :| -------------------------------------------------------------------------------- /tests/data/abc/example_4.abc: -------------------------------------------------------------------------------- 1 | M: 3/4 2 | L: 1/16 3 | K: A 4 | A3B|=c4A4E4|A3BA4A3B|=c4A4E4|A3BA4A3=c|B4=G4E4|=G3AG4G3A|B4=G4F4|E3FE2z2=C3D|E4E4=F4|=G3AG4A3A|E4E4A4|B3=cB4A3B|=c4A4E4|A3BA4 -------------------------------------------------------------------------------- /tests/data/abc/example_5.abc: -------------------------------------------------------------------------------- 1 | M: 2/4 2 | L: 1/16 3 | K: G 4 | |B3AG2E2|D2E2D4|D2c2A4|D2B2G4|B3AG2E2|D2E2D4|D2c2A2F2|G4z4|A2B2A2B2|F2B2A4|A2A2B2^c2|d2D2D4|D2D2G2G2|B2B2d4|c2D2A4|B2D2G4|D2D2G2G2|B2B2d4|c2D2A2D2|G4z4 -------------------------------------------------------------------------------- /tests/data/midi/example_1.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jryban/frechet-music-distance/86fd0efc653d74ec85866a421ee78404675c8ea2/tests/data/midi/example_1.mid -------------------------------------------------------------------------------- /tests/data/midi/example_2.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jryban/frechet-music-distance/86fd0efc653d74ec85866a421ee78404675c8ea2/tests/data/midi/example_2.mid -------------------------------------------------------------------------------- /tests/data/midi/example_3.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jryban/frechet-music-distance/86fd0efc653d74ec85866a421ee78404675c8ea2/tests/data/midi/example_3.mid -------------------------------------------------------------------------------- /tests/data/midi/example_4.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jryban/frechet-music-distance/86fd0efc653d74ec85866a421ee78404675c8ea2/tests/data/midi/example_4.mid -------------------------------------------------------------------------------- /tests/data/midi/example_5.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jryban/frechet-music-distance/86fd0efc653d74ec85866a421ee78404675c8ea2/tests/data/midi/example_5.mid -------------------------------------------------------------------------------- /tests/test_fmd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from frechet_music_distance import FrechetMusicDistance 4 | from frechet_music_distance.fmd import FMDInfResults 5 | from frechet_music_distance.models import CLaMP2Extractor, CLaMPExtractor 6 | from frechet_music_distance.utils import clear_cache 7 | from frechet_music_distance.gaussian_estimators.utils import get_estimator_by_name 8 | 9 | 10 | class TestFrechetMusicDistance: 11 | @staticmethod 12 | def test_fmd_clamp2_basic_creation(base_fmd_clamp2): 13 | assert base_fmd_clamp2 is not None 14 | assert base_fmd_clamp2._verbose is False 15 | assert isinstance(base_fmd_clamp2._feature_extractor, CLaMP2Extractor) 16 | clear_cache() 17 | 18 | @staticmethod 19 | def test_basic_creation_clamp(base_fmd_clamp): 20 | assert base_fmd_clamp is not None 21 | assert base_fmd_clamp._verbose is False 22 | assert isinstance(base_fmd_clamp._feature_extractor, CLaMPExtractor) 23 | clear_cache() 24 | 25 | @staticmethod 26 | @pytest.mark.parametrize("input_dataset_path", ["midi_data_path", "abc_data_path"]) 27 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 28 | def test_clamp2_score(base_fmd_clamp2, midi_data_path, abc_data_path, input_dataset_path, estimator_name): 29 | current_dataset = locals()[input_dataset_path] 30 | feature_extractor = get_estimator_by_name(estimator_name) 31 | base_fmd_clamp2._gaussian_estimator = feature_extractor 32 | score = base_fmd_clamp2.score(current_dataset, current_dataset) 33 | assert isinstance(score, float) 34 | assert score == pytest.approx(0, abs=0.1) 35 | clear_cache() 36 | 37 | @staticmethod 38 | @pytest.mark.parametrize("input_dataset_path", ["midi_data_path", "abc_data_path"]) 39 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 40 | def test_clamp2_score_inf(base_fmd_clamp2, midi_data_path, abc_data_path, input_dataset_path, estimator_name): 41 | current_dataset = locals()[input_dataset_path] 42 | feature_extractor = get_estimator_by_name(estimator_name) 43 | base_fmd_clamp2._gaussian_estimator = feature_extractor 44 | score = base_fmd_clamp2.score_inf(current_dataset, current_dataset, steps=3, min_n=3) 45 | assert isinstance(score, FMDInfResults) 46 | assert isinstance(score.score, float) 47 | assert isinstance(score.r2, float) 48 | assert isinstance(score.slope, float) 49 | assert isinstance(score.points, list) 50 | clear_cache() 51 | 52 | @staticmethod 53 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 54 | def test_clamp2_score_individual_midi(base_fmd_clamp2, midi_data_path, midi_song_path, estimator_name): 55 | feature_extractor = get_estimator_by_name(estimator_name) 56 | base_fmd_clamp2._gaussian_estimator = feature_extractor 57 | score = base_fmd_clamp2.score_individual(midi_data_path, midi_song_path) 58 | assert isinstance(score, float) 59 | assert score == pytest.approx(339, abs=10) 60 | clear_cache() 61 | 62 | @staticmethod 63 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 64 | def test_clamp2_score_individual_abc(base_fmd_clamp2, abc_data_path, abc_song_path, estimator_name): 65 | feature_extractor = get_estimator_by_name(estimator_name) 66 | base_fmd_clamp2._gaussian_estimator = feature_extractor 67 | score = base_fmd_clamp2.score_individual(abc_data_path, abc_song_path) 68 | assert isinstance(score, float) 69 | assert score == pytest.approx(275, abs=10) 70 | clear_cache() 71 | 72 | @staticmethod 73 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 74 | def test_clamp_score(base_fmd_clamp, abc_data_path, estimator_name): 75 | feature_extractor = get_estimator_by_name(estimator_name) 76 | base_fmd_clamp._gaussian_estimator = feature_extractor 77 | score = base_fmd_clamp.score(abc_data_path, abc_data_path) 78 | assert isinstance(score, float) 79 | assert score == pytest.approx(0, abs=0.1) 80 | clear_cache() 81 | 82 | @staticmethod 83 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 84 | def test_clamp_score_inf(base_fmd_clamp, abc_data_path, estimator_name): 85 | feature_extractor = get_estimator_by_name(estimator_name) 86 | base_fmd_clamp._gaussian_estimator = feature_extractor 87 | score = base_fmd_clamp.score_inf(abc_data_path, abc_data_path, steps=3, min_n=3) 88 | assert isinstance(score, FMDInfResults) 89 | assert isinstance(score.score, float) 90 | assert isinstance(score.r2, float) 91 | assert isinstance(score.slope, float) 92 | assert isinstance(score.points, list) 93 | clear_cache() 94 | 95 | @staticmethod 96 | @pytest.mark.parametrize("estimator_name", ["shrinkage", "mle", "leodit_wolf", "bootstrap", "oas"]) 97 | def test_clamp_score_individual(base_fmd_clamp, abc_data_path, abc_song_path, estimator_name): 98 | score = base_fmd_clamp.score_individual(abc_data_path, abc_song_path) 99 | assert isinstance(score, float) 100 | assert score == pytest.approx(90, abs=10) 101 | clear_cache() --------------------------------------------------------------------------------