├── LICENSE ├── README.md ├── __init__.py ├── apply.py ├── bc_resnet_model.py ├── example_model └── model-sc-2.pt ├── get_data.py ├── main.py ├── requiements.txt ├── subspectral_norm.py ├── train.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 re9ulus 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BC-ResNet for Keyword Spotting 2 | Unofficial implementation of [Broadcasted Residual Learning for Efficient Keyword Spotting](https://arxiv.org/abs/2106.04140) 3 | 4 | # TODO: 5 | - add specaug to train 6 | 7 | ### Usage 8 | Train 9 | ``` 10 | ; train scaled 2 times model for 50 epochs and save best checkpoint to model-sc-2.pt 11 | python main.py train --scale 2 --epoch 50 --checkpoint-file model-sc-2.pt 12 | 13 | ; Device: cuda 14 | ; Use subspectral norm: True 15 | ; --- start epoch 0 --- 16 | ; Train Epoch: 0 Loss: 3.6272 17 | ; Train Epoch: 0 Loss: 1.6613 18 | ; ... 19 | ; Train Epoch: 49 Loss: 0.3026 20 | ; Validation accuracy: 0.9626289950906722 21 | ; Top validation accuracy: 0.9628293758140467 22 | ; Test accuracy: 0.9604725124943208 23 | ``` 24 | 25 | Test 26 | ``` 27 | ; test saved model on test dataset 28 | python main.py test --scale 2 --model-file model-sc-2.pt 29 | 30 | ; Test accuracy: 0.9604725124943208 31 | ``` 32 | 33 | Apply 34 | ``` 35 | ; apply saved model to wav file 36 | python main.py apply --scale 2 --model-file model-sc-2.pt --wav-file SpeechCommands/speech_commands_v0.02/seven/5744b6a7_nohash_0.wav 37 | 38 | seven 0.99977 39 | six 0.00011 40 | stop 0.00008 41 | happy 0.00002 42 | up 0.00000 43 | ``` 44 | 45 | You can find pretrained `model-sc-2.pt` model in `example_model` folder. 46 | 47 | ### Options and help 48 | Use help 49 | ``` 50 | python main.py --help 51 | python main.py train --help 52 | python main.py test --help 53 | python main.py apply --help 54 | ``` 55 | 56 | This implementation use all 35 labels from Google Speech Commands Dataset. Original paper use 10 commands and additional re-balanced "Unknown word" and "Silence" labels (section 4.1 in paper). 57 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/re9ulus/BC-ResNet/086ce358885b4b4521b2763c4563f9e681412e74/__init__.py -------------------------------------------------------------------------------- /apply.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import get_data 3 | import numpy as np 4 | import torchaudio 5 | 6 | 7 | def number_of_correct(pred, target): 8 | return pred.squeeze().eq(target).sum().item() 9 | 10 | 11 | def get_likely_index(tensor): 12 | return tensor.argmax(dim=-1) 13 | 14 | 15 | def compute_accuracy(model, data_loader, device): 16 | model.eval() 17 | correct = 0 18 | for data, target in data_loader: 19 | data = data.to(device) 20 | target = target.to(device) 21 | 22 | pred = model(data) 23 | pred = get_likely_index(pred) 24 | 25 | correct += number_of_correct(pred, target) 26 | 27 | score = correct / len(data_loader.dataset) 28 | return score 29 | 30 | 31 | def apply_to_wav(model, waveform: torch.Tensor, sample_rate: float, device: str): 32 | model.eval() 33 | mel_spec = get_data.prepare_wav(waveform, sample_rate) 34 | mel_spec = torch.unsqueeze(mel_spec, dim=0).to(device) 35 | res = model(mel_spec) 36 | 37 | probs = torch.nn.Softmax(dim=-1)(res).cpu().detach().numpy() 38 | predictions = [] 39 | for idx in np.argsort(-probs): 40 | label = get_data.idx_to_label(idx) 41 | predictions.append((label, probs[idx])) 42 | return predictions 43 | 44 | 45 | def apply_to_file(model, wav_file: str, device: str): 46 | waveform, sample_rate = torchaudio.load(wav_file) 47 | return apply_to_wav(model, waveform, sample_rate, device) 48 | -------------------------------------------------------------------------------- /bc_resnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from subspectral_norm import SubSpectralNorm 5 | 6 | 7 | DROPOUT = 0.1 8 | 9 | 10 | class NormalBlock(nn.Module): 11 | def __init__(self, n_chan: int, *, dilation: int = 1, dropout: float = DROPOUT, use_subspectral: bool = True): 12 | super().__init__() 13 | norm_layer = SubSpectralNorm(n_chan, 5) if use_subspectral else nn.BatchNorm2d(n_chan) 14 | self.f2 = nn.Sequential( 15 | nn.Conv2d(n_chan, n_chan, kernel_size=(3, 1), padding="same", groups=n_chan), 16 | norm_layer, 17 | ) 18 | self.f1 = nn.Sequential( 19 | nn.Conv2d(n_chan, n_chan, kernel_size=(1, 3), padding="same", groups=n_chan, dilation=(1, dilation)), 20 | nn.BatchNorm2d(n_chan), 21 | nn.SiLU(), 22 | nn.Conv2d(n_chan, n_chan, kernel_size=1), 23 | nn.Dropout2d(dropout) 24 | ) 25 | self.activation = nn.ReLU() 26 | 27 | def forward(self, x): 28 | n_freq = x.shape[2] 29 | x1 = self.f2(x) 30 | 31 | x2 = torch.mean(x1, dim=2, keepdim=True) 32 | x2 = self.f1(x2) 33 | x2 = x2.repeat(1, 1, n_freq, 1) 34 | 35 | return self.activation(x + x1 + x2) 36 | 37 | 38 | class TransitionBlock(nn.Module): 39 | def __init__(self, in_chan: int, out_chan: int, *, dilation: int = 1, stride: int = 1, dropout: float = DROPOUT, use_subspectral: bool = True): 40 | super().__init__() 41 | 42 | if stride == 1: 43 | conv = nn.Conv2d(out_chan, out_chan, kernel_size=(3, 1), groups=out_chan, padding="same") 44 | else: 45 | conv = nn.Conv2d(out_chan, out_chan, kernel_size=(3, 1), stride=(stride, 1), groups=out_chan, padding=(1, 0)) 46 | 47 | norm_layer = SubSpectralNorm(out_chan, 5) if use_subspectral else nn.BatchNorm2d(out_chan) 48 | self.f2 = nn.Sequential( 49 | nn.Conv2d(in_chan, out_chan, kernel_size=(1, 1)), 50 | nn.BatchNorm2d(out_chan), 51 | nn.ReLU(), 52 | conv, 53 | norm_layer, 54 | ) 55 | 56 | self.f1 = nn.Sequential( 57 | nn.Conv2d(out_chan, out_chan, kernel_size=(1, 3), padding="same", groups=out_chan, dilation=(1, dilation)), 58 | nn.BatchNorm2d(out_chan), 59 | nn.SiLU(), 60 | nn.Conv2d(out_chan, out_chan, kernel_size=1), 61 | nn.Dropout2d(dropout) 62 | ) 63 | 64 | self.activation = nn.ReLU() 65 | 66 | def forward(self, x: torch.Tensor): 67 | x = self.f2(x) 68 | n_freq = x.shape[2] 69 | x1 = torch.mean(x, dim=2, keepdim=True) 70 | x1 = self.f1(x1) 71 | x1 = x1.repeat(1, 1, n_freq, 1) 72 | 73 | return self.activation(x + x1) 74 | 75 | 76 | class BcResNetModel(nn.Module): 77 | def __init__(self, n_class: int = 35, *, scale: int = 1, dropout: float = DROPOUT, use_subspectral: bool = True): 78 | super().__init__() 79 | 80 | self.input_conv = nn.Conv2d(1, 16*scale, kernel_size=(5, 5), stride=(2, 1), padding=2) 81 | 82 | self.t1 = TransitionBlock(16*scale, 8*scale, dropout=dropout, use_subspectral=use_subspectral) 83 | self.n11 = NormalBlock(8*scale, dropout=dropout, use_subspectral=use_subspectral) 84 | 85 | self.t2 = TransitionBlock(8*scale, 12*scale, dilation=2, stride=2, dropout=dropout, use_subspectral=use_subspectral) 86 | self.n21 = NormalBlock(12*scale, dilation=2, dropout=dropout, use_subspectral=use_subspectral) 87 | 88 | self.t3 = TransitionBlock(12*scale, 16*scale, dilation=4, stride=2, dropout=dropout, use_subspectral=use_subspectral) 89 | self.n31 = NormalBlock(16*scale, dilation=4, dropout=dropout, use_subspectral=use_subspectral) 90 | self.n32 = NormalBlock(16*scale, dilation=4, dropout=dropout, use_subspectral=use_subspectral) 91 | self.n33 = NormalBlock(16*scale, dilation=4, dropout=dropout, use_subspectral=use_subspectral) 92 | 93 | self.t4 = TransitionBlock(16*scale, 20*scale, dilation=8, dropout=dropout, use_subspectral=use_subspectral) 94 | self.n41 = NormalBlock(20*scale, dilation=8, dropout=dropout, use_subspectral=use_subspectral) 95 | self.n42 = NormalBlock(20*scale, dilation=8, dropout=dropout, use_subspectral=use_subspectral) 96 | self.n43 = NormalBlock(20*scale, dilation=8, dropout=dropout, use_subspectral=use_subspectral) 97 | 98 | self.dw_conv = nn.Conv2d(20*scale, 20*scale, kernel_size=(5, 5), groups=20) 99 | self.onexone_conv = nn.Conv2d(20*scale, 32*scale, kernel_size=1) 100 | 101 | self.head_conv = nn.Conv2d(32*scale, n_class, kernel_size=1) 102 | 103 | def forward(self, x: torch.Tensor): 104 | x = self.input_conv(x) 105 | x = self.t1(x) 106 | x = self.n11(x) 107 | 108 | x = self.t2(x) 109 | x = self.n21(x) 110 | 111 | x = self.t3(x) 112 | x = self.n31(x) 113 | x = self.n32(x) 114 | x = self.n33(x) 115 | 116 | x = self.t4(x) 117 | x = self.n41(x) 118 | x = self.n42(x) 119 | x = self.n43(x) 120 | 121 | x = self.dw_conv(x) 122 | x = self.onexone_conv(x) 123 | 124 | x = torch.mean(x, dim=3, keepdim=True) 125 | x = self.head_conv(x) 126 | 127 | x = x.squeeze() 128 | 129 | return F.log_softmax(x, dim=-1) 130 | -------------------------------------------------------------------------------- /example_model/model-sc-2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/re9ulus/BC-ResNet/086ce358885b4b4521b2763c4563f9e681412e74/example_model/model-sc-2.pt -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import torch 5 | import torchaudio 6 | from torchaudio import transforms 7 | from torchaudio.datasets import SPEECHCOMMANDS 8 | 9 | 10 | EPS = 1e-9 11 | SAMPLE_RATE = 16000 12 | 13 | 14 | # default labels from GSC dataset 15 | DEFAULT_LABELS = [ 16 | 'backward', 17 | 'bed', 18 | 'bird', 19 | 'cat', 20 | 'dog', 21 | 'down', 22 | 'eight', 23 | 'five', 24 | 'follow', 25 | 'forward', 26 | 'four', 27 | 'go', 28 | 'happy', 29 | 'house', 30 | 'learn', 31 | 'left', 32 | 'marvin', 33 | 'nine', 34 | 'no', 35 | 'off', 36 | 'on', 37 | 'one', 38 | 'right', 39 | 'seven', 40 | 'sheila', 41 | 'six', 42 | 'stop', 43 | 'three', 44 | 'tree', 45 | 'two', 46 | 'up', 47 | 'visual', 48 | 'wow', 49 | 'yes', 50 | 'zero' 51 | ] 52 | 53 | N_CLASS = len(DEFAULT_LABELS) 54 | 55 | 56 | def prepare_wav(waveform, sample_rate): 57 | if sample_rate != SAMPLE_RATE: 58 | resampler = transforms.Resample(orig_freq=sample_rate, new_freq=SAMPLE_RATE) 59 | waveform = resampler(waveform) 60 | to_mel = transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=1024, f_max=8000, n_mels=40) 61 | log_mel = (to_mel(waveform) + EPS).log2() 62 | return log_mel 63 | 64 | 65 | class SubsetSC(SPEECHCOMMANDS): 66 | def __init__(self, subset: str, path="./"): 67 | super().__init__(path, download=True) 68 | self.to_mel = transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=1024, f_max=8000, n_mels=40) 69 | self.subset = subset 70 | 71 | def load_list(filename): 72 | filepath = os.path.join(self._path, filename) 73 | with open(filepath) as fh: 74 | return [ 75 | os.path.join(self._path, line.strip()) for line in fh 76 | ] 77 | 78 | self._noise = [] 79 | 80 | if subset == "validation": 81 | self._walker = load_list("validation_list.txt") 82 | elif subset == "testing": 83 | self._walker = load_list("testing_list.txt") 84 | elif subset == "training": 85 | excludes = load_list("validation_list.txt") + load_list("testing_list.txt") 86 | excludes = set(excludes) 87 | self._walker = [w for w in self._walker if w not in excludes] 88 | 89 | noise_paths = [w for w in os.listdir(os.path.join(self._path, "_background_noise_")) if w.endswith(".wav")] 90 | for item in noise_paths: 91 | noise_path = os.path.join(self._path, "_background_noise_", item) 92 | noise_waveform, noise_sr = torchaudio.sox_effects.apply_effects_file(noise_path, effects=[]) 93 | noise_waveform = transforms.Resample(orig_freq=noise_sr, new_freq=SAMPLE_RATE)(noise_waveform) 94 | self._noise.append(noise_waveform) 95 | else: 96 | raise ValueError(f"Unknown subset {subset}. Use validation/testing/training") 97 | 98 | def _noise_augment(self, waveform): 99 | noise_waveform = random.choice(self._noise) 100 | 101 | noise_sample_start = 0 102 | if noise_waveform.shape[1] - waveform.shape[1] > 0: 103 | noise_sample_start = random.randint(0, noise_waveform.shape[1] - waveform.shape[1]) 104 | noise_waveform = noise_waveform[:, noise_sample_start:noise_sample_start+waveform.shape[1]] 105 | 106 | signal_power = waveform.norm(p=2) 107 | noise_power = noise_waveform.norm(p=2) 108 | 109 | snr_dbs = [20, 10, 3] 110 | snr = random.choice(snr_dbs) 111 | 112 | snr = math.exp(snr / 10) 113 | scale = snr * noise_power / signal_power 114 | noisy_signal = (scale * waveform + noise_waveform) / 2 115 | return noisy_signal 116 | 117 | def _shift_augment(self, waveform): 118 | shift = random.randint(-1600, 1600) 119 | waveform = torch.roll(waveform, shift) 120 | if shift > 0: 121 | waveform[0][:shift] = 0 122 | elif shift < 0: 123 | waveform[0][shift:] = 0 124 | return waveform 125 | 126 | def _augment(self, waveform): 127 | if random.random() < 0.8: 128 | waveform = self._noise_augment(waveform) 129 | 130 | waveform = self._shift_augment(waveform) 131 | 132 | return waveform 133 | 134 | def __getitem__(self, n): 135 | waveform, sample_rate, label, _, _ = super().__getitem__(n) 136 | if sample_rate != SAMPLE_RATE: 137 | resampler = transforms.Resample(orig_freq=sample_rate, new_freq=SAMPLE_RATE) 138 | waveform = resampler(waveform) 139 | if self.subset == "training": 140 | waveform = self._augment(waveform) 141 | log_mel = (self.to_mel(waveform) + EPS).log2() 142 | 143 | return log_mel, label 144 | 145 | 146 | _label_to_idx = {label: i for i, label in enumerate(DEFAULT_LABELS)} 147 | _idx_to_label = {i: label for label, i in _label_to_idx.items()} 148 | 149 | 150 | def label_to_idx(label): 151 | return _label_to_idx[label] 152 | 153 | 154 | def idx_to_label(idx): 155 | return _idx_to_label[idx] 156 | 157 | 158 | def pad_sequence(batch): 159 | batch = [item.permute(2, 1, 0) for item in batch] 160 | batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True) 161 | return batch.permute(0, 3, 2, 1) 162 | 163 | 164 | def collate_fn(batch): 165 | tensors, targets = [], [] 166 | for log_mel, label in batch: 167 | tensors.append(log_mel) 168 | targets.append(label_to_idx(label)) 169 | 170 | tensors = pad_sequence(tensors) 171 | targets = torch.LongTensor(targets) 172 | 173 | return tensors, targets 174 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import click 4 | import torch 5 | import torch.utils.data 6 | 7 | import bc_resnet_model 8 | import get_data 9 | import train 10 | import apply 11 | import util 12 | 13 | 14 | def run(model, train_loader, validation_loader, test_loader, optimizer, scheduler, device, checkpoint_file, n_epoch=10, log_interval=100): 15 | best_score = 0 16 | best_model = copy.deepcopy(model) 17 | for epoch in range(n_epoch): 18 | print(f"--- start epoch {epoch} ---") 19 | train.train_epoch(model, optimizer, train_loader, device, epoch, log_interval=log_interval) 20 | if scheduler: 21 | scheduler.step() 22 | score = apply.compute_accuracy(model, validation_loader, device) 23 | print(f"Validation accuracy: {score:.5f}") 24 | if best_score < score: 25 | best_score = score 26 | best_model = copy.deepcopy(model) 27 | torch.save(best_model.state_dict(), checkpoint_file) 28 | print(f"Top validation accuracy: {best_score:.5f}") 29 | test_score = apply.compute_accuracy(best_model, test_loader, device) 30 | print(f"Test accuracy: {test_score:.5f}") 31 | 32 | 33 | @click.group(help="Train and apply BC-ResNet Keyword Spotting Model") 34 | def cli(): 35 | pass 36 | 37 | 38 | @cli.command("train", help="Train model") 39 | @click.option("--scale", type=int, default=1, help="model width will be multiplied by scale") 40 | @click.option("--batch-size", type=int, default=256, help="batch size") 41 | @click.option("--device", type=str, default=util.get_device(), help="`cuda` or `cpu`") 42 | @click.option("--epoch", type=int, default=10, help="number of epochs to train") 43 | @click.option("--log-interval", type=int, default=100, help="display train loss after every `log-interval` batch") 44 | @click.option("--checkpoint-file", type=str, default="model.torch", help="file to save model checkpoint") 45 | @click.option("--optimizer", type=str, default="adam", help="optimizer adam/sgd") 46 | @click.option("--dropout", type=float, default=0.1, help="dropout") 47 | @click.option("--subspectral-norm/--dropout-norm", type=bool, default=True, help="use SubspectralNorm or Dropout") 48 | def train_command(scale, batch_size, device, epoch, log_interval, checkpoint_file, optimizer, dropout, subspectral_norm): 49 | if os.path.exists(checkpoint_file): 50 | raise FileExistsError(f"{checkpoint_file} already exists") 51 | 52 | if device == "cuda": 53 | num_workers = 1 54 | pin_memory = True 55 | else: 56 | num_workers = 0 57 | pin_memory = False 58 | 59 | print(f"Device: {device}") 60 | print(f"Use subspectral norm: {subspectral_norm}") 61 | 62 | model = bc_resnet_model.BcResNetModel( 63 | n_class=get_data.N_CLASS, 64 | scale=scale, 65 | dropout=dropout, 66 | use_subspectral=subspectral_norm, 67 | ).to(device) 68 | 69 | train_loader = torch.utils.data.DataLoader( 70 | get_data.SubsetSC(subset="training"), 71 | batch_size=batch_size, 72 | shuffle=True, 73 | collate_fn=get_data.collate_fn, 74 | num_workers=num_workers, 75 | pin_memory=pin_memory 76 | ) 77 | validation_loader = torch.utils.data.DataLoader( 78 | get_data.SubsetSC(subset="validation"), 79 | batch_size=batch_size, 80 | shuffle=False, 81 | drop_last=False, 82 | collate_fn=get_data.collate_fn, 83 | num_workers=num_workers, 84 | pin_memory=pin_memory 85 | ) 86 | test_loader = torch.utils.data.DataLoader( 87 | get_data.SubsetSC(subset="testing"), 88 | batch_size=batch_size, 89 | shuffle=False, 90 | drop_last=False, 91 | collate_fn=get_data.collate_fn, 92 | num_workers=num_workers, 93 | pin_memory=pin_memory 94 | ) 95 | 96 | if optimizer == "adam": 97 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001) 98 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 99 | elif optimizer == "sgd": 100 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001) 101 | scheduler = None 102 | else: 103 | raise ValueError(f"Unknown optimizer {optimizer}, use adam/sgd") 104 | 105 | run( 106 | model, 107 | train_loader, 108 | validation_loader, 109 | test_loader, 110 | optimizer, 111 | scheduler, 112 | device, 113 | checkpoint_file, 114 | n_epoch=epoch, 115 | log_interval=log_interval 116 | ) 117 | 118 | 119 | @cli.command("test", help="Test model accuracy on test set") 120 | @click.option("--model-file", type=str, help="path to model weights") 121 | @click.option("--scale", type=int, default=1, help="model width will be multiplied by scale") 122 | @click.option("--batch-size", type=int, default=256, help="batch size") 123 | @click.option("--device", type=str, default=util.get_device(), help="`cuda` or `cpu`") 124 | @click.option("--dropout", type=float, default=0.1, help="dropout") 125 | @click.option("--subspectral-norm/--dropout-norm", type=bool, default=True, help="use SubspectralNorm or Dropout") 126 | def test_command(model_file, scale, batch_size, device, dropout, subspectral_norm): 127 | if not os.path.exists(model_file): 128 | raise FileExistsError(f"model {model_file} not exists") 129 | 130 | if device == "cuda": 131 | num_workers = 1 132 | pin_memory = True 133 | else: 134 | num_workers = 0 135 | pin_memory = False 136 | 137 | print(f"Device: {device}") 138 | print(f"Use subspectral norm: {subspectral_norm}") 139 | 140 | model = bc_resnet_model.BcResNetModel( 141 | n_class=get_data.N_CLASS, 142 | scale=scale, 143 | dropout=dropout, 144 | use_subspectral=subspectral_norm, 145 | ).to(device) 146 | model.load_state_dict(torch.load(model_file)) 147 | 148 | test_loader = torch.utils.data.DataLoader( 149 | get_data.SubsetSC(subset="testing"), 150 | batch_size=batch_size, 151 | shuffle=False, 152 | drop_last=False, 153 | collate_fn=get_data.collate_fn, 154 | num_workers=num_workers, 155 | pin_memory=pin_memory 156 | ) 157 | test_score = apply.apply(model, test_loader, device) 158 | print(f"Test accuracy: {test_score}") 159 | 160 | 161 | @cli.command("apply", help="Apply model to wav file") 162 | @click.option("--model-file", type=str, help="path to model weights") 163 | @click.option("--wav-file", type=str, help="path to wav sound file") 164 | @click.option("--scale", type=int, default=1, help="model width will be multiplied by scale") 165 | @click.option("--device", type=str, default=util.get_device(), help="`cuda` or `cpu`") 166 | @click.option("--dropout", type=float, default=0.1, help="dropout") 167 | @click.option("--subspectral-norm/--dropout-norm", type=bool, default=True, help="use SubspectralNorm or Dropout") 168 | def apply_command(model_file, wav_file, scale, device, dropout, subspectral_norm): 169 | if not os.path.exists(model_file): 170 | raise FileExistsError(f"model file {model_file} not exists") 171 | if not os.path.exists(wav_file): 172 | raise FileExistsError(f"sound file {wav_file} not exists") 173 | 174 | model = bc_resnet_model.BcResNetModel( 175 | n_class=get_data.N_CLASS, 176 | scale=scale, 177 | dropout=dropout, 178 | use_subspectral=subspectral_norm, 179 | ).to(device) 180 | model.load_state_dict(torch.load(model_file)) 181 | model.eval() 182 | 183 | predictions = apply.apply_to_file(model, wav_file, device) 184 | for label, prob in predictions[:5]: 185 | print(f"{label}\t{prob:.5f}") 186 | 187 | 188 | if __name__ == "__main__": 189 | cli() 190 | -------------------------------------------------------------------------------- /requiements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0+cu111 2 | torchaudio==0.9.0 3 | numpy==1.21.1 4 | click==8.0.3 5 | -------------------------------------------------------------------------------- /subspectral_norm.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/2103.13620 2 | from torch import nn 3 | 4 | 5 | class SubSpectralNorm(nn.Module): 6 | def __init__(self, channels, sub_bands, eps=1e-5): 7 | super().__init__() 8 | self.sub_bands = sub_bands 9 | self.bn = nn.BatchNorm2d(channels*sub_bands, eps=eps) 10 | 11 | def forward(self, x): 12 | N, C, F, T = x.size() 13 | x = x.view(N, C * self.sub_bands, F // self.sub_bands, T) 14 | x = self.bn(x) 15 | return x.view(N, C, F, T) 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def train_epoch(model, optimizer, train_loader, device, epoch, log_interval): 5 | model.train() 6 | 7 | losses = [] 8 | for batch_idx, (data, target) in enumerate(train_loader): 9 | data = data.to(device) 10 | 11 | target = target.to(device) 12 | output = model(data) 13 | loss = F.nll_loss(output.squeeze(), target) 14 | 15 | optimizer.zero_grad() 16 | loss.backward() 17 | optimizer.step() 18 | 19 | if batch_idx % log_interval == 0: 20 | print(f"Train Epoch: {epoch}\tLoss: {loss.item():.4f}") 21 | 22 | losses.append(loss.item()) 23 | 24 | return losses 25 | 26 | 27 | def train(n_epoch, model, optimizer, train_loader, device, log_interval): 28 | print(f"--- Start train {n_epoch} epoches") 29 | for epoch in range(n_epoch): 30 | print(f"--- Start epoch {epoch+1}") 31 | train_epoch(model, optimizer, train_loader, device, epoch, log_interval) 32 | print("--- Done train") 33 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_device(): 5 | return "cuda" if torch.cuda.is_available() else "cpu" 6 | --------------------------------------------------------------------------------