├── .gitignore ├── LICENSE ├── README.md ├── audio_encoder ├── __init__.py ├── audio_processing.py ├── encoder.py └── train_encoder.py ├── requirements.txt ├── run.sh ├── setup.py └── supervised_examples ├── cnn_genre_classification.py ├── eval_cnn.py ├── prepare_data.py └── test_accuracy.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mansar Youness 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COLA 2 | (Unofficial)-Re-Impementation of the model archiecture of CONTRASTIVE LEARNING OF GENERAL-PURPOSE AUDIO REPRESENTATIONS in PyTorch. 3 | 4 | Paper: [CONTRASTIVE LEARNING OF GENERAL-PURPOSE AUDIO REPRESENTATIONS](https://arxiv.org/abs/2010.10915) 5 | 6 | Paper's Official Code (In tensorflow): [tf-code](https://github.com/google-research/google-research/tree/master/cola) 7 | 8 | ### Install 9 | ``` 10 | git clone https://github.com/CVxTz/COLA_pytorch 11 | cd COLA_pytorch 12 | python -m pip install . 13 | ``` 14 | 15 | ### Run 16 | 17 | ``` 18 | # Data download: download fma data and metadata 19 | 20 | 21 | wget -c https://os.unil.cloud.switch.ch/fma/fma_metadata.zip 22 | wget -c https://os.unil.cloud.switch.ch/fma/fma_small.zip 23 | wget -c https://os.unil.cloud.switch.ch/fma/fma_large.zip 24 | 25 | 26 | # Data preparation : prepare json with fma_small labels and pre-compute mel-spectrograms and save them as .npy 27 | 28 | python supervised_examples/prepare_data.py --metadata_path "/media/ml/data_ml/fma_metadata/" 29 | python audio_encoder/audio_processing.py --mp3_path "/media/ml/data_ml/fma_large/" 30 | python audio_encoder/audio_processing.py --mp3_path "/media/ml/data_ml/fma_small/" 31 | 32 | # Training 33 | 34 | # Train with COLA 35 | 36 | python audio_encoder/train_encoder.py --mp3_path "/media/ml/data_ml/fma_large/" 37 | 38 | # Train Supervised 39 | 40 | python supervised_examples/cnn_genre_classification.py --metadata_path "/media/ml/data_ml/fma_metadata/" \ 41 | --mp3_path "/media/ml/data_ml/fma_small/" 42 | 43 | python supervised_examples/cnn_genre_classification.py --metadata_path "/media/ml/data_ml/fma_metadata/" \ 44 | --mp3_path "/media/ml/data_ml/fma_small/" \ 45 | --encoder_path "models/encoder.ckpt" 46 | 47 | ``` 48 | 49 | #### Data 50 | [FMA Data](https://github.com/mdeff/fma) 51 | 52 | #### Description 53 | 54 | This post is a short summary and steps to implement the following paper: 55 | 56 | * [Learning of General-Purpose Audio 57 | Representations](https://arxiv.org/abs/2010.10915) 58 | 59 | The objective of this paper is to learn self-supervised general-purpose audio 60 | representations using Discriminative Pre-Training. The authors train a 2D CNN 61 | EfficientNet-B0 to transform Mel-spectrograms into 1D-512 vectors. Those 62 | representations are then transferred to other tasks like Speaker Identification 63 | or Bird Song detection. 64 | 65 | The basic idea behind DPT is to define an anchor element, a positive element, 66 | and one or more distractors. A model is then trained to match the anchor with 67 | the positive example. 68 | 69 | ![](https://cdn-images-1.medium.com/max/800/1*BsHigYU_qjPpQvq1k09k1A.png) 70 | 71 | DPT — Image By Author 72 | 73 | One such way of using DPT is to use the triplet loss along with the Cosine 74 | similarity measure to train the model such as Cosine(F(P), F(A)) is much higher 75 | than Cosine(F(D), F(A)). This will make it so the representation in the latent 76 | space of the Anchor is much closer to the Positive example than it is to the 77 | Distractor. The authors of the paper linked above used this approach as a 78 | baseline to show that their approach **COLA **works much better. 79 | 80 | #### COLA 81 | 82 | This approach is applied to the audio domain. For each audio clip, the authors 83 | pick a segment to be the anchor and another to be the positive example, for each 84 | of those samples (Anchor, Positive) they pick the other samples in the training 85 | batch as the distractors. This is a great idea for two reasons: 86 | 87 | * There are multiple distractors, this makes the training task harder, forcing the 88 | model to learn more meaningful representations to solve it. 89 | * The distractors are reused from other samples in the batch, which reduces the 90 | IO, computing, and memory cost of generating them. 91 | 92 | COLA also uses a Bi-linear similarity, which is learned directly from the data. 93 | The authors show that Bi-Linear similarity works much better than Cosine, giving 94 | an extra **7%** average accuracy on downstream tasks in comparison. 95 | 96 | After computing the similarity between the anchor and the other examples, the 97 | similarity values are used in a cross-entropy loss that measures the model’s 98 | ability to identify the positive example among the distractors (Eq 2 in the 99 | paper). 100 | 101 | #### COLA Evaluation 102 | 103 | **Linear Model Evaluation** 104 | 105 | COLA is used to train an EfficientNet-B0 on AudioSet, a dataset of around 1M 106 | audio clips taken from YouTube. The feature vectors generated by the model are 107 | then used to train a linear classifier on a wide range of downstream tasks. The 108 | better the representations learned by the model are, the better its performance 109 | will be when used as input to a linear model that performs a supervised task. 110 | The authors found that COLA outperforms other methods like triplet loss by an 111 | extra **20%** average accuracy on downstream tasks ( Table 2 of the paper) 112 | 113 | **Fine-tuning Evaluation** 114 | 115 | Another way to test this approach is to fine-tune the model on downstream tasks. 116 | This allowed the authors to compare a model pre-trained using COLA to one that 117 | is trained from scratch. Their results show that the pre-trained model 118 | outperforms the model trained from scratch by around **1.2%** on average. 119 | 120 | #### PyTorch Implementation 121 | 122 | 123 | I don’t have the compute resources to reproduce the experiments of the paper, so 124 | I tried to do something similar on a much smaller scale. I pre-trained a model 125 | using COLA on [FMA Large](https://github.com/mdeff/fma) ( without the labels) 126 | for a few epochs and then fine-tune on music genre detection applied to FMA 127 | Small. 128 | 129 | The results on FMA small are as follows: 130 | 131 | * Random guess: 12.5% 132 | * Trained from scratch: 51.1% 133 | * Pre-trained using COLA:** 54.3%** 134 | 135 | #### Conclusion 136 | 137 | The paper [Learning of General-Purpose Audio 138 | Representations](https://arxiv.org/abs/2010.10915) introduces the COLA 139 | pre-training approach, which implements some great ideas that make 140 | self-supervised training more effective, like using batch samples as distractors 141 | and the bi-linear similarity measure. This approach can be used to improve the 142 | performance of downstream supervised audio tasks. 143 | 144 | Code: 145 | [https://github.com/CVxTz/COLA_pytorch](https://github.com/CVxTz/COLA_pytorch) 146 | 147 | -------------------------------------------------------------------------------- /audio_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVxTz/COLA_pytorch/c4b7d4e807fec77e4e7d24509c30f97a669d4f28/audio_encoder/__init__.py -------------------------------------------------------------------------------- /audio_encoder/audio_processing.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | input_length = 16000 * 30 7 | 8 | n_mels = 64 9 | 10 | 11 | def pre_process_audio_mel_t(audio, sample_rate=16000): 12 | mel_spec = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=n_mels) 13 | mel_db = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 14 | 15 | return mel_db.T 16 | 17 | 18 | def load_audio_file(file_path, input_length=input_length): 19 | try: 20 | data = librosa.core.load(file_path, sr=16000)[0] # , sr=16000 21 | except ZeroDivisionError: 22 | data = [] 23 | 24 | if len(data) > input_length: 25 | 26 | max_offset = len(data) - input_length 27 | 28 | offset = np.random.randint(max_offset) 29 | 30 | data = data[offset : (input_length + offset)] 31 | 32 | else: 33 | if input_length > len(data): 34 | max_offset = input_length - len(data) 35 | 36 | offset = np.random.randint(max_offset) 37 | else: 38 | offset = 0 39 | 40 | data = np.pad(data, (offset, input_length - len(data) - offset), "constant") 41 | 42 | data = pre_process_audio_mel_t(data) 43 | return data 44 | 45 | 46 | def random_crop(data, crop_size=128): 47 | start = int(random.random() * (data.shape[0] - crop_size)) 48 | return data[start : (start + crop_size), :] 49 | 50 | 51 | def random_mask(data, rate_start=0.1, rate_seq=0.2): 52 | new_data = data.copy() 53 | mean = new_data.mean() 54 | prev_zero = False 55 | for i in range(new_data.shape[0]): 56 | if random.random() < rate_start or ( 57 | prev_zero and random.random() < rate_seq 58 | ): 59 | prev_zero = True 60 | new_data[i, :] = mean 61 | else: 62 | prev_zero = False 63 | 64 | return new_data 65 | 66 | 67 | def random_multiply(data): 68 | new_data = data.copy() 69 | return new_data * (0.9 + random.random() / 5.) 70 | 71 | 72 | def save(path): 73 | data = load_audio_file(path) 74 | np.save(path.replace(".mp3", ".npy"), data) 75 | return True 76 | 77 | 78 | if __name__ == "__main__": 79 | from tqdm import tqdm 80 | from glob import glob 81 | from multiprocessing import Pool 82 | import argparse 83 | from pathlib import Path 84 | import matplotlib.pyplot as plt 85 | 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--mp3_path") 88 | args = parser.parse_args() 89 | 90 | base_path = Path(args.mp3_path) 91 | 92 | files = sorted(list(glob(str(base_path / "*/*.mp3")))) 93 | 94 | print(len(files)) 95 | 96 | p = Pool(8) 97 | 98 | # for i, _ in tqdm(enumerate(p.imap(save, files)), total=len(files)): 99 | # if i % 1000 == 0: 100 | # print(i) 101 | 102 | data = load_audio_file(base_path / "000/000002.mp3", input_length=16000 * 30) 103 | 104 | print(data.shape, np.min(data), np.max(data)) 105 | new_data = random_mask(data) 106 | 107 | print(np.mean(new_data == data)) 108 | 109 | plt.figure() 110 | plt.imshow(data.T) 111 | plt.show() 112 | 113 | plt.figure() 114 | plt.imshow(new_data.T) 115 | plt.show() 116 | 117 | print(np.min(data), np.max(data)) 118 | -------------------------------------------------------------------------------- /audio_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from efficientnet_pytorch import EfficientNet 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoder(torch.nn.Module): 8 | def __init__(self, drop_connect_rate=0.1): 9 | super(Encoder, self).__init__() 10 | 11 | self.cnn1 = torch.nn.Conv2d(1, 3, kernel_size=3) 12 | self.efficientnet = EfficientNet.from_name( 13 | "efficientnet-b0", include_top=False, drop_connect_rate=drop_connect_rate 14 | ) 15 | 16 | def forward(self, x): 17 | x = x.unsqueeze(1) 18 | 19 | x = self.cnn1(x) 20 | x = self.efficientnet(x) 21 | 22 | y = x.squeeze(3).squeeze(2) 23 | 24 | return y 25 | 26 | 27 | class Cola(pl.LightningModule): 28 | def __init__(self, p=0.1): 29 | super().__init__() 30 | self.save_hyperparameters() 31 | 32 | self.p = p 33 | 34 | self.do = torch.nn.Dropout(p=self.p) 35 | 36 | self.encoder = Encoder(drop_connect_rate=p) 37 | 38 | self.g = torch.nn.Linear(1280, 512) 39 | self.layer_norm = torch.nn.LayerNorm(normalized_shape=512) 40 | self.linear = torch.nn.Linear(512, 512, bias=False) 41 | 42 | def forward(self, x): 43 | x1, x2 = x 44 | 45 | x1 = self.do(self.encoder(x1)) 46 | x1 = self.do(self.g(x1)) 47 | x1 = self.do(torch.tanh(self.layer_norm(x1))) 48 | 49 | x2 = self.do(self.encoder(x2)) 50 | x2 = self.do(self.g(x2)) 51 | x2 = self.do(torch.tanh(self.layer_norm(x2))) 52 | 53 | x1 = self.linear(x1) 54 | 55 | return x1, x2 56 | 57 | def training_step(self, x, batch_idx): 58 | x1, x2 = self(x) 59 | 60 | y = torch.arange(x1.size(0), device=x1.device) 61 | 62 | y_hat = torch.mm(x1, x2.t()) 63 | 64 | loss = F.cross_entropy(y_hat, y) 65 | 66 | _, predicted = torch.max(y_hat, 1) 67 | acc = (predicted == y).double().mean() 68 | 69 | self.log("train_loss", loss) 70 | self.log("train_acc", acc) 71 | 72 | return loss 73 | 74 | def validation_step(self, x, batch_idx): 75 | x1, x2 = self(x) 76 | 77 | y = torch.arange(x1.size(0), device=x1.device) 78 | 79 | y_hat = torch.mm(x1, x2.t()) 80 | 81 | loss = F.cross_entropy(y_hat, y) 82 | 83 | _, predicted = torch.max(y_hat, 1) 84 | acc = (predicted == y).double().mean() 85 | 86 | self.log("valid_loss", loss) 87 | self.log("valid_acc", acc) 88 | 89 | def test_step(self, x, batch_idx): 90 | x1, x2 = self(x) 91 | 92 | y = torch.arange(x1.size(0), device=x1.device) 93 | 94 | y_hat = torch.mm(x1, x2.t()) 95 | 96 | loss = F.cross_entropy(y_hat, y) 97 | 98 | _, predicted = torch.max(y_hat, 1) 99 | acc = (predicted == y).double().mean() 100 | 101 | self.log("test_loss", loss) 102 | self.log("test_acc", acc) 103 | 104 | def configure_optimizers(self): 105 | return torch.optim.Adam(self.parameters(), lr=1e-4) 106 | 107 | 108 | class AudioClassifier(pl.LightningModule): 109 | def __init__(self, classes=8, p=0.1): 110 | super().__init__() 111 | self.save_hyperparameters() 112 | 113 | self.p = p 114 | 115 | self.do = torch.nn.Dropout(p=self.p) 116 | 117 | self.encoder = Encoder(drop_connect_rate=self.p) 118 | 119 | self.g = torch.nn.Linear(1280, 512) 120 | self.layer_norm = torch.nn.LayerNorm(normalized_shape=512) 121 | 122 | self.fc1 = torch.nn.Linear(512, 256) 123 | self.fy = torch.nn.Linear(256, classes) 124 | 125 | def forward(self, x): 126 | x = self.do(self.encoder(x)) 127 | 128 | x = self.do(self.g(x)) 129 | x = self.do(torch.tanh(self.layer_norm(x))) 130 | 131 | x = F.relu(self.do(self.fc1(x))) 132 | y_hat = self.fy(x) 133 | 134 | return y_hat 135 | 136 | def training_step(self, batch, batch_idx): 137 | x, y = batch 138 | 139 | y_hat = self(x) 140 | 141 | loss = F.cross_entropy(y_hat, y) 142 | 143 | _, predicted = torch.max(y_hat, 1) 144 | acc = (predicted == y).double().mean() 145 | 146 | self.log("train_loss", loss) 147 | self.log("train_acc", acc) 148 | 149 | return loss 150 | 151 | def validation_step(self, batch, batch_idx): 152 | x, y = batch 153 | 154 | y_hat = self(x) 155 | 156 | loss = F.cross_entropy(y_hat, y) 157 | 158 | _, predicted = torch.max(y_hat, 1) 159 | acc = (predicted == y).double().mean() 160 | 161 | self.log("valid_loss", loss) 162 | self.log("valid_acc", acc) 163 | 164 | def test_step(self, batch, batch_idx): 165 | x, y = batch 166 | 167 | y_hat = self(x) 168 | 169 | loss = F.cross_entropy(y_hat, y) 170 | 171 | _, predicted = torch.max(y_hat, 1) 172 | acc = (predicted == y).double().mean() 173 | 174 | self.log("test_loss", loss) 175 | self.log("test_acc", acc) 176 | 177 | def configure_optimizers(self): 178 | return torch.optim.Adam(self.parameters(), lr=1e-4) 179 | -------------------------------------------------------------------------------- /audio_encoder/train_encoder.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | from sklearn.model_selection import train_test_split 9 | from torch.utils.data import DataLoader 10 | 11 | from audio_encoder.audio_processing import random_crop, random_mask, random_multiply 12 | from audio_encoder.encoder import Cola 13 | 14 | 15 | class AudioDataset(torch.utils.data.Dataset): 16 | def __init__(self, data, max_len=100, augment=True): 17 | self.data = data 18 | self.max_len = max_len 19 | self.augment = augment 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | def __getitem__(self, idx): 25 | npy_path = self.data[idx] 26 | 27 | x = np.load(npy_path) 28 | 29 | if self.augment: 30 | x = random_mask(x) 31 | 32 | x1 = random_crop(x, crop_size=self.max_len) 33 | x2 = random_crop(x, crop_size=self.max_len) 34 | 35 | if self.augment: 36 | x1 = random_multiply(x1) 37 | x2 = random_multiply(x2) 38 | 39 | x1 = torch.tensor(x1, dtype=torch.float) 40 | x2 = torch.tensor(x2, dtype=torch.float) 41 | 42 | return x1, x2 43 | 44 | 45 | class DecayLearningRate(pl.Callback): 46 | def __init__(self): 47 | self.old_lrs = [] 48 | 49 | def on_train_start(self, trainer, pl_module): 50 | # track the initial learning rates 51 | for opt_idx, optimizer in enumerate(trainer.optimizers): 52 | group = [] 53 | for param_group in optimizer.param_groups: 54 | group.append(param_group["lr"]) 55 | self.old_lrs.append(group) 56 | 57 | def on_train_epoch_end(self, trainer, pl_module, outputs): 58 | for opt_idx, optimizer in enumerate(trainer.optimizers): 59 | old_lr_group = self.old_lrs[opt_idx] 60 | new_lr_group = [] 61 | for p_idx, param_group in enumerate(optimizer.param_groups): 62 | old_lr = old_lr_group[p_idx] 63 | new_lr = old_lr * 0.99 64 | new_lr_group.append(new_lr) 65 | param_group["lr"] = new_lr 66 | self.old_lrs[opt_idx] = new_lr_group 67 | 68 | 69 | if __name__ == "__main__": 70 | import argparse 71 | from pathlib import Path 72 | 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--mp3_path") 75 | args = parser.parse_args() 76 | 77 | mp3_path = Path(args.mp3_path) 78 | 79 | batch_size = 128 80 | epochs = 512 81 | 82 | files = sorted(list(glob(str(mp3_path / "*/*.npy")))) 83 | 84 | _train, test = train_test_split(files, test_size=0.05, random_state=1337) 85 | 86 | train, val = train_test_split(_train, test_size=0.05, random_state=1337) 87 | 88 | train_data = AudioDataset(train, augment=True) 89 | test_data = AudioDataset(test, augment=False) 90 | val_data = AudioDataset(val, augment=False) 91 | 92 | train_loader = DataLoader( 93 | train_data, batch_size=batch_size, num_workers=8, shuffle=True 94 | ) 95 | val_loader = DataLoader( 96 | val_data, batch_size=batch_size, num_workers=8, shuffle=True 97 | ) 98 | test_loader = DataLoader( 99 | test_data, batch_size=batch_size, shuffle=False, num_workers=8 100 | ) 101 | 102 | model = Cola() 103 | 104 | logger = TensorBoardLogger( 105 | save_dir=".", 106 | name="lightning_logs", 107 | ) 108 | 109 | checkpoint_callback = ModelCheckpoint( 110 | monitor="valid_acc", mode="max", filepath="models/", prefix="encoder" 111 | ) 112 | 113 | trainer = pl.Trainer( 114 | max_epochs=epochs, 115 | gpus=1, 116 | logger=logger, 117 | checkpoint_callback=checkpoint_callback, 118 | callbacks=[DecayLearningRate()], 119 | ) 120 | trainer.fit(model, train_loader, val_loader) 121 | 122 | trainer.test(test_dataloaders=test_loader) 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | torchvision==0.7.0+cu101 3 | tqdm==4.50.2 4 | pytorch_lightning==1.0.0 5 | torch==1.6.0+cu101 6 | pandas==1.1.3 7 | matplotlib==3.3.2 8 | librosa==0.8.0 9 | scikit_learn==0.23.2 10 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ## Data download: download fma data and metadata 2 | # 3 | # 4 | #wget -c https://os.unil.cloud.switch.ch/fma/fma_metadata.zip 5 | #wget -c https://os.unil.cloud.switch.ch/fma/fma_small.zip 6 | #wget -c https://os.unil.cloud.switch.ch/fma/fma_large.zip 7 | # 8 | # 9 | ## Data preparation : prepare json with fma_small labels and pre-compute mel-spectrograms and save them as .npy 10 | # 11 | #python supervised_examples/prepare_data.py --metadata_path "/media/ml/data_ml/fma_metadata/" 12 | #python audio_encoder/audio_processing.py --mp3_path "/media/ml/data_ml/fma_large/" 13 | #python audio_encoder/audio_processing.py --mp3_path "/media/ml/data_ml/fma_small/" 14 | # 15 | ## Training 16 | # 17 | ## Train with COLA 18 | # 19 | #python audio_encoder/train_encoder.py --mp3_path "/media/ml/data_ml/fma_large/" 20 | # 21 | ## Train Supervised 22 | # 23 | #python supervised_examples/cnn_genre_classification.py --metadata_path "/media/ml/data_ml/fma_metadata/" \ 24 | # --mp3_path "/media/ml/data_ml/fma_small/" 25 | # 26 | #python supervised_examples/cnn_genre_classification.py --metadata_path "/media/ml/data_ml/fma_metadata/" \ 27 | # --mp3_path "/media/ml/data_ml/fma_small/" \ 28 | # --encoder_path "models/encoder.ckpt" 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | 5 | 6 | # Utility function to read the README file. 7 | # Used for the long_description. It's nice, because now 1) we have a top level 8 | # README file and 2) it's easier to type in the README file than to put a raw 9 | # string in below ... 10 | def read(fname): 11 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 12 | 13 | 14 | setup( 15 | name="audio_encoder", 16 | version="0.0.1", 17 | author="Youness MANSAR", 18 | author_email="mansaryounessecp@gmail.com", 19 | description="audio_encoder", 20 | license="MIT", 21 | keywords="audio", 22 | url="https://github.com/CVxTz/NeuralAudioEncoder", 23 | packages=["audio_encoder"], 24 | classifiers=[ 25 | "Development Status :: 3 - Alpha", 26 | "Topic :: Utilities", 27 | "License :: OSI Approved :: BSD License", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /supervised_examples/cnn_genre_classification.py: -------------------------------------------------------------------------------- 1 | import json 2 | from glob import glob 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | 12 | from audio_encoder.audio_processing import random_crop, random_mask, random_multiply 13 | from audio_encoder.encoder import AudioClassifier 14 | from supervised_examples.prepare_data import get_id_from_path 15 | 16 | 17 | class AudioDataset(torch.utils.data.Dataset): 18 | def __init__(self, data, max_len=512, augment=True): 19 | self.data = data 20 | self.max_len = max_len 21 | self.augment = augment 22 | 23 | def __len__(self): 24 | return len(self.data) 25 | 26 | def __getitem__(self, idx): 27 | npy_path = self.data[idx][0] 28 | label = self.data[idx][1] 29 | 30 | x = np.load(npy_path) 31 | 32 | x = random_crop(x, crop_size=self.max_len) 33 | 34 | if self.augment: 35 | x = random_mask(x) 36 | x = random_multiply(x) 37 | 38 | x = torch.tensor(x, dtype=torch.float) 39 | label = torch.tensor(label, dtype=torch.long) 40 | 41 | return x, label 42 | 43 | 44 | class DecayLearningRate(pl.Callback): 45 | def __init__(self): 46 | self.old_lrs = [] 47 | 48 | def on_train_start(self, trainer, pl_module): 49 | # track the initial learning rates 50 | for opt_idx, optimizer in enumerate(trainer.optimizers): 51 | group = [] 52 | for param_group in optimizer.param_groups: 53 | group.append(param_group["lr"]) 54 | self.old_lrs.append(group) 55 | 56 | def on_train_epoch_end(self, trainer, pl_module, outputs): 57 | for opt_idx, optimizer in enumerate(trainer.optimizers): 58 | old_lr_group = self.old_lrs[opt_idx] 59 | new_lr_group = [] 60 | for p_idx, param_group in enumerate(optimizer.param_groups): 61 | old_lr = old_lr_group[p_idx] 62 | new_lr = old_lr * 0.97 63 | new_lr_group.append(new_lr) 64 | param_group["lr"] = new_lr 65 | self.old_lrs[opt_idx] = new_lr_group 66 | 67 | 68 | if __name__ == "__main__": 69 | import argparse 70 | from pathlib import Path 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--metadata_path") 74 | parser.add_argument("--mp3_path") 75 | parser.add_argument("--encoder_path") 76 | 77 | args = parser.parse_args() 78 | 79 | metadata_path = Path(args.metadata_path) 80 | mp3_path = Path(args.mp3_path) 81 | 82 | batch_size = 64 83 | epochs = 64 84 | 85 | CLASS_MAPPING = json.load(open(metadata_path / "mapping.json")) 86 | id_to_genres = json.load(open(metadata_path / "tracks_genre.json")) 87 | id_to_genres = {int(k): v for k, v in id_to_genres.items()} 88 | 89 | files = sorted(list(glob(str(mp3_path / "*/*.npy")))) 90 | 91 | labels = [CLASS_MAPPING[id_to_genres[int(get_id_from_path(x))]] for x in files] 92 | print(len(labels)) 93 | 94 | samples = list(zip(files, labels)) 95 | 96 | _train, test = train_test_split( 97 | samples, test_size=0.2, random_state=1337, stratify=[a[1] for a in samples] 98 | ) 99 | 100 | train, val = train_test_split( 101 | _train, test_size=0.1, random_state=1337, stratify=[a[1] for a in _train] 102 | ) 103 | 104 | train_data = AudioDataset(train, augment=True) 105 | test_data = AudioDataset(test, augment=False) 106 | val_data = AudioDataset(val, augment=False) 107 | 108 | train_loader = DataLoader( 109 | train_data, batch_size=batch_size, num_workers=8, shuffle=True 110 | ) 111 | val_loader = DataLoader( 112 | val_data, batch_size=batch_size, num_workers=8, shuffle=True 113 | ) 114 | test_loader = DataLoader( 115 | test_data, batch_size=batch_size, shuffle=False, num_workers=8 116 | ) 117 | 118 | model = AudioClassifier() 119 | 120 | if args.encoder_path is not None: 121 | checkpoint_callback = ModelCheckpoint( 122 | monitor="valid_acc", mode="max", filepath="models/", prefix="pretrained" 123 | ) 124 | logger = TensorBoardLogger( 125 | save_dir=".", name="lightning_logs", version="pretrained" 126 | ) 127 | 128 | ckpt = torch.load(args.encoder_path) 129 | 130 | model.load_state_dict(ckpt["state_dict"], strict=False) 131 | 132 | else: 133 | 134 | checkpoint_callback = ModelCheckpoint( 135 | monitor="valid_acc", mode="max", filepath="models/", prefix="scratch" 136 | ) 137 | 138 | logger = TensorBoardLogger( 139 | save_dir=".", name="lightning_logs", version="scratch" 140 | ) 141 | 142 | trainer = pl.Trainer( 143 | max_epochs=epochs, 144 | gpus=1, 145 | logger=logger, 146 | checkpoint_callback=checkpoint_callback, 147 | callbacks=[DecayLearningRate()], 148 | gradient_clip_val=1.0, 149 | ) 150 | trainer.fit(model, train_loader, val_loader) 151 | 152 | trainer.test(test_dataloaders=test_loader) 153 | -------------------------------------------------------------------------------- /supervised_examples/eval_cnn.py: -------------------------------------------------------------------------------- 1 | import json 2 | from glob import glob 3 | 4 | import torch 5 | from audio_encoder.encoder import AudioClassifier 6 | from supervised_examples.prepare_data import get_id_from_path 7 | from supervised_examples.cnn_genre_classification import AudioDataset 8 | from sklearn.model_selection import train_test_split 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | import pytorch_lightning as pl 13 | 14 | if __name__ == "__main__": 15 | import argparse 16 | from pathlib import Path 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--metadata_path", default="/media/ml/data_ml/fma_metadata/") 20 | parser.add_argument("--mp3_path", default="/media/ml/data_ml/fma_small/") 21 | 22 | args = parser.parse_args() 23 | 24 | metadata_path = Path(args.metadata_path) 25 | mp3_path = Path(args.mp3_path) 26 | 27 | batch_size = 32 28 | epochs = 64 29 | 30 | CLASS_MAPPING = json.load(open(metadata_path / "mapping.json")) 31 | id_to_genres = json.load(open(metadata_path / "tracks_genre.json")) 32 | id_to_genres = {int(k): v for k, v in id_to_genres.items()} 33 | 34 | files = sorted(list(glob(str(mp3_path / "*/*.npy")))) 35 | 36 | labels = [CLASS_MAPPING[id_to_genres[int(get_id_from_path(x))]] for x in files] 37 | print(len(labels)) 38 | 39 | samples = list(zip(files, labels)) 40 | 41 | _train, test = train_test_split( 42 | samples, test_size=0.2, random_state=1337, stratify=[a[1] for a in samples] 43 | ) 44 | 45 | train, val = train_test_split( 46 | _train, test_size=0.1, random_state=1337, stratify=[a[1] for a in _train] 47 | ) 48 | 49 | train_data = AudioDataset(train, augment=False) 50 | test_data = AudioDataset(test, augment=False) 51 | val_data = AudioDataset(val, augment=False) 52 | 53 | train_loader = DataLoader( 54 | train_data, batch_size=batch_size, num_workers=8, shuffle=True 55 | ) 56 | val_loader = DataLoader( 57 | val_data, batch_size=batch_size, num_workers=8, shuffle=True 58 | ) 59 | test_loader = DataLoader( 60 | test_data, batch_size=batch_size, shuffle=False, num_workers=8 61 | ) 62 | 63 | model_paths = [ 64 | "../models/pretrained-epoch=21.ckpt", 65 | "../models/scratch-epoch=52.ckpt", 66 | ] 67 | 68 | models = [AudioClassifier() for a in model_paths] 69 | for model, path in zip(models, model_paths): 70 | model.load_state_dict(torch.load(path)["state_dict"]) 71 | model.cuda() 72 | 73 | accuracies = [] 74 | accuracy = pl.metrics.Accuracy() 75 | accuracy.cuda() 76 | 77 | for model in tqdm(models): 78 | correct = 0 79 | model.eval() 80 | 81 | for x, y in tqdm(test_loader): 82 | x = x.cuda() 83 | y = y.cuda() 84 | 85 | y_pred = model(x) 86 | 87 | accuracy(y_pred, y) 88 | 89 | accuracies.append(accuracy.compute().item()) 90 | 91 | data = {"model_name": model_paths, "accuracies": accuracies} 92 | 93 | json.dump(data, open("test_accuracy.json", "w"), indent=4) 94 | -------------------------------------------------------------------------------- /supervised_examples/prepare_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | import sys 4 | import warnings 5 | 6 | import pandas as pd 7 | from pandas.api.types import CategoricalDtype 8 | 9 | if not sys.warnoptions: 10 | warnings.simplefilter("ignore") 11 | import json 12 | 13 | 14 | def load(filepath): 15 | # From https://github.com/mdeff/fma/blob/rc1/utils.py / MIT License 16 | 17 | filename = os.path.basename(filepath) 18 | 19 | if "features" in filename: 20 | return pd.read_csv(filepath, index_col=0, header=[0, 1, 2]) 21 | 22 | if "echonest" in filename: 23 | return pd.read_csv(filepath, index_col=0, header=[0, 1, 2]) 24 | 25 | if "genres" in filename: 26 | return pd.read_csv(filepath, index_col=0) 27 | 28 | if "tracks" in filename: 29 | tracks = pd.read_csv(filepath, index_col=0, header=[0, 1]) 30 | 31 | COLUMNS = [ 32 | ("track", "tags"), 33 | ("album", "tags"), 34 | ("artist", "tags"), 35 | ("track", "genres"), 36 | ("track", "genres_all"), 37 | ] 38 | for column in COLUMNS: 39 | tracks[column] = tracks[column].map(ast.literal_eval) 40 | 41 | COLUMNS = [ 42 | ("track", "date_created"), 43 | ("track", "date_recorded"), 44 | ("album", "date_created"), 45 | ("album", "date_released"), 46 | ("artist", "date_created"), 47 | ("artist", "active_year_begin"), 48 | ("artist", "active_year_end"), 49 | ] 50 | for column in COLUMNS: 51 | tracks[column] = pd.to_datetime(tracks[column]) 52 | 53 | SUBSETS = ("small", "medium", "large") 54 | tracks["set", "subset"] = tracks["set", "subset"].astype( 55 | CategoricalDtype(categories=SUBSETS, ordered=True) 56 | ) 57 | 58 | COLUMNS = [ 59 | ("track", "genre_top"), 60 | ("track", "license"), 61 | ("album", "type"), 62 | ("album", "information"), 63 | ("artist", "bio"), 64 | ] 65 | for column in COLUMNS: 66 | tracks[column] = tracks[column].astype("category") 67 | 68 | return tracks 69 | 70 | 71 | def get_id_from_path(path): 72 | base_name = os.path.basename(path) 73 | 74 | return base_name.replace(".mp3", "").replace(".npy", "") 75 | 76 | 77 | if __name__ == "__main__": 78 | import argparse 79 | from pathlib import Path 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--metadata_path") 83 | args = parser.parse_args() 84 | 85 | base_path = Path(args.metadata_path) 86 | 87 | in_path = base_path / "tracks.csv" 88 | genres_path = base_path / "genres.csv" 89 | 90 | out_path = base_path / "tracks_genre.json" 91 | mapping_path = base_path / "mapping.json" 92 | 93 | df = load(in_path) 94 | 95 | df2 = pd.read_csv(genres_path) 96 | 97 | id_to_title = {k: v for k, v in zip(df2.genre_id.tolist(), df2.title.tolist())} 98 | 99 | df.reset_index(inplace=True) 100 | 101 | print(df.head()) 102 | print(df.columns.values) 103 | print(set(df[("set", "subset")].tolist())) 104 | 105 | df = df[df[("set", "subset")].isin(["small"])] 106 | 107 | print(set(df[("track", "genre_top")].tolist())) 108 | 109 | print( 110 | df[ 111 | [ 112 | ("track_id", ""), 113 | ("track", "genre_top"), 114 | ("track", "genres"), 115 | ("set", "subset"), 116 | ] 117 | ] 118 | ) 119 | 120 | data = { 121 | k: v 122 | for k, v in zip( 123 | df[("track_id", "")].tolist(), df[("track", "genre_top")].tolist() 124 | ) 125 | } 126 | 127 | json.dump(data, open(out_path, "w"), indent=4) 128 | 129 | mapping = {k: i for i, k in enumerate(set(df[("track", "genre_top")].tolist()))} 130 | 131 | json.dump(mapping, open(mapping_path, "w"), indent=4) 132 | -------------------------------------------------------------------------------- /supervised_examples/test_accuracy.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": [ 3 | "../models/pretrained-epoch=21.ckpt", 4 | "../models/scratch-epoch=52.ckpt" 5 | ], 6 | "accuracies": [ 7 | 0.543749988079071, 8 | 0.5112500190734863 9 | ] 10 | } --------------------------------------------------------------------------------