├── .gitignore
├── README.md
├── configs
└── example.yml
├── docs
└── img
│ └── mir_ref_logo.svg
├── mir_ref
├── __init__.py
├── conduct.py
├── dataloaders.py
├── datasets
│ ├── custom_datasets.py
│ ├── dataset.py
│ ├── datasets
│ │ ├── magnatagatune.py
│ │ ├── mtg_jamendo.py
│ │ └── vocalset.py
│ └── mirdata_datasets.py
├── deform.py
├── deformations.py
├── evaluate.py
├── extract.py
├── features
│ ├── custom_features.py
│ ├── feature_extraction.py
│ └── models
│ │ ├── clmr.py
│ │ ├── harmonic_cnn.py
│ │ ├── mule.py
│ │ └── openl3.py
├── metrics.py
├── probes
│ └── probe_builder.py
├── train.py
└── utils.py
├── requirements.txt
├── run.py
└── tests
├── test_cfg.yml
├── test_datasets.py
├── test_deform.py
└── test_generate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # macos
156 | **./DS_Store
157 |
158 | # vscode
159 | .vscode/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # mir_ref
4 |
5 | Representation Evaluation Framework for Music Information Retrieval tasks | [Paper](https://arxiv.org/abs/2312.05994)
6 |
7 | `mir_ref` is an open-source library for evaluating audio representations (embeddings or others) on a variety of music-related downstream tasks and datasets. It has two main capabilities:
8 |
9 | * Using a config file, you can specify all evaluation experiments you want to run. `mir_ref` will automatically acquire data and conduct the experiments - no coding or data handling needed. Many tasks, datasets, embedding models etc. are ready to use (see supported options section).
10 | * You can easily integrate your own features (e.g. embedding models), datasets, probes, metrics, and audio deformations and use them in `mir_ref` experiments.
11 |
12 | `mir_ref` builds upon existing reproducability efforts in music audio, including [`mirdata`](https://mirdata.readthedocs.io/en/stable/) for data handling, [`mir_eval`](https://craffel.github.io/mir_eval/) for evaluation metrics, and [`essentia models`](https://essentia.upf.edu/models.html) for pretrained audio models.
13 |
14 | ## Disclaimer
15 |
16 | A first beta release is expected by the end of April, and it includes many fixes and documentation improvements.
17 |
18 | ## Setup
19 |
20 | Clone the repository. Create and activate a python>3.9 environment and install the requirements.
21 |
22 | ```
23 | cd mir_ref
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ## Running
28 |
29 | To run the experiments specified in a config file `configs/example.yml` end-to-end:
30 |
31 | ```
32 | python run.py conduct -c example
33 | ```
34 |
35 | This will currently save deformations and features to use them later, but an online option will soon be available.
36 |
37 | Alternatively, `mir_ref` is comprised of 4 main functions-commands: `deform`, for generating deformations from a dataset; `extract`, for extracting features; `train`, for training probes; and `evaluate`, for evaluating them. These can be run as follows:
38 |
39 | ```
40 | python run.py COMMAND -c example
41 | ```
42 |
43 | `deform` optionally includes the option `n_jobs` for specifying parallelization of deformation computation, and `extract` includes the flag `--no_overwrite` to skip recomputing existing features.
44 |
45 | #### Configuration file
46 |
47 | An example configuration file is provided. Config files are written in YAML. A list of experiments is expected at the top level, and each experiment contains a task, datasets, features, and probes. For each dataset, a list of deformation scenarios can be specified, following the argument syntax of [audiomentations](https://iver56.github.io/audiomentations/).
48 |
49 | ## Currently supported options
50 |
51 | ### Datasets and Tasks
52 |
53 | * `magnatagatune`: MagnaTagATune (autotagging)
54 | * `mtg_jamendo`: MTG Jamendo (autotagging)
55 | * `vocalset`: VocalSet (singer_identification, technique_identification)
56 | * `tinysol`: TinySol (instrument_classification, pitch_class_classification)
57 | * `beaport`: Beatport (key_estimation)
58 |
59 | ~Many more soon
60 |
61 | ### Features
62 |
63 | * `effnet-discogs`
64 | * `vggish-audioset`
65 | * `msd-musicnn`
66 | * `openl3`
67 | * `neuralfp`
68 | * `clmr-v2`
69 | * `mert-v1-95m-6` / `mert-v1-95m-0-1-2-3` / `mert-v1-95m-4-5-6-7-8` / `mert-v1-95m-9-10-11-12` (referring to the layers used)
70 | * `maest`
71 |
72 | ~More soon
73 |
74 | ## Example results
75 |
76 | We conducted an example evaluation of 7 models in 6 tasks with 4 deformations and 5 different probing setups. For the full results, please refer to the 'Evaluation' chapter of [this thesis document](https://zenodo.org/records/8380471), pages 39-58.
77 |
78 | ## Citing
79 |
80 | If you use `mir_ref`, please cite the following [paper](https://arxiv.org/abs/2312.05994):
81 |
82 | ```
83 | @inproceedings{mir_ref,
84 | author = {Christos Plachouras and Pablo Alonso-Jim\'enez and Dmitry Bogdanov},
85 | title = {mir_ref: A Representation Evaluation Framework for Music Information Retrieval Tasks},
86 | booktitle = {37th Conference on Neural Information Processing Systems (NeurIPS), Machine Learning for Audio Workshop},
87 | address = {New Orleans, LA, USA},
88 | year = 2023,
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/configs/example.yml:
--------------------------------------------------------------------------------
1 | experiments:
2 | - task:
3 | name: autotagging
4 | type: multilabel_classification
5 | feature_aggregation: mean
6 | datasets:
7 | - name: magnatagatune
8 | type: custom
9 | dir: data/magnatagatune/
10 | split_type: single
11 | deformations:
12 | - - type: AddGaussianSNR
13 | params:
14 | min_snr_in_db: 15
15 | max_snr_in_db: 15
16 | p: 1
17 | - - type: AddGaussianSNR
18 | params:
19 | min_snr_in_db: 0
20 | max_snr_in_db: 0
21 | p: 1
22 | - - type: Gain
23 | params:
24 | min_gain_in_db: -12
25 | max_gain_in_db: -12
26 | p: 1
27 | features:
28 | - vggish-audioset
29 | - clmr-v2
30 | - mert-v1-95m-6
31 | probes:
32 | - type: classifier
33 | emb_dim_reduction: False
34 | emb_shape: infer
35 | hidden_units: []
36 | output_activation: sigmoid
37 | weight_decay: 1.0e-5
38 | # optimizer
39 | optimizer: adam
40 | learning_rate: 1.0e-3
41 | # training
42 | batch_size: 1083
43 | epochs: 100
44 | patience: 10
45 | train_sampling: random
46 | - type: classifier
47 | emb_dim_reduction: False
48 | emb_shape: infer
49 | hidden_units: [infer]
50 | output_activation: sigmoid
51 | weight_decay: 1.0e-5
52 | # optimizer
53 | optimizer: adam
54 | learning_rate: 1.0e-3
55 | # training
56 | batch_size: 1083
57 | epochs: 100
58 | patience: 10
59 | train_sampling: random
60 | - type: classifier
61 | emb_dim_reduction: False
62 | emb_shape: infer
63 | hidden_units: [256, 128]
64 | output_activation: sigmoid
65 | weight_decay: 1.0e-5
66 | # optimizer
67 | optimizer: adam
68 | learning_rate: 1.0e-3
69 | # training
70 | batch_size: 1083
71 | epochs: 100
72 | patience: 10
73 | train_sampling: random
74 |
75 |
--------------------------------------------------------------------------------
/docs/img/mir_ref_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
241 |
--------------------------------------------------------------------------------
/mir_ref/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrispla/mir_ref/691ae42815db6359ef66c1e175be55be42bbc340/mir_ref/__init__.py
--------------------------------------------------------------------------------
/mir_ref/conduct.py:
--------------------------------------------------------------------------------
1 | """End-to-end experiment conduct. Currently, this is a hacky way of
2 | doing things, and it will be replaced soon with an online process."""
3 |
4 |
5 | from mir_ref.deform import deform
6 | from mir_ref.evaluate import evaluate
7 | from mir_ref.extract import generate
8 | from mir_ref.train import train
9 |
10 |
11 | def conduct(cfg_path):
12 | """Conduct experiments in config end-to-end.
13 |
14 | Args:
15 | cfg_path (str): Path to config file.
16 | """
17 | deform(cfg_path=cfg_path, n_jobs=1)
18 | generate(cfg_path=cfg_path)
19 | train(cfg_path=cfg_path)
20 | evaluate(cfg_path=cfg_path)
21 |
--------------------------------------------------------------------------------
/mir_ref/dataloaders.py:
--------------------------------------------------------------------------------
1 | """Dataloaders.
2 | Code from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
3 | """
4 |
5 | from pathlib import Path
6 |
7 | import keras
8 | import numpy as np
9 |
10 |
11 | class DataGenerator(keras.utils.Sequence):
12 | """Generates data for Keras"""
13 |
14 | def __init__(
15 | self,
16 | ids_list,
17 | labels_dict,
18 | paths_dict,
19 | dim,
20 | n_classes,
21 | batch_size=8,
22 | shuffle=True,
23 | ):
24 | """Initialization"""
25 | self.dim = dim
26 | self.batch_size = batch_size
27 | self.ids_list = ids_list
28 | self.labels_dict = labels_dict
29 | self.paths_dict = paths_dict
30 | self.n_classes = n_classes
31 | self.shuffle = shuffle
32 | self.on_epoch_end()
33 |
34 | def __len__(self):
35 | """Denotes the number of batches per epoch. Last non-full batch is dropped."""
36 | return int(np.floor(len(self.ids_list) / self.batch_size))
37 |
38 | def __getitem__(self, index):
39 | """Generate one batch of data"""
40 | # Generate indexes of the batch
41 | indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
42 |
43 | # Find list of ids
44 | ids_list_temp = [self.ids_list[k] for k in indexes]
45 |
46 | # Generate data
47 | X, y = self.__data_generation(ids_list_temp=ids_list_temp)
48 |
49 | return X, y
50 |
51 | def on_epoch_end(self):
52 | """Updates indexes after each epoch"""
53 | self.indexes = np.arange(len(self.ids_list))
54 | if self.shuffle is True:
55 | np.random.shuffle(self.indexes)
56 |
57 | def __data_generation(self, ids_list_temp):
58 | """Generates data containing batch_size samples
59 | X : (n_samples, *dim)
60 | """
61 | # Initialization
62 | X = np.empty((self.batch_size, *self.dim))
63 | y = np.empty((self.batch_size, self.n_classes), dtype=int)
64 |
65 | # Generate data
66 | for i, t_id in enumerate(ids_list_temp):
67 | # Store sample
68 | emb = np.load(self.paths_dict[t_id])
69 | X[i,] = np.reshape(emb, *self.dim)
70 |
71 | y[i] = self.labels_dict[t_id]
72 |
73 | return X, y
74 |
--------------------------------------------------------------------------------
/mir_ref/datasets/custom_datasets.py:
--------------------------------------------------------------------------------
1 | """Implement your own datasets here. Provide custom implementations
2 | for the given methods, without changing any of the given method
3 | and input/output names.
4 | """
5 |
6 | from mir_ref.datasets.dataset import Dataset
7 |
8 |
9 | class CustomDataset0(Dataset):
10 | def __init__(
11 | self,
12 | name,
13 | dataset_type,
14 | data_dir,
15 | split_type,
16 | task_name,
17 | task_type,
18 | feature_aggregation,
19 | deformations_cfg,
20 | features_cfg,
21 | ):
22 | """Custom Dataset class.
23 |
24 | Args:
25 | name (str): Name of the dataset.
26 | dataset_type (str): Type of the dataset ("mirdata", "custom")
27 | data_dir (str): Path to the dataset directory.
28 | split_type (str): split_type (str, optional): Whether to use "all" or "single" split.
29 | Defaults to "single". If list of 3 floats,
30 | use a train, val, test split sizes.
31 | task_name (str): Name of the task.
32 | task_type (str): Type of the task. ("multiclass_classification",
33 | "multilabel_classification", "regression")
34 | feature_aggregation (str): Type of embedding aggregation ("mean", None)
35 | deformations_cfg (dict): Deformations config.
36 | features_cfg (dict): Embedding models config.
37 | """
38 | super().__init__(
39 | name=name,
40 | dataset_type=dataset_type,
41 | data_dir=data_dir,
42 | split_type=split_type,
43 | task_name=task_name,
44 | task_type=task_type,
45 | feature_aggregation=feature_aggregation,
46 | deformations_cfg=deformations_cfg,
47 | features_cfg=features_cfg,
48 | )
49 |
50 | @Dataset.try_to_load_metadata
51 | def download(self):
52 | pass
53 |
54 | @Dataset.try_to_load_metadata
55 | def download_metadata(self):
56 | # optional method
57 | pass
58 |
59 | def load_track_ids(self):
60 | # return list of track_ids
61 | pass
62 |
63 | def load_audio_paths(self):
64 | # return dict of track_id: audio_path
65 | pass
66 |
67 | @Dataset.check_metadata_is_loaded
68 | def get_splits(self):
69 | # return list of dicts with {train: [track_id_0, ...],
70 | # test: [", ...], val: [", ...]}
71 | pass
72 |
--------------------------------------------------------------------------------
/mir_ref/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | """Generic Dataset class and dataset factory.
2 | """
3 |
4 | import os
5 | from pathlib import Path
6 |
7 |
8 | def get_dataset(dataset_cfg, task_cfg, features_cfg):
9 | kwargs = {
10 | "name": dataset_cfg["name"],
11 | "dataset_type": dataset_cfg["type"],
12 | "data_dir": dataset_cfg["dir"],
13 | "split_type": dataset_cfg["split_type"],
14 | "task_name": task_cfg["name"],
15 | "task_type": task_cfg["type"],
16 | "feature_aggregation": task_cfg["feature_aggregation"],
17 | "deformations_cfg": dataset_cfg["deformations"],
18 | "features_cfg": features_cfg,
19 | }
20 | if dataset_cfg["type"] == "mirdata":
21 | from mir_ref.datasets.mirdata_datasets import MirdataDataset
22 |
23 | return MirdataDataset(**kwargs)
24 |
25 | elif dataset_cfg["type"] == "custom":
26 | if dataset_cfg["name"] == "vocalset":
27 | from mir_ref.datasets.datasets.vocalset import VocalSet
28 |
29 | return VocalSet(**kwargs)
30 | elif dataset_cfg["name"] in [
31 | "mtg-jamendo-moodtheme",
32 | "mtg-jamendo-instruments",
33 | "mtg-jamendo-genre",
34 | "mtg-jamendo-top50tags",
35 | ]:
36 | from mir_ref.datasets.datasets.mtg_jamendo import MTG_Jamendo
37 |
38 | return MTG_Jamendo(**kwargs)
39 | elif dataset_cfg["name"] == "magnatagatune":
40 | from mir_ref.datasets.datasets.magnatagatune import MagnaTagATune
41 |
42 | return MagnaTagATune(**kwargs)
43 | else:
44 | raise NotImplementedError(
45 | f"Custom dataset with name '{dataset_cfg['name']}' does not exist."
46 | )
47 |
48 | else:
49 | raise NotImplementedError
50 |
51 |
52 | class Dataset:
53 | def __init__(
54 | self,
55 | name,
56 | dataset_type,
57 | task_name,
58 | task_type,
59 | feature_aggregation,
60 | deformations_cfg,
61 | features_cfg,
62 | split_type="single",
63 | data_dir=None,
64 | ):
65 | """Generic Dataset class.
66 |
67 | Args:
68 | name (str): Name of the dataset.
69 | dataset_type (str): Type of the dataset ("mirdata", "custom")
70 | task_name (str): Name of the task.
71 | task_type (str): Type of the task.
72 | data_dir (str, optional): Path to the dataset directory.
73 | Defaults to ./data/{name}/.
74 | split_type (str, optional): Whether to use "all" or "single" split.
75 | Defaults to "single". If list of 3 floats,
76 | use a train, val, test split sizes.
77 | deformations_cfg (list, optional): List of deformation scenarios.
78 | features_cfg (list, optional): List of embedding models.
79 | """
80 | self.deformations_cfg = deformations_cfg
81 | self.features_cfg = features_cfg
82 | self.name = name
83 | self.dataset_type = dataset_type
84 | self.data_dir = data_dir
85 | if data_dir is None:
86 | self.data_dir = f"./data/{name}/"
87 | self.split_type = split_type
88 | self.task_name = task_name
89 | self.task_type = task_type
90 | self.feature_aggregation = feature_aggregation
91 | try:
92 | self.load_metadata()
93 | except:
94 | self.track_ids = None # list of track_ids
95 | self.labels = None # dict of track_id: label
96 | self.encoded_labels = None # dict of track_id: encoded_label
97 | self.audio_paths = None # dict of track_id: audio_path
98 | self.common_audio_dir = None # common directory between all audio paths
99 |
100 | def check_params(self):
101 | # check validity of provided parameters
102 | return
103 |
104 | def check_metadata_is_loaded(func):
105 | """Decorator to check if metadata is loaded."""
106 |
107 | def wrapper(self):
108 | if any(
109 | var is None
110 | for var in [
111 | self.track_ids,
112 | self.labels,
113 | self.encoded_labels,
114 | self.audio_paths,
115 | self.common_audio_dir,
116 | ]
117 | ):
118 | raise ValueError(
119 | "Metadata not loaded. Make sure to run 'dataset.download()'."
120 | )
121 | return func(self)
122 |
123 | return wrapper
124 |
125 | def try_to_load_metadata(func):
126 | """Try to load metadata after a function. Aimed at functions
127 | that download or preprocess datasets.
128 | """
129 |
130 | def wrapper(self):
131 | func(self)
132 | try:
133 | self.load_metadata()
134 | except Exception as e:
135 | print(e)
136 |
137 | return wrapper
138 |
139 | def download(self):
140 | # to be overwritten by child class
141 | return
142 |
143 | def download_metadata(self):
144 | # to be overwritten by child class
145 | return
146 |
147 | def preprocess(self):
148 | # to be overwritten by child class
149 | return
150 |
151 | def load_track_ids(self):
152 | # to be overwritten by child class
153 | return
154 |
155 | def load_audio_paths(self):
156 | # to be overwritten by child class
157 | return
158 |
159 | def load_labels(self):
160 | # to be overwritten by child class
161 | return
162 |
163 | def load_encoded_labels(self):
164 | """Return encoded labels given task type and labels."""
165 |
166 | # get lists of track_ids and labels (corresponding indices)
167 | track_ids = list(self.labels.keys())
168 | labels_list = list(self.labels.values())
169 |
170 | if self.task_type == "multiclass_classification":
171 | import keras
172 | from sklearn.preprocessing import LabelEncoder
173 |
174 | # fit label encoder on all tracks and labels
175 | self.label_encoder = LabelEncoder()
176 | self.label_encoder.fit(labels_list)
177 | # get encoded labels
178 | encoded_labels_list = self.label_encoder.transform(labels_list)
179 | encoded_categorical_labels_list = keras.utils.to_categorical(
180 | encoded_labels_list, num_classes=len(set(labels_list))
181 | )
182 | self.encoded_labels = {
183 | track_ids[i]: encoded_categorical_labels_list[i]
184 | for i in range(len(track_ids))
185 | }
186 |
187 | elif self.task_type == "multilabel_classification":
188 | from sklearn.preprocessing import MultiLabelBinarizer
189 |
190 | # fit label encoder on all tracks and labels
191 | self.label_encoder = MultiLabelBinarizer()
192 | self.label_encoder.fit(labels_list)
193 | # get encoded labels
194 | encoded_labels_list = self.label_encoder.transform(labels_list)
195 | self.encoded_labels = {
196 | track_ids[i]: encoded_labels_list[i] for i in range(len(track_ids))
197 | }
198 |
199 | else:
200 | raise NotImplementedError
201 |
202 | def load_common_audio_dir(self):
203 | """Get the deepest common directory between all audio paths.
204 | This will later be used as a reference for creating the dir
205 | structure for deformed audio and embeddings."""
206 |
207 | self.common_audio_dir = str(os.path.commonpath(self.audio_paths.values()))
208 |
209 | def get_deformed_audio_path(self, track_id, deform_idx):
210 | """Get path of deformed audio based on audio path and deform_idx."""
211 |
212 | audio_path = Path(self.audio_paths[track_id])
213 | new_filestem = f"{audio_path.stem}_deform_{deform_idx}"
214 | return str(
215 | Path(self.data_dir)
216 | / "audio_deformed"
217 | / audio_path.relative_to(self.common_audio_dir).with_name(
218 | f"{new_filestem}{audio_path.suffix}"
219 | )
220 | )
221 |
222 | def get_embedding_path(self, track_id, feature):
223 | """Get path of embedding based on audio path and embedding model."""
224 |
225 | return str(
226 | Path(self.data_dir)
227 | / "embeddings"
228 | / feature
229 | / Path(self.audio_paths[track_id])
230 | .relative_to(self.common_audio_dir)
231 | .with_suffix(f".npy")
232 | )
233 |
234 | def get_deformed_embedding_path(self, track_id, deform_idx, feature):
235 | """Get path of deformed embedding based on audio path, embedding model
236 | and deform_idx."""
237 |
238 | return str(
239 | Path(self.data_dir)
240 | / "embeddings"
241 | / feature
242 | / Path(self.audio_paths[track_id])
243 | .relative_to(self.common_audio_dir)
244 | .with_name(
245 | f"{Path(self.audio_paths[track_id]).stem}_deform_{deform_idx}.npy"
246 | )
247 | )
248 |
249 | def get_stratified_split(self, sizes=(0.8, 0.1, 0.1), seed=42):
250 | """Helper method to generate a stratified split of the dataset.
251 |
252 | Args:
253 | sizes (tuple, optional): Sizes of train, validation and test set.
254 | Defaults to (0.8, 0.1, 0.1), must add up to 1.
255 | seed (int, optional): Random seed. Defaults to 42.
256 | """
257 | from sklearn.model_selection import train_test_split
258 |
259 | if sum(sizes) != 1:
260 | raise ValueError("Sizes must add up to 1.")
261 |
262 | X = self.track_ids
263 | y = [self.labels[track_id] for track_id in self.track_ids]
264 | X_train, X_others, y_train, y_others = train_test_split(
265 | X,
266 | y,
267 | test_size=1 - sizes[0],
268 | random_state=seed,
269 | stratify=y,
270 | )
271 | X_val, X_test, y_val, y_test = train_test_split(
272 | X_others,
273 | y_others,
274 | test_size=sizes[2] / (sizes[1] + sizes[2]),
275 | random_state=seed,
276 | stratify=y_others,
277 | )
278 | return {"train": X_train, "validation": X_val, "test": X_test}
279 |
280 | def load_metadata(self):
281 | self.load_track_ids()
282 | self.load_labels()
283 | self.load_encoded_labels()
284 | self.load_audio_paths()
285 | self.load_common_audio_dir()
286 |
287 | def encode_label(self, label):
288 | return self.label_encoder.transform([label])[0]
289 |
290 | def decode_label(self, encoded_label):
291 | if self.task_type == "multiclass_classification":
292 | import numpy as np
293 |
294 | # get index maximum value
295 | encoded_label = np.argmax(encoded_label)
296 | return self.label_encoder.inverse_transform([encoded_label])[0]
297 | return self.label_encoder.inverse_transform([encoded_label])[0]
298 |
--------------------------------------------------------------------------------
/mir_ref/datasets/datasets/magnatagatune.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os.path
3 | import zipfile
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import wget
8 | from tqdm import tqdm
9 |
10 | from mir_ref.datasets.dataset import Dataset
11 |
12 |
13 | class MagnaTagATune(Dataset):
14 | def __init__(
15 | self,
16 | name,
17 | dataset_type,
18 | data_dir,
19 | split_type,
20 | task_name,
21 | task_type,
22 | feature_aggregation,
23 | deformations_cfg,
24 | features_cfg,
25 | ):
26 | """Dataset wrapper for MagnaTagATune."""
27 | super().__init__(
28 | name=name,
29 | dataset_type=dataset_type,
30 | data_dir=data_dir,
31 | split_type=split_type,
32 | task_name=task_name,
33 | task_type=task_type,
34 | feature_aggregation=feature_aggregation,
35 | deformations_cfg=deformations_cfg,
36 | features_cfg=features_cfg,
37 | )
38 |
39 | @Dataset.try_to_load_metadata
40 | def download(self):
41 | # make data dir if it doesn't exist or if it exists but is empty
42 | if os.path.exists(os.path.join(self.data_dir, "audio")) and (
43 | len(os.listdir(os.path.join(self.data_dir, "audio"))) != 0
44 | ):
45 | import warnings
46 |
47 | warnings.warn(
48 | f"Dataset '{self.name}' already exists in '{self.data_dir}'."
49 | + "Skipping audio download.",
50 | stacklevel=2,
51 | )
52 | self.download_metadata()
53 | return
54 | (Path(self.data_dir) / "audio").mkdir(parents=True, exist_ok=True)
55 |
56 | print(f"Downloading MagnaTagATune to {self.data_dir}...")
57 | for i in tqdm(["001", "002", "003"]):
58 | wget.download(
59 | url=f"https://mirg.city.ac.uk/datasets/magnatagatune/mp3.zip.{i}",
60 | out=os.path.join(self.data_dir, "audio/"),
61 | )
62 |
63 | archive_dir = os.path.join(self.data_dir, "audio")
64 |
65 | # Combine the split archive files into a single file
66 | with open(os.path.join(archive_dir, "mp3.zip"), "wb") as f:
67 | for i in ["001", "002", "003"]:
68 | with open(
69 | os.path.join(archive_dir, f"mp3.zip.{i}"),
70 | "rb",
71 | ) as part:
72 | f.write(part.read())
73 |
74 | # Extract the contents of the archive
75 | with zipfile.ZipFile(os.path.join(archive_dir, "mp3.zip"), "r") as zip_ref:
76 | zip_ref.extractall()
77 |
78 | # Remove zips
79 | for i in ["", ".001", ".002", ".003"]:
80 | os.remove(os.path.join(archive_dir, f"mp3.zip{i}"))
81 |
82 | self.download_metadata()
83 |
84 | @Dataset.try_to_load_metadata
85 | def download_metadata(self):
86 | if os.path.exists(os.path.join(self.data_dir, "metadata")) and (
87 | len(os.listdir(os.path.join(self.data_dir, "metadata"))) != 0
88 | ):
89 | import warnings
90 |
91 | warnings.warn(
92 | f"Metadata for dataset '{self.name}' already exists in '{self.data_dir}'."
93 | + "Skipping metadata download.",
94 | stacklevel=2,
95 | )
96 | return
97 | (Path(self.data_dir) / "metadata").mkdir(parents=True, exist_ok=True)
98 |
99 | urls = [
100 | # annotations
101 | "https://mirg.city.ac.uk/datasets/magnatagatune/annotations_final.csv",
102 | # train, validation, and test splits from Won et al. 2020
103 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/train.npy",
104 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/valid.npy",
105 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/test.npy",
106 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/tags.npy",
107 | ]
108 | for url in urls:
109 | wget.download(
110 | url=url,
111 | out=os.path.join(self.data_dir, "metadata/"),
112 | )
113 |
114 | def load_track_ids(self):
115 | with open(
116 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r"
117 | ) as f:
118 | annotations = csv.reader(f, delimiter="\t")
119 | next(annotations) # skip header
120 | self.track_ids = [line[0] for line in annotations]
121 | # manually remove some corrupt files
122 | self.track_ids.remove("35644")
123 | self.track_ids.remove("55753")
124 | self.track_ids.remove("57881")
125 |
126 | def load_labels(self):
127 | # get the list of top 50 tags used in Minz Won et al. 2020
128 | tags = np.load(os.path.join(self.data_dir, "metadata", "tags.npy"))
129 |
130 | with open(
131 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r"
132 | ) as f:
133 | annotations = csv.reader(f, delimiter="\t")
134 | annotations_header = next(annotations)
135 | self.labels = {
136 | line[0]: [
137 | annotations_header[j]
138 | for j in range(1, len(line) - 1)
139 | # only add the tag if it's in the tags list
140 | if line[j] == "1" and annotations_header[j] in tags
141 | ]
142 | for line in annotations
143 | # this is a slow way to do it, temporary fix for
144 | # some corrupt mp3s
145 | if line[0] in self.track_ids
146 | }
147 |
148 | def load_audio_paths(self):
149 | with open(
150 | os.path.join(self.data_dir, "metadata", "annotations_final.csv"), "r"
151 | ) as f:
152 | annotations = csv.reader(f, delimiter="\t")
153 | next(annotations) # skip header
154 | self.audio_paths = {
155 | line[0]: os.path.join(self.data_dir, "audio", line[-1])
156 | for line in annotations
157 | # this is a slow way to do it, temporary fix for
158 | # some corrupt mp3s
159 | if line[0] in self.track_ids
160 | }
161 |
162 | @Dataset.check_metadata_is_loaded
163 | def get_splits(self):
164 | # get inverse dictionary to get track id from audio path
165 | rel_path_to_track_id = {
166 | (Path(v).parent.name + "/" + Path(v).name): k
167 | for k, v in self.audio_paths.items()
168 | }
169 |
170 | split = {}
171 | for set_filename, set_targetname in zip(
172 | ["train", "valid", "test"], ["train", "validation", "test"]
173 | ):
174 | relative_paths = np.load(
175 | os.path.join(self.data_dir, "metadata", f"{set_filename}.npy")
176 | )
177 | # get track_ids by getting the full path and using the inv dict
178 | split[set_targetname] = [
179 | rel_path_to_track_id[path.split("\t")[1]] for path in relative_paths
180 | ]
181 |
182 | if self.split_type not in ["all", "single"]:
183 | raise NotImplementedError(f"Split type '{self.split_type}' does not exist.")
184 |
185 | return [split]
186 |
--------------------------------------------------------------------------------
/mir_ref/datasets/datasets/mtg_jamendo.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import hashlib
3 | import os.path
4 | import sys
5 | import tarfile
6 | from pathlib import Path
7 |
8 | import tqdm
9 | import wget
10 |
11 | from mir_ref.datasets.dataset import Dataset
12 |
13 |
14 | class MTG_Jamendo(Dataset):
15 | def __init__(
16 | self,
17 | name,
18 | dataset_type,
19 | data_dir,
20 | split_type,
21 | task_name,
22 | task_type,
23 | feature_aggregation,
24 | deformations_cfg,
25 | features_cfg,
26 | ):
27 | """Dataset wrapper for MTG Jamendo (sub)dataset(s). Each subset
28 | (moodtheme, instrument, top50tags, genre) is going to be treated
29 | as a separate dataset, meaning a separate data_dir needs to be
30 | specified for each one. This helps disambiguate versioning of
31 | the deformations and experiments. Since the methods are shared,
32 | however, a single class is used for all of them.
33 | """
34 | super().__init__(
35 | name=name,
36 | dataset_type=dataset_type,
37 | data_dir=data_dir,
38 | split_type=split_type,
39 | task_name=task_name,
40 | task_type=task_type,
41 | feature_aggregation=feature_aggregation,
42 | deformations_cfg=deformations_cfg,
43 | features_cfg=features_cfg,
44 | )
45 |
46 | @Dataset.try_to_load_metadata
47 | def download(self):
48 | # make data dir if it doesn't exist or if it exists but is empty
49 | if os.path.exists(os.path.join(self.data_dir, "audio")) and (
50 | len(os.listdir(os.path.join(self.data_dir, "audio"))) != 0
51 | ):
52 | import warnings
53 |
54 | warnings.warn(
55 | f"Dataset '{self.name}' already exists in '{self.data_dir}'."
56 | + "Skipping audio download.",
57 | stacklevel=2,
58 | )
59 | self.download_metadata()
60 | return
61 | (Path(self.data_dir) / "audio").mkdir(parents=True, exist_ok=True)
62 | # download with mtg_jamendo download helper
63 | if self.name == "mtg-jamendo-moodtheme":
64 | # only moodtheme has a separate download target
65 | download_jamendo(
66 | dataset="autotagging_moodtheme",
67 | data_type="audio-low",
68 | download_from="mtg-fast",
69 | output_dir=os.path.join(self.data_dir, "audio/"),
70 | unpack_tars=True,
71 | remove_tars=True,
72 | )
73 | elif (
74 | self.name == "mtg-jamendo-instrument"
75 | or self.name == "mtg-jamendo-genre"
76 | or self.name == "mtg-jamendo-top50tags"
77 | ):
78 | # the whole dataset needs to be downloaded as no subset-specific
79 | # download targets are available
80 | download_jamendo(
81 | dataset="raw_30s",
82 | data_type="audio-low",
83 | download_from="mtg-fast",
84 | output_dir=os.path.join(self.data_dir, "audio/"),
85 | unpack_tars=True,
86 | remove_tars=True,
87 | )
88 | # optionally I'd delete unneeded tracks here...
89 | else:
90 | raise NotImplementedError(f"Dataset '{self.name}' does not exist.")
91 | self.download_metadata()
92 |
93 | @Dataset.try_to_load_metadata
94 | def download_metadata(self):
95 | # download from github link
96 | url = (
97 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/"
98 | + "master/data/autotagging_"
99 | + (self.name).split("-")[2]
100 | + ".tsv"
101 | )
102 | (Path(self.data_dir) / "metadata").mkdir(parents=True, exist_ok=True)
103 | wget.download(url, out=os.path.join(self.data_dir, "metadata/"))
104 |
105 | def load_track_ids(self):
106 | # open the corresponding file using the subset name in self.name
107 | with open(
108 | os.path.join(
109 | self.data_dir,
110 | "metadata/",
111 | f"autotagging_{(self.name).split('-')[2]}.tsv",
112 | ),
113 | "r",
114 | ) as f:
115 | metadata = f.readlines()
116 | self.track_ids = [
117 | line.split("\t")[0].strip() for line in metadata[1:] if line
118 | ]
119 |
120 | def load_labels(self):
121 | with open(
122 | os.path.join(
123 | self.data_dir,
124 | "metadata/",
125 | f"autotagging_{(self.name).split('-')[2]}.tsv",
126 | ),
127 | "r",
128 | ) as f:
129 | metadata = f.readlines()
130 | # for each line, get key (track_id) and all stripped entries after
131 | # column 5 (list of labels)
132 | self.labels = {
133 | line.split("\t")[0].strip(): [
134 | tag.strip() for tag in line.split("\t")[5:]
135 | ]
136 | for line in metadata[1:]
137 | }
138 |
139 | def load_audio_paths(self):
140 | with open(
141 | os.path.join(
142 | self.data_dir,
143 | "metadata/",
144 | f"autotagging_{(self.name).split('-')[2]}.tsv",
145 | ),
146 | "r",
147 | ) as f:
148 | metadata = f.readlines()
149 | # the lowres version we downloaded contains "low" before the extension
150 | self.audio_paths = {
151 | line.split("\t")[0].strip(): os.path.join(
152 | self.data_dir,
153 | "audio",
154 | (line.split("\t")[3].strip())[:-4] + ".low.mp3",
155 | )
156 | for line in metadata[1:]
157 | if line
158 | }
159 |
160 | @Dataset.check_metadata_is_loaded
161 | def get_splits(self):
162 | subset = (self.name).split("-")[2]
163 | splits_url_dir = (
164 | "https://github.com/MTG/mtg-jamendo-dataset/blob/master/data/splits/"
165 | )
166 |
167 | if not os.path.exists(os.path.join(self.data_dir, "splits/")):
168 | # there are 5 splits for each subset
169 | for i in range(5):
170 | (Path(self.data_dir) / "metadata" / "splits").mkdir(
171 | parents=True, exist_ok=True
172 | )
173 | for split in ["train", "validation", "test"]:
174 | url = f"{splits_url_dir}split-{i}/autotagging_{subset}-{split}.tsv"
175 | wget.download(
176 | url,
177 | out=os.path.join(
178 | self.data_dir, "metadata", "splits", f"split-{i}"
179 | ),
180 | )
181 | splits = []
182 | for i in range(5):
183 | splits.append({})
184 | for split in ["train", "validation", "test"]:
185 | with open(
186 | os.path.join(
187 | self.data_dir,
188 | "metadata",
189 | "splits",
190 | f"split-{i}",
191 | f"autotagging_{subset}-{split}.tsv",
192 | ),
193 | "r",
194 | ) as f:
195 | split_metadata = f.readlines()
196 | self.splits[i][split] = [
197 | line.split("\t")[0].strip()
198 | for line in split_metadata[1:]
199 | if line
200 | ]
201 | if self.split_type == "single":
202 | return [splits[0]]
203 | elif self.split_type == "all":
204 | return splits
205 | else:
206 | raise NotImplementedError(f"Split type '{self.split_type}' does not exist.")
207 |
208 |
209 | """Code to download MTG Jamendo.
210 | Source: https://github.com/MTG/mtg-jamendo-dataset
211 | """
212 |
213 | download_from_names = {"gdrive": "GDrive", "mtg": "MTG", "mtg-fast": "MTG Fast mirror"}
214 |
215 |
216 | def compute_sha256(filename):
217 | with open(filename, "rb") as f:
218 | contents = f.read()
219 | checksum = hashlib.sha256(contents).hexdigest()
220 | return checksum
221 |
222 |
223 | def download_jamendo(
224 | dataset, data_type, download_from, output_dir, unpack_tars, remove_tars
225 | ):
226 | if not os.path.exists(output_dir):
227 | print("Output directory {} does not exist".format(output_dir), file=sys.stderr)
228 | return
229 |
230 | if download_from not in download_from_names:
231 | print(
232 | "Unknown --from argument, choices are {}".format(
233 | list(download_from_names.keys())
234 | ),
235 | file=sys.stderr,
236 | )
237 | return
238 |
239 | print("Downloading %s from %s" % (dataset, download_from_names[download_from]))
240 | # download checksums
241 | file_sha256_tars_url = (
242 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/master/data/download/"
243 | + f"{dataset}_{data_type}_sha256_tars.txt"
244 | )
245 | wget.download(file_sha256_tars_url, out=output_dir)
246 | file_sha256_tracks_url = (
247 | "https://raw.githubusercontent.com/MTG/mtg-jamendo-dataset/master/data/download/"
248 | + f"{dataset}_{data_type}_sha256_tracks.txt"
249 | )
250 | wget.download(file_sha256_tracks_url, out=output_dir)
251 |
252 | file_sha256_tars = os.path.join(
253 | output_dir, dataset + "_" + data_type + "_sha256_tars.txt"
254 | )
255 | file_sha256_tracks = os.path.join(
256 | output_dir, dataset + "_" + data_type + "_sha256_tracks.txt"
257 | )
258 |
259 | # Read checksum values for tars and files inside.
260 | with open(file_sha256_tars) as f:
261 | sha256_tars = dict([(row[1], row[0]) for row in csv.reader(f, delimiter=" ")])
262 |
263 | with open(file_sha256_tracks) as f:
264 | sha256_tracks = dict([(row[1], row[0]) for row in csv.reader(f, delimiter=" ")])
265 |
266 | # Filenames to download.
267 | ids = sha256_tars.keys()
268 |
269 | removed = []
270 | for filename in tqdm(ids, total=len(ids)):
271 | output = os.path.join(output_dir, filename)
272 | # print(output)
273 | # print(filename)
274 |
275 | # Download from Google Drive.
276 | if os.path.exists(output):
277 | print("Skipping %s (file already exists)" % output)
278 | continue
279 |
280 | elif download_from == "mtg":
281 | url = (
282 | "https://essentia.upf.edu/documentation/datasets/mtg-jamendo/"
283 | "%s/%s/%s" % (dataset, data_type, filename)
284 | )
285 | # print("From:", url)
286 | # print("To:", output)
287 | wget.download(url, out=output)
288 |
289 | elif download_from == "mtg-fast":
290 | url = "https://cdn.freesound.org/mtg-jamendo/" "%s/%s/%s" % (
291 | dataset,
292 | data_type,
293 | filename,
294 | )
295 | # print("From:", url)
296 | # print("To:", output)
297 | wget.download(url, out=output)
298 |
299 | # Validate the checksum.
300 | if compute_sha256(output) != sha256_tars[filename]:
301 | print(
302 | "%s does not match the checksum, removing the file" % output,
303 | file=sys.stderr,
304 | )
305 | removed.append(filename)
306 | os.remove(output)
307 | # else:
308 | # print("%s checksum OK" % filename)
309 |
310 | if removed:
311 | print("Missing files:", " ".join(removed))
312 | print("Re-run the script again")
313 | return
314 |
315 | print("Download complete")
316 |
317 | if unpack_tars:
318 | print("Unpacking tar archives")
319 |
320 | tracks_checked = []
321 | for filename in tqdm(ids, total=len(ids)):
322 | output = os.path.join(output_dir, filename)
323 | print("Unpacking", output)
324 | tar = tarfile.open(output)
325 | tracks = tar.getnames()[1:] # The first element is folder name.
326 | tar.extractall(path=output_dir)
327 | tar.close()
328 |
329 | # Validate checksums for all unpacked files
330 | for track in tracks:
331 | trackname = os.path.join(output_dir, track)
332 | if compute_sha256(trackname) != sha256_tracks[track]:
333 | print("%s does not match the checksum" % trackname, file=sys.stderr)
334 | raise Exception("Corrupt file in the dataset: %s" % trackname)
335 |
336 | # print("%s track checksums OK" % filename)
337 | tracks_checked += tracks
338 |
339 | if remove_tars:
340 | os.remove(output)
341 |
342 | # Check if any tracks are missing in the unpacked archives.
343 | if set(tracks_checked) != set(sha256_tracks.keys()):
344 | raise Exception(
345 | "Unpacked data contains tracks not present in the checksum files"
346 | )
347 |
348 | print("Unpacking complete")
349 |
350 | # delete checksum files
351 | os.remove(file_sha256_tars)
352 | os.remove(file_sha256_tracks)
353 |
--------------------------------------------------------------------------------
/mir_ref/datasets/datasets/vocalset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import subprocess
4 | from pathlib import Path
5 |
6 | import wget
7 |
8 | from mir_ref.datasets.dataset import Dataset
9 |
10 |
11 | class VocalSet(Dataset):
12 | def __init__(
13 | self,
14 | name,
15 | dataset_type,
16 | data_dir,
17 | split_type,
18 | task_name,
19 | task_type,
20 | feature_aggregation,
21 | deformations_cfg,
22 | features_cfg,
23 | ):
24 | """Dataset wrapper for VocalSet dataset."""
25 | super().__init__(
26 | name=name,
27 | dataset_type=dataset_type,
28 | data_dir=data_dir,
29 | split_type=split_type,
30 | task_name=task_name,
31 | task_type=task_type,
32 | feature_aggregation=feature_aggregation,
33 | deformations_cfg=deformations_cfg,
34 | features_cfg=features_cfg,
35 | )
36 |
37 | @Dataset.try_to_load_metadata
38 | def download(self):
39 | if Path(self.data_dir).exists():
40 | import warnings
41 |
42 | warnings.warn(
43 | (
44 | f"Dataset {self.name} already exists at {self.data_dir}."
45 | + " Skipping download."
46 | )
47 | )
48 | self.preprocess()
49 | return
50 | # make data dir
51 | Path(self.data_dir).mkdir(parents=True, exist_ok=False)
52 | zenodo_url = "https://zenodo.org/record/1203819/files/VocalSet11.zip"
53 | print(f"Downloading VocalSet to {self.data_dir}...")
54 | wget.download(zenodo_url, self.data_dir)
55 | # extract zip
56 | subprocess.run(
57 | [
58 | "unzip",
59 | "-q",
60 | "-d",
61 | str(Path(self.data_dir)),
62 | str(Path(self.data_dir) / "VocalSet11.zip"),
63 | ]
64 | )
65 | # remove zip
66 | subprocess.run(["rm", str(Path(self.data_dir) / "VocalSet11.zip")])
67 |
68 | # preprocess
69 | self.preprocess()
70 |
71 | def preprocess(self):
72 | # need to make some corrections to filenames and delete some duplicates
73 | # Thanks to the MARBLE authors for the list of corrections and dups:
74 | # https://github.com/a43992899/MARBLE-Benchmark/blob/main/benchmark/tasks/VocalSet/preprocess.py
75 |
76 | if not (
77 | Path(self.data_dir)
78 | / "FULL"
79 | / "**"
80 | / "*"
81 | / "vibrato/f2_scales_vibrato_a(1).wav"
82 | ).exists():
83 | # dataset is probably already preprocessed
84 | return
85 |
86 | file_delete = [
87 | "vibrato/f2_scales_vibrato_a(1).wav",
88 | "vibrato/caro_vibrato.wav",
89 | "vibrato/dona_vibrato.wav",
90 | "vibrato/row_vibrato.wav",
91 | "vibrado/slow_vibrato_arps.wav",
92 | ]
93 |
94 | filepaths_to_delete = [
95 | glob.glob(str(Path(self.data_dir) / "FULL" / "**" / "*" / f))
96 | for f in file_delete
97 | ]
98 | # flatten list of lists
99 | filepaths_to_delete = [
100 | item for sublist in filepaths_to_delete for item in sublist
101 | ]
102 | # delete files
103 | for f in filepaths_to_delete:
104 | os.remove(f)
105 |
106 | # thanks black formatter for this beauty
107 | name_correction = [
108 | ("/lip_trill/lip_trill_arps.wav", "/lip_trill/f2_lip_trill_arps.wav"),
109 | ("/lip_trill/scales_lip_trill.wav", "/lip_trill/m3_scales_lip_trill.wav"),
110 | (
111 | "/straight/arpeggios_straight_a.wav",
112 | "/straight/f4_arpeggios_straight_a.wav",
113 | ),
114 | (
115 | "/straight/arpeggios_straight_e.wav",
116 | "/straight/f4_arpeggios_straight_e.wav",
117 | ),
118 | (
119 | "/straight/arpeggios_straight_i.wav",
120 | "/straight/f4_arpeggios_straight_i.wav",
121 | ),
122 | (
123 | "/straight/arpeggios_straight_o.wav",
124 | "/straight/f4_arpeggios_straight_o.wav",
125 | ),
126 | (
127 | "/straight/arpeggios_straight_u.wav",
128 | "/straight/f4_arpeggios_straight_u.wav",
129 | ),
130 | ("/straight/row_straight.wav", "/straight/m8_row_straight.wav"),
131 | ("/straight/scales_straight_a.wav", "/straight/f4_scales_straight_a.wav"),
132 | ("/straight/scales_straight_e.wav", "/straight/f4_scales_straight_e.wav"),
133 | ("/straight/scales_straight_i.wav", "/straight/f4_scales_straight_i.wav"),
134 | ("/straight/scales_straight_o.wav", "/straight/f4_scales_straight_o.wav"),
135 | ("/straight/scales_straight_u.wav", "/straight/f4_scales_straight_u.wav"),
136 | ("/vocal_fry/scales_vocal_fry.wav", "/vocal_fry/f2_scales_vocal_fry.wav"),
137 | (
138 | "/fast_forte/arps_fast_piano_c.wav",
139 | "/fast_forte/f9_arps_fast_piano_c.wav",
140 | ),
141 | (
142 | "/fast_piano/fast_piano_arps_f.wav",
143 | "/fast_piano/f2_fast_piano_arps_f.wav",
144 | ),
145 | (
146 | "/fast_piano/arps_c_fast_piano.wav",
147 | "/fast_piano/m3_arps_c_fast_piano.wav",
148 | ),
149 | (
150 | "/fast_piano/scales_fast_piano_f.wav",
151 | "/fast_piano/f3_scales_fast_piano_f.wav",
152 | ),
153 | (
154 | "/fast_piano/scales_c_fast_piano_a.wav",
155 | "/fast_piano/m10_scales_c_fast_piano_a.wav",
156 | ),
157 | (
158 | "/fast_piano/scales_c_fast_piano_e.wav",
159 | "/fast_piano/m10_scales_c_fast_piano_e.wav",
160 | ),
161 | (
162 | "/fast_piano/scales_c_fast_piano_i.wav",
163 | "/fast_piano/m10_scales_c_fast_piano_i.wav",
164 | ),
165 | (
166 | "/fast_piano/scales_c_fast_piano_o.wav",
167 | "/fast_piano/m10_scales_c_fast_piano_o.wav",
168 | ),
169 | (
170 | "/fast_piano/scales_c_fast_piano_u.wav",
171 | "/fast_piano/m10_scales_c_fast_piano_u.wav",
172 | ),
173 | (
174 | "/fast_piano/scales_f_fast_piano_a.wav",
175 | "/fast_piano/m10_scales_f_fast_piano_a.wav",
176 | ),
177 | (
178 | "/fast_piano/scales_f_fast_piano_e.wav",
179 | "/fast_piano/m10_scales_f_fast_piano_e.wav",
180 | ),
181 | (
182 | "/fast_piano/scales_f_fast_piano_i.wav",
183 | "/fast_piano/m10_scales_f_fast_piano_i.wav",
184 | ),
185 | (
186 | "/fast_piano/scales_f_fast_piano_o.wav",
187 | "/fast_piano/m10_scales_f_fast_piano_o.wav",
188 | ),
189 | (
190 | "/fast_piano/scales_f_fast_piano_u.wav",
191 | "/fast_piano/m10_scales_f_fast_piano_u.wav",
192 | ),
193 | ]
194 |
195 | for old, new in name_correction:
196 | old_matches = glob.glob(
197 | str(Path(self.data_dir) / "FULL" / "**" / "*" / old[1:])
198 | )
199 | target_filepaths = [f.replace(old, new) for f in old_matches]
200 | # rename
201 | for old, new in zip(old_matches, target_filepaths):
202 | os.rename(old, new)
203 |
204 | @Dataset.try_to_load_metadata
205 | def download_metadata(self):
206 | # not possible to download only metadata for vocalset
207 | self.download()
208 |
209 | def load_track_ids(self):
210 | # the names of the tracks can be used as a unique identifier, as they contain
211 | # singer, technique, and take index information.
212 |
213 | if self.task_name == "singer_identification":
214 | self.track_ids = [
215 | Path(path).stem
216 | for path in glob.glob(
217 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True
218 | )
219 | ]
220 | elif self.task_name == "technique_identification":
221 | # use the 10 techniques used in the original paper
222 | techniques = [
223 | "vibrato",
224 | "straight",
225 | "belt",
226 | "breathy",
227 | "lip_trill",
228 | "spoken",
229 | "inhaled",
230 | "trill",
231 | "trillo",
232 | "vocal_fry",
233 | ]
234 | # only add track_ids that contain one of the techniques
235 | self.track_ids = [
236 | Path(path).stem
237 | for path in glob.glob(
238 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True
239 | )
240 | if any(technique in Path(path).stem for technique in techniques)
241 | ]
242 |
243 | def load_labels(self):
244 | if self.task_name == "singer_identification":
245 | self.labels = {
246 | Path(path).stem: (Path(path).stem)[:2]
247 | if ((Path(path).stem)[:3] != "m10" and (Path(path).stem)[:3] != "m11")
248 | else (Path(path).stem)[:3]
249 | for path in glob.glob(
250 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True
251 | )
252 | }
253 | elif self.task_name == "technique_identification":
254 | # use the 10 techniques used in the original paper
255 | techniques = [
256 | "vibrato",
257 | "straight",
258 | "belt",
259 | "breathy",
260 | "lip_trill",
261 | "spoken",
262 | "inhaled",
263 | "trill",
264 | "trillo",
265 | "vocal_fry",
266 | ]
267 | labels = {}
268 | for track_id in self.track_ids:
269 | for technique in techniques:
270 | if technique in track_id:
271 | labels[track_id] = technique
272 | self.labels = labels
273 | else:
274 | raise NotImplementedError(
275 | f"Task '{self.task_name}' not available for this dataset."
276 | )
277 |
278 | def load_audio_paths(self):
279 | audio_paths_list = [
280 | path
281 | for path in glob.glob(
282 | str(Path(self.data_dir) / "FULL/**/*.wav"), recursive=True
283 | )
284 | ]
285 | self.audio_paths = {Path(path).stem: path for path in audio_paths_list}
286 |
287 | @Dataset.check_metadata_is_loaded
288 | def get_splits(self):
289 | if self.task_name == "singer_identification":
290 | # no official splits are available, get stratified one
291 | return [super().get_stratified_split()]
292 | elif self.task_name == "technique_identification":
293 | train_singers = [
294 | "f1",
295 | "f3",
296 | "f4",
297 | "f5",
298 | "f6",
299 | "f7",
300 | "f9",
301 | "m1",
302 | "m2",
303 | "m4",
304 | "m6",
305 | "m7",
306 | "m8",
307 | "m9",
308 | ]
309 | # the original paper does not have a validation set, so we steal f9,
310 | # m9, and m11 from train
311 | val_singers = ["f9", "m9", "m11"]
312 | test_singer = ["f2", "f8", "m3", "m5", "m10"]
313 |
314 | split = {}
315 | split["train"] = [
316 | track_id for track_id in self.track_ids if track_id[:2] in train_singers
317 | ]
318 | split["validation"] = [
319 | track_id for track_id in self.track_ids if track_id[:2] in val_singers
320 | ]
321 | split["test"] = [
322 | track_id for track_id in self.track_ids if track_id[:2] in test_singer
323 | ]
324 | return [split]
325 | else:
326 | raise NotImplementedError(
327 | f"Task '{self.task_name}' not available for this dataset."
328 | )
329 |
--------------------------------------------------------------------------------
/mir_ref/datasets/mirdata_datasets.py:
--------------------------------------------------------------------------------
1 | """Wrapper for mirdata datasets, using relevant functions
2 | and adjusting them based on task requirements."""
3 |
4 | import mirdata
5 |
6 | from mir_ref.datasets.dataset import Dataset
7 |
8 |
9 | class MirdataDataset(Dataset):
10 | def __init__(
11 | self,
12 | name,
13 | dataset_type,
14 | data_dir,
15 | split_type,
16 | task_name,
17 | task_type,
18 | feature_aggregation,
19 | deformations_cfg,
20 | features_cfg,
21 | ):
22 | """Dataset wrapper for mirdata dataset.
23 |
24 | Args:
25 | name (str): Name of the dataset.
26 | dataset_type (str): Type of the dataset ("mirdata", "custom")
27 | task_name (str): Name of the task.
28 | task_type (str): Type of the task.
29 | data_dir (str, optional): Path to the dataset directory.
30 | Defaults to ./data/{name}/.
31 | split_type (str, optional): Whether to use "all" or "single" split.
32 | Defaults to "single".
33 | deformations_cfg (list, optional): List of deformation scenarios.
34 | features_cfg (list, optional): List of embedding models.
35 | """
36 | super().__init__(
37 | name=name,
38 | dataset_type=dataset_type,
39 | data_dir=data_dir,
40 | split_type=split_type,
41 | task_name=task_name,
42 | task_type=task_type,
43 | feature_aggregation=feature_aggregation,
44 | deformations_cfg=deformations_cfg,
45 | features_cfg=features_cfg,
46 | )
47 | # initialize mirdata dataset
48 | self.dataset = mirdata.initialize(
49 | dataset_name=self.name, data_home=self.data_dir
50 | )
51 |
52 | def download(self):
53 | self.dataset.download()
54 | self.dataset.validate(verbose=False)
55 |
56 | # try to load metadata again
57 | self.load_metadata()
58 |
59 | def download_metadata(self):
60 | try:
61 | self.dataset.download(partial_download=["metadata"])
62 | except ValueError:
63 | self.dataset.download(partial_download=["annotations"])
64 | self.dataset.validate(verbose=False)
65 |
66 | # try to load metadata again
67 | self.load_metadata()
68 |
69 | def preprocess(self):
70 | """Modifications to the downloaded content of the dataset."""
71 | return
72 |
73 | def load_track_ids(self):
74 | if self.task_name == "pitch_class_estimation":
75 | if self.name == "tinysol":
76 | track_ids = self.dataset.track_ids
77 | # only keep track_ids with single pitch annotations
78 | for track_id in track_ids:
79 | if len(self.dataset.track(track_id).pitch) != 1:
80 | track_ids.remove(track_id)
81 | elif self.task_name == "pitch_register_estimation":
82 | if self.name == "tinysol":
83 | track_ids = self.dataset.track_ids
84 | # only keep track_ids with single pitch annotations
85 | for track_id in track_ids:
86 | if len(self.dataset.track(track_id).pitch) != 1:
87 | track_ids.remove(track_id)
88 | elif self.task_name == "key_estimation":
89 | if self.name == "beatport_key":
90 | # only keep track_ids with single key annotations
91 | track_ids = [
92 | track_id
93 | for track_id in self.dataset.track_ids
94 | if len(self.dataset.track(track_id).key) == 1
95 | and len(self.dataset.track(track_id).key[0].split(" ")) == 2
96 | and "other" not in self.dataset.track(track_id).key[0]
97 | ]
98 | else:
99 | track_ids = self.dataset.track_ids
100 |
101 | self.track_ids = track_ids
102 |
103 | def load_labels(self):
104 | labels = {}
105 | for track_id in self.track_ids:
106 | if (
107 | self.task_name == "instrument_recognition"
108 | or self.task_name == "instrument_classification"
109 | ):
110 | if self.name == "tinysol":
111 | labels[track_id] = self.dataset.track(track_id).instrument_full
112 | elif self.task_name == "tagging":
113 | labels[track_id] = self.dataset.track(track_id).tags
114 | elif self.task_name == "pitch_class_classification":
115 | pitch = self.dataset.track(track_id).pitch
116 | labels[track_id] = "".join([c for c in pitch if not c.isdigit()])
117 | elif self.task_name == "pitch_register_classification":
118 | pitch = self.dataset.track(track_id).pitch
119 | labels[track_id] = "".join([c for c in pitch if c.isdigit()])
120 | elif self.task_name == "key_estimation":
121 | # map enharmonic keys, always use sharps
122 | enharm_map = {
123 | "Db": "C#",
124 | "Eb": "D#",
125 | "Gb": "F#",
126 | "Ab": "G#",
127 | "Bb": "A#",
128 | "F#_": "F#", # ok yes, that's an annotation fix
129 | }
130 | key = self.dataset.track(track_id).key[0].strip()
131 | for pitch_class in enharm_map.keys():
132 | if pitch_class in key:
133 | key = key.split(" ")
134 | key[0] = enharm_map[key[0]]
135 | key = " ".join(key)
136 | break
137 | labels[track_id] = key
138 |
139 | self.labels = labels
140 |
141 | def load_audio_paths(self):
142 | self.audio_paths = {
143 | t_id: self.dataset.track(t_id).audio_path for t_id in self.track_ids
144 | }
145 |
146 | # @Dataset.check_metadata_is_loaded
147 | def get_splits(self, seed=42):
148 | if self.split_type not in ["all", "single"] or isinstance(
149 | self.split_type, list
150 | ):
151 | raise ValueError(
152 | "Split type must be 'all', 'single', or "
153 | + "list of 3 floats adding up to 1."
154 | )
155 | # !!!validate metadata exists, and download if not
156 | # tags can be in metadata or annotations. Try metadata first.
157 | # try:
158 | # self.dataset.download(partial_download=["metadata"])
159 | # except ValueError:
160 | # self.dataset.download(partial_download=["annotations"])
161 | # self.dataset.validate(verbose=False)
162 |
163 | splits = []
164 | if self.split_type in ["all", "single"]:
165 | # check for up to 50 splits
166 | for i in range(50):
167 | try:
168 | split = self.dataset.get_track_splits(i)
169 | except TypeError:
170 | # if it fails at i=0, there are either no splits or only one split
171 | if i == 0:
172 | try:
173 | split = self.dataset.get_track_splits()
174 | except NotImplementedError:
175 | # no splits are available, so we need to generate them
176 | print(
177 | "No official splits found, generating random, stratified "
178 | + f"ones with seed {seed}."
179 | )
180 | splits.append(self.get_stratified_split(seed=seed))
181 | break
182 | # if it fails at i>0, there are no more splits
183 | else:
184 | break
185 | # we need to determine if each split returned is in the format
186 | # ["train", "validation", "test"], or whether they are actually
187 | # just folds keyed by integer index.
188 | try:
189 | _, _, _ = split["train"], split["validation"], split["test"]
190 | splits.append(split)
191 | except KeyError:
192 | # assume they're folds
193 | n_folds = len(split.keys())
194 | if n_folds >= 3:
195 | # get splits for cross validation, assigning one for
196 | # validation and test, and the rest for training
197 | for fold_idx in range(n_folds):
198 | fold = {}
199 | available_folds = list(split.keys())
200 |
201 | fold["validation"] = split[fold_idx]
202 | available_folds.remove(fold_idx)
203 |
204 | fold["test"] = split[(fold_idx + 1) % n_folds]
205 | available_folds.remove((fold_idx + 1) % n_folds)
206 |
207 | # the rest is train, get lists and flatten
208 | fold["train"] = sum(
209 | [split[af] for af in available_folds], []
210 | )
211 |
212 | splits.append(fold)
213 | else:
214 | print(
215 | "No official splits found, generating random, stratified "
216 | + f"ones with seed {seed}."
217 | )
218 | splits.append(self.get_stratified_split(seed=seed))
219 | if i == 0:
220 | break
221 |
222 | if self.split_type == "single":
223 | splits = [splits[0]]
224 |
225 | else:
226 | # else split_type is a list of sizes, meaning get stratified splits
227 | splits.append(self.get_stratified_split(sizes=self.split_type, seed=seed))
228 |
229 | return splits
230 |
--------------------------------------------------------------------------------
/mir_ref/deform.py:
--------------------------------------------------------------------------------
1 | """Apply deformations to audio files.
2 | """
3 |
4 | from colorama import Fore, Style
5 |
6 | from mir_ref.datasets.dataset import get_dataset
7 | from mir_ref.deformations import generate_deformations
8 | from mir_ref.utils import load_config
9 |
10 |
11 | def deform(cfg_path, n_jobs):
12 | cfg = load_config(cfg_path)
13 |
14 | # iterate through every dataset of every experiment to generate deformations
15 | print(Fore.GREEN + "# Generating deformations...", Style.RESET_ALL)
16 | for exp_cfg in cfg["experiments"]:
17 | for dataset_cfg in exp_cfg["datasets"]:
18 | # generate deformations
19 | print(Fore.GREEN + f"## Dataset: {dataset_cfg['name']}", Style.RESET_ALL)
20 |
21 | dataset = get_dataset(
22 | dataset_cfg=dataset_cfg,
23 | task_cfg=exp_cfg["task"],
24 | features_cfg=exp_cfg["features"],
25 | )
26 | dataset.download()
27 | dataset.preprocess()
28 |
29 | generate_deformations(
30 | dataset,
31 | n_jobs=n_jobs,
32 | )
33 |
--------------------------------------------------------------------------------
/mir_ref/deformations.py:
--------------------------------------------------------------------------------
1 | """Implementations of various audio deformations used for
2 | robustness evaluation.
3 | """
4 |
5 | import os
6 |
7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
8 |
9 | from pathlib import Path
10 |
11 | import librosa
12 | import soundfile as sf
13 | from joblib import Parallel, delayed
14 | from tqdm import tqdm
15 |
16 |
17 | def deform_audio(
18 | track_id: str,
19 | dataset,
20 | ):
21 | """For each deformation scenario, deform and save single audio file."""
22 |
23 | # load audio
24 | y, sr = librosa.load(dataset.audio_paths[track_id], sr=None)
25 |
26 | for scenario_idx, scenario in enumerate(dataset.deformations_cfg):
27 | y_d = y.copy()
28 | # we're looping like this so that we retain the order of
29 | # deformations provided by the user
30 | # debatable whether this should be the default behavior
31 | for deformation in scenario:
32 | if deformation["type"] == "AddGaussianSNR":
33 | from audiomentations import AddGaussianSNR
34 |
35 | transform = AddGaussianSNR(**deformation["params"])
36 | elif deformation["type"] == "ApplyImpulseResponse":
37 | from audiomentations import ApplyImpulseResponse
38 |
39 | transform = ApplyImpulseResponse(**deformation["params"])
40 | elif deformation["type"] == "ClippingDistortion":
41 | from audiomentations import ClippingDistortion
42 |
43 | transform = ClippingDistortion(**deformation["params"])
44 | elif deformation["type"] == "Gain":
45 | from audiomentations import Gain
46 |
47 | transform = Gain(**deformation["params"])
48 | elif deformation["type"] == "Mp3Compression":
49 | from audiomentations import Mp3Compression
50 |
51 | transform = Mp3Compression(**deformation["params"])
52 | elif deformation["type"] == "PitchShift":
53 | from audiomentations import PitchShift
54 |
55 | transform = PitchShift(**deformation["params"])
56 | else:
57 | raise ValueError(f"Deformation {deformation['type']} not implemented.")
58 |
59 | # apply deformation
60 | y_d = transform(y_d, sr)
61 | del transform
62 |
63 | # save deformed audio
64 | output_filepath = dataset.get_deformed_audio_path(
65 | track_id=track_id, deform_idx=scenario_idx
66 | )
67 |
68 | # create parent dirs if they don't exist, and write audio
69 | Path(output_filepath).parent.mkdir(parents=True, exist_ok=True)
70 | sf.write(output_filepath, y_d, sr)
71 |
72 |
73 | def deform_audio_essentia(
74 | track_id: str,
75 | dataset,
76 | ):
77 | """Some of the essentia ffmpeg calls seem to have deprecated parameters,
78 | something related to the time_base used. I'm temporarily moving the
79 | essentia loading and writing here.
80 | """
81 | # move to imports if used
82 | from essentia.standard import AudioLoader, AudioWriter
83 |
84 | input_filepath = dataset.audio_paths[track_id]
85 |
86 | # load audio
87 | y, sr, channels, _, bit_rate, _ = AudioLoader(filename=input_filepath)()
88 | # assert audio is mono or stereo
89 | assert channels <= 2
90 | # some book keeping for constructing the output path later
91 | file_dir = Path(input_filepath).parent
92 | file_stem = str(Path(input_filepath).stem)
93 | file_suffix = str(Path(input_filepath).suffix)
94 |
95 | for scenario_idx, scenario in enumerate(dataset.deformations_cfg):
96 | y_d = y.copy()
97 | # we're looping like this so that we retain the order of
98 | # deformations provided by the user
99 | # debatable whether this should be the default behavior
100 | for deformation in scenario:
101 | if deformation["type"] == "AddGaussianSNR":
102 | from audiomentations import AddGaussianSNR
103 |
104 | transform = AddGaussianSNR(**deformation["params"])
105 | elif deformation["type"] == "ApplyImpulseResponse":
106 | from audiomentations import ApplyImpulseResponse
107 |
108 | transform = ApplyImpulseResponse(**deformation["params"])
109 | elif deformation["type"] == "ClippingDistortion":
110 | from audiomentations import ClippingDistortion
111 |
112 | transform = ClippingDistortion(**deformation["params"])
113 | elif deformation["type"] == "Gain":
114 | from audiomentations import Gain
115 |
116 | transform == Gain(**deformation["params"])
117 | elif deformation["type"] == "Mp3Compression":
118 | from audiomentations import Mp3Compression
119 |
120 | transform = Mp3Compression(**deformation["params"])
121 | elif deformation["type"] == "PitchShift":
122 | from audiomentations import PitchShift
123 |
124 | transform = PitchShift(**deformation["params"])
125 | else:
126 | raise ValueError(f"Deformation {deformation['type']} not implemented.")
127 |
128 | # apply deformation
129 | y_d = transform(y_d, sr)
130 |
131 | # save deformed audio
132 | output_filepath = dataset.get_deformed_audio_path(
133 | track_id=track_id, deform_idx=scenario_idx
134 | )
135 |
136 | # create parent dirs if they don't exist, and write audio
137 | Path(output_filepath).parent.mkdir(parents=True, exist_ok=True)
138 |
139 | # special case for lossy compression formats in which bitrate needs to be specified
140 | lossy_format_bitrates = [
141 | 32,
142 | 40,
143 | 48,
144 | 56,
145 | 64,
146 | 80,
147 | 96,
148 | 112,
149 | 128,
150 | 144,
151 | 160,
152 | 192,
153 | 224,
154 | 256,
155 | 320,
156 | ]
157 | if (
158 | file_suffix == ".mp3" or file_suffix == ".ogg"
159 | ) and bit_rate in lossy_format_bitrates:
160 | AudioWriter(
161 | filename=output_filepath,
162 | format=file_suffix[1:],
163 | sampleRate=sr,
164 | bitrate=bit_rate,
165 | )(y_d)
166 | else:
167 | AudioWriter(
168 | filename=output_filepath, format=file_suffix[1:], sampleRate=sr
169 | )(y_d)
170 |
171 |
172 | def generate_deformations(
173 | dataset,
174 | n_jobs: int = 1,
175 | ):
176 | """Generate deformed audio and save."""
177 |
178 | # check if there are no deformations specified in the experiment
179 | if not dataset.deformations_cfg:
180 | print(
181 | f"No deformations specified for '{dataset.name}'. Skipping deformation generation."
182 | )
183 | return
184 |
185 | # create output dir for deformed audio if it doesn't exist
186 | (Path(dataset.data_dir) / "audio_deformed").mkdir(parents=True, exist_ok=True)
187 |
188 | if n_jobs == 1:
189 | for track_id in tqdm(dataset.track_ids):
190 | deform_audio(track_id, dataset)
191 | else:
192 | # this passes around the dataset object, which is not ideal for performance
193 | Parallel(n_jobs=n_jobs, verbose=0)(
194 | delayed(deform_audio)(track_id, dataset)
195 | for track_id in tqdm(dataset.track_ids)
196 | )
197 |
198 | print("Deformed audio generated and saved.")
199 |
--------------------------------------------------------------------------------
/mir_ref/evaluate.py:
--------------------------------------------------------------------------------
1 | """Evaluation of downstream models
2 | """
3 |
4 | import json
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | from colorama import Fore, Style
9 | from sklearn.metrics import (
10 | average_precision_score,
11 | classification_report,
12 | roc_auc_score,
13 | )
14 | from sklearn.model_selection import ParameterGrid
15 |
16 | from mir_ref.dataloaders import DataGenerator
17 | from mir_ref.datasets.dataset import get_dataset
18 | from mir_ref.probes.probe_builder import get_model
19 | from mir_ref.utils import load_config
20 |
21 |
22 | def evaluate(cfg_path, run_id=None):
23 | """"""
24 | cfg = load_config(cfg_path)
25 |
26 | if run_id is None:
27 | print(
28 | "No run ID has been specified. Attempting to load latest training run,",
29 | "but this will fail if no runs have a timestamp IDs.",
30 | )
31 | # attempt to load latest experiment by sorting dirs in logs
32 | dirs = [d for d in Path("./logs").iterdir() if d.is_dir()]
33 | # only keep dirs that has numeric and - in the name
34 | dirs = [d for d in dirs if d.name.replace("-", "").isnumeric()]
35 | if not dirs:
36 | raise ValueError(
37 | "No run ID has been specified and no timestamped runs have been found."
38 | )
39 | run_id = sorted(dirs)[-1].name
40 |
41 | for exp_cfg in cfg["experiments"]:
42 | run_params = {
43 | "dataset_cfg": exp_cfg["datasets"],
44 | "feature": exp_cfg["features"],
45 | "model_cfg": exp_cfg["probes"],
46 | }
47 |
48 | # create grid from parameters
49 | grid = ParameterGrid(run_params)
50 |
51 | # !!!temporary, assumes single dataset per task
52 | dataset = get_dataset(
53 | dataset_cfg=exp_cfg["datasets"][0],
54 | task_cfg=exp_cfg["task"],
55 | features_cfg=exp_cfg["features"],
56 | )
57 | dataset.download()
58 |
59 | for params in grid:
60 | # get index of downstream model for naming-logging
61 | model_idx = exp_cfg["probes"].index(params["model_cfg"])
62 | # get dataset object
63 | # dataset = get_dataset(
64 | # dataset_cfg=params["dataset_cfg"],
65 | # task_cfg=exp_cfg["task"],
66 | # deformations_cfg=exp_cfg["deformations"],
67 | # features_cfg=exp_cfg["features"],
68 | # )
69 | # dataset.download()
70 | # run task for every split
71 | for split_idx in range(len(dataset.get_splits())):
72 | print(
73 | Fore.GREEN
74 | + f"Task: {dataset.task_name}\n"
75 | + f"└── Dataset: {dataset.name}\n"
76 | + f" └── Embeddings: {params['feature']}\n"
77 | + f" └── Model: {model_idx}\n"
78 | + f" └── Split: {split_idx}",
79 | Style.RESET_ALL,
80 | )
81 | evaluate_probe(
82 | run_id=run_id,
83 | dataset=dataset,
84 | model_cfg=params["model_cfg"],
85 | model_idx=model_idx,
86 | feature=params["feature"],
87 | split_idx=split_idx,
88 | )
89 |
90 |
91 | def evaluate_probe(run_id, dataset, model_cfg, model_idx, feature, split_idx):
92 | """Evaluate downstream models, including in cases
93 | with deformations.
94 |
95 | Args:
96 | run_id (str): ID of the current run, defaults to timestamp.
97 | dataset (Dataset): Dataset object.
98 | model_cfg (dict): Downstream model config.
99 | model_idx (int): Index of the downstream model in the list of models.
100 | split_idx (int): Index of the split in the list of splits.
101 | """
102 |
103 | split = dataset.get_splits()[split_idx]
104 | n_classes = len(dataset.encoded_labels[dataset.track_ids[0]])
105 |
106 | if model_cfg["emb_shape"] == "infer":
107 | # get embedding shape from the first embedding
108 | emb_shape = np.load(
109 | dataset.get_embedding_path(
110 | feature=feature,
111 | track_id=split["train"][0],
112 | )
113 | ).shape
114 | elif isinstance(model_cfg["emb_shape"], int):
115 | emb_shape = model_cfg["emb_shape"]
116 | elif isinstance(model_cfg["emb_shape"], str):
117 | raise ValueError(f"{model_cfg['emb_shape']} not implemented.")
118 |
119 | run_dir = (
120 | Path(run_id) / dataset.task_name / f"{dataset.name}_{feature}_model-{model_idx}"
121 | )
122 |
123 | # get all dirs starting with run_dir, sort them, and get latest
124 | run_dirs = [
125 | d
126 | for d in (Path("./logs") / run_dir.parent).iterdir()
127 | if str(d).startswith(str(Path("logs") / run_dir))
128 | ]
129 | # remove ./logs from the start of dirs
130 | run_dirs = [d.relative_to("logs") for d in run_dirs]
131 |
132 | new_run_dir = sorted(run_dirs)[-1]
133 | if run_dir != new_run_dir:
134 | import warnings
135 |
136 | warnings.warn(
137 | f"Multiple runs for '{run_dir}' found. Loading latest run '{new_run_dir}'",
138 | stacklevel=2,
139 | )
140 | run_dir = new_run_dir
141 |
142 | # load model
143 | model = get_model(model_cfg=model_cfg, dim=emb_shape, n_classes=n_classes)
144 | model.load_weights(filepath=Path("./logs") / run_dir / "weights.h5")
145 |
146 | # load data
147 | test_gen = DataGenerator(
148 | ids_list=split["test"],
149 | labels_dict=dataset.encoded_labels,
150 | paths_dict={
151 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id)
152 | for t_id in split["test"]
153 | },
154 | batch_size=model_cfg["batch_size"],
155 | dim=emb_shape,
156 | n_classes=n_classes,
157 | shuffle=False,
158 | )
159 |
160 | pred = model.predict(x=test_gen, batch_size=model_cfg["batch_size"], verbose=1)
161 |
162 | if dataset.task_type == "multiclass_classification":
163 | # get one-hot encoded vectors where 1 is the argmax in each case
164 | y_pred = [np.eye(len(p))[np.argmax(p)] for p in pred]
165 | # predictions might be shorter because of partial batch drop
166 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][: len(y_pred)]
167 | # doing this twice because it has nice formatting, but need the json after
168 | print(classification_report(y_true=y_true, y_pred=y_pred))
169 | metrics = classification_report(y_true=y_true, y_pred=y_pred, output_dict=True)
170 | if dataset.task_name == "key_estimation":
171 | # calculate weighted accuracy
172 | from mir_ref.metrics import key_detection_weighted_accuracy
173 |
174 | metrics["weighted_accuracy"] = key_detection_weighted_accuracy(
175 | y_true=[dataset.decode_label(y) for y in y_true],
176 | y_pred=[dataset.decode_label(y) for y in y_pred],
177 | )
178 | print("Weighted accuracy:", metrics["weighted_accuracy"])
179 | elif dataset.task_type == "multilabel_classification":
180 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][: len(pred)]
181 | y_pred = pred
182 | metrics = {
183 | "roc_auc": roc_auc_score(y_true, y_pred),
184 | "average_precision": average_precision_score(y_true, y_pred),
185 | }
186 | print(metrics)
187 |
188 | with open("./logs" / run_dir / "clean_metrics.json", "w+") as f:
189 | json.dump(metrics, f, indent=4)
190 |
191 | # do the same but with the deformed audio as the test set
192 | if dataset.deformations_cfg:
193 | for scenario_idx, scenario_cfg in enumerate(dataset.deformations_cfg):
194 | print(
195 | Fore.GREEN
196 | + "# Scenario: "
197 | + f"{scenario_idx+1}/{len(dataset.deformations_cfg)} "
198 | + f"{[cfg['type'] for cfg in scenario_cfg]}",
199 | Style.RESET_ALL,
200 | )
201 |
202 | # load data
203 | test_gen = DataGenerator(
204 | ids_list=split["test"],
205 | labels_dict=dataset.encoded_labels,
206 | paths_dict={
207 | t_id: dataset.get_deformed_embedding_path(
208 | feature=feature, track_id=t_id, deform_idx=scenario_idx
209 | )
210 | for t_id in split["test"]
211 | },
212 | batch_size=model_cfg["batch_size"],
213 | dim=emb_shape,
214 | n_classes=n_classes,
215 | shuffle=False,
216 | )
217 |
218 | pred = model.predict(
219 | x=test_gen,
220 | batch_size=model_cfg["batch_size"],
221 | verbose=1,
222 | )
223 |
224 | if dataset.task_type == "multiclass_classification":
225 | # get one-hot encoded vectors where 1 is the argmax in each case
226 | y_pred = [np.eye(len(p))[np.argmax(p)] for p in pred]
227 | # predictions might be shorter because of partial batch drop
228 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][
229 | : len(y_pred)
230 | ]
231 | metrics = classification_report(
232 | y_true=y_true, y_pred=y_pred, output_dict=True
233 | )
234 | print(classification_report(y_true=y_true, y_pred=y_pred))
235 | if dataset.task_name == "key_estimation":
236 | # calculate weighted accuracy
237 | from mir_ref.metrics import key_detection_weighted_accuracy
238 |
239 | metrics["weighted_accuracy"] = key_detection_weighted_accuracy(
240 | y_true=[dataset.decode_label(y) for y in y_true],
241 | y_pred=[dataset.decode_label(y) for y in y_pred],
242 | )
243 | print("Weighted accuracy:", metrics["weighted_accuracy"])
244 |
245 | elif dataset.task_type == "multilabel_classification":
246 | y_true = [dataset.encoded_labels[t_id] for t_id in split["test"]][
247 | : len(pred)
248 | ]
249 | y_pred = pred
250 | metrics = {
251 | "roc_auc": roc_auc_score(y_true, y_pred),
252 | "average_precision": average_precision_score(y_true, y_pred),
253 | }
254 | print(metrics)
255 |
256 | # save metrics
257 | with open(
258 | "./logs" / run_dir / f"deform_{scenario_idx}_metrics.json",
259 | "w+",
260 | ) as f:
261 | json.dump(metrics, f, indent=4)
262 |
--------------------------------------------------------------------------------
/mir_ref/extract.py:
--------------------------------------------------------------------------------
1 | """Generate embeddings from the audio files.
2 | """
3 |
4 |
5 | from colorama import Fore, Style
6 |
7 | from mir_ref.datasets.dataset import get_dataset
8 | from mir_ref.features.feature_extraction import generate_embeddings
9 | from mir_ref.utils import load_config
10 |
11 |
12 | def generate(
13 | cfg_path,
14 | skip_clean=False,
15 | skip_deformed=False,
16 | no_overwrite=False,
17 | deform_list=None,
18 | ):
19 | cfg = load_config(cfg_path)
20 |
21 | for exp_cfg in cfg["experiments"]:
22 | # iterate through every dataset to generate embeddings
23 | print(Fore.GREEN + "# Extracting features...", Style.RESET_ALL)
24 | for dataset_cfg in exp_cfg["datasets"]:
25 | print(
26 | Fore.GREEN + f"## Dataset: {dataset_cfg['name']}",
27 | Style.RESET_ALL,
28 | )
29 | for model_name in exp_cfg["features"]:
30 | print(Fore.GREEN + f"### Feature: {model_name}", Style.RESET_ALL)
31 |
32 | dataset = get_dataset(
33 | dataset_cfg=dataset_cfg,
34 | task_cfg=exp_cfg["task"],
35 | features_cfg=exp_cfg["features"],
36 | )
37 | dataset.download()
38 | dataset.preprocess()
39 |
40 | generate_embeddings(
41 | dataset,
42 | model_name=model_name,
43 | skip_clean=skip_clean,
44 | skip_deformed=skip_deformed,
45 | no_overwrite=no_overwrite,
46 | deform_list=deform_list,
47 | )
48 |
--------------------------------------------------------------------------------
/mir_ref/features/custom_features.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chrispla/mir_ref/691ae42815db6359ef66c1e175be55be42bbc340/mir_ref/features/custom_features.py
--------------------------------------------------------------------------------
/mir_ref/features/feature_extraction.py:
--------------------------------------------------------------------------------
1 | """Retrieve correct feature extractor and extract features.
2 | """
3 |
4 | import os
5 |
6 | import wget
7 |
8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
9 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
10 |
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 |
15 | def check_model_exists(model_path: str):
16 | if not os.path.exists(model_path):
17 | raise Exception(
18 | f"Model not found at {model_path}"
19 | + "Please download it from https://essentia.upf.edu/models.html"
20 | + "and place it in the mir_ref/features/models/weights directory."
21 | )
22 |
23 |
24 | def get_input_output_paths(
25 | dataset,
26 | model_name,
27 | skip_clean=False,
28 | skip_deformed=False,
29 | no_overwrite=False,
30 | deform_list=None,
31 | ):
32 | """Get a list of input audio paths (including deformed audio)
33 | and a list of the corresponding output paths for the embeddings
34 | for a given dataset and embedding model. Don't include embeddings
35 | that have already been computed.
36 |
37 | Args:
38 | dataset: mir_ref Dataset object.
39 | model_name: Name of the embedding model.
40 | skip_clean: Whether to skip embedding generation for clean audio.
41 | skip_deformed: Whether to skip embedding generation for deformed audio.
42 | no_overwrite: Whether to skip embedding generation for existing embeddings.
43 | deform_list: List of deformation scenario indicies to include. If None,
44 | include all deformation scenarios.
45 | """
46 |
47 | # get audio paths for clean and deformed audio
48 | if not skip_clean:
49 | audio_paths = [dataset.audio_paths[track_id] for track_id in dataset.track_ids]
50 | else:
51 | audio_paths = []
52 | if not skip_deformed and dataset.deformations_cfg is not None:
53 | if deform_list is None:
54 | deform_list = range(len(dataset.deformations_cfg))
55 | audio_paths += [
56 | dataset.get_deformed_audio_path(track_id=track_id, deform_idx=deform_idx)
57 | for deform_idx in deform_list
58 | for track_id in dataset.track_ids
59 | ]
60 |
61 | # get output embedding paths for the same tracks
62 | if not skip_clean:
63 | emb_paths = [
64 | dataset.get_embedding_path(track_id, model_name)
65 | for track_id in dataset.track_ids
66 | ]
67 | else:
68 | emb_paths = []
69 | if not skip_deformed and dataset.deformations_cfg is not None:
70 | if deform_list is None:
71 | deform_list = range(len(dataset.deformations_cfg))
72 | emb_paths += [
73 | dataset.get_deformed_embedding_path(
74 | track_id=track_id, feature=model_name, deform_idx=deform_idx
75 | )
76 | for deform_idx in deform_list
77 | for track_id in dataset.track_ids
78 | ]
79 |
80 | if no_overwrite:
81 | # remove audio paths and the respective embeddings paths if embedding
82 | # already exists
83 | audio_paths, emb_paths = zip(
84 | *[
85 | (audio_path, emb_path)
86 | for audio_path, emb_path in zip(audio_paths, emb_paths)
87 | if not os.path.exists(emb_path)
88 | ]
89 | )
90 |
91 | return audio_paths, emb_paths
92 |
93 |
94 | def compute_and_save_embeddings(
95 | model: object,
96 | model_name: str,
97 | aggregation: str,
98 | dataset,
99 | sample_rate: int,
100 | resample_quality=1,
101 | skip_clean=False,
102 | skip_deformed=False,
103 | no_overwrite=False,
104 | deform_list=None,
105 | ):
106 | """Compute embeddings given model object and
107 | audio path list.
108 | """
109 |
110 | from essentia.standard import MonoLoader
111 |
112 | monoloader = MonoLoader(sampleRate=sample_rate, resampleQuality=resample_quality)
113 |
114 | audio_paths, emb_paths = get_input_output_paths(
115 | dataset=dataset,
116 | model_name=model_name,
117 | skip_clean=skip_clean,
118 | skip_deformed=skip_deformed,
119 | no_overwrite=no_overwrite,
120 | deform_list=deform_list,
121 | )
122 |
123 | for input_path, output_path in tqdm(
124 | zip(audio_paths, emb_paths), total=len(audio_paths)
125 | ):
126 | # Load audio
127 | monoloader.configure(filename=input_path)
128 | audio = monoloader()
129 |
130 | # Compute embeddings
131 | embedding = model(audio)
132 |
133 | if aggregation == "mean":
134 | embedding = np.mean(embedding, axis=0)
135 | elif aggregation is None:
136 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
137 | else:
138 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
139 |
140 | # Save embeddings
141 | if not os.path.exists(os.path.dirname(output_path)):
142 | os.makedirs(os.path.dirname(output_path))
143 | with open(output_path, "wb") as f:
144 | np.save(f, embedding)
145 |
146 |
147 | def generate_embeddings(
148 | dataset,
149 | model_name: str,
150 | skip_clean=False,
151 | skip_deformed=False,
152 | no_overwrite=False,
153 | deform_list=None,
154 | transcode_and_load=False,
155 | ):
156 | """Generate embeddings from a list of audio files.
157 |
158 | Args:
159 | dataset_cfg (dict): Dataset configuration.
160 | task_cfg (dict): Task configuration.
161 | aggregation (str, optional): Embedding aggregation method. Defaults to "mean".
162 | skip_clean (bool, optional): Whether to skip embedding generation for clean
163 | audio. Defaults to False.
164 | skip_deformed (bool, optional): Whether to skip embedding generation for
165 | deformed audio. Defaults to False.
166 | no_overwrite (bool, optional): Whether to skip embedding generation for
167 | existing embeddings. Defaults to False.
168 | deform_list (list, optional): List of deformation scenario indicies to include.
169 | If None, include all deformation scenarios.
170 | """
171 |
172 | aggregation = dataset.feature_aggregation
173 | # Load embedding model. Call essentia implementation if available,
174 | # otherwise custom implementation.
175 |
176 | if model_name == "vggish-audioset":
177 | model_path = "mir_ref/features/models/weights/audioset-vggish-3.pb"
178 | if not os.path.exists(model_path):
179 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
180 | wget.download(
181 | "https://essentia.upf.edu/models/feature-extractors/vggish/audioset-vggish-3.pb",
182 | out="mir_ref/features/models/weights/",
183 | )
184 | check_model_exists(model_path)
185 |
186 | from essentia.standard import MonoLoader, TensorflowPredictVGGish
187 |
188 | model = TensorflowPredictVGGish(
189 | graphFilename=model_path, output="model/vggish/embeddings"
190 | )
191 |
192 | compute_and_save_embeddings(
193 | model=model,
194 | model_name=model_name,
195 | aggregation=aggregation,
196 | dataset=dataset,
197 | sample_rate=16000,
198 | skip_clean=skip_clean,
199 | skip_deformed=skip_deformed,
200 | no_overwrite=no_overwrite,
201 | deform_list=deform_list,
202 | )
203 |
204 | elif model_name == "effnet-discogs":
205 | model_path = (
206 | "mir_ref/features/models/weights/discogs_artist_embeddings-effnet-bs64-1.pb"
207 | )
208 | if not os.path.exists(model_path):
209 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
210 | wget.download(
211 | "https://essentia.upf.edu/models/feature-extractors/discogs-effnet/discogs_artist_embeddings-effnet-bs64-1.pb",
212 | out="mir_ref/features/models/weights/",
213 | )
214 | check_model_exists(model_path)
215 |
216 | from essentia.standard import MonoLoader, TensorflowPredictEffnetDiscogs
217 |
218 | model = TensorflowPredictEffnetDiscogs(
219 | graphFilename=model_path, output="PartitionedCall:1"
220 | )
221 |
222 | compute_and_save_embeddings(
223 | model=model,
224 | model_name=model_name,
225 | aggregation=aggregation,
226 | dataset=dataset,
227 | sample_rate=16000,
228 | skip_clean=skip_clean,
229 | skip_deformed=skip_deformed,
230 | no_overwrite=no_overwrite,
231 | deform_list=deform_list,
232 | )
233 |
234 | elif model_name == "msd-musicnn":
235 | model_path = "mir_ref/features/models/weights/msd-musicnn-1.pb"
236 |
237 | if not os.path.exists(model_path):
238 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
239 | wget.download(
240 | "https://essentia.upf.edu/models/feature-extractors/musicnn/msd-musicnn-1.pb",
241 | out="mir_ref/features/models/weights/",
242 | )
243 |
244 | from essentia.standard import MonoLoader, TensorflowPredictMusiCNN
245 |
246 | model = TensorflowPredictMusiCNN(
247 | graphFilename=model_path, output="model/dense/BiasAdd"
248 | )
249 |
250 | compute_and_save_embeddings(
251 | model=model,
252 | model_name=model_name,
253 | aggregation=aggregation,
254 | dataset=dataset,
255 | sample_rate=16000,
256 | resample_quality=4,
257 | skip_clean=skip_clean,
258 | skip_deformed=skip_deformed,
259 | no_overwrite=no_overwrite,
260 | deform_list=deform_list,
261 | )
262 |
263 | elif model_name == "maest":
264 | model_path = "mir_ref/features/models/weights/discogs-maest-30s-pw-1.pb"
265 |
266 | if not os.path.exists(model_path):
267 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
268 | wget.download(
269 | "https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-1.pb",
270 | out="mir_ref/features/models/weights/",
271 | )
272 |
273 | check_model_exists(model_path)
274 |
275 | from essentia.standard import MonoLoader, TensorflowPredictMAEST
276 |
277 | model = TensorflowPredictMAEST(
278 | graphFilename=model_path, output="model/dense/BiasAdd"
279 | )
280 |
281 | compute_and_save_embeddings(
282 | model=model,
283 | model_name=model_name,
284 | aggregation=aggregation,
285 | dataset=dataset,
286 | sample_rate=16000,
287 | resample_quality=4,
288 | skip_clean=skip_clean,
289 | skip_deformed=skip_deformed,
290 | no_overwrite=no_overwrite,
291 | deform_list=deform_list,
292 | )
293 |
294 | elif model_name == "openl3":
295 | from essentia.standard import MonoLoader
296 |
297 | from mir_ref.features.models.openl3 import EmbeddingsOpenL3
298 |
299 | model_path = "mir_ref/features/models/weights/openl3-music-mel128-emb512-3.pb"
300 |
301 | if not os.path.exists(model_path):
302 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
303 | wget.download(
304 | "https://essentia.upf.edu/models/feature-extractors/openl3/openl3-env-mel128-emb512-3.pb",
305 | out="mir_ref/features/models/weights/",
306 | )
307 |
308 | check_model_exists(model_path)
309 |
310 | extractor = EmbeddingsOpenL3(model_path)
311 |
312 | audio_paths, emb_paths = get_input_output_paths(
313 | dataset=dataset,
314 | model_name=model_name,
315 | skip_clean=skip_clean,
316 | skip_deformed=skip_deformed,
317 | no_overwrite=no_overwrite,
318 | deform_list=deform_list,
319 | )
320 |
321 | # Compute embeddings
322 | for input_path, output_path in tqdm(
323 | zip(audio_paths, emb_paths), total=len(audio_paths)
324 | ):
325 | embedding = extractor.compute(input_path)
326 |
327 | if aggregation == "mean":
328 | embedding = np.mean(embedding, axis=0)
329 | elif aggregation is None:
330 | pass
331 | else:
332 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
333 |
334 | # Save embeddings
335 | if not os.path.exists(os.path.dirname(output_path)):
336 | os.makedirs(os.path.dirname(output_path))
337 | with open(output_path, "wb") as f:
338 | np.save(f, embedding)
339 |
340 | elif model_name == "neuralfp":
341 | import tensorflow as tf
342 | from essentia.standard import MonoLoader
343 |
344 | mel_spec_model_dir = "mir_ref/features/models/weights/neuralfp/mel_spec"
345 | fp_model_dir = "mir_ref/features/models/weights/neuralfp/fp"
346 | mel_spec_model = tf.saved_model.load(mel_spec_model_dir)
347 | mel_spec_infer = mel_spec_model.signatures["serving_default"]
348 | fp_model = tf.saved_model.load(fp_model_dir)
349 | fp_infer = fp_model.signatures["serving_default"]
350 |
351 | monoloader = MonoLoader(sampleRate=8000, resampleQuality=1)
352 |
353 | audio_paths, emb_paths = get_input_output_paths(
354 | dataset=dataset,
355 | model_name=model_name,
356 | skip_clean=skip_clean,
357 | skip_deformed=skip_deformed,
358 | no_overwrite=no_overwrite,
359 | deform_list=deform_list,
360 | )
361 |
362 | for input_path, output_path in tqdm(
363 | zip(audio_paths, emb_paths), total=len(audio_paths)
364 | ):
365 | # Load audio
366 | monoloader.configure(filename=input_path)
367 | audio = monoloader()
368 |
369 | # Fingerprinting is done per 8000 samples, with a 4000 sample overlap, so pad
370 | audio = np.concatenate((np.zeros(4000), audio))
371 | audio = np.concatenate((audio, np.zeros(4000 - (len(audio) % 4000))))
372 |
373 | # Compute embeddings
374 | embeddings = []
375 | for buffer_start in range(0, len(audio) - 4000, 4000):
376 | buffer = audio[buffer_start : buffer_start + 8000]
377 | # size (None, 1, 8000)
378 | buffer.resize(1, 8000)
379 | buffer = np.array([buffer])
380 | # use mel spectrogram model
381 | mel_spec_emb = mel_spec_infer(tf.constant(buffer, dtype=tf.float32))[
382 | "output_1"
383 | ]
384 | # use fingerprinter model
385 | fp_emb = fp_infer(mel_spec_emb)["output_1"]
386 | embeddings.append(fp_emb.numpy()[0])
387 |
388 | if aggregation == "mean":
389 | embedding = np.mean(embeddings, axis=0)
390 | elif aggregation is None:
391 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
392 | else:
393 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
394 |
395 | # Save embeddings
396 | if not os.path.exists(os.path.dirname(output_path)):
397 | os.makedirs(os.path.dirname(output_path))
398 | with open(output_path, "wb") as f:
399 | np.save(f, embedding)
400 |
401 | elif model_name == "mert-v1-330m" or model_name == "mert-v1-95m":
402 | # from transformers import Wav2Vec2Processor
403 | import torch
404 | import torchaudio.transforms as T
405 | from transformers import AutoModel, Wav2Vec2FeatureExtractor
406 |
407 | n_params = 330 if model_name == "mert-v1-330m" else 95
408 | # loading our model weights
409 | model = AutoModel.from_pretrained(
410 | f"m-a-p/MERT-v1-{n_params}M", trust_remote_code=True
411 | )
412 | # loading the corresponding preprocessor config
413 | processor = Wav2Vec2FeatureExtractor.from_pretrained(
414 | f"m-a-p/MERT-v1-{n_params}M", trust_remote_code=True
415 | )
416 | # get desired sample rate
417 | sample_rate = processor.sampling_rate
418 |
419 | audio_paths, emb_paths = get_input_output_paths(
420 | dataset=dataset,
421 | model_name=model_name,
422 | skip_clean=skip_clean,
423 | skip_deformed=skip_deformed,
424 | no_overwrite=no_overwrite,
425 | deform_list=deform_list,
426 | )
427 |
428 | if transcode_and_load:
429 | import librosa
430 | import sox
431 |
432 | tfm = sox.Transformer()
433 | tfm.convert(samplerate=sample_rate, n_channels=1)
434 | else:
435 | from essentia.standard import MonoLoader
436 |
437 | monoloader = MonoLoader(sampleRate=sample_rate, resampleQuality=1)
438 |
439 | for input_path, output_path in tqdm(
440 | zip(audio_paths, emb_paths), total=len(audio_paths)
441 | ):
442 | # Load audio
443 | if transcode_and_load:
444 | wav_input_path = input_path[:-4] + str(sample_rate) + ".wav"
445 | tfm.build(input_path, wav_input_path)
446 | audio, _ = librosa.load(wav_input_path, sr=sample_rate)
447 | os.remove(wav_input_path)
448 | else:
449 | monoloader.configure(filename=input_path)
450 | audio = monoloader()
451 |
452 | inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
453 | with torch.no_grad():
454 | outputs = model(**inputs, output_hidden_states=True)
455 |
456 | if aggregation == "mean":
457 | # we'll get the full embedding for now, meaning 13 layers x 768, or
458 | # 24 layers x 1024 for 95M and 330M respectively
459 | all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
460 | embedding = all_layer_hidden_states.mean(-2).detach().cpu().numpy()
461 | elif aggregation is None:
462 | pass
463 | else:
464 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
465 |
466 | # Save embeddings
467 | if not os.path.exists(os.path.dirname(output_path)):
468 | os.makedirs(os.path.dirname(output_path))
469 | with open(output_path, "wb") as f:
470 | np.save(f, embedding)
471 |
472 | elif model_name == "clmr-v2":
473 | import subprocess
474 |
475 | import torch
476 |
477 | from mir_ref.features.models.clmr import SampleCNN, load_encoder_checkpoint
478 |
479 | # download model
480 | if not os.path.exists(
481 | "mir_ref/features/models/weights/clmr_checkpoint_10000.pt"
482 | ):
483 | print(f"Downloading {model_name} to mir_ref/features/models/weights...")
484 | wget.download(
485 | "https://github.com/Spijkervet/CLMR/releases/download/2.0/clmr_checkpoint_10000.zip",
486 | out="mir_ref/features/models/weights/",
487 | )
488 |
489 | # unzip clmr_checkpoint_10000
490 | subprocess.run(
491 | [
492 | "unzip",
493 | "mir_ref/features/models/weights/clmr_checkpoint_10000.zip",
494 | "-d",
495 | "mir_ref/features/models/weights/",
496 | ]
497 | )
498 | # delete zip
499 | subprocess.run(
500 | [
501 | "rm",
502 | "mir_ref/features/models/weights/clmr_checkpoint_10000.zip",
503 | ]
504 | )
505 | # delete clmr_checkpoint_10000_optim.pt
506 | subprocess.run(
507 | [
508 | "rm",
509 | "mir_ref/features/models/weights/clmr_checkpoint_10000_optim.pt",
510 | ]
511 | )
512 |
513 | # load model
514 | encoder = SampleCNN(strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], supervised=False)
515 | state_dict = load_encoder_checkpoint(
516 | "mir_ref/features/models/weights/clmr_checkpoint_10000.pt"
517 | )
518 | encoder.load_state_dict(state_dict)
519 | encoder.eval()
520 |
521 | audio_paths, emb_paths = get_input_output_paths(
522 | dataset=dataset,
523 | model_name=model_name,
524 | skip_clean=skip_clean,
525 | skip_deformed=skip_deformed,
526 | no_overwrite=no_overwrite,
527 | deform_list=deform_list,
528 | )
529 |
530 | if transcode_and_load:
531 | import librosa
532 | import sox
533 |
534 | tfm = sox.Transformer()
535 | tfm.convert(samplerate=22050, n_channels=1)
536 | else:
537 | from essentia.standard import MonoLoader
538 |
539 | monoloader = MonoLoader(sampleRate=22050, resampleQuality=1)
540 |
541 | for input_path, output_path in tqdm(
542 | zip(audio_paths, emb_paths), total=len(audio_paths)
543 | ):
544 | # Load audio
545 | if transcode_and_load:
546 | wav_input_path = input_path[:-4] + str(22050) + ".wav"
547 | tfm.build(input_path, wav_input_path)
548 | audio, _ = librosa.load(wav_input_path, sr=22050)
549 | os.remove(wav_input_path)
550 | else:
551 | monoloader.configure(filename=input_path)
552 | audio = monoloader()
553 |
554 | # get embedding per 59049 samples, padding the last buffer
555 | embeddings = []
556 | buffer_size = 59049
557 | for i in range(0, len(audio), buffer_size):
558 | buffer = audio[i : i + buffer_size]
559 | if len(buffer) < buffer_size:
560 | buffer = np.pad(
561 | buffer, (0, buffer_size - len(buffer)), mode="constant"
562 | )
563 | buffer = torch.from_numpy(buffer).float()
564 | buffer = buffer.unsqueeze(0).unsqueeze(0)
565 | embedding = encoder(buffer).squeeze()
566 | embeddings.append(embedding)
567 |
568 | if aggregation == "mean":
569 | embedding = torch.mean(torch.stack(embeddings), axis=0).detach().numpy()
570 | elif aggregation is None:
571 | embedding = torch.stack(embeddings).detach().numpy()
572 | else:
573 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
574 |
575 | # Save embeddings
576 | if not os.path.exists(os.path.dirname(output_path)):
577 | os.makedirs(os.path.dirname(output_path))
578 | with open(output_path, "wb") as f:
579 | np.save(f, embedding)
580 |
581 | elif model_name == "mule":
582 | raise Exception("MULE embeddings are not fully supported yet.")
583 |
584 | from scooch import Config
585 |
586 | from mir_ref.features.models.mule import Analysis
587 |
588 | if aggregation is None:
589 | config = "mir_ref/features/models/weights/mule/mule_embedding_timeline.yml"
590 | elif aggregation == "mean":
591 | config = "mir_ref/features/models/weights/mule/mule_embedding_average.yml"
592 | else:
593 | raise Exception(f"Aggregation method '{aggregation}' not implemented.")
594 |
595 | cfg = Config(config)
596 | analysis = Analysis(cfg)
597 |
598 | audio_paths, emb_paths = get_input_output_paths(
599 | dataset=dataset,
600 | model_name=model_name,
601 | skip_clean=skip_clean,
602 | skip_deformed=skip_deformed,
603 | no_overwrite=no_overwrite,
604 | deform_list=deform_list,
605 | )
606 |
607 | for input_file, output_file in tqdm(
608 | zip(audio_paths, emb_paths), total=len(audio_paths)
609 | ):
610 | feat = analysis.analyze(input_file)
611 | feat.save(output_file)
612 |
--------------------------------------------------------------------------------
/mir_ref/features/models/clmr.py:
--------------------------------------------------------------------------------
1 | """Code used for inference for CLMR model.
2 | Source (slightly modified for inference purposes):
3 | https://github.com/Spijkervet/CLMR, Apache License 2.0
4 | """
5 |
6 | import torch
7 | from collections import OrderedDict
8 | import torch.nn as nn
9 |
10 |
11 | def load_encoder_checkpoint(checkpoint_path: str) -> OrderedDict:
12 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
13 | if "pytorch-lightning_version" in state_dict.keys():
14 | new_state_dict = OrderedDict(
15 | {
16 | k.replace("model.encoder.", ""): v
17 | for k, v in state_dict["state_dict"].items()
18 | if "model.encoder." in k
19 | }
20 | )
21 | else:
22 | new_state_dict = OrderedDict()
23 | for k, v in state_dict.items():
24 | if "encoder." in k:
25 | new_state_dict[k.replace("encoder.", "")] = v
26 |
27 | new_state_dict["fc.weight"] = torch.zeros(50, 512)
28 | new_state_dict["fc.bias"] = torch.zeros(50)
29 | return new_state_dict
30 |
31 |
32 | class Model(nn.Module):
33 | def __init__(self):
34 | super(Model, self).__init__()
35 |
36 | def initialize(self, m):
37 | if isinstance(m, (nn.Conv1d)):
38 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
39 |
40 |
41 | class Identity(nn.Module):
42 | def __init__(self):
43 | super(Identity, self).__init__()
44 |
45 | def forward(self, x):
46 | return x
47 |
48 |
49 | class SampleCNN(Model):
50 | def __init__(self, strides, supervised=False):
51 | super(SampleCNN, self).__init__()
52 |
53 | self.strides = strides
54 | self.supervised = supervised
55 | self.sequential = [
56 | nn.Sequential(
57 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0),
58 | nn.BatchNorm1d(128),
59 | nn.ReLU(),
60 | )
61 | ]
62 |
63 | self.hidden = [
64 | [128, 128],
65 | [128, 128],
66 | [128, 256],
67 | [256, 256],
68 | [256, 256],
69 | [256, 256],
70 | [256, 256],
71 | [256, 256],
72 | [256, 512],
73 | ]
74 |
75 | assert len(self.hidden) == len(
76 | self.strides
77 | ), "Number of hidden layers and strides are not equal"
78 | for stride, (h_in, h_out) in zip(self.strides, self.hidden):
79 | self.sequential.append(
80 | nn.Sequential(
81 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1),
82 | nn.BatchNorm1d(h_out),
83 | nn.ReLU(),
84 | nn.MaxPool1d(stride, stride=stride),
85 | )
86 | )
87 |
88 | # 1 x 512
89 | self.sequential.append(
90 | nn.Sequential(
91 | nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
92 | nn.BatchNorm1d(512),
93 | nn.ReLU(),
94 | )
95 | )
96 |
97 | self.sequential = nn.Sequential(*self.sequential)
98 |
99 | if self.supervised:
100 | self.dropout = nn.Dropout(0.5)
101 | self.fc = nn.Linear(512, 50)
102 |
103 | def forward(self, x):
104 | out = self.sequential(x)
105 | if self.supervised:
106 | out = self.dropout(out)
107 |
108 | out = out.reshape(x.shape[0], out.size(1) * out.size(2))
109 | logit = self.fc(out)
110 |
111 | return out
112 |
--------------------------------------------------------------------------------
/mir_ref/features/models/harmonic_cnn.py:
--------------------------------------------------------------------------------
1 | """Code for Harmonic CNN embedding model inference.
2 | Source: https://github.com/minzwon/sota-music-tagging-models/blob/master/predict.py
3 | """
4 |
5 | import os
6 | import sys
7 | import tempfile
8 | from pathlib import Path
9 |
10 | import cog
11 | import librosa
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torchaudio
18 | from torch.autograd import Variable
19 |
20 | sys.path.insert(0, "training")
21 |
22 | SAMPLE_RATE = 16000
23 | DATASET = "mtat"
24 |
25 |
26 | def initialize_filterbank(sample_rate, n_harmonic, semitone_scale):
27 | # MIDI
28 | # lowest note
29 | low_midi = note_to_midi('C1')
30 |
31 | # highest note
32 | high_note = hz_to_note(sample_rate / (2 * n_harmonic))
33 | high_midi = note_to_midi(high_note)
34 |
35 | # number of scales
36 | level = (high_midi - low_midi) * semitone_scale
37 | midi = np.linspace(low_midi, high_midi, level + 1)
38 | hz = midi_to_hz(midi[:-1])
39 |
40 | # stack harmonics
41 | harmonic_hz = []
42 | for i in range(n_harmonic):
43 | harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1)))
44 |
45 | return harmonic_hz, level
46 |
47 |
48 | class Conv_2d(nn.Module):
49 | def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2):
50 | super(Conv_2d, self).__init__()
51 | self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2)
52 | self.bn = nn.BatchNorm2d(output_channels)
53 | self.relu = nn.ReLU()
54 | self.mp = nn.MaxPool2d(pooling)
55 |
56 | def forward(self, x):
57 | out = self.mp(self.relu(self.bn(self.conv(x))))
58 | return out
59 |
60 |
61 | class Res_2d_mp(nn.Module):
62 | def __init__(self, input_channels, output_channels, pooling=2):
63 | super(Res_2d_mp, self).__init__()
64 | self.conv_1 = nn.Conv2d(input_channels, output_channels, 3, padding=1)
65 | self.bn_1 = nn.BatchNorm2d(output_channels)
66 | self.conv_2 = nn.Conv2d(output_channels, output_channels, 3, padding=1)
67 | self.bn_2 = nn.BatchNorm2d(output_channels)
68 | self.relu = nn.ReLU()
69 | self.mp = nn.MaxPool2d(pooling)
70 |
71 | def forward(self, x):
72 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
73 | out = x + out
74 | out = self.mp(self.relu(out))
75 | return out
76 |
77 |
78 | class HarmonicSTFT(nn.Module):
79 | def __init__(self,
80 | sample_rate=16000,
81 | n_fft=513,
82 | win_length=None,
83 | hop_length=None,
84 | pad=0,
85 | power=2,
86 | normalized=False,
87 | n_harmonic=6,
88 | semitone_scale=2,
89 | bw_Q=1.0,
90 | learn_bw=None):
91 | super(HarmonicSTFT, self).__init__()
92 |
93 | # Parameters
94 | self.sample_rate = sample_rate
95 | self.n_harmonic = n_harmonic
96 | self.bw_alpha = 0.1079
97 | self.bw_beta = 24.7
98 |
99 | # Spectrogram
100 | self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length,
101 | hop_length=None, pad=0,
102 | window_fn=torch.hann_window,
103 | power=power, normalized=normalized,
104 | wkwargs=None)
105 | self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
106 |
107 | # Initialize the filterbank. Equally spaced in MIDI scale.
108 | harmonic_hz, self.level = initialize_filterbank(sample_rate, n_harmonic, semitone_scale)
109 |
110 | # Center frequncies to tensor
111 | self.f0 = torch.tensor(harmonic_hz.astype('float32'))
112 |
113 | # Bandwidth parameters
114 | if learn_bw == 'only_Q':
115 | self.bw_Q = nn.Parameter(torch.tensor(np.array([bw_Q]).astype('float32')))
116 | elif learn_bw == 'fix':
117 | self.bw_Q = torch.tensor(np.array([bw_Q]).astype('float32'))
118 |
119 | def get_harmonic_fb(self):
120 | # bandwidth
121 | bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q
122 | bw = bw.unsqueeze(0) # (1, n_band)
123 | f0 = self.f0.unsqueeze(0) # (1, n_band)
124 | fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1)
125 |
126 | up_slope = torch.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw)
127 | down_slope = torch.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw)
128 | fb = torch.max(self.zero, torch.min(down_slope, up_slope))
129 | return fb
130 |
131 | def to_device(self, device, n_bins):
132 | self.f0 = self.f0.to(device)
133 | self.bw_Q = self.bw_Q.to(device)
134 | # fft bins
135 | self.fft_bins = torch.linspace(0, self.sample_rate//2, n_bins)
136 | self.fft_bins = self.fft_bins.to(device)
137 | self.zero = torch.zeros(1)
138 | self.zero = self.zero.to(device)
139 |
140 | def forward(self, waveform):
141 | # stft
142 | spectrogram = self.spec(waveform)
143 |
144 | # to device
145 | self.to_device(waveform.device, spectrogram.size(1))
146 |
147 | # triangle filter
148 | harmonic_fb = self.get_harmonic_fb()
149 | harmonic_spec = torch.matmul(spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2)
150 |
151 | # (batch, channel, length) -> (batch, harmonic, f0, length)
152 | b, c, l = harmonic_spec.size()
153 | harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l)
154 |
155 | # amplitude to db
156 | harmonic_spec = self.amplitude_to_db(harmonic_spec)
157 | return harmonic_spec
158 |
159 |
160 | class HarmonicCNN(nn.Module):
161 | '''
162 | Won et al. 2020
163 | Data-driven harmonic filters for audio representation learning.
164 | Trainable harmonic band-pass filters, short-chunk CNN.
165 | '''
166 | def __init__(self,
167 | n_channels=128,
168 | sample_rate=16000,
169 | n_fft=512,
170 | f_min=0.0,
171 | f_max=8000.0,
172 | n_mels=128,
173 | n_class=50,
174 | n_harmonic=6,
175 | semitone_scale=2,
176 | learn_bw='only_Q'):
177 | super(HarmonicCNN, self).__init__()
178 |
179 | # Harmonic STFT
180 | self.hstft = HarmonicSTFT(sample_rate=sample_rate,
181 | n_fft=n_fft,
182 | n_harmonic=n_harmonic,
183 | semitone_scale=semitone_scale,
184 | learn_bw=learn_bw)
185 | self.hstft_bn = nn.BatchNorm2d(n_harmonic)
186 |
187 | # CNN
188 | self.layer1 = Conv_2d(n_harmonic, n_channels, pooling=2)
189 | self.layer2 = Res_2d_mp(n_channels, n_channels, pooling=2)
190 | self.layer3 = Res_2d_mp(n_channels, n_channels, pooling=2)
191 | self.layer4 = Res_2d_mp(n_channels, n_channels, pooling=2)
192 | self.layer5 = Conv_2d(n_channels, n_channels*2, pooling=2)
193 | self.layer6 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2, 3))
194 | self.layer7 = Res_2d_mp(n_channels*2, n_channels*2, pooling=(2, 3))
195 |
196 | # Dense
197 | self.dense1 = nn.Linear(n_channels*2, n_channels*2)
198 | self.bn = nn.BatchNorm1d(n_channels*2)
199 | self.dense2 = nn.Linear(n_channels*2, n_class)
200 | self.dropout = nn.Dropout(0.5)
201 | self.relu = nn.ReLU()
202 |
203 | def forward(self, x):
204 | # Spectrogram
205 | x = self.hstft_bn(self.hstft(x))
206 |
207 | # CNN
208 | x = self.layer1(x)
209 | x = self.layer2(x)
210 | x = self.layer3(x)
211 | x = self.layer4(x)
212 | x = self.layer5(x)
213 | x = self.layer6(x)
214 | x = self.layer7(x)
215 | x = x.squeeze(2)
216 |
217 | # Global Max Pooling
218 | if x.size(-1) != 1:
219 | x = nn.MaxPool1d(x.size(-1))(x)
220 | x = x.squeeze(2)
221 |
222 | # Dense
223 | x = self.dense1(x)
224 | x = self.bn(x)
225 | x = self.relu(x)
226 | x = self.dropout(x)
227 | x = self.dense2(x)
228 | x = nn.Sigmoid()(x)
229 |
230 | return x
231 |
232 |
233 | class Predictor(cog.Predictor):
234 | def setup(self):
235 | if torch.cuda.is_available():
236 | self.device = torch.device("cuda:0")
237 | else:
238 | self.device = torch.device("cpu")
239 |
240 | self.model = HarmonicCNN().to(self.device),
241 | self.input_length = 5 * 16000
242 |
243 | for key, mod in self.models.items():
244 | filename = os.path.join("models", DATASET, key, "best_model.pth")
245 | state_dict = torch.load(filename, map_location=self.device)
246 | if "spec.mel_scale.fb" in state_dict.keys():
247 | mod.spec.mel_scale.fb = state_dict["spec.mel_scale.fb"]
248 | mod.load_state_dict(state_dict)
249 |
250 | self.tags = np.load("split/mtat/tags.npy")
251 |
252 | @cog.input(
253 | "output_format",
254 | type=str,
255 | default="Visualization",
256 | options=["Visualization", "JSON"],
257 | help="Output either a bar chart visualization or a JSON blob",
258 | )
259 | def predict(self, input, output_format):
260 |
261 | model = self.model.eval()
262 | input_length = self.input_length
263 | signal, _ = librosa.core.load(str(input), sr=SAMPLE_RATE)
264 | length = len(signal)
265 | hop = length // 2 - input_length // 2
266 | print("length, input_length", length, input_length)
267 | x = torch.zeros(1, input_length)
268 | x[0] = torch.Tensor(signal[hop:hop+input_length]).unsqueeze(0)
269 | x = Variable(x.to(self.device))
270 | print("x.max(), x.min(), x.mean()", x.max(), x.min(), x.mean())
271 | out = model(x)
272 | result = dict(zip(self.tags, out[0].detach().numpy().tolist()))
273 |
274 | if output_format == "JSON":
275 | return result
276 |
277 | result_list = list(sorted(result.items(), key=lambda x: x[1]))
278 | plt.figure(figsize=[5, 10])
279 | plt.barh(
280 | np.arange(len(result_list)), [r[1] for r in result_list], align="center"
281 | )
282 | plt.yticks(np.arange(len(result_list)), [r[0] for r in result_list])
283 | plt.tight_layout()
284 |
285 | out_path = Path(tempfile.mkdtemp()) / "out.png"
286 | plt.savefig(out_path)
287 | return out_path
288 |
--------------------------------------------------------------------------------
/mir_ref/features/models/mule.py:
--------------------------------------------------------------------------------
1 | """Code used for inference for MULE model.
2 | Source: https://github.com/PandoraMedia/music-audio-features:
3 | """
4 |
5 | import tempfile
6 |
7 | import numpy as np
8 | from scooch import ConfigList, Configurable, Param
9 |
10 |
11 | class Feature(Configurable):
12 | """
13 | The base class for all feature types.
14 | """
15 |
16 | def __del__(self):
17 | """
18 | **Destructor**
19 | """
20 | if hasattr(self, "_data_file"):
21 | self._data_file.close()
22 |
23 | def add_data(self, data):
24 | """
25 | Adds data to extend the object's current data via concatenation along the time axis.
26 | This is useful for populating data in chunks, where populating it all at once would
27 | cause excessive memory usage.
28 |
29 | Arg:
30 | data: np.ndarray - The data to be appended to the object's memmapped numpy array.
31 | """
32 | if not hasattr(self, "_data_file"):
33 | self._data_file = tempfile.NamedTemporaryFile(mode="w+b")
34 |
35 | if not hasattr(self, "_data") or self._data is None:
36 | original_data_size = 0
37 | else:
38 | original_data_size = self._data.shape[1]
39 | final_size = original_data_size + data.shape[1]
40 |
41 | filename = self._data_file.name
42 |
43 | self._data = np.memmap(
44 | filename,
45 | dtype="float32",
46 | mode="r+",
47 | shape=(data.shape[0], final_size),
48 | order="F",
49 | )
50 | self._data[:, original_data_size:] = data
51 |
52 | def _extract(self, source, length):
53 | """
54 | Extracts feature data file or other feature a given time-chunk.
55 |
56 | Args:
57 | source_feature: mule.features.Feature - The feature to transform.
58 |
59 | start_time: int - The index in the input at which to start extracting / transforming.
60 |
61 | chunk_size: int - The length of the chunk following the `start_time` to extract
62 | the feature from.
63 | """
64 | raise NotImplementedError(
65 | f"The {self.__name__.__class__} has no feature extraction method"
66 | )
67 |
68 | def save(self, path):
69 | """
70 | Save the feature data blob to disk.
71 |
72 | Args:
73 | path: str - The path to save the data to.
74 | """
75 | np.save(path, self.data)
76 |
77 | def clear(self):
78 | """
79 | Clears any previously analyzed feature data, ready for a new analysis.
80 | """
81 | self._data = None
82 | if hasattr(self, "_data_file"):
83 | self._data_file.close()
84 | del self._data_file
85 |
86 | @property
87 | def data(self):
88 | """
89 | The feature data blob itself.
90 | """
91 | return self._data
92 |
93 |
94 | class SourceFile(Configurable):
95 | """
96 | Base class for SCOOCH configurable file readers.
97 | """
98 |
99 | def load(self, fname):
100 | """
101 | Any preprocessing steps to load a file prior to reading it.
102 |
103 | Args:
104 | fname: file-like - A file like object to be loaded.
105 | """
106 | raise NotImplementedError(
107 | f"The class, {self.__class__.__name__}, has no method for loading files"
108 | )
109 |
110 | def read(self, n):
111 | """
112 | Reads an amount of data from the file.
113 |
114 | Args:
115 | n: int - A size parameter indicating the amount of data to read.
116 |
117 | Return:
118 | object - The decoded data read and in memory.
119 | """
120 | raise NotImplementedError(
121 | f"The class, {self.__class__.__name__}, has no method for reading files"
122 | )
123 |
124 | def close(self):
125 | """
126 | Closes any previously loaded file.
127 | """
128 | raise NotImplementedError(
129 | f"The class, {self.__class__.__name__}, has no method for closing files"
130 | )
131 |
132 | def __len__(self):
133 | raise NotImplementedError(
134 | f"The class, {self.__class__.__name__} has no method for determining file data length"
135 | )
136 |
137 |
138 | class SourceFeature(Feature):
139 | """
140 | A feature that is derived directly from raw data, e.g., a data file.
141 | """
142 |
143 | # SCOOCH Configuration
144 | _input_file = Param(
145 | SourceFile,
146 | doc="The file object defining the parameters of the raw data that this feature is constructed from.",
147 | )
148 |
149 | _CHUNK_SIZE = 44100 * 60 * 15
150 |
151 | # Methods
152 | def from_file(self, fname):
153 | """
154 | Takes a file and processes its data in chunks to form a feature.
155 |
156 | Args:
157 | fname: str - The path to the input file from which this feature is constructed.
158 | """
159 | # Load file
160 | self._input_file.load(fname)
161 |
162 | # Read samples into data
163 | processed_input_frames = 0
164 | while processed_input_frames < len(self._input_file):
165 | data = self._extract(
166 | self._input_file, processed_input_frames, self._CHUNK_SIZE
167 | )
168 | processed_input_frames += self._CHUNK_SIZE
169 | self.add_data(data)
170 |
171 | def clear(self):
172 | """
173 | Clears any previously analyzed feature data, ready for a new analysis.
174 | """
175 | super().clear()
176 | self._input_file.close()
177 |
178 | def __len__(self):
179 | """
180 | Returns the number of bytes / samples / indices in the input data file.
181 | """
182 | return len(self._input_file)
183 |
184 |
185 | class Extractor(Configurable):
186 | """
187 | Base class for classes that are responsible for extracting data
188 | from mule.features.Feature classes.
189 | """
190 |
191 | #
192 | # Methods
193 | #
194 | def extract_range(self, feature, start_index, end_index):
195 | """
196 | Extracts data over a given index range from a single feature.
197 |
198 | Args:
199 | feature: mule.feature.Feature - A feature to extract data from.
200 |
201 | start_index: int - The first index (inclusive) at which to return data.
202 |
203 | end_index: int - The last index (exclusive) at which to return data.
204 |
205 | Return:
206 | numpy.ndarray - The extracted feature data. Features on first axis, time on
207 | second axis.
208 | """
209 | raise NotImplementedError(
210 | f"The {self.__class__.__name__} class has no `extract_range` method."
211 | )
212 |
213 | def extract_batch(self, features, indices):
214 | """
215 | Extracts a batch of features from potentially multiple features, each potentially
216 | at distinct indices.
217 |
218 | Args:
219 | features: list(mule.features.Feature) - A list of features from which to extract
220 | data from.
221 |
222 | indices: list(int) - A list of indices, the same size as `features`. Each element
223 | provides an index at which to extract data from the coressponding element in the
224 | `features` argument.
225 |
226 | Return:
227 | np.ndarray - A batch of features, with features on the batch dimension on the first
228 | axis and feature data on the remaining axes.
229 | """
230 | raise NotImplementedError(
231 | f"The {self.__class__.__name__} class has no `extract_batch` method."
232 | )
233 |
234 |
235 | class TransformFeature(Feature):
236 | """
237 | Base class for all features that are transforms of other features.
238 | """
239 |
240 | #
241 | # SCOOCH Configuration
242 | #
243 | _extractor = Param(
244 | Extractor,
245 | doc="An object defining how data will be extracted from the input feature and provided to the transformation of this feature.",
246 | )
247 |
248 | # The size in time of each chunk that this feature will process at any one time.
249 | _CHUNK_SIZE = 44100 * 60 * 15
250 |
251 | #
252 | # Methods
253 | #
254 | def from_feature(self, source_feature):
255 | """
256 | Populates this features data as a transform of the provided input feature.
257 |
258 | Args:
259 | source_feature: mule.features.Feature - A feature from which this feature will
260 | be created as a transformation thereof.
261 | """
262 | boundaries = list(range(0, len(source_feature), self._CHUNK_SIZE)) + [
263 | len(source_feature)
264 | ]
265 | chunks = [(start, end) for start, end in zip(boundaries[:-1], boundaries[1:])]
266 | for start_time, end_time in chunks:
267 | data = self._extract(source_feature, start_time, end_time - start_time)
268 | if data is not None and len(data):
269 | self.add_data(data)
270 |
271 | def _extract(self, source_feature, start_time, chunk_size):
272 | """
273 | Extracts feature data as a transformation of a given source feature for a given
274 | time-chunk.
275 |
276 | Args:
277 | source_feature: mule.features.Feature - The feature to transform.
278 |
279 | start_time: int - The index in the feature at which to start extracting / transforming.
280 |
281 | chunk_size: int - The length of the chunk following the `start_time` to extract
282 | the feature from.
283 | """
284 | end_time = start_time + chunk_size
285 | return self._extractor.extract_range(source_feature, start_time, end_time)
286 |
287 | def __len__(self):
288 | if hasattr(self, "_data"):
289 | return self._data.shape[1]
290 | else:
291 | return 0
292 |
293 |
294 | class Analysis(Configurable):
295 | """
296 | A class encapsulating analysis of a single input file.
297 | """
298 |
299 | # SCOOCH Configuration
300 | _source_feature = Param(
301 | SourceFeature, doc="The feature used to decode the provided raw file data."
302 | )
303 | _feature_transforms = Param(
304 | ConfigList(TransformFeature),
305 | doc="Feature transformations to apply, in order, to the source feature generated from the input file.",
306 | )
307 |
308 | # Methods
309 | def analyze(self, fname):
310 | """
311 | Analyze features for a single filepath.
312 |
313 | Args:
314 | fname: str - The filename path, from which to generate features.
315 |
316 | Return:
317 | mule.features.Feature - The feature resulting from the configured feature
318 | transformations.
319 | """
320 | for feat in [self._source_feature] + self._feature_transforms:
321 | feat.clear()
322 |
323 | self._source_feature.from_file(fname)
324 | input_feature = self._source_feature
325 | for feature in self._feature_transforms:
326 | feature.from_feature(input_feature)
327 | input_feature = feature
328 |
329 | return input_feature
330 |
--------------------------------------------------------------------------------
/mir_ref/features/models/openl3.py:
--------------------------------------------------------------------------------
1 | """Code for OpenL3 embedding model inference.
2 | Source: https://gist.github.com/palonso/cfebe37e5492b5a3a31775d8eae8d9a8
3 | """
4 |
5 | from pathlib import Path
6 | import essentia.standard as es
7 | import numpy as np
8 | from essentia import Pool
9 |
10 |
11 | class MelSpectrogramOpenL3:
12 | def __init__(self, hop_time):
13 | self.hop_time = hop_time
14 |
15 | self.sr = 48000
16 | self.n_mels = 128
17 | self.frame_size = 2048
18 | self.hop_size = 242
19 | self.a_min = 1e-10
20 | self.d_range = 80
21 | self.db_ref = 1.0
22 |
23 | self.patch_samples = int(1 * self.sr)
24 | self.hop_samples = int(self.hop_time * self.sr)
25 |
26 | self.w = es.Windowing(
27 | size=self.frame_size,
28 | normalized=False,
29 | )
30 | self.s = es.Spectrum(size=self.frame_size)
31 | self.mb = es.MelBands(
32 | highFrequencyBound=self.sr / 2,
33 | inputSize=self.frame_size // 2 + 1,
34 | log=False,
35 | lowFrequencyBound=0,
36 | normalize="unit_tri",
37 | numberBands=self.n_mels,
38 | sampleRate=self.sr,
39 | type="magnitude",
40 | warpingFormula="slaneyMel",
41 | weighting="linear",
42 | )
43 |
44 | def compute(self, audio_file):
45 | audio = es.MonoLoader(filename=audio_file, sampleRate=self.sr)()
46 |
47 | batch = []
48 | for audio_chunk in es.FrameGenerator(
49 | audio, frameSize=self.patch_samples, hopSize=self.hop_samples
50 | ):
51 | melbands = np.array(
52 | [
53 | self.mb(self.s(self.w(frame)))
54 | for frame in es.FrameGenerator(
55 | audio_chunk,
56 | frameSize=self.frame_size,
57 | hopSize=self.hop_size,
58 | validFrameThresholdRatio=0.5,
59 | )
60 | ]
61 | )
62 |
63 | melbands = 10.0 * np.log10(np.maximum(self.a_min, melbands))
64 | melbands -= 10.0 * np.log10(np.maximum(self.a_min, self.db_ref))
65 | melbands = np.maximum(melbands, melbands.max() - self.d_range)
66 | melbands -= np.max(melbands)
67 |
68 | batch.append(melbands.copy())
69 |
70 | return np.vstack(batch)
71 |
72 |
73 | class EmbeddingsOpenL3:
74 | def __init__(self, graph_path, hop_time=1, batch_size=60, melbands=128):
75 | self.hop_time = hop_time
76 | self.batch_size = batch_size
77 |
78 | self.graph_path = Path(graph_path)
79 |
80 | self.x_size = 199
81 | self.y_size = melbands
82 | self.squeeze = False
83 |
84 | self.permutation = [0, 3, 2, 1]
85 |
86 | self.input_layer = "melspectrogram"
87 | self.output_layer = "embeddings"
88 |
89 | self.mel_extractor = MelSpectrogramOpenL3(hop_time=self.hop_time)
90 |
91 | self.model = es.TensorflowPredict(
92 | graphFilename=str(self.graph_path),
93 | inputs=[self.input_layer],
94 | outputs=[self.output_layer],
95 | squeeze=self.squeeze,
96 | )
97 |
98 | def compute(self, audio_file):
99 | mel_spectrogram = self.mel_extractor.compute(audio_file)
100 | # in OpenL3 the hop size is computed in the feature extraction level
101 |
102 | hop_size_samples = self.x_size
103 |
104 | batch = self.__melspectrogram_to_batch(mel_spectrogram, hop_size_samples)
105 |
106 | pool = Pool()
107 | embeddings = []
108 | nbatches = int(np.ceil(batch.shape[0] / self.batch_size))
109 | for i in range(nbatches):
110 | start = i * self.batch_size
111 | end = min(batch.shape[0], (i + 1) * self.batch_size)
112 | pool.set(self.input_layer, batch[start:end])
113 | out_pool = self.model(pool)
114 | embeddings.append(out_pool[self.output_layer].squeeze())
115 |
116 | return np.vstack(embeddings)
117 |
118 | def __melspectrogram_to_batch(self, melspectrogram, hop_time):
119 | npatches = int(np.ceil((melspectrogram.shape[0] - self.x_size) / hop_time) + 1)
120 | batch = np.zeros([npatches, self.x_size, self.y_size], dtype="float32")
121 | for i in range(npatches):
122 | last_frame = min(i * hop_time + self.x_size, melspectrogram.shape[0])
123 | first_frame = i * hop_time
124 | data_size = last_frame - first_frame
125 |
126 | # the last patch may be empty, remove it and exit the loop
127 | if data_size <= 0:
128 | batch = np.delete(batch, i, axis=0)
129 | break
130 | else:
131 | batch[i, :data_size] = melspectrogram[first_frame:last_frame]
132 |
133 | batch = np.expand_dims(batch, 1)
134 | batch = es.TensorTranspose(permutation=self.permutation)(batch)
135 |
136 | return batch
137 |
--------------------------------------------------------------------------------
/mir_ref/metrics.py:
--------------------------------------------------------------------------------
1 | """Custom metrics."""
2 |
3 |
4 | def key_detection_weighted_accuracy(y_true, y_pred):
5 | """Calculate weighted accuracy for key detection.
6 |
7 | Args:
8 | y_true (list): List of keys (e.g. C# major).
9 | y_pred (list): List of predicted keys.
10 |
11 | Returns:
12 | float: Weighted accuracy.
13 | """
14 | import mir_eval
15 | import numpy as np
16 |
17 | scores = []
18 | macro_scores = {}
19 |
20 | for truth, pred in zip(y_true, y_pred):
21 | score = mir_eval.key.weighted_score(truth, pred)
22 | scores.append(score)
23 | if truth not in macro_scores:
24 | macro_scores[truth] = []
25 | macro_scores[truth].append(score)
26 |
27 | # calculate macro scores
28 | macro_scores_mean = []
29 | for key, values in macro_scores.items():
30 | macro_scores_mean.append(np.mean(values))
31 |
32 | return {"micro": np.mean(scores), "macro": np.mean(macro_scores_mean)}
33 |
--------------------------------------------------------------------------------
/mir_ref/probes/probe_builder.py:
--------------------------------------------------------------------------------
1 | """Downstream models for feature evaluation, as well
2 | as helper code to construct and return models.
3 | """
4 |
5 | from keras import layers, regularizers
6 | from keras.models import Sequential
7 |
8 |
9 | def get_model(model_cfg, dim, n_classes):
10 | if model_cfg["type"] == "classifier":
11 | return classifier(model_cfg, dim, n_classes)
12 | else:
13 | raise ValueError(f"Model type '{model_cfg['type']}' not supported.")
14 |
15 |
16 | def classifier(model_cfg, dim, n_classes):
17 | """Classifier with configurable number of layers."""
18 |
19 | # if "infer" for hidden units, infer them
20 | if (model_cfg["hidden_units"] is not None) and (
21 | len(model_cfg["hidden_units"]) != 0
22 | ):
23 | if model_cfg["hidden_units"][0] == "power_infer":
24 | # get layer sizes with power of 2 regression
25 | # y = alpha * x ^ 2 + c, where y = emb_shape[0], c = n_classes,
26 | # and x = n_hidden_layers + 1
27 | alpha = (dim[0] - n_classes) / (len(model_cfg["hidden_units"]) + 1)
28 | for i in range(len(model_cfg["hidden_units"])):
29 | hu = int(alpha * ((i + 1) ** 2) + n_classes)
30 | if hu % 2 != 0:
31 | hu += 1
32 | model_cfg["hidden_units"][i] = hu
33 |
34 | print(f"Hidden units inferred, using {model_cfg['hidden_units']}.")
35 |
36 | if model_cfg["hidden_units"][0] == "infer":
37 | n_layers = len(model_cfg["hidden_units"])
38 | step_size = (dim[0] - n_classes) / (n_layers + 1)
39 | for i in range(len(model_cfg["hidden_units"])):
40 | hu = int(n_classes + ((n_layers - i) * step_size))
41 | if hu % 2 != 0:
42 | hu += 1
43 | model_cfg["hidden_units"][i] = hu
44 |
45 | print(f"Hidden units inferred, using {model_cfg['hidden_units']}.")
46 |
47 | model = Sequential()
48 |
49 | # add hidden layers
50 | for i, hu in enumerate(model_cfg["hidden_units"]):
51 | if i == 0:
52 | model.add(
53 | layers.Dense(
54 | units=hu,
55 | activation="relu",
56 | name=f"hidden_layer_{i}",
57 | input_shape=tuple(dim),
58 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]),
59 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]),
60 | )
61 | )
62 | else:
63 | model.add(
64 | layers.Dense(
65 | units=hu,
66 | activation="relu",
67 | name=f"hidden_layer_{i}",
68 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]),
69 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]),
70 | )
71 | )
72 |
73 | # add output layer
74 | if (model_cfg["hidden_units"] is not None) and (
75 | len(model_cfg["hidden_units"]) != 0
76 | ):
77 | model.add(
78 | layers.Dense(
79 | units=n_classes,
80 | activation=model_cfg["output_activation"],
81 | name="output_layer",
82 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]),
83 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]),
84 | )
85 | )
86 | else:
87 | model.add(
88 | layers.Dense(
89 | units=n_classes,
90 | activation=model_cfg["output_activation"],
91 | name="output_layer",
92 | input_shape=tuple(dim),
93 | kernel_regularizer=regularizers.L2(model_cfg["weight_decay"]),
94 | bias_regularizer=regularizers.L2(model_cfg["weight_decay"]),
95 | )
96 | )
97 | model.build()
98 |
99 | return model
100 |
--------------------------------------------------------------------------------
/mir_ref/train.py:
--------------------------------------------------------------------------------
1 | """Train downstream models.
2 | """
3 |
4 | from pathlib import Path
5 |
6 | import keras
7 | import numpy as np
8 | import tensorflow as tf
9 | from colorama import Fore, Style
10 | from sklearn.model_selection import ParameterGrid
11 | from tensorflow.keras import losses, optimizers
12 | from tensorflow.keras.callbacks import EarlyStopping, TensorBoard
13 |
14 | from mir_ref.dataloaders import DataGenerator
15 | from mir_ref.datasets.dataset import get_dataset
16 | from mir_ref.probes.probe_builder import get_model
17 | from mir_ref.utils import load_config
18 |
19 |
20 | def train(cfg_path, run_id=None):
21 | """Make a grid of all combinations of dataset, embedding models,
22 | downstream models and splits and call training for each.
23 |
24 | Args:
25 | cfg_path (str): Path to config file.
26 | run_id (str, optional): Experiment ID, timestamp if not specified.
27 | """
28 | cfg = load_config(cfg_path)
29 |
30 | if run_id is None:
31 | import datetime
32 |
33 | run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
34 |
35 | for exp_cfg in cfg["experiments"]:
36 | run_params = {
37 | "dataset_cfg": exp_cfg["datasets"],
38 | "feature": exp_cfg["features"],
39 | "model_cfg": exp_cfg["probes"],
40 | }
41 |
42 | # create grid from parameters
43 | grid = ParameterGrid(run_params)
44 |
45 | for params in grid:
46 | # get index of downstream model for naming-logging
47 | model_idx = exp_cfg["probes"].index(params["model_cfg"])
48 | # get dataset object
49 | dataset = get_dataset(
50 | dataset_cfg=params["dataset_cfg"],
51 | task_cfg=exp_cfg["task"],
52 | features_cfg=exp_cfg["features"],
53 | )
54 | dataset.download()
55 | # !!!the following only works if there's one experiment per config
56 | if dataset.task_type == "multiclass_classification":
57 | try:
58 | for metric in metrics:
59 | metric.reset_states()
60 | except UnboundLocalError:
61 | metrics = [
62 | keras.metrics.CategoricalAccuracy(),
63 | keras.metrics.Precision(),
64 | keras.metrics.Recall(),
65 | keras.metrics.AUC(),
66 | ]
67 | elif dataset.task_type == "multilabel_classification":
68 | try:
69 | for metric in metrics:
70 | metric.reset_states()
71 | except UnboundLocalError:
72 | metrics = [
73 | keras.metrics.Precision(),
74 | keras.metrics.Recall(),
75 | keras.metrics.AUC(curve="ROC"),
76 | keras.metrics.AUC(curve="PR"),
77 | ]
78 | # run task for every split
79 | for split_idx in range(len(dataset.get_splits())):
80 | print(
81 | Fore.GREEN
82 | + f"Task: {dataset.task_name}\n"
83 | + f"└── Dataset: {dataset.name}\n"
84 | + f" └── Embeddings: {params['feature']}\n"
85 | + f" └── Model: {model_idx}\n"
86 | + f" └── Split: {split_idx}",
87 | Style.RESET_ALL,
88 | )
89 | train_probe(
90 | run_id=run_id,
91 | dataset=dataset,
92 | model_cfg=params["model_cfg"],
93 | model_idx=model_idx,
94 | feature=params["feature"],
95 | split_idx=split_idx,
96 | metrics=metrics,
97 | )
98 |
99 |
100 | def train_probe(run_id, dataset, model_cfg, model_idx, feature, split_idx, metrics):
101 | """Train a single model per split given parameters.
102 |
103 | Args:
104 | run_id (str): ID of the current run, defaults to timestamp.
105 | dataset (Dataset): Dataset object.
106 | model_cfg (dict): Downstream model config.
107 | model_idx (int): Index of the downstream model in the list of models.
108 | feature (str): Name of the embedding model.
109 | split_idx (int): Index of the split in the list of splits.
110 | metrics (list): List of metrics to use for training.
111 | """
112 |
113 | split = dataset.get_splits()[split_idx]
114 | n_classes = len(dataset.encoded_labels[dataset.track_ids[0]])
115 |
116 | if model_cfg["emb_shape"] == "infer":
117 | # get embedding shape from the first embedding
118 | emb_shape = np.load(
119 | dataset.get_embedding_path(
120 | feature=feature,
121 | track_id=split["train"][0],
122 | )
123 | ).shape
124 | elif isinstance(model_cfg["emb_shape"], int):
125 | emb_shape = model_cfg["emb_shape"]
126 | elif isinstance(model_cfg["emb_shape"], str):
127 | raise ValueError(f"{model_cfg['emb_shape']} not implemented.")
128 |
129 | model = get_model(model_cfg=model_cfg, dim=emb_shape, n_classes=n_classes)
130 | model.summary()
131 |
132 | tr_gen = DataGenerator(
133 | ids_list=split["train"],
134 | labels_dict=dataset.encoded_labels,
135 | paths_dict={
136 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id)
137 | for t_id in split["train"]
138 | },
139 | batch_size=model_cfg["batch_size"],
140 | dim=emb_shape,
141 | n_classes=n_classes,
142 | shuffle=True,
143 | )
144 | val_gen = DataGenerator(
145 | ids_list=split["validation"],
146 | labels_dict=dataset.encoded_labels,
147 | paths_dict={
148 | t_id: dataset.get_embedding_path(feature=feature, track_id=t_id)
149 | for t_id in split["validation"]
150 | },
151 | batch_size=1,
152 | dim=emb_shape,
153 | n_classes=n_classes,
154 | shuffle=True,
155 | )
156 |
157 | # dir for tensorboard logs and weights for this run
158 | run_dir = (
159 | Path(run_id) / dataset.task_name / f"{dataset.name}_{feature}_model-{model_idx}"
160 | )
161 |
162 | # if the dir already exists, change run dir for duplicates
163 | if (Path("./logs") / run_dir).exists():
164 | i = 1
165 | while (Path("./logs") / run_dir).exists():
166 | run_dir = (
167 | Path(run_id)
168 | / dataset.task_name
169 | / f"{dataset.name}_{feature}_model-{model_idx} ({i})"
170 | )
171 | i += 1
172 | import warnings
173 |
174 | # raise warning about existing experiment
175 | warnings.warn(
176 | f"Model in '{dataset.name}_{feature}_model-{model_idx}' already exists. "
177 | f"Renaming new run to '{run_dir}'",
178 | stacklevel=2,
179 | )
180 |
181 | # create dir for this run if it doesn't exist
182 | if not (Path("./logs") / run_dir).exists():
183 | (Path("./logs") / run_dir).mkdir(parents=True)
184 |
185 | # save model config in run dir
186 | with open(Path("./logs") / run_dir / "model_config.yml", "w+") as f:
187 | yaml.dump(model_cfg, f)
188 |
189 | if dataset.task_type == "multiclass_classification":
190 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
191 | save_weights_only=False,
192 | filepath=str(Path("./logs") / run_dir / "weights.h5"),
193 | save_best_only=True,
194 | monitor="val_categorical_accuracy",
195 | mode="max",
196 | )
197 | elif dataset.task_type == "multilabel_classification":
198 | checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
199 | save_weights_only=False,
200 | filepath=str(Path("./logs") / run_dir / "weights.h5"),
201 | save_best_only=True,
202 | monitor="val_auc",
203 | mode="max",
204 | )
205 | callbacks = [
206 | EarlyStopping(patience=model_cfg["patience"]),
207 | TensorBoard(log_dir=str(Path("./logs") / run_dir)),
208 | checkpoint_callback,
209 | ]
210 |
211 | # make sure all metric and callback states are reset
212 | for callback in callbacks:
213 | if hasattr(callback, "reset_state"):
214 | callback.reset_state()
215 |
216 | # loss and optimizer
217 | if dataset.task_type == "multiclass_classification":
218 | loss = losses.CategoricalCrossentropy(from_logits=False)
219 | elif dataset.task_type == "multilabel_classification":
220 | loss = losses.BinaryCrossentropy()
221 | else:
222 | raise ValueError(f"Task type '{dataset.task_type}' not implemented.")
223 |
224 | if model_cfg["optimizer"] == "adam":
225 | optimizer = optimizers.Adam(learning_rate=model_cfg["learning_rate"])
226 | else:
227 | raise ValueError(f"Optimizer '{model_cfg['optimizer']}' not implemented.")
228 |
229 | model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
230 |
231 | model.fit(
232 | x=tr_gen,
233 | validation_data=val_gen,
234 | batch_size=model_cfg["batch_size"],
235 | validation_batch_size=1,
236 | epochs=model_cfg["epochs"],
237 | callbacks=callbacks,
238 | use_multiprocessing=True,
239 | workers=4,
240 | verbose=1,
241 | )
242 |
--------------------------------------------------------------------------------
/mir_ref/utils.py:
--------------------------------------------------------------------------------
1 | """Various shared utilities."""
2 |
3 | import yaml
4 |
5 |
6 | def raise_missing_param(param, exp_idx, parent=None):
7 | """Raise an error for a missing parameter."""
8 | if not parent:
9 | raise ValueError(
10 | f"Missing required parameter: '{param}' in experiment {exp_idx}."
11 | )
12 | else:
13 | raise ValueError(
14 | f"Missing required parameter: '{param}' in '{parent}' of experiment {exp_idx}."
15 | )
16 |
17 |
18 | def load_config(cfg_path):
19 | """Load a YAML config file. Check formatting, and add
20 | missing keys."""
21 | with open(cfg_path, "r") as f:
22 | cfg = yaml.safe_load(f)
23 |
24 | if "experiments" not in cfg:
25 | raise ValueError("'experiments' missing, please check config file structure.")
26 |
27 | for i, exp in enumerate(cfg["experiments"]):
28 | # check top-level parameters
29 | for top_level_param in [
30 | "task",
31 | "datasets",
32 | "features",
33 | "probes",
34 | ]:
35 | if top_level_param not in exp:
36 | raise_missing_param(param=top_level_param, exp_idx=i)
37 |
38 | # check task parameters
39 | for task_param in ["name", "type"]:
40 | if task_param not in exp["task"]:
41 | raise_missing_param(param=task_param, exp_idx=i, parent="task")
42 | if "feature_aggregation" not in exp["task"]:
43 | exp["task"]["feature_aggregation"] = "mean"
44 |
45 | # check dataset parameters
46 | for j, dataset in enumerate(exp["datasets"]):
47 | for dataset_param in ["name", "dir"]:
48 | if dataset_param not in dataset:
49 | raise_missing_param(
50 | param=dataset_param, exp_idx=i, parent="datasets"
51 | )
52 | if "split_type" not in dataset:
53 | cfg[i]["datasets"][j]["split_type"] = "random"
54 | if "deformations" not in dataset:
55 | cfg[i]["datasets"][j]["deformations"] = []
56 |
57 | # check downstream model parameters
58 | for j, model in enumerate(exp["probes"]):
59 | for model_param in ["type"]:
60 | if model_param not in model:
61 | raise_missing_param(param=model_param, exp_idx=i, parent="probes")
62 | if "emb_dim_reduction" not in model:
63 | cfg[i]["probes"][j]["emb_dim_reduction"] = None
64 | if "emb_shape" not in model:
65 | cfg[i]["probes"][j]["emb_shape"] = None
66 | if "hidden_units" not in model:
67 | cfg[i]["probes"][j]["hidden_units"] = []
68 | if "output_activation" not in model:
69 | if exp["task"]["type"] == "multiclass_classification":
70 | cfg[i]["probes"][j]["output_activation"] = "softmax"
71 | elif exp["task"]["type"] == "multilabel_classification":
72 | cfg[i]["probes"][j]["output_activation"] = "sigmoid"
73 | if "weight_decay" not in model:
74 | cfg[i]["probes"][j]["weight_decay"] = 0.0
75 | if "optimizer" not in model:
76 | cfg[i]["probes"][j]["optimizer"] = "adam"
77 | if "learning_rate" not in model:
78 | cfg[i]["probes"][j]["learning_rate"] = 1e-3
79 | if "batch_size" not in model:
80 | cfg[i]["probes"][j]["batch_size"] = 1
81 | if "epochs" not in model:
82 | cfg[i]["probes"][j]["epochs"] = 100
83 | if "patience" not in model:
84 | cfg[i]["probes"][j]["patience"] = 10
85 | if "train_sampling" not in model:
86 | cfg[i]["probes"][j]["train_sampling"] = "random"
87 |
88 | return cfg
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | audiomentations
2 | colorama
3 | essentia-tensorflow==2.1b6.dev1110
4 | librosa
5 | pyyaml
6 | scikit-learn
7 | tensorflow==2.15
8 | torch==2.0.1
9 | torchaudio==2.0.2
10 | torchvision==0.15.2
11 | tqdm
12 | wget
13 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | """Script to invoke embedding generation and evaluation.
2 | """
3 |
4 | import argparse
5 | import os
6 |
7 | from mir_ref.conduct import conduct
8 | from mir_ref.deform import deform
9 | from mir_ref.evaluate import evaluate
10 | from mir_ref.extract import generate
11 | from mir_ref.train import train
12 |
13 |
14 | def main():
15 | parser = argparse.ArgumentParser()
16 |
17 | subparsers = parser.add_subparsers(dest="command")
18 |
19 | # End to end
20 | parser_conduct = subparsers.add_parser("conduct")
21 | parser_conduct.add_argument(
22 | "--config",
23 | "-c",
24 | default="configs/default.yml",
25 | help="Path of configuration file.",
26 | )
27 |
28 | # Audio deformation
29 | parser_deform = subparsers.add_parser("deform")
30 | parser_deform.add_argument(
31 | "--config",
32 | "-c",
33 | default="configs/default.yml",
34 | help="Path of configuration file.",
35 | )
36 | parser_deform.add_argument(
37 | "--n_jobs", default=1, type=int, help="Number of parallel jobs"
38 | )
39 |
40 | # Feature extraction
41 | parser_extract = subparsers.add_parser("extract")
42 | parser_extract.add_argument(
43 | "--config",
44 | "-c",
45 | default="configs/default.yml",
46 | help="Path of configuration file.",
47 | )
48 | parser_extract.add_argument(
49 | "--skip_clean",
50 | action="store_true",
51 | help="Skip extracting features from the clean audio.",
52 | )
53 | parser_extract.add_argument(
54 | "--skip_deformed",
55 | action="store_true",
56 | help="Skip extracting features from the deformed audio.",
57 | )
58 | parser_extract.add_argument(
59 | "--no_overwrite",
60 | action="store_true",
61 | help="Skip extracting features if they already exist.",
62 | )
63 | parser_extract.add_argument(
64 | "--deform_list",
65 | default=None,
66 | help="Deformation scenario indices to extract features for. Arguments as comma-separated integers, e.g. 0,1,2,3",
67 | )
68 |
69 | # Training
70 | parser_train = subparsers.add_parser("train")
71 | parser_train.add_argument(
72 | "--config",
73 | "-c",
74 | default="configs/default.yml",
75 | help="Path of configuration file.",
76 | )
77 | parser_train.add_argument(
78 | "--run_id",
79 | default=None,
80 | help="Optional experiment ID, otherwise timestamp is used.",
81 | )
82 |
83 | # Evaluation
84 | parser_evaluate = subparsers.add_parser("evaluate")
85 | parser_evaluate.add_argument(
86 | "--config",
87 | "-c",
88 | default="configs/default.yml",
89 | help="Path of configuration file.",
90 | )
91 | parser_evaluate.add_argument(
92 | "--run_id",
93 | default=None,
94 | help="Experiment ID to evaluate, otherwise retrieves latest if timestamp is available.",
95 | )
96 |
97 | args = parser.parse_args()
98 |
99 | if args.command == "conduct":
100 | conduct(cfg_path=os.path.join("./configs/", args.config + ".yml"))
101 |
102 | if args.command == "deform":
103 | deform(
104 | cfg_path=os.path.join("./configs/", args.config + ".yml"),
105 | n_jobs=args.n_jobs,
106 | )
107 | elif args.command == "extract":
108 | if args.deform_list:
109 | args.deform_list = [int(i) for i in args.deform_list.split(",")]
110 | generate(
111 | cfg_path=os.path.join("./configs/", args.config + ".yml"),
112 | skip_clean=args.skip_clean,
113 | skip_deformed=args.skip_deformed,
114 | no_overwrite=args.no_overwrite,
115 | deform_list=args.deform_list,
116 | )
117 | elif args.command == "train":
118 | train(
119 | cfg_path=os.path.join("./configs/", args.config + ".yml"),
120 | run_id=args.run_id,
121 | )
122 | elif args.command == "evaluate":
123 | evaluate(
124 | cfg_path=os.path.join("./configs/", args.config + ".yml"),
125 | run_id=args.run_id,
126 | )
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/tests/test_cfg.yml:
--------------------------------------------------------------------------------
1 | # Configuration file for tests
2 | experiments:
3 | - task: # single task per experiment
4 | name: instrument_classification
5 | type: multiclass_classification
6 | feature_aggregation: mean
7 | datasets:
8 | - name: tinysol
9 | type: mirdata
10 | dir: tests/data/tinysol/
11 | split_type: single
12 | deformations:
13 | - - type: AddGaussianSNR
14 | params:
15 | min_snr_db: 15
16 | max_snr_db: 15
17 | p: 1
18 | - type: Gain
19 | params:
20 | min_gain_db: -12
21 | max_gain_db: -12
22 | p: 1
23 | - - type: Mp3Compression
24 | params:
25 | min_bitrate: 32
26 | max_bitrate: 32
27 | p: 1
28 | features:
29 | - vggish-audioset
30 | - effnet-discogs
31 | - msd-musicnn
32 | - openl3
33 | - neuralfp
34 | probes:
35 | - type: classifier
36 | emb_dim_reduction: False
37 | emb_shape: infer
38 | hidden_units: [infer]
39 | output_activation: softmax
40 | weight_decay: 1.0e-5
41 | # optimizer
42 | optimizer: adam
43 | learning_rate: 1.0e-2
44 | # training
45 | batch_size: 16
46 | epochs: 2
47 | patience: 10
48 | train_sampling: random
49 | - type: classifier
50 | emb_dim_reduction: False
51 | emb_shape: infer
52 | hidden_units: [infer, infer]
53 | output_activation: softmax
54 | weight_decay: 1.0e-5
55 | # optimizer
56 | optimizer: adam
57 | learning_rate: 1.0e-2
58 | # training
59 | batch_size: 16
60 | epochs: 2
61 | patience: 10
62 | train_sampling: random
63 | - type: classifier
64 | emb_dim_reduction: False
65 | emb_shape: infer
66 | hidden_units: [64]
67 | output_activation: softmax
68 | weight_decay: 1.0e-5
69 | # optimizer
70 | optimizer: adam
71 | learning_rate: 1.0e-2
72 | # training
73 | batch_size: 16
74 | epochs: 50
75 | patience: 10
76 | train_sampling: random
77 | - type: classifier
78 | emb_dim_reduction: False
79 | emb_shape: infer
80 | hidden_units: [96, 64]
81 | output_activation: softmax
82 | weight_decay: 1.0e-5
83 | # optimizer
84 | optimizer: adam
85 | learning_rate: 1.0e-2
86 | # training
87 | batch_size: 16
88 | epochs: 50
89 | patience: 10
90 | train_sampling: random
91 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | """Tests for dataset objects."""
2 |
3 | import sys
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import yaml
8 |
9 | sys.path.append(str(Path(__file__).resolve().parent.parent))
10 |
11 | from mir_ref.datasets.dataset import get_dataset
12 |
13 | # load configuration file, used in all tests
14 | # it uses the tinysol dataset, implemented with mirdata
15 | with open("./tests/test_cfg.yml", "r") as f:
16 | exp_cfg = yaml.safe_load(f)["experiments"][0]
17 |
18 | dataset = get_dataset(
19 | exp_cfg["datasets"][0],
20 | exp_cfg["task"],
21 | exp_cfg["features"],
22 | )
23 |
24 |
25 | def test_download_metadata():
26 | dataset.download_metadata()
27 | assert Path("./tests/data/tinysol/annotation").is_dir()
28 | assert Path("./tests/data/tinysol/annotation/TinySOL_metadata.csv").exists()
29 |
30 |
31 | def test_load_metadata():
32 | dataset.load_metadata()
33 | assert (
34 | dataset.common_audio_dir == "tests/data/tinysol/audio"
35 | or dataset.common_audio_dir == "tests/data/tinysol/audio/"
36 | )
37 | assert dataset.track_ids[0] == "BTb-ord-F#1-pp-N-N"
38 | assert dataset.labels["BTb-ord-F#1-pp-N-N"] == "Bass Tuba"
39 | assert len(set(dataset.labels.values())) == 14
40 | categorical_encoded_label = list(np.zeros((14)).astype(np.float32))
41 | categorical_encoded_label[2] = 1.0
42 | assert np.array_equal(
43 | dataset.encoded_labels["BTb-ord-F#1-pp-N-N"], categorical_encoded_label
44 | )
45 | assert (
46 | dataset.audio_paths["BTb-ord-F#1-pp-N-N"]
47 | == "tests/data/tinysol/audio/Brass/Bass_Tuba/ordinario/"
48 | + "BTb-ord-F#1-pp-N-N.wav"
49 | )
50 |
51 |
52 | def test_extended_metadata():
53 | dataset.load_metadata()
54 |
55 | deformed_audio_path = dataset.get_deformed_audio_path("BTb-ord-F#1-pp-N-N", 0)
56 | assert deformed_audio_path == (
57 | "tests/data/tinysol/audio_deformed/Brass/Bass_Tuba/ordinario/"
58 | + "BTb-ord-F#1-pp-N-N_deform_0.wav"
59 | )
60 | emb_path = dataset.get_embedding_path("BTb-ord-F#1-pp-N-N", "vggish-audioset")
61 | assert emb_path == (
62 | "tests/data/tinysol/embeddings/vggish-audioset/Brass/Bass_Tuba/ordinario/"
63 | + "BTb-ord-F#1-pp-N-N.npy"
64 | )
65 | deformed_emb_path = dataset.get_deformed_embedding_path(
66 | "BTb-ord-F#1-pp-N-N", 0, "vggish-audioset"
67 | )
68 | assert deformed_emb_path == (
69 | "tests/data/tinysol/embeddings/vggish-audioset/Brass/Bass_Tuba/ordinario/"
70 | + "BTb-ord-F#1-pp-N-N_deform_0.npy"
71 | )
72 |
73 |
74 | test_load_metadata()
75 |
--------------------------------------------------------------------------------
/tests/test_deform.py:
--------------------------------------------------------------------------------
1 | """Tests for audio deformations."""
2 |
3 | import shutil
4 | import sys
5 | from pathlib import Path
6 |
7 | import yaml
8 |
9 | sys.path.append(str(Path(__file__).resolve().parent.parent))
10 |
11 | with open("./tests/test_cfg.yml", "r") as f:
12 | exp_cfg = yaml.safe_load(f)["experiments"][0]
13 |
14 | # we need to load and download the dataset before we can test the deformations
15 | from mir_ref.datasets.dataset import get_dataset
16 | from mir_ref.deformations import generate_deformations
17 |
18 | dataset = get_dataset(
19 | dataset_cfg=exp_cfg["datasets"][0],
20 | task_cfg=exp_cfg["task"],
21 | features_cfg=exp_cfg["features"],
22 | )
23 |
24 | dataset.download()
25 | dataset.preprocess()
26 | dataset.load_metadata()
27 |
28 | # keep only a few track_ids for testing
29 | # this also tests if everything is correctly anchored to the dataset object
30 | # and its track_ids
31 | # unfortunately we can't currently download only a few tracks from an mirdata dataset
32 | dataset.track_ids = dataset.track_ids[:5]
33 | first_track_id = dataset.track_ids[0]
34 |
35 |
36 | def test_single_threaded_deformations():
37 | generate_deformations(
38 | dataset,
39 | n_jobs=1,
40 | )
41 | assert Path(
42 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0)
43 | ).exists()
44 | assert Path(
45 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1)
46 | ).exists()
47 | assert not Path(
48 | dataset.audio_paths[first_track_id].replace(".wav", "_deform_2.wav")
49 | ).exists()
50 | assert not Path(
51 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0).replace(
52 | ".wav", "_deform_0.wav"
53 | )
54 | ).exists()
55 |
56 | # delete computed deformations
57 | shutil.rmtree("tests/data/tinysol/audio_deformed")
58 |
59 |
60 | def test_multi_threaded_deformations():
61 | generate_deformations(
62 | dataset,
63 | n_jobs=2,
64 | )
65 | assert Path(
66 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0)
67 | ).exists()
68 | assert Path(
69 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1)
70 | ).exists()
71 | assert not Path(
72 | dataset.audio_paths[first_track_id].replace(".wav", "_deform_2.wav")
73 | ).exists()
74 | assert not Path(
75 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0).replace(
76 | ".wav", "_deform_0.wav"
77 | )
78 | ).exists()
79 |
80 | # delete computed deformations
81 | shutil.rmtree("tests/data/tinysol/audio_deformed")
82 |
--------------------------------------------------------------------------------
/tests/test_generate.py:
--------------------------------------------------------------------------------
1 | """Tests embedding inference."""
2 |
3 | import shutil
4 | import sys
5 | from pathlib import Path
6 |
7 | import yaml
8 |
9 | sys.path.append(str(Path(__file__).resolve().parent.parent))
10 |
11 | with open("./tests/test_cfg.yml", "r") as f:
12 | exp_cfg = yaml.safe_load(f)["experiments"][0]
13 |
14 | # we need to load and download the dataset before we can test the deformations
15 | from mir_ref.datasets.dataset import get_dataset
16 | from mir_ref.features.feature_extraction import generate_embeddings
17 |
18 | dataset = get_dataset(
19 | dataset_cfg=exp_cfg["datasets"][0],
20 | task_cfg=exp_cfg["task"],
21 | features_cfg=exp_cfg["features"],
22 | )
23 |
24 | dataset.download()
25 | dataset.preprocess()
26 | dataset.load_metadata()
27 |
28 | # keep only a few track_ids for testing
29 | # this also tests if everything is correctly anchored to the dataset object
30 | # and its track_ids
31 | # unfortunately we can't currently download only a few tracks from an mirdata dataset
32 | dataset.track_ids = dataset.track_ids[:5]
33 | first_track_id = dataset.track_ids[0]
34 |
35 |
36 | def test_models():
37 | dataset.deformations_cfg = None
38 | for model_name in exp_cfg["features"]:
39 | generate_embeddings(
40 | dataset,
41 | model_name=model_name,
42 | )
43 |
44 | assert Path(
45 | dataset.get_embedding_path(track_id=first_track_id, feature=model_name)
46 | ).exists()
47 |
48 | # delete computed embeddings
49 | shutil.rmtree("tests/data/tinysol/embeddings")
50 |
51 |
52 | def test_models_deformed_audio():
53 | dataset.deformations_cfg = exp_cfg["deformations"]
54 | # we need to first compute the deformations
55 | from mir_ref.deformations import generate_deformations
56 |
57 | generate_deformations(
58 | dataset,
59 | n_jobs=2,
60 | )
61 |
62 | for model_name in exp_cfg["features"]:
63 | generate_embeddings(
64 | dataset,
65 | model_name=model_name,
66 | )
67 | assert Path(
68 | dataset.get_embedding_path(track_id=first_track_id, feature=model_name)
69 | ).exists()
70 | assert Path(
71 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=0)
72 | )
73 | assert Path(
74 | dataset.get_deformed_audio_path(track_id=first_track_id, deform_idx=1)
75 | )
76 |
77 | # delete computed deformations
78 | shutil.rmtree("tests/data/tinysol/audio_deformed")
79 |
80 | # delete computed embeddings
81 | shutil.rmtree("tests/data/tinysol/embeddings")
82 |
83 |
84 | # test_models()
85 |
--------------------------------------------------------------------------------