├── .gitignore ├── README.md ├── configs └── example.yml ├── docs └── img │ └── mir_ref_logo.svg ├── mir_ref ├── __init__.py ├── conduct.py ├── dataloaders.py ├── datasets │ ├── custom_datasets.py │ ├── dataset.py │ ├── datasets │ │ ├── magnatagatune.py │ │ ├── mtg_jamendo.py │ │ └── vocalset.py │ └── mirdata_datasets.py ├── deform.py ├── deformations.py ├── evaluate.py ├── extract.py ├── features │ ├── custom_features.py │ ├── feature_extraction.py │ └── models │ │ ├── clmr.py │ │ ├── harmonic_cnn.py │ │ ├── mule.py │ │ └── openl3.py ├── metrics.py ├── probes │ └── probe_builder.py ├── train.py └── utils.py ├── requirements.txt ├── run.py └── tests ├── test_cfg.yml ├── test_datasets.py ├── test_deform.py └── test_generate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # macos 156 | **./DS_Store 157 | 158 | # vscode 159 | .vscode/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # mir_ref 4 | 5 | Representation Evaluation Framework for Music Information Retrieval tasks | [Paper](https://arxiv.org/abs/2312.05994) 6 | 7 | `mir_ref` is an open-source library for evaluating audio representations (embeddings or others) on a variety of music-related downstream tasks and datasets. It has two main capabilities: 8 | 9 | * Using a config file, you can specify all evaluation experiments you want to run. `mir_ref` will automatically acquire data and conduct the experiments - no coding or data handling needed. Many tasks, datasets, embedding models etc. are ready to use (see supported options section). 10 | * You can easily integrate your own features (e.g. embedding models), datasets, probes, metrics, and audio deformations and use them in `mir_ref` experiments. 11 | 12 | `mir_ref` builds upon existing reproducability efforts in music audio, including [`mirdata`](https://mirdata.readthedocs.io/en/stable/) for data handling, [`mir_eval`](https://craffel.github.io/mir_eval/) for evaluation metrics, and [`essentia models`](https://essentia.upf.edu/models.html) for pretrained audio models. 13 | 14 | ## Disclaimer 15 | 16 | A first beta release is expected by the end of April, and it includes many fixes and documentation improvements. 17 | 18 | ## Setup 19 | 20 | Clone the repository. Create and activate a python>3.9 environment and install the requirements. 21 | 22 | ``` 23 | cd mir_ref 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Running 28 | 29 | To run the experiments specified in a config file `configs/example.yml` end-to-end: 30 | 31 | ``` 32 | python run.py conduct -c example 33 | ``` 34 | 35 | This will currently save deformations and features to use them later, but an online option will soon be available. 36 | 37 | Alternatively, `mir_ref` is comprised of 4 main functions-commands: `deform`, for generating deformations from a dataset; `extract`, for extracting features; `train`, for training probes; and `evaluate`, for evaluating them. These can be run as follows: 38 | 39 | ``` 40 | python run.py COMMAND -c example 41 | ``` 42 | 43 | `deform` optionally includes the option `n_jobs` for specifying parallelization of deformation computation, and `extract` includes the flag `--no_overwrite` to skip recomputing existing features. 44 | 45 | #### Configuration file 46 | 47 | An example configuration file is provided. Config files are written in YAML. A list of experiments is expected at the top level, and each experiment contains a task, datasets, features, and probes. For each dataset, a list of deformation scenarios can be specified, following the argument syntax of [audiomentations](https://iver56.github.io/audiomentations/). 48 | 49 | ## Currently supported options 50 | 51 | ### Datasets and Tasks 52 | 53 | * `magnatagatune`: MagnaTagATune (autotagging) 54 | * `mtg_jamendo`: MTG Jamendo (autotagging) 55 | * `vocalset`: VocalSet (singer_identification, technique_identification) 56 | * `tinysol`: TinySol (instrument_classification, pitch_class_classification) 57 | * `beaport`: Beatport (key_estimation) 58 | 59 | ~Many more soon 60 | 61 | ### Features 62 | 63 | * `effnet-discogs` 64 | * `vggish-audioset` 65 | * `msd-musicnn` 66 | * `openl3` 67 | * `neuralfp` 68 | * `clmr-v2` 69 | * `mert-v1-95m-6` / `mert-v1-95m-0-1-2-3` / `mert-v1-95m-4-5-6-7-8` / `mert-v1-95m-9-10-11-12` (referring to the layers used) 70 | * `maest` 71 | 72 | ~More soon 73 | 74 | ## Example results 75 | 76 | We conducted an example evaluation of 7 models in 6 tasks with 4 deformations and 5 different probing setups. For the full results, please refer to the 'Evaluation' chapter of [this thesis document](https://zenodo.org/records/8380471), pages 39-58. 77 | 78 | ## Citing 79 | 80 | If you use `mir_ref`, please cite the following [paper](https://arxiv.org/abs/2312.05994): 81 | 82 | ``` 83 | @inproceedings{mir_ref, 84 | author = {Christos Plachouras and Pablo Alonso-Jim\'enez and Dmitry Bogdanov}, 85 | title = {mir_ref: A Representation Evaluation Framework for Music Information Retrieval Tasks}, 86 | booktitle = {37th Conference on Neural Information Processing Systems (NeurIPS), Machine Learning for Audio Workshop}, 87 | address = {New Orleans, LA, USA}, 88 | year = 2023, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /configs/example.yml: -------------------------------------------------------------------------------- 1 | experiments: 2 | - task: 3 | name: autotagging 4 | type: multilabel_classification 5 | feature_aggregation: mean 6 | datasets: 7 | - name: magnatagatune 8 | type: custom 9 | dir: data/magnatagatune/ 10 | split_type: single 11 | deformations: 12 | - - type: AddGaussianSNR 13 | params: 14 | min_snr_in_db: 15 15 | max_snr_in_db: 15 16 | p: 1 17 | - - type: AddGaussianSNR 18 | params: 19 | min_snr_in_db: 0 20 | max_snr_in_db: 0 21 | p: 1 22 | - - type: Gain 23 | params: 24 | min_gain_in_db: -12 25 | max_gain_in_db: -12 26 | p: 1 27 | features: 28 | - vggish-audioset 29 | - clmr-v2 30 | - mert-v1-95m-6 31 | probes: 32 | - type: classifier 33 | emb_dim_reduction: False 34 | emb_shape: infer 35 | hidden_units: [] 36 | output_activation: sigmoid 37 | weight_decay: 1.0e-5 38 | # optimizer 39 | optimizer: adam 40 | learning_rate: 1.0e-3 41 | # training 42 | batch_size: 1083 43 | epochs: 100 44 | patience: 10 45 | train_sampling: random 46 | - type: classifier 47 | emb_dim_reduction: False 48 | emb_shape: infer 49 | hidden_units: [infer] 50 | output_activation: sigmoid 51 | weight_decay: 1.0e-5 52 | # optimizer 53 | optimizer: adam 54 | learning_rate: 1.0e-3 55 | # training 56 | batch_size: 1083 57 | epochs: 100 58 | patience: 10 59 | train_sampling: random 60 | - type: classifier 61 | emb_dim_reduction: False 62 | emb_shape: infer 63 | hidden_units: [256, 128] 64 | output_activation: sigmoid 65 | weight_decay: 1.0e-5 66 | # optimizer 67 | optimizer: adam 68 | learning_rate: 1.0e-3 69 | # training 70 | batch_size: 1083 71 | epochs: 100 72 | patience: 10 73 | train_sampling: random 74 | 75 | -------------------------------------------------------------------------------- /docs/img/mir_ref_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 17 | 37 | 39 | 42 | 46 | 50 | 51 | 54 | 58 | 62 | 63 | 77 | 91 | 105 | 119 | 133 | 147 | 156 | 165 | 174 | 183 | 192 | 201 | 202 | 207 | 211 | 218 | 223 | 228 | 233 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /mir_ref/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrispla/mir_ref/691ae42815db6359ef66c1e175be55be42bbc340/mir_ref/__init__.py -------------------------------------------------------------------------------- /mir_ref/conduct.py: -------------------------------------------------------------------------------- 1 | """End-to-end experiment conduct. Currently, this is a hacky way of 2 | doing things, and it will be replaced soon with an online process.""" 3 | 4 | 5 | from mir_ref.deform import deform 6 | from mir_ref.evaluate import evaluate 7 | from mir_ref.extract import generate 8 | from mir_ref.train import train 9 | 10 | 11 | def conduct(cfg_path): 12 | """Conduct experiments in config end-to-end. 13 | 14 | Args: 15 | cfg_path (str): Path to config file. 16 | """ 17 | deform(cfg_path=cfg_path, n_jobs=1) 18 | generate(cfg_path=cfg_path) 19 | train(cfg_path=cfg_path) 20 | evaluate(cfg_path=cfg_path) 21 | -------------------------------------------------------------------------------- /mir_ref/dataloaders.py: -------------------------------------------------------------------------------- 1 | """Dataloaders. 2 | Code from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly 3 | """ 4 | 5 | from pathlib import Path 6 | 7 | import keras 8 | import numpy as np 9 | 10 | 11 | class DataGenerator(keras.utils.Sequence): 12 | """Generates data for Keras""" 13 | 14 | def __init__( 15 | self, 16 | ids_list, 17 | labels_dict, 18 | paths_dict, 19 | dim, 20 | n_classes, 21 | batch_size=8, 22 | shuffle=True, 23 | ): 24 | """Initialization""" 25 | self.dim = dim 26 | self.batch_size = batch_size 27 | self.ids_list = ids_list 28 | self.labels_dict = labels_dict 29 | self.paths_dict = paths_dict 30 | self.n_classes = n_classes 31 | self.shuffle = shuffle 32 | self.on_epoch_end() 33 | 34 | def __len__(self): 35 | """Denotes the number of batches per epoch. Last non-full batch is dropped.""" 36 | return int(np.floor(len(self.ids_list) / self.batch_size)) 37 | 38 | def __getitem__(self, index): 39 | """Generate one batch of data""" 40 | # Generate indexes of the batch 41 | indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size] 42 | 43 | # Find list of ids 44 | ids_list_temp = [self.ids_list[k] for k in indexes] 45 | 46 | # Generate data 47 | X, y = self.__data_generation(ids_list_temp=ids_list_temp) 48 | 49 | return X, y 50 | 51 | def on_epoch_end(self): 52 | """Updates indexes after each epoch""" 53 | self.indexes = np.arange(len(self.ids_list)) 54 | if self.shuffle is True: 55 | np.random.shuffle(self.indexes) 56 | 57 | def __data_generation(self, ids_list_temp): 58 | """Generates data containing batch_size samples 59 | X : (n_samples, *dim) 60 | """ 61 | # Initialization 62 | X = np.empty((self.batch_size, *self.dim)) 63 | y = np.empty((self.batch_size, self.n_classes), dtype=int) 64 | 65 | # Generate data 66 | for i, t_id in enumerate(ids_list_temp): 67 | # Store sample 68 | emb = np.load(self.paths_dict[t_id]) 69 | X[i,] = np.reshape(emb, *self.dim) 70 | 71 | y[i] = self.labels_dict[t_id] 72 | 73 | return X, y 74 | -------------------------------------------------------------------------------- /mir_ref/datasets/custom_datasets.py: -------------------------------------------------------------------------------- 1 | """Implement your own datasets here. Provide custom implementations 2 | for the given methods, without changing any of the given method 3 | and input/output names. 4 | """ 5 | 6 | from mir_ref.datasets.dataset import Dataset 7 | 8 | 9 | class CustomDataset0(Dataset): 10 | def __init__( 11 | self, 12 | name, 13 | dataset_type, 14 | data_dir, 15 | split_type, 16 | task_name, 17 | task_type, 18 | feature_aggregation, 19 | deformations_cfg, 20 | features_cfg, 21 | ): 22 | """Custom Dataset class. 23 | 24 | Args: 25 | name (str): Name of the dataset. 26 | dataset_type (str): Type of the dataset ("mirdata", "custom") 27 | data_dir (str): Path to the dataset directory. 28 | split_type (str): split_type (str, optional): Whether to use "all" or "single" split. 29 | Defaults to "single". If list of 3 floats, 30 | use a train, val, test split sizes. 31 | task_name (str): Name of the task. 32 | task_type (str): Type of the task. ("multiclass_classification", 33 | "multilabel_classification", "regression") 34 | feature_aggregation (str): Type of embedding aggregation ("mean", None) 35 | deformations_cfg (dict): Deformations config. 36 | features_cfg (dict): Embedding models config. 37 | """ 38 | super().__init__( 39 | name=name, 40 | dataset_type=dataset_type, 41 | data_dir=data_dir, 42 | split_type=split_type, 43 | task_name=task_name, 44 | task_type=task_type, 45 | feature_aggregation=feature_aggregation, 46 | deformations_cfg=deformations_cfg, 47 | features_cfg=features_cfg, 48 | ) 49 | 50 | @Dataset.try_to_load_metadata 51 | def download(self): 52 | pass 53 | 54 | @Dataset.try_to_load_metadata 55 | def download_metadata(self): 56 | # optional method 57 | pass 58 | 59 | def load_track_ids(self): 60 | # return list of track_ids 61 | pass 62 | 63 | def load_audio_paths(self): 64 | # return dict of track_id: audio_path 65 | pass 66 | 67 | @Dataset.check_metadata_is_loaded 68 | def get_splits(self): 69 | # return list of dicts with {train: [track_id_0, ...], 70 | # test: [", ...], val: [", ...]} 71 | pass 72 | -------------------------------------------------------------------------------- /mir_ref/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """Generic Dataset class and dataset factory. 2 | """ 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | 8 | def get_dataset(dataset_cfg, task_cfg, features_cfg): 9 | kwargs = { 10 | "name": dataset_cfg["name"], 11 | "dataset_type": dataset_cfg["type"], 12 | "data_dir": dataset_cfg["dir"], 13 | "split_type": dataset_cfg["split_type"], 14 | "task_name": task_cfg["name"], 15 | "task_type": task_cfg["type"], 16 | "feature_aggregation": task_cfg["feature_aggregation"], 17 | "deformations_cfg": dataset_cfg["deformations"], 18 | "features_cfg": features_cfg, 19 | } 20 | if dataset_cfg["type"] == "mirdata": 21 | from mir_ref.datasets.mirdata_datasets import MirdataDataset 22 | 23 | return MirdataDataset(**kwargs) 24 | 25 | elif dataset_cfg["type"] == "custom": 26 | if dataset_cfg["name"] == "vocalset": 27 | from mir_ref.datasets.datasets.vocalset import VocalSet 28 | 29 | return VocalSet(**kwargs) 30 | elif dataset_cfg["name"] in [ 31 | "mtg-jamendo-moodtheme", 32 | "mtg-jamendo-instruments", 33 | "mtg-jamendo-genre", 34 | "mtg-jamendo-top50tags", 35 | ]: 36 | from mir_ref.datasets.datasets.mtg_jamendo import MTG_Jamendo 37 | 38 | return MTG_Jamendo(**kwargs) 39 | elif dataset_cfg["name"] == "magnatagatune": 40 | from mir_ref.datasets.datasets.magnatagatune import MagnaTagATune 41 | 42 | return MagnaTagATune(**kwargs) 43 | else: 44 | raise NotImplementedError( 45 | f"Custom dataset with name '{dataset_cfg['name']}' does not exist." 46 | ) 47 | 48 | else: 49 | raise NotImplementedError 50 | 51 | 52 | class Dataset: 53 | def __init__( 54 | self, 55 | name, 56 | dataset_type, 57 | task_name, 58 | task_type, 59 | feature_aggregation, 60 | deformations_cfg, 61 | features_cfg, 62 | split_type="single", 63 | data_dir=None, 64 | ): 65 | """Generic Dataset class. 66 | 67 | Args: 68 | name (str): Name of the dataset. 69 | dataset_type (str): Type of the dataset ("mirdata", "custom") 70 | task_name (str): Name of the task. 71 | task_type (str): Type of the task. 72 | data_dir (str, optional): Path to the dataset directory. 73 | Defaults to ./data/{name}/. 74 | split_type (str, optional): Whether to use "all" or "single" split. 75 | Defaults to "single". If list of 3 floats, 76 | use a train, val, test split sizes. 77 | deformations_cfg (list, optional): List of deformation scenarios. 78 | features_cfg (list, optional): List of embedding models. 79 | """ 80 | self.deformations_cfg = deformations_cfg 81 | self.features_cfg = features_cfg 82 | self.name = name 83 | self.dataset_type = dataset_type 84 | self.data_dir = data_dir 85 | if data_dir is None: 86 | self.data_dir = f"./data/{name}/" 87 | self.split_type = split_type 88 | self.task_name = task_name 89 | self.task_type = task_type 90 | self.feature_aggregation = feature_aggregation 91 | try: 92 | self.load_metadata() 93 | except: 94 | self.track_ids = None # list of track_ids 95 | self.labels = None # dict of track_id: label 96 | self.encoded_labels = None # dict of track_id: encoded_label 97 | self.audio_paths = None # dict of track_id: audio_path 98 | self.common_audio_dir = None # common directory between all audio paths 99 | 100 | def check_params(self): 101 | # check validity of provided parameters 102 | return 103 | 104 | def check_metadata_is_loaded(func): 105 | """Decorator to check if metadata is loaded.""" 106 | 107 | def wrapper(self): 108 | if any( 109 | var is None 110 | for var in [ 111 | self.track_ids, 112 | self.labels, 113 | self.encoded_labels, 114 | self.audio_paths, 115 | self.common_audio_dir, 116 | ] 117 | ): 118 | raise ValueError( 119 | "Metadata not loaded. Make sure to run 'dataset.download()'." 120 | ) 121 | return func(self) 122 | 123 | return wrapper 124 | 125 | def try_to_load_metadata(func): 126 | """Try to load metadata after a function. Aimed at functions 127 | that download or preprocess datasets. 128 | """ 129 | 130 | def wrapper(self): 131 | func(self) 132 | try: 133 | self.load_metadata() 134 | except Exception as e: 135 | print(e) 136 | 137 | return wrapper 138 | 139 | def download(self): 140 | # to be overwritten by child class 141 | return 142 | 143 | def download_metadata(self): 144 | # to be overwritten by child class 145 | return 146 | 147 | def preprocess(self): 148 | # to be overwritten by child class 149 | return 150 | 151 | def load_track_ids(self): 152 | # to be overwritten by child class 153 | return 154 | 155 | def load_audio_paths(self): 156 | # to be overwritten by child class 157 | return 158 | 159 | def load_labels(self): 160 | # to be overwritten by child class 161 | return 162 | 163 | def load_encoded_labels(self): 164 | """Return encoded labels given task type and labels.""" 165 | 166 | # get lists of track_ids and labels (corresponding indices) 167 | track_ids = list(self.labels.keys()) 168 | labels_list = list(self.labels.values()) 169 | 170 | if self.task_type == "multiclass_classification": 171 | import keras 172 | from sklearn.preprocessing import LabelEncoder 173 | 174 | # fit label encoder on all tracks and labels 175 | self.label_encoder = LabelEncoder() 176 | self.label_encoder.fit(labels_list) 177 | # get encoded labels 178 | encoded_labels_list = self.label_encoder.transform(labels_list) 179 | encoded_categorical_labels_list = keras.utils.to_categorical( 180 | encoded_labels_list, num_classes=len(set(labels_list)) 181 | ) 182 | self.encoded_labels = { 183 | track_ids[i]: encoded_categorical_labels_list[i] 184 | for i in range(len(track_ids)) 185 | } 186 | 187 | elif self.task_type == "multilabel_classification": 188 | from sklearn.preprocessing import MultiLabelBinarizer 189 | 190 | # fit label encoder on all tracks and labels 191 | self.label_encoder = MultiLabelBinarizer() 192 | self.label_encoder.fit(labels_list) 193 | # get encoded labels 194 | encoded_labels_list = self.label_encoder.transform(labels_list) 195 | self.encoded_labels = { 196 | track_ids[i]: encoded_labels_list[i] for i in range(len(track_ids)) 197 | } 198 | 199 | else: 200 | raise NotImplementedError 201 | 202 | def load_common_audio_dir(self): 203 | """Get the deepest common directory between all audio paths. 204 | This will later be used as a reference for creating the dir 205 | structure for deformed audio and embeddings.""" 206 | 207 | self.common_audio_dir = str(os.path.commonpath(self.audio_paths.values())) 208 | 209 | def get_deformed_audio_path(self, track_id, deform_idx): 210 | """Get path of deformed audio based on audio path and deform_idx.""" 211 | 212 | audio_path = Path(self.audio_paths[track_id]) 213 | new_filestem = f"{audio_path.stem}_deform_{deform_idx}" 214 | return str( 215 | Path(self.data_dir) 216 | / "audio_deformed" 217 | / audio_path.relative_to(self.common_audio_dir).with_name( 218 | f"{new_filestem}{audio_path.suffix}" 219 | ) 220 | ) 221 | 222 | def get_embedding_path(self, track_id, feature): 223 | """Get path of embedding based on audio path and embedding model.""" 224 | 225 | return str( 226 | Path(self.data_dir) 227 | / "embeddings" 228 | / feature 229 | / Path(self.audio_paths[track_id]) 230 | .relative_to(self.common_audio_dir) 231 | .with_suffix(f".npy") 232 | ) 233 | 234 | def get_deformed_embedding_path(self, track_id, deform_idx, feature): 235 | """Get path of deformed embedding based on audio path, embedding model 236 | and deform_idx.""" 237 | 238 | return str( 239 | Path(self.data_dir) 240 | / "embeddings" 241 | / feature 242 | / Path(self.audio_paths[track_id]) 243 | .relative_to(self.common_audio_dir) 244 | .with_name( 245 | f"{Path(self.audio_paths[track_id]).stem}_deform_{deform_idx}.npy" 246 | ) 247 | ) 248 | 249 | def get_stratified_split(self, sizes=(0.8, 0.1, 0.1), seed=42): 250 | """Helper method to generate a stratified split of the dataset. 251 | 252 | Args: 253 | sizes (tuple, optional): Sizes of train, validation and test set. 254 | Defaults to (0.8, 0.1, 0.1), must add up to 1. 255 | seed (int, optional): Random seed. Defaults to 42. 256 | """ 257 | from sklearn.model_selection import train_test_split 258 | 259 | if sum(sizes) != 1: 260 | raise ValueError("Sizes must add up to 1.") 261 | 262 | X = self.track_ids 263 | y = [self.labels[track_id] for track_id in self.track_ids] 264 | X_train, X_others, y_train, y_others = train_test_split( 265 | X, 266 | y, 267 | test_size=1 - sizes[0], 268 | random_state=seed, 269 | stratify=y, 270 | ) 271 | X_val, X_test, y_val, y_test = train_test_split( 272 | X_others, 273 | y_others, 274 | test_size=sizes[2] / (sizes[1] + sizes[2]), 275 | random_state=seed, 276 | stratify=y_others, 277 | ) 278 | return {"train": X_train, "validation": X_val, "test": X_test} 279 | 280 | def load_metadata(self): 281 | self.load_track_ids() 282 | self.load_labels() 283 | self.load_encoded_labels() 284 | self.load_audio_paths() 285 | self.load_common_audio_dir() 286 | 287 | def encode_label(self, label): 288 | return self.label_encoder.transform([label])[0] 289 | 290 | def decode_label(self, encoded_label): 291 | if self.task_type == "multiclass_classification": 292 | import numpy as np 293 | 294 | # get index maximum value 295 | encoded_label = np.argmax(encoded_label) 296 | return self.label_encoder.inverse_transform([encoded_label])[0] 297 | return self.label_encoder.inverse_transform([encoded_label])[0] 298 | -------------------------------------------------------------------------------- /mir_ref/datasets/datasets/magnatagatune.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path 3 | import zipfile 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import wget 8 | from tqdm import tqdm 9 | 10 | from mir_ref.datasets.dataset import Dataset 11 | 12 | 13 | class MagnaTagATune(Dataset): 14 | def __init__( 15 | self, 16 | name, 17 | dataset_type, 18 | data_dir, 19 | split_type, 20 | task_name, 21 | task_type, 22 | feature_aggregation, 23 | deformations_cfg, 24 | features_cfg, 25 | ): 26 | """Dataset wrapper for MagnaTagATune.""" 27 | super().__init__( 28 | name=name, 29 | dataset_type=dataset_type, 30 | data_dir=data_dir, 31 | split_type=split_type, 32 | task_name=task_name, 33 | task_type=task_type, 34 | feature_aggregation=feature_aggregation, 35 | deformations_cfg=deformations_cfg, 36 | features_cfg=features_cfg, 37 | ) 38 | 39 | @Dataset.try_to_load_metadata 40 | def download(self): 41 | # make data dir if it doesn't exist or if it exists but is empty 42 | if os.path.exists(os.path.join(self.data_dir, "audio")) and ( 43 | len(os.listdir(os.path.join(self.data_dir, "audio"))) != 0 44 | ): 45 | import warnings 46 | 47 | warnings.warn( 48 | f"Dataset '{self.name}' already exists in '{self.data_dir}'." 49 | + "Skipping audio download.", 50 | stacklevel=2, 51 | ) 52 | self.download_metadata() 53 | return 54 | (Path(self.data_dir) / "audio").mkdir(parents=True, exist_ok=True) 55 | 56 | print(f"Downloading MagnaTagATune to {self.data_dir}...") 57 | for i in tqdm(["001", "002", "003"]): 58 | wget.download( 59 | url=f"https://mirg.city.ac.uk/datasets/magnatagatune/mp3.zip.{i}", 60 | out=os.path.join(self.data_dir, "audio/"), 61 | ) 62 | 63 | archive_dir = os.path.join(self.data_dir, "audio") 64 | 65 | # Combine the split archive files into a single file 66 | with open(os.path.join(archive_dir, "mp3.zip"), "wb") as f: 67 | for i in ["001", "002", "003"]: 68 | with open( 69 | os.path.join(archive_dir, f"mp3.zip.{i}"), 70 | "rb", 71 | ) as part: 72 | f.write(part.read()) 73 | 74 | # Extract the contents of the archive 75 | with zipfile.ZipFile(os.path.join(archive_dir, "mp3.zip"), "r") as zip_ref: 76 | zip_ref.extractall() 77 | 78 | # Remove zips 79 | for i in ["", ".001", ".002", ".003"]: 80 | os.remove(os.path.join(archive_dir, f"mp3.zip{i}")) 81 | 82 | self.download_metadata() 83 | 84 | @Dataset.try_to_load_metadata 85 | def download_metadata(self): 86 | if os.path.exists(os.path.join(self.data_dir, "metadata")) and ( 87 | len(os.listdir(os.path.join(self.data_dir, "metadata"))) != 0 88 | ): 89 | import warnings 90 | 91 | warnings.warn( 92 | f"Metadata for dataset '{self.name}' already exists in '{self.data_dir}'." 93 | + "Skipping metadata download.", 94 | stacklevel=2, 95 | ) 96 | return 97 | (Path(self.data_dir) / "metadata").mkdir(parents=True, exist_ok=True) 98 | 99 | urls = [ 100 | # annotations 101 | "https://mirg.city.ac.uk/datasets/magnatagatune/annotations_final.csv", 102 | # train, validation, and test splits from Won et al. 2020 103 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/train.npy", 104 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/valid.npy", 105 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/test.npy", 106 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/tags.npy", 107 | ] 108 | for url in urls: 109 | wget.download( 110 | url=url, 111 | out=os.path.join(self.data_dir, "metadata/"), 112 | ) 113 | 114 | def load_track_ids(self): 115 | with open( 116 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r" 117 | ) as f: 118 | annotations = csv.reader(f, delimiter="\t") 119 | next(annotations) # skip header 120 | self.track_ids = [line[0] for line in annotations] 121 | # manually remove some corrupt files 122 | self.track_ids.remove("35644") 123 | self.track_ids.remove("55753") 124 | self.track_ids.remove("57881") 125 | 126 | def load_labels(self): 127 | # get the list of top 50 tags used in Minz Won et al. 2020 128 | tags = np.load(os.path.join(self.data_dir, "metadata", "tags.npy")) 129 | 130 | with open( 131 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r" 132 | ) as f: 133 | annotations = csv.reader(f, delimiter="\t") 134 | annotations_header = next(annotations) 135 | self.labels = { 136 | line[0]: [ 137 | annotations_header[j] 138 | for j in range(1, len(line) - 1) 139 | # only add the tag if it's in the tags list 140 | if line[j] == "1" and annotations_header[j] in tags 141 | ] 142 | for line in annotations 143 | # this is a slow way to do it, temporary fix for 144 | # some corrupt mp3s 145 | if line[0] in self.track_ids 146 | } 147 | 148 | def load_audio_paths(self): 149 | with open( 150 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r" 151 | ) as f: 152 | annotations = csv.reader(f, delimiter="\t") 153 | next(annotations) # skip header 154 | self.audio_paths = { 155 | line[0]: os.path.join(self.data_dir, "audio", line[-1]) 156 | for line in annotations 157 | # this is a slow way to do it, temporary fix for 158 | # some corrupt mp3s 159 | if line[0] in self.track_ids 160 | } 161 | 162 | @Dataset.check_metadata_is_loaded 163 | def get_splits(self): 164 | # get inverse dictionary to get track id from audio path 165 | rel_path_to_track_id = { 166 | (Path(v).parent.name + "/" + Path(v).name): k 167 | for k, v in self.audio_paths.items() 168 | } 169 | 170 | split = {} 171 | for set_filename, set_targetname in zip( 172 | ["train", "valid", "test"], ["train", "validation", "test"] 173 | ): 174 | relative_paths = np.load( 175 | os.path.join(self.data_dir, "metadata", f"{set_filename}.npy") 176 | ) 177 | # get track_ids by getting the full path and using the inv dict 178 | split[set_targetname] = [ 179 | rel_path_to_track_id[path.split("\t")[1]] for path in relative_paths 180 | ] 181 | 182 | if self.split_type not in ["all", "single"]: 183 | raise NotImplementedError(f"Split type '{self.split_type}' does not exist.") 184 | 185 | return [split] 186 | -------------------------------------------------------------------------------- /mir_ref/datasets/datasets/mtg_jamendo.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import hashlib 3 | import os.path 4 | import sys 5 | import tarfile 6 | from pathlib import Path 7 | 8 | import tqdm 9 | import wget 10 | 11 | from mir_ref.datasets.dataset import Dataset 12 | 13 | 14 | class MTG_Jamendo(Dataset): 15 | def __init__( 16 | self, 17 | name, 18 | dataset_type, 19 | data_dir, 20 | split_type, 21 | task_name, 22 | task_type, 23 | feature_aggregation, 24 | deformations_cfg, 25 | features_cfg, 26 | ): 27 | """Dataset wrapper for MTG Jamendo (sub)dataset(s). Each subset 28 | (moodtheme, instrument, top50tags, genre) is going to be treated 29 | as a separate dataset, meaning a separate data_dir needs to be 30 | specified for each one. This helps disambiguate versioning of 31 | the deformations and experiments. Since the methods are shared, 32 | however, a single class is used for all of them. 33 | """ 34 | super().__init__( 35 | name=name, 36 | dataset_type=dataset_type, 37 | data_dir=data_dir, 38 | split_type=split_type, 39 | task_name=task_name, 40 | task_type=task_type, 41 | feature_aggregation=feature_aggregation, 42 | deformations_cfg=deformations_cfg, 43 | features_cfg=features_cfg, 44 | ) 45 | 46 | @Dataset.try_to_load_metadata 47 | def download(self): 48 | # make data dir if it doesn't exist or if it exists but is empty 49 | if os.path.exists(os.path.join(self.data_dir, "audio")) and ( 50 | len(os.listdir(os.path.join(self.data_dir, "audio"))) != 0 51 | ): 52 | import warnings 53 | 54 | warnings.warn( 55 | f"Dataset '{self.name}' already exists in '{self.data_dir}'." 56 | + "Skipping audio download.", 57 | stacklevel=2, 58 | ) 59 | self.download_metadata() 60 | return 61 | (Path(self.data_dir) / "audio").mkdir(parents=True, exist_ok=True) 62 | # download with mtg_jamendo download helper 63 | if self.name == "mtg-jamendo-moodtheme": 64 | # only moodtheme has a separate download target 65 | download_jamendo( 66 | dataset="autotagging_moodtheme", 67 | data_type="audio-low", 68 | download_from="mtg-fast", 69 | output_dir=os.path.join(self.data_dir, "audio/"), 70 | unpack_tars=True, 71 | remove_tars=True, 72 | ) 73 | elif ( 74 | self.name == "mtg-jamendo-instrument" 75 | or self.name == "mtg-jamendo-genre" 76 | or self.name == "mtg-jamendo-top50tags" 77 | ): 78 | # the whole dataset needs to be downloaded as no subset-specific 79 | # download targets are available 80 | download_jamendo( 81 | dataset="raw_30s", 82 | data_type="audio-low", 83 | download_from="mtg-fast", 84 | output_dir=os.path.join(self.data_dir, "audio/"), 85 | unpack_tars=True, 86 | remove_tars=True, 87 | ) 88 | # optionally I'd delete unneeded tracks here... 89 | else: 90 | raise NotImplementedError(f"Dataset '{self.name}' does not exist.") 91 | self.download_metadata() 92 | 93 | @Dataset.try_to_load_metadata 94 | def download_metadata(self): 95 | # download from github link 96 | url = ( 97 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/" 98 | + "master/data/autotagging_" 99 | + (self.name).split("-")[2] 100 | + ".tsv" 101 | ) 102 | (Path(self.data_dir) / "metadata").mkdir(parents=True, exist_ok=True) 103 | wget.download(url, out=os.path.join(self.data_dir, "metadata/")) 104 | 105 | def load_track_ids(self): 106 | # open the corresponding file using the subset name in self.name 107 | with open( 108 | os.path.join( 109 | self.data_dir, 110 | "metadata/", 111 | f"autotagging_{(self.name).split('-')[2]}.tsv", 112 | ), 113 | "r", 114 | ) as f: 115 | metadata = f.readlines() 116 | self.track_ids = [ 117 | line.split("\t")[0].strip() for line in metadata[1:] if line 118 | ] 119 | 120 | def load_labels(self): 121 | with open( 122 | os.path.join( 123 | self.data_dir, 124 | "metadata/", 125 | f"autotagging_{(self.name).split('-')[2]}.tsv", 126 | ), 127 | "r", 128 | ) as f: 129 | metadata = f.readlines() 130 | # for each line, get key (track_id) and all stripped entries after 131 | # column 5 (list of labels) 132 | self.labels = { 133 | line.split("\t")[0].strip(): [ 134 | tag.strip() for tag in line.split("\t")[5:] 135 | ] 136 | for line in metadata[1:] 137 | } 138 | 139 | def load_audio_paths(self): 140 | with open( 141 | os.path.join( 142 | self.data_dir, 143 | "metadata/", 144 | f"autotagging_{(self.name).split('-')[2]}.tsv", 145 | ), 146 | "r", 147 | ) as f: 148 | metadata = f.readlines() 149 | # the lowres version we downloaded contains "low" before the extension 150 | self.audio_paths = { 151 | line.split("\t")[0].strip(): os.path.join( 152 | self.data_dir, 153 | "audio", 154 | (line.split("\t")[3].strip())[:-4] + ".low.mp3", 155 | ) 156 | for line in metadata[1:] 157 | if line 158 | } 159 | 160 | @Dataset.check_metadata_is_loaded 161 | def get_splits(self): 162 | subset = (self.name).split("-")[2] 163 | splits_url_dir = ( 164 | "https://github.com/MTG/mtg-jamendo-dataset/blob/master/data/splits/" 165 | ) 166 | 167 | if not os.path.exists(os.path.join(self.data_dir, "splits/")): 168 | # there are 5 splits for each subset 169 | for i in range(5): 170 | (Path(self.data_dir) / "metadata" / "splits").mkdir( 171 | parents=True, exist_ok=True 172 | ) 173 | for split in ["train", "validation", "test"]: 174 | url = f"{splits_url_dir}split-{i}/autotagging_{subset}-{split}.tsv" 175 | wget.download( 176 | url, 177 | out=os.path.join( 178 | self.data_dir, "metadata", "splits", f"split-{i}" 179 | ), 180 | ) 181 | splits = [] 182 | for i in range(5): 183 | splits.append({}) 184 | for split in ["train", "validation", "test"]: 185 | with open( 186 | os.path.join( 187 | self.data_dir, 188 | "metadata", 189 | "splits", 190 | f"split-{i}", 191 | f"autotagging_{subset}-{split}.tsv", 192 | ), 193 | "r", 194 | ) as f: 195 | split_metadata = f.readlines() 196 | self.splits[i][split] = [ 197 | line.split("\t")[0].strip() 198 | for line in split_metadata[1:] 199 | if line 200 | ] 201 | if self.split_type == "single": 202 | return [splits[0]] 203 | elif self.split_type == "all": 204 | return splits 205 | else: 206 | raise NotImplementedError(f"Split type '{self.split_type}' does not exist.") 207 | 208 | 209 | """Code to download MTG Jamendo. 210 | Source: https://github.com/MTG/mtg-jamendo-dataset 211 | """ 212 | 213 | download_from_names = {"gdrive": "GDrive", "mtg": "MTG", "mtg-fast": "MTG Fast mirror"} 214 | 215 | 216 | def compute_sha256(filename): 217 | with open(filename, "rb") as f: 218 | contents = f.read() 219 | checksum = hashlib.sha256(contents).hexdigest() 220 | return checksum 221 | 222 | 223 | def download_jamendo( 224 | dataset, data_type, download_from, output_dir, unpack_tars, remove_tars 225 | ): 226 | if not os.path.exists(output_dir): 227 | print("Output directory {} does not exist".format(output_dir), file=sys.stderr) 228 | return 229 | 230 | if download_from not in download_from_names: 231 | print( 232 | "Unknown --from argument, choices are {}".format( 233 | list(download_from_names.keys()) 234 | ), 235 | file=sys.stderr, 236 | ) 237 | return 238 | 239 | print("Downloading %s from %s" % (dataset, download_from_names[download_from])) 240 | # download checksums 241 | file_sha256_tars_url = ( 242 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/master/data/download/" 243 | + f"{dataset}_{data_type}_sha256_tars.txt" 244 | ) 245 | wget.download(file_sha256_tars_url, out=output_dir) 246 | file_sha256_tracks_url = ( 247 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/master/data/download/" 248 | + f"{dataset}_{data_type}_sha256_tracks.txt" 249 | ) 250 | wget.download(file_sha256_tracks_url, out=output_dir) 251 | 252 | file_sha256_tars = os.path.join( 253 | output_dir, dataset + "_" + data_type + "_sha256_tars.txt" 254 | ) 255 | file_sha256_tracks = os.path.join( 256 | output_dir, dataset + "_" + data_type + "_sha256_tracks.txt" 257 | ) 258 | 259 | # Read checksum values for tars and files inside. 260 | with open(file_sha256_tars) as f: 261 | sha256_tars = dict([(row[1], row[0]) for row in csv.reader(f, delimiter=" ")]) 262 | 263 | with open(file_sha256_tracks) as f: 264 | sha256_tracks = dict([(row[1], row[0]) for row in csv.reader(f, delimiter=" ")]) 265 | 266 | # Filenames to download. 267 | ids = sha256_tars.keys() 268 | 269 | removed = [] 270 | for filename in tqdm(ids, total=len(ids)): 271 | output = os.path.join(output_dir, filename) 272 | # print(output) 273 | # print(filename) 274 | 275 | # Download from Google Drive. 276 | if os.path.exists(output): 277 | print("Skipping %s (file already exists)" % output) 278 | continue 279 | 280 | elif download_from == "mtg": 281 | url = ( 282 | "https://essentia.upf.edu/documentation/datasets/mtg-jamendo/" 283 | "%s/%s/%s" % (dataset, data_type, filename) 284 | ) 285 | # print("From:", url) 286 | # print("To:", output) 287 | wget.download(url, out=output) 288 | 289 | elif download_from == "mtg-fast": 290 | url = "https://cdn.freesound.org/mtg-jamendo/" "%s/%s/%s" % ( 291 | dataset, 292 | data_type, 293 | filename, 294 | ) 295 | # print("From:", url) 296 | # print("To:", output) 297 | wget.download(url, out=output) 298 | 299 | # Validate the checksum. 300 | if compute_sha256(output) != sha256_tars[filename]: 301 | print( 302 | "%s does not match the checksum, removing the file" % output, 303 | file=sys.stderr, 304 | ) 305 | removed.append(filename) 306 | os.remove(output) 307 | # else: 308 | # print("%s checksum OK" % filename) 309 | 310 | if removed: 311 | print("Missing files:", " ".join(removed)) 312 | print("Re-run the script again") 313 | return 314 | 315 | print("Download complete") 316 | 317 | if unpack_tars: 318 | print("Unpacking tar archives") 319 | 320 | tracks_checked = [] 321 | for filename in tqdm(ids, total=len(ids)): 322 | output = os.path.join(output_dir, filename) 323 | print("Unpacking", output) 324 | tar = tarfile.open(output) 325 | tracks = tar.getnames()[1:] # The first element is folder name. 326 | tar.extractall(path=output_dir) 327 | tar.close() 328 | 329 | # Validate checksums for all unpacked files 330 | for track in tracks: 331 | trackname = os.path.join(output_dir, track) 332 | if compute_sha256(trackname) != sha256_tracks[track]: 333 | print("%s does not match the checksum" % trackname, file=sys.stderr) 334 | raise Exception("Corrupt file in the dataset: %s" % trackname) 335 | 336 | # print("%s track checksums OK" % filename) 337 | tracks_checked += tracks 338 | 339 | if remove_tars: 340 | os.remove(output) 341 | 342 | # Check if any tracks are missing in the unpacked archives. 343 | if set(tracks_checked) != set(sha256_tracks.keys()): 344 | raise Exception( 345 | "Unpacked data contains tracks not present in the checksum files" 346 | ) 347 | 348 | print("Unpacking complete") 349 | 350 | # delete checksum files 351 | os.remove(file_sha256_tars) 352 | os.remove(file_sha256_tracks) 353 | -------------------------------------------------------------------------------- /mir_ref/datasets/datasets/vocalset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import wget 7 | 8 | from mir_ref.datasets.dataset import Dataset 9 | 10 | 11 | class VocalSet(Dataset): 12 | def __init__( 13 | self, 14 | name, 15 | dataset_type, 16 | data_dir, 17 | split_type, 18 | task_name, 19 | task_type, 20 | feature_aggregation, 21 | deformations_cfg, 22 | features_cfg, 23 | ): 24 | """Dataset wrapper for VocalSet dataset.""" 25 | super().__init__( 26 | name=name, 27 | dataset_type=dataset_type, 28 | data_dir=data_dir, 29 | split_type=split_type, 30 | task_name=task_name, 31 | task_type=task_type, 32 | feature_aggregation=feature_aggregation, 33 | deformations_cfg=deformations_cfg, 34 | features_cfg=features_cfg, 35 | ) 36 | 37 | @Dataset.try_to_load_metadata 38 | def download(self): 39 | if Path(self.data_dir).exists(): 40 | import warnings 41 | 42 | warnings.warn( 43 | ( 44 | f"Dataset {self.name} already exists at {self.data_dir}." 45 | + " Skipping download." 46 | ) 47 | ) 48 | self.preprocess() 49 | return 50 | # make data dir 51 | Path(self.data_dir).mkdir(parents=True, exist_ok=False) 52 | zenodo_url = "https://zenodo.org/record/1203819/files/VocalSet11.zip" 53 | print(f"Downloading VocalSet to {self.data_dir}...") 54 | wget.download(zenodo_url, self.data_dir) 55 | # extract zip 56 | subprocess.run( 57 | [ 58 | "unzip", 59 | "-q", 60 | "-d", 61 | str(Path(self.data_dir)), 62 | str(Path(self.data_dir) / "VocalSet11.zip"), 63 | ] 64 | ) 65 | # remove zip 66 | subprocess.run(["rm", str(Path(self.data_dir) / "VocalSet11.zip")]) 67 | 68 | # preprocess 69 | self.preprocess() 70 | 71 | def preprocess(self): 72 | # need to make some corrections to filenames and delete some duplicates 73 | # Thanks to the MARBLE authors for the list of corrections and dups: 74 | # https://github.com/a43992899/MARBLE-Benchmark/blob/main/benchmark/tasks/VocalSet/preprocess.py 75 | 76 | if not ( 77 | Path(self.data_dir) 78 | / "FULL" 79 | / "**" 80 | / "*" 81 | / "vibrato/f2_scales_vibrato_a(1).wav" 82 | ).exists(): 83 | # dataset is probably already preprocessed 84 | return 85 | 86 | file_delete = [ 87 | "vibrato/f2_scales_vibrato_a(1).wav", 88 | "vibrato/caro_vibrato.wav", 89 | "vibrato/dona_vibrato.wav", 90 | "vibrato/row_vibrato.wav", 91 | "vibrado/slow_vibrato_arps.wav", 92 | ] 93 | 94 | filepaths_to_delete = [ 95 | glob.glob(str(Path(self.data_dir) / "FULL" / "**" / "*" / f)) 96 | for f in file_delete 97 | ] 98 | # flatten list of lists 99 | filepaths_to_delete = [ 100 | item for sublist in filepaths_to_delete for item in sublist 101 | ] 102 | # delete files 103 | for f in filepaths_to_delete: 104 | os.remove(f) 105 | 106 | # thanks black formatter for this beauty 107 | name_correction = [ 108 | ("/lip_trill/lip_trill_arps.wav", "/lip_trill/f2_lip_trill_arps.wav"), 109 | ("/lip_trill/scales_lip_trill.wav", "/lip_trill/m3_scales_lip_trill.wav"), 110 | ( 111 | "/straight/arpeggios_straight_a.wav", 112 | "/straight/f4_arpeggios_straight_a.wav", 113 | ), 114 | ( 115 | "/straight/arpeggios_straight_e.wav", 116 | "/straight/f4_arpeggios_straight_e.wav", 117 | ), 118 | ( 119 | "/straight/arpeggios_straight_i.wav", 120 | "/straight/f4_arpeggios_straight_i.wav", 121 | ), 122 | ( 123 | "/straight/arpeggios_straight_o.wav", 124 | "/straight/f4_arpeggios_straight_o.wav", 125 | ), 126 | ( 127 | "/straight/arpeggios_straight_u.wav", 128 | "/straight/f4_arpeggios_straight_u.wav", 129 | ), 130 | ("/straight/row_straight.wav", "/straight/m8_row_straight.wav"), 131 | ("/straight/scales_straight_a.wav", "/straight/f4_scales_straight_a.wav"), 132 | ("/straight/scales_straight_e.wav", "/straight/f4_scales_straight_e.wav"), 133 | ("/straight/scales_straight_i.wav", "/straight/f4_scales_straight_i.wav"), 134 | ("/straight/scales_straight_o.wav", "/straight/f4_scales_straight_o.wav"), 135 | ("/straight/scales_straight_u.wav", "/straight/f4_scales_straight_u.wav"), 136 | ("/vocal_fry/scales_vocal_fry.wav", "/vocal_fry/f2_scales_vocal_fry.wav"), 137 | ( 138 | "/fast_forte/arps_fast_piano_c.wav", 139 | "/fast_forte/f9_arps_fast_piano_c.wav", 140 | ), 141 | ( 142 | "/fast_piano/fast_piano_arps_f.wav", 143 | "/fast_piano/f2_fast_piano_arps_f.wav", 144 | ), 145 | ( 146 | "/fast_piano/arps_c_fast_piano.wav", 147 | "/fast_piano/m3_arps_c_fast_piano.wav", 148 | ), 149 | ( 150 | "/fast_piano/scales_fast_piano_f.wav", 151 | "/fast_piano/f3_scales_fast_piano_f.wav", 152 | ), 153 | ( 154 | "/fast_piano/scales_c_fast_piano_a.wav", 155 | "/fast_piano/m10_scales_c_fast_piano_a.wav", 156 | ), 157 | ( 158 | "/fast_piano/scales_c_fast_piano_e.wav", 159 | "/fast_piano/m10_scales_c_fast_piano_e.wav", 160 | ), 161 | ( 162 | "/fast_piano/scales_c_fast_piano_i.wav", 163 | "/fast_piano/m10_scales_c_fast_piano_i.wav", 164 | ), 165 | ( 166 | "/fast_piano/scales_c_fast_piano_o.wav", 167 | "/fast_piano/m10_scales_c_fast_piano_o.wav", 168 | ), 169 | ( 170 | "/fast_piano/scales_c_fast_piano_u.wav", 171 | "/fast_piano/m10_scales_c_fast_piano_u.wav", 172 | ), 173 | ( 174 | "/fast_piano/scales_f_fast_piano_a.wav", 175 | "/fast_piano/m10_scales_f_fast_piano_a.wav", 176 | ), 177 | ( 178 | "/fast_piano/scales_f_fast_piano_e.wav", 179 | "/fast_piano/m10_scales_f_fast_piano_e.wav", 180 | ), 181 | ( 182 | "/fast_piano/scales_f_fast_piano_i.wav", 183 | "/fast_piano/m10_scales_f_fast_piano_i.wav", 184 | ), 185 | ( 186 | "/fast_piano/scales_f_fast_piano_o.wav", 187 | "/fast_piano/m10_scales_f_fast_piano_o.wav", 188 | ), 189 | ( 190 | "/fast_piano/scales_f_fast_piano_u.wav", 191 | "/fast_piano/m10_scales_f_fast_piano_u.wav", 192 | ), 193 | ] 194 | 195 | for old, new in name_correction: 196 | old_matches = glob.glob( 197 | str(Path(self.data_dir) / "FULL" / "**" / "*" / old[1:]) 198 | ) 199 | target_filepaths = [f.replace(old, new) for f in old_matches] 200 | # rename 201 | for old, new in zip(old_matches, target_filepaths): 202 | os.rename(old, new) 203 | 204 | @Dataset.try_to_load_metadata 205 | def download_metadata(self): 206 | # not possible to download only metadata for vocalset 207 | self.download() 208 | 209 | def load_track_ids(self): 210 | # the names of the tracks can be used as a unique identifier, as they contain 211 | # singer, technique, and take index information. 212 | 213 | if self.task_name == "singer_identification": 214 | self.track_ids = [ 215 | Path(path).stem 216 | for path in glob.glob( 217 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True 218 | ) 219 | ] 220 | elif self.task_name == "technique_identification": 221 | # use the 10 techniques used in the original paper 222 | techniques = [ 223 | "vibrato", 224 | "straight", 225 | "belt", 226 | "breathy", 227 | "lip_trill", 228 | "spoken", 229 | "inhaled", 230 | "trill", 231 | "trillo", 232 | "vocal_fry", 233 | ] 234 | # only add track_ids that contain one of the techniques 235 | self.track_ids = [ 236 | Path(path).stem 237 | for path in glob.glob( 238 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True 239 | ) 240 | if any(technique in Path(path).stem for technique in techniques) 241 | ] 242 | 243 | def load_labels(self): 244 | if self.task_name == "singer_identification": 245 | self.labels = { 246 | Path(path).stem: (Path(path).stem)[:2] 247 | if ((Path(path).stem)[:3] != "m10" and (Path(path).stem)[:3] != "m11") 248 | else (Path(path).stem)[:3] 249 | for path in glob.glob( 250 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True 251 | ) 252 | } 253 | elif self.task_name == "technique_identification": 254 | # use the 10 techniques used in the original paper 255 | techniques = [ 256 | "vibrato", 257 | "straight", 258 | "belt", 259 | "breathy", 260 | "lip_trill", 261 | "spoken", 262 | "inhaled", 263 | "trill", 264 | "trillo", 265 | "vocal_fry", 266 | ] 267 | labels = {} 268 | for track_id in self.track_ids: 269 | for technique in techniques: 270 | if technique in track_id: 271 | labels[track_id] = technique 272 | self.labels = labels 273 | else: 274 | raise NotImplementedError( 275 | f"Task '{self.task_name}' not available for this dataset." 276 | ) 277 | 278 | def load_audio_paths(self): 279 | audio_paths_list = [ 280 | path 281 | for path in glob.glob( 282 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True 283 | ) 284 | ] 285 | self.audio_paths = {Path(path).stem: path for path in audio_paths_list} 286 | 287 | @Dataset.check_metadata_is_loaded 288 | def get_splits(self): 289 | if self.task_name == "singer_identification": 290 | # no official splits are available, get stratified one 291 | return [super().get_stratified_split()] 292 | elif self.task_name == "technique_identification": 293 | train_singers = [ 294 | "f1", 295 | "f3", 296 | "f4", 297 | "f5", 298 | "f6", 299 | "f7", 300 | "f9", 301 | "m1", 302 | "m2", 303 | "m4", 304 | "m6", 305 | "m7", 306 | "m8", 307 | "m9", 308 | ] 309 | # the original paper does not have a validation set, so we steal f9, 310 | # m9, and m11 from train 311 | val_singers = ["f9", "m9", "m11"] 312 | test_singer = ["f2", "f8", "m3", "m5", "m10"] 313 | 314 | split = {} 315 | split["train"] = [ 316 | track_id for track_id in self.track_ids if track_id[:2] in train_singers 317 | ] 318 | split["validation"] = [ 319 | track_id for track_id in self.track_ids if track_id[:2] in val_singers 320 | ] 321 | split["test"] = [ 322 | track_id for track_id in self.track_ids if track_id[:2] in test_singer 323 | ] 324 | return [split] 325 | else: 326 | raise NotImplementedError( 327 | f"Task '{self.task_name}' not available for this dataset." 328 | ) 329 | -------------------------------------------------------------------------------- /mir_ref/datasets/mirdata_datasets.py: -------------------------------------------------------------------------------- 1 | """Wrapper for mirdata datasets, using relevant functions 2 | and adjusting them based on task requirements.""" 3 | 4 | import mirdata 5 | 6 | from mir_ref.datasets.dataset import Dataset 7 | 8 | 9 | class MirdataDataset(Dataset): 10 | def __init__( 11 | self, 12 | name, 13 | dataset_type, 14 | data_dir, 15 | split_type, 16 | task_name, 17 | task_type, 18 | feature_aggregation, 19 | deformations_cfg, 20 | features_cfg, 21 | ): 22 | """Dataset wrapper for mirdata dataset. 23 | 24 | Args: 25 | name (str): Name of the dataset. 26 | dataset_type (str): Type of the dataset ("mirdata", "custom") 27 | task_name (str): Name of the task. 28 | task_type (str): Type of the task. 29 | data_dir (str, optional): Path to the dataset directory. 30 | Defaults to ./data/{name}/. 31 | split_type (str, optional): Whether to use "all" or "single" split. 32 | Defaults to "single". 33 | deformations_cfg (list, optional): List of deformation scenarios. 34 | features_cfg (list, optional): List of embedding models. 35 | """ 36 | super().__init__( 37 | name=name, 38 | dataset_type=dataset_type, 39 | data_dir=data_dir, 40 | split_type=split_type, 41 | task_name=task_name, 42 | task_type=task_type, 43 | feature_aggregation=feature_aggregation, 44 | deformations_cfg=deformations_cfg, 45 | features_cfg=features_cfg, 46 | ) 47 | # initialize mirdata dataset 48 | self.dataset = mirdata.initialize( 49 | dataset_name=self.name, data_home=self.data_dir 50 | ) 51 | 52 | def download(self): 53 | self.dataset.download() 54 | self.dataset.validate(verbose=False) 55 | 56 | # try to load metadata again 57 | self.load_metadata() 58 | 59 | def download_metadata(self): 60 | try: 61 | self.dataset.download(partial_download=["metadata"]) 62 | except ValueError: 63 | self.dataset.download(partial_download=["annotations"]) 64 | self.dataset.validate(verbose=False) 65 | 66 | # try to load metadata again 67 | self.load_metadata() 68 | 69 | def preprocess(self): 70 | """Modifications to the downloaded content of the dataset.""" 71 | return 72 | 73 | def load_track_ids(self): 74 | if self.task_name == "pitch_class_estimation": 75 | if self.name == "tinysol": 76 | track_ids = self.dataset.track_ids 77 | # only keep track_ids with single pitch annotations 78 | for track_id in track_ids: 79 | if len(self.dataset.track(track_id).pitch) != 1: 80 | track_ids.remove(track_id) 81 | elif self.task_name == "pitch_register_estimation": 82 | if self.name == "tinysol": 83 | track_ids = self.dataset.track_ids 84 | # only keep track_ids with single pitch annotations 85 | for track_id in track_ids: 86 | if len(self.dataset.track(track_id).pitch) != 1: 87 | track_ids.remove(track_id) 88 | elif self.task_name == "key_estimation": 89 | if self.name == "beatport_key": 90 | # only keep track_ids with single key annotations 91 | track_ids = [ 92 | track_id 93 | for track_id in self.dataset.track_ids 94 | if len(self.dataset.track(track_id).key) == 1 95 | and len(self.dataset.track(track_id).key[0].split(" ")) == 2 96 | and "other" not in self.dataset.track(track_id).key[0] 97 | ] 98 | else: 99 | track_ids = self.dataset.track_ids 100 | 101 | self.track_ids = track_ids 102 | 103 | def load_labels(self): 104 | labels = {} 105 | for track_id in self.track_ids: 106 | if ( 107 | self.task_name == "instrument_recognition" 108 | or self.task_name == "instrument_classification" 109 | ): 110 | if self.name == "tinysol": 111 | labels[track_id] = self.dataset.track(track_id).instrument_full 112 | elif self.task_name == "tagging": 113 | labels[track_id] = self.dataset.track(track_id).tags 114 | elif self.task_name == "pitch_class_classification": 115 | pitch = self.dataset.track(track_id).pitch 116 | labels[track_id] = "".join([c for c in pitch if not c.isdigit()]) 117 | elif self.task_name == "pitch_register_classification": 118 | pitch = self.dataset.track(track_id).pitch 119 | labels[track_id] = "".join([c for c in pitch if c.isdigit()]) 120 | elif self.task_name == "key_estimation": 121 | # map enharmonic keys, always use sharps 122 | enharm_map = { 123 | "Db": "C#", 124 | "Eb": "D#", 125 | "Gb": "F#", 126 | "Ab": "G#", 127 | "Bb": "A#", 128 | "F#_": "F#", # ok yes, that's an annotation fix 129 | } 130 | key = self.dataset.track(track_id).key[0].strip() 131 | for pitch_class in enharm_map.keys(): 132 | if pitch_class in key: 133 | key = key.split(" ") 134 | key[0] = enharm_map[key[0]] 135 | key = " ".join(key) 136 | break 137 | labels[track_id] = key 138 | 139 | self.labels = labels 140 | 141 | def load_audio_paths(self): 142 | self.audio_paths = { 143 | t_id: self.dataset.track(t_id).audio_path for t_id in self.track_ids 144 | } 145 | 146 | # @Dataset.check_metadata_is_loaded 147 | def get_splits(self, seed=42): 148 | if self.split_type not in ["all", "single"] or isinstance( 149 | self.split_type, list 150 | ): 151 | raise ValueError( 152 | "Split type must be 'all', 'single', or " 153 | + "list of 3 floats adding up to 1." 154 | ) 155 | # !!!validate metadata exists, and download if not 156 | # tags can be in metadata or annotations. Try metadata first. 157 | # try: 158 | # self.dataset.download(partial_download=["metadata"]) 159 | # except ValueError: 160 | # self.dataset.download(partial_download=["annotations"]) 161 | # self.dataset.validate(verbose=False) 162 | 163 | splits = [] 164 | if self.split_type in ["all", "single"]: 165 | # check for up to 50 splits 166 | for i in range(50): 167 | try: 168 | split = self.dataset.get_track_splits(i) 169 | except TypeError: 170 | # if it fails at i=0, there are either no splits or only one split 171 | if i == 0: 172 | try: 173 | split = self.dataset.get_track_splits() 174 | except NotImplementedError: 175 | # no splits are available, so we need to generate them 176 | print( 177 | "No official splits found, generating random, stratified " 178 | + f"ones with seed {seed}." 179 | ) 180 | splits.append(self.get_stratified_split(seed=seed)) 181 | break 182 | # if it fails at i>0, there are no more splits 183 | else: 184 | break 185 | # we need to determine if each split returned is in the format 186 | # ["train", "validation", "test"], or whether they are actually 187 | # just folds keyed by integer index. 188 | try: 189 | _, _, _ = split["train"], split["validation"], split["test"] 190 | splits.append(split) 191 | except KeyError: 192 | # assume they're folds 193 | n_folds = len(split.keys()) 194 | if n_folds >= 3: 195 | # get splits for cross validation, assigning one for 196 | # validation and test, and the rest for training 197 | for fold_idx in range(n_folds): 198 | fold = {} 199 | available_folds = list(split.keys()) 200 | 201 | fold["validation"] = split[fold_idx] 202 | available_folds.remove(fold_idx) 203 | 204 | fold["test"] = split[(fold_idx + 1) % n_folds] 205 | available_folds.remove((fold_idx + 1) % n_folds) 206 | 207 | # the rest is train, get lists and flatten 208 | fold["train"] = sum( 209 | [split[af] for af in available_folds], [] 210 | ) 211 | 212 | splits.append(fold) 213 | else: 214 | print( 215 | "No official splits found, generating random, stratified " 216 | + f"ones with seed {seed}." 217 | ) 218 | splits.append(self.get_stratified_split(seed=seed)) 219 | if i == 0: 220 | break 221 | 222 | if self.split_type == "single": 223 | splits = [splits[0]] 224 | 225 | else: 226 | # else split_type is a list of sizes, meaning get stratified splits 227 | splits.append(self.get_stratified_split(sizes=self.split_type, seed=seed)) 228 | 229 | return splits 230 | -------------------------------------------------------------------------------- /mir_ref/deform.py: -------------------------------------------------------------------------------- 1 | """Apply deformations to audio files. 2 | """ 3 | 4 | from colorama import Fore, Style 5 | 6 | from mir_ref.datasets.dataset import get_dataset 7 | from mir_ref.deformations import generate_deformations 8 | from mir_ref.utils import load_config 9 | 10 | 11 | def deform(cfg_path, n_jobs): 12 | cfg = load_config(cfg_path) 13 | 14 | # iterate through every dataset of every experiment to generate deformations 15 | print(Fore.GREEN + "# Generating deformations...", Style.RESET_ALL) 16 | for exp_cfg in cfg["experiments"]: 17 | for dataset_cfg in exp_cfg["datasets"]: 18 | # generate deformations 19 | print(Fore.GREEN + f"## Dataset: {dataset_cfg['name']}", Style.RESET_ALL) 20 | 21 | dataset = get_dataset( 22 | dataset_cfg=dataset_cfg, 23 | task_cfg=exp_cfg["task"], 24 | features_cfg=exp_cfg["features"], 25 | ) 26 | dataset.download() 27 | dataset.preprocess() 28 | 29 | generate_deformations( 30 | dataset, 31 | n_jobs=n_jobs, 32 | ) 33 | -------------------------------------------------------------------------------- /mir_ref/deformations.py: -------------------------------------------------------------------------------- 1 | """Implementations of various audio deformations used for 2 | robustness evaluation. 3 | """ 4 | 5 | import os 6 | 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 8 | 9 | from pathlib import Path 10 | 11 | import librosa 12 | import soundfile as sf 13 | from joblib import Parallel, delayed 14 | from tqdm import tqdm 15 | 16 | 17 | def deform_audio( 18 | track_id: str, 19 | dataset, 20 | ): 21 | """For each deformation scenario, deform and save single audio file.""" 22 | 23 | # load audio 24 | y, sr = librosa.load(dataset.audio_paths[track_id], sr=None) 25 | 26 | for scenario_idx, scenario in enumerate(dataset.deformations_cfg): 27 | y_d = y.copy() 28 | # we're looping like this so that we retain the order of 29 | # deformations provided by the user 30 | # debatable whether this should be the default behavior 31 | for deformation in scenario: 32 | if deformation["type"] == "AddGaussianSNR": 33 | from audiomentations import AddGaussianSNR 34 | 35 | transform = AddGaussianSNR(**deformation["params"]) 36 | elif deformation["type"] == "ApplyImpulseResponse": 37 | from audiomentations import ApplyImpulseResponse 38 | 39 | transform = ApplyImpulseResponse(**deformation["params"]) 40 | elif deformation["type"] == "ClippingDistortion": 41 | from audiomentations import ClippingDistortion 42 | 43 | transform = ClippingDistortion(**deformation["params"]) 44 | elif deformation["type"] == "Gain": 45 | from audiomentations import Gain 46 | 47 | transform = Gain(**deformation["params"]) 48 | elif deformation["type"] == "Mp3Compression": 49 | from audiomentations import Mp3Compression 50 | 51 | transform = Mp3Compression(**deformation["params"]) 52 | elif deformation["type"] == "PitchShift": 53 | from audiomentations import PitchShift 54 | 55 | transform = PitchShift(**deformation["params"]) 56 | else: 57 | raise ValueError(f"Deformation {deformation['type']} not implemented.") 58 | 59 | # apply deformation 60 | y_d = transform(y_d, sr) 61 | del transform 62 | 63 | # save deformed audio 64 | output_filepath = dataset.get_deformed_audio_path( 65 | track_id=track_id, deform_idx=scenario_idx 66 | ) 67 | 68 | # create parent dirs if they don't exist, and write audio 69 | Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) 70 | sf.write(output_filepath, y_d, sr) 71 | 72 | 73 | def deform_audio_essentia( 74 | track_id: str, 75 | dataset, 76 | ): 77 | """Some of the essentia ffmpeg calls seem to have deprecated parameters, 78 | something related to the time_base used. I'm temporarily moving the 79 | essentia loading and writing here. 80 | """ 81 | # move to imports if used 82 | from essentia.standard import AudioLoader, AudioWriter 83 | 84 | input_filepath = dataset.audio_paths[track_id] 85 | 86 | # load audio 87 | y, sr, channels, _, bit_rate, _ = AudioLoader(filename=input_filepath)() 88 | # assert audio is mono or stereo 89 | assert channels <= 2 90 | # some book keeping for constructing the output path later 91 | file_dir = Path(input_filepath).parent 92 | file_stem = str(Path(input_filepath).stem) 93 | file_suffix = str(Path(input_filepath).suffix) 94 | 95 | for scenario_idx, scenario in enumerate(dataset.deformations_cfg): 96 | y_d = y.copy() 97 | # we're looping like this so that we retain the order of 98 | # deformations provided by the user 99 | # debatable whether this should be the default behavior 100 | for deformation in scenario: 101 | if deformation["type"] == "AddGaussianSNR": 102 | from audiomentations import AddGaussianSNR 103 | 104 | transform = AddGaussianSNR(**deformation["params"]) 105 | elif deformation["type"] == "ApplyImpulseResponse": 106 | from audiomentations import ApplyImpulseResponse 107 | 108 | transform = ApplyImpulseResponse(**deformation["params"]) 109 | elif deformation["type"] == "ClippingDistortion": 110 | from audiomentations import ClippingDistortion 111 | 112 | transform = ClippingDistortion(**deformation["params"]) 113 | elif deformation["type"] == "Gain": 114 | from audiomentations import Gain 115 | 116 | transform == Gain(**deformation["params"]) 117 | elif deformation["type"] == "Mp3Compression": 118 | from audiomentations import Mp3Compression 119 | 120 | transform = Mp3Compression(**deformation["params"]) 121 | elif deformation["type"] == "PitchShift": 122 | from audiomentations import PitchShift 123 | 124 | transform = PitchShift(**deformation["params"]) 125 | else: 126 | raise ValueError(f"Deformation {deformation['type']} not implemented.") 127 | 128 | # apply deformation 129 | y_d = transform(y_d, sr) 130 | 131 | # save deformed audio 132 | output_filepath = dataset.get_deformed_audio_path( 133 | track_id=track_id, deform_idx=scenario_idx 134 | ) 135 | 136 | # create parent dirs if they don't exist, and write audio 137 | Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) 138 | 139 | # special case for lossy compression formats in which bitrate needs to be specified 140 | lossy_format_bitrates = [ 141 | 32, 142 | 40, 143 | 48, 144 | 56, 145 | 64, 146 | 80, 147 | 96, 148 | 112, 149 | 128, 150 | 144, 151 | 160, 152 | 192, 153 | 224, 154 | 256, 155 | 320, 156 | ] 157 | if ( 158 | file_suffix == ".mp3" or file_suffix == ".ogg" 159 | ) and bit_rate in lossy_format_bitrates: 160 | AudioWriter( 161 | filename=output_filepath, 162 | format=file_suffix[1:], 163 | sampleRate=sr, 164 | bitrate=bit_rate, 165 | )(y_d) 166 | else: 167 | AudioWriter( 168 | filename=output_filepath, format=file_suffix[1:], sampleRate=sr 169 | )(y_d) 170 | 171 | 172 | def generate_deformations( 173 | dataset, 174 | n_jobs: int = 1, 175 | ): 176 | """Generate deformed audio and save.""" 177 | 178 | # check if there are no deformations specified in the experiment 179 | if not dataset.deformations_cfg: 180 | print( 181 | f"No deformations specified for '{dataset.name}'. Skipping deformation generation." 182 | ) 183 | return 184 | 185 | # create output dir for deformed audio if it doesn't exist 186 | (Path(dataset.data_dir) / "audio_deformed").mkdir(parents=True, exist_ok=True) 187 | 188 | if n_jobs == 1: 189 | for track_id in tqdm(dataset.track_ids): 190 | deform_audio(track_id, dataset) 191 | else: 192 | # this passes around the dataset object, which is not ideal for performance 193 | Parallel(n_jobs=n_jobs, verbose=0)( 194 | delayed(deform_audio)(track_id, dataset) 195 | for track_id in tqdm(dataset.track_ids) 196 | ) 197 | 198 | print("Deformed audio generated and saved.") 199 | -------------------------------------------------------------------------------- /mir_ref/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluation of downstream models 2 | """ 3 | 4 | import json 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from colorama import Fore, Style 9 | from sklearn.metrics import ( 10 | average_precision_score, 11 | classification_report, 12 | roc_auc_score, 13 | ) 14 | from sklearn.model_selection import ParameterGrid 15 | 16 | from mir_ref.dataloaders import DataGenerator 17 | from mir_ref.datasets.dataset import get_dataset 18 | from mir_ref.probes.probe_builder import get_model 19 | from mir_ref.utils import load_config 20 | 21 | 22 | def evaluate(cfg_path, run_id=None): 23 | """""" 24 | cfg = load_config(cfg_path) 25 | 26 | if run_id is None: 27 | print( 28 | "No run ID has been specified. Attempting to load latest training run,", 29 | "but this will fail if no runs have a timestamp IDs.", 30 | ) 31 | # attempt to load latest experiment by sorting dirs in logs 32 | dirs = [d for d in Path("./logs").iterdir() if d.is_dir()] 33 | # only keep dirs that has numeric and - in the name 34 | dirs = [d for d in dirs if d.name.replace("-", "").isnumeric()] 35 | if not dirs: 36 | raise ValueError( 37 | "No run ID has been specified and no timestamped runs have been found." 38 | ) 39 | run_id = sorted(dirs)[-1].name 40 | 41 | for exp_cfg in cfg["experiments"]: 42 | run_params = { 43 | "dataset_cfg": exp_cfg["datasets"], 44 | "feature": exp_cfg["features"], 45 | "model_cfg": exp_cfg["probes"], 46 | } 47 | 48 | # create grid from parameters 49 | grid = ParameterGrid(run_params) 50 | 51 | # !!!temporary, assumes single dataset per task 52 | dataset = get_dataset( 53 | dataset_cfg=exp_cfg["datasets"][0], 54 | task_cfg=exp_cfg["task"], 55 | features_cfg=exp_cfg["features"], 56 | ) 57 | dataset.download() 58 | 59 | for params in grid: 60 | # get index of downstream model for naming-logging 61 | model_idx = exp_cfg["probes"].index(params["model_cfg"]) 62 | # get dataset object 63 | # dataset = get_dataset( 64 | # dataset_cfg=params["dataset_cfg"], 65 | # task_cfg=exp_cfg["task"], 66 | # deformations_cfg=exp_cfg["deformations"], 67 | # features_cfg=exp_cfg["features"], 68 | # ) 69 | # dataset.download() 70 | # run task for every split 71 | for split_idx in range(len(dataset.get_splits())): 72 | print( 73 | Fore.GREEN 74 | + f"Task: {dataset.task_name}\n" 75 | + f"└── Dataset: {dataset.name}\n" 76 | + f" └── Embeddings: {params['feature']}\n" 77 | + f" └── Model: {model_idx}\n" 78 | + f" └── Split: {split_idx}", 79 | Style.RESET_ALL, 80 | ) 81 | evaluate_probe( 82 | run_id=run_id, 83 | dataset=dataset, 84 | model_cfg=params["model_cfg"], 85 | model_idx=model_idx, 86 | feature=params["feature"], 87 | split_idx=split_idx, 88 | ) 89 | 90 | 91 | def evaluate_probe(run_id, dataset, model_cfg, model_idx, feature, split_idx): 92 | """Evaluate downstream models, including in cases 93 | with deformations. 94 | 95 | Args: 96 | run_id (str): ID of the current run, defaults to timestamp. 97 | dataset (Dataset): Dataset object. 98 | model_cfg (dict): Downstream model config. 99 | model_idx (int): Index of the downstream model in the list of models. 100 | split_idx (int): Index of the split in the list of splits. 101 | """ 102 | 103 | split = dataset.get_splits()[split_idx] 104 | n_classes = len(dataset.encoded_labels[dataset.track_ids[0]]) 105 | 106 | if model_cfg["emb_shape"] == "infer": 107 | # get embedding shape from the first embedding 108 | emb_shape = np.load( 109 | dataset.get_embedding_path( 110 | feature=feature, 111 | track_id=split["train"][0], 112 | ) 113 | ).shape 114 | elif isinstance(model_cfg["emb_shape"], int): 115 | emb_shape = model_cfg["emb_shape"] 116 | elif isinstance(model_cfg["emb_shape"], str): 117 | raise ValueError(f"{model_cfg['emb_shape']} not implemented.") 118 | 119 | run_dir = ( 120 | Path(run_id) / dataset.task_name / f"{dataset.name}_{feature}_model-{model_idx}" 121 | ) 122 | 123 | # get all dirs starting with run_dir, sort them, and get latest 124 | run_dirs = [ 125 | d 126 | for d in (Path("./logs") / run_dir.parent).iterdir() 127 | if str(d).startswith(str(Path("logs") / run_dir)) 128 | ] 129 | # remove ./logs from the start of dirs 130 | run_dirs = [d.relative_to("logs") for d in run_dirs] 131 | 132 | new_run_dir = sorted(run_dirs)[-1] 133 | if run_dir != new_run_dir: 134 | import warnings 135 | 136 | warnings.warn( 137 | f"Multiple runs for '{run_dir}' found. Loading latest run '{new_run_dir}'", 138 | stacklevel=2, 139 | ) 140 | run_dir = new_run_dir 141 | 142 | # load model 143 | model = get_model(model_cfg=model_cfg, dim=emb_shape, n_classes=n_classes) 144 | model.load_weights(filepath=Path("./logs") / run_dir / "weights.h5") 145 | 146 | # load data 147 | test_gen = DataGenerator( 148 | ids_list=split["test"], 149 | labels_dict=dataset.encoded_labels, 150 | paths_dict={ 151 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id) 152 | for t_id in split["test"] 153 | }, 154 | batch_size=model_cfg["batch_size"], 155 | dim=emb_shape, 156 | n_classes=n_classes, 157 | shuffle=False, 158 | ) 159 | 160 | pred = model.predict(x=test_gen, batch_size=model_cfg["batch_size"], verbose=1) 161 | 162 | if dataset.task_type == "multiclass_classification": 163 | # get one-hot encoded vectors where 1 is the argmax in each case 164 | y_pred = [np.eye(len(p))[np.argmax(p)] for p in pred] 165 | # predictions might be shorter because of partial batch drop 166 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][: len(y_pred)] 167 | # doing this twice because it has nice formatting, but need the json after 168 | print(classification_report(y_true=y_true, y_pred=y_pred)) 169 | metrics = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True) 170 | if dataset.task_name == "key_estimation": 171 | # calculate weighted accuracy 172 | from mir_ref.metrics import key_detection_weighted_accuracy 173 | 174 | metrics["weighted_accuracy"] = key_detection_weighted_accuracy( 175 | y_true=[dataset.decode_label(y) for y in y_true], 176 | y_pred=[dataset.decode_label(y) for y in y_pred], 177 | ) 178 | print("Weighted accuracy:", metrics["weighted_accuracy"]) 179 | elif dataset.task_type == "multilabel_classification": 180 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][: len(pred)] 181 | y_pred = pred 182 | metrics = { 183 | "roc_auc": roc_auc_score(y_true, y_pred), 184 | "average_precision": average_precision_score(y_true, y_pred), 185 | } 186 | print(metrics) 187 | 188 | with open("./logs" / run_dir / "clean_metrics.json", "w+") as f: 189 | json.dump(metrics, f, indent=4) 190 | 191 | # do the same but with the deformed audio as the test set 192 | if dataset.deformations_cfg: 193 | for scenario_idx, scenario_cfg in enumerate(dataset.deformations_cfg): 194 | print( 195 | Fore.GREEN 196 | + "# Scenario: " 197 | + f"{scenario_idx+1}/{len(dataset.deformations_cfg)} " 198 | + f"{[cfg['type'] for cfg in scenario_cfg]}", 199 | Style.RESET_ALL, 200 | ) 201 | 202 | # load data 203 | test_gen = DataGenerator( 204 | ids_list=split["test"], 205 | labels_dict=dataset.encoded_labels, 206 | paths_dict={ 207 | t_id: dataset.get_deformed_embedding_path( 208 | feature=feature, track_id=t_id, deform_idx=scenario_idx 209 | ) 210 | for t_id in split["test"] 211 | }, 212 | batch_size=model_cfg["batch_size"], 213 | dim=emb_shape, 214 | n_classes=n_classes, 215 | shuffle=False, 216 | ) 217 | 218 | pred = model.predict( 219 | x=test_gen, 220 | batch_size=model_cfg["batch_size"], 221 | verbose=1, 222 | ) 223 | 224 | if dataset.task_type == "multiclass_classification": 225 | # get one-hot encoded vectors where 1 is the argmax in each case 226 | y_pred = [np.eye(len(p))[np.argmax(p)] for p in pred] 227 | # predictions might be shorter because of partial batch drop 228 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][ 229 | : len(y_pred) 230 | ] 231 | metrics = classification_report( 232 | y_true=y_true, y_pred=y_pred, output_dict=True 233 | ) 234 | print(classification_report(y_true=y_true, y_pred=y_pred)) 235 | if dataset.task_name == "key_estimation": 236 | # calculate weighted accuracy 237 | from mir_ref.metrics import key_detection_weighted_accuracy 238 | 239 | metrics["weighted_accuracy"] = key_detection_weighted_accuracy( 240 | y_true=[dataset.decode_label(y) for y in y_true], 241 | y_pred=[dataset.decode_label(y) for y in y_pred], 242 | ) 243 | print("Weighted accuracy:", metrics["weighted_accuracy"]) 244 | 245 | elif dataset.task_type == "multilabel_classification": 246 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][ 247 | : len(pred) 248 | ] 249 | y_pred = pred 250 | metrics = { 251 | "roc_auc": roc_auc_score(y_true, y_pred), 252 | "average_precision": average_precision_score(y_true, y_pred), 253 | } 254 | print(metrics) 255 | 256 | # save metrics 257 | with open( 258 | "./logs" / run_dir / f"deform_{scenario_idx}_metrics.json", 259 | "w+", 260 | ) as f: 261 | json.dump(metrics, f, indent=4) 262 | -------------------------------------------------------------------------------- /mir_ref/extract.py: -------------------------------------------------------------------------------- 1 | """Generate embeddings from the audio files. 2 | """ 3 | 4 | 5 | from colorama import Fore, Style 6 | 7 | from mir_ref.datasets.dataset import get_dataset 8 | from mir_ref.features.feature_extraction import generate_embeddings 9 | from mir_ref.utils import load_config 10 | 11 | 12 | def generate( 13 | cfg_path, 14 | skip_clean=False, 15 | skip_deformed=False, 16 | no_overwrite=False, 17 | deform_list=None, 18 | ): 19 | cfg = load_config(cfg_path) 20 | 21 | for exp_cfg in cfg["experiments"]: 22 | # iterate through every dataset to generate embeddings 23 | print(Fore.GREEN + "# Extracting features...", Style.RESET_ALL) 24 | for dataset_cfg in exp_cfg["datasets"]: 25 | print( 26 | Fore.GREEN + f"## Dataset: {dataset_cfg['name']}", 27 | Style.RESET_ALL, 28 | ) 29 | for model_name in exp_cfg["features"]: 30 | print(Fore.GREEN + f"### Feature: {model_name}", Style.RESET_ALL) 31 | 32 | dataset = get_dataset( 33 | dataset_cfg=dataset_cfg, 34 | task_cfg=exp_cfg["task"], 35 | features_cfg=exp_cfg["features"], 36 | ) 37 | dataset.download() 38 | dataset.preprocess() 39 | 40 | generate_embeddings( 41 | dataset, 42 | model_name=model_name, 43 | skip_clean=skip_clean, 44 | skip_deformed=skip_deformed, 45 | no_overwrite=no_overwrite, 46 | deform_list=deform_list, 47 | ) 48 | -------------------------------------------------------------------------------- /mir_ref/features/custom_features.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrispla/mir_ref/691ae42815db6359ef66c1e175be55be42bbc340/mir_ref/features/custom_features.py -------------------------------------------------------------------------------- /mir_ref/features/feature_extraction.py: -------------------------------------------------------------------------------- 1 | """Retrieve correct feature extractor and extract features. 2 | """ 3 | 4 | import os 5 | 6 | import wget 7 | 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 9 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | 15 | def check_model_exists(model_path: str): 16 | if not os.path.exists(model_path): 17 | raise Exception( 18 | f"Model not found at {model_path}" 19 | + "Please download it from https://essentia.upf.edu/models.html" 20 | + "and place it in the mir_ref/features/models/weights directory." 21 | ) 22 | 23 | 24 | def get_input_output_paths( 25 | dataset, 26 | model_name, 27 | skip_clean=False, 28 | skip_deformed=False, 29 | no_overwrite=False, 30 | deform_list=None, 31 | ): 32 | """Get a list of input audio paths (including deformed audio) 33 | and a list of the corresponding output paths for the embeddings 34 | for a given dataset and embedding model. Don't include embeddings 35 | that have already been computed. 36 | 37 | Args: 38 | dataset: mir_ref Dataset object. 39 | model_name: Name of the embedding model. 40 | skip_clean: Whether to skip embedding generation for clean audio. 41 | skip_deformed: Whether to skip embedding generation for deformed audio. 42 | no_overwrite: Whether to skip embedding generation for existing embeddings. 43 | deform_list: List of deformation scenario indicies to include. If None, 44 | include all deformation scenarios. 45 | """ 46 | 47 | # get audio paths for clean and deformed audio 48 | if not skip_clean: 49 | audio_paths = [dataset.audio_paths[track_id] for track_id in dataset.track_ids] 50 | else: 51 | audio_paths = [] 52 | if not skip_deformed and dataset.deformations_cfg is not None: 53 | if deform_list is None: 54 | deform_list = range(len(dataset.deformations_cfg)) 55 | audio_paths += [ 56 | dataset.get_deformed_audio_path(track_id=track_id, deform_idx=deform_idx) 57 | for deform_idx in deform_list 58 | for track_id in dataset.track_ids 59 | ] 60 | 61 | # get output embedding paths for the same tracks 62 | if not skip_clean: 63 | emb_paths = [ 64 | dataset.get_embedding_path(track_id, model_name) 65 | for track_id in dataset.track_ids 66 | ] 67 | else: 68 | emb_paths = [] 69 | if not skip_deformed and dataset.deformations_cfg is not None: 70 | if deform_list is None: 71 | deform_list = range(len(dataset.deformations_cfg)) 72 | emb_paths += [ 73 | dataset.get_deformed_embedding_path( 74 | track_id=track_id, feature=model_name, deform_idx=deform_idx 75 | ) 76 | for deform_idx in deform_list 77 | for track_id in dataset.track_ids 78 | ] 79 | 80 | if no_overwrite: 81 | # remove audio paths and the respective embeddings paths if embedding 82 | # already exists 83 | audio_paths, emb_paths = zip( 84 | *[ 85 | (audio_path, emb_path) 86 | for audio_path, emb_path in zip(audio_paths, emb_paths) 87 | if not os.path.exists(emb_path) 88 | ] 89 | ) 90 | 91 | return audio_paths, emb_paths 92 | 93 | 94 | def compute_and_save_embeddings( 95 | model: object, 96 | model_name: str, 97 | aggregation: str, 98 | dataset, 99 | sample_rate: int, 100 | resample_quality=1, 101 | skip_clean=False, 102 | skip_deformed=False, 103 | no_overwrite=False, 104 | deform_list=None, 105 | ): 106 | """Compute embeddings given model object and 107 | audio path list. 108 | """ 109 | 110 | from essentia.standard import MonoLoader 111 | 112 | monoloader = MonoLoader(sampleRate=sample_rate, resampleQuality=resample_quality) 113 | 114 | audio_paths, emb_paths = get_input_output_paths( 115 | dataset=dataset, 116 | model_name=model_name, 117 | skip_clean=skip_clean, 118 | skip_deformed=skip_deformed, 119 | no_overwrite=no_overwrite, 120 | deform_list=deform_list, 121 | ) 122 | 123 | for input_path, output_path in tqdm( 124 | zip(audio_paths, emb_paths), total=len(audio_paths) 125 | ): 126 | # Load audio 127 | monoloader.configure(filename=input_path) 128 | audio = monoloader() 129 | 130 | # Compute embeddings 131 | embedding = model(audio) 132 | 133 | if aggregation == "mean": 134 | embedding = np.mean(embedding, axis=0) 135 | elif aggregation is None: 136 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 137 | else: 138 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 139 | 140 | # Save embeddings 141 | if not os.path.exists(os.path.dirname(output_path)): 142 | os.makedirs(os.path.dirname(output_path)) 143 | with open(output_path, "wb") as f: 144 | np.save(f, embedding) 145 | 146 | 147 | def generate_embeddings( 148 | dataset, 149 | model_name: str, 150 | skip_clean=False, 151 | skip_deformed=False, 152 | no_overwrite=False, 153 | deform_list=None, 154 | transcode_and_load=False, 155 | ): 156 | """Generate embeddings from a list of audio files. 157 | 158 | Args: 159 | dataset_cfg (dict): Dataset configuration. 160 | task_cfg (dict): Task configuration. 161 | aggregation (str, optional): Embedding aggregation method. Defaults to "mean". 162 | skip_clean (bool, optional): Whether to skip embedding generation for clean 163 | audio. Defaults to False. 164 | skip_deformed (bool, optional): Whether to skip embedding generation for 165 | deformed audio. Defaults to False. 166 | no_overwrite (bool, optional): Whether to skip embedding generation for 167 | existing embeddings. Defaults to False. 168 | deform_list (list, optional): List of deformation scenario indicies to include. 169 | If None, include all deformation scenarios. 170 | """ 171 | 172 | aggregation = dataset.feature_aggregation 173 | # Load embedding model. Call essentia implementation if available, 174 | # otherwise custom implementation. 175 | 176 | if model_name == "vggish-audioset": 177 | model_path = "mir_ref/features/models/weights/audioset-vggish-3.pb" 178 | if not os.path.exists(model_path): 179 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 180 | wget.download( 181 | "https://essentia.upf.edu/models/feature-extractors/vggish/audioset-vggish-3.pb", 182 | out="mir_ref/features/models/weights/", 183 | ) 184 | check_model_exists(model_path) 185 | 186 | from essentia.standard import MonoLoader, TensorflowPredictVGGish 187 | 188 | model = TensorflowPredictVGGish( 189 | graphFilename=model_path, output="model/vggish/embeddings" 190 | ) 191 | 192 | compute_and_save_embeddings( 193 | model=model, 194 | model_name=model_name, 195 | aggregation=aggregation, 196 | dataset=dataset, 197 | sample_rate=16000, 198 | skip_clean=skip_clean, 199 | skip_deformed=skip_deformed, 200 | no_overwrite=no_overwrite, 201 | deform_list=deform_list, 202 | ) 203 | 204 | elif model_name == "effnet-discogs": 205 | model_path = ( 206 | "mir_ref/features/models/weights/discogs_artist_embeddings-effnet-bs64-1.pb" 207 | ) 208 | if not os.path.exists(model_path): 209 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 210 | wget.download( 211 | "https://essentia.upf.edu/models/feature-extractors/discogs-effnet/discogs_artist_embeddings-effnet-bs64-1.pb", 212 | out="mir_ref/features/models/weights/", 213 | ) 214 | check_model_exists(model_path) 215 | 216 | from essentia.standard import MonoLoader, TensorflowPredictEffnetDiscogs 217 | 218 | model = TensorflowPredictEffnetDiscogs( 219 | graphFilename=model_path, output="PartitionedCall:1" 220 | ) 221 | 222 | compute_and_save_embeddings( 223 | model=model, 224 | model_name=model_name, 225 | aggregation=aggregation, 226 | dataset=dataset, 227 | sample_rate=16000, 228 | skip_clean=skip_clean, 229 | skip_deformed=skip_deformed, 230 | no_overwrite=no_overwrite, 231 | deform_list=deform_list, 232 | ) 233 | 234 | elif model_name == "msd-musicnn": 235 | model_path = "mir_ref/features/models/weights/msd-musicnn-1.pb" 236 | 237 | if not os.path.exists(model_path): 238 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 239 | wget.download( 240 | "https://essentia.upf.edu/models/feature-extractors/musicnn/msd-musicnn-1.pb", 241 | out="mir_ref/features/models/weights/", 242 | ) 243 | 244 | from essentia.standard import MonoLoader, TensorflowPredictMusiCNN 245 | 246 | model = TensorflowPredictMusiCNN( 247 | graphFilename=model_path, output="model/dense/BiasAdd" 248 | ) 249 | 250 | compute_and_save_embeddings( 251 | model=model, 252 | model_name=model_name, 253 | aggregation=aggregation, 254 | dataset=dataset, 255 | sample_rate=16000, 256 | resample_quality=4, 257 | skip_clean=skip_clean, 258 | skip_deformed=skip_deformed, 259 | no_overwrite=no_overwrite, 260 | deform_list=deform_list, 261 | ) 262 | 263 | elif model_name == "maest": 264 | model_path = "mir_ref/features/models/weights/discogs-maest-30s-pw-1.pb" 265 | 266 | if not os.path.exists(model_path): 267 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 268 | wget.download( 269 | "https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-1.pb", 270 | out="mir_ref/features/models/weights/", 271 | ) 272 | 273 | check_model_exists(model_path) 274 | 275 | from essentia.standard import MonoLoader, TensorflowPredictMAEST 276 | 277 | model = TensorflowPredictMAEST( 278 | graphFilename=model_path, output="model/dense/BiasAdd" 279 | ) 280 | 281 | compute_and_save_embeddings( 282 | model=model, 283 | model_name=model_name, 284 | aggregation=aggregation, 285 | dataset=dataset, 286 | sample_rate=16000, 287 | resample_quality=4, 288 | skip_clean=skip_clean, 289 | skip_deformed=skip_deformed, 290 | no_overwrite=no_overwrite, 291 | deform_list=deform_list, 292 | ) 293 | 294 | elif model_name == "openl3": 295 | from essentia.standard import MonoLoader 296 | 297 | from mir_ref.features.models.openl3 import EmbeddingsOpenL3 298 | 299 | model_path = "mir_ref/features/models/weights/openl3-music-mel128-emb512-3.pb" 300 | 301 | if not os.path.exists(model_path): 302 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 303 | wget.download( 304 | "https://essentia.upf.edu/models/feature-extractors/openl3/openl3-env-mel128-emb512-3.pb", 305 | out="mir_ref/features/models/weights/", 306 | ) 307 | 308 | check_model_exists(model_path) 309 | 310 | extractor = EmbeddingsOpenL3(model_path) 311 | 312 | audio_paths, emb_paths = get_input_output_paths( 313 | dataset=dataset, 314 | model_name=model_name, 315 | skip_clean=skip_clean, 316 | skip_deformed=skip_deformed, 317 | no_overwrite=no_overwrite, 318 | deform_list=deform_list, 319 | ) 320 | 321 | # Compute embeddings 322 | for input_path, output_path in tqdm( 323 | zip(audio_paths, emb_paths), total=len(audio_paths) 324 | ): 325 | embedding = extractor.compute(input_path) 326 | 327 | if aggregation == "mean": 328 | embedding = np.mean(embedding, axis=0) 329 | elif aggregation is None: 330 | pass 331 | else: 332 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 333 | 334 | # Save embeddings 335 | if not os.path.exists(os.path.dirname(output_path)): 336 | os.makedirs(os.path.dirname(output_path)) 337 | with open(output_path, "wb") as f: 338 | np.save(f, embedding) 339 | 340 | elif model_name == "neuralfp": 341 | import tensorflow as tf 342 | from essentia.standard import MonoLoader 343 | 344 | mel_spec_model_dir = "mir_ref/features/models/weights/neuralfp/mel_spec" 345 | fp_model_dir = "mir_ref/features/models/weights/neuralfp/fp" 346 | mel_spec_model = tf.saved_model.load(mel_spec_model_dir) 347 | mel_spec_infer = mel_spec_model.signatures["serving_default"] 348 | fp_model = tf.saved_model.load(fp_model_dir) 349 | fp_infer = fp_model.signatures["serving_default"] 350 | 351 | monoloader = MonoLoader(sampleRate=8000, resampleQuality=1) 352 | 353 | audio_paths, emb_paths = get_input_output_paths( 354 | dataset=dataset, 355 | model_name=model_name, 356 | skip_clean=skip_clean, 357 | skip_deformed=skip_deformed, 358 | no_overwrite=no_overwrite, 359 | deform_list=deform_list, 360 | ) 361 | 362 | for input_path, output_path in tqdm( 363 | zip(audio_paths, emb_paths), total=len(audio_paths) 364 | ): 365 | # Load audio 366 | monoloader.configure(filename=input_path) 367 | audio = monoloader() 368 | 369 | # Fingerprinting is done per 8000 samples, with a 4000 sample overlap, so pad 370 | audio = np.concatenate((np.zeros(4000), audio)) 371 | audio = np.concatenate((audio, np.zeros(4000 - (len(audio) % 4000)))) 372 | 373 | # Compute embeddings 374 | embeddings = [] 375 | for buffer_start in range(0, len(audio) - 4000, 4000): 376 | buffer = audio[buffer_start : buffer_start + 8000] 377 | # size (None, 1, 8000) 378 | buffer.resize(1, 8000) 379 | buffer = np.array([buffer]) 380 | # use mel spectrogram model 381 | mel_spec_emb = mel_spec_infer(tf.constant(buffer, dtype=tf.float32))[ 382 | "output_1" 383 | ] 384 | # use fingerprinter model 385 | fp_emb = fp_infer(mel_spec_emb)["output_1"] 386 | embeddings.append(fp_emb.numpy()[0]) 387 | 388 | if aggregation == "mean": 389 | embedding = np.mean(embeddings, axis=0) 390 | elif aggregation is None: 391 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 392 | else: 393 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 394 | 395 | # Save embeddings 396 | if not os.path.exists(os.path.dirname(output_path)): 397 | os.makedirs(os.path.dirname(output_path)) 398 | with open(output_path, "wb") as f: 399 | np.save(f, embedding) 400 | 401 | elif model_name == "mert-v1-330m" or model_name == "mert-v1-95m": 402 | # from transformers import Wav2Vec2Processor 403 | import torch 404 | import torchaudio.transforms as T 405 | from transformers import AutoModel, Wav2Vec2FeatureExtractor 406 | 407 | n_params = 330 if model_name == "mert-v1-330m" else 95 408 | # loading our model weights 409 | model = AutoModel.from_pretrained( 410 | f"m-a-p/MERT-v1-{n_params}M", trust_remote_code=True 411 | ) 412 | # loading the corresponding preprocessor config 413 | processor = Wav2Vec2FeatureExtractor.from_pretrained( 414 | f"m-a-p/MERT-v1-{n_params}M", trust_remote_code=True 415 | ) 416 | # get desired sample rate 417 | sample_rate = processor.sampling_rate 418 | 419 | audio_paths, emb_paths = get_input_output_paths( 420 | dataset=dataset, 421 | model_name=model_name, 422 | skip_clean=skip_clean, 423 | skip_deformed=skip_deformed, 424 | no_overwrite=no_overwrite, 425 | deform_list=deform_list, 426 | ) 427 | 428 | if transcode_and_load: 429 | import librosa 430 | import sox 431 | 432 | tfm = sox.Transformer() 433 | tfm.convert(samplerate=sample_rate, n_channels=1) 434 | else: 435 | from essentia.standard import MonoLoader 436 | 437 | monoloader = MonoLoader(sampleRate=sample_rate, resampleQuality=1) 438 | 439 | for input_path, output_path in tqdm( 440 | zip(audio_paths, emb_paths), total=len(audio_paths) 441 | ): 442 | # Load audio 443 | if transcode_and_load: 444 | wav_input_path = input_path[:-4] + str(sample_rate) + ".wav" 445 | tfm.build(input_path, wav_input_path) 446 | audio, _ = librosa.load(wav_input_path, sr=sample_rate) 447 | os.remove(wav_input_path) 448 | else: 449 | monoloader.configure(filename=input_path) 450 | audio = monoloader() 451 | 452 | inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt") 453 | with torch.no_grad(): 454 | outputs = model(**inputs, output_hidden_states=True) 455 | 456 | if aggregation == "mean": 457 | # we'll get the full embedding for now, meaning 13 layers x 768, or 458 | # 24 layers x 1024 for 95M and 330M respectively 459 | all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() 460 | embedding = all_layer_hidden_states.mean(-2).detach().cpu().numpy() 461 | elif aggregation is None: 462 | pass 463 | else: 464 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 465 | 466 | # Save embeddings 467 | if not os.path.exists(os.path.dirname(output_path)): 468 | os.makedirs(os.path.dirname(output_path)) 469 | with open(output_path, "wb") as f: 470 | np.save(f, embedding) 471 | 472 | elif model_name == "clmr-v2": 473 | import subprocess 474 | 475 | import torch 476 | 477 | from mir_ref.features.models.clmr import SampleCNN, load_encoder_checkpoint 478 | 479 | # download model 480 | if not os.path.exists( 481 | "mir_ref/features/models/weights/clmr_checkpoint_10000.pt" 482 | ): 483 | print(f"Downloading {model_name} to mir_ref/features/models/weights...") 484 | wget.download( 485 | "https://github.com/Spijkervet/CLMR/releases/download/2.0/clmr_checkpoint_10000.zip", 486 | out="mir_ref/features/models/weights/", 487 | ) 488 | 489 | # unzip clmr_checkpoint_10000 490 | subprocess.run( 491 | [ 492 | "unzip", 493 | "mir_ref/features/models/weights/clmr_checkpoint_10000.zip", 494 | "-d", 495 | "mir_ref/features/models/weights/", 496 | ] 497 | ) 498 | # delete zip 499 | subprocess.run( 500 | [ 501 | "rm", 502 | "mir_ref/features/models/weights/clmr_checkpoint_10000.zip", 503 | ] 504 | ) 505 | # delete clmr_checkpoint_10000_optim.pt 506 | subprocess.run( 507 | [ 508 | "rm", 509 | "mir_ref/features/models/weights/clmr_checkpoint_10000_optim.pt", 510 | ] 511 | ) 512 | 513 | # load model 514 | encoder = SampleCNN(strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], supervised=False) 515 | state_dict = load_encoder_checkpoint( 516 | "mir_ref/features/models/weights/clmr_checkpoint_10000.pt" 517 | ) 518 | encoder.load_state_dict(state_dict) 519 | encoder.eval() 520 | 521 | audio_paths, emb_paths = get_input_output_paths( 522 | dataset=dataset, 523 | model_name=model_name, 524 | skip_clean=skip_clean, 525 | skip_deformed=skip_deformed, 526 | no_overwrite=no_overwrite, 527 | deform_list=deform_list, 528 | ) 529 | 530 | if transcode_and_load: 531 | import librosa 532 | import sox 533 | 534 | tfm = sox.Transformer() 535 | tfm.convert(samplerate=22050, n_channels=1) 536 | else: 537 | from essentia.standard import MonoLoader 538 | 539 | monoloader = MonoLoader(sampleRate=22050, resampleQuality=1) 540 | 541 | for input_path, output_path in tqdm( 542 | zip(audio_paths, emb_paths), total=len(audio_paths) 543 | ): 544 | # Load audio 545 | if transcode_and_load: 546 | wav_input_path = input_path[:-4] + str(22050) + ".wav" 547 | tfm.build(input_path, wav_input_path) 548 | audio, _ = librosa.load(wav_input_path, sr=22050) 549 | os.remove(wav_input_path) 550 | else: 551 | monoloader.configure(filename=input_path) 552 | audio = monoloader() 553 | 554 | # get embedding per 59049 samples, padding the last buffer 555 | embeddings = [] 556 | buffer_size = 59049 557 | for i in range(0, len(audio), buffer_size): 558 | buffer = audio[i : i + buffer_size] 559 | if len(buffer) < buffer_size: 560 | buffer = np.pad( 561 | buffer, (0, buffer_size - len(buffer)), mode="constant" 562 | ) 563 | buffer = torch.from_numpy(buffer).float() 564 | buffer = buffer.unsqueeze(0).unsqueeze(0) 565 | embedding = encoder(buffer).squeeze() 566 | embeddings.append(embedding) 567 | 568 | if aggregation == "mean": 569 | embedding = torch.mean(torch.stack(embeddings), axis=0).detach().numpy() 570 | elif aggregation is None: 571 | embedding = torch.stack(embeddings).detach().numpy() 572 | else: 573 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 574 | 575 | # Save embeddings 576 | if not os.path.exists(os.path.dirname(output_path)): 577 | os.makedirs(os.path.dirname(output_path)) 578 | with open(output_path, "wb") as f: 579 | np.save(f, embedding) 580 | 581 | elif model_name == "mule": 582 | raise Exception("MULE embeddings are not fully supported yet.") 583 | 584 | from scooch import Config 585 | 586 | from mir_ref.features.models.mule import Analysis 587 | 588 | if aggregation is None: 589 | config = "mir_ref/features/models/weights/mule/mule_embedding_timeline.yml" 590 | elif aggregation == "mean": 591 | config = "mir_ref/features/models/weights/mule/mule_embedding_average.yml" 592 | else: 593 | raise Exception(f"Aggregation method '{aggregation}' not implemented.") 594 | 595 | cfg = Config(config) 596 | analysis = Analysis(cfg) 597 | 598 | audio_paths, emb_paths = get_input_output_paths( 599 | dataset=dataset, 600 | model_name=model_name, 601 | skip_clean=skip_clean, 602 | skip_deformed=skip_deformed, 603 | no_overwrite=no_overwrite, 604 | deform_list=deform_list, 605 | ) 606 | 607 | for input_file, output_file in tqdm( 608 | zip(audio_paths, emb_paths), total=len(audio_paths) 609 | ): 610 | feat = analysis.analyze(input_file) 611 | feat.save(output_file) 612 | -------------------------------------------------------------------------------- /mir_ref/features/models/clmr.py: -------------------------------------------------------------------------------- 1 | """Code used for inference for CLMR model. 2 | Source (slightly modified for inference purposes): 3 | https://github.com/Spijkervet/CLMR, Apache License 2.0 4 | """ 5 | 6 | import torch 7 | from collections import OrderedDict 8 | import torch.nn as nn 9 | 10 | 11 | def load_encoder_checkpoint(checkpoint_path: str) -> OrderedDict: 12 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 13 | if "pytorch-lightning_version" in state_dict.keys(): 14 | new_state_dict = OrderedDict( 15 | { 16 | k.replace("model.encoder.", ""): v 17 | for k, v in state_dict["state_dict"].items() 18 | if "model.encoder." in k 19 | } 20 | ) 21 | else: 22 | new_state_dict = OrderedDict() 23 | for k, v in state_dict.items(): 24 | if "encoder." in k: 25 | new_state_dict[k.replace("encoder.", "")] = v 26 | 27 | new_state_dict["fc.weight"] = torch.zeros(50, 512) 28 | new_state_dict["fc.bias"] = torch.zeros(50) 29 | return new_state_dict 30 | 31 | 32 | class Model(nn.Module): 33 | def __init__(self): 34 | super(Model, self).__init__() 35 | 36 | def initialize(self, m): 37 | if isinstance(m, (nn.Conv1d)): 38 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 39 | 40 | 41 | class Identity(nn.Module): 42 | def __init__(self): 43 | super(Identity, self).__init__() 44 | 45 | def forward(self, x): 46 | return x 47 | 48 | 49 | class SampleCNN(Model): 50 | def __init__(self, strides, supervised=False): 51 | super(SampleCNN, self).__init__() 52 | 53 | self.strides = strides 54 | self.supervised = supervised 55 | self.sequential = [ 56 | nn.Sequential( 57 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0), 58 | nn.BatchNorm1d(128), 59 | nn.ReLU(), 60 | ) 61 | ] 62 | 63 | self.hidden = [ 64 | [128, 128], 65 | [128, 128], 66 | [128, 256], 67 | [256, 256], 68 | [256, 256], 69 | [256, 256], 70 | [256, 256], 71 | [256, 256], 72 | [256, 512], 73 | ] 74 | 75 | assert len(self.hidden) == len( 76 | self.strides 77 | ), "Number of hidden layers and strides are not equal" 78 | for stride, (h_in, h_out) in zip(self.strides, self.hidden): 79 | self.sequential.append( 80 | nn.Sequential( 81 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1), 82 | nn.BatchNorm1d(h_out), 83 | nn.ReLU(), 84 | nn.MaxPool1d(stride, stride=stride), 85 | ) 86 | ) 87 | 88 | # 1 x 512 89 | self.sequential.append( 90 | nn.Sequential( 91 | nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1), 92 | nn.BatchNorm1d(512), 93 | nn.ReLU(), 94 | ) 95 | ) 96 | 97 | self.sequential = nn.Sequential(*self.sequential) 98 | 99 | if self.supervised: 100 | self.dropout = nn.Dropout(0.5) 101 | self.fc = nn.Linear(512, 50) 102 | 103 | def forward(self, x): 104 | out = self.sequential(x) 105 | if self.supervised: 106 | out = self.dropout(out) 107 | 108 | out = out.reshape(x.shape[0], out.size(1) * out.size(2)) 109 | logit = self.fc(out) 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /mir_ref/features/models/harmonic_cnn.py: -------------------------------------------------------------------------------- 1 | """Code for Harmonic CNN embedding model inference. 2 | Source: https://github.com/minzwon/sota-music-tagging-models/blob/master/predict.py 3 | """ 4 | 5 | import os 6 | import sys 7 | import tempfile 8 | from pathlib import Path 9 | 10 | import cog 11 | import librosa 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torchaudio 18 | from torch.autograd import Variable 19 | 20 | sys.path.insert(0, "training") 21 | 22 | SAMPLE_RATE = 16000 23 | DATASET = "mtat" 24 | 25 | 26 | def initialize_filterbank(sample_rate, n_harmonic, semitone_scale): 27 | # MIDI 28 | # lowest note 29 | low_midi = note_to_midi('C1') 30 | 31 | # highest note 32 | high_note = hz_to_note(sample_rate / (2 * n_harmonic)) 33 | high_midi = note_to_midi(high_note) 34 | 35 | # number of scales 36 | level = (high_midi - low_midi) * semitone_scale 37 | midi = np.linspace(low_midi, high_midi, level + 1) 38 | hz = midi_to_hz(midi[:-1]) 39 | 40 | # stack harmonics 41 | harmonic_hz = [] 42 | for i in range(n_harmonic): 43 | harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1))) 44 | 45 | return harmonic_hz, level 46 | 47 | 48 | class Conv_2d(nn.Module): 49 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): 50 | super(Conv_2d, self).__init__() 51 | self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) 52 | self.bn = nn.BatchNorm2d(output_channels) 53 | self.relu = nn.ReLU() 54 | self.mp = nn.MaxPool2d(pooling) 55 | 56 | def forward(self, x): 57 | out = self.mp(self.relu(self.bn(self.conv(x)))) 58 | return out 59 | 60 | 61 | class Res_2d_mp(nn.Module): 62 | def __init__(self, input_channels, output_channels, pooling=2): 63 | super(Res_2d_mp, self).__init__() 64 | self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1) 65 | self.bn_1 = nn.BatchNorm2d(output_channels) 66 | self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1) 67 | self.bn_2 = nn.BatchNorm2d(output_channels) 68 | self.relu = nn.ReLU() 69 | self.mp = nn.MaxPool2d(pooling) 70 | 71 | def forward(self, x): 72 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 73 | out = x + out 74 | out = self.mp(self.relu(out)) 75 | return out 76 | 77 | 78 | class HarmonicSTFT(nn.Module): 79 | def __init__(self, 80 | sample_rate=16000, 81 | n_fft=513, 82 | win_length=None, 83 | hop_length=None, 84 | pad=0, 85 | power=2, 86 | normalized=False, 87 | n_harmonic=6, 88 | semitone_scale=2, 89 | bw_Q=1.0, 90 | learn_bw=None): 91 | super(HarmonicSTFT, self).__init__() 92 | 93 | # Parameters 94 | self.sample_rate = sample_rate 95 | self.n_harmonic = n_harmonic 96 | self.bw_alpha = 0.1079 97 | self.bw_beta = 24.7 98 | 99 | # Spectrogram 100 | self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length, 101 | hop_length=None, pad=0, 102 | window_fn=torch.hann_window, 103 | power=power, normalized=normalized, 104 | wkwargs=None) 105 | self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() 106 | 107 | # Initialize the filterbank. Equally spaced in MIDI scale. 108 | harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale) 109 | 110 | # Center frequncies to tensor 111 | self.f0 = torch.tensor(harmonic_hz.astype('float32')) 112 | 113 | # Bandwidth parameters 114 | if learn_bw == 'only_Q': 115 | self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32'))) 116 | elif learn_bw == 'fix': 117 | self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32')) 118 | 119 | def get_harmonic_fb(self): 120 | # bandwidth 121 | bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q 122 | bw = bw.unsqueeze(0) # (1, n_band) 123 | f0 = self.f0.unsqueeze(0) # (1, n_band) 124 | fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1) 125 | 126 | up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw) 127 | down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw) 128 | fb = torch.max(self.zero, torch.min(down_slope, up_slope)) 129 | return fb 130 | 131 | def to_device(self, device, n_bins): 132 | self.f0 = self.f0.to(device) 133 | self.bw_Q = self.bw_Q.to(device) 134 | # fft bins 135 | self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins) 136 | self.fft_bins = self.fft_bins.to(device) 137 | self.zero = torch.zeros(1) 138 | self.zero = self.zero.to(device) 139 | 140 | def forward(self, waveform): 141 | # stft 142 | spectrogram = self.spec(waveform) 143 | 144 | # to device 145 | self.to_device(waveform.device, spectrogram.size(1)) 146 | 147 | # triangle filter 148 | harmonic_fb = self.get_harmonic_fb() 149 | harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2) 150 | 151 | # (batch, channel, length) -> (batch, harmonic, f0, length) 152 | b, c, l = harmonic_spec.size() 153 | harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l) 154 | 155 | # amplitude to db 156 | harmonic_spec = self.amplitude_to_db(harmonic_spec) 157 | return harmonic_spec 158 | 159 | 160 | class HarmonicCNN(nn.Module): 161 | ''' 162 | Won et al. 2020 163 | Data-driven harmonic filters for audio representation learning. 164 | Trainable harmonic band-pass filters, short-chunk CNN. 165 | ''' 166 | def __init__(self, 167 | n_channels=128, 168 | sample_rate=16000, 169 | n_fft=512, 170 | f_min=0.0, 171 | f_max=8000.0, 172 | n_mels=128, 173 | n_class=50, 174 | n_harmonic=6, 175 | semitone_scale=2, 176 | learn_bw='only_Q'): 177 | super(HarmonicCNN, self).__init__() 178 | 179 | # Harmonic STFT 180 | self.hstft = HarmonicSTFT(sample_rate=sample_rate, 181 | n_fft=n_fft, 182 | n_harmonic=n_harmonic, 183 | semitone_scale=semitone_scale, 184 | learn_bw=learn_bw) 185 | self.hstft_bn = nn.BatchNorm2d(n_harmonic) 186 | 187 | # CNN 188 | self.layer1 = Conv_2d(n_harmonic, n_channels, pooling=2) 189 | self.layer2 = Res_2d_mp(n_channels, n_channels, pooling=2) 190 | self.layer3 = Res_2d_mp(n_channels, n_channels, pooling=2) 191 | self.layer4 = Res_2d_mp(n_channels, n_channels, pooling=2) 192 | self.layer5 = Conv_2d(n_channels, n_channels*2, pooling=2) 193 | self.layer6 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2, 3)) 194 | self.layer7 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2, 3)) 195 | 196 | # Dense 197 | self.dense1 = nn.Linear(n_channels*2, n_channels*2) 198 | self.bn = nn.BatchNorm1d(n_channels*2) 199 | self.dense2 = nn.Linear(n_channels*2, n_class) 200 | self.dropout = nn.Dropout(0.5) 201 | self.relu = nn.ReLU() 202 | 203 | def forward(self, x): 204 | # Spectrogram 205 | x = self.hstft_bn(self.hstft(x)) 206 | 207 | # CNN 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | x = self.layer5(x) 213 | x = self.layer6(x) 214 | x = self.layer7(x) 215 | x = x.squeeze(2) 216 | 217 | # Global Max Pooling 218 | if x.size(-1) != 1: 219 | x = nn.MaxPool1d(x.size(-1))(x) 220 | x = x.squeeze(2) 221 | 222 | # Dense 223 | x = self.dense1(x) 224 | x = self.bn(x) 225 | x = self.relu(x) 226 | x = self.dropout(x) 227 | x = self.dense2(x) 228 | x = nn.Sigmoid()(x) 229 | 230 | return x 231 | 232 | 233 | class Predictor(cog.Predictor): 234 | def setup(self): 235 | if torch.cuda.is_available(): 236 | self.device = torch.device("cuda:0") 237 | else: 238 | self.device = torch.device("cpu") 239 | 240 | self.model = HarmonicCNN().to(self.device), 241 | self.input_length = 5 * 16000 242 | 243 | for key, mod in self.models.items(): 244 | filename = os.path.join("models", DATASET, key, "best_model.pth") 245 | state_dict = torch.load(filename, map_location=self.device) 246 | if "spec.mel_scale.fb" in state_dict.keys(): 247 | mod.spec.mel_scale.fb = state_dict["spec.mel_scale.fb"] 248 | mod.load_state_dict(state_dict) 249 | 250 | self.tags = np.load("split/mtat/tags.npy") 251 | 252 | @cog.input( 253 | "output_format", 254 | type=str, 255 | default="Visualization", 256 | options=["Visualization", "JSON"], 257 | help="Output either a bar chart visualization or a JSON blob", 258 | ) 259 | def predict(self, input, output_format): 260 | 261 | model = self.model.eval() 262 | input_length = self.input_length 263 | signal, _ = librosa.core.load(str(input), sr=SAMPLE_RATE) 264 | length = len(signal) 265 | hop = length // 2 - input_length // 2 266 | print("length, input_length", length, input_length) 267 | x = torch.zeros(1, input_length) 268 | x[0] = torch.Tensor(signal[hop:hop+input_length]).unsqueeze(0) 269 | x = Variable(x.to(self.device)) 270 | print("x.max(), x.min(), x.mean()", x.max(), x.min(), x.mean()) 271 | out = model(x) 272 | result = dict(zip(self.tags, out[0].detach().numpy().tolist())) 273 | 274 | if output_format == "JSON": 275 | return result 276 | 277 | result_list = list(sorted(result.items(), key=lambda x: x[1])) 278 | plt.figure(figsize=[5, 10]) 279 | plt.barh( 280 | np.arange(len(result_list)), [r[1] for r in result_list], align="center" 281 | ) 282 | plt.yticks(np.arange(len(result_list)), [r[0] for r in result_list]) 283 | plt.tight_layout() 284 | 285 | out_path = Path(tempfile.mkdtemp()) / "out.png" 286 | plt.savefig(out_path) 287 | return out_path 288 | -------------------------------------------------------------------------------- /mir_ref/features/models/mule.py: -------------------------------------------------------------------------------- 1 | """Code used for inference for MULE model. 2 | Source: https://github.com/PandoraMedia/music-audio-features: 3 | """ 4 | 5 | import tempfile 6 | 7 | import numpy as np 8 | from scooch import ConfigList, Configurable, Param 9 | 10 | 11 | class Feature(Configurable): 12 | """ 13 | The base class for all feature types. 14 | """ 15 | 16 | def __del__(self): 17 | """ 18 | **Destructor** 19 | """ 20 | if hasattr(self, "_data_file"): 21 | self._data_file.close() 22 | 23 | def add_data(self, data): 24 | """ 25 | Adds data to extend the object's current data via concatenation along the time axis. 26 | This is useful for populating data in chunks, where populating it all at once would 27 | cause excessive memory usage. 28 | 29 | Arg: 30 | data: np.ndarray - The data to be appended to the object's memmapped numpy array. 31 | """ 32 | if not hasattr(self, "_data_file"): 33 | self._data_file = tempfile.NamedTemporaryFile(mode="w+b") 34 | 35 | if not hasattr(self, "_data") or self._data is None: 36 | original_data_size = 0 37 | else: 38 | original_data_size = self._data.shape[1] 39 | final_size = original_data_size + data.shape[1] 40 | 41 | filename = self._data_file.name 42 | 43 | self._data = np.memmap( 44 | filename, 45 | dtype="float32", 46 | mode="r+", 47 | shape=(data.shape[0], final_size), 48 | order="F", 49 | ) 50 | self._data[:, original_data_size:] = data 51 | 52 | def _extract(self, source, length): 53 | """ 54 | Extracts feature data file or other feature a given time-chunk. 55 | 56 | Args: 57 | source_feature: mule.features.Feature - The feature to transform. 58 | 59 | start_time: int - The index in the input at which to start extracting / transforming. 60 | 61 | chunk_size: int - The length of the chunk following the `start_time` to extract 62 | the feature from. 63 | """ 64 | raise NotImplementedError( 65 | f"The {self.__name__.__class__} has no feature extraction method" 66 | ) 67 | 68 | def save(self, path): 69 | """ 70 | Save the feature data blob to disk. 71 | 72 | Args: 73 | path: str - The path to save the data to. 74 | """ 75 | np.save(path, self.data) 76 | 77 | def clear(self): 78 | """ 79 | Clears any previously analyzed feature data, ready for a new analysis. 80 | """ 81 | self._data = None 82 | if hasattr(self, "_data_file"): 83 | self._data_file.close() 84 | del self._data_file 85 | 86 | @property 87 | def data(self): 88 | """ 89 | The feature data blob itself. 90 | """ 91 | return self._data 92 | 93 | 94 | class SourceFile(Configurable): 95 | """ 96 | Base class for SCOOCH configurable file readers. 97 | """ 98 | 99 | def load(self, fname): 100 | """ 101 | Any preprocessing steps to load a file prior to reading it. 102 | 103 | Args: 104 | fname: file-like - A file like object to be loaded. 105 | """ 106 | raise NotImplementedError( 107 | f"The class, {self.__class__.__name__}, has no method for loading files" 108 | ) 109 | 110 | def read(self, n): 111 | """ 112 | Reads an amount of data from the file. 113 | 114 | Args: 115 | n: int - A size parameter indicating the amount of data to read. 116 | 117 | Return: 118 | object - The decoded data read and in memory. 119 | """ 120 | raise NotImplementedError( 121 | f"The class, {self.__class__.__name__}, has no method for reading files" 122 | ) 123 | 124 | def close(self): 125 | """ 126 | Closes any previously loaded file. 127 | """ 128 | raise NotImplementedError( 129 | f"The class, {self.__class__.__name__}, has no method for closing files" 130 | ) 131 | 132 | def __len__(self): 133 | raise NotImplementedError( 134 | f"The class, {self.__class__.__name__} has no method for determining file data length" 135 | ) 136 | 137 | 138 | class SourceFeature(Feature): 139 | """ 140 | A feature that is derived directly from raw data, e.g., a data file. 141 | """ 142 | 143 | # SCOOCH Configuration 144 | _input_file = Param( 145 | SourceFile, 146 | doc="The file object defining the parameters of the raw data that this feature is constructed from.", 147 | ) 148 | 149 | _CHUNK_SIZE = 44100 * 60 * 15 150 | 151 | # Methods 152 | def from_file(self, fname): 153 | """ 154 | Takes a file and processes its data in chunks to form a feature. 155 | 156 | Args: 157 | fname: str - The path to the input file from which this feature is constructed. 158 | """ 159 | # Load file 160 | self._input_file.load(fname) 161 | 162 | # Read samples into data 163 | processed_input_frames = 0 164 | while processed_input_frames < len(self._input_file): 165 | data = self._extract( 166 | self._input_file, processed_input_frames, self._CHUNK_SIZE 167 | ) 168 | processed_input_frames += self._CHUNK_SIZE 169 | self.add_data(data) 170 | 171 | def clear(self): 172 | """ 173 | Clears any previously analyzed feature data, ready for a new analysis. 174 | """ 175 | super().clear() 176 | self._input_file.close() 177 | 178 | def __len__(self): 179 | """ 180 | Returns the number of bytes / samples / indices in the input data file. 181 | """ 182 | return len(self._input_file) 183 | 184 | 185 | class Extractor(Configurable): 186 | """ 187 | Base class for classes that are responsible for extracting data 188 | from mule.features.Feature classes. 189 | """ 190 | 191 | # 192 | # Methods 193 | # 194 | def extract_range(self, feature, start_index, end_index): 195 | """ 196 | Extracts data over a given index range from a single feature. 197 | 198 | Args: 199 | feature: mule.feature.Feature - A feature to extract data from. 200 | 201 | start_index: int - The first index (inclusive) at which to return data. 202 | 203 | end_index: int - The last index (exclusive) at which to return data. 204 | 205 | Return: 206 | numpy.ndarray - The extracted feature data. Features on first axis, time on 207 | second axis. 208 | """ 209 | raise NotImplementedError( 210 | f"The {self.__class__.__name__} class has no `extract_range` method." 211 | ) 212 | 213 | def extract_batch(self, features, indices): 214 | """ 215 | Extracts a batch of features from potentially multiple features, each potentially 216 | at distinct indices. 217 | 218 | Args: 219 | features: list(mule.features.Feature) - A list of features from which to extract 220 | data from. 221 | 222 | indices: list(int) - A list of indices, the same size as `features`. Each element 223 | provides an index at which to extract data from the coressponding element in the 224 | `features` argument. 225 | 226 | Return: 227 | np.ndarray - A batch of features, with features on the batch dimension on the first 228 | axis and feature data on the remaining axes. 229 | """ 230 | raise NotImplementedError( 231 | f"The {self.__class__.__name__} class has no `extract_batch` method." 232 | ) 233 | 234 | 235 | class TransformFeature(Feature): 236 | """ 237 | Base class for all features that are transforms of other features. 238 | """ 239 | 240 | # 241 | # SCOOCH Configuration 242 | # 243 | _extractor = Param( 244 | Extractor, 245 | doc="An object defining how data will be extracted from the input feature and provided to the transformation of this feature.", 246 | ) 247 | 248 | # The size in time of each chunk that this feature will process at any one time. 249 | _CHUNK_SIZE = 44100 * 60 * 15 250 | 251 | # 252 | # Methods 253 | # 254 | def from_feature(self, source_feature): 255 | """ 256 | Populates this features data as a transform of the provided input feature. 257 | 258 | Args: 259 | source_feature: mule.features.Feature - A feature from which this feature will 260 | be created as a transformation thereof. 261 | """ 262 | boundaries = list(range(0, len(source_feature), self._CHUNK_SIZE)) + [ 263 | len(source_feature) 264 | ] 265 | chunks = [(start, end) for start, end in zip(boundaries[:-1], boundaries[1:])] 266 | for start_time, end_time in chunks: 267 | data = self._extract(source_feature, start_time, end_time - start_time) 268 | if data is not None and len(data): 269 | self.add_data(data) 270 | 271 | def _extract(self, source_feature, start_time, chunk_size): 272 | """ 273 | Extracts feature data as a transformation of a given source feature for a given 274 | time-chunk. 275 | 276 | Args: 277 | source_feature: mule.features.Feature - The feature to transform. 278 | 279 | start_time: int - The index in the feature at which to start extracting / transforming. 280 | 281 | chunk_size: int - The length of the chunk following the `start_time` to extract 282 | the feature from. 283 | """ 284 | end_time = start_time + chunk_size 285 | return self._extractor.extract_range(source_feature, start_time, end_time) 286 | 287 | def __len__(self): 288 | if hasattr(self, "_data"): 289 | return self._data.shape[1] 290 | else: 291 | return 0 292 | 293 | 294 | class Analysis(Configurable): 295 | """ 296 | A class encapsulating analysis of a single input file. 297 | """ 298 | 299 | # SCOOCH Configuration 300 | _source_feature = Param( 301 | SourceFeature, doc="The feature used to decode the provided raw file data." 302 | ) 303 | _feature_transforms = Param( 304 | ConfigList(TransformFeature), 305 | doc="Feature transformations to apply, in order, to the source feature generated from the input file.", 306 | ) 307 | 308 | # Methods 309 | def analyze(self, fname): 310 | """ 311 | Analyze features for a single filepath. 312 | 313 | Args: 314 | fname: str - The filename path, from which to generate features. 315 | 316 | Return: 317 | mule.features.Feature - The feature resulting from the configured feature 318 | transformations. 319 | """ 320 | for feat in [self._source_feature] + self._feature_transforms: 321 | feat.clear() 322 | 323 | self._source_feature.from_file(fname) 324 | input_feature = self._source_feature 325 | for feature in self._feature_transforms: 326 | feature.from_feature(input_feature) 327 | input_feature = feature 328 | 329 | return input_feature 330 | -------------------------------------------------------------------------------- /mir_ref/features/models/openl3.py: -------------------------------------------------------------------------------- 1 | """Code for OpenL3 embedding model inference. 2 | Source: https://gist.github.com/palonso/cfebe37e5492b5a3a31775d8eae8d9a8 3 | """ 4 | 5 | from pathlib import Path 6 | import essentia.standard as es 7 | import numpy as np 8 | from essentia import Pool 9 | 10 | 11 | class MelSpectrogramOpenL3: 12 | def __init__(self, hop_time): 13 | self.hop_time = hop_time 14 | 15 | self.sr = 48000 16 | self.n_mels = 128 17 | self.frame_size = 2048 18 | self.hop_size = 242 19 | self.a_min = 1e-10 20 | self.d_range = 80 21 | self.db_ref = 1.0 22 | 23 | self.patch_samples = int(1 * self.sr) 24 | self.hop_samples = int(self.hop_time * self.sr) 25 | 26 | self.w = es.Windowing( 27 | size=self.frame_size, 28 | normalized=False, 29 | ) 30 | self.s = es.Spectrum(size=self.frame_size) 31 | self.mb = es.MelBands( 32 | highFrequencyBound=self.sr / 2, 33 | inputSize=self.frame_size // 2 + 1, 34 | log=False, 35 | lowFrequencyBound=0, 36 | normalize="unit_tri", 37 | numberBands=self.n_mels, 38 | sampleRate=self.sr, 39 | type="magnitude", 40 | warpingFormula="slaneyMel", 41 | weighting="linear", 42 | ) 43 | 44 | def compute(self, audio_file): 45 | audio = es.MonoLoader(filename=audio_file, sampleRate=self.sr)() 46 | 47 | batch = [] 48 | for audio_chunk in es.FrameGenerator( 49 | audio, frameSize=self.patch_samples, hopSize=self.hop_samples 50 | ): 51 | melbands = np.array( 52 | [ 53 | self.mb(self.s(self.w(frame))) 54 | for frame in es.FrameGenerator( 55 | audio_chunk, 56 | frameSize=self.frame_size, 57 | hopSize=self.hop_size, 58 | validFrameThresholdRatio=0.5, 59 | ) 60 | ] 61 | ) 62 | 63 | melbands = 10.0 * np.log10(np.maximum(self.a_min, melbands)) 64 | melbands -= 10.0 * np.log10(np.maximum(self.a_min, self.db_ref)) 65 | melbands = np.maximum(melbands, melbands.max() - self.d_range) 66 | melbands -= np.max(melbands) 67 | 68 | batch.append(melbands.copy()) 69 | 70 | return np.vstack(batch) 71 | 72 | 73 | class EmbeddingsOpenL3: 74 | def __init__(self, graph_path, hop_time=1, batch_size=60, melbands=128): 75 | self.hop_time = hop_time 76 | self.batch_size = batch_size 77 | 78 | self.graph_path = Path(graph_path) 79 | 80 | self.x_size = 199 81 | self.y_size = melbands 82 | self.squeeze = False 83 | 84 | self.permutation = [0, 3, 2, 1] 85 | 86 | self.input_layer = "melspectrogram" 87 | self.output_layer = "embeddings" 88 | 89 | self.mel_extractor = MelSpectrogramOpenL3(hop_time=self.hop_time) 90 | 91 | self.model = es.TensorflowPredict( 92 | graphFilename=str(self.graph_path), 93 | inputs=[self.input_layer], 94 | outputs=[self.output_layer], 95 | squeeze=self.squeeze, 96 | ) 97 | 98 | def compute(self, audio_file): 99 | mel_spectrogram = self.mel_extractor.compute(audio_file) 100 | # in OpenL3 the hop size is computed in the feature extraction level 101 | 102 | hop_size_samples = self.x_size 103 | 104 | batch = self.__melspectrogram_to_batch(mel_spectrogram, hop_size_samples) 105 | 106 | pool = Pool() 107 | embeddings = [] 108 | nbatches = int(np.ceil(batch.shape[0] / self.batch_size)) 109 | for i in range(nbatches): 110 | start = i * self.batch_size 111 | end = min(batch.shape[0], (i + 1) * self.batch_size) 112 | pool.set(self.input_layer, batch[start:end]) 113 | out_pool = self.model(pool) 114 | embeddings.append(out_pool[self.output_layer].squeeze()) 115 | 116 | return np.vstack(embeddings) 117 | 118 | def __melspectrogram_to_batch(self, melspectrogram, hop_time): 119 | npatches = int(np.ceil((melspectrogram.shape[0] - self.x_size) / hop_time) + 1) 120 | batch = np.zeros([npatches, self.x_size, self.y_size], dtype="float32") 121 | for i in range(npatches): 122 | last_frame = min(i * hop_time + self.x_size, melspectrogram.shape[0]) 123 | first_frame = i * hop_time 124 | data_size = last_frame - first_frame 125 | 126 | # the last patch may be empty, remove it and exit the loop 127 | if data_size <= 0: 128 | batch = np.delete(batch, i, axis=0) 129 | break 130 | else: 131 | batch[i, :data_size] = melspectrogram[first_frame:last_frame] 132 | 133 | batch = np.expand_dims(batch, 1) 134 | batch = es.TensorTranspose(permutation=self.permutation)(batch) 135 | 136 | return batch 137 | -------------------------------------------------------------------------------- /mir_ref/metrics.py: -------------------------------------------------------------------------------- 1 | """Custom metrics.""" 2 | 3 | 4 | def key_detection_weighted_accuracy(y_true, y_pred): 5 | """Calculate weighted accuracy for key detection. 6 | 7 | Args: 8 | y_true (list): List of keys (e.g. C# major). 9 | y_pred (list): List of predicted keys. 10 | 11 | Returns: 12 | float: Weighted accuracy. 13 | """ 14 | import mir_eval 15 | import numpy as np 16 | 17 | scores = [] 18 | macro_scores = {} 19 | 20 | for truth, pred in zip(y_true, y_pred): 21 | score = mir_eval.key.weighted_score(truth, pred) 22 | scores.append(score) 23 | if truth not in macro_scores: 24 | macro_scores[truth] = [] 25 | macro_scores[truth].append(score) 26 | 27 | # calculate macro scores 28 | macro_scores_mean = [] 29 | for key, values in macro_scores.items(): 30 | macro_scores_mean.append(np.mean(values)) 31 | 32 | return {"micro": np.mean(scores), "macro": np.mean(macro_scores_mean)} 33 | -------------------------------------------------------------------------------- /mir_ref/probes/probe_builder.py: -------------------------------------------------------------------------------- 1 | """Downstream models for feature evaluation, as well 2 | as helper code to construct and return models. 3 | """ 4 | 5 | from keras import layers, regularizers 6 | from keras.models import Sequential 7 | 8 | 9 | def get_model(model_cfg, dim, n_classes): 10 | if model_cfg["type"] == "classifier": 11 | return classifier(model_cfg, dim, n_classes) 12 | else: 13 | raise ValueError(f"Model type '{model_cfg['type']}' not supported.") 14 | 15 | 16 | def classifier(model_cfg, dim, n_classes): 17 | """Classifier with configurable number of layers.""" 18 | 19 | # if "infer" for hidden units, infer them 20 | if (model_cfg["hidden_units"] is not None) and ( 21 | len(model_cfg["hidden_units"]) != 0 22 | ): 23 | if model_cfg["hidden_units"][0] == "power_infer": 24 | # get layer sizes with power of 2 regression 25 | # y = alpha * x ^ 2 + c, where y = emb_shape[0], c = n_classes, 26 | # and x = n_hidden_layers + 1 27 | alpha = (dim[0] - n_classes) / (len(model_cfg["hidden_units"]) + 1) 28 | for i in range(len(model_cfg["hidden_units"])): 29 | hu = int(alpha * ((i + 1) ** 2) + n_classes) 30 | if hu % 2 != 0: 31 | hu += 1 32 | model_cfg["hidden_units"][i] = hu 33 | 34 | print(f"Hidden units inferred, using {model_cfg['hidden_units']}.") 35 | 36 | if model_cfg["hidden_units"][0] == "infer": 37 | n_layers = len(model_cfg["hidden_units"]) 38 | step_size = (dim[0] - n_classes) / (n_layers + 1) 39 | for i in range(len(model_cfg["hidden_units"])): 40 | hu = int(n_classes + ((n_layers - i) * step_size)) 41 | if hu % 2 != 0: 42 | hu += 1 43 | model_cfg["hidden_units"][i] = hu 44 | 45 | print(f"Hidden units inferred, using {model_cfg['hidden_units']}.") 46 | 47 | model = Sequential() 48 | 49 | # add hidden layers 50 | for i, hu in enumerate(model_cfg["hidden_units"]): 51 | if i == 0: 52 | model.add( 53 | layers.Dense( 54 | units=hu, 55 | activation="relu", 56 | name=f"hidden_layer_{i}", 57 | input_shape=tuple(dim), 58 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]), 59 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]), 60 | ) 61 | ) 62 | else: 63 | model.add( 64 | layers.Dense( 65 | units=hu, 66 | activation="relu", 67 | name=f"hidden_layer_{i}", 68 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]), 69 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]), 70 | ) 71 | ) 72 | 73 | # add output layer 74 | if (model_cfg["hidden_units"] is not None) and ( 75 | len(model_cfg["hidden_units"]) != 0 76 | ): 77 | model.add( 78 | layers.Dense( 79 | units=n_classes, 80 | activation=model_cfg["output_activation"], 81 | name="output_layer", 82 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]), 83 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]), 84 | ) 85 | ) 86 | else: 87 | model.add( 88 | layers.Dense( 89 | units=n_classes, 90 | activation=model_cfg["output_activation"], 91 | name="output_layer", 92 | input_shape=tuple(dim), 93 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]), 94 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]), 95 | ) 96 | ) 97 | model.build() 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /mir_ref/train.py: -------------------------------------------------------------------------------- 1 | """Train downstream models. 2 | """ 3 | 4 | from pathlib import Path 5 | 6 | import keras 7 | import numpy as np 8 | import tensorflow as tf 9 | from colorama import Fore, Style 10 | from sklearn.model_selection import ParameterGrid 11 | from tensorflow.keras import losses, optimizers 12 | from tensorflow.keras.callbacks import EarlyStopping, TensorBoard 13 | 14 | from mir_ref.dataloaders import DataGenerator 15 | from mir_ref.datasets.dataset import get_dataset 16 | from mir_ref.probes.probe_builder import get_model 17 | from mir_ref.utils import load_config 18 | 19 | 20 | def train(cfg_path, run_id=None): 21 | """Make a grid of all combinations of dataset, embedding models, 22 | downstream models and splits and call training for each. 23 | 24 | Args: 25 | cfg_path (str): Path to config file. 26 | run_id (str, optional): Experiment ID, timestamp if not specified. 27 | """ 28 | cfg = load_config(cfg_path) 29 | 30 | if run_id is None: 31 | import datetime 32 | 33 | run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 34 | 35 | for exp_cfg in cfg["experiments"]: 36 | run_params = { 37 | "dataset_cfg": exp_cfg["datasets"], 38 | "feature": exp_cfg["features"], 39 | "model_cfg": exp_cfg["probes"], 40 | } 41 | 42 | # create grid from parameters 43 | grid = ParameterGrid(run_params) 44 | 45 | for params in grid: 46 | # get index of downstream model for naming-logging 47 | model_idx = exp_cfg["probes"].index(params["model_cfg"]) 48 | # get dataset object 49 | dataset = get_dataset( 50 | dataset_cfg=params["dataset_cfg"], 51 | task_cfg=exp_cfg["task"], 52 | features_cfg=exp_cfg["features"], 53 | ) 54 | dataset.download() 55 | # !!!the following only works if there's one experiment per config 56 | if dataset.task_type == "multiclass_classification": 57 | try: 58 | for metric in metrics: 59 | metric.reset_states() 60 | except UnboundLocalError: 61 | metrics = [ 62 | keras.metrics.CategoricalAccuracy(), 63 | keras.metrics.Precision(), 64 | keras.metrics.Recall(), 65 | keras.metrics.AUC(), 66 | ] 67 | elif dataset.task_type == "multilabel_classification": 68 | try: 69 | for metric in metrics: 70 | metric.reset_states() 71 | except UnboundLocalError: 72 | metrics = [ 73 | keras.metrics.Precision(), 74 | keras.metrics.Recall(), 75 | keras.metrics.AUC(curve="ROC"), 76 | keras.metrics.AUC(curve="PR"), 77 | ] 78 | # run task for every split 79 | for split_idx in range(len(dataset.get_splits())): 80 | print( 81 | Fore.GREEN 82 | + f"Task: {dataset.task_name}\n" 83 | + f"└── Dataset: {dataset.name}\n" 84 | + f" └── Embeddings: {params['feature']}\n" 85 | + f" └── Model: {model_idx}\n" 86 | + f" └── Split: {split_idx}", 87 | Style.RESET_ALL, 88 | ) 89 | train_probe( 90 | run_id=run_id, 91 | dataset=dataset, 92 | model_cfg=params["model_cfg"], 93 | model_idx=model_idx, 94 | feature=params["feature"], 95 | split_idx=split_idx, 96 | metrics=metrics, 97 | ) 98 | 99 | 100 | def train_probe(run_id, dataset, model_cfg, model_idx, feature, split_idx, metrics): 101 | """Train a single model per split given parameters. 102 | 103 | Args: 104 | run_id (str): ID of the current run, defaults to timestamp. 105 | dataset (Dataset): Dataset object. 106 | model_cfg (dict): Downstream model config. 107 | model_idx (int): Index of the downstream model in the list of models. 108 | feature (str): Name of the embedding model. 109 | split_idx (int): Index of the split in the list of splits. 110 | metrics (list): List of metrics to use for training. 111 | """ 112 | 113 | split = dataset.get_splits()[split_idx] 114 | n_classes = len(dataset.encoded_labels[dataset.track_ids[0]]) 115 | 116 | if model_cfg["emb_shape"] == "infer": 117 | # get embedding shape from the first embedding 118 | emb_shape = np.load( 119 | dataset.get_embedding_path( 120 | feature=feature, 121 | track_id=split["train"][0], 122 | ) 123 | ).shape 124 | elif isinstance(model_cfg["emb_shape"], int): 125 | emb_shape = model_cfg["emb_shape"] 126 | elif isinstance(model_cfg["emb_shape"], str): 127 | raise ValueError(f"{model_cfg['emb_shape']} not implemented.") 128 | 129 | model = get_model(model_cfg=model_cfg, dim=emb_shape, n_classes=n_classes) 130 | model.summary() 131 | 132 | tr_gen = DataGenerator( 133 | ids_list=split["train"], 134 | labels_dict=dataset.encoded_labels, 135 | paths_dict={ 136 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id) 137 | for t_id in split["train"] 138 | }, 139 | batch_size=model_cfg["batch_size"], 140 | dim=emb_shape, 141 | n_classes=n_classes, 142 | shuffle=True, 143 | ) 144 | val_gen = DataGenerator( 145 | ids_list=split["validation"], 146 | labels_dict=dataset.encoded_labels, 147 | paths_dict={ 148 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id) 149 | for t_id in split["validation"] 150 | }, 151 | batch_size=1, 152 | dim=emb_shape, 153 | n_classes=n_classes, 154 | shuffle=True, 155 | ) 156 | 157 | # dir for tensorboard logs and weights for this run 158 | run_dir = ( 159 | Path(run_id) / dataset.task_name / f"{dataset.name}_{feature}_model-{model_idx}" 160 | ) 161 | 162 | # if the dir already exists, change run dir for duplicates 163 | if (Path("./logs") / run_dir).exists(): 164 | i = 1 165 | while (Path("./logs") / run_dir).exists(): 166 | run_dir = ( 167 | Path(run_id) 168 | / dataset.task_name 169 | / f"{dataset.name}_{feature}_model-{model_idx} ({i})" 170 | ) 171 | i += 1 172 | import warnings 173 | 174 | # raise warning about existing experiment 175 | warnings.warn( 176 | f"Model in '{dataset.name}_{feature}_model-{model_idx}' already exists. " 177 | f"Renaming new run to '{run_dir}'", 178 | stacklevel=2, 179 | ) 180 | 181 | # create dir for this run if it doesn't exist 182 | if not (Path("./logs") / run_dir).exists(): 183 | (Path("./logs") / run_dir).mkdir(parents=True) 184 | 185 | # save model config in run dir 186 | with open(Path("./logs") / run_dir / "model_config.yml", "w+") as f: 187 | yaml.dump(model_cfg, f) 188 | 189 | if dataset.task_type == "multiclass_classification": 190 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 191 | save_weights_only=False, 192 | filepath=str(Path("./logs") / run_dir / "weights.h5"), 193 | save_best_only=True, 194 | monitor="val_categorical_accuracy", 195 | mode="max", 196 | ) 197 | elif dataset.task_type == "multilabel_classification": 198 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( 199 | save_weights_only=False, 200 | filepath=str(Path("./logs") / run_dir / "weights.h5"), 201 | save_best_only=True, 202 | monitor="val_auc", 203 | mode="max", 204 | ) 205 | callbacks = [ 206 | EarlyStopping(patience=model_cfg["patience"]), 207 | TensorBoard(log_dir=str(Path("./logs") / run_dir)), 208 | checkpoint_callback, 209 | ] 210 | 211 | # make sure all metric and callback states are reset 212 | for callback in callbacks: 213 | if hasattr(callback, "reset_state"): 214 | callback.reset_state() 215 | 216 | # loss and optimizer 217 | if dataset.task_type == "multiclass_classification": 218 | loss = losses.CategoricalCrossentropy(from_logits=False) 219 | elif dataset.task_type == "multilabel_classification": 220 | loss = losses.BinaryCrossentropy() 221 | else: 222 | raise ValueError(f"Task type '{dataset.task_type}' not implemented.") 223 | 224 | if model_cfg["optimizer"] == "adam": 225 | optimizer = optimizers.Adam(learning_rate=model_cfg["learning_rate"]) 226 | else: 227 | raise ValueError(f"Optimizer '{model_cfg['optimizer']}' not implemented.") 228 | 229 | model.compile(optimizer=optimizer, loss=loss, metrics=metrics) 230 | 231 | model.fit( 232 | x=tr_gen, 233 | validation_data=val_gen, 234 | batch_size=model_cfg["batch_size"], 235 | validation_batch_size=1, 236 | epochs=model_cfg["epochs"], 237 | callbacks=callbacks, 238 | use_multiprocessing=True, 239 | workers=4, 240 | verbose=1, 241 | ) 242 | -------------------------------------------------------------------------------- /mir_ref/utils.py: -------------------------------------------------------------------------------- 1 | """Various shared utilities.""" 2 | 3 | import yaml 4 | 5 | 6 | def raise_missing_param(param, exp_idx, parent=None): 7 | """Raise an error for a missing parameter.""" 8 | if not parent: 9 | raise ValueError( 10 | f"Missing required parameter: '{param}' in experiment {exp_idx}." 11 | ) 12 | else: 13 | raise ValueError( 14 | f"Missing required parameter: '{param}' in '{parent}' of experiment {exp_idx}." 15 | ) 16 | 17 | 18 | def load_config(cfg_path): 19 | """Load a YAML config file. Check formatting, and add 20 | missing keys.""" 21 | with open(cfg_path, "r") as f: 22 | cfg = yaml.safe_load(f) 23 | 24 | if "experiments" not in cfg: 25 | raise ValueError("'experiments' missing, please check config file structure.") 26 | 27 | for i, exp in enumerate(cfg["experiments"]): 28 | # check top-level parameters 29 | for top_level_param in [ 30 | "task", 31 | "datasets", 32 | "features", 33 | "probes", 34 | ]: 35 | if top_level_param not in exp: 36 | raise_missing_param(param=top_level_param, exp_idx=i) 37 | 38 | # check task parameters 39 | for task_param in ["name", "type"]: 40 | if task_param not in exp["task"]: 41 | raise_missing_param(param=task_param, exp_idx=i, parent="task") 42 | if "feature_aggregation" not in exp["task"]: 43 | exp["task"]["feature_aggregation"] = "mean" 44 | 45 | # check dataset parameters 46 | for j, dataset in enumerate(exp["datasets"]): 47 | for dataset_param in ["name", "dir"]: 48 | if dataset_param not in dataset: 49 | raise_missing_param( 50 | param=dataset_param, exp_idx=i, parent="datasets" 51 | ) 52 | if "split_type" not in dataset: 53 | cfg[i]["datasets"][j]["split_type"] = "random" 54 | if "deformations" not in dataset: 55 | cfg[i]["datasets"][j]["deformations"] = [] 56 | 57 | # check downstream model parameters 58 | for j, model in enumerate(exp["probes"]): 59 | for model_param in ["type"]: 60 | if model_param not in model: 61 | raise_missing_param(param=model_param, exp_idx=i, parent="probes") 62 | if "emb_dim_reduction" not in model: 63 | cfg[i]["probes"][j]["emb_dim_reduction"] = None 64 | if "emb_shape" not in model: 65 | cfg[i]["probes"][j]["emb_shape"] = None 66 | if "hidden_units" not in model: 67 | cfg[i]["probes"][j]["hidden_units"] = [] 68 | if "output_activation" not in model: 69 | if exp["task"]["type"] == "multiclass_classification": 70 | cfg[i]["probes"][j]["output_activation"] = "softmax" 71 | elif exp["task"]["type"] == "multilabel_classification": 72 | cfg[i]["probes"][j]["output_activation"] = "sigmoid" 73 | if "weight_decay" not in model: 74 | cfg[i]["probes"][j]["weight_decay"] = 0.0 75 | if "optimizer" not in model: 76 | cfg[i]["probes"][j]["optimizer"] = "adam" 77 | if "learning_rate" not in model: 78 | cfg[i]["probes"][j]["learning_rate"] = 1e-3 79 | if "batch_size" not in model: 80 | cfg[i]["probes"][j]["batch_size"] = 1 81 | if "epochs" not in model: 82 | cfg[i]["probes"][j]["epochs"] = 100 83 | if "patience" not in model: 84 | cfg[i]["probes"][j]["patience"] = 10 85 | if "train_sampling" not in model: 86 | cfg[i]["probes"][j]["train_sampling"] = "random" 87 | 88 | return cfg 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audiomentations 2 | colorama 3 | essentia-tensorflow==2.1b6.dev1110 4 | librosa 5 | pyyaml 6 | scikit-learn 7 | tensorflow==2.15 8 | torch==2.0.1 9 | torchaudio==2.0.2 10 | torchvision==0.15.2 11 | tqdm 12 | wget 13 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Script to invoke embedding generation and evaluation. 2 | """ 3 | 4 | import argparse 5 | import os 6 | 7 | from mir_ref.conduct import conduct 8 | from mir_ref.deform import deform 9 | from mir_ref.evaluate import evaluate 10 | from mir_ref.extract import generate 11 | from mir_ref.train import train 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | 17 | subparsers = parser.add_subparsers(dest="command") 18 | 19 | # End to end 20 | parser_conduct = subparsers.add_parser("conduct") 21 | parser_conduct.add_argument( 22 | "--config", 23 | "-c", 24 | default="configs/default.yml", 25 | help="Path of configuration file.", 26 | ) 27 | 28 | # Audio deformation 29 | parser_deform = subparsers.add_parser("deform") 30 | parser_deform.add_argument( 31 | "--config", 32 | "-c", 33 | default="configs/default.yml", 34 | help="Path of configuration file.", 35 | ) 36 | parser_deform.add_argument( 37 | "--n_jobs", default=1, type=int, help="Number of parallel jobs" 38 | ) 39 | 40 | # Feature extraction 41 | parser_extract = subparsers.add_parser("extract") 42 | parser_extract.add_argument( 43 | "--config", 44 | "-c", 45 | default="configs/default.yml", 46 | help="Path of configuration file.", 47 | ) 48 | parser_extract.add_argument( 49 | "--skip_clean", 50 | action="store_true", 51 | help="Skip extracting features from the clean audio.", 52 | ) 53 | parser_extract.add_argument( 54 | "--skip_deformed", 55 | action="store_true", 56 | help="Skip extracting features from the deformed audio.", 57 | ) 58 | parser_extract.add_argument( 59 | "--no_overwrite", 60 | action="store_true", 61 | help="Skip extracting features if they already exist.", 62 | ) 63 | parser_extract.add_argument( 64 | "--deform_list", 65 | default=None, 66 | help="Deformation scenario indices to extract features for. Arguments as comma-separated integers, e.g. 0,1,2,3", 67 | ) 68 | 69 | # Training 70 | parser_train = subparsers.add_parser("train") 71 | parser_train.add_argument( 72 | "--config", 73 | "-c", 74 | default="configs/default.yml", 75 | help="Path of configuration file.", 76 | ) 77 | parser_train.add_argument( 78 | "--run_id", 79 | default=None, 80 | help="Optional experiment ID, otherwise timestamp is used.", 81 | ) 82 | 83 | # Evaluation 84 | parser_evaluate = subparsers.add_parser("evaluate") 85 | parser_evaluate.add_argument( 86 | "--config", 87 | "-c", 88 | default="configs/default.yml", 89 | help="Path of configuration file.", 90 | ) 91 | parser_evaluate.add_argument( 92 | "--run_id", 93 | default=None, 94 | help="Experiment ID to evaluate, otherwise retrieves latest if timestamp is available.", 95 | ) 96 | 97 | args = parser.parse_args() 98 | 99 | if args.command == "conduct": 100 | conduct(cfg_path=os.path.join("./configs/", args.config + ".yml")) 101 | 102 | if args.command == "deform": 103 | deform( 104 | cfg_path=os.path.join("./configs/", args.config + ".yml"), 105 | n_jobs=args.n_jobs, 106 | ) 107 | elif args.command == "extract": 108 | if args.deform_list: 109 | args.deform_list = [int(i) for i in args.deform_list.split(",")] 110 | generate( 111 | cfg_path=os.path.join("./configs/", args.config + ".yml"), 112 | skip_clean=args.skip_clean, 113 | skip_deformed=args.skip_deformed, 114 | no_overwrite=args.no_overwrite, 115 | deform_list=args.deform_list, 116 | ) 117 | elif args.command == "train": 118 | train( 119 | cfg_path=os.path.join("./configs/", args.config + ".yml"), 120 | run_id=args.run_id, 121 | ) 122 | elif args.command == "evaluate": 123 | evaluate( 124 | cfg_path=os.path.join("./configs/", args.config + ".yml"), 125 | run_id=args.run_id, 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /tests/test_cfg.yml: -------------------------------------------------------------------------------- 1 | # Configuration file for tests 2 | experiments: 3 | - task: # single task per experiment 4 | name: instrument_classification 5 | type: multiclass_classification 6 | feature_aggregation: mean 7 | datasets: 8 | - name: tinysol 9 | type: mirdata 10 | dir: tests/data/tinysol/ 11 | split_type: single 12 | deformations: 13 | - - type: AddGaussianSNR 14 | params: 15 | min_snr_db: 15 16 | max_snr_db: 15 17 | p: 1 18 | - type: Gain 19 | params: 20 | min_gain_db: -12 21 | max_gain_db: -12 22 | p: 1 23 | - - type: Mp3Compression 24 | params: 25 | min_bitrate: 32 26 | max_bitrate: 32 27 | p: 1 28 | features: 29 | - vggish-audioset 30 | - effnet-discogs 31 | - msd-musicnn 32 | - openl3 33 | - neuralfp 34 | probes: 35 | - type: classifier 36 | emb_dim_reduction: False 37 | emb_shape: infer 38 | hidden_units: [infer] 39 | output_activation: softmax 40 | weight_decay: 1.0e-5 41 | # optimizer 42 | optimizer: adam 43 | learning_rate: 1.0e-2 44 | # training 45 | batch_size: 16 46 | epochs: 2 47 | patience: 10 48 | train_sampling: random 49 | - type: classifier 50 | emb_dim_reduction: False 51 | emb_shape: infer 52 | hidden_units: [infer, infer] 53 | output_activation: softmax 54 | weight_decay: 1.0e-5 55 | # optimizer 56 | optimizer: adam 57 | learning_rate: 1.0e-2 58 | # training 59 | batch_size: 16 60 | epochs: 2 61 | patience: 10 62 | train_sampling: random 63 | - type: classifier 64 | emb_dim_reduction: False 65 | emb_shape: infer 66 | hidden_units: [64] 67 | output_activation: softmax 68 | weight_decay: 1.0e-5 69 | # optimizer 70 | optimizer: adam 71 | learning_rate: 1.0e-2 72 | # training 73 | batch_size: 16 74 | epochs: 50 75 | patience: 10 76 | train_sampling: random 77 | - type: classifier 78 | emb_dim_reduction: False 79 | emb_shape: infer 80 | hidden_units: [96, 64] 81 | output_activation: softmax 82 | weight_decay: 1.0e-5 83 | # optimizer 84 | optimizer: adam 85 | learning_rate: 1.0e-2 86 | # training 87 | batch_size: 16 88 | epochs: 50 89 | patience: 10 90 | train_sampling: random 91 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | """Tests for dataset objects.""" 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import yaml 8 | 9 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 10 | 11 | from mir_ref.datasets.dataset import get_dataset 12 | 13 | # load configuration file, used in all tests 14 | # it uses the tinysol dataset, implemented with mirdata 15 | with open("./tests/test_cfg.yml", "r") as f: 16 | exp_cfg = yaml.safe_load(f)["experiments"][0] 17 | 18 | dataset = get_dataset( 19 | exp_cfg["datasets"][0], 20 | exp_cfg["task"], 21 | exp_cfg["features"], 22 | ) 23 | 24 | 25 | def test_download_metadata(): 26 | dataset.download_metadata() 27 | assert Path("./tests/data/tinysol/annotation").is_dir() 28 | assert Path("./tests/data/tinysol/annotation/TinySOL_metadata.csv").exists() 29 | 30 | 31 | def test_load_metadata(): 32 | dataset.load_metadata() 33 | assert ( 34 | dataset.common_audio_dir == "tests/data/tinysol/audio" 35 | or dataset.common_audio_dir == "tests/data/tinysol/audio/" 36 | ) 37 | assert dataset.track_ids[0] == "BTb-ord-F#1-pp-N-N" 38 | assert dataset.labels["BTb-ord-F#1-pp-N-N"] == "Bass Tuba" 39 | assert len(set(dataset.labels.values())) == 14 40 | categorical_encoded_label = list(np.zeros((14)).astype(np.float32)) 41 | categorical_encoded_label[2] = 1.0 42 | assert np.array_equal( 43 | dataset.encoded_labels["BTb-ord-F#1-pp-N-N"], categorical_encoded_label 44 | ) 45 | assert ( 46 | dataset.audio_paths["BTb-ord-F#1-pp-N-N"] 47 | == "tests/data/tinysol/audio/Brass/Bass_Tuba/ordinario/" 48 | + "BTb-ord-F#1-pp-N-N.wav" 49 | ) 50 | 51 | 52 | def test_extended_metadata(): 53 | dataset.load_metadata() 54 | 55 | deformed_audio_path = dataset.get_deformed_audio_path("BTb-ord-F#1-pp-N-N", 0) 56 | assert deformed_audio_path == ( 57 | "tests/data/tinysol/audio_deformed/Brass/Bass_Tuba/ordinario/" 58 | + "BTb-ord-F#1-pp-N-N_deform_0.wav" 59 | ) 60 | emb_path = dataset.get_embedding_path("BTb-ord-F#1-pp-N-N", "vggish-audioset") 61 | assert emb_path == ( 62 | "tests/data/tinysol/embeddings/vggish-audioset/Brass/Bass_Tuba/ordinario/" 63 | + "BTb-ord-F#1-pp-N-N.npy" 64 | ) 65 | deformed_emb_path = dataset.get_deformed_embedding_path( 66 | "BTb-ord-F#1-pp-N-N", 0, "vggish-audioset" 67 | ) 68 | assert deformed_emb_path == ( 69 | "tests/data/tinysol/embeddings/vggish-audioset/Brass/Bass_Tuba/ordinario/" 70 | + "BTb-ord-F#1-pp-N-N_deform_0.npy" 71 | ) 72 | 73 | 74 | test_load_metadata() 75 | -------------------------------------------------------------------------------- /tests/test_deform.py: -------------------------------------------------------------------------------- 1 | """Tests for audio deformations.""" 2 | 3 | import shutil 4 | import sys 5 | from pathlib import Path 6 | 7 | import yaml 8 | 9 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 10 | 11 | with open("./tests/test_cfg.yml", "r") as f: 12 | exp_cfg = yaml.safe_load(f)["experiments"][0] 13 | 14 | # we need to load and download the dataset before we can test the deformations 15 | from mir_ref.datasets.dataset import get_dataset 16 | from mir_ref.deformations import generate_deformations 17 | 18 | dataset = get_dataset( 19 | dataset_cfg=exp_cfg["datasets"][0], 20 | task_cfg=exp_cfg["task"], 21 | features_cfg=exp_cfg["features"], 22 | ) 23 | 24 | dataset.download() 25 | dataset.preprocess() 26 | dataset.load_metadata() 27 | 28 | # keep only a few track_ids for testing 29 | # this also tests if everything is correctly anchored to the dataset object 30 | # and its track_ids 31 | # unfortunately we can't currently download only a few tracks from an mirdata dataset 32 | dataset.track_ids = dataset.track_ids[:5] 33 | first_track_id = dataset.track_ids[0] 34 | 35 | 36 | def test_single_threaded_deformations(): 37 | generate_deformations( 38 | dataset, 39 | n_jobs=1, 40 | ) 41 | assert Path( 42 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0) 43 | ).exists() 44 | assert Path( 45 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1) 46 | ).exists() 47 | assert not Path( 48 | dataset.audio_paths[first_track_id].replace(".wav", "_deform_2.wav") 49 | ).exists() 50 | assert not Path( 51 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0).replace( 52 | ".wav", "_deform_0.wav" 53 | ) 54 | ).exists() 55 | 56 | # delete computed deformations 57 | shutil.rmtree("tests/data/tinysol/audio_deformed") 58 | 59 | 60 | def test_multi_threaded_deformations(): 61 | generate_deformations( 62 | dataset, 63 | n_jobs=2, 64 | ) 65 | assert Path( 66 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0) 67 | ).exists() 68 | assert Path( 69 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1) 70 | ).exists() 71 | assert not Path( 72 | dataset.audio_paths[first_track_id].replace(".wav", "_deform_2.wav") 73 | ).exists() 74 | assert not Path( 75 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0).replace( 76 | ".wav", "_deform_0.wav" 77 | ) 78 | ).exists() 79 | 80 | # delete computed deformations 81 | shutil.rmtree("tests/data/tinysol/audio_deformed") 82 | -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | """Tests embedding inference.""" 2 | 3 | import shutil 4 | import sys 5 | from pathlib import Path 6 | 7 | import yaml 8 | 9 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 10 | 11 | with open("./tests/test_cfg.yml", "r") as f: 12 | exp_cfg = yaml.safe_load(f)["experiments"][0] 13 | 14 | # we need to load and download the dataset before we can test the deformations 15 | from mir_ref.datasets.dataset import get_dataset 16 | from mir_ref.features.feature_extraction import generate_embeddings 17 | 18 | dataset = get_dataset( 19 | dataset_cfg=exp_cfg["datasets"][0], 20 | task_cfg=exp_cfg["task"], 21 | features_cfg=exp_cfg["features"], 22 | ) 23 | 24 | dataset.download() 25 | dataset.preprocess() 26 | dataset.load_metadata() 27 | 28 | # keep only a few track_ids for testing 29 | # this also tests if everything is correctly anchored to the dataset object 30 | # and its track_ids 31 | # unfortunately we can't currently download only a few tracks from an mirdata dataset 32 | dataset.track_ids = dataset.track_ids[:5] 33 | first_track_id = dataset.track_ids[0] 34 | 35 | 36 | def test_models(): 37 | dataset.deformations_cfg = None 38 | for model_name in exp_cfg["features"]: 39 | generate_embeddings( 40 | dataset, 41 | model_name=model_name, 42 | ) 43 | 44 | assert Path( 45 | dataset.get_embedding_path(track_id=first_track_id, feature=model_name) 46 | ).exists() 47 | 48 | # delete computed embeddings 49 | shutil.rmtree("tests/data/tinysol/embeddings") 50 | 51 | 52 | def test_models_deformed_audio(): 53 | dataset.deformations_cfg = exp_cfg["deformations"] 54 | # we need to first compute the deformations 55 | from mir_ref.deformations import generate_deformations 56 | 57 | generate_deformations( 58 | dataset, 59 | n_jobs=2, 60 | ) 61 | 62 | for model_name in exp_cfg["features"]: 63 | generate_embeddings( 64 | dataset, 65 | model_name=model_name, 66 | ) 67 | assert Path( 68 | dataset.get_embedding_path(track_id=first_track_id, feature=model_name) 69 | ).exists() 70 | assert Path( 71 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0) 72 | ) 73 | assert Path( 74 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1) 75 | ) 76 | 77 | # delete computed deformations 78 | shutil.rmtree("tests/data/tinysol/audio_deformed") 79 | 80 | # delete computed embeddings 81 | shutil.rmtree("tests/data/tinysol/embeddings") 82 | 83 | 84 | # test_models() 85 | --------------------------------------------------------------------------------