├── .gitignore ├── README.md ├── dataset.py ├── models.py ├── screenshots ├── 0.png ├── 1.png ├── 2.png ├── ast.png └── comparision.png ├── train.py ├── train_config.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | archive 3 | pretrained_models 4 | __pycache__ 5 | .idea 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **AUDIO SPECTOGRAM TRANSFORMER** 2 | 3 | [Source](https://arxiv.org/abs/2104.01778) 4 | 5 | 6 | 7 | **** 8 | 9 | Torch implementation of ***[ViT](https://arxiv.org/abs/2010.11929)*** based classifier which achieved **97%** accuracy on Audio [FSDD](https://github.com/Jakobovski/free-spoken-digit-dataset.git) dataset. 10 | 11 | 12 | ![Comparison](screenshots/comparision.png) 13 | 14 | | ![#12b5cb](screenshots/0.png) ViT Audio Classifier (acc 97%) | 15 | | -------------------------------------------------------------------------------------------- | 16 | | ![e52592](screenshots/1.png) Resnet Audio Classifier (93%) | 17 | | ![425066](screenshots/2.png) Resnet with PolyLoss (93%) | 18 | 19 | Check other branches for the comparison with **Resnet** and **Resnet + [Polyloss](https://arxiv.org/abs/2204.12511?context=cs)** code 20 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torchaudio 6 | from utils import * 7 | from torch.utils.data import Dataset, DataLoader 8 | from train_config import * 9 | 10 | 11 | class TIDataset(Dataset): 12 | def __init__(self, mode="train"): 13 | self.file_names = os.listdir(WAV_FILES_PATH) 14 | random.shuffle(self.file_names) 15 | self.file_names = self.file_names[ 16 | :int(len(self.file_names) * DATASET_PERCENTAGE)] 17 | if mode == "train": 18 | self.file_names = self.file_names[ 19 | :int(len(self.file_names) * TRAIN_TEST_SPLIT)] 20 | else: 21 | self.file_names = self.file_names[ 22 | int(len(self.file_names) * TRAIN_TEST_SPLIT):] 23 | self.labels = [int(i.split("_")[0]) for i in self.file_names] 24 | 25 | def __len__(self): 26 | return len(self.labels) 27 | 28 | def __getitem__(self, idx): 29 | waveform, sr = torchaudio.load( 30 | os.path.join(WAV_FILES_PATH, self.file_names[idx])) 31 | window_len = int(sr * WINDOW_SIZE) # 200 32 | stride_len = int(sr * STRIDE) # 80 33 | desired_size = 64 * stride_len # no of time frames = signal_size / stride (or hop_len) 34 | # desired_size = 64 * 80 35 | waveform = cut_if_necessary(waveform, desired_size) 36 | waveform = pad_if_necessary(waveform, desired_size) 37 | n_fft = window_len * 2 - 1 38 | spectogram = torchaudio.transforms.Spectrogram(n_fft=n_fft, 39 | hop_length=stride_len)( 40 | waveform) 41 | # no_of_time_frame(x_axis) = signal_size / stride = 64 * 80 / 80 42 | # no_of_freq(y_axis) = n_fft / 2 + 1 = window_len 43 | # spectogram shape : [1, window, signal_size / stride] # [1, y, x] 44 | # mel_spectogram_shape : [1, mel_filters, signal_size / stride] # [1, y, x] 45 | mel_spectogram = torchaudio.transforms.MelScale(n_mels=32, 46 | n_stft=n_fft // 2 + 1)( 47 | spectogram) 48 | 49 | label = torch.tensor(self.labels[idx], dtype=torch.long) 50 | return mel_spectogram, label 51 | 52 | 53 | if __name__ == '__main__': 54 | training_data = TIDataset("eval") 55 | training_data_loader = DataLoader(training_data, batch_size=2, shuffle=True) 56 | img, label = next(iter(training_data_loader)) 57 | print(len(training_data_loader), label.shape, 58 | img.shape) 59 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from train_config import * 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, dim, heads=8): 8 | super(Attention, self).__init__() 9 | self.dim = dim 10 | self.heads = heads 11 | self.dim_head = self.dim // self.heads 12 | self.query_linear = nn.Linear(self.dim, self.dim_head * self.heads) 13 | self.value_linear = nn.Linear(self.dim, self.dim_head * self.heads) 14 | self.key_linear = nn.Linear(self.dim, self.dim_head * self.heads) 15 | self.softmax = nn.Softmax(dim=-1) 16 | 17 | def forward(self, x): 18 | # print(x.shape) 19 | # print(self.heads, self.dim_head, self.key_linear(x).shape) 20 | keys = self.transform_qkv(self.key_linear(x)) 21 | values = self.transform_qkv(self.value_linear(x)) 22 | queries = self.transform_qkv(self.query_linear(x)) 23 | # print(keys.shape, values.shape, queries.shape) 24 | attention = self.softmax( 25 | torch.matmul(queries, keys.permute(0, 2, 1)) / (self.heads ** .5)) 26 | # print("attention : ", attention.shape) 27 | # print("attention * values : ", torch.matmul(attention, values).shape) 28 | out = self.opp_transform_qkv(torch.matmul(attention, values)) 29 | # print("out : ", out.shape) 30 | # out = torch.matmul(attention, values).reshape(b, n, -1) 31 | # print(out.shape) 32 | return out 33 | 34 | def transform_qkv(self, x): 35 | b, n, _ = x.shape 36 | x = x.reshape(b, n, self.heads, self.dim_head) 37 | x = x.permute(0, 2, 1, 3) 38 | x = x.reshape(-1, x.shape[2], x.shape[3]) 39 | return x 40 | 41 | def opp_transform_qkv(self, x): 42 | x = x.reshape(x.shape[0] // self.heads, self.heads, x.shape[1], x.shape[2]) 43 | x = x.permute(0, 2, 1, 3) 44 | x = x.reshape(x.shape[0], x.shape[1], -1) 45 | return x 46 | 47 | 48 | class Transformer(nn.Module): 49 | def __init__(self, dim, hidden_dim=2048, depth=3, dropout=0.1): 50 | super(Transformer, self).__init__() 51 | self.attention = Attention(dim) 52 | self.feed_forward = nn.Sequential(nn.Linear(dim, hidden_dim), 53 | nn.GELU(), 54 | nn.Dropout(dropout), 55 | nn.Linear(hidden_dim, dim), 56 | nn.Dropout(dropout) 57 | ) 58 | self.layer_norm = nn.LayerNorm(dim) 59 | self.depth = depth 60 | 61 | def forward(self, x): 62 | for i in range(self.depth): 63 | x = x + self.attention(self.layer_norm(x)) 64 | x = x + self.feed_forward(self.layer_norm(x)) 65 | return x 66 | 67 | 68 | class VIT(nn.Module): 69 | def __init__(self, image_shape, patch_size=8, embedding_size=128): 70 | # image_shape = c, h, w 71 | c, h, w = image_shape 72 | super(VIT, self).__init__() 73 | self.patch_size = patch_size 74 | self.p1 = h // self.patch_size 75 | self.p2 = w // self.patch_size 76 | self.no_of_patches = self.p1 * self.p2 77 | # c (h p1) (w p2) -> (h w) (p1 p2 c) 78 | self.patch_dim = c * self.patch_size * self.patch_size 79 | self.new_shape = (self.no_of_patches, self.patch_dim) 80 | self.linear_embedding = nn.Linear(self.patch_dim, embedding_size) 81 | self.class_token = nn.Parameter(torch.rand(1, 1, embedding_size)) 82 | self.position_encoding = nn.Parameter( 83 | torch.rand(1, self.no_of_patches + 1, embedding_size)) 84 | 85 | self.transformer = Transformer(dim=embedding_size) 86 | self.mlp_head = nn.Sequential( 87 | nn.LayerNorm(embedding_size), 88 | nn.Linear(embedding_size, CLASSES), 89 | ) 90 | 91 | def forward(self, x): 92 | batch_size = x.shape[0] 93 | # x is melspectogram shape = 1, freq_bins, time_frames 94 | patches = x.reshape(batch_size, *self.new_shape) 95 | # patches shape is no_of_patches, channel * patches_size**2 96 | patches_embedding = self.linear_embedding(patches) 97 | # add class token 98 | patches_embedding = torch.cat( 99 | (patches_embedding, self.class_token.repeat(batch_size, 1, 1)), 100 | dim=1) 101 | # add position embedding 102 | patches_embedding += self.position_encoding 103 | # pass to transformer 104 | # print(patches_embedding.shape) 105 | out = self.transformer(patches_embedding) 106 | out = self.mlp_head(out[:, 0]) 107 | return out 108 | 109 | 110 | if __name__ == '__main__': 111 | batch_size = 2 112 | image_shape = (1, 32, 64) 113 | seq_shape = (33, 768) 114 | 115 | attention = Attention(768) 116 | print(attention(torch.rand(batch_size, *seq_shape)).shape) 117 | 118 | res = VIT(image_shape) 119 | print(res(torch.rand(batch_size, *image_shape)).shape) 120 | -------------------------------------------------------------------------------- /screenshots/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syedecryptr/audio-spectogram-transformer/f15bede33ce1fc5ea75893c523ce6671f3561424/screenshots/0.png -------------------------------------------------------------------------------- /screenshots/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syedecryptr/audio-spectogram-transformer/f15bede33ce1fc5ea75893c523ce6671f3561424/screenshots/1.png -------------------------------------------------------------------------------- /screenshots/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syedecryptr/audio-spectogram-transformer/f15bede33ce1fc5ea75893c523ce6671f3561424/screenshots/2.png -------------------------------------------------------------------------------- /screenshots/ast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syedecryptr/audio-spectogram-transformer/f15bede33ce1fc5ea75893c523ce6671f3561424/screenshots/ast.png -------------------------------------------------------------------------------- /screenshots/comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syedecryptr/audio-spectogram-transformer/f15bede33ce1fc5ea75893c523ce6671f3561424/screenshots/comparision.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import os 5 | from dataset import TIDataset 6 | from torch.utils.data import DataLoader 7 | from models import VIT 8 | from tqdm import tqdm 9 | from train_config import * 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.nn as nn 12 | 13 | writer = SummaryWriter(os.path.join(RUNS_PATH, EXPERIMENT_NAME)) 14 | 15 | if __name__ == '__main__': 16 | training_data = TIDataset() 17 | test_data = TIDataset("eval") 18 | train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, 19 | shuffle=True) 20 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, 21 | shuffle=True) 22 | print( 23 | f"Starting training with {train_dataloader.__len__() * BATCH_SIZE} " 24 | f"training samples and {test_dataloader.__len__() * BATCH_SIZE} " 25 | f"test samples") 26 | audionet = VIT(IMAGE_SIZE).to(DEVICE) 27 | start_epoch = 1 28 | if RESUME_TRAIN: 29 | audionet = torch.load(RESUME_TRAIN_PATH) 30 | # TODO add scheduler 31 | optimizer = torch.optim.Adam(audionet.parameters(), lr=LEARNING_RATE) 32 | criterion = nn.CrossEntropyLoss() 33 | for epoch in range(start_epoch, EPOCHS): 34 | print(f"\n---------------------- Epoch {epoch} ----------------------") 35 | # validation 36 | audionet.eval() 37 | total_corrects = 0 38 | total_labels = 0 39 | 40 | with torch.no_grad(): 41 | avg_loss = 0 42 | for audios, labels in tqdm(test_dataloader): 43 | audios, labels = audios.to(DEVICE), labels.to(DEVICE) 44 | # print(audios.shape, labels.shape) 45 | y_pred = audionet(audios) 46 | # print(y_pred.shape, labels.shape) 47 | loss = criterion(y_pred, labels) 48 | # print(loss.item()) 49 | avg_loss += loss.item() 50 | # print(torch.argmax(y_pred, dim=1), labels, torch.argmax(y_pred, dim=1) == labels, torch.sum((torch.argmax(y_pred, dim=1) == labels).to(torch.long))) 51 | total_corrects += torch.sum((torch.argmax(y_pred, dim=1) == labels).to(torch.long)) 52 | total_labels += labels.shape[0] 53 | 54 | test_loss = avg_loss / len(test_dataloader) 55 | writer.add_scalar('Loss/test', test_loss, epoch) 56 | print("accuracy", total_corrects / total_labels * 100) 57 | writer.add_scalar("Accuracy/test", 58 | total_corrects / total_labels * 100, epoch) 59 | 60 | audionet.train() 61 | avg_loss = 0 62 | with torch.autograd.set_detect_anomaly(True): 63 | for audios, labels in tqdm(train_dataloader): 64 | audios, labels = audios.to(DEVICE), labels.to(DEVICE) 65 | optimizer.zero_grad() 66 | y_pred = audionet(audios) 67 | loss = criterion(y_pred, labels) 68 | avg_loss += loss.item() 69 | loss.backward() 70 | optimizer.step() 71 | train_loss = avg_loss / len(train_dataloader) 72 | writer.add_scalar('Loss/train', train_loss, epoch) 73 | 74 | if epoch % EPOCH_AFTER_SAVE_MODEL == 0: 75 | if not os.path.exists(MODEL_PATH): 76 | os.makedirs(MODEL_PATH) 77 | # logic to save best model 78 | models = os.listdir(MODEL_PATH) 79 | models_suffixes = [i.split(".pth")[0].split("_")[0] for i in models] 80 | if "best" in models_suffixes: 81 | best_model = models[models_suffixes.index("best")] 82 | prev_loss = float(best_model.split(".pth")[0].split("_")[-1]) 83 | if prev_loss > test_loss: 84 | os.remove(os.path.join(MODEL_PATH, best_model)) 85 | torch.save(audionet.state_dict(), 86 | os.path.join(MODEL_PATH, 87 | f"best_{test_loss}.pth")) 88 | else: 89 | torch.save(audionet.state_dict(), 90 | os.path.join(MODEL_PATH, f"best_{test_loss}.pth")) 91 | # logic to keep no of models less than MODELS_TO_KEEP 92 | models_available = sorted(os.listdir(MODEL_PATH), 93 | key=lambda x: float( 94 | x.split(".pth")[0].split("_")[-1])) 95 | if len(models_available) > MODELS_TO_KEEP: 96 | # 0 element is "best_model.pth" so keep it . 97 | os.remove(os.path.join(MODEL_PATH, models_available[1])) 98 | torch.save(audionet.state_dict(), 99 | os.path.join(MODEL_PATH, f"epoch_{epoch}.pth")) 100 | -------------------------------------------------------------------------------- /train_config.py: -------------------------------------------------------------------------------- 1 | WAV_FILES_PATH = "archive/free-spoken-digit-dataset-master/recordings" 2 | FILTER_BANKS = 32 3 | WINDOW_SIZE = 25 / 1000 # in seconds 4 | STRIDE = 10 / 1000 # in seconds 5 | IMAGE_SIZE = (1, 32, 64) # shape of mel scale (see dataset.py for more info) 6 | CLASSES = 10 7 | 8 | RUNS_PATH = "./runs" 9 | EXPERIMENT_NAME = "vit_based_model" 10 | 11 | DATASET_PERCENTAGE = 1 12 | TRAIN_TEST_SPLIT = 0.9 13 | 14 | EPOCHS = 1000 15 | BATCH_SIZE = 64 16 | EPOCH_AFTER_SAVE_MODEL = 4 17 | MODELS_TO_KEEP = 4 18 | DEVICE = 0 19 | LEARNING_RATE = 0.01 20 | 21 | MODEL_PATH = "./pretrained_models" 22 | 23 | RESUME_TRAIN = False 24 | RESUME_TRAIN_PATH = "" 25 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from train_config import * 2 | import torch.nn.functional as F 3 | 4 | 5 | def pad_if_necessary(waveform, desired_size): 6 | if waveform.shape[1] < desired_size: 7 | missing_samples = desired_size - waveform.shape[1] 8 | waveform = F.pad(waveform, (0, missing_samples)) 9 | return waveform 10 | 11 | 12 | def cut_if_necessary(waveform, desired_size): 13 | if waveform.shape[1] > desired_size: 14 | waveform = waveform[:, :desired_size] 15 | return waveform 16 | --------------------------------------------------------------------------------