├── .gitignore ├── LICENSE ├── README.md ├── config_parser.py ├── docs └── config_file_explained.md ├── download_gspeech_v2.sh ├── inference.py ├── label_map.json ├── make_data_list.py ├── models ├── __init__.py └── kwmlp.py ├── notebooks ├── README.md ├── keyword_mlp_tutorial.ipynb └── mlp-mixer-audio.ipynb ├── requirements.txt ├── resources ├── kw-mlp.png └── wandb.png ├── sample_configs └── base_config.yaml ├── train.py ├── utils ├── __init__.py ├── augment.py ├── dataset.py ├── loss.py ├── misc.py ├── opt.py ├── scheduler.py └── trainer.py └── window_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | env/ 4 | configs/ 5 | data/ 6 | runs/ 7 | wandb/ 8 | notes.txt 9 | tests -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mashrur Mahmud Morshed and Ahmad Omar Ahsan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keyword-MLP 2 | 3 | Official PyTorch implementation of [*Attention-Free Keyword Spotting*](https://arxiv.org/abs/2110.07749v1). 4 | 5 | Keyword-MLP Architecture 6 | 7 | Open in Colab 8 | 9 | ## Setup 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Dataset 16 | To download the Google Speech Commands V2 dataset, you may run the provided bash script as below. This would download and extract the dataset to the "destination path" provided. 17 | 18 | ``` 19 | sh ./download_gspeech_v2.sh 20 | ``` 21 | 22 | ## Training 23 | 24 | The Speech Commands V2 dataset provides two files: `validation_list.txt` and `testing_list.txt`. Run: 25 | 26 | ``` 27 | python make_data_list.py -v -t -d -o 28 | ``` 29 | 30 | This will create the files `training_list.txt`, `validation_list.txt`, `testing_list.txt` and `label_map.json` at the specified output directory. 31 | 32 | Running `train.py` is fairly straightforward. Only a path to a config file is required. Inside the config file, you'll need to add the paths to the .txt files and the label_map.json file created above. 33 | 34 | ``` 35 | python train.py --conf path/to/config.yaml 36 | ``` 37 | 38 | Refer to the [example config](sample_configs/base_config.yaml) to see how the config file looks like, and see the [config explanation](docs/config_file_explained.md) for a complete rundown of the various config parameters. You may also take a look at the [colab tutorial](#tutorials) for a live example. 39 | 40 | 41 | ## Inference 42 | 43 | You can use the pre-trained model (or a model you trained) for inference, using the two scripts: 44 | 45 | - `inference.py`: For short ~1s clips, like the audios in the Speech Commands dataset 46 | - `window_inference.py`: For running inference on longer audio clips, where multiple keywords may be present. Runs inference on the audio in a sliding window manner. 47 | 48 | ``` 49 | python inference.py --conf sample_configs/base_config.yaml \ 50 | --ckpt \ 51 | --inp \ 52 | --out \ 53 | --lmap label_map.json \ 54 | --device cpu \ 55 | --batch_size 8 # should be possible to use much larger batches if necessary, like 128, 256, 512 etc. 56 | 57 | !python window_inference.py --conf sample_configs/base_config.yaml \ 58 | --ckpt \ 59 | --inp \ 60 | --out \ 61 | --lmap label_map.json \ 62 | --device cpu \ 63 | --wlen 1 \ 64 | --stride 0.5 \ 65 | --thresh 0.85 \ 66 | --mode multi 67 | ``` 68 | 69 | For a detailed usage example, check the [colab tutorial](#tutorials). 70 | 71 | ## Tutorials 72 | - [Tutorial: [Using pretrained model | Inference scripts | Training]](notebooks/keyword_mlp_tutorial.ipynb) 73 | - Open in Colab 74 | 75 | ## Weights & Biases 76 | 77 | You can optionally log your training runs with [wandb](https://wandb.ai/site). You may provide a path to a file containing your API key, or use the `WANDB_API_KEY` env variable, or simply provide it manually from a login prompt when you start your training. 78 | 79 | W&B Dashboard 80 | 81 | ## Pretrained Checkpoints 82 | 83 | | Model Name | # Params | GFLOPS | Accuracy (V2-35) | Link | 84 | | ---------- | -------- | ------ | ---------------- | ---- | 85 | | KW-MLP | 424K | 0.045 | 97.56 | [kw-mlp (1.7MB)](https://drive.google.com/uc?id=1lywXTaJjPud41f3G_NmuRHzhDY8uNbWe&export=download) | 86 | 87 | ## Citation 88 | 89 | ```bibtex 90 | @misc{morshed2021attentionfree, 91 | title = {Attention-Free Keyword Spotting}, 92 | author = {Mashrur M. Morshed and Ahmad Omar Ahsan}, 93 | year = {2021}, 94 | eprint = {2110.07749}, 95 | archivePrefix = {arXiv}, 96 | primaryClass = {cs.LG} 97 | } 98 | ``` -------------------------------------------------------------------------------- /config_parser.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import torch 4 | import sys 5 | 6 | 7 | def get_config(config_file: str) -> dict: 8 | """Reads settings from config file. 9 | 10 | Args: 11 | config_file (str): YAML config file. 12 | 13 | Returns: 14 | dict: Dict containing settings. 15 | """ 16 | 17 | with open(config_file, "r") as f: 18 | base_config = yaml.load(f, Loader=yaml.FullLoader) 19 | 20 | if base_config["exp"]["wandb"]: 21 | if base_config["exp"]["wandb_api_key"] is not None: 22 | assert os.path.exists(base_config["exp"]["wandb_api_key"]), f"API key file not found at specified location {base_config['exp']['wandb_api_key']}." 23 | 24 | if base_config["exp"]["device"] == "auto": 25 | base_config["exp"]["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | base_config["hparams"]["device"] = base_config["exp"]["device"] 27 | 28 | return base_config 29 | 30 | 31 | if __name__ == "__main__": 32 | config = get_config(sys.argv[1]) 33 | print("Using settings:\n", yaml.dump(config)) -------------------------------------------------------------------------------- /docs/config_file_explained.md: -------------------------------------------------------------------------------- 1 | # Understanding the Config File 2 | 3 | The config file contains all the hyperparameters and various settings regarding your training runs. I'll break down the numerous settings here as clearly as possible. 4 | 5 | ## Dataset Paths 6 | 7 | ``` 8 | data_root: ./data/ # Where you have extracted the google speech commands v2 dataset. 9 | train_list_file: ./data/training_list.txt # Contains paths to your training .wav files. 10 | val_list_file: ./data/validation_list.txt # Contains paths to your validation .wav files. 11 | test_list_file: ./data/testing_list.txt # Contains paths to your test .wav files. 12 | label_map: ./data/label_map.json # A json file containing {id: label} key value pairs. 13 | # The above four files can be generated by make_data_list.py. 14 | ``` 15 | 16 | ## Experiment Settings 17 | 18 | ``` 19 | exp: 20 | wandb: False # Whether to use wandb or not 21 | wandb_api_key: # Path to your key. Ignored if wandb is False. If blank, looks for key in the ${WANDB_API_KEY} env variable. 22 | proj_name: torch-kw-mlp # Name of your wandb project. Ignored if wandb is False. 23 | exp_dir: ./runs # Your checkpoints will be saved locally at exp_dir/exp_name 24 | exp_name: kw-mlp-0.1.0 # ..for example, ./runs/kw-mlp-0.1.0/something.pth 25 | device: auto # "auto" checks whether cuda is available; if not, uses cpu. You can also put in "cpu" or "cuda" as device. 26 | # only single device training is supported currently. 27 | log_freq: 20 # Saves logs every log_freq steps 28 | log_to_file: True # Saves logs to exp_dir/exp_name/training_logs.txt 29 | log_to_stdout: True # Prints logs to stdout 30 | val_freq: 1 # Validate every val_freq epochs 31 | n_workers: 2 # Number of workers for dataloader --- best to set to number of CPUs on machine 32 | pin_memory: True # Pin memory argument for dataloader 33 | cache: 2 # 0 -> no cache | 1 -> cache wav arrays | 2 -> cache MFCCs (and also prevents wav augmentations like time_shift, 34 | # resampling and add_background_noise) 35 | ``` 36 | 37 | ## Hyperparameters 38 | ``` 39 | hparams: # everything nested under hparams are hyperparamters, and will be logged as wandb hparams as well. 40 | ... 41 | ... 42 | ``` 43 | ### Basic settings 44 | ``` 45 | hparams: 46 | restore_ckpt: # Path to ckpt, if resuming an interrupted training run. Ckpt must have optimizer state as well. 47 | seed: 0 # Random seed for determinism 48 | batch_size: 256 # Batch size 49 | start_epoch: 0 # Start epoch, 0 by default. 50 | n_epochs: 140 # How many epochs will be trained. (1 epoch = (len(dataset) / batch_size) steps) 51 | l_smooth: 0.1 # If a positive float, uses LabelSmoothingLoss instead of the vanilla CrossEntropyLoss 52 | ``` 53 | 54 | ### Audio Processing 55 | ``` 56 | hparams: 57 | ... 58 | ... 59 | audio: 60 | sr: 16000 # sampling rate 61 | n_mels: 40 # number of mel bands for melspectrogram (and MFCC) 62 | n_fft: 480 # n_fft, window length, hop length, center are also all args for calculating the melspectrogram 63 | win_length: 480 # Check the docs here for further explanation: 64 | hop_length: 160 # https://librosa.org/doc/main/generated/librosa.feature.melspectrogram.html#librosa.feature.melspectrogram 65 | center: False # MFCC conversion is currently done on CPU with librosa. May add in a CUDA MFCC conversion later (with nnAudio) 66 | ``` 67 | 68 | ### Model Settings 69 | ``` 70 | hparams: 71 | ... 72 | ... 73 | model: 74 | type: kw-mlp # Selects the KW-MLP architecture 75 | input_res: [40, 98] # Shape of input spectrogram (n_mels x T) 76 | patch_res: [40, 1] # Resolution of patches 77 | num_classes: 35 # Number of classes 78 | channels: 1 # MFCCs are single channel inputs 79 | dim: 64 # Patch embedding dim (d) 80 | depth: 12 # Number of gated MLP blocks (L) 81 | pre_norm: False # Prenorm or Postnorm gated-MLP. PostNorm has been shown to perform better 82 | prob_survival: 0.9 # Each gated MLP block has a 0.1 probability of being dropped, as a regularization scheme 83 | ``` 84 | 85 | ### Optimizer & Scheduling 86 | ``` 87 | hparams: 88 | ... 89 | ... 90 | optimizer: # AdamW with an lr of 0.001 and weight decay of 0.1, as in the paper. 91 | opt_type: adamw # Please modify get_optimizer() in utils/opt.py if you want to add support for more optimizer variants. 92 | opt_kwargs: 93 | lr: 0.001 94 | weight_decay: 0.1 95 | 96 | scheduler: # Warmup scheduling for 10 epochs and cosine annealing, as in the paper. 97 | n_warmup: 10 # Please modify get_scheduler() in utils/scheduler.py if you want to add support for other scheduling techniques. 98 | max_epochs: 140 # Up to which epoch the normal scheduler will be run. 99 | scheduler_type: cosine_annealing 100 | ``` 101 | 102 | ### Augmentation 103 | ``` 104 | hparams: 105 | ... 106 | ... 107 | augment: # Augmentations are applied only during training. In the paper, only spec_aug is used. Resample, time_shift and 108 | # bg_noise are available, like in the Keyword-Transformer paper, but increases training time significantly. 109 | # Make sure to comment out resample, time_shift and bg_noise if the goal is to reproduce the results of KW-MLP. 110 | 111 | # resample: # Randomly resamples between 85% and 115% 112 | # r_min: 0.85 113 | # r_max: 1.15 114 | 115 | # time_shift: # Randomly shifts samples left or right, up to 10% 116 | # s_min: -0.1 117 | # s_max: 0.1 118 | 119 | # bg_noise: # Adds background noise from a folder containing noise files. Make sure folder only contains .wav files 120 | # bg_folder: ./data/_background_noise_/ 121 | 122 | spec_aug: # Spectral augmentation. SpecAug is applied on the CPU currently, with a Numba JIT compiled function. May provide a 123 | # CUDA SpecAug later. 124 | n_time_masks: 2 125 | time_mask_width: 25 126 | n_freq_masks: 2 127 | freq_mask_width: 7 128 | ``` -------------------------------------------------------------------------------- /download_gspeech_v2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data_dir=$1 4 | curr_dir=$PWD 5 | 6 | mkdir -p $data_dir 7 | 8 | cd $data_dir 9 | wget http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz -O - | tar -xz 10 | 11 | cd $curr_dir -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """Run inference on short ~1s clips, like the ones in the Speech Commands dataset.""" 2 | 3 | from argparse import ArgumentParser 4 | from config_parser import get_config 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from utils.misc import get_model 8 | from utils.dataset import GoogleSpeechDataset 9 | from tqdm import tqdm 10 | import os 11 | import glob 12 | import json 13 | 14 | 15 | @torch.no_grad() 16 | def get_preds(net, dataloader, device) -> list: 17 | """Performs inference.""" 18 | 19 | net.eval() 20 | preds_list = [] 21 | 22 | for data in tqdm(dataloader): 23 | data = data.to(device) 24 | out = net(data) 25 | preds = out.argmax(1).cpu().numpy().ravel().tolist() 26 | preds_list.extend(preds) 27 | 28 | return preds_list 29 | 30 | 31 | def main(args): 32 | ###################### 33 | # create model 34 | ###################### 35 | config = get_config(args.conf) 36 | model = get_model(config["hparams"]["model"]) 37 | 38 | ###################### 39 | # load weights 40 | ###################### 41 | ckpt = torch.load(args.ckpt, map_location="cpu") 42 | model.load_state_dict(ckpt["model_state_dict"]) 43 | 44 | ###################### 45 | # setup data 46 | ###################### 47 | if os.path.isdir(args.inp): 48 | data_list = glob.glob(os.path.join(args.inp, "*.wav")) 49 | elif os.path.isfile(args.inp): 50 | data_list = [args.inp] 51 | 52 | dataset = GoogleSpeechDataset( 53 | data_list=data_list, 54 | label_map=None, 55 | audio_settings=config["hparams"]["audio"], 56 | aug_settings=None, 57 | cache=0 58 | ) 59 | 60 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) 61 | 62 | ###################### 63 | # run inference 64 | ###################### 65 | if args.device == "auto": 66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | else: 68 | device = torch.device(args.device) 69 | 70 | model = model.to(device) 71 | preds = get_preds(model, dataloader, device) 72 | 73 | ###################### 74 | # save predictions 75 | ###################### 76 | if args.lmap: 77 | with open(args.lmap, "r") as f: 78 | label_map = json.load(f) 79 | preds = list(map(lambda a: label_map[str(a)], preds)) 80 | 81 | pred_dict = dict(zip(data_list, preds)) 82 | 83 | os.makedirs(args.out, exist_ok=True) 84 | out_path = os.path.join(args.out, "preds.json") 85 | 86 | with open(out_path, "w+") as f: 87 | json.dump(pred_dict, f) 88 | 89 | print(f"Saved preds to {out_path}") 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = ArgumentParser() 94 | parser.add_argument("--conf", type=str, required=True, help="Path to config file. Will be used only to construct model and process audio.") 95 | parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint file.") 96 | parser.add_argument("--inp", type=str, required=True, help="Path to input. Can be a path to a .wav file, or a path to a folder containing .wav files.") 97 | parser.add_argument("--out", type=str, default="./", help="Path to output folder. Predictions will be stored in {out}/preds.json.") 98 | parser.add_argument("--lmap", type=str, default=None, help="Path to label_map.json. If not provided, will save predictions as class indices instead of class names.") 99 | parser.add_argument("--device", type=str, default="auto", help="One of auto, cpu, or cuda.") 100 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch inference.") 101 | 102 | args = parser.parse_args() 103 | 104 | assert os.path.exists(args.inp), f"Could not find input {args.inp}" 105 | 106 | main(args) -------------------------------------------------------------------------------- /label_map.json: -------------------------------------------------------------------------------- 1 | {"0": "backward", "1": "bed", "2": "bird", "3": "cat", "4": "dog", "5": "down", "6": "eight", "7": "five", "8": "follow", "9": "forward", "10": "four", "11": "go", "12": "happy", "13": "house", "14": "learn", "15": "left", "16": "marvin", "17": "nine", "18": "no", "19": "off", "20": "on", "21": "one", "22": "right", "23": "seven", "24": "sheila", "25": "six", "26": "stop", "27": "three", "28": "tree", "29": "two", "30": "up", "31": "visual", "32": "wow", "33": "yes", "34": "zero"} -------------------------------------------------------------------------------- /make_data_list.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from utils.dataset import get_train_val_test_split 4 | import os 5 | 6 | 7 | def main(args): 8 | 9 | train_list, val_list, test_list, label_map = get_train_val_test_split(args.data_root, args.val_list_file, args.test_list_file) 10 | 11 | with open(os.path.join(args.out_dir, "training_list.txt"), "w+") as f: 12 | f.write("\n".join(train_list)) 13 | 14 | with open(os.path.join(args.out_dir, "validation_list.txt"), "w+") as f: 15 | f.write("\n".join(val_list)) 16 | 17 | with open(os.path.join(args.out_dir, "testing_list.txt"), "w+") as f: 18 | f.write("\n".join(test_list)) 19 | 20 | with open(os.path.join(args.out_dir, "label_map.json"), "w+") as f: 21 | json.dump(label_map, f) 22 | 23 | print("Saved data lists and label map.") 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser() 28 | parser.add_argument("-v", "--val_list_file", type=str, required=True, help="Path to validation_list.txt.") 29 | parser.add_argument("-t", "--test_list_file", type=str, required=True, help="Path to test_list.txt.") 30 | parser.add_argument("-d", "--data_root", type=str, required=True, help="Root directory of speech commands v2 dataset.") 31 | parser.add_argument("-o", "--out_dir", type=str, required=True, help="Output directory for data lists and label map.") 32 | args = parser.parse_args() 33 | 34 | main(args) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Research-BD/Keyword-MLP/a87c8ac9c4d3ba01745d68cc5dcf9cd21dd67b72/models/__init__.py -------------------------------------------------------------------------------- /models/kwmlp.py: -------------------------------------------------------------------------------- 1 | 2 | ############################################################################### 3 | # The code for the kw-mlp model is mostly adapted from lucidrains/g-mlp-pytorch 4 | ############################################################################### 5 | # MIT License 6 | # 7 | # Copyright (c) 2021 Phil Wang 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | # SOFTWARE. 26 | ############################################################################### 27 | 28 | 29 | import torch 30 | import torch.nn.functional as F 31 | from torch import nn, einsum 32 | from einops.layers.torch import Rearrange, Reduce 33 | from random import randrange 34 | 35 | 36 | # helpers 37 | 38 | def dropout_layers(layers, prob_survival): 39 | if prob_survival == 1: 40 | return layers 41 | 42 | num_layers = len(layers) 43 | to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival 44 | 45 | # make sure at least one layer makes it 46 | if all(to_drop): 47 | rand_index = randrange(num_layers) 48 | to_drop[rand_index] = False 49 | 50 | layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] 51 | return layers 52 | 53 | 54 | class Residual(nn.Module): 55 | def __init__(self, fn): 56 | super().__init__() 57 | self.fn = fn 58 | 59 | def forward(self, x): 60 | return self.fn(x) + x 61 | 62 | 63 | class PreNorm(nn.Module): 64 | def __init__(self, dim, fn): 65 | super().__init__() 66 | self.fn = fn 67 | self.norm = nn.LayerNorm(dim) 68 | 69 | def forward(self, x, **kwargs): 70 | x = self.norm(x) 71 | return self.fn(x, **kwargs) 72 | 73 | 74 | class PostNorm(nn.Module): 75 | def __init__(self, dim, fn): 76 | super().__init__() 77 | self.norm = nn.LayerNorm(dim) 78 | self.fn = fn 79 | 80 | def forward(self, x, **kwargs): 81 | return self.norm(self.fn(x, **kwargs)) 82 | 83 | 84 | class SpatialGatingUnit(nn.Module): 85 | def __init__(self, dim, dim_seq, act = nn.Identity(), init_eps = 1e-3): 86 | super().__init__() 87 | dim_out = dim // 2 88 | 89 | self.norm = nn.LayerNorm(dim_out) 90 | self.proj = nn.Conv1d(dim_seq, dim_seq, 1) 91 | 92 | self.act = act 93 | 94 | init_eps /= dim_seq 95 | nn.init.uniform_(self.proj.weight, -init_eps, init_eps) 96 | nn.init.constant_(self.proj.bias, 1.) 97 | 98 | def forward(self, x): 99 | res, gate = x.chunk(2, dim = -1) 100 | gate = self.norm(gate) 101 | 102 | weight, bias = self.proj.weight, self.proj.bias 103 | gate = F.conv1d(gate, weight, bias) 104 | 105 | return self.act(gate) * res 106 | 107 | 108 | class gMLPBlock(nn.Module): 109 | def __init__( 110 | self, 111 | *, 112 | dim, 113 | dim_ff, 114 | seq_len, 115 | act = nn.Identity() 116 | ): 117 | super().__init__() 118 | self.proj_in = nn.Sequential( 119 | nn.Linear(dim, dim_ff), 120 | nn.GELU() 121 | ) 122 | 123 | self.sgu = SpatialGatingUnit(dim_ff, seq_len, act) 124 | self.proj_out = nn.Linear(dim_ff // 2, dim) 125 | 126 | def forward(self, x): 127 | x = self.proj_in(x) 128 | x = self.sgu(x) 129 | x = self.proj_out(x) 130 | return x 131 | 132 | 133 | class KW_MLP(nn.Module): 134 | """Keyword-MLP.""" 135 | 136 | def __init__( 137 | self, 138 | input_res = [40, 98], 139 | patch_res = [40, 1], 140 | num_classes = 35, 141 | dim = 64, 142 | depth = 12, 143 | ff_mult = 4, 144 | channels = 1, 145 | prob_survival = 0.9, 146 | pre_norm = False, 147 | **kwargs 148 | ): 149 | super().__init__() 150 | image_height, image_width = input_res 151 | patch_height, patch_width = patch_res 152 | assert (image_height % patch_height) == 0 and (image_width % patch_width) == 0, 'image height and width must be divisible by patch size' 153 | num_patches = (image_height // patch_height) * (image_width // patch_width) 154 | 155 | P_Norm = PreNorm if pre_norm else PostNorm 156 | 157 | dim_ff = dim * ff_mult 158 | 159 | self.to_patch_embed = nn.Sequential( 160 | Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_height, p2 = patch_width), 161 | nn.Linear(channels * patch_height * patch_width, dim) 162 | ) 163 | 164 | self.prob_survival = prob_survival 165 | 166 | self.layers = nn.ModuleList( 167 | [Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)] 168 | ) 169 | 170 | self.to_logits = nn.Sequential( 171 | nn.LayerNorm(dim), 172 | Reduce('b n d -> b d', 'mean'), 173 | nn.Linear(dim, num_classes) 174 | ) 175 | 176 | def forward(self, x): 177 | x = self.to_patch_embed(x) 178 | layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival) 179 | x = nn.Sequential(*layers)(x) 180 | return self.to_logits(x) -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | - [Keyword-MLP Tutorial](keyword_mlp_tutorial.ipynb): Shows a complete example of training and inference (with pretrained model) with the Keyword-MLP repository. 4 | - [MLP-Mixer-Audio](mlp-mixer-audio.ipynb): Contains a TensorFlow implementation of MLP-Mixer which works with audio, which we mention in our ablation studies. -------------------------------------------------------------------------------- /notebooks/mlp-mixer-audio.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Import Libraries","metadata":{}},{"cell_type":"code","source":"import os\nimport pathlib\nimport glob\n\nimport librosa.util\nimport librosa\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport seaborn as sns\nimport tensorflow as tf\n\nfrom tensorflow.keras.layers.experimental import preprocessing\nfrom tensorflow.keras import layers\nfrom tensorflow.keras import models\nfrom IPython import display\nimport tensorflow_datasets as tfds\nimport datetime, os\nfrom wandb.keras import WandbCallback\nimport tensorflow_addons as tfa\n\n\n# Set seed for experiment reproducibility\nseed = 42\ntf.random.set_seed(seed)\nnp.random.seed(seed)","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:16.681182Z","iopub.execute_input":"2021-10-02T05:36:16.68157Z","iopub.status.idle":"2021-10-02T05:36:23.591645Z","shell.execute_reply.started":"2021-10-02T05:36:16.68149Z","shell.execute_reply":"2021-10-02T05:36:23.590689Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"import wandb\n\n# Insert your wandb login key to track metrics\n# wandb.login(key='Your login key')\n","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:25.883506Z","iopub.execute_input":"2021-10-02T05:36:25.883884Z","iopub.status.idle":"2021-10-02T05:36:26.856142Z","shell.execute_reply.started":"2021-10-02T05:36:25.883827Z","shell.execute_reply":"2021-10-02T05:36:26.855357Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# 35 words speech command dataset","metadata":{}},{"cell_type":"code","source":"def get_train_val_test_split(root: str, val_file: str, test_file: str):\n \"\"\"Creates train, val, and test split according to provided val and test files.\n Args:\n root (str): Path to base directory of the dataset.\n val_file (str): Path to file containing list of validation data files.\n test_file (str): Path to file containing list of test data files.\n\n Returns:\n train_list (list): List of paths to training data items.\n val_list (list): List of paths to validation data items.\n test_list (list): List of paths to test data items.\n train_label (list): List of train labels\n val_label (list): List of val labels\n test_label (list): List of test labels\n\n \"\"\"\n\n ####################\n # Labels\n ####################\n\n label_list = [label for label in sorted(os.listdir(root)) if\n os.path.isdir(os.path.join(root, label)) and label[0] != \"_\"]\n label_map = {idx: label for idx, label in enumerate(label_list)}\n label_to_idx = {v: int(k) for k, v in label_map.items()}\n\n ###################\n # Split\n ###################\n\n all_files_set = set()\n for label in label_list:\n all_files_set.update(set(glob.glob(os.path.join(root, label, \"*.wav\"))))\n\n with open(val_file, \"r\") as f:\n val_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip(\"\\n\").split(\"\\n\")))\n\n with open(test_file, \"r\") as f:\n test_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip(\"\\n\").split(\"\\n\")))\n\n assert len(val_files_set.intersection(\n test_files_set)) == 0, \"Sanity check: No files should be common between val and test.\"\n\n all_files_set -= val_files_set\n all_files_set -= test_files_set\n\n train_list, val_list, test_list = list(all_files_set), list(val_files_set), list(test_files_set)\n\n print(f\"Number of training samples: {len(train_list)}\")\n print(f\"Number of validation samples: {len(val_list)}\")\n print(f\"Number of test samples: {len(test_list)}\")\n\n train_label = []\n val_label = []\n test_label = []\n\n for path in train_list:\n train_label.append(label_to_idx[path.split('/')[-2]])\n\n for path in val_list:\n val_label.append(label_to_idx[path.split('/')[-2]])\n\n for path in test_list:\n test_label.append(label_to_idx[path.split('/')[-2]])\n\n return train_list, val_list, test_list, train_label, val_label, test_label","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:28.99479Z","iopub.execute_input":"2021-10-02T05:36:28.99511Z","iopub.status.idle":"2021-10-02T05:36:29.009692Z","shell.execute_reply.started":"2021-10-02T05:36:28.995079Z","shell.execute_reply":"2021-10-02T05:36:29.008763Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def time_shift(wav: np.ndarray, sr: int, s_min: float, s_max: float) -> np.ndarray:\n \"\"\"Time shift augmentation.\n Refer to https://www.kaggle.com/haqishen/augmentation-methods-for-audio#1.-Time-shifting.\n Changed np.r_ to np.hstack for numba support.\n Args:\n wav (np.ndarray): Waveform array of shape (n_samples,).\n sr (int): Sampling rate.\n s_min (float): Minimum fraction of a second by which to shift.\n s_max (float): Maximum fraction of a second by which to shift.\n \n Returns:\n wav_time_shift (np.ndarray): Time-shifted waveform array.\n \"\"\"\n\n start = int(np.random.uniform(sr * s_min, sr * s_max))\n if start >= 0:\n wav_time_shift = np.hstack((wav[start:], np.random.uniform(-0.001, 0.001, start)))\n else:\n wav_time_shift = np.hstack((np.random.uniform(-0.001, 0.001, -start), wav[:start]))\n \n return wav_time_shift\n\n\ndef resample(x: np.ndarray, sr: int, r_min: float, r_max: float) -> np.ndarray:\n \"\"\"Resamples waveform.\n Args:\n x (np.ndarray): Input waveform, array of shape (n_samples, ).\n sr (int): Sampling rate.\n r_min (float): Minimum percentage of resampling.\n r_max (float): Maximum percentage of resampling.\n \"\"\"\n\n sr_new = sr * np.random.uniform(r_min, r_max)\n x = librosa.resample(x, sr, sr_new)\n return x, sr_new\n\n\n\ndef spec_augment(mel_spec: np.ndarray, n_time_masks: int, time_mask_width: int, n_freq_masks: int, freq_mask_width: int):\n \"\"\"Numpy implementation of spectral augmentation.\n Args:\n mel_spec (np.ndarray): Mel spectrogram, array of shape (n_mels, T).\n n_time_masks (int): Number of time bands. \n time_mask_width (int): Max width of each time band.\n n_freq_masks (int): Number of frequency bands.\n freq_mask_width (int): Max width of each frequency band.\n Returns:\n mel_spec (np.ndarray): Spectrogram with random time bands and freq bands masked out.\n \"\"\"\n \n offset, begin = 0, 0\n\n for _ in range(n_time_masks):\n offset = np.random.randint(0, time_mask_width)\n begin = np.random.randint(0, mel_spec.shape[1] - offset)\n mel_spec[:, begin: begin + offset] = 0.0\n \n for _ in range(n_freq_masks):\n offset = np.random.randint(0, freq_mask_width)\n begin = np.random.randint(0, mel_spec.shape[0] - offset)\n mel_spec[begin: begin + offset, :] = 0.0\n\n return mel_spec","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:30.979018Z","iopub.execute_input":"2021-10-02T05:36:30.979333Z","iopub.status.idle":"2021-10-02T05:36:30.990689Z","shell.execute_reply.started":"2021-10-02T05:36:30.979303Z","shell.execute_reply":"2021-10-02T05:36:30.989449Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"def sample_generator(data_list: list, label_list: list, augment: bool):\n \"\"\"\n Generator function to create samples\n :param data: Data list\n :param label_list: Label list\n :return:\n \"\"\"\n\n\n def transform(x, augment=True):\n sr = 16000\n x = librosa.util.fix_length(x, sr)\n x = librosa.feature.melspectrogram(y=x, n_fft=480, win_length=480, hop_length=160, center=False)\n x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=40)\n if augment:\n x = spec_augment(mel_spec=x, n_time_masks=2, time_mask_width=25, n_freq_masks=2, freq_mask_width=7)\n x = tf.expand_dims(x, axis=-1)\n #x = tf.tile(x, [1, 1, 3])\n return x\n\n for audio_file, label in zip(data_list, label_list):\n \n x = librosa.load(audio_file, sr=16000)[0]\n x = transform(x, augment)\n label = tf.convert_to_tensor(label, dtype=tf.int32)\n label = tf.one_hot(label, depth=35)\n yield x, label\n","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:32.930299Z","iopub.execute_input":"2021-10-02T05:36:32.930661Z","iopub.status.idle":"2021-10-02T05:36:32.940814Z","shell.execute_reply.started":"2021-10-02T05:36:32.93063Z","shell.execute_reply":"2021-10-02T05:36:32.939637Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"# Model training","metadata":{}},{"cell_type":"markdown","source":"## Model architecture","metadata":{}},{"cell_type":"code","source":"import tensorflow as tf\n\n\nclass Patches(tf.keras.layers.Layer):\n \"\"\"\n Extract patches from images\n \"\"\"\n def __init__(self, patch_size_w, patch_size_h):\n super(Patches, self).__init__()\n \n self.w = patch_size_w\n self.h = patch_size_h\n\n def call(self, images):\n \n batch_size = tf.shape(images)[0]\n patches = tf.image.extract_patches(\n images=images,\n sizes=[1, self.w, self.h, 1],\n strides=[1, self.w, self.h, 1],\n rates=[1, 1, 1, 1],\n padding='SAME',\n )\n \n dim = patches.shape[-1]\n\n patches = tf.reshape(patches, (batch_size, -1, dim))\n return patches\n\n\nclass Mixer(tf.keras.layers.Layer):\n def __init__(self, S, C, DS, DC):\n \"\"\"\n Mixer layer for the MLP mixer\n :param S: Number of patches\n :param C: Hidden dimension for projection\n :param DS: tunable token mixing hidden width\n :param DC: tunable channel mixing hidden width\n \"\"\"\n super(Mixer, self).__init__()\n self.layer_norm = tf.keras.layers.LayerNormalization()\n self.S = S\n self.C = C\n self.DS = DS\n self.DC = DC\n \n w_init = tf.random_normal_initializer()\n\n self.W1 = tf.Variable(initial_value=w_init(shape=(S, DS), dtype=\"float32\"),trainable=True, name='W1')\n self.W2 = tf.Variable(initial_value=w_init(shape=(DS, S), dtype=\"float32\"),trainable=True, name='W2')\n self.W3 = tf.Variable(initial_value=w_init(shape=(C, DC), dtype=\"float32\"),trainable=True, name='W3')\n self.W4 = tf.Variable(initial_value=w_init(shape=(DC, C), dtype=\"float32\"),trainable=True, name='W4')\n\n def call(self, X):\n \"\"\"\n Call function of mixer layer\n :param X: Input\n :return:\n \"\"\"\n X_T = tf.transpose(self.layer_norm(X), perm=(0, 2, 1))\n \n W1X = tf.matmul(X_T, self.W1)\n \n U = X + tf.transpose(tf.matmul(tf.nn.gelu(W1X), self.W2), perm=(0, 2, 1))\n\n W3U = tf.matmul(self.layer_norm(U), self.W3)\n Y = U + tf.matmul(tf.nn.gelu(W3U), self.W4)\n\n return Y\n\n\nclass MLPMixer(tf.keras.models.Model):\n def __init__(self, patch_size_w, patch_size_h, C, DS, DC, num_of_mixer_blocks, num_classes):\n \"\"\"\n Creates the Mixer model\n :param patch_size: Patch size\n :param S: number of patches\n :param C: dimension of projection layer\n :param DS: tunable token mixing hidden width\n :param DC: tunable channel mixing hidden width\n :param num_of_mixer_blocks: number of mixer layers\n :param num_classes: number of classes\n \"\"\"\n super(MLPMixer, self).__init__()\n self.projection = tf.keras.layers.Dense(C)\n self.S = int((40*98)/(patch_size_w * patch_size_h))\n self.mixer = [Mixer(self.S, C, DS, DC,) for _ in range(num_of_mixer_blocks)]\n self.C = C\n self.DS = DS\n self.DC = DC\n self.num_classes = num_classes\n self.patch_w = patch_size_w\n self.patch_h = patch_size_h\n\n \n self.classification_layer = tf.keras.models.Sequential([\n tf.keras.layers.GlobalAveragePooling1D(),\n tf.keras.layers.Dropout(0.2),\n tf.keras.layers.Dense(self.num_classes, activation='softmax')\n ])\n\n def call(self, images):\n \"\"\"\n Call function for MLPMixer model\n :param images: input image\n :return:\n \"\"\"\n patcher = Patches(self.patch_w, self.patch_h)\n\n X = patcher(images)\n \n X = self.projection(X)\n \n for block in self.mixer:\n X = block(X)\n \n out = self.classification_layer(X)\n return out","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:36.046634Z","iopub.execute_input":"2021-10-02T05:36:36.046966Z","iopub.status.idle":"2021-10-02T05:36:36.076599Z","shell.execute_reply.started":"2021-10-02T05:36:36.046927Z","shell.execute_reply":"2021-10-02T05:36:36.075725Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Learning rate scheduler \n","metadata":{}},{"cell_type":"code","source":"class CosineScheduler(tf.keras.optimizers.schedules.LearningRateSchedule):\n def __init__(self,\n learning_rate_base,\n total_steps,\n warmup_learning_rate=0.0,\n warmup_steps=0):\n self.learning_rate_base = learning_rate_base\n self.total_steps = total_steps\n self.warmup_learning_rate =warmup_learning_rate\n self.warmup_steps = warmup_steps\n \n def __call__(self,step):\n learning_rate = 0.5 * self.learning_rate_base * (1 + tf.cos(\n np.pi * \n (tf.cast(step, tf.float32) - self.warmup_steps)/ float(self.total_steps-self.warmup_steps)))\n if self.warmup_steps > 0:\n slope = (self.learning_rate_base - self.warmup_learning_rate) / self.warmup_steps\n warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n learning_rate = tf.where(step < self.warmup_steps, warmup_rate, learning_rate)\n lr = tf.where(step > self.total_steps, 0.0, learning_rate, name='learning_rate')\n wandb.log({\"lr\": lr})\n return lr\n\n","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:37.328714Z","iopub.execute_input":"2021-10-02T05:36:37.329039Z","iopub.status.idle":"2021-10-02T05:36:37.337785Z","shell.execute_reply.started":"2021-10-02T05:36:37.329011Z","shell.execute_reply":"2021-10-02T05:36:37.336526Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Model training function for a single epoch","metadata":{}},{"cell_type":"code","source":"# Train the model\n\ndef model_train(features, labels, model, loss_func,optimizer,train_acc,train_loss):\n # Define the GradientTape context\n with tf.GradientTape() as tape:\n # Get the probabilities\n predictions = model(features)\n # Calculate the loss\n loss = loss_func(labels, predictions)\n # Get the gradients\n gradients = tape.gradient(loss, model.trainable_variables)\n # Update the weights\n optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n\n # Update the loss and accuracy\n train_loss(loss)\n train_acc(labels, predictions)\n loss = train_loss.result()\n wandb.log({\"train_loss\": loss.numpy()})","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:38.879316Z","iopub.execute_input":"2021-10-02T05:36:38.879663Z","iopub.status.idle":"2021-10-02T05:36:38.887929Z","shell.execute_reply.started":"2021-10-02T05:36:38.879632Z","shell.execute_reply":"2021-10-02T05:36:38.887005Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Model validation function for a single epoch","metadata":{}},{"cell_type":"code","source":"\ndef model_validate(features, labels,model,loss_func,valid_loss,val_acc):\n predictions = model(features)\n v_loss = loss_func(labels, predictions)\n\n valid_loss(v_loss)\n val_acc(labels, predictions)\n (val_loss, val_acc) = valid_loss.result(), val_acc.result()\n wandb.log({\n \"val_loss\": val_loss.numpy()\n })","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:36:40.857335Z","iopub.execute_input":"2021-10-02T05:36:40.857729Z","iopub.status.idle":"2021-10-02T05:36:40.863514Z","shell.execute_reply.started":"2021-10-02T05:36:40.857696Z","shell.execute_reply":"2021-10-02T05:36:40.86211Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Training function","metadata":{}},{"cell_type":"code","source":"def train(train_dataset,\n val_dataset,\n model,\n optimizer,\n loss_func,\n train_loss,\n train_acc,\n valid_loss,\n valid_acc,\n epochs\n ):\n max_val_acc = 0\n checkpoint_path = \"training_1/cp.ckpt\"\n\n for epoch in range(epochs):\n # Run the model through train and test sets respectively\n for (features, labels) in train_dataset:\n model_train(features, labels, model, loss_func,optimizer,train_acc,train_loss)\n\n for val_features, val_labels in val_dataset:\n model_validate(val_features, val_labels,model,loss_func,valid_loss,valid_acc)\n\n # Grab the results\n (loss, acc) = train_loss.result(), train_acc.result()\n (val_loss, val_acc) = valid_loss.result(), valid_acc.result()\n if val_acc > max_val_acc:\n max_val_acc= val_acc\n model.save_weights(checkpoint_path)\n\n # Clear the current state of the metrics\n train_loss.reset_states(), train_acc.reset_states()\n valid_loss.reset_states(), valid_acc.reset_states()\n\n # Local logging\n template = \"Epoch {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}\"\n print (template.format(epoch+1,\n loss,\n acc,\n val_loss,\n val_acc))\n wandb.log({\"train_accuracy\": acc,\n \"val_accuracy\": val_acc})\n wandb.log({\"best_val_acc\": max_val_acc})\n","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:37:11.533988Z","iopub.execute_input":"2021-10-02T05:37:11.534319Z","iopub.status.idle":"2021-10-02T05:37:11.542669Z","shell.execute_reply.started":"2021-10-02T05:37:11.534288Z","shell.execute_reply":"2021-10-02T05:37:11.541689Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## Initiate training","metadata":{}},{"cell_type":"code","source":"\nconfig_defaults= {\n \"C\": 256,\n \"DS\" :128,\n \"DC\" : 1024,\n \"num_of_mixer_blocks\" : 8,\n 'learning_rate' : 0.0003\n}\nwandb.init(config=config_defaults,project=\"MLP-mixer-audio\")\ntrain_list, val_list, test_list, train_label, val_label, test_label = get_train_val_test_split('../input/google-speech-v2', \n '../input/google-speech-v2/validation_list.txt', \n '../input/google-speech-v2/testing_list.txt')\ntrain_dataset = tf.data.Dataset.from_generator(\n sample_generator,\n args=(train_list, train_label,True),\n output_types=(tf.float32, tf.float32),\n output_shapes=((40, 98, 1), (35,))\n )\nval_dataset = tf.data.Dataset.from_generator(\n sample_generator,\n args=(val_list, val_label,False),\n output_types=(tf.float32, tf.float32),\n output_shapes=((40, 98, 1), (35,))\n )\ntest_dataset = tf.data.Dataset.from_generator(\n sample_generator,\n args=(test_list, test_label,False),\n output_types=(tf.float32, tf.float32),\n output_shapes=((40, 98, 1), (35,))\n )\nAUTOTUNE = tf.data.AUTOTUNE\ntrain_dataset = train_dataset.shuffle(1024).batch(64).cache()\nval_dataset = val_dataset.shuffle(1024).batch(64).cache()\n\n\nmodel = MLPMixer(40, 1, C=wandb.config.C, DS=wandb.config.DS, DC=wandb.config.DC, num_of_mixer_blocks=wandb.config.num_of_mixer_blocks, num_classes=35)\nlearning_rate = CosineScheduler(learning_rate_base=wandb.config.learning_rate, \n total_steps=23000, \n warmup_learning_rate=0.0, \n warmup_steps=1660)\nloss_func = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)\noptimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n\n# Average the loss across the batch size within an epoch\ntrain_loss = tf.keras.metrics.Mean(name=\"train_loss\")\nvalid_loss = tf.keras.metrics.Mean(name=\"test_loss\")\n\n# Specify the performance metric\ntrain_acc = tf.keras.metrics.CategoricalAccuracy(name=\"train_acc\")\nvalid_acc = tf.keras.metrics.CategoricalAccuracy(name=\"valid_acc\")\n\ntrain(train_dataset,\n val_dataset,\n model,\n optimizer,\n loss_func,\n train_loss,\n train_acc,\n valid_loss,\n valid_acc,\n epochs = 50)\ntest_audio = []\ntest_labels = []\n\nfor audio, label in test_dataset:\n test_audio.append(audio.numpy())\n test_labels.append(label.numpy())\n\ntest_audio = np.array(test_audio)\ntest_labels = np.array(test_labels)\n\ny_pred = np.argmax(model.predict(test_audio), axis=1)\ny_true = np.argmax(test_labels,axis=1)\n\ntest_acc = sum(y_pred == y_true) / len(y_true)\nprint(f'Test set accuracy: {test_acc:.0%}')\nwandb.log({'test_acc': test_acc})\n\ntf.keras.backend.clear_session()\ncheckpoint_path = \"training_1/cp.ckpt\"\nnew_model = MLPMixer(40, 1, C=wandb.config.C, DS=wandb.config.DS, DC=wandb.config.DC, num_of_mixer_blocks=wandb.config.num_of_mixer_blocks, num_classes=35)\nnew_model.load_weights(checkpoint_path)\n\ny_pred = np.argmax(new_model.predict(test_audio), axis=1)\ny_true = np.argmax(test_labels,axis=1)\n\ntest_acc = sum(y_pred == y_true) / len(y_true)\nprint(f'Test set accuracy: {test_acc:.0%}')\nwandb.log({'best_test_acc': test_acc})","metadata":{"execution":{"iopub.status.busy":"2021-10-02T05:37:26.396372Z","iopub.execute_input":"2021-10-02T05:37:26.396744Z","iopub.status.idle":"2021-10-02T06:03:21.692714Z","shell.execute_reply.started":"2021-10-02T05:37:26.396716Z","shell.execute_reply":"2021-10-02T06:03:21.691696Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{"execution":{"iopub.status.busy":"2021-09-18T15:12:11.538279Z","iopub.execute_input":"2021-09-18T15:12:11.538656Z","iopub.status.idle":"2021-09-18T16:39:05.242572Z","shell.execute_reply.started":"2021-09-18T15:12:11.538619Z","shell.execute_reply":"2021-09-18T16:39:05.241276Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"","metadata":{},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | pyyaml>=5.3.1 3 | librosa 4 | audiomentations 5 | pydub 6 | wandb 7 | tqdm 8 | einops -------------------------------------------------------------------------------- /resources/kw-mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Research-BD/Keyword-MLP/a87c8ac9c4d3ba01745d68cc5dcf9cd21dd67b72/resources/kw-mlp.png -------------------------------------------------------------------------------- /resources/wandb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Research-BD/Keyword-MLP/a87c8ac9c4d3ba01745d68cc5dcf9cd21dd67b72/resources/wandb.png -------------------------------------------------------------------------------- /sample_configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | ###################### 2 | # sample config file 3 | ###################### 4 | 5 | data_root: ./data/ 6 | train_list_file: ./data/training_list.txt 7 | val_list_file: ./data/validation_list.txt 8 | test_list_file: ./data/testing_list.txt 9 | label_map: ./data/label_map.json 10 | 11 | exp: 12 | wandb: True 13 | wandb_api_key: 14 | proj_name: torch-kw-mlp 15 | exp_dir: ./runs 16 | exp_name: kw-mlp-0.1.0 17 | device: auto 18 | log_freq: 20 # steps 19 | log_to_file: False 20 | log_to_stdout: True 21 | val_freq: 1 # epochs 22 | n_workers: 2 23 | pin_memory: True 24 | cache: 2 # 0 -> no cache | 1 -> cache wavs | 2 -> cache specs; stops wav augments 25 | 26 | 27 | hparams: 28 | restore_ckpt: 29 | seed: 0 30 | batch_size: 256 31 | start_epoch: 0 32 | n_epochs: 140 33 | l_smooth: 0.1 34 | 35 | audio: 36 | sr: 16000 37 | n_mels: 40 38 | n_fft: 480 39 | win_length: 480 40 | hop_length: 160 41 | center: False 42 | 43 | model: 44 | type: kw-mlp 45 | input_res: [40, 98] 46 | patch_res: [40, 1] 47 | num_classes: 35 48 | channels: 1 49 | dim: 64 50 | depth: 12 51 | pre_norm: False 52 | prob_survival: 0.9 53 | 54 | optimizer: 55 | opt_type: adamw 56 | opt_kwargs: 57 | lr: 0.001 58 | weight_decay: 0.1 59 | 60 | scheduler: 61 | n_warmup: 10 62 | max_epochs: 140 63 | scheduler_type: cosine_annealing 64 | 65 | augment: 66 | # resample: 67 | # r_min: 0.85 68 | # r_max: 1.15 69 | 70 | # time_shift: 71 | # s_min: -0.1 72 | # s_max: 0.1 73 | 74 | # bg_noise: 75 | # bg_folder: ./data/bg_folder/ 76 | 77 | spec_aug: 78 | n_time_masks: 2 79 | time_mask_width: 25 80 | n_freq_masks: 2 81 | freq_mask_width: 7 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from config_parser import get_config 3 | 4 | from utils.loss import LabelSmoothingLoss 5 | from utils.opt import get_optimizer 6 | from utils.scheduler import WarmUpLR, get_scheduler 7 | from utils.trainer import train, evaluate 8 | from utils.dataset import get_loader 9 | from utils.misc import seed_everything, count_params, get_model, calc_step, log 10 | 11 | import torch 12 | from torch import nn 13 | import numpy as np 14 | import wandb 15 | 16 | import os 17 | import yaml 18 | import random 19 | import time 20 | 21 | 22 | def training_pipeline(config): 23 | """Initiates and executes all the steps involved with model training. 24 | 25 | Args: 26 | config (dict) - Dict containing various settings for the training run. 27 | """ 28 | 29 | config["exp"]["save_dir"] = os.path.join(config["exp"]["exp_dir"], config["exp"]["exp_name"]) 30 | os.makedirs(config["exp"]["save_dir"], exist_ok=True) 31 | 32 | ###################################### 33 | # save hyperparameters for current run 34 | ###################################### 35 | 36 | config_str = yaml.dump(config) 37 | print("Using settings:\n", config_str) 38 | 39 | with open(os.path.join(config["exp"]["save_dir"], "settings.txt"), "w+") as f: 40 | f.write(config_str) 41 | 42 | 43 | ##################################### 44 | # initialize training items 45 | ##################################### 46 | 47 | # data 48 | with open(config["train_list_file"], "r") as f: 49 | train_list = f.read().rstrip().split("\n") 50 | 51 | with open(config["val_list_file"], "r") as f: 52 | val_list = f.read().rstrip().split("\n") 53 | 54 | with open(config["test_list_file"], "r") as f: 55 | test_list = f.read().rstrip().split("\n") 56 | 57 | 58 | trainloader = get_loader(train_list, config, train=True) 59 | valloader = get_loader(val_list, config, train=False) 60 | testloader = get_loader(test_list, config, train=False) 61 | 62 | # model 63 | model = get_model(config["hparams"]["model"]) 64 | model = model.to(config["hparams"]["device"]) 65 | print(f"Created model with {count_params(model)} parameters.") 66 | 67 | # loss 68 | if config["hparams"]["l_smooth"]: 69 | criterion = LabelSmoothingLoss(num_classes=config["hparams"]["model"]["num_classes"], smoothing=config["hparams"]["l_smooth"]) 70 | else: 71 | criterion = nn.CrossEntropyLoss() 72 | 73 | # optimizer 74 | optimizer = get_optimizer(model, config["hparams"]["optimizer"]) 75 | 76 | # scheduler 77 | schedulers = { 78 | "warmup": None, 79 | "scheduler": None 80 | } 81 | 82 | if config["hparams"]["scheduler"]["n_warmup"]: 83 | schedulers["warmup"] = WarmUpLR(optimizer, total_iters=len(trainloader) * config["hparams"]["scheduler"]["n_warmup"]) 84 | 85 | if config["hparams"]["scheduler"]["scheduler_type"] is not None: 86 | total_iters = len(trainloader) * max(1, (config["hparams"]["scheduler"]["max_epochs"] - config["hparams"]["scheduler"]["n_warmup"])) 87 | schedulers["scheduler"] = get_scheduler(optimizer, config["hparams"]["scheduler"]["scheduler_type"], total_iters) 88 | 89 | 90 | ##################################### 91 | # Resume run 92 | ##################################### 93 | 94 | if config["hparams"]["restore_ckpt"]: 95 | ckpt = torch.load(config["hparams"]["restore_ckpt"]) 96 | 97 | config["hparams"]["start_epoch"] = ckpt["epoch"] + 1 98 | model.load_state_dict(ckpt["model_state_dict"]) 99 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 100 | 101 | if schedulers["scheduler"]: 102 | schedulers["scheduler"].load_state_dict(ckpt["scheduler_state_dict"]) 103 | 104 | print(f'Restored state from {config["hparams"]["restore_ckpt"]} successfully.') 105 | 106 | 107 | ##################################### 108 | # Training 109 | ##################################### 110 | 111 | print("Initiating training.") 112 | train(model, optimizer, criterion, trainloader, valloader, schedulers, config) 113 | 114 | 115 | ##################################### 116 | # Final Test 117 | ##################################### 118 | 119 | final_step = calc_step(config["hparams"]["n_epochs"] + 1, len(trainloader), len(trainloader) - 1) 120 | 121 | # evaluating the final state (last.pth) 122 | test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"]) 123 | log_dict = { 124 | "test_loss_last": test_loss, 125 | "test_acc_last": test_acc 126 | } 127 | log(log_dict, final_step, config) 128 | 129 | # evaluating the best validation state (best.pth) 130 | ckpt = torch.load(os.path.join(config["exp"]["save_dir"], "best.pth")) 131 | model.load_state_dict(ckpt["model_state_dict"]) 132 | print("Best ckpt loaded.") 133 | 134 | test_acc, test_loss = evaluate(model, criterion, testloader, config["hparams"]["device"]) 135 | log_dict = { 136 | "test_loss_best": test_loss, 137 | "test_acc_best": test_acc 138 | } 139 | log(log_dict, final_step, config) 140 | 141 | 142 | def main(args): 143 | config = get_config(args.conf) 144 | seed_everything(config["hparams"]["seed"]) 145 | 146 | 147 | if config["exp"]["wandb"]: 148 | if config["exp"]["wandb_api_key"] is not None: 149 | with open(config["exp"]["wandb_api_key"], "r") as f: 150 | os.environ["WANDB_API_KEY"] = f.read() 151 | 152 | elif os.environ.get("WANDB_API_KEY", False): 153 | print(f"Found API key from env variable.") 154 | 155 | else: 156 | wandb.login() 157 | 158 | with wandb.init(project=config["exp"]["proj_name"], name=config["exp"]["exp_name"], config=config["hparams"]): 159 | training_pipeline(config) 160 | 161 | else: 162 | training_pipeline(config) 163 | 164 | 165 | 166 | if __name__ == "__main__": 167 | parser = ArgumentParser("Driver code.") 168 | parser.add_argument("--conf", type=str, required=True, help="Path to config.yaml file.") 169 | args = parser.parse_args() 170 | 171 | main(args) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Research-BD/Keyword-MLP/a87c8ac9c4d3ba01745d68cc5dcf9cd21dd67b72/utils/__init__.py -------------------------------------------------------------------------------- /utils/augment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba as nb 3 | import librosa 4 | 5 | 6 | @nb.jit(nopython=True, cache=True) 7 | def time_shift(wav: np.ndarray, sr: int, s_min: float, s_max: float) -> np.ndarray: 8 | """Time shift augmentation. 9 | Refer to https://www.kaggle.com/haqishen/augmentation-methods-for-audio#1.-Time-shifting. 10 | Changed np.r_ to np.hstack for numba support. 11 | 12 | Args: 13 | wav (np.ndarray): Waveform array of shape (n_samples,). 14 | sr (int): Sampling rate. 15 | s_min (float): Minimum fraction of a second by which to shift. 16 | s_max (float): Maximum fraction of a second by which to shift. 17 | 18 | Returns: 19 | wav_time_shift (np.ndarray): Time-shifted waveform array. 20 | """ 21 | 22 | start = int(np.random.uniform(sr * s_min, sr * s_max)) 23 | if start >= 0: 24 | wav_time_shift = np.hstack((wav[start:], np.random.uniform(-0.001, 0.001, start))) 25 | else: 26 | wav_time_shift = np.hstack((np.random.uniform(-0.001, 0.001, -start), wav[:start])) 27 | 28 | return wav_time_shift 29 | 30 | 31 | def resample(x: np.ndarray, sr: int, r_min: float, r_max: float) -> np.ndarray: 32 | """Resamples waveform. 33 | 34 | Args: 35 | x (np.ndarray): Input waveform, array of shape (n_samples, ). 36 | sr (int): Sampling rate. 37 | r_min (float): Minimum percentage of resampling. 38 | r_max (float): Maximum percentage of resampling. 39 | """ 40 | 41 | sr_new = sr * np.random.uniform(r_min, r_max) 42 | x = librosa.resample(x, sr, sr_new) 43 | return x, sr_new 44 | 45 | 46 | @nb.jit(nopython=True, cache=True) 47 | def spec_augment(mel_spec: np.ndarray, n_time_masks: int, time_mask_width: int, n_freq_masks: int, freq_mask_width: int): 48 | """Numpy implementation of spectral augmentation. 49 | 50 | Args: 51 | mel_spec (np.ndarray): Mel spectrogram, array of shape (n_mels, T). 52 | n_time_masks (int): Number of time bands. 53 | time_mask_width (int): Max width of each time band. 54 | n_freq_masks (int): Number of frequency bands. 55 | freq_mask_width (int): Max width of each frequency band. 56 | 57 | Returns: 58 | mel_spec (np.ndarray): Spectrogram with random time bands and freq bands masked out. 59 | """ 60 | 61 | offset, begin = 0, 0 62 | 63 | for _ in range(n_time_masks): 64 | offset = np.random.randint(0, time_mask_width) 65 | begin = np.random.randint(0, mel_spec.shape[1] - offset) 66 | mel_spec[:, begin: begin + offset] = 0.0 67 | 68 | for _ in range(n_freq_masks): 69 | offset = np.random.randint(0, freq_mask_width) 70 | begin = np.random.randint(0, mel_spec.shape[0] - offset) 71 | mel_spec[begin: begin + offset, :] = 0.0 72 | 73 | return mel_spec 74 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from torch.utils.data import Dataset, DataLoader 4 | import numpy as np 5 | import functools 6 | import librosa 7 | import glob 8 | import os 9 | from tqdm import tqdm 10 | import multiprocessing as mp 11 | import json 12 | 13 | from utils.augment import time_shift, resample, spec_augment 14 | from audiomentations import AddBackgroundNoise 15 | 16 | 17 | def get_train_val_test_split(root: str, val_file: str, test_file: str): 18 | """Creates train, val, and test split according to provided val and test files. 19 | Args: 20 | root (str): Path to base directory of the dataset. 21 | val_file (str): Path to file containing list of validation data files. 22 | test_file (str): Path to file containing list of test data files. 23 | 24 | Returns: 25 | train_list (list): List of paths to training data items. 26 | val_list (list): List of paths to validation data items. 27 | test_list (list): List of paths to test data items. 28 | label_map (dict): Mapping of indices to label classes. 29 | """ 30 | 31 | #################### 32 | # Labels 33 | #################### 34 | 35 | label_list = [label for label in sorted(os.listdir(root)) if os.path.isdir(os.path.join(root, label)) and label[0] != "_"] 36 | label_map = {idx: label for idx, label in enumerate(label_list)} 37 | 38 | ################### 39 | # Split 40 | ################### 41 | 42 | all_files_set = set() 43 | for label in label_list: 44 | all_files_set.update(set(glob.glob(os.path.join(root, label, "*.wav")))) 45 | 46 | with open(val_file, "r") as f: 47 | val_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n"))) 48 | 49 | with open(test_file, "r") as f: 50 | test_files_set = set(map(lambda a: os.path.join(root, a), f.read().rstrip("\n").split("\n"))) 51 | 52 | assert len(val_files_set.intersection(test_files_set)) == 0, "Sanity check: No files should be common between val and test." 53 | 54 | all_files_set -= val_files_set 55 | all_files_set -= test_files_set 56 | 57 | train_list, val_list, test_list = list(all_files_set), list(val_files_set), list(test_files_set) 58 | 59 | print(f"Number of training samples: {len(train_list)}") 60 | print(f"Number of validation samples: {len(val_list)}") 61 | print(f"Number of test samples: {len(test_list)}") 62 | 63 | return train_list, val_list, test_list, label_map 64 | 65 | 66 | class GoogleSpeechDataset(Dataset): 67 | """Dataset wrapper for Google Speech Commands V2.""" 68 | 69 | def __init__(self, data_list: list, audio_settings: dict, label_map: dict = None, aug_settings: dict = None, cache: int = 0): 70 | super().__init__() 71 | 72 | self.audio_settings = audio_settings 73 | self.aug_settings = aug_settings 74 | self.cache = cache 75 | 76 | if cache: 77 | print("Caching dataset into memory.") 78 | self.data_list = init_cache(data_list, audio_settings["sr"], cache, audio_settings) 79 | else: 80 | self.data_list = data_list 81 | 82 | # labels: if no label map is provided, will not load labels. (Use for inference) 83 | if label_map is not None: 84 | self.label_list = [] 85 | label_2_idx = {v: int(k) for k, v in label_map.items()} 86 | for path in data_list: 87 | self.label_list.append(label_2_idx[path.split("/")[-2]]) 88 | else: 89 | self.label_list = None 90 | 91 | 92 | if aug_settings is not None: 93 | if "bg_noise" in self.aug_settings: 94 | self.bg_adder = AddBackgroundNoise(sounds_path=aug_settings["bg_noise"]["bg_folder"]) 95 | 96 | 97 | def __len__(self): 98 | return len(self.data_list) 99 | 100 | 101 | def __getitem__(self, idx): 102 | if self.cache: 103 | x = self.data_list[idx] 104 | else: 105 | x = librosa.load(self.data_list[idx], self.audio_settings["sr"])[0] 106 | 107 | x = self.transform(x) 108 | 109 | if self.label_list is not None: 110 | label = self.label_list[idx] 111 | return x, label 112 | else: 113 | return x 114 | 115 | 116 | def transform(self, x): 117 | """Applies necessary preprocessing to audio. 118 | Args: 119 | x (np.ndarray) - Input waveform; array of shape (n_samples, ). 120 | 121 | Returns: 122 | x (torch.FloatTensor) - MFCC matrix of shape (n_mfcc, T). 123 | """ 124 | 125 | sr = self.audio_settings["sr"] 126 | 127 | ################### 128 | # Waveform 129 | ################### 130 | 131 | if self.cache < 2: 132 | if self.aug_settings is not None: 133 | if "bg_noise" in self.aug_settings: 134 | x = self.bg_adder(samples=x, sample_rate=sr) 135 | 136 | if "time_shift" in self.aug_settings: 137 | x = time_shift(x, sr, **self.aug_settings["time_shift"]) 138 | 139 | if "resample" in self.aug_settings: 140 | x, _ = resample(x, sr, **self.aug_settings["resample"]) 141 | 142 | x = librosa.util.fix_length(x, sr) 143 | 144 | ################### 145 | # Spectrogram 146 | ################### 147 | 148 | x = librosa.feature.melspectrogram(y=x, **self.audio_settings) 149 | x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=self.audio_settings["n_mels"]) 150 | 151 | 152 | if self.aug_settings is not None: 153 | if "spec_aug" in self.aug_settings: 154 | x = spec_augment(x, **self.aug_settings["spec_aug"]) 155 | 156 | x = torch.from_numpy(x).float().unsqueeze(0) 157 | return x 158 | 159 | 160 | def cache_item_loader(path: str, sr: int, cache_level: int, audio_settings: dict) -> np.ndarray: 161 | x = librosa.load(path, sr)[0] 162 | if cache_level == 2: 163 | x = librosa.util.fix_length(x, sr) 164 | x = librosa.feature.melspectrogram(y=x, **audio_settings) 165 | x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=audio_settings["n_mels"]) 166 | return x 167 | 168 | 169 | def init_cache(data_list: list, sr: int, cache_level: int, audio_settings: dict, n_cache_workers: int = 4) -> list: 170 | """Loads entire dataset into memory for later use. 171 | Args: 172 | data_list (list): List of data items. 173 | sr (int): Sampling rate. 174 | cache_level (int): Cache levels, one of (1, 2), caching wavs and spectrograms respectively. 175 | n_cache_workers (int, optional): Number of workers. Defaults to 4. 176 | Returns: 177 | cache (list): List of data items. 178 | """ 179 | 180 | cache = [] 181 | loader_fn = functools.partial(cache_item_loader, sr=sr, cache_level=cache_level, audio_settings=audio_settings) 182 | 183 | pool = mp.Pool(n_cache_workers) 184 | 185 | for audio in tqdm(pool.imap(func=loader_fn, iterable=data_list), total=len(data_list)): 186 | cache.append(audio) 187 | 188 | pool.close() 189 | pool.join() 190 | 191 | return cache 192 | 193 | 194 | def get_loader(data_list, config, train=True): 195 | """Creates dataloaders for training, validation and testing. 196 | Args: 197 | config (dict): Dict containing various settings for the training run. 198 | train (bool): Training or evaluation mode. 199 | 200 | Returns: 201 | dataloader (DataLoader): DataLoader wrapper for training/validation/test data. 202 | """ 203 | 204 | with open(config["label_map"], "r") as f: 205 | label_map = json.load(f) 206 | 207 | dataset = GoogleSpeechDataset( 208 | data_list=data_list, 209 | label_map=label_map, 210 | audio_settings=config["hparams"]["audio"], 211 | aug_settings=config["hparams"]["augment"] if train else None, 212 | cache=config["exp"]["cache"] 213 | ) 214 | 215 | dataloader = DataLoader( 216 | dataset, 217 | batch_size=config["hparams"]["batch_size"], 218 | num_workers=config["exp"]["n_workers"], 219 | pin_memory=config["exp"]["pin_memory"], 220 | shuffle=True if train else False 221 | ) 222 | 223 | return dataloader -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LabelSmoothingLoss(nn.Module): 6 | """Cross Entropy with Label Smoothing. 7 | 8 | Attributes: 9 | num_classes (int): Number of target classes. 10 | smoothing (float, optional): Smoothing fraction constant, in the range (0.0, 1.0). Defaults to 0.1. 11 | dim (int, optional): Dimension across which to apply loss. Defaults to -1. 12 | """ 13 | 14 | def __init__(self, num_classes: int, smoothing : float = 0.1, dim : int = -1): 15 | """Initializer for LabelSmoothingLoss. 16 | 17 | Args: 18 | num_classes (int): Number of target classes. 19 | smoothing (float, optional): Smoothing fraction constant, in the range (0.0, 1.0). Defaults to 0.1. 20 | dim (int, optional): Dimension across which to apply loss. Defaults to -1. 21 | """ 22 | super().__init__() 23 | 24 | self.confidence = 1.0 - smoothing 25 | self.smoothing = smoothing 26 | self.cls = num_classes 27 | self.dim = dim 28 | 29 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 30 | """Forward function. 31 | 32 | Args: 33 | pred (torch.Tensor): Model predictions, of shape (batch_size, num_classes). 34 | target (torch.Tensor): Target tensor of shape (batch_size). 35 | 36 | Returns: 37 | torch.Tensor: Loss. 38 | """ 39 | 40 | assert 0 <= self.smoothing < 1 41 | pred = pred.log_softmax(dim=self.dim) 42 | 43 | with torch.no_grad(): 44 | true_dist = torch.zeros_like(pred) 45 | true_dist.fill_(self.smoothing / (self.cls - 1)) 46 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 47 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous helper functions.""" 2 | 3 | import torch 4 | from torch import nn, optim 5 | import numpy as np 6 | import random 7 | import os 8 | import wandb 9 | from models.kwmlp import KW_MLP 10 | 11 | 12 | def seed_everything(seed: str) -> None: 13 | """Set manual seed. 14 | 15 | Args: 16 | seed (int): Supplied seed. 17 | """ 18 | 19 | random.seed(seed) 20 | os.environ['PYTHONHASHSEED'] = str(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | print(f'Set seed {seed}') 26 | 27 | 28 | def count_params(model: nn.Module) -> int: 29 | """Counts number of parameters in a model. 30 | 31 | Args: 32 | model (torch.nn.Module): Model instance for which number of params is to be counted. 33 | 34 | Returns: 35 | int: Parameter count. 36 | """ 37 | 38 | return sum(map(lambda p: p.data.numel(), model.parameters())) 39 | 40 | 41 | def calc_step(epoch: int, n_batches: int, batch_index: int) -> int: 42 | """Calculates current step. 43 | 44 | Args: 45 | epoch (int): Current epoch. 46 | n_batches (int): Number of batches in dataloader. 47 | batch_index (int): Current batch index. 48 | 49 | Returns: 50 | int: Current step. 51 | """ 52 | return (epoch - 1) * n_batches + (1 + batch_index) 53 | 54 | 55 | def log(log_dict: dict, step: int, config: dict) -> None: 56 | """Handles logging for metric tracking server, local disk and stdout. 57 | 58 | Args: 59 | log_dict (dict): Log metric dict. 60 | step (int): Current step. 61 | config (dict): Config dict. 62 | """ 63 | 64 | # send logs to wandb tracking server 65 | if config["exp"]["wandb"]: 66 | wandb.log(log_dict, step=step) 67 | 68 | log_message = f"Step: {step} | " + " | ".join([f"{k}: {v}" for k, v in log_dict.items()]) 69 | 70 | # write logs to disk 71 | if config["exp"]["log_to_file"]: 72 | log_file = os.path.join(config["exp"]["save_dir"], "training_log.txt") 73 | 74 | with open(log_file, "a+") as f: 75 | f.write(log_message + "\n") 76 | 77 | # show logs in stdout 78 | if config["exp"]["log_to_stdout"]: 79 | print(log_message) 80 | 81 | 82 | def get_model(model_config: dict) -> nn.Module: 83 | """Creates model from config dict. 84 | 85 | Args: 86 | model_config (dict): Dict containing model config params. If the "name" key is not None, other params are ignored. 87 | 88 | Returns: 89 | nn.Module: Model instance. 90 | """ 91 | 92 | if model_config["type"] == "kw-mlp": 93 | return KW_MLP(**model_config) 94 | else: 95 | raise ValueError(f"Unknown model type: {model_config['type']}") 96 | 97 | 98 | def save_model(epoch: int, val_acc: float, save_path: str, net: nn.Module, optimizer : optim.Optimizer = None, scheduler: optim.lr_scheduler._LRScheduler = None, log_file : str = None) -> None: 99 | """Saves checkpoint. 100 | 101 | Args: 102 | epoch (int): Current epoch. 103 | val_acc (float): Validation accuracy. 104 | save_path (str): Checkpoint path. 105 | net (nn.Module): Model instance. 106 | optimizer (optim.Optimizer, optional): Optimizer. Defaults to None. 107 | scheduler (optim.lr_scheduler._LRScheduler): Scheduler. Defaults to None. 108 | log_file (str, optional): Log file. Defaults to None. 109 | """ 110 | 111 | ckpt_dict = { 112 | "epoch": epoch, 113 | "val_acc": val_acc, 114 | "model_state_dict": net.state_dict(), 115 | "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else optimizer, 116 | "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else scheduler 117 | } 118 | 119 | torch.save(ckpt_dict, save_path) 120 | 121 | log_message = f"Saved {save_path} with accuracy {val_acc}." 122 | print(log_message) 123 | 124 | if log_file is not None: 125 | with open(log_file, "a+") as f: 126 | f.write(log_message + "\n") 127 | -------------------------------------------------------------------------------- /utils/opt.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | 3 | 4 | def get_optimizer(net: nn.Module, opt_config: dict) -> optim.Optimizer: 5 | """Creates optimizer based on config. 6 | 7 | Args: 8 | net (nn.Module): Model instance. 9 | opt_config (dict): Dict containing optimizer settings. 10 | 11 | Raises: 12 | ValueError: Unsupported optimizer type. 13 | 14 | Returns: 15 | optim.Optimizer: Optimizer instance. 16 | """ 17 | 18 | if opt_config["opt_type"] == "adamw": 19 | optimizer = optim.AdamW(net.parameters(), **opt_config["opt_kwargs"]) 20 | else: 21 | raise ValueError(f'Unsupported optimizer {opt_config["opt_type"]}') 22 | 23 | return optimizer 24 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from torch.optim import lr_scheduler 3 | 4 | 5 | class WarmUpLR(lr_scheduler._LRScheduler): 6 | """WarmUp learning rate scheduler. 7 | 8 | Args: 9 | optimizer (optim.Optimizer): Optimizer instance 10 | total_iters (int): steps_per_epoch * n_warmup_epochs 11 | last_epoch (int): Final epoch. Defaults to -1. 12 | """ 13 | 14 | def __init__(self, optimizer: optim.Optimizer, total_iters: int, last_epoch: int = -1): 15 | """Initializer for WarmUpLR""" 16 | 17 | self.total_iters = total_iters 18 | super().__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | """Learning rate will be set to base_lr * last_epoch / total_iters.""" 22 | 23 | return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] 24 | 25 | 26 | def get_scheduler(optimizer: optim.Optimizer, scheduler_type: str, T_max: int) -> lr_scheduler._LRScheduler: 27 | """Gets scheduler. 28 | 29 | Args: 30 | optimizer (optim.Optimizer): Optimizer instance. 31 | scheduler_type (str): Specified scheduler. 32 | T_max (int): Final step. 33 | 34 | Raises: 35 | ValueError: Unsupported scheduler type. 36 | 37 | Returns: 38 | lr_scheduler._LRScheduler: Scheduler instance. 39 | """ 40 | 41 | if scheduler_type == "cosine_annealing": 42 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=1e-8) 43 | else: 44 | raise ValueError(f"Unsupported scheduler type: {scheduler_type}") 45 | 46 | return scheduler -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from typing import Callable, Tuple 4 | from torch.utils.data import DataLoader 5 | from utils.misc import log, save_model 6 | import os 7 | import time 8 | from tqdm import tqdm 9 | 10 | 11 | def train_single_batch(net: nn.Module, data: torch.Tensor, targets: torch.Tensor, optimizer: optim.Optimizer, criterion: Callable, device: torch.device) -> Tuple[float, int]: 12 | """Performs a single training step. 13 | 14 | Args: 15 | net (nn.Module): Model instance. 16 | data (torch.Tensor): Data tensor, of shape (batch_size, dim_1, ... , dim_N). 17 | targets (torch.Tensor): Target tensor, of shape (batch_size). 18 | optimizer (optim.Optimizer): Optimizer instance. 19 | criterion (Callable): Loss function. 20 | device (torch.device): Device. 21 | 22 | Returns: 23 | float: Loss scalar. 24 | int: Number of correct preds. 25 | """ 26 | 27 | data, targets = data.to(device), targets.to(device) 28 | 29 | optimizer.zero_grad() 30 | outputs = net(data) 31 | loss = criterion(outputs, targets) 32 | loss.backward() 33 | optimizer.step() 34 | 35 | correct = outputs.argmax(1).eq(targets).sum() 36 | return loss.item(), correct.item() 37 | 38 | 39 | @torch.no_grad() 40 | def evaluate(net: nn.Module, criterion: Callable, dataloader: DataLoader, device: torch.device) -> Tuple[float, float]: 41 | """Performs inference. 42 | 43 | Args: 44 | net (nn.Module): Model instance. 45 | criterion (Callable): Loss function. 46 | dataloader (DataLoader): Test or validation loader. 47 | device (torch.device): Device. 48 | 49 | Returns: 50 | accuracy (float): Accuracy. 51 | float: Loss scalar. 52 | """ 53 | 54 | net.eval() 55 | correct = 0 56 | running_loss = 0.0 57 | 58 | for data, targets in tqdm(dataloader): 59 | data, targets = data.to(device), targets.to(device) 60 | out = net(data) 61 | correct += out.argmax(1).eq(targets).sum().item() 62 | loss = criterion(out, targets) 63 | running_loss += loss.item() 64 | 65 | net.train() 66 | accuracy = correct / len(dataloader.dataset) 67 | return accuracy, running_loss / len(dataloader) 68 | 69 | 70 | def train(net: nn.Module, optimizer: optim.Optimizer, criterion: Callable, trainloader: DataLoader, valloader: DataLoader, schedulers: dict, config: dict) -> None: 71 | """Trains model. 72 | 73 | Args: 74 | net (nn.Module): Model instance. 75 | optimizer (optim.Optimizer): Optimizer instance. 76 | criterion (Callable): Loss function. 77 | trainloader (DataLoader): Training data loader. 78 | valloader (DataLoader): Validation data loader. 79 | schedulers (dict): Dict containing schedulers. 80 | config (dict): Config dict. 81 | """ 82 | 83 | step = 0 84 | best_acc = 0.0 85 | n_batches = len(trainloader) 86 | device = config["hparams"]["device"] 87 | log_file = os.path.join(config["exp"]["save_dir"], "training_log.txt") 88 | 89 | ############################ 90 | # start training 91 | ############################ 92 | net.train() 93 | 94 | for epoch in range(config["hparams"]["start_epoch"], config["hparams"]["n_epochs"]): 95 | t0 = time.time() 96 | running_loss = 0.0 97 | correct = 0 98 | 99 | for batch_index, (data, targets) in enumerate(trainloader): 100 | 101 | if schedulers["warmup"] is not None and epoch < config["hparams"]["scheduler"]["n_warmup"]: 102 | schedulers["warmup"].step() 103 | 104 | elif schedulers["scheduler"] is not None: 105 | schedulers["scheduler"].step() 106 | 107 | #################### 108 | # optimization step 109 | #################### 110 | 111 | loss, corr = train_single_batch(net, data, targets, optimizer, criterion, device) 112 | running_loss += loss 113 | correct += corr 114 | 115 | if not step % config["exp"]["log_freq"]: 116 | log_dict = {"epoch": epoch, "loss": loss, "lr": optimizer.param_groups[0]["lr"]} 117 | log(log_dict, step, config) 118 | 119 | step += 1 120 | 121 | ####################### 122 | # epoch complete 123 | ####################### 124 | 125 | log_dict = {"epoch": epoch, "time_per_epoch": time.time() - t0, "train_acc": correct/(len(trainloader.dataset)), "avg_loss_per_ep": running_loss/len(trainloader)} 126 | log(log_dict, step, config) 127 | 128 | if not epoch % config["exp"]["val_freq"]: 129 | val_acc, avg_val_loss = evaluate(net, criterion, valloader, device) 130 | log_dict = {"epoch": epoch, "val_loss": avg_val_loss, "val_acc": val_acc} 131 | log(log_dict, step, config) 132 | 133 | # save best val ckpt 134 | if val_acc > best_acc: 135 | best_acc = val_acc 136 | save_path = os.path.join(config["exp"]["save_dir"], "best.pth") 137 | save_model(epoch, val_acc, save_path, net, optimizer, schedulers["scheduler"], log_file) 138 | 139 | ########################### 140 | # training complete 141 | ########################### 142 | 143 | val_acc, avg_val_loss = evaluate(net, criterion, valloader, device) 144 | log_dict = {"epoch": epoch, "val_loss": avg_val_loss, "val_acc": val_acc} 145 | log(log_dict, step, config) 146 | 147 | # save final ckpt 148 | save_path = os.path.join(config["exp"]["save_dir"], "last.pth") 149 | save_model(epoch, val_acc, save_path, net, optimizer, schedulers["scheduler"], log_file) -------------------------------------------------------------------------------- /window_inference.py: -------------------------------------------------------------------------------- 1 | """Runs inference on clips much longer than 1s, by running a sliding window and aggregating predictions.""" 2 | 3 | 4 | from argparse import ArgumentParser 5 | from config_parser import get_config 6 | import torch 7 | import numpy as np 8 | import librosa 9 | from utils.misc import get_model 10 | 11 | from tqdm import tqdm 12 | import os 13 | import glob 14 | import json 15 | 16 | 17 | def process_window(x, sr, audio_settings): 18 | x = librosa.util.fix_length(x, sr) 19 | x = librosa.feature.melspectrogram(y=x, **audio_settings) 20 | x = librosa.feature.mfcc(S=librosa.power_to_db(x), n_mfcc=audio_settings["n_mels"]) 21 | return x 22 | 23 | @torch.no_grad() 24 | def get_clip_pred(net, audio_path, win_len, stride, thresh, config, batch_size, device, mode, label_map) -> list: 25 | """Performs clip-level inference.""" 26 | 27 | net.eval() 28 | preds_list = [] 29 | 30 | audio_settings = config["hparams"]["audio"] 31 | sr = audio_settings["sr"] 32 | win_len, stride = int(win_len * sr), int(stride * sr) 33 | x = librosa.load(audio_path, sr)[0] 34 | 35 | windows, result = [], [] 36 | 37 | slice_positions = np.arange(0, len(x) - win_len + 1, stride) 38 | 39 | for b, i in enumerate(slice_positions): 40 | windows.append( 41 | process_window(x[i: i + win_len], sr, audio_settings) 42 | ) 43 | 44 | if (not (b + 1) % batch_size) or (b + 1) == len(slice_positions): 45 | windows = torch.from_numpy(np.stack(windows)).float().unsqueeze(1) 46 | windows = windows.to(device) 47 | out = net(windows) 48 | conf, preds = out.softmax(1).max(1) 49 | conf, preds = conf.cpu().numpy().reshape(-1, 1), preds.cpu().numpy().reshape(-1, 1) 50 | 51 | starts = slice_positions[b - preds.shape[0] + 1: b + 1, None] 52 | ends = starts + win_len 53 | 54 | res = np.hstack([preds, conf, starts, ends]) 55 | res = res[res[:, 1] > thresh].tolist() 56 | if len(res): 57 | result.extend(res) 58 | windows = [] 59 | 60 | ####################### 61 | # pred aggregation 62 | ####################### 63 | pred = [] 64 | if len(result): 65 | result = np.array(result) 66 | 67 | if mode == "max": 68 | pred = result[result[:, 1].argmax()][0] 69 | if label_map is not None: 70 | pred = label_map[str(int(pred))] 71 | elif mode == "n_voting": 72 | pred = np.bincount(result[:, 0].astype(int)).argmax() 73 | if label_map is not None: 74 | pred = label_map[str(int(pred))] 75 | elif mode == "multi": 76 | if label_map is not None: 77 | pred = list(map(lambda a: [label_map[str(int(a[0]))], a[1], a[2], a[3]], result)) 78 | else: 79 | pred = result.tolist() 80 | 81 | return pred 82 | 83 | 84 | def main(args): 85 | ###################### 86 | # create model 87 | ###################### 88 | config = get_config(args.conf) 89 | model = get_model(config["hparams"]["model"]) 90 | 91 | ###################### 92 | # load weights 93 | ###################### 94 | ckpt = torch.load(args.ckpt, map_location="cpu") 95 | model.load_state_dict(ckpt["model_state_dict"]) 96 | 97 | ###################### 98 | # setup data 99 | ###################### 100 | if os.path.isdir(args.inp): 101 | data_list = glob.glob(os.path.join(args.inp, "*.wav")) 102 | elif os.path.isfile(args.inp): 103 | data_list = [args.inp] 104 | 105 | ###################### 106 | # run inference 107 | ###################### 108 | if args.device == "auto": 109 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 110 | else: 111 | device = torch.device(args.device) 112 | 113 | model = model.to(device) 114 | 115 | label_map = None 116 | if args.lmap: 117 | with open(args.lmap, "r") as f: 118 | label_map = json.load(f) 119 | 120 | pred_dict = dict() 121 | for file_path in data_list: 122 | preds = get_clip_pred(model, file_path, args.wlen, args.stride, args.thresh, config, args.batch_size, device, args.mode, label_map) 123 | pred_dict[file_path] = preds 124 | 125 | os.makedirs(args.out, exist_ok=True) 126 | out_path = os.path.join(args.out, "preds_clip.json") 127 | 128 | with open(out_path, "w+") as f: 129 | json.dump(pred_dict, f) 130 | 131 | print(f"Saved preds to {out_path}") 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = ArgumentParser() 136 | parser.add_argument("--conf", type=str, required=True, help="Path to config file. Will be used only to construct model and process audio.") 137 | parser.add_argument("--ckpt", type=str, required=True, help="Path to checkpoint file.") 138 | parser.add_argument("--inp", type=str, required=True, help="Path to input. Can be a path to a .wav file, or a path to a folder containing .wav files.") 139 | parser.add_argument("--out", type=str, default="./", help="Path to output folder. Predictions will be stored in {out}/preds.json.") 140 | parser.add_argument("--lmap", type=str, default=None, help="Path to label_map.json. If not provided, will save predictions as class indices instead of class names.") 141 | parser.add_argument("--device", type=str, default="auto", help="One of auto, cpu, or cuda.") 142 | parser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch inference.") 143 | parser.add_argument("--wlen", type=float, default=1.0, help="Window length. E.g. for wlen = 1, will make inference on 1s windows from the clip.") 144 | parser.add_argument("--stride", type=float, default=0.2, help="By how much the sliding window will be shifted.") 145 | parser.add_argument("--thresh", type=float, default=0.85, help="Confidence threshold above which preds will be counted.") 146 | parser.add_argument("--mode", type=str, default="multi", help="""Prediction logic. One of: max, n_voting, multi. 147 | -'max' simply checks the confidences of every predicted window in a clip and returns the most confident prediction as the output. 148 | -'n_voting' returns the most frequent predicted class above the threshold. 149 | -'multi' expects that there are multiple different keyword classes in the audio. For each audio, the output is a list of lists, 150 | each sub-list being of the form [class, confidence, start, end].""") 151 | 152 | args = parser.parse_args() 153 | 154 | assert os.path.exists(args.inp), f"Could not find input {args.inp}" 155 | 156 | main(args) --------------------------------------------------------------------------------