├── .github
└── workflows
│ └── clmr.yml
├── .gitignore
├── LICENSE
├── README.md
├── _config.yml
├── clmr
├── data.py
├── datasets
│ ├── __init__.py
│ ├── audio.py
│ ├── dataset.py
│ ├── gtzan.py
│ ├── librispeech.py
│ ├── magnatagatune.py
│ └── million_song_dataset.py
├── evaluation.py
├── models
│ ├── __init__.py
│ ├── model.py
│ ├── sample_cnn.py
│ ├── sample_cnn_xl.py
│ ├── shortchunk_cnn.py
│ └── sinc_net.py
├── modules
│ ├── __init__.py
│ ├── callbacks.py
│ ├── contrastive_learning.py
│ ├── linear_evaluation.py
│ └── supervised_learning.py
└── utils
│ ├── __init__.py
│ ├── checkpoint.py
│ └── yaml_config_hook.py
├── config
└── config.yaml
├── examples
├── .DS_Store
└── clmr-onnxruntime-web
│ ├── .DS_Store
│ ├── clmr_sample-cnn.onnx
│ ├── data
│ ├── .DS_Store
│ ├── fisher_losing_it.mp3
│ ├── fisher_losing_it_22050.mp3
│ ├── john_lennon_imagine.mp3
│ ├── john_lennon_imagine_22050.mp3
│ ├── nirvana_smells_like_teen_spirit.mp3
│ ├── nirvana_smells_like_teen_spirit_22050.mp3
│ ├── queen_love_of_my_life.mp3
│ └── queen_love_of_my_life_22050.mp3
│ └── index.html
├── export.py
├── linear_evaluation.py
├── main.py
├── media
└── clmr_model.png
├── preprocess.py
├── requirements.txt
├── setup.py
└── tests
├── .DS_Store
├── __init__.py
├── data
└── audioset
│ └── 1272-128104-0000.wav
├── test_audioset.py
├── test_dataset.py
└── test_spectogram.py
/.github/workflows/clmr.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CLMR
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.9
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.9
23 | - name: Install dependencies
24 | run: |
25 | sudo apt-get install libsndfile1
26 | python -m pip install --upgrade pip
27 | pip install black pytest
28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
29 | - name: Lint with black
30 | run: |
31 | black --check .
32 | - name: Test with pytest
33 | run: |
34 | pytest
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__
3 | /data
4 | runs/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Contrastive Learning of Musical Representations
2 |
3 | PyTorch implementation of [Contrastive Learning of Musical Representations](https://arxiv.org/abs/2103.09410) by Janne Spijkervet and John Ashley Burgoyne.
4 |
5 | 
6 | [](https://colab.research.google.com/drive/1Njz8EoN4br587xjpRKcssMuqQY6Cc5nj#scrollTo=aeKVT59FhWzV)
7 |
8 | [](https://arxiv.org/abs/2103.09410)
9 | [](https://github.com/Spijkervet/CLMR/releases/download/2.1/CLMR.-.Supplementary.Material.pdf)
10 |
11 |
17 |
18 | You can run a pre-trained CLMR model directly from within your browser using ONNX Runtime: [here](https://spijkervet.github.io/CLMR/examples/clmr-onnxruntime-web).
19 |
20 |
21 | In this work, we introduce SimCLR to the music domain and contribute a large chain of audio data augmentations, to form a simple framework for self-supervised learning of raw waveforms of music: CLMR. We evaluate the performance of the self-supervised learned representations on the task of music classification.
22 |
23 | - We achieve competitive results on the MagnaTagATune and Million Song Datasets relative to fully supervised training, despite only using a linear classifier on self-supervised learned representations, i.e., representations that were learned task-agnostically without any labels.
24 | - CLMR enables efficient classification: with only 1% of the labeled data, we achieve similar scores compared to using 100% of the labeled data.
25 | - CLMR is able to generalise to out-of-domain datasets: when training on entirely different music datasets, it is still able to perform competitively compared to fully supervised training on the target dataset.
26 |
27 | *This is the CLMR v2 implementation, for the original implementation go to the [`v1`](https://github.com/Spijkervet/CLMR/tree/v1) branch*
28 |
29 |
30 |
31 |
32 |
33 | An illustration of CLMR.
34 |
35 |
36 |
37 | This repository relies on my SimCLR implementation, which can be found [here](https://github.com/spijkervet/simclr) and on my `torchaudio-augmentations` package, found [here](https://github.com/Spijkervet/torchaudio-augmentations).
38 |
39 |
40 |
41 | ## Quickstart
42 | ```
43 | git clone https://github.com/spijkervet/clmr.git && cd clmr
44 |
45 | pip3 install -r requirements.txt
46 | # or
47 | python3 setup.py install
48 | ```
49 |
50 | The following command downloads MagnaTagATune, preprocesses it and starts self-supervised pre-training on 1 GPU (with 8 simultaneous CPU workers) and linear evaluation:
51 | ```
52 | python3 preprocess.py --dataset magnatagatune
53 |
54 | # add --workers 8 to increase the number of parallel CPU threads to speed up online data augmentations + training.
55 | python3 main.py --dataset magnatagatune --gpus 1 --workers 8
56 |
57 | python3 linear_evaluation.py --gpus 1 --workers 8 --checkpoint_path [path to checkpoint.pt, usually in ./runs]
58 | ```
59 |
60 | ## Pre-train on your own folder of audio files
61 | Simply run the following command to pre-train the CLMR model on a folder containing .wav files (or .mp3 files when editing `src_ext_audio=".mp3"` in `clmr/datasets/audio.py`). You may need to convert your audio files to the correct sample rate first, before giving it to the encoder (which accepts `22,050Hz` per default).
62 |
63 | ```
64 | python preprocess.py --dataset audio --dataset_dir ./directory_containing_audio_files
65 |
66 | python main.py --dataset audio --dataset_dir ./directory_containing_audio_files
67 | ```
68 |
69 |
70 | ## Results
71 |
72 | ### MagnaTagATune
73 |
74 | | Encoder / Model | Batch-size / epochs | Fine-tune head | ROC-AUC | PR-AUC |
75 | |-------------|-------------|-------------|-------------|-------------|
76 | | SampleCNN / CLMR | 48 / 10000 | Linear Classifier | 88.7 | **35.6** |
77 | SampleCNN / CLMR | 48 / 10000 | MLP (1 extra hidden layer) | **89.3** | **36.0** |
78 | | [SampleCNN (fully supervised)](https://www.mdpi.com/2076-3417/8/1/150) | 48 / - | - | 88.6 | 34.4 |
79 | | [Pons et al. (fully supervised)](https://arxiv.org/pdf/1711.02520.pdf) | 48 / - | - | 89.1 | 34.92 |
80 |
81 | ### Million Song Dataset
82 |
83 | | Encoder / Model | Batch-size / epochs | Fine-tune head | ROC-AUC | PR-AUC |
84 | |-------------|-------------|-------------|-------------|-------------|
85 | | SampleCNN / CLMR | 48 / 1000 | Linear Classifier | 85.7 | 25.0 |
86 | | [SampleCNN (fully supervised)](https://www.mdpi.com/2076-3417/8/1/150) | 48 / - | - | **88.4** | - |
87 | | [Pons et al. (fully supervised)](https://arxiv.org/pdf/1711.02520.pdf) | 48 / - | - | 87.4 | **28.5** |
88 |
89 |
90 | ## Pre-trained models
91 | *Links go to download*
92 |
93 | | Encoder (batch-size, epochs) | Fine-tune head | Pre-train dataset | ROC-AUC | PR-AUC
94 | |-------------|-------------|-------------|-------------|-------------|
95 | [SampleCNN (96, 10000)](https://github.com/Spijkervet/CLMR/releases/download/2.0/clmr_checkpoint_10000.zip) | [Linear Classifier](https://github.com/Spijkervet/CLMR/releases/download/2.0/finetuner_checkpoint_200.zip) | MagnaTagATune | 88.7 (89.3) | 35.6 (36.0)
96 | [SampleCNN (48, 1550)](https://github.com/Spijkervet/CLMR/releases/download/1.0/clmr_checkpoint_1550.pt) | [Linear Classifier](https://github.com/Spijkervet/CLMR/releases/download/1.0-l/finetuner_checkpoint_20.pt) | MagnaTagATune | 87.71 (88.47) | 34.27 (34.96)
97 |
98 | ## Training
99 | ### 1. Pre-training
100 | Simply run the following command to pre-train the CLMR model on the MagnaTagATune dataset.
101 | ```
102 | python main.py --dataset magnatagatune
103 | ```
104 |
105 | ### 2. Linear evaluation
106 | To test a trained model, make sure to set the `checkpoint_path` variable in the `config/config.yaml`, or specify it as an argument:
107 | ```
108 | python linear_evaluation.py --checkpoint_path ./clmr_checkpoint_10000.pt
109 | ```
110 |
111 | ## Configuration
112 | The configuration of training can be found in: `config/config.yaml`. I personally prefer to use files instead of long strings of arguments when configuring a run. Every entry in the config file can be overrided with the corresponding flag (e.g. `--max_epochs 500` if you would like to train with 500 epochs).
113 |
114 | ## Logging and TensorBoard
115 | To view results in TensorBoard, run:
116 | ```
117 | tensorboard --logdir ./runs
118 | ```
119 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/clmr/data.py:
--------------------------------------------------------------------------------
1 | """Wrapper for Torch Dataset class to enable contrastive training
2 | """
3 | import torch
4 | from torch import Tensor
5 | from torch.utils.data import Dataset
6 | from torchaudio_augmentations import Compose
7 | from typing import Tuple, List
8 |
9 |
10 | class ContrastiveDataset(Dataset):
11 | def __init__(self, dataset: Dataset, input_shape: List[int], transform: Compose):
12 | self.dataset = dataset
13 | self.transform = transform
14 | self.input_shape = input_shape
15 | self.ignore_idx = []
16 |
17 | def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
18 | if idx in self.ignore_idx:
19 | return self[idx + 1]
20 |
21 | audio, label = self.dataset[idx]
22 |
23 | if audio.shape[1] < self.input_shape[1]:
24 | self.ignore_idx.append(idx)
25 | return self[idx + 1]
26 |
27 | if self.transform:
28 | audio = self.transform(audio)
29 | return audio, label
30 |
31 | def __len__(self) -> int:
32 | return len(self.dataset)
33 |
34 | def concat_clip(self, n: int, audio_length: float) -> Tensor:
35 | audio, _ = self.dataset[n]
36 | batch = torch.split(audio, audio_length, dim=1)
37 | batch = torch.cat(batch[:-1])
38 | batch = batch.unsqueeze(dim=1)
39 |
40 | if self.transform:
41 | batch = self.transform(batch)
42 |
43 | return batch
44 |
--------------------------------------------------------------------------------
/clmr/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .dataset import Dataset
3 | from .audio import AUDIO
4 | from .librispeech import LIBRISPEECH
5 | from .gtzan import GTZAN
6 | from .magnatagatune import MAGNATAGATUNE
7 | from .million_song_dataset import MillionSongDataset
8 |
9 |
10 | def get_dataset(dataset, dataset_dir, subset, download=True):
11 |
12 | if not os.path.exists(dataset_dir):
13 | os.makedirs(dataset_dir)
14 |
15 | if dataset == "audio":
16 | d = AUDIO(root=dataset_dir)
17 | elif dataset == "librispeech":
18 | d = LIBRISPEECH(root=dataset_dir, download=download, subset=subset)
19 | elif dataset == "gtzan":
20 | d = GTZAN(root=dataset_dir, download=download, subset=subset)
21 | elif dataset == "magnatagatune":
22 | d = MAGNATAGATUNE(root=dataset_dir, download=download, subset=subset)
23 | elif dataset == "msd":
24 | d = MillionSongDataset(root=dataset_dir, subset=subset)
25 | else:
26 | raise NotImplementedError("Dataset not implemented")
27 | return d
28 |
--------------------------------------------------------------------------------
/clmr/datasets/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from torch import Tensor
4 | from typing import Tuple
5 |
6 |
7 | from clmr.datasets import Dataset
8 |
9 |
10 | class AUDIO(Dataset):
11 | """Create a Dataset for any folder of audio files.
12 | Args:
13 | root (str): Path to the directory where the dataset is found or downloaded.
14 | src_ext_audio (str): The extension of the audio files to analyze.
15 | """
16 |
17 | def __init__(
18 | self,
19 | root: str,
20 | src_ext_audio: str = ".wav",
21 | n_classes: int = 1,
22 | ) -> None:
23 | super(AUDIO, self).__init__(root)
24 |
25 | self._path = root
26 | self._src_ext_audio = src_ext_audio
27 | self.n_classes = n_classes
28 |
29 | self.fl = glob(
30 | os.path.join(self._path, "**", "*{}".format(self._src_ext_audio)),
31 | recursive=True,
32 | )
33 |
34 | if len(self.fl) == 0:
35 | raise RuntimeError(
36 | "Dataset not found. Please place the audio files in the {} folder.".format(
37 | self._path
38 | )
39 | )
40 |
41 | def file_path(self, n: int) -> str:
42 | fp = self.fl[n]
43 | return fp
44 |
45 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]:
46 | """Load the n-th sample from the dataset.
47 |
48 | Args:
49 | n (int): The index of the sample to be loaded
50 |
51 | Returns:
52 | Tuple [Tensor, Tensor]: ``(waveform, label)``
53 | """
54 | audio, _ = self.load(n)
55 | label = []
56 | return audio, label
57 |
58 | def __len__(self) -> int:
59 | return len(self.fl)
60 |
--------------------------------------------------------------------------------
/clmr/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import torchaudio
4 | from torch.utils.data import Dataset as TorchDataset
5 | from abc import abstractmethod
6 |
7 |
8 | def preprocess_audio(source, target, sample_rate):
9 | p = subprocess.Popen(
10 | ["ffmpeg", "-i", source, "-ar", str(sample_rate), target, "-loglevel", "quiet"]
11 | )
12 | p.wait()
13 |
14 |
15 | class Dataset(TorchDataset):
16 |
17 | _ext_audio = ".wav"
18 |
19 | def __init__(self, root: str):
20 | pass
21 |
22 | @abstractmethod
23 | def file_path(self, n: int):
24 | pass
25 |
26 | def target_file_path(self, n: int) -> str:
27 | fp = self.file_path(n)
28 | file_basename, _ = os.path.splitext(fp)
29 | return file_basename + self._ext_audio
30 |
31 | def preprocess(self, n: int, sample_rate: int):
32 | fp = self.file_path(n)
33 | target_fp = self.target_file_path(n)
34 |
35 | if not os.path.exists(target_fp):
36 | preprocess_audio(fp, target_fp, sample_rate)
37 |
38 | def load(self, n):
39 | target_fp = self.target_file_path(n)
40 | try:
41 | audio, sample_rate = torchaudio.load(target_fp)
42 | except OSError as e:
43 | print("File not found, try running `python preprocess.py` first.\n\n", e)
44 | return
45 | return audio, sample_rate
46 |
--------------------------------------------------------------------------------
/clmr/datasets/gtzan.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | from torchaudio.datasets.gtzan import gtzan_genres
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class GTZAN(Dataset):
7 |
8 | subset_map = {"train": "training", "valid": "validation", "test": "testing"}
9 |
10 | def __init__(self, root, download, subset):
11 | self.dataset = torchaudio.datasets.GTZAN(
12 | root=root, download=download, subset=self.subset_map[subset]
13 | )
14 | self.labels = gtzan_genres
15 |
16 | self.label2idx = {}
17 | for idx, label in enumerate(self.labels):
18 | self.label2idx[label] = idx
19 |
20 | self.n_classes = len(self.label2idx.keys())
21 |
22 | def __getitem__(self, idx):
23 | audio, sr, label = self.dataset[idx]
24 | label = self.label2idx[label]
25 | return audio, label
26 |
27 | def __len__(self):
28 | return len(self.dataset)
29 |
--------------------------------------------------------------------------------
/clmr/datasets/librispeech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchaudio
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class LIBRISPEECH(Dataset):
7 |
8 | subset_map = {"train": "train-clean-100", "test": "test-clean"}
9 |
10 | def __init__(self, root, download, subset):
11 | self.dataset = torchaudio.datasets.LIBRISPEECH(
12 | root=root, download=download, url=self.subset_map[subset]
13 | )
14 |
15 | self.speaker2idx = {}
16 |
17 | if not os.path.exists(self.dataset._path):
18 | raise RuntimeError(
19 | "Dataset not found. Please use `download=True` to download it."
20 | )
21 |
22 | self.speaker_ids = list(map(int, os.listdir(self.dataset._path)))
23 | for idx, speaker_id in enumerate(sorted(self.speaker_ids)):
24 | self.speaker2idx[speaker_id] = idx
25 |
26 | self.n_classes = len(self.speaker2idx.keys())
27 |
28 | def __getitem__(self, idx):
29 | (
30 | audio,
31 | sample_rate,
32 | utterance,
33 | speaker_id,
34 | chapter_id,
35 | utterance_id,
36 | ) = self.dataset[idx]
37 | label = self.speaker2idx[speaker_id]
38 | return audio, label
39 |
40 | def __len__(self):
41 | return len(self.dataset)
42 |
--------------------------------------------------------------------------------
/clmr/datasets/magnatagatune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import subprocess
4 | import torch
5 | import numpy as np
6 | import zipfile
7 | from collections import defaultdict
8 | from typing import Any, Tuple, Optional
9 | from tqdm import tqdm
10 |
11 | import soundfile as sf
12 | import torchaudio
13 |
14 | torchaudio.set_audio_backend("soundfile")
15 | from torch import Tensor, FloatTensor
16 | from torchaudio.datasets.utils import (
17 | download_url,
18 | extract_archive,
19 | )
20 |
21 | from clmr.datasets import Dataset
22 |
23 |
24 | FOLDER_IN_ARCHIVE = "magnatagatune"
25 | _CHECKSUMS = {
26 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.001": "",
27 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.002": "",
28 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.003": "",
29 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/annotations_final.csv": "",
30 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/binary.npy": "",
31 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/tags.npy": "",
32 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/test.npy": "",
33 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/train.npy": "",
34 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/valid.npy": "",
35 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/train_gt_mtt.tsv": "",
36 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/val_gt_mtt.tsv": "",
37 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/test_gt_mtt.tsv": "",
38 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/index_mtt.tsv": "",
39 | }
40 |
41 |
42 | def get_file_list(root, subset, split):
43 | if subset == "train":
44 | if split == "pons2017":
45 | fl = open(os.path.join(root, "train_gt_mtt.tsv")).read().splitlines()
46 | else:
47 | fl = np.load(os.path.join(root, "train.npy"))
48 | elif subset == "valid":
49 | if split == "pons2017":
50 | fl = open(os.path.join(root, "val_gt_mtt.tsv")).read().splitlines()
51 | else:
52 | fl = np.load(os.path.join(root, "valid.npy"))
53 | else:
54 | if split == "pons2017":
55 | fl = open(os.path.join(root, "test_gt_mtt.tsv")).read().splitlines()
56 | else:
57 | fl = np.load(os.path.join(root, "test.npy"))
58 |
59 | if split == "pons2017":
60 | binary = {}
61 | index = open(os.path.join(root, "index_mtt.tsv")).read().splitlines()
62 | fp_dict = {}
63 | for i in index:
64 | clip_id, fp = i.split("\t")
65 | fp_dict[clip_id] = fp
66 |
67 | for idx, f in enumerate(fl):
68 | clip_id, label = f.split("\t")
69 | fl[idx] = "{}\t{}".format(clip_id, fp_dict[clip_id])
70 | clip_id = int(clip_id)
71 | binary[clip_id] = eval(label)
72 | else:
73 | binary = np.load(os.path.join(root, "binary.npy"))
74 |
75 | return fl, binary
76 |
77 |
78 | class MAGNATAGATUNE(Dataset):
79 | """Create a Dataset for MagnaTagATune.
80 | Args:
81 | root (str): Path to the directory where the dataset is found or downloaded.
82 | folder_in_archive (str, optional): The top-level directory of the dataset.
83 | download (bool, optional):
84 | Whether to download the dataset if it is not found at root path. (default: ``False``).
85 | subset (str, optional): Which subset of the dataset to use.
86 | One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
87 | If ``None``, the entire dataset is used. (default: ``None``).
88 | """
89 |
90 | _ext_audio = ".wav"
91 |
92 | def __init__(
93 | self,
94 | root: str,
95 | folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE,
96 | download: Optional[bool] = False,
97 | subset: Optional[str] = None,
98 | split: Optional[str] = "pons2017",
99 | ) -> None:
100 |
101 | super(MAGNATAGATUNE, self).__init__(root)
102 | self.root = root
103 | self.folder_in_archive = folder_in_archive
104 | self.download = download
105 | self.subset = subset
106 | self.split = split
107 |
108 | assert subset is None or subset in ["train", "valid", "test"], (
109 | "When `subset` not None, it must take a value from "
110 | + "{'train', 'valid', 'test'}."
111 | )
112 |
113 | self._path = os.path.join(root, folder_in_archive)
114 |
115 | if download:
116 | if not os.path.isdir(self._path):
117 | os.makedirs(self._path)
118 |
119 | zip_files = []
120 | for url, checksum in _CHECKSUMS.items():
121 | target_fn = os.path.basename(url)
122 | target_fp = os.path.join(self._path, target_fn)
123 | if ".zip" in target_fp:
124 | zip_files.append(target_fp)
125 |
126 | if not os.path.exists(target_fp):
127 | download_url(
128 | url,
129 | self._path,
130 | filename=target_fn,
131 | hash_value=checksum,
132 | hash_type="md5",
133 | )
134 |
135 | if not os.path.exists(
136 | os.path.join(
137 | self._path,
138 | "f",
139 | "american_bach_soloists-j_s__bach_solo_cantatas-01-bwv54__i_aria-30-59.mp3",
140 | )
141 | ):
142 | merged_zip = os.path.join(self._path, "mp3.zip")
143 | print("Merging zip files...")
144 | with open(merged_zip, "wb") as f:
145 | for filename in zip_files:
146 | with open(filename, "rb") as g:
147 | f.write(g.read())
148 |
149 | extract_archive(merged_zip)
150 |
151 | if not os.path.isdir(self._path):
152 | raise RuntimeError(
153 | "Dataset not found. Please use `download=True` to download it."
154 | )
155 |
156 | self.fl, self.binary = get_file_list(self._path, self.subset, self.split)
157 | self.n_classes = 50 # self.binary.shape[1]
158 | # self.audio = {}
159 | # for f in tqdm(self.fl):
160 | # clip_id, fp = f.split("\t")
161 | # if clip_id not in self.audio.keys():
162 | # audio, _ = load_magnatagatune_item(fp, self._path, self._ext_audio)
163 | # self.audio[clip_id] = audio
164 |
165 | def file_path(self, n: int) -> str:
166 | _, fp = self.fl[n].split("\t")
167 | return os.path.join(self._path, fp)
168 |
169 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]:
170 | """Load the n-th sample from the dataset.
171 |
172 | Args:
173 | n (int): The index of the sample to be loaded
174 |
175 | Returns:
176 | tuple: ``(waveform, label)``
177 | """
178 | clip_id, fp = self.fl[n].split("\t")
179 | label = self.binary[int(clip_id)]
180 |
181 | audio, _ = self.load(n)
182 | label = FloatTensor(label)
183 | return audio, label
184 |
185 | def __len__(self) -> int:
186 | return len(self.fl)
187 |
--------------------------------------------------------------------------------
/clmr/datasets/million_song_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import torch
4 | import torchaudio
5 | from collections import defaultdict
6 | from pathlib import Path
7 | from torch import Tensor, FloatTensor
8 | from tqdm import tqdm
9 | from typing import Any, Tuple, Optional
10 |
11 | from clmr.datasets import Dataset
12 |
13 |
14 | def load_id2gt(gt_file, msd_7d):
15 | ids = []
16 | with open(gt_file) as f:
17 | id2gt = dict()
18 | for line in f.readlines():
19 | msd_id, gt = line.strip().split("\t") # id is string
20 | id_7d = msd_7d[msd_id]
21 | id2gt[msd_id] = eval(gt) # gt is array
22 | ids.append(msd_id)
23 | return ids, id2gt
24 |
25 |
26 | def load_id2path(index_file, msd_7d):
27 | paths = []
28 | with open(index_file) as f:
29 | id2path = dict()
30 | for line in f.readlines():
31 | msd_id, msd_path = line.strip().split("\t")
32 | id_7d = msd_7d[msd_id]
33 | path = os.path.join(id_7d[0], id_7d[1], f"{id_7d}.clip.mp3")
34 | id2path[msd_id] = path
35 | paths.append(path)
36 | return paths, id2path
37 |
38 |
39 | def default_indexer(ids, id2audio_path, id2gt):
40 | index = []
41 | track_index = defaultdict(list)
42 | track_idx = 0
43 | clip_idx = 0
44 | for clip_id in ids:
45 | fp = id2audio_path[clip_id]
46 | label = id2gt[clip_id]
47 | track_idx = clip_id
48 | clip_id = clip_idx
49 | clip_idx += 1
50 | index.append([track_idx, clip_id, fp, label])
51 | track_index[track_idx].append([clip_id, fp, label])
52 | return index, track_index
53 |
54 |
55 | def default_loader(path):
56 | audio, sr = torchaudio.load(path)
57 | audio = audio.mean(dim=0, keepdim=True)
58 | return audio, sr
59 |
60 |
61 | class MillionSongDataset(Dataset):
62 |
63 | _base_dir = "million_song_dataset"
64 | _ext_audio = ".wav"
65 |
66 | def __init__(
67 | self,
68 | root: str,
69 | base_dir: str = _base_dir,
70 | download: bool = False,
71 | subset: Optional[str] = None,
72 | ):
73 | if download:
74 | raise Exception("The Million Song Dataset is not publicly available")
75 |
76 | self.root = root
77 | self.base_dir = base_dir
78 | self.subset = subset
79 |
80 | assert subset is None or subset in ["train", "valid", "test"], (
81 | "When `subset` not None, it must take a value from "
82 | + "{'train', 'valid', 'test'}."
83 | )
84 |
85 | self._path = os.path.join(self.root, self.base_dir)
86 |
87 | if not os.path.exists(self._path):
88 | raise RuntimeError(
89 | "Dataset not found. Please place the MSD files in the {} folder.".format(
90 | self._path
91 | )
92 | )
93 |
94 | msd_processed_annot = Path(self._path, "processed_annotations")
95 |
96 | if self.subset == "train":
97 | self.annotations_file = Path(msd_processed_annot) / "train_gt_msd.tsv"
98 | elif self.subset == "valid":
99 | self.annotations_file = Path(msd_processed_annot) / "val_gt_msd.tsv"
100 | else:
101 | self.annotations_file = Path(msd_processed_annot) / "test_gt_msd.tsv"
102 |
103 | with open(Path(msd_processed_annot) / "MSD_id_to_7D_id.pkl", "rb") as f:
104 | self.msd_to_7d = pickle.load(f)
105 |
106 | # int to label
107 | with open(Path(msd_processed_annot) / "output_labels_msd.txt", "r") as f:
108 | lines = f.readlines()
109 | self.tags = eval(lines[1][lines[1].find("[") :])
110 | self.n_classes = len(self.tags)
111 |
112 | [audio_repr_paths, id2audio_path] = load_id2path(
113 | Path(msd_processed_annot) / "index_msd.tsv", self.msd_to_7d
114 | )
115 | [ids, id2gt] = load_id2gt(self.annotations_file, self.msd_to_7d)
116 |
117 | self.index, self.track_index = default_indexer(ids, id2audio_path, id2gt)
118 |
119 | def file_path(self, n: int) -> str:
120 | _, _, fp, _ = self.index[n]
121 | return os.path.join(self._path, "preprocessed", fp)
122 |
123 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]:
124 | track_id, clip_id, fp, label = self.index[n]
125 | label = torch.FloatTensor(label)
126 |
127 | try:
128 | audio, _ = self.load(n)
129 | except Exception as e:
130 | print(f"Skipped {track_id, fp}, could not load audio: {e}")
131 | return self.__getitem__(n + 1)
132 | return audio, label
133 |
134 | def __len__(self) -> int:
135 | return len(self.index)
136 |
--------------------------------------------------------------------------------
/clmr/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import Dataset
5 | from tqdm import tqdm
6 | from sklearn import metrics
7 |
8 |
9 | def evaluate(
10 | encoder: nn.Module,
11 | finetuned_head: nn.Module,
12 | test_dataset: Dataset,
13 | dataset_name: str,
14 | audio_length: int,
15 | device,
16 | ) -> dict:
17 | est_array = []
18 | gt_array = []
19 |
20 | encoder = encoder.to(device)
21 | encoder.eval()
22 |
23 | if finetuned_head is not None:
24 | finetuned_head = finetuned_head.to(device)
25 | finetuned_head.eval()
26 |
27 | with torch.no_grad():
28 | for idx in tqdm(range(len(test_dataset))):
29 | _, label = test_dataset[idx]
30 | batch = test_dataset.concat_clip(idx, audio_length)
31 | batch = batch.to(device)
32 |
33 | output = encoder(batch)
34 | if finetuned_head:
35 | output = finetuned_head(output)
36 |
37 | # we always return logits, so we need a sigmoid here for multi-label classification
38 | if dataset_name in ["magnatagatune", "msd"]:
39 | output = torch.sigmoid(output)
40 | else:
41 | output = F.softmax(output, dim=1)
42 |
43 | track_prediction = output.mean(dim=0)
44 | est_array.append(track_prediction)
45 | gt_array.append(label)
46 |
47 | if dataset_name in ["magnatagatune", "msd"]:
48 | est_array = torch.stack(est_array, dim=0).cpu().numpy()
49 | gt_array = torch.stack(gt_array, dim=0).cpu().numpy()
50 | roc_aucs = metrics.roc_auc_score(gt_array, est_array, average="macro")
51 | pr_aucs = metrics.average_precision_score(gt_array, est_array, average="macro")
52 | return {
53 | "PR-AUC": pr_aucs,
54 | "ROC-AUC": roc_aucs,
55 | }
56 |
57 | est_array = torch.stack(est_array, dim=0)
58 | _, est_array = torch.max(est_array, 1) # extract the predicted labels here.
59 | accuracy = metrics.accuracy_score(gt_array, est_array)
60 | return {"Accuracy": accuracy}
61 |
--------------------------------------------------------------------------------
/clmr/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model, Identity
2 | from .sample_cnn import SampleCNN
3 | from .shortchunk_cnn import ShortChunkCNN_Res
4 | from .sinc_net import SincNet
5 |
--------------------------------------------------------------------------------
/clmr/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 |
4 |
5 | class Model(nn.Module):
6 | def __init__(self):
7 | super(Model, self).__init__()
8 |
9 | def initialize(self, m):
10 | if isinstance(m, (nn.Conv1d)):
11 | # nn.init.xavier_uniform_(m.weight)
12 | # if m.bias is not None:
13 | # nn.init.xavier_uniform_(m.bias)
14 |
15 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
16 |
17 |
18 | class Identity(nn.Module):
19 | def __init__(self):
20 | super(Identity, self).__init__()
21 |
22 | def forward(self, x):
23 | return x
24 |
--------------------------------------------------------------------------------
/clmr/models/sample_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .model import Model
4 |
5 |
6 | class SampleCNN(Model):
7 | def __init__(self, strides, supervised, out_dim):
8 | super(SampleCNN, self).__init__()
9 |
10 | self.strides = strides
11 | self.supervised = supervised
12 | self.sequential = [
13 | nn.Sequential(
14 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0),
15 | nn.BatchNorm1d(128),
16 | nn.ReLU(),
17 | )
18 | ]
19 |
20 | self.hidden = [
21 | [128, 128],
22 | [128, 128],
23 | [128, 256],
24 | [256, 256],
25 | [256, 256],
26 | [256, 256],
27 | [256, 256],
28 | [256, 256],
29 | [256, 512],
30 | ]
31 |
32 | assert len(self.hidden) == len(
33 | self.strides
34 | ), "Number of hidden layers and strides are not equal"
35 | for stride, (h_in, h_out) in zip(self.strides, self.hidden):
36 | self.sequential.append(
37 | nn.Sequential(
38 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1),
39 | nn.BatchNorm1d(h_out),
40 | nn.ReLU(),
41 | nn.MaxPool1d(stride, stride=stride),
42 | )
43 | )
44 |
45 | # 1 x 512
46 | self.sequential.append(
47 | nn.Sequential(
48 | nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1),
49 | nn.BatchNorm1d(512),
50 | nn.ReLU(),
51 | )
52 | )
53 |
54 | self.sequential = nn.Sequential(*self.sequential)
55 |
56 | if self.supervised:
57 | self.dropout = nn.Dropout(0.5)
58 | self.fc = nn.Linear(512, out_dim)
59 |
60 | def forward(self, x):
61 | out = self.sequential(x)
62 | if self.supervised:
63 | out = self.dropout(out)
64 |
65 | out = out.reshape(x.shape[0], out.size(1) * out.size(2))
66 | logit = self.fc(out)
67 | return logit
68 |
--------------------------------------------------------------------------------
/clmr/models/sample_cnn_xl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .model import Model
4 |
5 |
6 | class SampleCNNXL(Model):
7 | def __init__(self, strides, supervised, out_dim):
8 | super(SampleCNN, self).__init__()
9 |
10 | self.strides = strides
11 | self.supervised = supervised
12 | self.sequential = [
13 | nn.Sequential(
14 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0),
15 | nn.BatchNorm1d(128),
16 | nn.ReLU(),
17 | )
18 | ]
19 |
20 | self.hidden = [
21 | [128, 128],
22 | [128, 128],
23 | [128, 256],
24 | [256, 256],
25 | [256, 512],
26 | [512, 512],
27 | [512, 1024],
28 | [1024, 1024],
29 | [1024, 2048],
30 | ]
31 |
32 | assert len(self.hidden) == len(
33 | self.strides
34 | ), "Number of hidden layers and strides are not equal"
35 | for stride, (h_in, h_out) in zip(self.strides, self.hidden):
36 | self.sequential.append(
37 | nn.Sequential(
38 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1),
39 | nn.BatchNorm1d(h_out),
40 | nn.ReLU(),
41 | nn.MaxPool1d(stride, stride=stride),
42 | )
43 | )
44 |
45 | # 1 x 512
46 | self.sequential.append(
47 | nn.Sequential(
48 | nn.Conv1d(2048, 2048, kernel_size=3, stride=1, padding=1),
49 | nn.BatchNorm1d(2048),
50 | nn.ReLU(),
51 | )
52 | )
53 |
54 | self.sequential = nn.Sequential(*self.sequential)
55 |
56 | if self.supervised:
57 | self.dropout = nn.Dropout(0.5)
58 | self.fc = nn.Linear(2048, out_dim)
59 |
60 | def forward(self, x):
61 | out = self.sequential(x)
62 | if self.supervised:
63 | out = self.dropout(out)
64 |
65 | out = out.reshape(x.shape[0], out.size(1) * out.size(2))
66 | logit = self.fc(out)
67 | return logit
68 |
--------------------------------------------------------------------------------
/clmr/models/shortchunk_cnn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class ShortChunkCNN_Res(nn.Module):
5 | """
6 | Short-chunk CNN architecture with residual connections.
7 | """
8 |
9 | def __init__(self, n_channels=128, n_classes=50):
10 | super(ShortChunkCNN_Res, self).__init__()
11 |
12 | self.spec_bn = nn.BatchNorm2d(1)
13 |
14 | # CNN
15 | self.layer1 = Res_2d(1, n_channels, stride=2)
16 | self.layer2 = Res_2d(n_channels, n_channels, stride=2)
17 | self.layer3 = Res_2d(n_channels, n_channels * 2, stride=2)
18 | self.layer4 = Res_2d(n_channels * 2, n_channels * 2, stride=2)
19 | self.layer5 = Res_2d(n_channels * 2, n_channels * 2, stride=2)
20 | self.layer6 = Res_2d(n_channels * 2, n_channels * 2, stride=2)
21 | self.layer7 = Res_2d(n_channels * 2, n_channels * 4, stride=2)
22 |
23 | # Dense
24 | self.dense1 = nn.Linear(n_channels * 4, n_channels * 4)
25 | self.bn = nn.BatchNorm1d(n_channels * 4)
26 |
27 | self.fc = nn.Linear(n_channels * 4, n_classes)
28 | self.dropout = nn.Dropout(0.5)
29 | self.relu = nn.ReLU()
30 |
31 | def forward(self, x):
32 | x = self.spec_bn(x)
33 |
34 | # CNN
35 | x = self.layer1(x)
36 | x = self.layer2(x)
37 | x = self.layer3(x)
38 | x = self.layer4(x)
39 | x = self.layer5(x)
40 | x = self.layer6(x)
41 | x = self.layer7(x)
42 | x = x.squeeze(2)
43 |
44 | # Global Max Pooling
45 | if x.size(-1) != 1:
46 | x = nn.MaxPool1d(x.size(-1))(x)
47 | x = x.squeeze(2)
48 |
49 | # Dense
50 | x = self.dense1(x)
51 | x = self.bn(x)
52 | x = self.relu(x)
53 | x = self.dropout(x)
54 | x = self.fc(x)
55 | # x = nn.Sigmoid()(x)
56 |
57 | return x
58 |
59 |
60 | class Res_2d(nn.Module):
61 | def __init__(self, input_channels, output_channels, shape=3, stride=2):
62 | super(Res_2d, self).__init__()
63 | # convolution
64 | self.conv_1 = nn.Conv2d(
65 | input_channels, output_channels, shape, stride=stride, padding=shape // 2
66 | )
67 | self.bn_1 = nn.BatchNorm2d(output_channels)
68 | self.conv_2 = nn.Conv2d(
69 | output_channels, output_channels, shape, padding=shape // 2
70 | )
71 | self.bn_2 = nn.BatchNorm2d(output_channels)
72 |
73 | # residual
74 | self.diff = False
75 | if (stride != 1) or (input_channels != output_channels):
76 | self.conv_3 = nn.Conv2d(
77 | input_channels,
78 | output_channels,
79 | shape,
80 | stride=stride,
81 | padding=shape // 2,
82 | )
83 | self.bn_3 = nn.BatchNorm2d(output_channels)
84 | self.diff = True
85 | self.relu = nn.ReLU()
86 |
87 | def forward(self, x):
88 | # convolution
89 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x)))))
90 |
91 | # residual
92 | if self.diff:
93 | x = self.bn_3(self.conv_3(x))
94 | out = x + out
95 | out = self.relu(out)
96 | return out
97 |
--------------------------------------------------------------------------------
/clmr/models/sinc_net.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | import sys
6 | from torch.autograd import Variable
7 | import math
8 |
9 |
10 | def flip(x, dim):
11 | xsize = x.size()
12 | dim = x.dim() + dim if dim < 0 else dim
13 | x = x.contiguous()
14 | x = x.view(-1, *xsize[dim:])
15 | x = x.view(x.size(0), x.size(1), -1)[
16 | :,
17 | getattr(
18 | torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda]
19 | )().long(),
20 | :,
21 | ]
22 | return x.view(xsize)
23 |
24 |
25 | def sinc(band, t_right):
26 | y_right = torch.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right)
27 | y_left = flip(y_right, 0)
28 |
29 | y = torch.cat([y_left, Variable(torch.ones(1)).cuda(), y_right])
30 |
31 | return y
32 |
33 |
34 | class SincConv_fast(nn.Module):
35 | """Sinc-based convolution
36 | Parameters
37 | ----------
38 | in_channels : `int`
39 | Number of input channels. Must be 1.
40 | out_channels : `int`
41 | Number of filters.
42 | kernel_size : `int`
43 | Filter length.
44 | sample_rate : `int`, optional
45 | Sample rate. Defaults to 16000.
46 | Usage
47 | -----
48 | See `torch.nn.Conv1d`
49 | Reference
50 | ---------
51 | Mirco Ravanelli, Yoshua Bengio,
52 | "Speaker Recognition from raw waveform with SincNet".
53 | https://arxiv.org/abs/1808.00158
54 | """
55 |
56 | @staticmethod
57 | def to_mel(hz):
58 | return 2595 * np.log10(1 + hz / 700)
59 |
60 | @staticmethod
61 | def to_hz(mel):
62 | return 700 * (10 ** (mel / 2595) - 1)
63 |
64 | def __init__(
65 | self,
66 | out_channels,
67 | kernel_size,
68 | sample_rate=16000,
69 | in_channels=1,
70 | stride=1,
71 | padding=0,
72 | dilation=1,
73 | bias=False,
74 | groups=1,
75 | min_low_hz=50,
76 | min_band_hz=50,
77 | ):
78 |
79 | super(SincConv_fast, self).__init__()
80 |
81 | if in_channels != 1:
82 | # msg = (f'SincConv only support one input channel '
83 | # f'(here, in_channels = {in_channels:d}).')
84 | msg = (
85 | "SincConv only support one input channel (here, in_channels = {%i})"
86 | % (in_channels)
87 | )
88 | raise ValueError(msg)
89 |
90 | self.out_channels = out_channels
91 | self.kernel_size = kernel_size
92 |
93 | # Forcing the filters to be odd (i.e, perfectly symmetrics)
94 | if kernel_size % 2 == 0:
95 | self.kernel_size = self.kernel_size + 1
96 |
97 | self.stride = stride
98 | self.padding = padding
99 | self.dilation = dilation
100 |
101 | if bias:
102 | raise ValueError("SincConv does not support bias.")
103 | if groups > 1:
104 | raise ValueError("SincConv does not support groups.")
105 |
106 | self.sample_rate = sample_rate
107 | self.min_low_hz = min_low_hz
108 | self.min_band_hz = min_band_hz
109 |
110 | # initialize filterbanks such that they are equally spaced in Mel scale
111 | low_hz = 30
112 | high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
113 |
114 | mel = np.linspace(
115 | self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1
116 | )
117 | hz = self.to_hz(mel)
118 |
119 | # filter lower frequency (out_channels, 1)
120 | self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
121 |
122 | # filter frequency band (out_channels, 1)
123 | self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
124 |
125 | # Hamming window
126 | # self.window_ = torch.hamming_window(self.kernel_size)
127 | n_lin = torch.linspace(
128 | 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
129 | ) # computing only half of the window
130 | self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / self.kernel_size)
131 |
132 | # (1, kernel_size/2)
133 | n = (self.kernel_size - 1) / 2.0
134 | self.n_ = (
135 | 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
136 | ) # Due to symmetry, I only need half of the time axes
137 |
138 | def forward(self, waveforms):
139 | """
140 | Parameters
141 | ----------
142 | waveforms : `torch.Tensor` (batch_size, 1, n_samples)
143 | Batch of waveforms.
144 | Returns
145 | -------
146 | features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
147 | Batch of sinc filters activations.
148 | """
149 |
150 | self.n_ = self.n_.to(waveforms.device)
151 |
152 | self.window_ = self.window_.to(waveforms.device)
153 |
154 | low = self.min_low_hz + torch.abs(self.low_hz_)
155 |
156 | high = torch.clamp(
157 | low + self.min_band_hz + torch.abs(self.band_hz_),
158 | self.min_low_hz,
159 | self.sample_rate / 2,
160 | )
161 | band = (high - low)[:, 0]
162 |
163 | f_times_t_low = torch.matmul(low, self.n_)
164 | f_times_t_high = torch.matmul(high, self.n_)
165 |
166 | band_pass_left = (
167 | (torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2)
168 | ) * self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
169 | band_pass_center = 2 * band.view(-1, 1)
170 | band_pass_right = torch.flip(band_pass_left, dims=[1])
171 |
172 | band_pass = torch.cat(
173 | [band_pass_left, band_pass_center, band_pass_right], dim=1
174 | )
175 |
176 | band_pass = band_pass / (2 * band[:, None])
177 |
178 | self.filters = (band_pass).view(self.out_channels, 1, self.kernel_size)
179 |
180 | return F.conv1d(
181 | waveforms,
182 | self.filters,
183 | stride=self.stride,
184 | padding=self.padding,
185 | dilation=self.dilation,
186 | bias=None,
187 | groups=1,
188 | )
189 |
190 |
191 | class sinc_conv(nn.Module):
192 | def __init__(self, N_filt, Filt_dim, fs):
193 | super(sinc_conv, self).__init__()
194 |
195 | # Mel Initialization of the filterbanks
196 | low_freq_mel = 80
197 | high_freq_mel = 2595 * np.log10(1 + (fs / 2) / 700) # Convert Hz to Mel
198 | mel_points = np.linspace(
199 | low_freq_mel, high_freq_mel, N_filt
200 | ) # Equally spaced in Mel scale
201 | f_cos = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz
202 | b1 = np.roll(f_cos, 1)
203 | b2 = np.roll(f_cos, -1)
204 | b1[0] = 30
205 | b2[-1] = (fs / 2) - 100
206 |
207 | self.freq_scale = fs * 1.0
208 | self.filt_b1 = nn.Parameter(torch.from_numpy(b1 / self.freq_scale))
209 | self.filt_band = nn.Parameter(torch.from_numpy((b2 - b1) / self.freq_scale))
210 |
211 | self.N_filt = N_filt
212 | self.Filt_dim = Filt_dim
213 | self.fs = fs
214 |
215 | def forward(self, x):
216 |
217 | filters = Variable(torch.zeros((self.N_filt, self.Filt_dim))).cuda()
218 | N = self.Filt_dim
219 | t_right = Variable(
220 | torch.linspace(1, (N - 1) / 2, steps=int((N - 1) / 2)) / self.fs
221 | ).cuda()
222 |
223 | min_freq = 50.0
224 | min_band = 50.0
225 |
226 | filt_beg_freq = torch.abs(self.filt_b1) + min_freq / self.freq_scale
227 | filt_end_freq = filt_beg_freq + (
228 | torch.abs(self.filt_band) + min_band / self.freq_scale
229 | )
230 |
231 | n = torch.linspace(0, N, steps=N)
232 |
233 | # Filter window (hamming)
234 | window = 0.54 - 0.46 * torch.cos(2 * math.pi * n / N)
235 | window = Variable(window.float().cuda())
236 |
237 | for i in range(self.N_filt):
238 |
239 | low_pass1 = (
240 | 2
241 | * filt_beg_freq[i].float()
242 | * sinc(filt_beg_freq[i].float() * self.freq_scale, t_right)
243 | )
244 | low_pass2 = (
245 | 2
246 | * filt_end_freq[i].float()
247 | * sinc(filt_end_freq[i].float() * self.freq_scale, t_right)
248 | )
249 | band_pass = low_pass2 - low_pass1
250 |
251 | band_pass = band_pass / torch.max(band_pass)
252 |
253 | filters[i, :] = band_pass.cuda() * window
254 |
255 | out = F.conv1d(x, filters.view(self.N_filt, 1, self.Filt_dim))
256 |
257 | return out
258 |
259 |
260 | def act_fun(act_type):
261 |
262 | if act_type == "relu":
263 | return nn.ReLU()
264 |
265 | if act_type == "tanh":
266 | return nn.Tanh()
267 |
268 | if act_type == "sigmoid":
269 | return nn.Sigmoid()
270 |
271 | if act_type == "leaky_relu":
272 | return nn.LeakyReLU(0.2)
273 |
274 | if act_type == "elu":
275 | return nn.ELU()
276 |
277 | if act_type == "softmax":
278 | return nn.LogSoftmax(dim=1)
279 |
280 | if act_type == "linear":
281 | return nn.LeakyReLU(1) # initializzed like this, but not used in forward!
282 |
283 |
284 | class LayerNorm(nn.Module):
285 | def __init__(self, features, eps=1e-6):
286 | super(LayerNorm, self).__init__()
287 | self.gamma = nn.Parameter(torch.ones(features))
288 | self.beta = nn.Parameter(torch.zeros(features))
289 | self.eps = eps
290 |
291 | def forward(self, x):
292 | mean = x.mean(-1, keepdim=True)
293 | std = x.std(-1, keepdim=True)
294 | return self.gamma * (x - mean) / (std + self.eps) + self.beta
295 |
296 |
297 | class MLP(nn.Module):
298 | def __init__(self, options):
299 | super(MLP, self).__init__()
300 |
301 | self.input_dim = int(options["input_dim"])
302 | self.fc_lay = options["fc_lay"]
303 | self.fc_drop = options["fc_drop"]
304 | self.fc_use_batchnorm = options["fc_use_batchnorm"]
305 | self.fc_use_laynorm = options["fc_use_laynorm"]
306 | self.fc_use_laynorm_inp = options["fc_use_laynorm_inp"]
307 | self.fc_use_batchnorm_inp = options["fc_use_batchnorm_inp"]
308 | self.fc_act = options["fc_act"]
309 |
310 | self.wx = nn.ModuleList([])
311 | self.bn = nn.ModuleList([])
312 | self.ln = nn.ModuleList([])
313 | self.act = nn.ModuleList([])
314 | self.drop = nn.ModuleList([])
315 |
316 | # input layer normalization
317 | if self.fc_use_laynorm_inp:
318 | self.ln0 = LayerNorm(self.input_dim)
319 |
320 | # input batch normalization
321 | if self.fc_use_batchnorm_inp:
322 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05)
323 |
324 | self.N_fc_lay = len(self.fc_lay)
325 |
326 | current_input = self.input_dim
327 |
328 | # Initialization of hidden layers
329 |
330 | for i in range(self.N_fc_lay):
331 |
332 | # dropout
333 | self.drop.append(nn.Dropout(p=self.fc_drop[i]))
334 |
335 | # activation
336 | self.act.append(act_fun(self.fc_act[i]))
337 |
338 | add_bias = True
339 |
340 | # layer norm initialization
341 | self.ln.append(LayerNorm(self.fc_lay[i]))
342 | self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05))
343 |
344 | if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]:
345 | add_bias = False
346 |
347 | # Linear operations
348 | self.wx.append(nn.Linear(current_input, self.fc_lay[i], bias=add_bias))
349 |
350 | # weight initialization
351 | self.wx[i].weight = torch.nn.Parameter(
352 | torch.Tensor(self.fc_lay[i], current_input).uniform_(
353 | -np.sqrt(0.01 / (current_input + self.fc_lay[i])),
354 | np.sqrt(0.01 / (current_input + self.fc_lay[i])),
355 | )
356 | )
357 | self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i]))
358 |
359 | current_input = self.fc_lay[i]
360 |
361 | def forward(self, x):
362 |
363 | # Applying Layer/Batch Norm
364 | if bool(self.fc_use_laynorm_inp):
365 | x = self.ln0((x))
366 |
367 | if bool(self.fc_use_batchnorm_inp):
368 | x = self.bn0((x))
369 |
370 | for i in range(self.N_fc_lay):
371 |
372 | if self.fc_act[i] != "linear":
373 |
374 | if self.fc_use_laynorm[i]:
375 | x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x))))
376 |
377 | if self.fc_use_batchnorm[i]:
378 | x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x))))
379 |
380 | if (
381 | self.fc_use_batchnorm[i] == False
382 | and self.fc_use_laynorm[i] == False
383 | ):
384 | x = self.drop[i](self.act[i](self.wx[i](x)))
385 |
386 | else:
387 | if self.fc_use_laynorm[i]:
388 | x = self.drop[i](self.ln[i](self.wx[i](x)))
389 |
390 | if self.fc_use_batchnorm[i]:
391 | x = self.drop[i](self.bn[i](self.wx[i](x)))
392 |
393 | if (
394 | self.fc_use_batchnorm[i] == False
395 | and self.fc_use_laynorm[i] == False
396 | ):
397 | x = self.drop[i](self.wx[i](x))
398 |
399 | return x
400 |
401 |
402 | class SincNet(nn.Module):
403 | def __init__(
404 | self,
405 | cnn_N_filt,
406 | cnn_len_filt,
407 | cnn_max_pool_len,
408 | cnn_act,
409 | cnn_drop,
410 | cnn_use_laynorm,
411 | cnn_use_batchnorm,
412 | cnn_use_laynorm_inp,
413 | cnn_use_batchnorm_inp,
414 | input_dim,
415 | fs,
416 | ):
417 | super(SincNet, self).__init__()
418 |
419 | self.cnn_N_filt = cnn_N_filt
420 | self.cnn_len_filt = cnn_len_filt
421 | self.cnn_max_pool_len = cnn_max_pool_len
422 |
423 | self.cnn_act = cnn_act
424 | self.cnn_drop = cnn_drop
425 |
426 | self.cnn_use_laynorm = cnn_use_laynorm
427 | self.cnn_use_batchnorm = cnn_use_batchnorm
428 | self.cnn_use_laynorm_inp = cnn_use_laynorm_inp
429 | self.cnn_use_batchnorm_inp = cnn_use_batchnorm_inp
430 |
431 | self.input_dim = int(input_dim)
432 |
433 | self.fs = fs
434 |
435 | self.N_cnn_lay = len(self.cnn_N_filt)
436 | self.conv = nn.ModuleList([])
437 | self.bn = nn.ModuleList([])
438 | self.ln = nn.ModuleList([])
439 | self.act = nn.ModuleList([])
440 | self.drop = nn.ModuleList([])
441 |
442 | if self.cnn_use_laynorm_inp:
443 | self.ln0 = LayerNorm(self.input_dim)
444 |
445 | if self.cnn_use_batchnorm_inp:
446 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05)
447 |
448 | current_input = self.input_dim
449 |
450 | for i in range(self.N_cnn_lay):
451 |
452 | N_filt = int(self.cnn_N_filt[i])
453 | len_filt = int(self.cnn_len_filt[i])
454 |
455 | # dropout
456 | self.drop.append(nn.Dropout(p=self.cnn_drop[i]))
457 |
458 | # activation
459 | self.act.append(act_fun(self.cnn_act[i]))
460 |
461 | # layer norm initialization
462 | self.ln.append(
463 | LayerNorm(
464 | [
465 | N_filt,
466 | int(
467 | (current_input - self.cnn_len_filt[i] + 1)
468 | / self.cnn_max_pool_len[i]
469 | ),
470 | ]
471 | )
472 | )
473 |
474 | self.bn.append(
475 | nn.BatchNorm1d(
476 | N_filt,
477 | int(
478 | (current_input - self.cnn_len_filt[i] + 1)
479 | / self.cnn_max_pool_len[i]
480 | ),
481 | momentum=0.05,
482 | )
483 | )
484 |
485 | if i == 0:
486 | self.conv.append(
487 | SincConv_fast(self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs)
488 | )
489 |
490 | else:
491 | self.conv.append(
492 | nn.Conv1d(
493 | self.cnn_N_filt[i - 1], self.cnn_N_filt[i], self.cnn_len_filt[i]
494 | )
495 | )
496 |
497 | current_input = int(
498 | (current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i]
499 | )
500 |
501 | self.out_dim = current_input * N_filt
502 |
503 | def forward(self, x):
504 | batch = x.shape[0]
505 | seq_len = x.shape[1]
506 |
507 | if bool(self.cnn_use_laynorm_inp):
508 | x = self.ln0((x))
509 |
510 | if bool(self.cnn_use_batchnorm_inp):
511 | x = self.bn0((x))
512 |
513 | x = x.view(batch, 1, seq_len)
514 |
515 | for i in range(self.N_cnn_lay):
516 |
517 | if self.cnn_use_laynorm[i]:
518 | if i == 0:
519 | x = self.drop[i](
520 | self.act[i](
521 | self.ln[i](
522 | F.max_pool1d(
523 | torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i]
524 | )
525 | )
526 | )
527 | )
528 | else:
529 | x = self.drop[i](
530 | self.act[i](
531 | self.ln[i](
532 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])
533 | )
534 | )
535 | )
536 |
537 | if self.cnn_use_batchnorm[i]:
538 | x = self.drop[i](
539 | self.act[i](
540 | self.bn[i](
541 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])
542 | )
543 | )
544 | )
545 |
546 | if self.cnn_use_batchnorm[i] == False and self.cnn_use_laynorm[i] == False:
547 | x = self.drop[i](
548 | self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))
549 | )
550 |
551 | x = x.view(batch, -1)
552 |
553 | return x
554 |
--------------------------------------------------------------------------------
/clmr/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .callbacks import PlotSpectogramCallback
2 | from .contrastive_learning import ContrastiveLearning
3 | from .linear_evaluation import LinearEvaluation
4 | from .supervised_learning import SupervisedLearning
5 |
--------------------------------------------------------------------------------
/clmr/modules/callbacks.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 |
4 | matplotlib.use("Agg")
5 |
6 | from pytorch_lightning.callbacks import Callback
7 |
8 |
9 | class PlotSpectogramCallback(Callback):
10 | def on_train_start(self, trainer, pl_module):
11 |
12 | if not pl_module.hparams.time_domain:
13 | x, y = trainer.train_dataloader.dataset[0]
14 |
15 | fig = plt.figure()
16 | x_i = x[0, :]
17 | fig.add_subplot(1, 2, 1)
18 | plt.imshow(x_i)
19 | if x.shape[0] > 1:
20 | x_j = x[1, :]
21 | fig.add_subplot(1, 2, 2)
22 | plt.imshow(x_j)
23 |
24 | trainer.logger.experiment.add_figure(
25 | "Train/spectogram_sample", fig, global_step=0
26 | )
27 | plt.close()
28 |
--------------------------------------------------------------------------------
/clmr/modules/contrastive_learning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_lightning import LightningModule
4 | from torch import Tensor
5 |
6 | from simclr import SimCLR
7 | from simclr.modules import NT_Xent, LARS
8 |
9 |
10 | class ContrastiveLearning(LightningModule):
11 | def __init__(self, args, encoder: nn.Module):
12 | super().__init__()
13 | self.save_hyperparameters(args)
14 |
15 | self.encoder = encoder
16 | self.n_features = (
17 | self.encoder.fc.in_features
18 | ) # get dimensions of last fully-connected layer
19 | self.model = SimCLR(self.encoder, self.hparams.projection_dim, self.n_features)
20 | self.criterion = self.configure_criterion()
21 |
22 | def forward(self, x_i: Tensor, x_j: Tensor) -> Tensor:
23 | _, _, z_i, z_j = self.model(x_i, x_j)
24 | loss = self.criterion(z_i, z_j)
25 | return loss
26 |
27 | def training_step(self, batch, _) -> Tensor:
28 | x, _ = batch
29 | x_i = x[:, 0, :]
30 | x_j = x[:, 1, :]
31 | loss = self.forward(x_i, x_j)
32 | self.log("Train/loss", loss)
33 | return loss
34 |
35 | def configure_criterion(self) -> nn.Module:
36 | # PT lightning aggregates differently in DP mode
37 | if self.hparams.accelerator == "dp" and self.hparams.gpus:
38 | batch_size = int(self.hparams.batch_size / self.hparams.gpus)
39 | else:
40 | batch_size = self.hparams.batch_size
41 |
42 | criterion = NT_Xent(batch_size, self.hparams.temperature, world_size=1)
43 | return criterion
44 |
45 | def configure_optimizers(self) -> dict:
46 | scheduler = None
47 | if self.hparams.optimizer == "Adam":
48 | optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
49 | elif self.hparams.optimizer == "LARS":
50 | # optimized using LARS with linear learning rate scaling
51 | # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6.
52 | learning_rate = 0.3 * self.hparams.batch_size / 256
53 | optimizer = LARS(
54 | self.model.parameters(),
55 | lr=learning_rate,
56 | weight_decay=self.hparams.weight_decay,
57 | exclude_from_weight_decay=["batch_normalization", "bias"],
58 | )
59 |
60 | # "decay the learning rate with the cosine decay schedule without restarts"
61 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
62 | optimizer, self.hparams.max_epochs, eta_min=0, last_epoch=-1
63 | )
64 | else:
65 | raise NotImplementedError
66 |
67 | if scheduler:
68 | return {"optimizer": optimizer, "lr_scheduler": scheduler}
69 | else:
70 | return {"optimizer": optimizer}
71 |
--------------------------------------------------------------------------------
/clmr/modules/linear_evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchmetrics
4 | from copy import deepcopy
5 | from pytorch_lightning import LightningModule
6 | from torch import Tensor
7 | from torch.utils.data import DataLoader, Dataset, TensorDataset
8 | from typing import Tuple
9 | from tqdm import tqdm
10 |
11 |
12 | class LinearEvaluation(LightningModule):
13 | def __init__(self, args, encoder: nn.Module, hidden_dim: int, output_dim: int):
14 | super().__init__()
15 | self.save_hyperparameters(args)
16 |
17 | self.encoder = encoder
18 | self.hidden_dim = hidden_dim
19 | self.output_dim = output_dim
20 |
21 | if self.hparams.finetuner_mlp:
22 | self.model = nn.Sequential(
23 | nn.Linear(self.hidden_dim, self.hidden_dim),
24 | nn.ReLU(),
25 | nn.Linear(self.hidden_dim, self.output_dim),
26 | )
27 | else:
28 | self.model = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim))
29 | self.criterion = self.configure_criterion()
30 |
31 | self.accuracy = torchmetrics.Accuracy()
32 | self.average_precision = torchmetrics.AveragePrecision(pos_label=1)
33 |
34 | def forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
35 | preds = self._forward_representations(x, y)
36 | loss = self.criterion(preds, y)
37 | return loss, preds
38 |
39 | def _forward_representations(self, x: Tensor, y: Tensor) -> Tensor:
40 | """
41 | Perform a forward pass using either the representations, or the input data (that we still)
42 | need to extract the represenations from using our encoder.
43 | """
44 | if x.shape[-1] == self.hidden_dim:
45 | h0 = x
46 | else:
47 | with torch.no_grad():
48 | h0 = self.encoder(x)
49 | return self.model(h0)
50 |
51 | def training_step(self, batch, _) -> Tensor:
52 | x, y = batch
53 | loss, preds = self.forward(x, y)
54 |
55 | self.log("Train/accuracy", self.accuracy(preds, y))
56 | # self.log("Train/pr_auc", self.average_precision(preds, y))
57 | self.log("Train/loss", loss)
58 | return loss
59 |
60 | def validation_step(self, batch, _) -> Tensor:
61 | x, y = batch
62 | loss, preds = self.forward(x, y)
63 |
64 | self.log("Valid/accuracy", self.accuracy(preds, y))
65 | # self.log("Valid/pr_auc", self.average_precision(preds, y))
66 | self.log("Valid/loss", loss)
67 | return loss
68 |
69 | def configure_criterion(self) -> nn.Module:
70 | if self.hparams.dataset in ["magnatagatune", "msd"]:
71 | criterion = nn.BCEWithLogitsLoss()
72 | else:
73 | criterion = nn.CrossEntropyLoss()
74 | return criterion
75 |
76 | def configure_optimizers(self) -> dict:
77 | optimizer = torch.optim.Adam(
78 | self.model.parameters(),
79 | lr=self.hparams.finetuner_learning_rate,
80 | weight_decay=self.hparams.weight_decay,
81 | )
82 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
83 | optimizer,
84 | mode="min",
85 | factor=0.1,
86 | patience=5,
87 | threshold=0.0001,
88 | threshold_mode="rel",
89 | cooldown=0,
90 | min_lr=0,
91 | eps=1e-08,
92 | verbose=False,
93 | )
94 | if scheduler:
95 | return {
96 | "optimizer": optimizer,
97 | "lr_scheduler": scheduler,
98 | "monitor": "Valid/loss",
99 | }
100 | else:
101 | return {"optimizer": optimizer}
102 |
103 | def extract_representations(self, dataloader: DataLoader) -> Dataset:
104 |
105 | representations = []
106 | ys = []
107 | for x, y in tqdm(dataloader):
108 | with torch.no_grad():
109 | h0 = self.encoder(x)
110 | representations.append(h0)
111 | ys.append(y)
112 |
113 | if len(representations) > 1:
114 | representations = torch.cat(representations, dim=0)
115 | ys = torch.cat(ys, dim=0)
116 | else:
117 | representations = representations[0]
118 | ys = ys[0]
119 |
120 | tensor_dataset = TensorDataset(representations, ys)
121 | return tensor_dataset
122 |
--------------------------------------------------------------------------------
/clmr/modules/supervised_learning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchmetrics
3 | import torch.nn as nn
4 | from pytorch_lightning import LightningModule
5 |
6 |
7 | class SupervisedLearning(LightningModule):
8 | def __init__(self, args, encoder: nn.Module, output_dim: int):
9 | super().__init__()
10 | self.save_hyperparameters(args)
11 | self.encoder = encoder
12 |
13 | self.encoder.fc.out_features = output_dim
14 | self.output_dim = output_dim
15 | self.model = self.encoder
16 | self.criterion = self.configure_criterion()
17 |
18 | self.average_precision = torchmetrics.AveragePrecision(pos_label=1)
19 |
20 | def forward(self, x, y):
21 | x = x[:, 0, :] # we only have 1 sample, no augmentations
22 | preds = self.model(x)
23 | loss = self.criterion(preds, y)
24 | return loss, preds
25 |
26 | def training_step(self, batch, batch_idx):
27 | x, y = batch
28 | loss, preds = self.forward(x, y)
29 | self.log("Train/pr_auc", self.average_precision(preds, y))
30 | self.log("Train/loss", loss)
31 | return loss
32 |
33 | def validation_step(self, batch, batch_idx):
34 | x, y = batch
35 | loss, preds = self.forward(x, y)
36 | self.log("Valid/pr_auc", self.average_precision(preds, y))
37 | self.log("Valid/loss", loss)
38 | return loss
39 |
40 | def configure_criterion(self):
41 | if self.hparams.dataset in ["magnatagatune"]:
42 | criterion = nn.BCEWithLogitsLoss()
43 | else:
44 | criterion = nn.CrossEntropyLoss()
45 | return criterion
46 |
47 | def configure_optimizers(self):
48 | optimizer = torch.optim.SGD(
49 | self.model.parameters(),
50 | lr=self.hparams.learning_rate,
51 | momentum=0.9,
52 | weight_decay=1e-6,
53 | nesterov=True,
54 | )
55 |
56 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
57 | optimizer, mode="min", factor=0.2, patience=5, verbose=True
58 | )
59 |
60 | if scheduler:
61 | return {
62 | "optimizer": optimizer,
63 | "lr_scheduler": scheduler,
64 | "monitor": "Valid/loss",
65 | }
66 | else:
67 | return {"optimizer": optimizer}
68 |
--------------------------------------------------------------------------------
/clmr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpoint import load_encoder_checkpoint, load_finetuner_checkpoint
2 | from .yaml_config_hook import yaml_config_hook
3 |
--------------------------------------------------------------------------------
/clmr/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 |
5 | def load_encoder_checkpoint(checkpoint_path: str, output_dim: int) -> OrderedDict:
6 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
7 | if "pytorch-lightning_version" in state_dict.keys():
8 | new_state_dict = OrderedDict(
9 | {
10 | k.replace("model.encoder.", ""): v
11 | for k, v in state_dict["state_dict"].items()
12 | if "model.encoder." in k
13 | }
14 | )
15 | else:
16 | new_state_dict = OrderedDict()
17 | for k, v in state_dict.items():
18 | if "encoder." in k:
19 | new_state_dict[k.replace("encoder.", "")] = v
20 |
21 | new_state_dict["fc.weight"] = torch.zeros(output_dim, 512)
22 | new_state_dict["fc.bias"] = torch.zeros(output_dim)
23 | return new_state_dict
24 |
25 |
26 | def load_finetuner_checkpoint(checkpoint_path: str) -> OrderedDict:
27 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
28 | if "pytorch-lightning_version" in state_dict.keys():
29 | state_dict = OrderedDict(
30 | {
31 | k.replace("model.", ""): v
32 | for k, v in state_dict["state_dict"].items()
33 | if "model." in k
34 | }
35 | )
36 | return state_dict
37 |
--------------------------------------------------------------------------------
/clmr/utils/yaml_config_hook.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 |
5 | def yaml_config_hook(config_file):
6 | """
7 | Custom YAML config loader, which can include other yaml files (I like using config files
8 | insteaad of using argparser)
9 | """
10 |
11 | # load yaml files in the nested 'defaults' section, which include defaults for experiments
12 | with open(config_file) as f:
13 | cfg = yaml.safe_load(f)
14 | for d in cfg.get("defaults", []):
15 | config_dir, cf = d.popitem()
16 | cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml")
17 | with open(cf) as f:
18 | l = yaml.safe_load(f)
19 | cfg.update(l)
20 |
21 | if "defaults" in cfg.keys():
22 | del cfg["defaults"]
23 |
24 | return cfg
25 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | # infra options
2 | # gpus: 0
3 | # accelerator: "dp" # use ddp for gpus > 1. Also see PyTorch Lightning documentation on distributed training.
4 | workers: 0 # I recommend tuning this parameter for faster data augmentation processing
5 | dataset_dir: "./data"
6 |
7 | # train options
8 | seed: 42
9 | batch_size: 48
10 | # max_epochs: 200
11 | dataset: "magnatagatune" # ["magnatagatune", "msd", "gtzan", "audio"]
12 | supervised: 0 # train with supervised baseline
13 |
14 | # SimCLR model options
15 | projection_dim: 64 # Projection dim. of SimCLR projector
16 |
17 | # loss options
18 | optimizer: "Adam" # or LARS (experimental)
19 | learning_rate: 0.0003
20 | weight_decay: 1.0e-6 # "optimized using LARS [...] and weight decay of 10−6"
21 | temperature: 0.5 # see appendix B.7.: Optimal temperature under different batch sizes
22 |
23 | # reload options
24 | checkpoint_path: "" # set to the directory containing `checkpoint_##.tar`
25 |
26 | # logistic regression options
27 | finetuner_mlp: 0
28 | finetuner_checkpoint_path: ""
29 | finetuner_max_epochs: 200
30 | finetuner_batch_size: 256
31 | finetuner_learning_rate: 0.001
32 |
33 | # audio data augmentation options
34 | audio_length: 59049
35 | sample_rate: 22050
36 | transforms_polarity: 0.8
37 | transforms_noise: 0.01
38 | transforms_gain: 0.3
39 | transforms_filters: 0.8
40 | transforms_delay: 0.3
41 | transforms_pitch: 0.6
42 | transforms_reverb: 0.6
43 |
--------------------------------------------------------------------------------
/examples/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/.DS_Store
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/.DS_Store
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/clmr_sample-cnn.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/clmr_sample-cnn.onnx
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/.DS_Store
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/fisher_losing_it.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/fisher_losing_it.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/fisher_losing_it_22050.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/fisher_losing_it_22050.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/john_lennon_imagine.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/john_lennon_imagine.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/john_lennon_imagine_22050.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/john_lennon_imagine_22050.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit_22050.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit_22050.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/queen_love_of_my_life.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/queen_love_of_my_life.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/data/queen_love_of_my_life_22050.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/queen_love_of_my_life_22050.mp3
--------------------------------------------------------------------------------
/examples/clmr-onnxruntime-web/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | CLMR on ONNX Runtime
5 |
7 |
8 |
9 |
10 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
91 |
92 | On this page, you are running a CLMR model on the ONNX Runtime in your browser.
93 | No data is sent to a server, all predictions are done on your device!
94 |
95 | In short, the following happens under the hood:
96 |
97 | It will first load a pre-trained model in the background.
98 | The model extracts audio representations from raw samples from the audio buffer.
99 | The audio representations are given to a linear classifier , which predicts the corresponding audio tags.
100 | The multi-label predictions are shown in the bar plot.
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 | Select a piece of music:
111 |
112 | ./data/queen_love_of_my_life_22050.mp3
113 | ./data/nirvana_smells_like_teen_spirit_22050.mp3
114 | ./data/john_lennon_imagine_22050.mp3
115 | ./data/fisher_losing_it_22050.mp3
116 |
117 |
118 |
119 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
Loading ONNX model...
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
156 |
159 |
162 |
163 |
164 |
167 |
168 |
169 |
170 |
552 |
553 |
554 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | """
2 | This script will extract a pre-trained CLMR PyTorch model to an ONNX model.
3 | """
4 |
5 | import argparse
6 | import os
7 | import torch
8 | from collections import OrderedDict
9 | from copy import deepcopy
10 | from clmr.models import SampleCNN, Identity
11 | from clmr.utils import load_encoder_checkpoint, load_finetuner_checkpoint
12 |
13 |
14 | def convert_encoder_to_onnx(
15 | encoder: torch.nn.Module, test_input: torch.Tensor, fp: str
16 | ) -> None:
17 | input_names = ["audio"]
18 | output_names = ["representation"]
19 |
20 | torch.onnx.export(
21 | encoder,
22 | test_input,
23 | fp,
24 | verbose=False,
25 | input_names=input_names,
26 | output_names=output_names,
27 | )
28 |
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser()
32 |
33 | parser.add_argument("--checkpoint_path", type=str, required=True)
34 | parser.add_argument("--finetuner_checkpoint_path", type=str, required=True)
35 | parser.add_argument("--n_classes", type=int, default=50)
36 | args = parser.parse_args()
37 |
38 | if not os.path.exists(args.checkpoint_path):
39 | raise FileNotFoundError("That encoder checkpoint does not exist")
40 |
41 | if not os.path.exists(args.finetuner_checkpoint_path):
42 | raise FileNotFoundError("That linear model checkpoint does not exist")
43 |
44 | # ------------
45 | # encoder
46 | # ------------
47 | encoder = SampleCNN(
48 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3],
49 | supervised=False,
50 | out_dim=args.n_classes,
51 | )
52 |
53 | n_features = encoder.fc.in_features # get dimensions of last fully-connected layer
54 |
55 | state_dict = load_encoder_checkpoint(args.checkpoint_path, args.n_classes)
56 | encoder.load_state_dict(state_dict)
57 | encoder.eval()
58 |
59 | # ------------
60 | # linear model
61 | # ------------
62 | state_dict = load_finetuner_checkpoint(args.finetuner_checkpoint_path)
63 | encoder.fc.load_state_dict(
64 | OrderedDict({k.replace("0.", ""): v for k, v in state_dict.items()})
65 | )
66 |
67 | encoder_export = deepcopy(encoder)
68 | # set last fully connected layer to an identity function:
69 | encoder_export.fc = Identity()
70 |
71 | batch_size = 1
72 | channels = 1
73 | audio_length = 59049
74 | test_input = torch.randn(batch_size, 1, audio_length)
75 |
76 | convert_encoder_to_onnx(encoder, test_input, "clmr_sample-cnn.onnx")
77 | convert_encoder_to_onnx(
78 | encoder_export, test_input, "clmr_encoder_only_sample-cnn.onnx"
79 | )
80 |
--------------------------------------------------------------------------------
/linear_evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pytorch_lightning as pl
4 | from torch.utils.data import DataLoader
5 | from torchaudio_augmentations import Compose, RandomResizedCrop
6 | from pytorch_lightning import Trainer
7 | from pytorch_lightning.callbacks import EarlyStopping
8 | from pytorch_lightning.loggers import TensorBoardLogger
9 |
10 | from clmr.datasets import get_dataset
11 | from clmr.data import ContrastiveDataset
12 | from clmr.evaluation import evaluate
13 | from clmr.models import SampleCNN
14 | from clmr.modules import ContrastiveLearning, LinearEvaluation
15 | from clmr.utils import (
16 | yaml_config_hook,
17 | load_encoder_checkpoint,
18 | load_finetuner_checkpoint,
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 |
24 | parser = argparse.ArgumentParser(description="SimCLR")
25 | parser = Trainer.add_argparse_args(parser)
26 |
27 | config = yaml_config_hook("./config/config.yaml")
28 | for k, v in config.items():
29 | parser.add_argument(f"--{k}", default=v, type=type(v))
30 |
31 | args = parser.parse_args()
32 | pl.seed_everything(args.seed)
33 | args.accelerator = None
34 |
35 | if not os.path.exists(args.checkpoint_path):
36 | raise FileNotFoundError("That checkpoint does not exist")
37 |
38 | train_transform = [RandomResizedCrop(n_samples=args.audio_length)]
39 |
40 | # ------------
41 | # dataloaders
42 | # ------------
43 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train")
44 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid")
45 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test")
46 |
47 | contrastive_train_dataset = ContrastiveDataset(
48 | train_dataset,
49 | input_shape=(1, args.audio_length),
50 | transform=Compose(train_transform),
51 | )
52 |
53 | contrastive_valid_dataset = ContrastiveDataset(
54 | valid_dataset,
55 | input_shape=(1, args.audio_length),
56 | transform=Compose(train_transform),
57 | )
58 |
59 | contrastive_test_dataset = ContrastiveDataset(
60 | test_dataset,
61 | input_shape=(1, args.audio_length),
62 | transform=None,
63 | )
64 |
65 | train_loader = DataLoader(
66 | contrastive_train_dataset,
67 | batch_size=args.finetuner_batch_size,
68 | num_workers=args.workers,
69 | shuffle=True,
70 | )
71 |
72 | valid_loader = DataLoader(
73 | contrastive_valid_dataset,
74 | batch_size=args.finetuner_batch_size,
75 | num_workers=args.workers,
76 | shuffle=False,
77 | )
78 |
79 | test_loader = DataLoader(
80 | contrastive_test_dataset,
81 | batch_size=args.finetuner_batch_size,
82 | num_workers=args.workers,
83 | shuffle=False,
84 | )
85 |
86 | # ------------
87 | # encoder
88 | # ------------
89 | encoder = SampleCNN(
90 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3],
91 | supervised=args.supervised,
92 | out_dim=train_dataset.n_classes,
93 | )
94 |
95 | n_features = encoder.fc.in_features # get dimensions of last fully-connected layer
96 |
97 | state_dict = load_encoder_checkpoint(args.checkpoint_path, train_dataset.n_classes)
98 | encoder.load_state_dict(state_dict)
99 |
100 | cl = ContrastiveLearning(args, encoder)
101 | cl.eval()
102 | cl.freeze()
103 |
104 | module = LinearEvaluation(
105 | args,
106 | cl.encoder,
107 | hidden_dim=n_features,
108 | output_dim=train_dataset.n_classes,
109 | )
110 |
111 | train_representations_dataset = module.extract_representations(train_loader)
112 | train_loader = DataLoader(
113 | train_representations_dataset,
114 | batch_size=args.batch_size,
115 | num_workers=args.workers,
116 | shuffle=True,
117 | )
118 |
119 | valid_representations_dataset = module.extract_representations(valid_loader)
120 | valid_loader = DataLoader(
121 | valid_representations_dataset,
122 | batch_size=args.batch_size,
123 | num_workers=args.workers,
124 | shuffle=False,
125 | )
126 |
127 | if args.finetuner_checkpoint_path:
128 | state_dict = load_finetuner_checkpoint(args.finetuner_checkpoint_path)
129 | module.model.load_state_dict(state_dict)
130 | else:
131 | early_stop_callback = EarlyStopping(
132 | monitor="Valid/loss", patience=10, verbose=False, mode="min"
133 | )
134 |
135 | trainer = Trainer.from_argparse_args(
136 | args,
137 | logger=TensorBoardLogger(
138 | "runs", name="CLMRv2-eval-{}".format(args.dataset)
139 | ),
140 | max_epochs=args.finetuner_max_epochs,
141 | callbacks=[early_stop_callback],
142 | )
143 | trainer.fit(module, train_loader, valid_loader)
144 |
145 | device = "cuda:0" if args.gpus else "cpu"
146 | results = evaluate(
147 | module.encoder,
148 | module.model,
149 | contrastive_test_dataset,
150 | args.dataset,
151 | args.audio_length,
152 | device=device,
153 | )
154 | print(results)
155 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pytorch_lightning as pl
3 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
4 | from pytorch_lightning import Trainer
5 | from pytorch_lightning.loggers import TensorBoardLogger
6 | from torch.utils.data import DataLoader
7 |
8 | # Audio Augmentations
9 | from torchaudio_augmentations import (
10 | RandomApply,
11 | ComposeMany,
12 | RandomResizedCrop,
13 | PolarityInversion,
14 | Noise,
15 | Gain,
16 | HighLowPass,
17 | Delay,
18 | PitchShift,
19 | Reverb,
20 | )
21 |
22 | from clmr.data import ContrastiveDataset
23 | from clmr.datasets import get_dataset
24 | from clmr.evaluation import evaluate
25 | from clmr.models import SampleCNN
26 | from clmr.modules import ContrastiveLearning, SupervisedLearning
27 | from clmr.utils import yaml_config_hook
28 |
29 |
30 | if __name__ == "__main__":
31 |
32 | parser = argparse.ArgumentParser(description="CLMR")
33 | parser = Trainer.add_argparse_args(parser)
34 |
35 | config = yaml_config_hook("./config/config.yaml")
36 | for k, v in config.items():
37 | parser.add_argument(f"--{k}", default=v, type=type(v))
38 |
39 | args = parser.parse_args()
40 | pl.seed_everything(args.seed)
41 |
42 | # ------------
43 | # data augmentations
44 | # ------------
45 | if args.supervised:
46 | train_transform = [RandomResizedCrop(n_samples=args.audio_length)]
47 | num_augmented_samples = 1
48 | else:
49 | train_transform = [
50 | RandomResizedCrop(n_samples=args.audio_length),
51 | RandomApply([PolarityInversion()], p=args.transforms_polarity),
52 | RandomApply([Noise()], p=args.transforms_noise),
53 | RandomApply([Gain()], p=args.transforms_gain),
54 | RandomApply(
55 | [HighLowPass(sample_rate=args.sample_rate)], p=args.transforms_filters
56 | ),
57 | RandomApply([Delay(sample_rate=args.sample_rate)], p=args.transforms_delay),
58 | RandomApply(
59 | [
60 | PitchShift(
61 | n_samples=args.audio_length,
62 | sample_rate=args.sample_rate,
63 | )
64 | ],
65 | p=args.transforms_pitch,
66 | ),
67 | RandomApply(
68 | [Reverb(sample_rate=args.sample_rate)], p=args.transforms_reverb
69 | ),
70 | ]
71 | num_augmented_samples = 2
72 |
73 | # ------------
74 | # dataloaders
75 | # ------------
76 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train")
77 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid")
78 | contrastive_train_dataset = ContrastiveDataset(
79 | train_dataset,
80 | input_shape=(1, args.audio_length),
81 | transform=ComposeMany(
82 | train_transform, num_augmented_samples=num_augmented_samples
83 | ),
84 | )
85 |
86 | contrastive_valid_dataset = ContrastiveDataset(
87 | valid_dataset,
88 | input_shape=(1, args.audio_length),
89 | transform=ComposeMany(
90 | train_transform, num_augmented_samples=num_augmented_samples
91 | ),
92 | )
93 |
94 | train_loader = DataLoader(
95 | contrastive_train_dataset,
96 | batch_size=args.batch_size,
97 | num_workers=args.workers,
98 | drop_last=True,
99 | shuffle=True,
100 | )
101 |
102 | valid_loader = DataLoader(
103 | contrastive_valid_dataset,
104 | batch_size=args.batch_size,
105 | num_workers=args.workers,
106 | drop_last=True,
107 | shuffle=False,
108 | )
109 |
110 | # ------------
111 | # encoder
112 | # ------------
113 | encoder = SampleCNN(
114 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3],
115 | supervised=args.supervised,
116 | out_dim=train_dataset.n_classes,
117 | )
118 |
119 | # ------------
120 | # model
121 | # ------------
122 | if args.supervised:
123 | module = SupervisedLearning(args, encoder, output_dim=train_dataset.n_classes)
124 | else:
125 | module = ContrastiveLearning(args, encoder)
126 |
127 | logger = TensorBoardLogger("runs", name="CLMRv2-{}".format(args.dataset))
128 | if args.checkpoint_path:
129 | module = module.load_from_checkpoint(
130 | args.checkpoint_path, encoder=encoder, output_dim=train_dataset.n_classes
131 | )
132 |
133 | else:
134 | # ------------
135 | # training
136 | # ------------
137 |
138 | if args.supervised:
139 | early_stopping = EarlyStopping(monitor="Valid/loss", patience=20)
140 | else:
141 | early_stopping = None
142 |
143 | trainer = Trainer.from_argparse_args(
144 | args,
145 | logger=logger,
146 | sync_batchnorm=True,
147 | max_epochs=args.max_epochs,
148 | log_every_n_steps=10,
149 | check_val_every_n_epoch=1,
150 | accelerator=args.accelerator,
151 | )
152 | trainer.fit(module, train_loader, valid_loader)
153 |
154 | if args.supervised:
155 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test")
156 |
157 | contrastive_test_dataset = ContrastiveDataset(
158 | test_dataset,
159 | input_shape=(1, args.audio_length),
160 | transform=None,
161 | )
162 |
163 | device = "cuda:0" if args.gpus else "cpu"
164 | results = evaluate(
165 | module.encoder,
166 | None,
167 | contrastive_test_dataset,
168 | args.dataset,
169 | args.audio_length,
170 | device=device,
171 | )
172 | print(results)
173 |
--------------------------------------------------------------------------------
/media/clmr_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/media/clmr_model.png
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | from clmr.datasets import get_dataset
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--dataset", type=str, default="magnatagatune")
8 | parser.add_argument("--dataset_dir", type=str, default="./data")
9 | parser.add_argument("--sample_rate", type=int, default=22050)
10 | args = parser.parse_args()
11 |
12 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train")
13 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid")
14 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test")
15 |
16 | for i in tqdm(range(len(train_dataset))):
17 | train_dataset.preprocess(i, args.sample_rate)
18 |
19 | for i in tqdm(range(len(valid_dataset))):
20 | valid_dataset.preprocess(i, args.sample_rate)
21 |
22 | for i in tqdm(range(len(test_dataset))):
23 | test_dataset.preprocess(i, args.sample_rate)
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | simclr
2 | torchaudio-augmentations
3 | torch==1.9.0
4 | torchaudio
5 | pytorch-lightning
6 | soundfile
7 | sklearn
8 | matplotlib
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pipenv install twine --dev
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 | from setuptools import find_packages, setup, Command
13 |
14 | # Package meta-data.
15 | NAME = "clmr"
16 | DESCRIPTION = "Contrastive Learning of Musical Representations"
17 | URL = "https://github.com/spijkervet/CLMR"
18 | EMAIL = "janne.spijkervet@gmail.com"
19 | AUTHOR = "Janne Spijkervet"
20 | REQUIRES_PYTHON = ">=3.6.0"
21 | VERSION = "0.1.0"
22 |
23 | # What packages are required for this module to be executed?
24 | REQUIRED = [
25 | "torch==1.9.0",
26 | "torchaudio",
27 | "simclr",
28 | "torchaudio-augmentations",
29 | "pytorch-lightning",
30 | "soundfile",
31 | "sklearn",
32 | "matplotlib",
33 | ]
34 |
35 | # What packages are optional?
36 | EXTRAS = {
37 | # 'fancy feature': ['django'],
38 | }
39 |
40 | # The rest you shouldn't have to touch too much :)
41 | # ------------------------------------------------
42 | # Except, perhaps the License and Trove Classifiers!
43 | # If you do change the License, remember to change the Trove Classifier for that!
44 |
45 | here = os.path.abspath(os.path.dirname(__file__))
46 |
47 | # Import the README and use it as the long-description.
48 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
49 | try:
50 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
51 | long_description = "\n" + f.read()
52 | except FileNotFoundError:
53 | long_description = DESCRIPTION
54 |
55 | # Load the package's __version__.py module as a dictionary.
56 | about = {}
57 | if not VERSION:
58 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
59 | with open(os.path.join(here, project_slug, "__version__.py")) as f:
60 | exec(f.read(), about)
61 | else:
62 | about["__version__"] = VERSION
63 |
64 |
65 | class UploadCommand(Command):
66 | """Support setup.py upload."""
67 |
68 | description = "Build and publish the package."
69 | user_options = []
70 |
71 | @staticmethod
72 | def status(s):
73 | """Prints things in bold."""
74 | print("\033[1m{0}\033[0m".format(s))
75 |
76 | def initialize_options(self):
77 | pass
78 |
79 | def finalize_options(self):
80 | pass
81 |
82 | def run(self):
83 | try:
84 | self.status("Removing previous builds…")
85 | rmtree(os.path.join(here, "dist"))
86 | except OSError:
87 | pass
88 |
89 | self.status("Building Source and Wheel (universal) distribution…")
90 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
91 |
92 | self.status("Uploading the package to PyPI via Twine…")
93 | os.system("twine upload dist/*")
94 |
95 | self.status("Pushing git tags…")
96 | os.system("git tag v{0}".format(about["__version__"]))
97 | os.system("git push --tags")
98 |
99 | sys.exit()
100 |
101 |
102 | # Where the magic happens:
103 | setup(
104 | name=NAME,
105 | version=about["__version__"],
106 | description=DESCRIPTION,
107 | long_description=long_description,
108 | long_description_content_type="text/markdown",
109 | author=AUTHOR,
110 | author_email=EMAIL,
111 | python_requires=REQUIRES_PYTHON,
112 | url=URL,
113 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
114 | # If your package is a single module, use this instead of 'packages':
115 | # py_modules=['mypackage'],
116 | # entry_points={
117 | # 'console_scripts': ['mycli=mymodule:cli'],
118 | # },
119 | install_requires=REQUIRED,
120 | extras_require=EXTRAS,
121 | include_package_data=True,
122 | license="MIT",
123 | classifiers=[
124 | # Trove classifiers
125 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
126 | "License :: OSI Approved :: MIT License",
127 | "Programming Language :: Python",
128 | "Programming Language :: Python :: 3",
129 | "Programming Language :: Python :: 3.6",
130 | "Programming Language :: Python :: Implementation :: CPython",
131 | "Programming Language :: Python :: Implementation :: PyPy",
132 | ],
133 | # $ setup.py publish support.
134 | cmdclass={
135 | "upload": UploadCommand,
136 | },
137 | )
138 |
--------------------------------------------------------------------------------
/tests/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/.DS_Store
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/__init__.py
--------------------------------------------------------------------------------
/tests/data/audioset/1272-128104-0000.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/data/audioset/1272-128104-0000.wav
--------------------------------------------------------------------------------
/tests/test_audioset.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torchaudio
3 | from torchaudio_augmentations import (
4 | Compose,
5 | RandomApply,
6 | RandomResizedCrop,
7 | PolarityInversion,
8 | Noise,
9 | Gain,
10 | Delay,
11 | PitchShift,
12 | Reverb,
13 | )
14 | from clmr.datasets import AUDIO
15 |
16 |
17 | class TestAudioSet(unittest.TestCase):
18 | sample_rate = 16000
19 |
20 | def get_audio_transforms(self, num_samples):
21 | transform = Compose(
22 | [
23 | RandomResizedCrop(n_samples=num_samples),
24 | RandomApply([PolarityInversion()], p=0.8),
25 | RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3),
26 | RandomApply([Gain()], p=0.2),
27 | RandomApply([Delay(sample_rate=self.sample_rate)], p=0.5),
28 | RandomApply(
29 | [PitchShift(n_samples=num_samples, sample_rate=self.sample_rate)],
30 | p=0.4,
31 | ),
32 | RandomApply([Reverb(sample_rate=self.sample_rate)], p=0.3),
33 | ]
34 | )
35 | return transform
36 |
37 | def test_audioset(self):
38 | audio_dataset = AUDIO("./tests/data/audioset")
39 | audio, label = audio_dataset[0]
40 | assert audio.shape[0] == 1
41 | assert audio.shape[1] == 93680
42 |
43 | num_samples = (
44 | self.sample_rate * 5
45 | ) # the test item is approximately 5.8 seconds.
46 | transform = self.get_audio_transforms(num_samples=num_samples)
47 | audio = transform(audio)
48 | torchaudio.save("augmented_sample.wav", audio, sample_rate=self.sample_rate)
49 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pytest
3 | from clmr.datasets import (
4 | get_dataset,
5 | AUDIO,
6 | LIBRISPEECH,
7 | GTZAN,
8 | MAGNATAGATUNE,
9 | MillionSongDataset,
10 | )
11 |
12 |
13 | class TestAudioSet(unittest.TestCase):
14 |
15 | datasets = {
16 | "librispeech": LIBRISPEECH,
17 | "gtzan": GTZAN,
18 | "magnatagatune": MAGNATAGATUNE,
19 | "msd": MillionSongDataset,
20 | "audio": AUDIO,
21 | }
22 |
23 | def test_dataset_names(self):
24 | for dataset_name, dataset_type in self.datasets.items():
25 | with pytest.raises(RuntimeError):
26 | _ = get_dataset(
27 | dataset_name, "./data/audio", subset="train", download=False
28 | )
29 |
30 | def test_custom_audio_dataset(self):
31 | audio_dataset = get_dataset(
32 | "audio", "./tests/data/audioset", subset="train", download=False
33 | )
34 | assert type(audio_dataset) == AUDIO
35 | assert len(audio_dataset) == 1
36 |
--------------------------------------------------------------------------------
/tests/test_spectogram.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torchaudio
3 | import torch.nn as nn
4 | from torchaudio_augmentations import *
5 |
6 | from clmr.datasets import AUDIO
7 |
8 |
9 | class TestAudioSet(unittest.TestCase):
10 | sample_rate = 16000
11 |
12 | def get_audio_transforms(self, num_samples):
13 | transform = Compose(
14 | [
15 | RandomResizedCrop(n_samples=num_samples),
16 | RandomApply([PolarityInversion()], p=0.8),
17 | RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3),
18 | RandomApply([Gain()], p=0.2),
19 | RandomApply([Delay(sample_rate=self.sample_rate)], p=0.5),
20 | RandomApply(
21 | [PitchShift(n_samples=num_samples, sample_rate=self.sample_rate)],
22 | p=0.4,
23 | ),
24 | RandomApply([Reverb(sample_rate=self.sample_rate)], p=0.3),
25 | ]
26 | )
27 | return transform
28 |
29 | def test_audioset(self):
30 | audio_dataset = AUDIO("tests/data/audioset")
31 | audio, label = audio_dataset[0]
32 |
33 | sample_rate = 22050
34 | n_fft = 1024
35 | n_mels = 128
36 | stype = "magnitude" # magnitude
37 | top_db = None # f_max
38 |
39 | transform = self.get_audio_transforms(num_samples=sample_rate)
40 |
41 | spec_transform = nn.Sequential(
42 | torchaudio.transforms.MelSpectrogram(
43 | sample_rate=sample_rate,
44 | n_fft=n_fft,
45 | n_mels=n_mels,
46 | ),
47 | torchaudio.transforms.AmplitudeToDB(stype=stype, top_db=top_db),
48 | )
49 |
50 | audio = transform(audio)
51 | audio = spec_transform(audio)
52 | assert audio.shape[1] == 128
53 | assert audio.shape[2] == 44
54 |
--------------------------------------------------------------------------------