├── .github └── workflows │ └── clmr.yml ├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── clmr ├── data.py ├── datasets │ ├── __init__.py │ ├── audio.py │ ├── dataset.py │ ├── gtzan.py │ ├── librispeech.py │ ├── magnatagatune.py │ └── million_song_dataset.py ├── evaluation.py ├── models │ ├── __init__.py │ ├── model.py │ ├── sample_cnn.py │ ├── sample_cnn_xl.py │ ├── shortchunk_cnn.py │ └── sinc_net.py ├── modules │ ├── __init__.py │ ├── callbacks.py │ ├── contrastive_learning.py │ ├── linear_evaluation.py │ └── supervised_learning.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ └── yaml_config_hook.py ├── config └── config.yaml ├── examples ├── .DS_Store └── clmr-onnxruntime-web │ ├── .DS_Store │ ├── clmr_sample-cnn.onnx │ ├── data │ ├── .DS_Store │ ├── fisher_losing_it.mp3 │ ├── fisher_losing_it_22050.mp3 │ ├── john_lennon_imagine.mp3 │ ├── john_lennon_imagine_22050.mp3 │ ├── nirvana_smells_like_teen_spirit.mp3 │ ├── nirvana_smells_like_teen_spirit_22050.mp3 │ ├── queen_love_of_my_life.mp3 │ └── queen_love_of_my_life_22050.mp3 │ └── index.html ├── export.py ├── linear_evaluation.py ├── main.py ├── media └── clmr_model.png ├── preprocess.py ├── requirements.txt ├── setup.py └── tests ├── .DS_Store ├── __init__.py ├── data └── audioset │ └── 1272-128104-0000.wav ├── test_audioset.py ├── test_dataset.py └── test_spectogram.py /.github/workflows/clmr.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CLMR 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.9 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.9 23 | - name: Install dependencies 24 | run: | 25 | sudo apt-get install libsndfile1 26 | python -m pip install --upgrade pip 27 | pip install black pytest 28 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 29 | - name: Lint with black 30 | run: | 31 | black --check . 32 | - name: Test with pytest 33 | run: | 34 | pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | /data 4 | runs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Learning of Musical Representations 2 | 3 | PyTorch implementation of [Contrastive Learning of Musical Representations](https://arxiv.org/abs/2103.09410) by Janne Spijkervet and John Ashley Burgoyne. 4 | 5 | ![CLMR](https://github.com/spijkervet/clmr/actions/workflows/clmr.yml/badge.svg) 6 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Njz8EoN4br587xjpRKcssMuqQY6Cc5nj#scrollTo=aeKVT59FhWzV) 7 | 8 | [![arXiv](https://img.shields.io/badge/arXiv-2103.09410-b31b1b.svg)](https://arxiv.org/abs/2103.09410) 9 | [![Supplementary Material](https://img.shields.io/badge/Supplementary%20Material-2103.09410-blue.svg)](https://github.com/Spijkervet/CLMR/releases/download/2.1/CLMR.-.Supplementary.Material.pdf) 10 | 11 |
12 | 13 | CLMR x 14 | 15 | 16 |
17 | 18 | You can run a pre-trained CLMR model directly from within your browser using ONNX Runtime: [here](https://spijkervet.github.io/CLMR/examples/clmr-onnxruntime-web). 19 | 20 | 21 | In this work, we introduce SimCLR to the music domain and contribute a large chain of audio data augmentations, to form a simple framework for self-supervised learning of raw waveforms of music: CLMR. We evaluate the performance of the self-supervised learned representations on the task of music classification. 22 | 23 | - We achieve competitive results on the MagnaTagATune and Million Song Datasets relative to fully supervised training, despite only using a linear classifier on self-supervised learned representations, i.e., representations that were learned task-agnostically without any labels. 24 | - CLMR enables efficient classification: with only 1% of the labeled data, we achieve similar scores compared to using 100% of the labeled data. 25 | - CLMR is able to generalise to out-of-domain datasets: when training on entirely different music datasets, it is still able to perform competitively compared to fully supervised training on the target dataset. 26 | 27 | *This is the CLMR v2 implementation, for the original implementation go to the [`v1`](https://github.com/Spijkervet/CLMR/tree/v1) branch* 28 | 29 |
30 | CLMR model 31 |
32 |
33 | An illustration of CLMR. 34 |
35 | 36 | 37 | This repository relies on my SimCLR implementation, which can be found [here](https://github.com/spijkervet/simclr) and on my `torchaudio-augmentations` package, found [here](https://github.com/Spijkervet/torchaudio-augmentations). 38 | 39 | 40 | 41 | ## Quickstart 42 | ``` 43 | git clone https://github.com/spijkervet/clmr.git && cd clmr 44 | 45 | pip3 install -r requirements.txt 46 | # or 47 | python3 setup.py install 48 | ``` 49 | 50 | The following command downloads MagnaTagATune, preprocesses it and starts self-supervised pre-training on 1 GPU (with 8 simultaneous CPU workers) and linear evaluation: 51 | ``` 52 | python3 preprocess.py --dataset magnatagatune 53 | 54 | # add --workers 8 to increase the number of parallel CPU threads to speed up online data augmentations + training. 55 | python3 main.py --dataset magnatagatune --gpus 1 --workers 8 56 | 57 | python3 linear_evaluation.py --gpus 1 --workers 8 --checkpoint_path [path to checkpoint.pt, usually in ./runs] 58 | ``` 59 | 60 | ## Pre-train on your own folder of audio files 61 | Simply run the following command to pre-train the CLMR model on a folder containing .wav files (or .mp3 files when editing `src_ext_audio=".mp3"` in `clmr/datasets/audio.py`). You may need to convert your audio files to the correct sample rate first, before giving it to the encoder (which accepts `22,050Hz` per default). 62 | 63 | ``` 64 | python preprocess.py --dataset audio --dataset_dir ./directory_containing_audio_files 65 | 66 | python main.py --dataset audio --dataset_dir ./directory_containing_audio_files 67 | ``` 68 | 69 | 70 | ## Results 71 | 72 | ### MagnaTagATune 73 | 74 | | Encoder / Model | Batch-size / epochs | Fine-tune head | ROC-AUC | PR-AUC | 75 | |-------------|-------------|-------------|-------------|-------------| 76 | | SampleCNN / CLMR | 48 / 10000 | Linear Classifier | 88.7 | **35.6** | 77 | SampleCNN / CLMR | 48 / 10000 | MLP (1 extra hidden layer) | **89.3** | **36.0** | 78 | | [SampleCNN (fully supervised)](https://www.mdpi.com/2076-3417/8/1/150) | 48 / - | - | 88.6 | 34.4 | 79 | | [Pons et al. (fully supervised)](https://arxiv.org/pdf/1711.02520.pdf) | 48 / - | - | 89.1 | 34.92 | 80 | 81 | ### Million Song Dataset 82 | 83 | | Encoder / Model | Batch-size / epochs | Fine-tune head | ROC-AUC | PR-AUC | 84 | |-------------|-------------|-------------|-------------|-------------| 85 | | SampleCNN / CLMR | 48 / 1000 | Linear Classifier | 85.7 | 25.0 | 86 | | [SampleCNN (fully supervised)](https://www.mdpi.com/2076-3417/8/1/150) | 48 / - | - | **88.4** | - | 87 | | [Pons et al. (fully supervised)](https://arxiv.org/pdf/1711.02520.pdf) | 48 / - | - | 87.4 | **28.5** | 88 | 89 | 90 | ## Pre-trained models 91 | *Links go to download* 92 | 93 | | Encoder (batch-size, epochs) | Fine-tune head | Pre-train dataset | ROC-AUC | PR-AUC 94 | |-------------|-------------|-------------|-------------|-------------| 95 | [SampleCNN (96, 10000)](https://github.com/Spijkervet/CLMR/releases/download/2.0/clmr_checkpoint_10000.zip) | [Linear Classifier](https://github.com/Spijkervet/CLMR/releases/download/2.0/finetuner_checkpoint_200.zip) | MagnaTagATune | 88.7 (89.3) | 35.6 (36.0) 96 | [SampleCNN (48, 1550)](https://github.com/Spijkervet/CLMR/releases/download/1.0/clmr_checkpoint_1550.pt) | [Linear Classifier](https://github.com/Spijkervet/CLMR/releases/download/1.0-l/finetuner_checkpoint_20.pt) | MagnaTagATune | 87.71 (88.47) | 34.27 (34.96) 97 | 98 | ## Training 99 | ### 1. Pre-training 100 | Simply run the following command to pre-train the CLMR model on the MagnaTagATune dataset. 101 | ``` 102 | python main.py --dataset magnatagatune 103 | ``` 104 | 105 | ### 2. Linear evaluation 106 | To test a trained model, make sure to set the `checkpoint_path` variable in the `config/config.yaml`, or specify it as an argument: 107 | ``` 108 | python linear_evaluation.py --checkpoint_path ./clmr_checkpoint_10000.pt 109 | ``` 110 | 111 | ## Configuration 112 | The configuration of training can be found in: `config/config.yaml`. I personally prefer to use files instead of long strings of arguments when configuring a run. Every entry in the config file can be overrided with the corresponding flag (e.g. `--max_epochs 500` if you would like to train with 500 epochs). 113 | 114 | ## Logging and TensorBoard 115 | To view results in TensorBoard, run: 116 | ``` 117 | tensorboard --logdir ./runs 118 | ``` 119 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /clmr/data.py: -------------------------------------------------------------------------------- 1 | """Wrapper for Torch Dataset class to enable contrastive training 2 | """ 3 | import torch 4 | from torch import Tensor 5 | from torch.utils.data import Dataset 6 | from torchaudio_augmentations import Compose 7 | from typing import Tuple, List 8 | 9 | 10 | class ContrastiveDataset(Dataset): 11 | def __init__(self, dataset: Dataset, input_shape: List[int], transform: Compose): 12 | self.dataset = dataset 13 | self.transform = transform 14 | self.input_shape = input_shape 15 | self.ignore_idx = [] 16 | 17 | def __getitem__(self, idx) -> Tuple[Tensor, Tensor]: 18 | if idx in self.ignore_idx: 19 | return self[idx + 1] 20 | 21 | audio, label = self.dataset[idx] 22 | 23 | if audio.shape[1] < self.input_shape[1]: 24 | self.ignore_idx.append(idx) 25 | return self[idx + 1] 26 | 27 | if self.transform: 28 | audio = self.transform(audio) 29 | return audio, label 30 | 31 | def __len__(self) -> int: 32 | return len(self.dataset) 33 | 34 | def concat_clip(self, n: int, audio_length: float) -> Tensor: 35 | audio, _ = self.dataset[n] 36 | batch = torch.split(audio, audio_length, dim=1) 37 | batch = torch.cat(batch[:-1]) 38 | batch = batch.unsqueeze(dim=1) 39 | 40 | if self.transform: 41 | batch = self.transform(batch) 42 | 43 | return batch 44 | -------------------------------------------------------------------------------- /clmr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .dataset import Dataset 3 | from .audio import AUDIO 4 | from .librispeech import LIBRISPEECH 5 | from .gtzan import GTZAN 6 | from .magnatagatune import MAGNATAGATUNE 7 | from .million_song_dataset import MillionSongDataset 8 | 9 | 10 | def get_dataset(dataset, dataset_dir, subset, download=True): 11 | 12 | if not os.path.exists(dataset_dir): 13 | os.makedirs(dataset_dir) 14 | 15 | if dataset == "audio": 16 | d = AUDIO(root=dataset_dir) 17 | elif dataset == "librispeech": 18 | d = LIBRISPEECH(root=dataset_dir, download=download, subset=subset) 19 | elif dataset == "gtzan": 20 | d = GTZAN(root=dataset_dir, download=download, subset=subset) 21 | elif dataset == "magnatagatune": 22 | d = MAGNATAGATUNE(root=dataset_dir, download=download, subset=subset) 23 | elif dataset == "msd": 24 | d = MillionSongDataset(root=dataset_dir, subset=subset) 25 | else: 26 | raise NotImplementedError("Dataset not implemented") 27 | return d 28 | -------------------------------------------------------------------------------- /clmr/datasets/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from torch import Tensor 4 | from typing import Tuple 5 | 6 | 7 | from clmr.datasets import Dataset 8 | 9 | 10 | class AUDIO(Dataset): 11 | """Create a Dataset for any folder of audio files. 12 | Args: 13 | root (str): Path to the directory where the dataset is found or downloaded. 14 | src_ext_audio (str): The extension of the audio files to analyze. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | root: str, 20 | src_ext_audio: str = ".wav", 21 | n_classes: int = 1, 22 | ) -> None: 23 | super(AUDIO, self).__init__(root) 24 | 25 | self._path = root 26 | self._src_ext_audio = src_ext_audio 27 | self.n_classes = n_classes 28 | 29 | self.fl = glob( 30 | os.path.join(self._path, "**", "*{}".format(self._src_ext_audio)), 31 | recursive=True, 32 | ) 33 | 34 | if len(self.fl) == 0: 35 | raise RuntimeError( 36 | "Dataset not found. Please place the audio files in the {} folder.".format( 37 | self._path 38 | ) 39 | ) 40 | 41 | def file_path(self, n: int) -> str: 42 | fp = self.fl[n] 43 | return fp 44 | 45 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]: 46 | """Load the n-th sample from the dataset. 47 | 48 | Args: 49 | n (int): The index of the sample to be loaded 50 | 51 | Returns: 52 | Tuple [Tensor, Tensor]: ``(waveform, label)`` 53 | """ 54 | audio, _ = self.load(n) 55 | label = [] 56 | return audio, label 57 | 58 | def __len__(self) -> int: 59 | return len(self.fl) 60 | -------------------------------------------------------------------------------- /clmr/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import torchaudio 4 | from torch.utils.data import Dataset as TorchDataset 5 | from abc import abstractmethod 6 | 7 | 8 | def preprocess_audio(source, target, sample_rate): 9 | p = subprocess.Popen( 10 | ["ffmpeg", "-i", source, "-ar", str(sample_rate), target, "-loglevel", "quiet"] 11 | ) 12 | p.wait() 13 | 14 | 15 | class Dataset(TorchDataset): 16 | 17 | _ext_audio = ".wav" 18 | 19 | def __init__(self, root: str): 20 | pass 21 | 22 | @abstractmethod 23 | def file_path(self, n: int): 24 | pass 25 | 26 | def target_file_path(self, n: int) -> str: 27 | fp = self.file_path(n) 28 | file_basename, _ = os.path.splitext(fp) 29 | return file_basename + self._ext_audio 30 | 31 | def preprocess(self, n: int, sample_rate: int): 32 | fp = self.file_path(n) 33 | target_fp = self.target_file_path(n) 34 | 35 | if not os.path.exists(target_fp): 36 | preprocess_audio(fp, target_fp, sample_rate) 37 | 38 | def load(self, n): 39 | target_fp = self.target_file_path(n) 40 | try: 41 | audio, sample_rate = torchaudio.load(target_fp) 42 | except OSError as e: 43 | print("File not found, try running `python preprocess.py` first.\n\n", e) 44 | return 45 | return audio, sample_rate 46 | -------------------------------------------------------------------------------- /clmr/datasets/gtzan.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from torchaudio.datasets.gtzan import gtzan_genres 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class GTZAN(Dataset): 7 | 8 | subset_map = {"train": "training", "valid": "validation", "test": "testing"} 9 | 10 | def __init__(self, root, download, subset): 11 | self.dataset = torchaudio.datasets.GTZAN( 12 | root=root, download=download, subset=self.subset_map[subset] 13 | ) 14 | self.labels = gtzan_genres 15 | 16 | self.label2idx = {} 17 | for idx, label in enumerate(self.labels): 18 | self.label2idx[label] = idx 19 | 20 | self.n_classes = len(self.label2idx.keys()) 21 | 22 | def __getitem__(self, idx): 23 | audio, sr, label = self.dataset[idx] 24 | label = self.label2idx[label] 25 | return audio, label 26 | 27 | def __len__(self): 28 | return len(self.dataset) 29 | -------------------------------------------------------------------------------- /clmr/datasets/librispeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchaudio 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class LIBRISPEECH(Dataset): 7 | 8 | subset_map = {"train": "train-clean-100", "test": "test-clean"} 9 | 10 | def __init__(self, root, download, subset): 11 | self.dataset = torchaudio.datasets.LIBRISPEECH( 12 | root=root, download=download, url=self.subset_map[subset] 13 | ) 14 | 15 | self.speaker2idx = {} 16 | 17 | if not os.path.exists(self.dataset._path): 18 | raise RuntimeError( 19 | "Dataset not found. Please use `download=True` to download it." 20 | ) 21 | 22 | self.speaker_ids = list(map(int, os.listdir(self.dataset._path))) 23 | for idx, speaker_id in enumerate(sorted(self.speaker_ids)): 24 | self.speaker2idx[speaker_id] = idx 25 | 26 | self.n_classes = len(self.speaker2idx.keys()) 27 | 28 | def __getitem__(self, idx): 29 | ( 30 | audio, 31 | sample_rate, 32 | utterance, 33 | speaker_id, 34 | chapter_id, 35 | utterance_id, 36 | ) = self.dataset[idx] 37 | label = self.speaker2idx[speaker_id] 38 | return audio, label 39 | 40 | def __len__(self): 41 | return len(self.dataset) 42 | -------------------------------------------------------------------------------- /clmr/datasets/magnatagatune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import subprocess 4 | import torch 5 | import numpy as np 6 | import zipfile 7 | from collections import defaultdict 8 | from typing import Any, Tuple, Optional 9 | from tqdm import tqdm 10 | 11 | import soundfile as sf 12 | import torchaudio 13 | 14 | torchaudio.set_audio_backend("soundfile") 15 | from torch import Tensor, FloatTensor 16 | from torchaudio.datasets.utils import ( 17 | download_url, 18 | extract_archive, 19 | ) 20 | 21 | from clmr.datasets import Dataset 22 | 23 | 24 | FOLDER_IN_ARCHIVE = "magnatagatune" 25 | _CHECKSUMS = { 26 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.001": "", 27 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.002": "", 28 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/mp3.zip.003": "", 29 | "http://mi.soi.city.ac.uk/datasets/magnatagatune/annotations_final.csv": "", 30 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/binary.npy": "", 31 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/tags.npy": "", 32 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/test.npy": "", 33 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/train.npy": "", 34 | "https://github.com/minzwon/sota-music-tagging-models/raw/master/split/mtat/valid.npy": "", 35 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/train_gt_mtt.tsv": "", 36 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/val_gt_mtt.tsv": "", 37 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/test_gt_mtt.tsv": "", 38 | "https://github.com/jordipons/musicnn-training/raw/master/data/index/mtt/index_mtt.tsv": "", 39 | } 40 | 41 | 42 | def get_file_list(root, subset, split): 43 | if subset == "train": 44 | if split == "pons2017": 45 | fl = open(os.path.join(root, "train_gt_mtt.tsv")).read().splitlines() 46 | else: 47 | fl = np.load(os.path.join(root, "train.npy")) 48 | elif subset == "valid": 49 | if split == "pons2017": 50 | fl = open(os.path.join(root, "val_gt_mtt.tsv")).read().splitlines() 51 | else: 52 | fl = np.load(os.path.join(root, "valid.npy")) 53 | else: 54 | if split == "pons2017": 55 | fl = open(os.path.join(root, "test_gt_mtt.tsv")).read().splitlines() 56 | else: 57 | fl = np.load(os.path.join(root, "test.npy")) 58 | 59 | if split == "pons2017": 60 | binary = {} 61 | index = open(os.path.join(root, "index_mtt.tsv")).read().splitlines() 62 | fp_dict = {} 63 | for i in index: 64 | clip_id, fp = i.split("\t") 65 | fp_dict[clip_id] = fp 66 | 67 | for idx, f in enumerate(fl): 68 | clip_id, label = f.split("\t") 69 | fl[idx] = "{}\t{}".format(clip_id, fp_dict[clip_id]) 70 | clip_id = int(clip_id) 71 | binary[clip_id] = eval(label) 72 | else: 73 | binary = np.load(os.path.join(root, "binary.npy")) 74 | 75 | return fl, binary 76 | 77 | 78 | class MAGNATAGATUNE(Dataset): 79 | """Create a Dataset for MagnaTagATune. 80 | Args: 81 | root (str): Path to the directory where the dataset is found or downloaded. 82 | folder_in_archive (str, optional): The top-level directory of the dataset. 83 | download (bool, optional): 84 | Whether to download the dataset if it is not found at root path. (default: ``False``). 85 | subset (str, optional): Which subset of the dataset to use. 86 | One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``. 87 | If ``None``, the entire dataset is used. (default: ``None``). 88 | """ 89 | 90 | _ext_audio = ".wav" 91 | 92 | def __init__( 93 | self, 94 | root: str, 95 | folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE, 96 | download: Optional[bool] = False, 97 | subset: Optional[str] = None, 98 | split: Optional[str] = "pons2017", 99 | ) -> None: 100 | 101 | super(MAGNATAGATUNE, self).__init__(root) 102 | self.root = root 103 | self.folder_in_archive = folder_in_archive 104 | self.download = download 105 | self.subset = subset 106 | self.split = split 107 | 108 | assert subset is None or subset in ["train", "valid", "test"], ( 109 | "When `subset` not None, it must take a value from " 110 | + "{'train', 'valid', 'test'}." 111 | ) 112 | 113 | self._path = os.path.join(root, folder_in_archive) 114 | 115 | if download: 116 | if not os.path.isdir(self._path): 117 | os.makedirs(self._path) 118 | 119 | zip_files = [] 120 | for url, checksum in _CHECKSUMS.items(): 121 | target_fn = os.path.basename(url) 122 | target_fp = os.path.join(self._path, target_fn) 123 | if ".zip" in target_fp: 124 | zip_files.append(target_fp) 125 | 126 | if not os.path.exists(target_fp): 127 | download_url( 128 | url, 129 | self._path, 130 | filename=target_fn, 131 | hash_value=checksum, 132 | hash_type="md5", 133 | ) 134 | 135 | if not os.path.exists( 136 | os.path.join( 137 | self._path, 138 | "f", 139 | "american_bach_soloists-j_s__bach_solo_cantatas-01-bwv54__i_aria-30-59.mp3", 140 | ) 141 | ): 142 | merged_zip = os.path.join(self._path, "mp3.zip") 143 | print("Merging zip files...") 144 | with open(merged_zip, "wb") as f: 145 | for filename in zip_files: 146 | with open(filename, "rb") as g: 147 | f.write(g.read()) 148 | 149 | extract_archive(merged_zip) 150 | 151 | if not os.path.isdir(self._path): 152 | raise RuntimeError( 153 | "Dataset not found. Please use `download=True` to download it." 154 | ) 155 | 156 | self.fl, self.binary = get_file_list(self._path, self.subset, self.split) 157 | self.n_classes = 50 # self.binary.shape[1] 158 | # self.audio = {} 159 | # for f in tqdm(self.fl): 160 | # clip_id, fp = f.split("\t") 161 | # if clip_id not in self.audio.keys(): 162 | # audio, _ = load_magnatagatune_item(fp, self._path, self._ext_audio) 163 | # self.audio[clip_id] = audio 164 | 165 | def file_path(self, n: int) -> str: 166 | _, fp = self.fl[n].split("\t") 167 | return os.path.join(self._path, fp) 168 | 169 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]: 170 | """Load the n-th sample from the dataset. 171 | 172 | Args: 173 | n (int): The index of the sample to be loaded 174 | 175 | Returns: 176 | tuple: ``(waveform, label)`` 177 | """ 178 | clip_id, fp = self.fl[n].split("\t") 179 | label = self.binary[int(clip_id)] 180 | 181 | audio, _ = self.load(n) 182 | label = FloatTensor(label) 183 | return audio, label 184 | 185 | def __len__(self) -> int: 186 | return len(self.fl) 187 | -------------------------------------------------------------------------------- /clmr/datasets/million_song_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | import torchaudio 5 | from collections import defaultdict 6 | from pathlib import Path 7 | from torch import Tensor, FloatTensor 8 | from tqdm import tqdm 9 | from typing import Any, Tuple, Optional 10 | 11 | from clmr.datasets import Dataset 12 | 13 | 14 | def load_id2gt(gt_file, msd_7d): 15 | ids = [] 16 | with open(gt_file) as f: 17 | id2gt = dict() 18 | for line in f.readlines(): 19 | msd_id, gt = line.strip().split("\t") # id is string 20 | id_7d = msd_7d[msd_id] 21 | id2gt[msd_id] = eval(gt) # gt is array 22 | ids.append(msd_id) 23 | return ids, id2gt 24 | 25 | 26 | def load_id2path(index_file, msd_7d): 27 | paths = [] 28 | with open(index_file) as f: 29 | id2path = dict() 30 | for line in f.readlines(): 31 | msd_id, msd_path = line.strip().split("\t") 32 | id_7d = msd_7d[msd_id] 33 | path = os.path.join(id_7d[0], id_7d[1], f"{id_7d}.clip.mp3") 34 | id2path[msd_id] = path 35 | paths.append(path) 36 | return paths, id2path 37 | 38 | 39 | def default_indexer(ids, id2audio_path, id2gt): 40 | index = [] 41 | track_index = defaultdict(list) 42 | track_idx = 0 43 | clip_idx = 0 44 | for clip_id in ids: 45 | fp = id2audio_path[clip_id] 46 | label = id2gt[clip_id] 47 | track_idx = clip_id 48 | clip_id = clip_idx 49 | clip_idx += 1 50 | index.append([track_idx, clip_id, fp, label]) 51 | track_index[track_idx].append([clip_id, fp, label]) 52 | return index, track_index 53 | 54 | 55 | def default_loader(path): 56 | audio, sr = torchaudio.load(path) 57 | audio = audio.mean(dim=0, keepdim=True) 58 | return audio, sr 59 | 60 | 61 | class MillionSongDataset(Dataset): 62 | 63 | _base_dir = "million_song_dataset" 64 | _ext_audio = ".wav" 65 | 66 | def __init__( 67 | self, 68 | root: str, 69 | base_dir: str = _base_dir, 70 | download: bool = False, 71 | subset: Optional[str] = None, 72 | ): 73 | if download: 74 | raise Exception("The Million Song Dataset is not publicly available") 75 | 76 | self.root = root 77 | self.base_dir = base_dir 78 | self.subset = subset 79 | 80 | assert subset is None or subset in ["train", "valid", "test"], ( 81 | "When `subset` not None, it must take a value from " 82 | + "{'train', 'valid', 'test'}." 83 | ) 84 | 85 | self._path = os.path.join(self.root, self.base_dir) 86 | 87 | if not os.path.exists(self._path): 88 | raise RuntimeError( 89 | "Dataset not found. Please place the MSD files in the {} folder.".format( 90 | self._path 91 | ) 92 | ) 93 | 94 | msd_processed_annot = Path(self._path, "processed_annotations") 95 | 96 | if self.subset == "train": 97 | self.annotations_file = Path(msd_processed_annot) / "train_gt_msd.tsv" 98 | elif self.subset == "valid": 99 | self.annotations_file = Path(msd_processed_annot) / "val_gt_msd.tsv" 100 | else: 101 | self.annotations_file = Path(msd_processed_annot) / "test_gt_msd.tsv" 102 | 103 | with open(Path(msd_processed_annot) / "MSD_id_to_7D_id.pkl", "rb") as f: 104 | self.msd_to_7d = pickle.load(f) 105 | 106 | # int to label 107 | with open(Path(msd_processed_annot) / "output_labels_msd.txt", "r") as f: 108 | lines = f.readlines() 109 | self.tags = eval(lines[1][lines[1].find("[") :]) 110 | self.n_classes = len(self.tags) 111 | 112 | [audio_repr_paths, id2audio_path] = load_id2path( 113 | Path(msd_processed_annot) / "index_msd.tsv", self.msd_to_7d 114 | ) 115 | [ids, id2gt] = load_id2gt(self.annotations_file, self.msd_to_7d) 116 | 117 | self.index, self.track_index = default_indexer(ids, id2audio_path, id2gt) 118 | 119 | def file_path(self, n: int) -> str: 120 | _, _, fp, _ = self.index[n] 121 | return os.path.join(self._path, "preprocessed", fp) 122 | 123 | def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]: 124 | track_id, clip_id, fp, label = self.index[n] 125 | label = torch.FloatTensor(label) 126 | 127 | try: 128 | audio, _ = self.load(n) 129 | except Exception as e: 130 | print(f"Skipped {track_id, fp}, could not load audio: {e}") 131 | return self.__getitem__(n + 1) 132 | return audio, label 133 | 134 | def __len__(self) -> int: 135 | return len(self.index) 136 | -------------------------------------------------------------------------------- /clmr/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | 8 | 9 | def evaluate( 10 | encoder: nn.Module, 11 | finetuned_head: nn.Module, 12 | test_dataset: Dataset, 13 | dataset_name: str, 14 | audio_length: int, 15 | device, 16 | ) -> dict: 17 | est_array = [] 18 | gt_array = [] 19 | 20 | encoder = encoder.to(device) 21 | encoder.eval() 22 | 23 | if finetuned_head is not None: 24 | finetuned_head = finetuned_head.to(device) 25 | finetuned_head.eval() 26 | 27 | with torch.no_grad(): 28 | for idx in tqdm(range(len(test_dataset))): 29 | _, label = test_dataset[idx] 30 | batch = test_dataset.concat_clip(idx, audio_length) 31 | batch = batch.to(device) 32 | 33 | output = encoder(batch) 34 | if finetuned_head: 35 | output = finetuned_head(output) 36 | 37 | # we always return logits, so we need a sigmoid here for multi-label classification 38 | if dataset_name in ["magnatagatune", "msd"]: 39 | output = torch.sigmoid(output) 40 | else: 41 | output = F.softmax(output, dim=1) 42 | 43 | track_prediction = output.mean(dim=0) 44 | est_array.append(track_prediction) 45 | gt_array.append(label) 46 | 47 | if dataset_name in ["magnatagatune", "msd"]: 48 | est_array = torch.stack(est_array, dim=0).cpu().numpy() 49 | gt_array = torch.stack(gt_array, dim=0).cpu().numpy() 50 | roc_aucs = metrics.roc_auc_score(gt_array, est_array, average="macro") 51 | pr_aucs = metrics.average_precision_score(gt_array, est_array, average="macro") 52 | return { 53 | "PR-AUC": pr_aucs, 54 | "ROC-AUC": roc_aucs, 55 | } 56 | 57 | est_array = torch.stack(est_array, dim=0) 58 | _, est_array = torch.max(est_array, 1) # extract the predicted labels here. 59 | accuracy = metrics.accuracy_score(gt_array, est_array) 60 | return {"Accuracy": accuracy} 61 | -------------------------------------------------------------------------------- /clmr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model, Identity 2 | from .sample_cnn import SampleCNN 3 | from .shortchunk_cnn import ShortChunkCNN_Res 4 | from .sinc_net import SincNet 5 | -------------------------------------------------------------------------------- /clmr/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | 5 | class Model(nn.Module): 6 | def __init__(self): 7 | super(Model, self).__init__() 8 | 9 | def initialize(self, m): 10 | if isinstance(m, (nn.Conv1d)): 11 | # nn.init.xavier_uniform_(m.weight) 12 | # if m.bias is not None: 13 | # nn.init.xavier_uniform_(m.bias) 14 | 15 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 16 | 17 | 18 | class Identity(nn.Module): 19 | def __init__(self): 20 | super(Identity, self).__init__() 21 | 22 | def forward(self, x): 23 | return x 24 | -------------------------------------------------------------------------------- /clmr/models/sample_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .model import Model 4 | 5 | 6 | class SampleCNN(Model): 7 | def __init__(self, strides, supervised, out_dim): 8 | super(SampleCNN, self).__init__() 9 | 10 | self.strides = strides 11 | self.supervised = supervised 12 | self.sequential = [ 13 | nn.Sequential( 14 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0), 15 | nn.BatchNorm1d(128), 16 | nn.ReLU(), 17 | ) 18 | ] 19 | 20 | self.hidden = [ 21 | [128, 128], 22 | [128, 128], 23 | [128, 256], 24 | [256, 256], 25 | [256, 256], 26 | [256, 256], 27 | [256, 256], 28 | [256, 256], 29 | [256, 512], 30 | ] 31 | 32 | assert len(self.hidden) == len( 33 | self.strides 34 | ), "Number of hidden layers and strides are not equal" 35 | for stride, (h_in, h_out) in zip(self.strides, self.hidden): 36 | self.sequential.append( 37 | nn.Sequential( 38 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1), 39 | nn.BatchNorm1d(h_out), 40 | nn.ReLU(), 41 | nn.MaxPool1d(stride, stride=stride), 42 | ) 43 | ) 44 | 45 | # 1 x 512 46 | self.sequential.append( 47 | nn.Sequential( 48 | nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm1d(512), 50 | nn.ReLU(), 51 | ) 52 | ) 53 | 54 | self.sequential = nn.Sequential(*self.sequential) 55 | 56 | if self.supervised: 57 | self.dropout = nn.Dropout(0.5) 58 | self.fc = nn.Linear(512, out_dim) 59 | 60 | def forward(self, x): 61 | out = self.sequential(x) 62 | if self.supervised: 63 | out = self.dropout(out) 64 | 65 | out = out.reshape(x.shape[0], out.size(1) * out.size(2)) 66 | logit = self.fc(out) 67 | return logit 68 | -------------------------------------------------------------------------------- /clmr/models/sample_cnn_xl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .model import Model 4 | 5 | 6 | class SampleCNNXL(Model): 7 | def __init__(self, strides, supervised, out_dim): 8 | super(SampleCNN, self).__init__() 9 | 10 | self.strides = strides 11 | self.supervised = supervised 12 | self.sequential = [ 13 | nn.Sequential( 14 | nn.Conv1d(1, 128, kernel_size=3, stride=3, padding=0), 15 | nn.BatchNorm1d(128), 16 | nn.ReLU(), 17 | ) 18 | ] 19 | 20 | self.hidden = [ 21 | [128, 128], 22 | [128, 128], 23 | [128, 256], 24 | [256, 256], 25 | [256, 512], 26 | [512, 512], 27 | [512, 1024], 28 | [1024, 1024], 29 | [1024, 2048], 30 | ] 31 | 32 | assert len(self.hidden) == len( 33 | self.strides 34 | ), "Number of hidden layers and strides are not equal" 35 | for stride, (h_in, h_out) in zip(self.strides, self.hidden): 36 | self.sequential.append( 37 | nn.Sequential( 38 | nn.Conv1d(h_in, h_out, kernel_size=stride, stride=1, padding=1), 39 | nn.BatchNorm1d(h_out), 40 | nn.ReLU(), 41 | nn.MaxPool1d(stride, stride=stride), 42 | ) 43 | ) 44 | 45 | # 1 x 512 46 | self.sequential.append( 47 | nn.Sequential( 48 | nn.Conv1d(2048, 2048, kernel_size=3, stride=1, padding=1), 49 | nn.BatchNorm1d(2048), 50 | nn.ReLU(), 51 | ) 52 | ) 53 | 54 | self.sequential = nn.Sequential(*self.sequential) 55 | 56 | if self.supervised: 57 | self.dropout = nn.Dropout(0.5) 58 | self.fc = nn.Linear(2048, out_dim) 59 | 60 | def forward(self, x): 61 | out = self.sequential(x) 62 | if self.supervised: 63 | out = self.dropout(out) 64 | 65 | out = out.reshape(x.shape[0], out.size(1) * out.size(2)) 66 | logit = self.fc(out) 67 | return logit 68 | -------------------------------------------------------------------------------- /clmr/models/shortchunk_cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ShortChunkCNN_Res(nn.Module): 5 | """ 6 | Short-chunk CNN architecture with residual connections. 7 | """ 8 | 9 | def __init__(self, n_channels=128, n_classes=50): 10 | super(ShortChunkCNN_Res, self).__init__() 11 | 12 | self.spec_bn = nn.BatchNorm2d(1) 13 | 14 | # CNN 15 | self.layer1 = Res_2d(1, n_channels, stride=2) 16 | self.layer2 = Res_2d(n_channels, n_channels, stride=2) 17 | self.layer3 = Res_2d(n_channels, n_channels * 2, stride=2) 18 | self.layer4 = Res_2d(n_channels * 2, n_channels * 2, stride=2) 19 | self.layer5 = Res_2d(n_channels * 2, n_channels * 2, stride=2) 20 | self.layer6 = Res_2d(n_channels * 2, n_channels * 2, stride=2) 21 | self.layer7 = Res_2d(n_channels * 2, n_channels * 4, stride=2) 22 | 23 | # Dense 24 | self.dense1 = nn.Linear(n_channels * 4, n_channels * 4) 25 | self.bn = nn.BatchNorm1d(n_channels * 4) 26 | 27 | self.fc = nn.Linear(n_channels * 4, n_classes) 28 | self.dropout = nn.Dropout(0.5) 29 | self.relu = nn.ReLU() 30 | 31 | def forward(self, x): 32 | x = self.spec_bn(x) 33 | 34 | # CNN 35 | x = self.layer1(x) 36 | x = self.layer2(x) 37 | x = self.layer3(x) 38 | x = self.layer4(x) 39 | x = self.layer5(x) 40 | x = self.layer6(x) 41 | x = self.layer7(x) 42 | x = x.squeeze(2) 43 | 44 | # Global Max Pooling 45 | if x.size(-1) != 1: 46 | x = nn.MaxPool1d(x.size(-1))(x) 47 | x = x.squeeze(2) 48 | 49 | # Dense 50 | x = self.dense1(x) 51 | x = self.bn(x) 52 | x = self.relu(x) 53 | x = self.dropout(x) 54 | x = self.fc(x) 55 | # x = nn.Sigmoid()(x) 56 | 57 | return x 58 | 59 | 60 | class Res_2d(nn.Module): 61 | def __init__(self, input_channels, output_channels, shape=3, stride=2): 62 | super(Res_2d, self).__init__() 63 | # convolution 64 | self.conv_1 = nn.Conv2d( 65 | input_channels, output_channels, shape, stride=stride, padding=shape // 2 66 | ) 67 | self.bn_1 = nn.BatchNorm2d(output_channels) 68 | self.conv_2 = nn.Conv2d( 69 | output_channels, output_channels, shape, padding=shape // 2 70 | ) 71 | self.bn_2 = nn.BatchNorm2d(output_channels) 72 | 73 | # residual 74 | self.diff = False 75 | if (stride != 1) or (input_channels != output_channels): 76 | self.conv_3 = nn.Conv2d( 77 | input_channels, 78 | output_channels, 79 | shape, 80 | stride=stride, 81 | padding=shape // 2, 82 | ) 83 | self.bn_3 = nn.BatchNorm2d(output_channels) 84 | self.diff = True 85 | self.relu = nn.ReLU() 86 | 87 | def forward(self, x): 88 | # convolution 89 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 90 | 91 | # residual 92 | if self.diff: 93 | x = self.bn_3(self.conv_3(x)) 94 | out = x + out 95 | out = self.relu(out) 96 | return out 97 | -------------------------------------------------------------------------------- /clmr/models/sinc_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import sys 6 | from torch.autograd import Variable 7 | import math 8 | 9 | 10 | def flip(x, dim): 11 | xsize = x.size() 12 | dim = x.dim() + dim if dim < 0 else dim 13 | x = x.contiguous() 14 | x = x.view(-1, *xsize[dim:]) 15 | x = x.view(x.size(0), x.size(1), -1)[ 16 | :, 17 | getattr( 18 | torch.arange(x.size(1) - 1, -1, -1), ("cpu", "cuda")[x.is_cuda] 19 | )().long(), 20 | :, 21 | ] 22 | return x.view(xsize) 23 | 24 | 25 | def sinc(band, t_right): 26 | y_right = torch.sin(2 * math.pi * band * t_right) / (2 * math.pi * band * t_right) 27 | y_left = flip(y_right, 0) 28 | 29 | y = torch.cat([y_left, Variable(torch.ones(1)).cuda(), y_right]) 30 | 31 | return y 32 | 33 | 34 | class SincConv_fast(nn.Module): 35 | """Sinc-based convolution 36 | Parameters 37 | ---------- 38 | in_channels : `int` 39 | Number of input channels. Must be 1. 40 | out_channels : `int` 41 | Number of filters. 42 | kernel_size : `int` 43 | Filter length. 44 | sample_rate : `int`, optional 45 | Sample rate. Defaults to 16000. 46 | Usage 47 | ----- 48 | See `torch.nn.Conv1d` 49 | Reference 50 | --------- 51 | Mirco Ravanelli, Yoshua Bengio, 52 | "Speaker Recognition from raw waveform with SincNet". 53 | https://arxiv.org/abs/1808.00158 54 | """ 55 | 56 | @staticmethod 57 | def to_mel(hz): 58 | return 2595 * np.log10(1 + hz / 700) 59 | 60 | @staticmethod 61 | def to_hz(mel): 62 | return 700 * (10 ** (mel / 2595) - 1) 63 | 64 | def __init__( 65 | self, 66 | out_channels, 67 | kernel_size, 68 | sample_rate=16000, 69 | in_channels=1, 70 | stride=1, 71 | padding=0, 72 | dilation=1, 73 | bias=False, 74 | groups=1, 75 | min_low_hz=50, 76 | min_band_hz=50, 77 | ): 78 | 79 | super(SincConv_fast, self).__init__() 80 | 81 | if in_channels != 1: 82 | # msg = (f'SincConv only support one input channel ' 83 | # f'(here, in_channels = {in_channels:d}).') 84 | msg = ( 85 | "SincConv only support one input channel (here, in_channels = {%i})" 86 | % (in_channels) 87 | ) 88 | raise ValueError(msg) 89 | 90 | self.out_channels = out_channels 91 | self.kernel_size = kernel_size 92 | 93 | # Forcing the filters to be odd (i.e, perfectly symmetrics) 94 | if kernel_size % 2 == 0: 95 | self.kernel_size = self.kernel_size + 1 96 | 97 | self.stride = stride 98 | self.padding = padding 99 | self.dilation = dilation 100 | 101 | if bias: 102 | raise ValueError("SincConv does not support bias.") 103 | if groups > 1: 104 | raise ValueError("SincConv does not support groups.") 105 | 106 | self.sample_rate = sample_rate 107 | self.min_low_hz = min_low_hz 108 | self.min_band_hz = min_band_hz 109 | 110 | # initialize filterbanks such that they are equally spaced in Mel scale 111 | low_hz = 30 112 | high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) 113 | 114 | mel = np.linspace( 115 | self.to_mel(low_hz), self.to_mel(high_hz), self.out_channels + 1 116 | ) 117 | hz = self.to_hz(mel) 118 | 119 | # filter lower frequency (out_channels, 1) 120 | self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) 121 | 122 | # filter frequency band (out_channels, 1) 123 | self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) 124 | 125 | # Hamming window 126 | # self.window_ = torch.hamming_window(self.kernel_size) 127 | n_lin = torch.linspace( 128 | 0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2)) 129 | ) # computing only half of the window 130 | self.window_ = 0.54 - 0.46 * torch.cos(2 * math.pi * n_lin / self.kernel_size) 131 | 132 | # (1, kernel_size/2) 133 | n = (self.kernel_size - 1) / 2.0 134 | self.n_ = ( 135 | 2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate 136 | ) # Due to symmetry, I only need half of the time axes 137 | 138 | def forward(self, waveforms): 139 | """ 140 | Parameters 141 | ---------- 142 | waveforms : `torch.Tensor` (batch_size, 1, n_samples) 143 | Batch of waveforms. 144 | Returns 145 | ------- 146 | features : `torch.Tensor` (batch_size, out_channels, n_samples_out) 147 | Batch of sinc filters activations. 148 | """ 149 | 150 | self.n_ = self.n_.to(waveforms.device) 151 | 152 | self.window_ = self.window_.to(waveforms.device) 153 | 154 | low = self.min_low_hz + torch.abs(self.low_hz_) 155 | 156 | high = torch.clamp( 157 | low + self.min_band_hz + torch.abs(self.band_hz_), 158 | self.min_low_hz, 159 | self.sample_rate / 2, 160 | ) 161 | band = (high - low)[:, 0] 162 | 163 | f_times_t_low = torch.matmul(low, self.n_) 164 | f_times_t_high = torch.matmul(high, self.n_) 165 | 166 | band_pass_left = ( 167 | (torch.sin(f_times_t_high) - torch.sin(f_times_t_low)) / (self.n_ / 2) 168 | ) * self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 169 | band_pass_center = 2 * band.view(-1, 1) 170 | band_pass_right = torch.flip(band_pass_left, dims=[1]) 171 | 172 | band_pass = torch.cat( 173 | [band_pass_left, band_pass_center, band_pass_right], dim=1 174 | ) 175 | 176 | band_pass = band_pass / (2 * band[:, None]) 177 | 178 | self.filters = (band_pass).view(self.out_channels, 1, self.kernel_size) 179 | 180 | return F.conv1d( 181 | waveforms, 182 | self.filters, 183 | stride=self.stride, 184 | padding=self.padding, 185 | dilation=self.dilation, 186 | bias=None, 187 | groups=1, 188 | ) 189 | 190 | 191 | class sinc_conv(nn.Module): 192 | def __init__(self, N_filt, Filt_dim, fs): 193 | super(sinc_conv, self).__init__() 194 | 195 | # Mel Initialization of the filterbanks 196 | low_freq_mel = 80 197 | high_freq_mel = 2595 * np.log10(1 + (fs / 2) / 700) # Convert Hz to Mel 198 | mel_points = np.linspace( 199 | low_freq_mel, high_freq_mel, N_filt 200 | ) # Equally spaced in Mel scale 201 | f_cos = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz 202 | b1 = np.roll(f_cos, 1) 203 | b2 = np.roll(f_cos, -1) 204 | b1[0] = 30 205 | b2[-1] = (fs / 2) - 100 206 | 207 | self.freq_scale = fs * 1.0 208 | self.filt_b1 = nn.Parameter(torch.from_numpy(b1 / self.freq_scale)) 209 | self.filt_band = nn.Parameter(torch.from_numpy((b2 - b1) / self.freq_scale)) 210 | 211 | self.N_filt = N_filt 212 | self.Filt_dim = Filt_dim 213 | self.fs = fs 214 | 215 | def forward(self, x): 216 | 217 | filters = Variable(torch.zeros((self.N_filt, self.Filt_dim))).cuda() 218 | N = self.Filt_dim 219 | t_right = Variable( 220 | torch.linspace(1, (N - 1) / 2, steps=int((N - 1) / 2)) / self.fs 221 | ).cuda() 222 | 223 | min_freq = 50.0 224 | min_band = 50.0 225 | 226 | filt_beg_freq = torch.abs(self.filt_b1) + min_freq / self.freq_scale 227 | filt_end_freq = filt_beg_freq + ( 228 | torch.abs(self.filt_band) + min_band / self.freq_scale 229 | ) 230 | 231 | n = torch.linspace(0, N, steps=N) 232 | 233 | # Filter window (hamming) 234 | window = 0.54 - 0.46 * torch.cos(2 * math.pi * n / N) 235 | window = Variable(window.float().cuda()) 236 | 237 | for i in range(self.N_filt): 238 | 239 | low_pass1 = ( 240 | 2 241 | * filt_beg_freq[i].float() 242 | * sinc(filt_beg_freq[i].float() * self.freq_scale, t_right) 243 | ) 244 | low_pass2 = ( 245 | 2 246 | * filt_end_freq[i].float() 247 | * sinc(filt_end_freq[i].float() * self.freq_scale, t_right) 248 | ) 249 | band_pass = low_pass2 - low_pass1 250 | 251 | band_pass = band_pass / torch.max(band_pass) 252 | 253 | filters[i, :] = band_pass.cuda() * window 254 | 255 | out = F.conv1d(x, filters.view(self.N_filt, 1, self.Filt_dim)) 256 | 257 | return out 258 | 259 | 260 | def act_fun(act_type): 261 | 262 | if act_type == "relu": 263 | return nn.ReLU() 264 | 265 | if act_type == "tanh": 266 | return nn.Tanh() 267 | 268 | if act_type == "sigmoid": 269 | return nn.Sigmoid() 270 | 271 | if act_type == "leaky_relu": 272 | return nn.LeakyReLU(0.2) 273 | 274 | if act_type == "elu": 275 | return nn.ELU() 276 | 277 | if act_type == "softmax": 278 | return nn.LogSoftmax(dim=1) 279 | 280 | if act_type == "linear": 281 | return nn.LeakyReLU(1) # initializzed like this, but not used in forward! 282 | 283 | 284 | class LayerNorm(nn.Module): 285 | def __init__(self, features, eps=1e-6): 286 | super(LayerNorm, self).__init__() 287 | self.gamma = nn.Parameter(torch.ones(features)) 288 | self.beta = nn.Parameter(torch.zeros(features)) 289 | self.eps = eps 290 | 291 | def forward(self, x): 292 | mean = x.mean(-1, keepdim=True) 293 | std = x.std(-1, keepdim=True) 294 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 295 | 296 | 297 | class MLP(nn.Module): 298 | def __init__(self, options): 299 | super(MLP, self).__init__() 300 | 301 | self.input_dim = int(options["input_dim"]) 302 | self.fc_lay = options["fc_lay"] 303 | self.fc_drop = options["fc_drop"] 304 | self.fc_use_batchnorm = options["fc_use_batchnorm"] 305 | self.fc_use_laynorm = options["fc_use_laynorm"] 306 | self.fc_use_laynorm_inp = options["fc_use_laynorm_inp"] 307 | self.fc_use_batchnorm_inp = options["fc_use_batchnorm_inp"] 308 | self.fc_act = options["fc_act"] 309 | 310 | self.wx = nn.ModuleList([]) 311 | self.bn = nn.ModuleList([]) 312 | self.ln = nn.ModuleList([]) 313 | self.act = nn.ModuleList([]) 314 | self.drop = nn.ModuleList([]) 315 | 316 | # input layer normalization 317 | if self.fc_use_laynorm_inp: 318 | self.ln0 = LayerNorm(self.input_dim) 319 | 320 | # input batch normalization 321 | if self.fc_use_batchnorm_inp: 322 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) 323 | 324 | self.N_fc_lay = len(self.fc_lay) 325 | 326 | current_input = self.input_dim 327 | 328 | # Initialization of hidden layers 329 | 330 | for i in range(self.N_fc_lay): 331 | 332 | # dropout 333 | self.drop.append(nn.Dropout(p=self.fc_drop[i])) 334 | 335 | # activation 336 | self.act.append(act_fun(self.fc_act[i])) 337 | 338 | add_bias = True 339 | 340 | # layer norm initialization 341 | self.ln.append(LayerNorm(self.fc_lay[i])) 342 | self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05)) 343 | 344 | if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: 345 | add_bias = False 346 | 347 | # Linear operations 348 | self.wx.append(nn.Linear(current_input, self.fc_lay[i], bias=add_bias)) 349 | 350 | # weight initialization 351 | self.wx[i].weight = torch.nn.Parameter( 352 | torch.Tensor(self.fc_lay[i], current_input).uniform_( 353 | -np.sqrt(0.01 / (current_input + self.fc_lay[i])), 354 | np.sqrt(0.01 / (current_input + self.fc_lay[i])), 355 | ) 356 | ) 357 | self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) 358 | 359 | current_input = self.fc_lay[i] 360 | 361 | def forward(self, x): 362 | 363 | # Applying Layer/Batch Norm 364 | if bool(self.fc_use_laynorm_inp): 365 | x = self.ln0((x)) 366 | 367 | if bool(self.fc_use_batchnorm_inp): 368 | x = self.bn0((x)) 369 | 370 | for i in range(self.N_fc_lay): 371 | 372 | if self.fc_act[i] != "linear": 373 | 374 | if self.fc_use_laynorm[i]: 375 | x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) 376 | 377 | if self.fc_use_batchnorm[i]: 378 | x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) 379 | 380 | if ( 381 | self.fc_use_batchnorm[i] == False 382 | and self.fc_use_laynorm[i] == False 383 | ): 384 | x = self.drop[i](self.act[i](self.wx[i](x))) 385 | 386 | else: 387 | if self.fc_use_laynorm[i]: 388 | x = self.drop[i](self.ln[i](self.wx[i](x))) 389 | 390 | if self.fc_use_batchnorm[i]: 391 | x = self.drop[i](self.bn[i](self.wx[i](x))) 392 | 393 | if ( 394 | self.fc_use_batchnorm[i] == False 395 | and self.fc_use_laynorm[i] == False 396 | ): 397 | x = self.drop[i](self.wx[i](x)) 398 | 399 | return x 400 | 401 | 402 | class SincNet(nn.Module): 403 | def __init__( 404 | self, 405 | cnn_N_filt, 406 | cnn_len_filt, 407 | cnn_max_pool_len, 408 | cnn_act, 409 | cnn_drop, 410 | cnn_use_laynorm, 411 | cnn_use_batchnorm, 412 | cnn_use_laynorm_inp, 413 | cnn_use_batchnorm_inp, 414 | input_dim, 415 | fs, 416 | ): 417 | super(SincNet, self).__init__() 418 | 419 | self.cnn_N_filt = cnn_N_filt 420 | self.cnn_len_filt = cnn_len_filt 421 | self.cnn_max_pool_len = cnn_max_pool_len 422 | 423 | self.cnn_act = cnn_act 424 | self.cnn_drop = cnn_drop 425 | 426 | self.cnn_use_laynorm = cnn_use_laynorm 427 | self.cnn_use_batchnorm = cnn_use_batchnorm 428 | self.cnn_use_laynorm_inp = cnn_use_laynorm_inp 429 | self.cnn_use_batchnorm_inp = cnn_use_batchnorm_inp 430 | 431 | self.input_dim = int(input_dim) 432 | 433 | self.fs = fs 434 | 435 | self.N_cnn_lay = len(self.cnn_N_filt) 436 | self.conv = nn.ModuleList([]) 437 | self.bn = nn.ModuleList([]) 438 | self.ln = nn.ModuleList([]) 439 | self.act = nn.ModuleList([]) 440 | self.drop = nn.ModuleList([]) 441 | 442 | if self.cnn_use_laynorm_inp: 443 | self.ln0 = LayerNorm(self.input_dim) 444 | 445 | if self.cnn_use_batchnorm_inp: 446 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) 447 | 448 | current_input = self.input_dim 449 | 450 | for i in range(self.N_cnn_lay): 451 | 452 | N_filt = int(self.cnn_N_filt[i]) 453 | len_filt = int(self.cnn_len_filt[i]) 454 | 455 | # dropout 456 | self.drop.append(nn.Dropout(p=self.cnn_drop[i])) 457 | 458 | # activation 459 | self.act.append(act_fun(self.cnn_act[i])) 460 | 461 | # layer norm initialization 462 | self.ln.append( 463 | LayerNorm( 464 | [ 465 | N_filt, 466 | int( 467 | (current_input - self.cnn_len_filt[i] + 1) 468 | / self.cnn_max_pool_len[i] 469 | ), 470 | ] 471 | ) 472 | ) 473 | 474 | self.bn.append( 475 | nn.BatchNorm1d( 476 | N_filt, 477 | int( 478 | (current_input - self.cnn_len_filt[i] + 1) 479 | / self.cnn_max_pool_len[i] 480 | ), 481 | momentum=0.05, 482 | ) 483 | ) 484 | 485 | if i == 0: 486 | self.conv.append( 487 | SincConv_fast(self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs) 488 | ) 489 | 490 | else: 491 | self.conv.append( 492 | nn.Conv1d( 493 | self.cnn_N_filt[i - 1], self.cnn_N_filt[i], self.cnn_len_filt[i] 494 | ) 495 | ) 496 | 497 | current_input = int( 498 | (current_input - self.cnn_len_filt[i] + 1) / self.cnn_max_pool_len[i] 499 | ) 500 | 501 | self.out_dim = current_input * N_filt 502 | 503 | def forward(self, x): 504 | batch = x.shape[0] 505 | seq_len = x.shape[1] 506 | 507 | if bool(self.cnn_use_laynorm_inp): 508 | x = self.ln0((x)) 509 | 510 | if bool(self.cnn_use_batchnorm_inp): 511 | x = self.bn0((x)) 512 | 513 | x = x.view(batch, 1, seq_len) 514 | 515 | for i in range(self.N_cnn_lay): 516 | 517 | if self.cnn_use_laynorm[i]: 518 | if i == 0: 519 | x = self.drop[i]( 520 | self.act[i]( 521 | self.ln[i]( 522 | F.max_pool1d( 523 | torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i] 524 | ) 525 | ) 526 | ) 527 | ) 528 | else: 529 | x = self.drop[i]( 530 | self.act[i]( 531 | self.ln[i]( 532 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]) 533 | ) 534 | ) 535 | ) 536 | 537 | if self.cnn_use_batchnorm[i]: 538 | x = self.drop[i]( 539 | self.act[i]( 540 | self.bn[i]( 541 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]) 542 | ) 543 | ) 544 | ) 545 | 546 | if self.cnn_use_batchnorm[i] == False and self.cnn_use_laynorm[i] == False: 547 | x = self.drop[i]( 548 | self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])) 549 | ) 550 | 551 | x = x.view(batch, -1) 552 | 553 | return x 554 | -------------------------------------------------------------------------------- /clmr/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import PlotSpectogramCallback 2 | from .contrastive_learning import ContrastiveLearning 3 | from .linear_evaluation import LinearEvaluation 4 | from .supervised_learning import SupervisedLearning 5 | -------------------------------------------------------------------------------- /clmr/modules/callbacks.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | 4 | matplotlib.use("Agg") 5 | 6 | from pytorch_lightning.callbacks import Callback 7 | 8 | 9 | class PlotSpectogramCallback(Callback): 10 | def on_train_start(self, trainer, pl_module): 11 | 12 | if not pl_module.hparams.time_domain: 13 | x, y = trainer.train_dataloader.dataset[0] 14 | 15 | fig = plt.figure() 16 | x_i = x[0, :] 17 | fig.add_subplot(1, 2, 1) 18 | plt.imshow(x_i) 19 | if x.shape[0] > 1: 20 | x_j = x[1, :] 21 | fig.add_subplot(1, 2, 2) 22 | plt.imshow(x_j) 23 | 24 | trainer.logger.experiment.add_figure( 25 | "Train/spectogram_sample", fig, global_step=0 26 | ) 27 | plt.close() 28 | -------------------------------------------------------------------------------- /clmr/modules/contrastive_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_lightning import LightningModule 4 | from torch import Tensor 5 | 6 | from simclr import SimCLR 7 | from simclr.modules import NT_Xent, LARS 8 | 9 | 10 | class ContrastiveLearning(LightningModule): 11 | def __init__(self, args, encoder: nn.Module): 12 | super().__init__() 13 | self.save_hyperparameters(args) 14 | 15 | self.encoder = encoder 16 | self.n_features = ( 17 | self.encoder.fc.in_features 18 | ) # get dimensions of last fully-connected layer 19 | self.model = SimCLR(self.encoder, self.hparams.projection_dim, self.n_features) 20 | self.criterion = self.configure_criterion() 21 | 22 | def forward(self, x_i: Tensor, x_j: Tensor) -> Tensor: 23 | _, _, z_i, z_j = self.model(x_i, x_j) 24 | loss = self.criterion(z_i, z_j) 25 | return loss 26 | 27 | def training_step(self, batch, _) -> Tensor: 28 | x, _ = batch 29 | x_i = x[:, 0, :] 30 | x_j = x[:, 1, :] 31 | loss = self.forward(x_i, x_j) 32 | self.log("Train/loss", loss) 33 | return loss 34 | 35 | def configure_criterion(self) -> nn.Module: 36 | # PT lightning aggregates differently in DP mode 37 | if self.hparams.accelerator == "dp" and self.hparams.gpus: 38 | batch_size = int(self.hparams.batch_size / self.hparams.gpus) 39 | else: 40 | batch_size = self.hparams.batch_size 41 | 42 | criterion = NT_Xent(batch_size, self.hparams.temperature, world_size=1) 43 | return criterion 44 | 45 | def configure_optimizers(self) -> dict: 46 | scheduler = None 47 | if self.hparams.optimizer == "Adam": 48 | optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4) 49 | elif self.hparams.optimizer == "LARS": 50 | # optimized using LARS with linear learning rate scaling 51 | # (i.e. LearningRate = 0.3 × BatchSize/256) and weight decay of 10−6. 52 | learning_rate = 0.3 * self.hparams.batch_size / 256 53 | optimizer = LARS( 54 | self.model.parameters(), 55 | lr=learning_rate, 56 | weight_decay=self.hparams.weight_decay, 57 | exclude_from_weight_decay=["batch_normalization", "bias"], 58 | ) 59 | 60 | # "decay the learning rate with the cosine decay schedule without restarts" 61 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 62 | optimizer, self.hparams.max_epochs, eta_min=0, last_epoch=-1 63 | ) 64 | else: 65 | raise NotImplementedError 66 | 67 | if scheduler: 68 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 69 | else: 70 | return {"optimizer": optimizer} 71 | -------------------------------------------------------------------------------- /clmr/modules/linear_evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchmetrics 4 | from copy import deepcopy 5 | from pytorch_lightning import LightningModule 6 | from torch import Tensor 7 | from torch.utils.data import DataLoader, Dataset, TensorDataset 8 | from typing import Tuple 9 | from tqdm import tqdm 10 | 11 | 12 | class LinearEvaluation(LightningModule): 13 | def __init__(self, args, encoder: nn.Module, hidden_dim: int, output_dim: int): 14 | super().__init__() 15 | self.save_hyperparameters(args) 16 | 17 | self.encoder = encoder 18 | self.hidden_dim = hidden_dim 19 | self.output_dim = output_dim 20 | 21 | if self.hparams.finetuner_mlp: 22 | self.model = nn.Sequential( 23 | nn.Linear(self.hidden_dim, self.hidden_dim), 24 | nn.ReLU(), 25 | nn.Linear(self.hidden_dim, self.output_dim), 26 | ) 27 | else: 28 | self.model = nn.Sequential(nn.Linear(self.hidden_dim, self.output_dim)) 29 | self.criterion = self.configure_criterion() 30 | 31 | self.accuracy = torchmetrics.Accuracy() 32 | self.average_precision = torchmetrics.AveragePrecision(pos_label=1) 33 | 34 | def forward(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: 35 | preds = self._forward_representations(x, y) 36 | loss = self.criterion(preds, y) 37 | return loss, preds 38 | 39 | def _forward_representations(self, x: Tensor, y: Tensor) -> Tensor: 40 | """ 41 | Perform a forward pass using either the representations, or the input data (that we still) 42 | need to extract the represenations from using our encoder. 43 | """ 44 | if x.shape[-1] == self.hidden_dim: 45 | h0 = x 46 | else: 47 | with torch.no_grad(): 48 | h0 = self.encoder(x) 49 | return self.model(h0) 50 | 51 | def training_step(self, batch, _) -> Tensor: 52 | x, y = batch 53 | loss, preds = self.forward(x, y) 54 | 55 | self.log("Train/accuracy", self.accuracy(preds, y)) 56 | # self.log("Train/pr_auc", self.average_precision(preds, y)) 57 | self.log("Train/loss", loss) 58 | return loss 59 | 60 | def validation_step(self, batch, _) -> Tensor: 61 | x, y = batch 62 | loss, preds = self.forward(x, y) 63 | 64 | self.log("Valid/accuracy", self.accuracy(preds, y)) 65 | # self.log("Valid/pr_auc", self.average_precision(preds, y)) 66 | self.log("Valid/loss", loss) 67 | return loss 68 | 69 | def configure_criterion(self) -> nn.Module: 70 | if self.hparams.dataset in ["magnatagatune", "msd"]: 71 | criterion = nn.BCEWithLogitsLoss() 72 | else: 73 | criterion = nn.CrossEntropyLoss() 74 | return criterion 75 | 76 | def configure_optimizers(self) -> dict: 77 | optimizer = torch.optim.Adam( 78 | self.model.parameters(), 79 | lr=self.hparams.finetuner_learning_rate, 80 | weight_decay=self.hparams.weight_decay, 81 | ) 82 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 83 | optimizer, 84 | mode="min", 85 | factor=0.1, 86 | patience=5, 87 | threshold=0.0001, 88 | threshold_mode="rel", 89 | cooldown=0, 90 | min_lr=0, 91 | eps=1e-08, 92 | verbose=False, 93 | ) 94 | if scheduler: 95 | return { 96 | "optimizer": optimizer, 97 | "lr_scheduler": scheduler, 98 | "monitor": "Valid/loss", 99 | } 100 | else: 101 | return {"optimizer": optimizer} 102 | 103 | def extract_representations(self, dataloader: DataLoader) -> Dataset: 104 | 105 | representations = [] 106 | ys = [] 107 | for x, y in tqdm(dataloader): 108 | with torch.no_grad(): 109 | h0 = self.encoder(x) 110 | representations.append(h0) 111 | ys.append(y) 112 | 113 | if len(representations) > 1: 114 | representations = torch.cat(representations, dim=0) 115 | ys = torch.cat(ys, dim=0) 116 | else: 117 | representations = representations[0] 118 | ys = ys[0] 119 | 120 | tensor_dataset = TensorDataset(representations, ys) 121 | return tensor_dataset 122 | -------------------------------------------------------------------------------- /clmr/modules/supervised_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import torch.nn as nn 4 | from pytorch_lightning import LightningModule 5 | 6 | 7 | class SupervisedLearning(LightningModule): 8 | def __init__(self, args, encoder: nn.Module, output_dim: int): 9 | super().__init__() 10 | self.save_hyperparameters(args) 11 | self.encoder = encoder 12 | 13 | self.encoder.fc.out_features = output_dim 14 | self.output_dim = output_dim 15 | self.model = self.encoder 16 | self.criterion = self.configure_criterion() 17 | 18 | self.average_precision = torchmetrics.AveragePrecision(pos_label=1) 19 | 20 | def forward(self, x, y): 21 | x = x[:, 0, :] # we only have 1 sample, no augmentations 22 | preds = self.model(x) 23 | loss = self.criterion(preds, y) 24 | return loss, preds 25 | 26 | def training_step(self, batch, batch_idx): 27 | x, y = batch 28 | loss, preds = self.forward(x, y) 29 | self.log("Train/pr_auc", self.average_precision(preds, y)) 30 | self.log("Train/loss", loss) 31 | return loss 32 | 33 | def validation_step(self, batch, batch_idx): 34 | x, y = batch 35 | loss, preds = self.forward(x, y) 36 | self.log("Valid/pr_auc", self.average_precision(preds, y)) 37 | self.log("Valid/loss", loss) 38 | return loss 39 | 40 | def configure_criterion(self): 41 | if self.hparams.dataset in ["magnatagatune"]: 42 | criterion = nn.BCEWithLogitsLoss() 43 | else: 44 | criterion = nn.CrossEntropyLoss() 45 | return criterion 46 | 47 | def configure_optimizers(self): 48 | optimizer = torch.optim.SGD( 49 | self.model.parameters(), 50 | lr=self.hparams.learning_rate, 51 | momentum=0.9, 52 | weight_decay=1e-6, 53 | nesterov=True, 54 | ) 55 | 56 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 57 | optimizer, mode="min", factor=0.2, patience=5, verbose=True 58 | ) 59 | 60 | if scheduler: 61 | return { 62 | "optimizer": optimizer, 63 | "lr_scheduler": scheduler, 64 | "monitor": "Valid/loss", 65 | } 66 | else: 67 | return {"optimizer": optimizer} 68 | -------------------------------------------------------------------------------- /clmr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import load_encoder_checkpoint, load_finetuner_checkpoint 2 | from .yaml_config_hook import yaml_config_hook 3 | -------------------------------------------------------------------------------- /clmr/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | 5 | def load_encoder_checkpoint(checkpoint_path: str, output_dim: int) -> OrderedDict: 6 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 7 | if "pytorch-lightning_version" in state_dict.keys(): 8 | new_state_dict = OrderedDict( 9 | { 10 | k.replace("model.encoder.", ""): v 11 | for k, v in state_dict["state_dict"].items() 12 | if "model.encoder." in k 13 | } 14 | ) 15 | else: 16 | new_state_dict = OrderedDict() 17 | for k, v in state_dict.items(): 18 | if "encoder." in k: 19 | new_state_dict[k.replace("encoder.", "")] = v 20 | 21 | new_state_dict["fc.weight"] = torch.zeros(output_dim, 512) 22 | new_state_dict["fc.bias"] = torch.zeros(output_dim) 23 | return new_state_dict 24 | 25 | 26 | def load_finetuner_checkpoint(checkpoint_path: str) -> OrderedDict: 27 | state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) 28 | if "pytorch-lightning_version" in state_dict.keys(): 29 | state_dict = OrderedDict( 30 | { 31 | k.replace("model.", ""): v 32 | for k, v in state_dict["state_dict"].items() 33 | if "model." in k 34 | } 35 | ) 36 | return state_dict 37 | -------------------------------------------------------------------------------- /clmr/utils/yaml_config_hook.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | 5 | def yaml_config_hook(config_file): 6 | """ 7 | Custom YAML config loader, which can include other yaml files (I like using config files 8 | insteaad of using argparser) 9 | """ 10 | 11 | # load yaml files in the nested 'defaults' section, which include defaults for experiments 12 | with open(config_file) as f: 13 | cfg = yaml.safe_load(f) 14 | for d in cfg.get("defaults", []): 15 | config_dir, cf = d.popitem() 16 | cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml") 17 | with open(cf) as f: 18 | l = yaml.safe_load(f) 19 | cfg.update(l) 20 | 21 | if "defaults" in cfg.keys(): 22 | del cfg["defaults"] 23 | 24 | return cfg 25 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # infra options 2 | # gpus: 0 3 | # accelerator: "dp" # use ddp for gpus > 1. Also see PyTorch Lightning documentation on distributed training. 4 | workers: 0 # I recommend tuning this parameter for faster data augmentation processing 5 | dataset_dir: "./data" 6 | 7 | # train options 8 | seed: 42 9 | batch_size: 48 10 | # max_epochs: 200 11 | dataset: "magnatagatune" # ["magnatagatune", "msd", "gtzan", "audio"] 12 | supervised: 0 # train with supervised baseline 13 | 14 | # SimCLR model options 15 | projection_dim: 64 # Projection dim. of SimCLR projector 16 | 17 | # loss options 18 | optimizer: "Adam" # or LARS (experimental) 19 | learning_rate: 0.0003 20 | weight_decay: 1.0e-6 # "optimized using LARS [...] and weight decay of 10−6" 21 | temperature: 0.5 # see appendix B.7.: Optimal temperature under different batch sizes 22 | 23 | # reload options 24 | checkpoint_path: "" # set to the directory containing `checkpoint_##.tar` 25 | 26 | # logistic regression options 27 | finetuner_mlp: 0 28 | finetuner_checkpoint_path: "" 29 | finetuner_max_epochs: 200 30 | finetuner_batch_size: 256 31 | finetuner_learning_rate: 0.001 32 | 33 | # audio data augmentation options 34 | audio_length: 59049 35 | sample_rate: 22050 36 | transforms_polarity: 0.8 37 | transforms_noise: 0.01 38 | transforms_gain: 0.3 39 | transforms_filters: 0.8 40 | transforms_delay: 0.3 41 | transforms_pitch: 0.6 42 | transforms_reverb: 0.6 43 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/.DS_Store -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/.DS_Store -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/clmr_sample-cnn.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/clmr_sample-cnn.onnx -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/.DS_Store -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/fisher_losing_it.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/fisher_losing_it.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/fisher_losing_it_22050.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/fisher_losing_it_22050.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/john_lennon_imagine.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/john_lennon_imagine.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/john_lennon_imagine_22050.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/john_lennon_imagine_22050.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit_22050.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/nirvana_smells_like_teen_spirit_22050.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/queen_love_of_my_life.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/queen_love_of_my_life.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/data/queen_love_of_my_life_22050.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/examples/clmr-onnxruntime-web/data/queen_love_of_my_life_22050.mp3 -------------------------------------------------------------------------------- /examples/clmr-onnxruntime-web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | CLMR on ONNX Runtime 5 | 7 | 8 | 9 | 10 | 78 | 79 |
80 | 81 | 82 | 83 |
84 | 85 |
86 | 87 | CLMR x 88 | 89 | 90 |
91 |

92 | On this page, you are running a CLMR model on the ONNX Runtime in your browser.
93 | No data is sent to a server, all predictions are done on your device!

94 | 95 | In short, the following happens under the hood: 96 |

    97 |
  1. It will first load a pre-trained model in the background.
  2. 98 |
  3. The model extracts audio representations from raw samples from the audio buffer.
  4. 99 |
  5. The audio representations are given to a linear classifier, which predicts the corresponding audio tags.
  6. 100 |
  7. The multi-label predictions are shown in the bar plot.
  8. 101 |
102 |

103 |
104 | 105 | 106 | 107 |
108 | 109 |
110 | 111 | 117 |
118 | 119 |
120 | 121 | 122 |
123 | No data is uploaded to the server.
Ideally, the sample rate of the file should be 22,050Hz. But smaller / higher sample rates also work.
124 |
125 |
126 | 127 | 128 | 129 |

130 | 131 | 132 |     133 | 134 |

Loading ONNX model...

135 |

136 |

137 |

138 | 139 | 140 | 141 |
142 | 143 |
144 | 145 |
146 | 147 | 148 | 149 |
150 | 151 |
152 | 153 | 156 | 159 | 162 | 163 | 164 | 167 | 168 | 169 | 170 | 552 | 553 | 554 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will extract a pre-trained CLMR PyTorch model to an ONNX model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import torch 8 | from collections import OrderedDict 9 | from copy import deepcopy 10 | from clmr.models import SampleCNN, Identity 11 | from clmr.utils import load_encoder_checkpoint, load_finetuner_checkpoint 12 | 13 | 14 | def convert_encoder_to_onnx( 15 | encoder: torch.nn.Module, test_input: torch.Tensor, fp: str 16 | ) -> None: 17 | input_names = ["audio"] 18 | output_names = ["representation"] 19 | 20 | torch.onnx.export( 21 | encoder, 22 | test_input, 23 | fp, 24 | verbose=False, 25 | input_names=input_names, 26 | output_names=output_names, 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument("--checkpoint_path", type=str, required=True) 34 | parser.add_argument("--finetuner_checkpoint_path", type=str, required=True) 35 | parser.add_argument("--n_classes", type=int, default=50) 36 | args = parser.parse_args() 37 | 38 | if not os.path.exists(args.checkpoint_path): 39 | raise FileNotFoundError("That encoder checkpoint does not exist") 40 | 41 | if not os.path.exists(args.finetuner_checkpoint_path): 42 | raise FileNotFoundError("That linear model checkpoint does not exist") 43 | 44 | # ------------ 45 | # encoder 46 | # ------------ 47 | encoder = SampleCNN( 48 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], 49 | supervised=False, 50 | out_dim=args.n_classes, 51 | ) 52 | 53 | n_features = encoder.fc.in_features # get dimensions of last fully-connected layer 54 | 55 | state_dict = load_encoder_checkpoint(args.checkpoint_path, args.n_classes) 56 | encoder.load_state_dict(state_dict) 57 | encoder.eval() 58 | 59 | # ------------ 60 | # linear model 61 | # ------------ 62 | state_dict = load_finetuner_checkpoint(args.finetuner_checkpoint_path) 63 | encoder.fc.load_state_dict( 64 | OrderedDict({k.replace("0.", ""): v for k, v in state_dict.items()}) 65 | ) 66 | 67 | encoder_export = deepcopy(encoder) 68 | # set last fully connected layer to an identity function: 69 | encoder_export.fc = Identity() 70 | 71 | batch_size = 1 72 | channels = 1 73 | audio_length = 59049 74 | test_input = torch.randn(batch_size, 1, audio_length) 75 | 76 | convert_encoder_to_onnx(encoder, test_input, "clmr_sample-cnn.onnx") 77 | convert_encoder_to_onnx( 78 | encoder_export, test_input, "clmr_encoder_only_sample-cnn.onnx" 79 | ) 80 | -------------------------------------------------------------------------------- /linear_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | from torchaudio_augmentations import Compose, RandomResizedCrop 6 | from pytorch_lightning import Trainer 7 | from pytorch_lightning.callbacks import EarlyStopping 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | 10 | from clmr.datasets import get_dataset 11 | from clmr.data import ContrastiveDataset 12 | from clmr.evaluation import evaluate 13 | from clmr.models import SampleCNN 14 | from clmr.modules import ContrastiveLearning, LinearEvaluation 15 | from clmr.utils import ( 16 | yaml_config_hook, 17 | load_encoder_checkpoint, 18 | load_finetuner_checkpoint, 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | 24 | parser = argparse.ArgumentParser(description="SimCLR") 25 | parser = Trainer.add_argparse_args(parser) 26 | 27 | config = yaml_config_hook("./config/config.yaml") 28 | for k, v in config.items(): 29 | parser.add_argument(f"--{k}", default=v, type=type(v)) 30 | 31 | args = parser.parse_args() 32 | pl.seed_everything(args.seed) 33 | args.accelerator = None 34 | 35 | if not os.path.exists(args.checkpoint_path): 36 | raise FileNotFoundError("That checkpoint does not exist") 37 | 38 | train_transform = [RandomResizedCrop(n_samples=args.audio_length)] 39 | 40 | # ------------ 41 | # dataloaders 42 | # ------------ 43 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train") 44 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid") 45 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test") 46 | 47 | contrastive_train_dataset = ContrastiveDataset( 48 | train_dataset, 49 | input_shape=(1, args.audio_length), 50 | transform=Compose(train_transform), 51 | ) 52 | 53 | contrastive_valid_dataset = ContrastiveDataset( 54 | valid_dataset, 55 | input_shape=(1, args.audio_length), 56 | transform=Compose(train_transform), 57 | ) 58 | 59 | contrastive_test_dataset = ContrastiveDataset( 60 | test_dataset, 61 | input_shape=(1, args.audio_length), 62 | transform=None, 63 | ) 64 | 65 | train_loader = DataLoader( 66 | contrastive_train_dataset, 67 | batch_size=args.finetuner_batch_size, 68 | num_workers=args.workers, 69 | shuffle=True, 70 | ) 71 | 72 | valid_loader = DataLoader( 73 | contrastive_valid_dataset, 74 | batch_size=args.finetuner_batch_size, 75 | num_workers=args.workers, 76 | shuffle=False, 77 | ) 78 | 79 | test_loader = DataLoader( 80 | contrastive_test_dataset, 81 | batch_size=args.finetuner_batch_size, 82 | num_workers=args.workers, 83 | shuffle=False, 84 | ) 85 | 86 | # ------------ 87 | # encoder 88 | # ------------ 89 | encoder = SampleCNN( 90 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], 91 | supervised=args.supervised, 92 | out_dim=train_dataset.n_classes, 93 | ) 94 | 95 | n_features = encoder.fc.in_features # get dimensions of last fully-connected layer 96 | 97 | state_dict = load_encoder_checkpoint(args.checkpoint_path, train_dataset.n_classes) 98 | encoder.load_state_dict(state_dict) 99 | 100 | cl = ContrastiveLearning(args, encoder) 101 | cl.eval() 102 | cl.freeze() 103 | 104 | module = LinearEvaluation( 105 | args, 106 | cl.encoder, 107 | hidden_dim=n_features, 108 | output_dim=train_dataset.n_classes, 109 | ) 110 | 111 | train_representations_dataset = module.extract_representations(train_loader) 112 | train_loader = DataLoader( 113 | train_representations_dataset, 114 | batch_size=args.batch_size, 115 | num_workers=args.workers, 116 | shuffle=True, 117 | ) 118 | 119 | valid_representations_dataset = module.extract_representations(valid_loader) 120 | valid_loader = DataLoader( 121 | valid_representations_dataset, 122 | batch_size=args.batch_size, 123 | num_workers=args.workers, 124 | shuffle=False, 125 | ) 126 | 127 | if args.finetuner_checkpoint_path: 128 | state_dict = load_finetuner_checkpoint(args.finetuner_checkpoint_path) 129 | module.model.load_state_dict(state_dict) 130 | else: 131 | early_stop_callback = EarlyStopping( 132 | monitor="Valid/loss", patience=10, verbose=False, mode="min" 133 | ) 134 | 135 | trainer = Trainer.from_argparse_args( 136 | args, 137 | logger=TensorBoardLogger( 138 | "runs", name="CLMRv2-eval-{}".format(args.dataset) 139 | ), 140 | max_epochs=args.finetuner_max_epochs, 141 | callbacks=[early_stop_callback], 142 | ) 143 | trainer.fit(module, train_loader, valid_loader) 144 | 145 | device = "cuda:0" if args.gpus else "cpu" 146 | results = evaluate( 147 | module.encoder, 148 | module.model, 149 | contrastive_test_dataset, 150 | args.dataset, 151 | args.audio_length, 152 | device=device, 153 | ) 154 | print(results) 155 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.loggers import TensorBoardLogger 6 | from torch.utils.data import DataLoader 7 | 8 | # Audio Augmentations 9 | from torchaudio_augmentations import ( 10 | RandomApply, 11 | ComposeMany, 12 | RandomResizedCrop, 13 | PolarityInversion, 14 | Noise, 15 | Gain, 16 | HighLowPass, 17 | Delay, 18 | PitchShift, 19 | Reverb, 20 | ) 21 | 22 | from clmr.data import ContrastiveDataset 23 | from clmr.datasets import get_dataset 24 | from clmr.evaluation import evaluate 25 | from clmr.models import SampleCNN 26 | from clmr.modules import ContrastiveLearning, SupervisedLearning 27 | from clmr.utils import yaml_config_hook 28 | 29 | 30 | if __name__ == "__main__": 31 | 32 | parser = argparse.ArgumentParser(description="CLMR") 33 | parser = Trainer.add_argparse_args(parser) 34 | 35 | config = yaml_config_hook("./config/config.yaml") 36 | for k, v in config.items(): 37 | parser.add_argument(f"--{k}", default=v, type=type(v)) 38 | 39 | args = parser.parse_args() 40 | pl.seed_everything(args.seed) 41 | 42 | # ------------ 43 | # data augmentations 44 | # ------------ 45 | if args.supervised: 46 | train_transform = [RandomResizedCrop(n_samples=args.audio_length)] 47 | num_augmented_samples = 1 48 | else: 49 | train_transform = [ 50 | RandomResizedCrop(n_samples=args.audio_length), 51 | RandomApply([PolarityInversion()], p=args.transforms_polarity), 52 | RandomApply([Noise()], p=args.transforms_noise), 53 | RandomApply([Gain()], p=args.transforms_gain), 54 | RandomApply( 55 | [HighLowPass(sample_rate=args.sample_rate)], p=args.transforms_filters 56 | ), 57 | RandomApply([Delay(sample_rate=args.sample_rate)], p=args.transforms_delay), 58 | RandomApply( 59 | [ 60 | PitchShift( 61 | n_samples=args.audio_length, 62 | sample_rate=args.sample_rate, 63 | ) 64 | ], 65 | p=args.transforms_pitch, 66 | ), 67 | RandomApply( 68 | [Reverb(sample_rate=args.sample_rate)], p=args.transforms_reverb 69 | ), 70 | ] 71 | num_augmented_samples = 2 72 | 73 | # ------------ 74 | # dataloaders 75 | # ------------ 76 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train") 77 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid") 78 | contrastive_train_dataset = ContrastiveDataset( 79 | train_dataset, 80 | input_shape=(1, args.audio_length), 81 | transform=ComposeMany( 82 | train_transform, num_augmented_samples=num_augmented_samples 83 | ), 84 | ) 85 | 86 | contrastive_valid_dataset = ContrastiveDataset( 87 | valid_dataset, 88 | input_shape=(1, args.audio_length), 89 | transform=ComposeMany( 90 | train_transform, num_augmented_samples=num_augmented_samples 91 | ), 92 | ) 93 | 94 | train_loader = DataLoader( 95 | contrastive_train_dataset, 96 | batch_size=args.batch_size, 97 | num_workers=args.workers, 98 | drop_last=True, 99 | shuffle=True, 100 | ) 101 | 102 | valid_loader = DataLoader( 103 | contrastive_valid_dataset, 104 | batch_size=args.batch_size, 105 | num_workers=args.workers, 106 | drop_last=True, 107 | shuffle=False, 108 | ) 109 | 110 | # ------------ 111 | # encoder 112 | # ------------ 113 | encoder = SampleCNN( 114 | strides=[3, 3, 3, 3, 3, 3, 3, 3, 3], 115 | supervised=args.supervised, 116 | out_dim=train_dataset.n_classes, 117 | ) 118 | 119 | # ------------ 120 | # model 121 | # ------------ 122 | if args.supervised: 123 | module = SupervisedLearning(args, encoder, output_dim=train_dataset.n_classes) 124 | else: 125 | module = ContrastiveLearning(args, encoder) 126 | 127 | logger = TensorBoardLogger("runs", name="CLMRv2-{}".format(args.dataset)) 128 | if args.checkpoint_path: 129 | module = module.load_from_checkpoint( 130 | args.checkpoint_path, encoder=encoder, output_dim=train_dataset.n_classes 131 | ) 132 | 133 | else: 134 | # ------------ 135 | # training 136 | # ------------ 137 | 138 | if args.supervised: 139 | early_stopping = EarlyStopping(monitor="Valid/loss", patience=20) 140 | else: 141 | early_stopping = None 142 | 143 | trainer = Trainer.from_argparse_args( 144 | args, 145 | logger=logger, 146 | sync_batchnorm=True, 147 | max_epochs=args.max_epochs, 148 | log_every_n_steps=10, 149 | check_val_every_n_epoch=1, 150 | accelerator=args.accelerator, 151 | ) 152 | trainer.fit(module, train_loader, valid_loader) 153 | 154 | if args.supervised: 155 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test") 156 | 157 | contrastive_test_dataset = ContrastiveDataset( 158 | test_dataset, 159 | input_shape=(1, args.audio_length), 160 | transform=None, 161 | ) 162 | 163 | device = "cuda:0" if args.gpus else "cpu" 164 | results = evaluate( 165 | module.encoder, 166 | None, 167 | contrastive_test_dataset, 168 | args.dataset, 169 | args.audio_length, 170 | device=device, 171 | ) 172 | print(results) 173 | -------------------------------------------------------------------------------- /media/clmr_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/media/clmr_model.png -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | from clmr.datasets import get_dataset 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--dataset", type=str, default="magnatagatune") 8 | parser.add_argument("--dataset_dir", type=str, default="./data") 9 | parser.add_argument("--sample_rate", type=int, default=22050) 10 | args = parser.parse_args() 11 | 12 | train_dataset = get_dataset(args.dataset, args.dataset_dir, subset="train") 13 | valid_dataset = get_dataset(args.dataset, args.dataset_dir, subset="valid") 14 | test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test") 15 | 16 | for i in tqdm(range(len(train_dataset))): 17 | train_dataset.preprocess(i, args.sample_rate) 18 | 19 | for i in tqdm(range(len(valid_dataset))): 20 | valid_dataset.preprocess(i, args.sample_rate) 21 | 22 | for i in tqdm(range(len(test_dataset))): 23 | test_dataset.preprocess(i, args.sample_rate) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | simclr 2 | torchaudio-augmentations 3 | torch==1.9.0 4 | torchaudio 5 | pytorch-lightning 6 | soundfile 7 | sklearn 8 | matplotlib -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = "clmr" 16 | DESCRIPTION = "Contrastive Learning of Musical Representations" 17 | URL = "https://github.com/spijkervet/CLMR" 18 | EMAIL = "janne.spijkervet@gmail.com" 19 | AUTHOR = "Janne Spijkervet" 20 | REQUIRES_PYTHON = ">=3.6.0" 21 | VERSION = "0.1.0" 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | "torch==1.9.0", 26 | "torchaudio", 27 | "simclr", 28 | "torchaudio-augmentations", 29 | "pytorch-lightning", 30 | "soundfile", 31 | "sklearn", 32 | "matplotlib", 33 | ] 34 | 35 | # What packages are optional? 36 | EXTRAS = { 37 | # 'fancy feature': ['django'], 38 | } 39 | 40 | # The rest you shouldn't have to touch too much :) 41 | # ------------------------------------------------ 42 | # Except, perhaps the License and Trove Classifiers! 43 | # If you do change the License, remember to change the Trove Classifier for that! 44 | 45 | here = os.path.abspath(os.path.dirname(__file__)) 46 | 47 | # Import the README and use it as the long-description. 48 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 49 | try: 50 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 51 | long_description = "\n" + f.read() 52 | except FileNotFoundError: 53 | long_description = DESCRIPTION 54 | 55 | # Load the package's __version__.py module as a dictionary. 56 | about = {} 57 | if not VERSION: 58 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 59 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 60 | exec(f.read(), about) 61 | else: 62 | about["__version__"] = VERSION 63 | 64 | 65 | class UploadCommand(Command): 66 | """Support setup.py upload.""" 67 | 68 | description = "Build and publish the package." 69 | user_options = [] 70 | 71 | @staticmethod 72 | def status(s): 73 | """Prints things in bold.""" 74 | print("\033[1m{0}\033[0m".format(s)) 75 | 76 | def initialize_options(self): 77 | pass 78 | 79 | def finalize_options(self): 80 | pass 81 | 82 | def run(self): 83 | try: 84 | self.status("Removing previous builds…") 85 | rmtree(os.path.join(here, "dist")) 86 | except OSError: 87 | pass 88 | 89 | self.status("Building Source and Wheel (universal) distribution…") 90 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 91 | 92 | self.status("Uploading the package to PyPI via Twine…") 93 | os.system("twine upload dist/*") 94 | 95 | self.status("Pushing git tags…") 96 | os.system("git tag v{0}".format(about["__version__"])) 97 | os.system("git push --tags") 98 | 99 | sys.exit() 100 | 101 | 102 | # Where the magic happens: 103 | setup( 104 | name=NAME, 105 | version=about["__version__"], 106 | description=DESCRIPTION, 107 | long_description=long_description, 108 | long_description_content_type="text/markdown", 109 | author=AUTHOR, 110 | author_email=EMAIL, 111 | python_requires=REQUIRES_PYTHON, 112 | url=URL, 113 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 114 | # If your package is a single module, use this instead of 'packages': 115 | # py_modules=['mypackage'], 116 | # entry_points={ 117 | # 'console_scripts': ['mycli=mymodule:cli'], 118 | # }, 119 | install_requires=REQUIRED, 120 | extras_require=EXTRAS, 121 | include_package_data=True, 122 | license="MIT", 123 | classifiers=[ 124 | # Trove classifiers 125 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 126 | "License :: OSI Approved :: MIT License", 127 | "Programming Language :: Python", 128 | "Programming Language :: Python :: 3", 129 | "Programming Language :: Python :: 3.6", 130 | "Programming Language :: Python :: Implementation :: CPython", 131 | "Programming Language :: Python :: Implementation :: PyPy", 132 | ], 133 | # $ setup.py publish support. 134 | cmdclass={ 135 | "upload": UploadCommand, 136 | }, 137 | ) 138 | -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/.DS_Store -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/audioset/1272-128104-0000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spijkervet/CLMR/423414d086350d04fd46e48bc74af30a208eff90/tests/data/audioset/1272-128104-0000.wav -------------------------------------------------------------------------------- /tests/test_audioset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torchaudio 3 | from torchaudio_augmentations import ( 4 | Compose, 5 | RandomApply, 6 | RandomResizedCrop, 7 | PolarityInversion, 8 | Noise, 9 | Gain, 10 | Delay, 11 | PitchShift, 12 | Reverb, 13 | ) 14 | from clmr.datasets import AUDIO 15 | 16 | 17 | class TestAudioSet(unittest.TestCase): 18 | sample_rate = 16000 19 | 20 | def get_audio_transforms(self, num_samples): 21 | transform = Compose( 22 | [ 23 | RandomResizedCrop(n_samples=num_samples), 24 | RandomApply([PolarityInversion()], p=0.8), 25 | RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3), 26 | RandomApply([Gain()], p=0.2), 27 | RandomApply([Delay(sample_rate=self.sample_rate)], p=0.5), 28 | RandomApply( 29 | [PitchShift(n_samples=num_samples, sample_rate=self.sample_rate)], 30 | p=0.4, 31 | ), 32 | RandomApply([Reverb(sample_rate=self.sample_rate)], p=0.3), 33 | ] 34 | ) 35 | return transform 36 | 37 | def test_audioset(self): 38 | audio_dataset = AUDIO("./tests/data/audioset") 39 | audio, label = audio_dataset[0] 40 | assert audio.shape[0] == 1 41 | assert audio.shape[1] == 93680 42 | 43 | num_samples = ( 44 | self.sample_rate * 5 45 | ) # the test item is approximately 5.8 seconds. 46 | transform = self.get_audio_transforms(num_samples=num_samples) 47 | audio = transform(audio) 48 | torchaudio.save("augmented_sample.wav", audio, sample_rate=self.sample_rate) 49 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | from clmr.datasets import ( 4 | get_dataset, 5 | AUDIO, 6 | LIBRISPEECH, 7 | GTZAN, 8 | MAGNATAGATUNE, 9 | MillionSongDataset, 10 | ) 11 | 12 | 13 | class TestAudioSet(unittest.TestCase): 14 | 15 | datasets = { 16 | "librispeech": LIBRISPEECH, 17 | "gtzan": GTZAN, 18 | "magnatagatune": MAGNATAGATUNE, 19 | "msd": MillionSongDataset, 20 | "audio": AUDIO, 21 | } 22 | 23 | def test_dataset_names(self): 24 | for dataset_name, dataset_type in self.datasets.items(): 25 | with pytest.raises(RuntimeError): 26 | _ = get_dataset( 27 | dataset_name, "./data/audio", subset="train", download=False 28 | ) 29 | 30 | def test_custom_audio_dataset(self): 31 | audio_dataset = get_dataset( 32 | "audio", "./tests/data/audioset", subset="train", download=False 33 | ) 34 | assert type(audio_dataset) == AUDIO 35 | assert len(audio_dataset) == 1 36 | -------------------------------------------------------------------------------- /tests/test_spectogram.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torchaudio 3 | import torch.nn as nn 4 | from torchaudio_augmentations import * 5 | 6 | from clmr.datasets import AUDIO 7 | 8 | 9 | class TestAudioSet(unittest.TestCase): 10 | sample_rate = 16000 11 | 12 | def get_audio_transforms(self, num_samples): 13 | transform = Compose( 14 | [ 15 | RandomResizedCrop(n_samples=num_samples), 16 | RandomApply([PolarityInversion()], p=0.8), 17 | RandomApply([Noise(min_snr=0.3, max_snr=0.5)], p=0.3), 18 | RandomApply([Gain()], p=0.2), 19 | RandomApply([Delay(sample_rate=self.sample_rate)], p=0.5), 20 | RandomApply( 21 | [PitchShift(n_samples=num_samples, sample_rate=self.sample_rate)], 22 | p=0.4, 23 | ), 24 | RandomApply([Reverb(sample_rate=self.sample_rate)], p=0.3), 25 | ] 26 | ) 27 | return transform 28 | 29 | def test_audioset(self): 30 | audio_dataset = AUDIO("tests/data/audioset") 31 | audio, label = audio_dataset[0] 32 | 33 | sample_rate = 22050 34 | n_fft = 1024 35 | n_mels = 128 36 | stype = "magnitude" # magnitude 37 | top_db = None # f_max 38 | 39 | transform = self.get_audio_transforms(num_samples=sample_rate) 40 | 41 | spec_transform = nn.Sequential( 42 | torchaudio.transforms.MelSpectrogram( 43 | sample_rate=sample_rate, 44 | n_fft=n_fft, 45 | n_mels=n_mels, 46 | ), 47 | torchaudio.transforms.AmplitudeToDB(stype=stype, top_db=top_db), 48 | ) 49 | 50 | audio = transform(audio) 51 | audio = spec_transform(audio) 52 | assert audio.shape[1] == 128 53 | assert audio.shape[2] == 44 54 | --------------------------------------------------------------------------------