├── env.sh ├── assets ├── acc_on_fold_0.png └── data_spectrograms.png ├── plot.py ├── data ├── audio_io.py └── gtzan.py ├── inference.py ├── models ├── cnn.py └── cnn14_transfer.py ├── README.md ├── train_accelerate.py ├── train.py └── train_fabric.py /env.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/acc_on_fold_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/mini_music_tagging/HEAD/assets/acc_on_fold_0.png -------------------------------------------------------------------------------- /assets/data_spectrograms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/mini_music_tagging/HEAD/assets/data_spectrograms.png -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import librosa 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from data.gtzan import GTZAN 7 | 8 | 9 | def plot(): 10 | 11 | root = "/datasets/gtzan" 12 | sr = 44100 13 | 14 | labels = GTZAN.labels 15 | 16 | fig, axs = plt.subplots(3, 4, sharex=True, figsize=(10, 4)) 17 | 18 | for i, label in enumerate(labels): 19 | 20 | audio_path = Path(root, "genres", label, "{}.00000.au".format(label)) 21 | 22 | audio, _ = librosa.load(path=audio_path, sr=sr, mono=True) 23 | 24 | mel_sp = librosa.feature.melspectrogram( 25 | y=audio, 26 | sr=sr, 27 | n_fft=2048, 28 | hop_length=441, 29 | n_mels=128, 30 | ) 31 | # (freq_bins, frames_num) 32 | 33 | axs[i // 4, i % 4].matshow(np.log(mel_sp), origin='lower', aspect='auto', cmap='jet') 34 | axs[i // 4, i % 4].set_title(label) 35 | axs[i // 4, i % 4].axis('off') 36 | 37 | for i in range(10, 12): 38 | axs[i // 4, i % 4].set_visible(False) 39 | 40 | plt.tight_layout(pad=0.5, h_pad=0.5, w_pad=0.5) 41 | 42 | out_path = "data_spectrograms.png" 43 | plt.savefig(out_path) 44 | print("Write out to {}".format(out_path)) 45 | 46 | 47 | if __name__ == "__main__": 48 | 49 | plot() -------------------------------------------------------------------------------- /data/audio_io.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | import torchaudio 7 | 8 | 9 | def load( 10 | path: str, 11 | sr: int, 12 | mono: bool = True, 13 | offset: float = 0., # Load start time 14 | duration: Union[float, None] = None # Load duration 15 | ) -> np.ndarray: 16 | r"""Load audio. 17 | 18 | Returns: 19 | audio: (channels, audio_samples) 20 | 21 | Examples: 22 | >>> audio = load_audio(path="xx/yy.wav", sr=16000) 23 | """ 24 | 25 | # Prepare arguments 26 | orig_sr = librosa.get_samplerate(path) 27 | 28 | seg_start_sample = round(offset * orig_sr) 29 | 30 | if duration is None: 31 | seg_samples = -1 32 | else: 33 | seg_samples = round(duration * orig_sr) 34 | 35 | # Load audio 36 | audio, fs = torchaudio.load( 37 | path, 38 | frame_offset=seg_start_sample, 39 | num_frames=seg_samples 40 | ) 41 | # (channels, audio_samples) 42 | 43 | # Resample. Faster than librosa 44 | audio = torchaudio.functional.resample( 45 | waveform=audio, 46 | orig_freq=orig_sr, 47 | new_freq=sr 48 | ) 49 | # shape: (channels, audio_samples) 50 | 51 | if mono: 52 | audio = torch.mean(audio, dim=0, keepdim=True) 53 | 54 | audio = audio.numpy() 55 | 56 | return audio 57 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | from torch.utils.data.sampler import SequentialSampler 6 | 7 | from data.gtzan import GTZAN 8 | from train import get_model 9 | 10 | 11 | def inference(args): 12 | 13 | # Arguments 14 | model_name = args.model_name 15 | 16 | # Default parameters 17 | test_fold = 0 18 | sr = 16000 19 | batch_size = 16 20 | device = "cuda" 21 | filename = Path(__file__).stem 22 | classes_num = GTZAN.classes_num 23 | 24 | root = "/datasets/gtzan" 25 | 26 | # Load checkpoint 27 | checkpoint_path = Path("checkpoints", "train", model_name, "latest.pth") 28 | 29 | model = get_model(model_name, classes_num) 30 | model.load_state_dict(torch.load(checkpoint_path)) 31 | model.to(device) 32 | 33 | # Test dataset 34 | test_dataset = GTZAN( 35 | root=root, 36 | split="test", 37 | test_fold=test_fold, 38 | sr=sr, 39 | ) 40 | 41 | # Test sampler 42 | test_sampler = SequentialSampler(test_dataset) 43 | 44 | # Dataloader 45 | test_dataloader = torch.utils.data.DataLoader( 46 | dataset=test_dataset, 47 | batch_size=batch_size, 48 | sampler=test_sampler, 49 | num_workers=16, 50 | pin_memory=True 51 | ) 52 | 53 | pred_ids = [] 54 | target_ids = [] 55 | 56 | for data in test_dataloader: 57 | 58 | segment = torch.Tensor(data["audio"]).to(device) 59 | target = data["target"] 60 | 61 | with torch.no_grad(): 62 | model.eval() 63 | output = model(audio=segment) 64 | 65 | pred_ids.append(np.argmax(output.cpu().numpy(), axis=-1)) 66 | target_ids.append(np.argmax(target, axis=-1)) 67 | 68 | pred_ids = np.concatenate(pred_ids, axis=0) 69 | target_ids = np.concatenate(target_ids, axis=0) 70 | 71 | accuracy = np.mean(pred_ids == target_ids) 72 | 73 | print("Accuracy: {:.3f}".format(accuracy)) 74 | 75 | 76 | if __name__ == "__main__": 77 | 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('--model_name', type=str, default="Cnn") 80 | args = parser.parse_args() 81 | 82 | inference(args) -------------------------------------------------------------------------------- /models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from torchaudio.transforms import MelSpectrogram 6 | 7 | 8 | class Cnn(nn.Module): 9 | def __init__(self, classes_num): 10 | super().__init__() 11 | 12 | self.mel_extractor = MelSpectrogram( 13 | sample_rate=16000, 14 | n_fft=2048, 15 | hop_length=160, 16 | f_min=0., 17 | f_max=8000, 18 | n_mels=128, 19 | power=2.0, 20 | normalized=True, 21 | ) 22 | 23 | self.conv1 = ConvBlock(in_channels=1, out_channels=32) 24 | self.conv2 = ConvBlock(in_channels=32, out_channels=64) 25 | self.conv3 = ConvBlock(in_channels=64, out_channels=128) 26 | self.conv4 = ConvBlock(in_channels=128, out_channels=256) 27 | 28 | self.onset_fc = nn.Linear(256, classes_num) 29 | 30 | def forward(self, audio): 31 | r""" 32 | Args: 33 | audio: (batch_size, channels_num, samples_num) 34 | 35 | Outputs: 36 | output: (batch_size, classes_num) 37 | """ 38 | 39 | x = self.mel_extractor(audio) 40 | # shape: (B, 1, F, T) 41 | 42 | x = torch.log10(torch.clamp(x, 1e-8)) 43 | 44 | x = rearrange(x, 'b c f t -> b c t f') 45 | # shape: (B, 1, T, F) 46 | 47 | x = self.conv1(x) 48 | x = self.conv2(x) 49 | x = self.conv3(x) 50 | x = self.conv4(x) 51 | # shape: (B, C, T, F) 52 | 53 | x, _ = torch.max(x, dim=-1) 54 | x, _ = torch.max(x, dim=-1) 55 | 56 | output = torch.sigmoid(self.onset_fc(x)) 57 | 58 | return output 59 | 60 | 61 | class ConvBlock(nn.Module): 62 | def __init__(self, in_channels, out_channels): 63 | super().__init__() 64 | 65 | self.conv1 = nn.Conv2d( 66 | in_channels=in_channels, 67 | out_channels=out_channels, 68 | kernel_size=(3, 3), 69 | padding=(1, 1), 70 | ) 71 | self.conv2 = nn.Conv2d( 72 | in_channels=out_channels, 73 | out_channels=out_channels, 74 | kernel_size=(3, 3), 75 | padding=(1, 1), 76 | ) 77 | 78 | self.bn1 = nn.BatchNorm2d(out_channels) 79 | self.bn2 = nn.BatchNorm2d(out_channels) 80 | 81 | def forward(self, x): 82 | r""" 83 | Args: 84 | x: (batch_size, in_channels, time_steps, freq_bins) 85 | 86 | Returns: 87 | output: (batch_size, out_channels, time_steps // 2, freq_bins // 2) 88 | """ 89 | 90 | x = F.relu_(self.bn1(self.conv1(x))) 91 | x = F.relu_(self.bn2(self.conv2(x))) 92 | 93 | output = F.avg_pool2d(x, kernel_size=(2, 2)) 94 | 95 | return output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # A Minimal Implementation of Music tagging 3 | 4 | This is an minimal implementation of music taggign with PyTorch. We use the GTZAN dataset containing 1,000 30-second audio clips for training and validation. The GTZAN dataset contains 10 genres. We use 900 audio files for training and use 100 audio files for validation. We train a convolutional neural network as classifier. 5 | 6 | ## 0. Download dataset 7 | 8 | The original link dataset link: [http://marsyas.info/index.html](http://marsyas.info/index.html) is not available anymore. Please search other sources to download the dataset. Here are the log mel spectrograms of different genre audios. 9 | 10 |  11 | 12 | The downloaded dataset looks like: 13 | 14 |
15 | dataset_root (1.3 GB) 16 | └── genres 17 | ├── blues (100 files) 18 | ├── classical (100 files) 19 | ├── country (100 files) 20 | ├── disco (100 files) 21 | ├── hiphop (100 files) 22 | ├── jazz (100 files) 23 | ├── metal (100 files) 24 | ├── pop (100 files) 25 | ├── reggae (100 files) 26 | └── rock (100 files) 27 |28 | 29 | ## 1. Install dependencies 30 | 31 | ```bash 32 | git clone https://github.com/qiuqiangkong/mini_music_tagging 33 | 34 | # Install Python environment. 35 | conda create --name music_tagging python=3.8 36 | 37 | # Activate environment. 38 | conda activate music_tagging 39 | 40 | # Install Python packages dependencies. 41 | sh env.sh 42 | ``` 43 | 44 | ## 2. Single GPU training 45 | 46 | We use the Wandb toolkit for logging. You may set wandb_log to False or use other loggers. 47 | 48 | ```python 49 | CUDA_VISIBLE_DEVICES=0 python train.py 50 | ``` 51 | 52 | ## 3. Multiple GPUs training 53 | 54 | We use Huggingface accelerate toolkit for multiple GPUs training. Here is an example of using 4 GPUs for training. 55 | 56 | ```python 57 | CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --multi_gpu --num_processes 4 train_accelerate.py 58 | ``` 59 | 60 | The training takes around 20 min to train for 10,000 steps on a single RTX4090 GPU card. The result looks like: 61 | 62 |
63 | 0it [00:00, ?it/s]step: 0, loss: 0.865 64 | Accuracy: 0.1 65 | Save model to checkpoints/train/Cnn/step=0.pth 66 | Save model to checkpoints/train/Cnn/latest.pth 67 | 200it [00:31, 7.80it/s]step: 200, loss: 0.159 68 | Accuracy: 0.48 69 | Save model to checkpoints/train/Cnn/step=200.pth 70 | Save model to checkpoints/train/Cnn/latest.pth 71 | ... 72 | Accuracy: 0.64 73 | Save model to checkpoints/train/Cnn/step=10000.pth 74 | Save model to checkpoints/train/Cnn/latest.pth 75 |76 | 77 | The validation accuracy during training looks like: 78 | 79 |  80 | 81 | # 4. Inference 82 | 83 | Users may use the trained checkpoints for inference. 84 | 85 | ```python 86 | CUDA_VISIBLE_DEVICES=0 python inference.py 87 | ``` 88 | 89 | For example, we test on fold 0 and get the following results: 90 | 91 |
92 | Accuracy: 0.670 93 |94 | 95 | ## Reference 96 | 97 | ``` 98 | @article{kong2020panns, 99 | title={Panns: Large-scale pretrained audio neural networks for audio pattern recognition}, 100 | author={Kong, Qiuqiang and Cao, Yin and Iqbal, Turab and Wang, Yuxuan and Wang, Wenwu and Plumbley, Mark D}, 101 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 102 | volume={28}, 103 | pages={2880--2894}, 104 | year={2020}, 105 | } 106 | ``` -------------------------------------------------------------------------------- /train_accelerate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.optim as optim 6 | from accelerate import Accelerator 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.sampler import SequentialSampler 9 | from tqdm import tqdm 10 | 11 | import wandb 12 | 13 | wandb.require("core") 14 | 15 | from data.gtzan import GTZAN 16 | from train import InfiniteSampler, bce_loss, get_model, validate 17 | 18 | 19 | def train(args): 20 | 21 | # Arguments 22 | model_name = args.model_name 23 | 24 | # Default parameters 25 | test_fold = 0 26 | sr = 16000 27 | batch_size = 16 28 | num_workers = 16 29 | pin_memory = True 30 | learning_rate = 1e-4 31 | test_step_frequency = 200 32 | save_step_frequency = 200 33 | training_steps = 10000 34 | wandb_log = True 35 | 36 | filename = Path(__file__).stem 37 | classes_num = GTZAN.classes_num 38 | 39 | if wandb_log: 40 | wandb.init(project="mini_music_tagging") 41 | 42 | checkpoints_dir = Path("./checkpoints", filename, model_name) 43 | 44 | root = "/datasets/gtzan" 45 | 46 | # Dataset 47 | train_dataset = GTZAN( 48 | root=root, 49 | split="train", 50 | test_fold=test_fold, 51 | sr=sr, 52 | ) 53 | 54 | test_dataset = GTZAN( 55 | root=root, 56 | split="test", 57 | test_fold=test_fold, 58 | sr=sr, 59 | ) 60 | 61 | # Sampler 62 | train_sampler = InfiniteSampler(train_dataset) 63 | 64 | test_sampler = SequentialSampler(test_dataset) 65 | 66 | # Dataloader 67 | train_dataloader = DataLoader( 68 | dataset=train_dataset, 69 | batch_size=batch_size, 70 | sampler=train_sampler, 71 | num_workers=num_workers, 72 | pin_memory=pin_memory 73 | ) 74 | 75 | test_dataloader = DataLoader( 76 | dataset=test_dataset, 77 | batch_size=batch_size, 78 | sampler=test_sampler, 79 | num_workers=1, 80 | pin_memory=pin_memory 81 | ) 82 | 83 | # Model 84 | model = get_model(model_name, classes_num) 85 | 86 | # Optimizer 87 | optimizer = optim.AdamW(model.parameters(), lr=learning_rate) 88 | 89 | # Prepare for multiprocessing 90 | accelerator = Accelerator() 91 | 92 | model, optimizer, train_dataloader = accelerator.prepare( 93 | model, optimizer, train_dataloader) 94 | 95 | # Create checkpoints directory 96 | Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) 97 | 98 | # Train 99 | for step, data in enumerate(tqdm(train_dataloader)): 100 | 101 | audio = data["audio"] 102 | target = data["target"] 103 | 104 | # Forward 105 | model.train() 106 | output = model(audio=audio) 107 | 108 | # Loss 109 | loss = bce_loss(output, target) 110 | 111 | # Optimize 112 | optimizer.zero_grad() # Reset all parameter.grad to 0 113 | accelerator.backward(loss) # Update all parameter.grad 114 | optimizer.step() # Update all parameters based on all parameter.grad 115 | 116 | # Evaluate 117 | if step % test_step_frequency == 0: 118 | 119 | accelerator.wait_for_everyone() 120 | 121 | if accelerator.is_main_process: 122 | 123 | if accelerator.num_processes == 1: 124 | val_model = model 125 | else: 126 | val_model = model.module 127 | 128 | test_acc = validate(val_model, test_dataloader) 129 | print("Test Accuracy: {}".format(test_acc)) 130 | 131 | if wandb_log: 132 | wandb.log( 133 | data={"test_acc": test_acc}, 134 | step=step 135 | ) 136 | 137 | # Save model 138 | if step % save_step_frequency == 0: 139 | 140 | accelerator.wait_for_everyone() 141 | 142 | if accelerator.is_main_process: 143 | 144 | unwrapped_model = accelerator.unwrap_model(model) 145 | 146 | checkpoint_path = Path(checkpoints_dir, "step={}.pth".format(step)) 147 | torch.save(unwrapped_model.state_dict(), checkpoint_path) 148 | print("Save model to {}".format(checkpoint_path)) 149 | 150 | checkpoint_path = Path(checkpoints_dir, "latest.pth") 151 | torch.save(unwrapped_model.state_dict(), Path(checkpoint_path)) 152 | print("Save model to {}".format(checkpoint_path)) 153 | 154 | 155 | if step == training_steps: 156 | break 157 | 158 | 159 | if __name__ == "__main__": 160 | 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument('--model_name', type=str, default="Cnn") 163 | args = parser.parse_args() 164 | 165 | train(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.sampler import SequentialSampler 11 | from tqdm import tqdm 12 | 13 | import wandb 14 | 15 | wandb.require("core") 16 | 17 | from data.gtzan import GTZAN 18 | from models.cnn import Cnn 19 | 20 | 21 | def train(args): 22 | 23 | # Arguments 24 | model_name = args.model_name 25 | 26 | # Default parameters 27 | test_fold = 0 28 | sr = 16000 29 | batch_size = 16 30 | num_workers = 16 31 | pin_memory = True 32 | learning_rate = 1e-4 33 | test_step_frequency = 200 34 | save_step_frequency = 200 35 | training_steps = 10000 36 | wandb_log = True 37 | device = "cuda" 38 | 39 | filename = Path(__file__).stem 40 | classes_num = GTZAN.classes_num 41 | 42 | checkpoints_dir = Path("./checkpoints", filename, model_name) 43 | 44 | root = "/datasets/gtzan" 45 | 46 | if wandb_log: 47 | wandb.init(project="mini_music_tagging") 48 | 49 | # Dataset 50 | train_dataset = GTZAN( 51 | root=root, 52 | split="train", 53 | test_fold=test_fold, 54 | sr=sr, 55 | ) 56 | 57 | test_dataset = GTZAN( 58 | root=root, 59 | split="test", 60 | test_fold=test_fold, 61 | sr=sr, 62 | ) 63 | 64 | # Sampler 65 | train_sampler = InfiniteSampler(train_dataset) 66 | 67 | test_sampler = SequentialSampler(test_dataset) 68 | 69 | # Dataloader 70 | train_dataloader = DataLoader( 71 | dataset=train_dataset, 72 | batch_size=batch_size, 73 | sampler=train_sampler, 74 | num_workers=num_workers, 75 | pin_memory=pin_memory 76 | ) 77 | 78 | test_dataloader = DataLoader( 79 | dataset=test_dataset, 80 | batch_size=batch_size, 81 | sampler=test_sampler, 82 | num_workers=1, 83 | pin_memory=pin_memory 84 | ) 85 | 86 | # Model 87 | model = get_model(model_name, classes_num) 88 | model.to(device) 89 | 90 | # Optimizer 91 | optimizer = optim.AdamW(model.parameters(), lr=learning_rate) 92 | 93 | # Create checkpoints directory 94 | Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) 95 | 96 | # Train 97 | for step, data in enumerate(tqdm(train_dataloader)): 98 | 99 | # Move data to device 100 | audio = data["audio"].to(device) 101 | target = data["target"].to(device) 102 | 103 | # Forward 104 | model.train() 105 | output = model(audio=audio) 106 | 107 | # Loss 108 | loss = bce_loss(output, target) 109 | 110 | # Optimize 111 | optimizer.zero_grad() # Reset all parameter.grad to 0 112 | loss.backward() # Update all parameter.grad 113 | optimizer.step() # Update all parameters based on all parameter.grad 114 | 115 | if step % test_step_frequency == 0: 116 | print("step: {}, loss: {:.3f}".format(step, loss.item())) 117 | test_acc = validate(model, test_dataloader) 118 | print("Accuracy: {}".format(test_acc)) 119 | 120 | if wandb_log: 121 | wandb.log( 122 | data={"test_acc": test_acc}, 123 | step=step 124 | ) 125 | 126 | # Save model 127 | if step % save_step_frequency == 0: 128 | checkpoint_path = Path(checkpoints_dir, "step={}.pth".format(step)) 129 | torch.save(model.state_dict(), checkpoint_path) 130 | print("Save model to {}".format(checkpoint_path)) 131 | 132 | checkpoint_path = Path(checkpoints_dir, "latest.pth") 133 | torch.save(model.state_dict(), Path(checkpoint_path)) 134 | print("Save model to {}".format(checkpoint_path)) 135 | 136 | if step == training_steps: 137 | break 138 | 139 | 140 | def get_model(model_name, classes_num): 141 | if model_name == "Cnn": 142 | return Cnn(classes_num) 143 | else: 144 | raise NotImplementedError 145 | 146 | 147 | def bce_loss(output, target): 148 | return F.binary_cross_entropy(output, target) 149 | 150 | 151 | class InfiniteSampler: 152 | def __init__(self, dataset): 153 | 154 | self.indexes = list(range(len(dataset))) 155 | random.shuffle(self.indexes) 156 | 157 | def __iter__(self): 158 | 159 | pointer = 0 160 | 161 | while True: 162 | 163 | if pointer == len(self.indexes): 164 | random.shuffle(self.indexes) 165 | pointer = 0 166 | 167 | index = self.indexes[pointer] 168 | pointer += 1 169 | 170 | yield index 171 | 172 | 173 | def validate(model, dataloader): 174 | 175 | device = next(model.parameters()).device 176 | 177 | pred_ids = [] 178 | target_ids = [] 179 | 180 | for step, data in enumerate(dataloader): 181 | 182 | segment = torch.Tensor(data["audio"]).to(device) 183 | target = torch.Tensor(data["target"]).to(device) 184 | 185 | with torch.no_grad(): 186 | model.eval() 187 | output = model(audio=segment) 188 | 189 | pred_ids.append(np.argmax(output.cpu().numpy(), axis=-1)) 190 | target_ids.append(np.argmax(target.cpu().numpy(), axis=-1)) 191 | 192 | pred_ids = np.concatenate(pred_ids, axis=0) 193 | target_ids = np.concatenate(target_ids, axis=0) 194 | accuracy = np.mean(pred_ids == target_ids) 195 | 196 | return accuracy 197 | 198 | 199 | if __name__ == "__main__": 200 | 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('--model_name', type=str, default="Cnn") 203 | args = parser.parse_args() 204 | 205 | train(args) -------------------------------------------------------------------------------- /train_fabric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | import librosa 5 | import numpy as np 6 | import soundfile 7 | import matplotlib.pyplot as plt 8 | from pathlib import Path 9 | import torch.optim as optim 10 | from data.gtzan import Gtzan, CLASSES_NUM 11 | from data.collate import collate_fn 12 | from models.cnn import Cnn 13 | from tqdm import tqdm 14 | import museval 15 | import argparse 16 | import random 17 | from torch.utils.data.sampler import SequentialSampler 18 | 19 | 20 | def train(args): 21 | 22 | # Arguments 23 | model_name = args.model_name 24 | 25 | # Default parameters 26 | fold = 0 27 | batch_size = 16 28 | num_workers = 16 29 | test_step_frequency = 200 30 | save_step_frequency = 200 31 | training_steps = 10000 32 | debug = False 33 | device = "cuda" 34 | filename = Path(__file__).stem 35 | classes_num = CLASSES_NUM 36 | 37 | checkpoints_dir = Path("./checkpoints", filename, model_name) 38 | 39 | root = "/datasets/gtzan" 40 | 41 | # Dataset 42 | train_dataset = Gtzan( 43 | root=root, 44 | split="train", 45 | fold=fold, 46 | ) 47 | 48 | test_dataset = Gtzan( 49 | root=root, 50 | split="test", 51 | fold=fold, 52 | ) 53 | 54 | # Sampler 55 | train_sampler = Sampler(dataset_size=len(train_dataset)) 56 | 57 | test_sampler = SequentialSampler(test_dataset) 58 | 59 | # Dataloader 60 | train_dataloader = torch.utils.data.DataLoader( 61 | dataset=train_dataset, 62 | batch_size=batch_size, 63 | sampler=train_sampler, 64 | collate_fn=collate_fn, 65 | num_workers=num_workers, 66 | pin_memory=True 67 | ) 68 | 69 | test_dataloader = torch.utils.data.DataLoader( 70 | dataset=test_dataset, 71 | batch_size=batch_size, 72 | sampler=test_sampler, 73 | collate_fn=collate_fn, 74 | num_workers=num_workers, 75 | pin_memory=True 76 | ) 77 | 78 | # Model 79 | model = get_model(model_name, classes_num) 80 | model.to(device) 81 | 82 | # Optimizer 83 | optimizer = optim.AdamW(model.parameters(), lr=0.001) 84 | 85 | # Create checkpoints directory 86 | Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) 87 | 88 | # Train 89 | for step, data in enumerate(tqdm(train_dataloader)): 90 | 91 | # Move data to device 92 | audio = data["audio"].to(device) 93 | target = data["target"].to(device) 94 | 95 | # Play the audio 96 | if debug: 97 | play_audio(mixture, target) 98 | 99 | # Forward 100 | model.train() 101 | output = model(audio=audio) 102 | 103 | # Loss 104 | loss = bce_loss(output, target) 105 | 106 | # Optimize 107 | optimizer.zero_grad() # Reset parameter.grad to 0 108 | loss.backward() # Update parameter.grad 109 | optimizer.step() # Update parameters based on parameter.grad 110 | from IPython import embed; embed(using=False); os._exit(0) 111 | 112 | if step % test_step_frequency == 0: 113 | print("step: {}, loss: {:.3f}".format(step, loss.item())) 114 | accuracy = validate(model, test_dataloader) 115 | print("Accuracy: {}".format(accuracy)) 116 | 117 | # Save model 118 | if step % save_step_frequency == 0: 119 | checkpoint_path = Path(checkpoints_dir, "step={}.pth".format(step)) 120 | torch.save(model.state_dict(), checkpoint_path) 121 | print("Save model to {}".format(checkpoint_path)) 122 | 123 | checkpoint_path = Path(checkpoints_dir, "latest.pth") 124 | torch.save(model.state_dict(), Path(checkpoint_path)) 125 | print("Save model to {}".format(checkpoint_path)) 126 | 127 | if step == training_steps: 128 | break 129 | 130 | 131 | def get_model(model_name, classes_num): 132 | if model_name == "Cnn": 133 | return Cnn(classes_num) 134 | else: 135 | raise NotImplementedError 136 | 137 | 138 | def bce_loss(output, target): 139 | return F.binary_cross_entropy(output, target) 140 | 141 | 142 | def play_audio(mixture, target): 143 | soundfile.write(file="tmp_mixture.wav", data=mixture[0].cpu().numpy().T, samplerate=44100) 144 | soundfile.write(file="tmp_target.wav", data=target[0].cpu().numpy().T, samplerate=44100) 145 | from IPython import embed; embed(using=False); os._exit(0) 146 | 147 | 148 | class Sampler: 149 | def __init__(self, dataset_size): 150 | self.indexes = list(range(dataset_size)) 151 | random.shuffle(self.indexes) 152 | 153 | def __iter__(self): 154 | 155 | pointer = 0 156 | 157 | while True: 158 | 159 | if pointer == len(self.indexes): 160 | random.shuffle(self.indexes) 161 | pointer = 0 162 | 163 | index = self.indexes[pointer] 164 | pointer += 1 165 | 166 | yield index 167 | 168 | 169 | def validate(model, dataloader): 170 | 171 | device = next(model.parameters()).device 172 | 173 | pred_ids = [] 174 | target_ids = [] 175 | 176 | for step, data in enumerate(dataloader): 177 | 178 | segment = torch.Tensor(data["audio"]).to(device) 179 | target = torch.Tensor(data["target"]).to(device) 180 | 181 | with torch.no_grad(): 182 | model.eval() 183 | output = model(audio=segment) 184 | 185 | pred_ids.append(np.argmax(output.cpu().numpy(), axis=-1)) 186 | target_ids.append(np.argmax(target.cpu().numpy(), axis=-1)) 187 | 188 | pred_ids = np.concatenate(pred_ids, axis=0) 189 | target_ids = np.concatenate(target_ids, axis=0) 190 | accuracy = np.mean(pred_ids == target_ids) 191 | 192 | return accuracy 193 | 194 | 195 | if __name__ == "__main__": 196 | 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('--model_name', type=str, default="Cnn") 199 | args = parser.parse_args() 200 | 201 | train(args) -------------------------------------------------------------------------------- /data/gtzan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from pathlib import Path 4 | from typing import Callable, Dict, Optional, Tuple, Union 5 | 6 | import librosa 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | from data.audio_io import load 11 | 12 | 13 | class GTZAN(Dataset): 14 | r"""GTZAN [1] is a music dataset containing 1000 30-second music. 15 | GTZAN contains 10 genres. Audios are sampled at 22,050 Hz. Dataset size is 1.3 GB. 16 | 17 | [1] Tzanetakis, G., et al., Musical genre classification of audio signals. 2002 18 | 19 | The dataset looks like: 20 | 21 | dataset_root (1.3 GB) 22 | └── genres 23 | ├── blues (100 files) 24 | ├── classical (100 files) 25 | ├── country (100 files) 26 | ├── disco (100 files) 27 | ├── hiphop (100 files) 28 | ├── jazz (100 files) 29 | ├── metal (100 files) 30 | ├── pop (100 files) 31 | ├── reggae (100 files) 32 | └── rock (100 files) 33 | """ 34 | 35 | url = "http://marsyas.info/index.html" 36 | 37 | duration = 30024.07 # Dataset duration (s), including training, validation, and testing. 38 | 39 | labels = ["blues", "classical", "country", "disco", "hiphop", "jazz", 40 | "metal", "pop", "reggae", "rock"] 41 | 42 | classes_num = len(labels) 43 | lb_to_ix = {lb: ix for ix, lb in enumerate(labels)} 44 | ix_to_lb = {ix: lb for ix, lb in enumerate(labels)} 45 | 46 | def __init__( 47 | self, 48 | root: str = None, 49 | split: Union["train", "test"] = "train", 50 | test_fold: int = 0, # E.g., fold 0 is used for testing. Fold 1 - 9 are used for training. 51 | sr: float = 16000, # Sampling rate 52 | clip_duration = 30., 53 | transform: Optional[Callable] = None, 54 | target_transform: Optional[Callable] = None 55 | ) -> None: 56 | 57 | self.root = root 58 | self.split = split 59 | self.test_fold = test_fold 60 | self.sr = sr 61 | self.transform = transform 62 | self.target_transform = target_transform 63 | 64 | self.audio_samples = int(clip_duration * self.sr) 65 | 66 | if not Path(root).exists(): 67 | raise "Please download the GTZAN dataset from {} (Invalid anymore. Please search a source)".format(GTZAN.url) 68 | 69 | self.meta_dict = self.load_meta() 70 | # E.g., meta_dict = { 71 | # "label": ["blues", "disco", ...], 72 | # "audio_name": ["blues.00010.au", "disco00005.au", ...], 73 | # "audio_path": ["path/blues.00010.au", "path/disco00005.au", ...] 74 | # } 75 | 76 | def __getitem__(self, index: int) -> Dict: 77 | 78 | audio_path = self.meta_dict["audio_path"][index] 79 | label = self.meta_dict["label"][index] 80 | 81 | # Load audio 82 | audio = self.load_audio(path=audio_path) 83 | # shape: (channels, audio_samples) 84 | 85 | # Load target 86 | target_data = self.load_target(label=label) 87 | # shape: (classes_num,) 88 | 89 | full_data = { 90 | "audio_path": str(audio_path), 91 | "audio": audio 92 | } 93 | 94 | # Merge dict 95 | full_data.update(target_data) 96 | 97 | return full_data 98 | 99 | def __len__(self) -> int: 100 | 101 | audios_num = len(self.meta_dict["audio_name"]) 102 | 103 | return audios_num 104 | 105 | def load_meta(self) -> Dict: 106 | r"""Load metadata of the GTZAN dataset. 107 | """ 108 | 109 | labels = GTZAN.labels 110 | 111 | meta_dict = { 112 | "label": [], 113 | "audio_name": [], 114 | "audio_path": [] 115 | } 116 | 117 | audios_dir = Path(self.root, "genres") 118 | 119 | for genre in labels: 120 | 121 | audio_names = sorted(os.listdir(Path(audios_dir, genre))) 122 | # len(audio_names) = 1000 123 | 124 | train_audio_names, test_audio_names = self.split_train_test(audio_names) 125 | # len(train_audio_names) = 900 126 | # len(test_audio_names) = 100 127 | 128 | if self.split == "train": 129 | filtered_audio_names = train_audio_names 130 | 131 | elif self.split == "test": 132 | filtered_audio_names = test_audio_names 133 | 134 | for audio_name in filtered_audio_names: 135 | 136 | audio_path = Path(audios_dir, genre, audio_name) 137 | 138 | meta_dict["label"].append(genre) 139 | meta_dict["audio_name"].append(audio_name) 140 | meta_dict["audio_path"].append(audio_path) 141 | 142 | return meta_dict 143 | 144 | def split_train_test(self, audio_names: list) -> Tuple[list, list]: 145 | 146 | train_audio_names = [] 147 | test_audio_names = [] 148 | 149 | test_ids = range(self.test_fold * 10, (self.test_fold + 1) * 10) 150 | # E.g., if test_fold = 3, then test_ids = [30, 31, 32, ..., 39] 151 | 152 | for audio_name in audio_names: 153 | 154 | audio_id = int(re.search(r'\d+', audio_name).group()) 155 | # E.g., if audio_name is "blues.00037.au", then audio_id = 37 156 | 157 | if audio_id in test_ids: 158 | test_audio_names.append(audio_name) 159 | 160 | else: 161 | train_audio_names.append(audio_name) 162 | 163 | return train_audio_names, test_audio_names 164 | 165 | def load_audio(self, path: str) -> np.ndarray: 166 | 167 | audio = load(path=path, sr=self.sr) 168 | # shape: (channels, audio_samples) 169 | 170 | audio = librosa.util.fix_length(data=audio, size=self.audio_samples, axis=-1) 171 | # shape: (channels, audio_samples) 172 | 173 | if self.transform is not None: 174 | audio = self.transform(audio) 175 | 176 | return audio 177 | 178 | def load_target(self, label: str) -> np.ndarray: 179 | 180 | classes_num = GTZAN.classes_num 181 | lb_to_ix = GTZAN.lb_to_ix 182 | 183 | target = np.zeros(classes_num, dtype="float32") 184 | class_ix = lb_to_ix[label] 185 | target[class_ix] = 1 186 | # target: (classes_num,) 187 | 188 | target_data = { 189 | "target": target, 190 | "label": label 191 | } 192 | 193 | if self.target_transform: 194 | target_data = self.target_transform(target_data) 195 | 196 | return target_data 197 | 198 | 199 | if __name__ == "__main__": 200 | 201 | # Example 202 | from torch.utils.data import DataLoader 203 | 204 | root = "/datasets/gtzan" 205 | 206 | sr = 16000 207 | 208 | dataset = GTZAN( 209 | root=root, 210 | split="train", 211 | test_fold=0, 212 | sr=sr 213 | ) 214 | 215 | dataloader = DataLoader(dataset=dataset, batch_size=4) 216 | 217 | for data in dataloader: 218 | 219 | n = 0 220 | audio_path = data["audio_path"][n] 221 | audio = data["audio"][n].cpu().numpy() 222 | target = data["target"][n].cpu().numpy() 223 | label = data["label"][n] 224 | break 225 | 226 | # ------ Visualize ------ 227 | print("audio_path:", audio_path) 228 | print("audio:", audio.shape) 229 | print("target:", target.shape) 230 | print("label:", label) -------------------------------------------------------------------------------- /models/cnn14_transfer.py: -------------------------------------------------------------------------------- 1 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 2 | from torchlibrosa.augmentation import SpecAugmentation 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # from pytorch_utils import do_mixup, interpolate, pad_framewise_output 8 | 9 | 10 | def init_layer(layer): 11 | """Initialize a Linear or Convolutional layer. """ 12 | nn.init.xavier_uniform_(layer.weight) 13 | 14 | if hasattr(layer, 'bias'): 15 | if layer.bias is not None: 16 | layer.bias.data.fill_(0.) 17 | 18 | 19 | def init_bn(bn): 20 | """Initialize a Batchnorm layer. """ 21 | bn.bias.data.fill_(0.) 22 | bn.weight.data.fill_(1.) 23 | 24 | 25 | class ConvBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels): 27 | 28 | super(ConvBlock, self).__init__() 29 | 30 | self.conv1 = nn.Conv2d(in_channels=in_channels, 31 | out_channels=out_channels, 32 | kernel_size=(3, 3), stride=(1, 1), 33 | padding=(1, 1), bias=False) 34 | 35 | self.conv2 = nn.Conv2d(in_channels=out_channels, 36 | out_channels=out_channels, 37 | kernel_size=(3, 3), stride=(1, 1), 38 | padding=(1, 1), bias=False) 39 | 40 | self.bn1 = nn.BatchNorm2d(out_channels) 41 | self.bn2 = nn.BatchNorm2d(out_channels) 42 | 43 | self.init_weight() 44 | 45 | def init_weight(self): 46 | init_layer(self.conv1) 47 | init_layer(self.conv2) 48 | init_bn(self.bn1) 49 | init_bn(self.bn2) 50 | 51 | 52 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 53 | 54 | x = input 55 | x = F.relu_(self.bn1(self.conv1(x))) 56 | x = F.relu_(self.bn2(self.conv2(x))) 57 | if pool_type == 'max': 58 | x = F.max_pool2d(x, kernel_size=pool_size) 59 | elif pool_type == 'avg': 60 | x = F.avg_pool2d(x, kernel_size=pool_size) 61 | elif pool_type == 'avg+max': 62 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 63 | x2 = F.max_pool2d(x, kernel_size=pool_size) 64 | x = x1 + x2 65 | else: 66 | raise Exception('Incorrect argument!') 67 | 68 | return x 69 | 70 | 71 | class Cnn14(nn.Module): 72 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 73 | fmax, classes_num): 74 | 75 | super(Cnn14, self).__init__() 76 | 77 | window = 'hann' 78 | center = True 79 | pad_mode = 'reflect' 80 | ref = 1.0 81 | amin = 1e-10 82 | top_db = None 83 | 84 | # Spectrogram extractor 85 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 86 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 87 | freeze_parameters=True) 88 | 89 | # Logmel feature extractor 90 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 91 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 92 | freeze_parameters=True) 93 | 94 | # Spec augmenter 95 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 96 | freq_drop_width=8, freq_stripes_num=2) 97 | 98 | self.bn0 = nn.BatchNorm2d(64) 99 | 100 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 101 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 102 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 103 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 104 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 105 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 106 | 107 | self.fc1 = nn.Linear(2048, 2048, bias=True) 108 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 109 | 110 | self.init_weight() 111 | 112 | def init_weight(self): 113 | init_bn(self.bn0) 114 | init_layer(self.fc1) 115 | init_layer(self.fc_audioset) 116 | 117 | def forward(self, input, mixup_lambda=None): 118 | """ 119 | Input: (batch_size, data_length)""" 120 | 121 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 122 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 123 | 124 | x = x.transpose(1, 3) 125 | x = self.bn0(x) 126 | x = x.transpose(1, 3) 127 | 128 | if self.training: 129 | x = self.spec_augmenter(x) 130 | 131 | # Mixup on spectrogram 132 | # if self.training and mixup_lambda is not None: 133 | # x = do_mixup(x, mixup_lambda) 134 | 135 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 136 | x = F.dropout(x, p=0.2, training=self.training) 137 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 138 | x = F.dropout(x, p=0.2, training=self.training) 139 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 140 | x = F.dropout(x, p=0.2, training=self.training) 141 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 142 | x = F.dropout(x, p=0.2, training=self.training) 143 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 144 | x = F.dropout(x, p=0.2, training=self.training) 145 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 146 | x = F.dropout(x, p=0.2, training=self.training) 147 | x = torch.mean(x, dim=3) 148 | 149 | (x1, _) = torch.max(x, dim=2) 150 | x2 = torch.mean(x, dim=2) 151 | x = x1 + x2 152 | x = F.dropout(x, p=0.5, training=self.training) 153 | x = F.relu_(self.fc1(x)) 154 | embedding = F.dropout(x, p=0.5, training=self.training) 155 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 156 | 157 | output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding} 158 | 159 | return output_dict 160 | 161 | 162 | class Transfer_Cnn14(nn.Module): 163 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 164 | fmax, classes_num, freeze_base): 165 | """Classifier for a new task using pretrained Cnn14 as a sub module. 166 | """ 167 | super(Transfer_Cnn14, self).__init__() 168 | audioset_classes_num = 527 169 | 170 | self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, 171 | fmax, audioset_classes_num) 172 | 173 | # Transfer to another task layer 174 | self.fc_transfer = nn.Linear(2048, classes_num, bias=True) 175 | 176 | if freeze_base: 177 | # Freeze AudioSet pretrained layers 178 | for param in self.base.parameters(): 179 | param.requires_grad = False 180 | 181 | self.init_weights() 182 | 183 | def init_weights(self): 184 | init_layer(self.fc_transfer) 185 | 186 | def load_from_pretrain(self, pretrained_checkpoint_path): 187 | checkpoint = torch.load(pretrained_checkpoint_path) 188 | self.base.load_state_dict(checkpoint['model']) 189 | 190 | def forward(self, input, mixup_lambda=None): 191 | """Input: (batch_size, data_length) 192 | """ 193 | output_dict = self.base(input, mixup_lambda) 194 | embedding = output_dict['embedding'] 195 | 196 | clipwise_output = torch.sigmoid(self.fc_transfer(embedding)) 197 | output_dict['clipwise_output'] = clipwise_output 198 | 199 | return output_dict --------------------------------------------------------------------------------