├── __init__.py ├── core ├── __init__.py ├── torch │ ├── __init__.py │ ├── graph_operation.py │ ├── ops.py │ ├── graph_wrapper.py │ └── base_classes.py ├── .DS_Store ├── registry.py ├── config.py └── logger.py ├── models ├── __init__.py ├── torch │ ├── __init__.py │ ├── utils.py │ ├── dfsmn.py │ └── bidfsmn.py └── .DS_Store ├── torch_utils ├── __init__.py ├── mixup.py └── distributed_utils.py ├── speech_commands ├── __init__.py ├── dataset │ ├── __init__.py │ ├── transform.py │ └── speech_commands.py └── .DS_Store ├── .DS_Store ├── overview.png ├── README.md ├── basic.py ├── train_speech_commands.py └── train_valid_test.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speech_commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speech_commands/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htqin/BiFSMN/HEAD/.DS_Store -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htqin/BiFSMN/HEAD/overview.png -------------------------------------------------------------------------------- /core/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htqin/BiFSMN/HEAD/core/.DS_Store -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htqin/BiFSMN/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /speech_commands/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htqin/BiFSMN/HEAD/speech_commands/.DS_Store -------------------------------------------------------------------------------- /models/torch/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def weight_init(m, **kwargs): 5 | # weight initialization 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 8 | if m.bias is not None: 9 | nn.init.zeros_(m.bias) 10 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 11 | nn.init.ones_(m.weight) 12 | nn.init.zeros_(m.bias) 13 | elif isinstance(m, nn.Linear): 14 | nn.init.normal_(m.weight, 0, 0.01) 15 | nn.init.zeros_(m.bias) 16 | -------------------------------------------------------------------------------- /core/torch/graph_operation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def replace_module(model: nn.Module, name: str, module: nn.Module): 6 | scopes = name.split('.') 7 | current_scope = '' 8 | current_module = model 9 | for scope in scopes[:-1]: 10 | current_scope += scope 11 | assert hasattr( 12 | current_module, 13 | scope), '{} scope is not in the model'.format(current_scope) 14 | current_module = getattr(current_module, scope) 15 | current_scope += scopes[-1] 16 | scope = scopes[-1] 17 | assert hasattr(current_module, 18 | scope), '{} scope is not in the model'.format(current_scope) 19 | current_module.add_module(scope, module) 20 | 21 | 22 | def find_module_by_name(model: nn.Module, name: str): 23 | for n, m in model.named_modules(): 24 | if n == name: 25 | return m 26 | -------------------------------------------------------------------------------- /core/registry.py: -------------------------------------------------------------------------------- 1 | class Registry(object): 2 | def __init__(self, name): 3 | self._name = name 4 | self._module_dict = dict() 5 | 6 | @property 7 | def name(self): 8 | return self._name 9 | 10 | @property 11 | def module_dict(self): 12 | return self._module_dict 13 | 14 | def module_names(self): 15 | return list(self._module_dict.keys()) 16 | 17 | def register(self, cls): 18 | module_name = cls.__name__ 19 | if module_name in self._module_dict: 20 | raise KeyError('{} is already registered in {}'.format( 21 | module_name, self.name)) 22 | self._module_dict[module_name] = cls 23 | return cls 24 | 25 | def get(self, name): 26 | if name in self._module_dict: 27 | return self._module_dict[name] 28 | else: 29 | raise KeyError('{} is not registered in {}'.format(name, self.name)) 30 | 31 | 32 | CONFIG = Registry('config') 33 | -------------------------------------------------------------------------------- /torch_utils/mixup.py: -------------------------------------------------------------------------------- 1 | # ref to: https://github.com/moskomule/mixup.pytorch 2 | 3 | from typing import Tuple 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def partial_mixup(input: torch.Tensor, gamma: float, 9 | indices: torch.Tensor) -> torch.Tensor: 10 | if input.size(0) != indices.size(0): 11 | raise RuntimeError("Size mismatch!") 12 | perm_input = input[indices] 13 | return input.mul(gamma).add(perm_input, alpha=1 - gamma) 14 | 15 | 16 | def mixup(input: torch.Tensor, target: torch.Tensor, gamma: float, 17 | num_classes) -> Tuple[torch.Tensor, torch.Tensor]: 18 | target = F.one_hot(target, num_classes) 19 | indices = torch.randperm(input.size(0), 20 | device=input.device, 21 | dtype=torch.long) 22 | return partial_mixup(input, gamma, 23 | indices), partial_mixup(target, gamma, indices) 24 | 25 | 26 | def naive_cross_entropy_loss(input: torch.Tensor, 27 | target: torch.Tensor) -> torch.Tensor: 28 | return -(input.log_softmax(dim=-1) * target).sum(dim=-1).mean() 29 | -------------------------------------------------------------------------------- /speech_commands/dataset/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ChangeAmplitude(object): 6 | """Changes amplitude of an audio randomly.""" 7 | def __init__(self, prop=0.5, amplitude_range=(0.7, 1.1)): 8 | self.amplitude_range = amplitude_range 9 | self.prop = prop 10 | 11 | def __call__(self, data: torch.Tensor): 12 | if random.uniform(0, 1) <= self.prop: 13 | data = data * random.uniform(*self.amplitude_range) 14 | return data 15 | 16 | 17 | class FixAudioLength(object): 18 | """Either pads or truncates an audio into a fixed length.""" 19 | def __init__(self, time=1, sample_rate=16000): 20 | self.target_len = time * sample_rate 21 | 22 | def __call__(self, data: torch.Tensor): 23 | cur_len = data.shape[1] 24 | if self.target_len <= cur_len: 25 | data = data[:, :self.target_len] 26 | else: 27 | data = torch.nn.functional.pad(data, (0, self.target_len - cur_len)) 28 | return data 29 | 30 | 31 | class ChangeSpeedAndPitchAudio(object): 32 | """Change the speed of an audio. This transform also changes the pitch of the audio.""" 33 | def __init__(self, prop=0.5, max_scale=0.2, sample_rate=16000): 34 | self.max_scale = max_scale 35 | self.sample_rate = sample_rate 36 | self.prop = prop 37 | 38 | def __call__(self, data): 39 | if random.uniform(0, 1) <= self.prop: 40 | scale = random.uniform(-self.max_scale, self.max_scale) 41 | speed_fac = 1.0 / (1 + scale) 42 | data = torch.nn.functional.interpolate(data.unsqueeze(1), 43 | scale_factor=speed_fac, 44 | mode='nearest').squeeze(1) 45 | return data 46 | 47 | 48 | class TimeshiftAudio(object): 49 | """Shifts an audio randomly.""" 50 | def __init__(self, prop=0.5, max_shift_seconds=0.2, sample_rate=16000): 51 | self.shift_len = max_shift_seconds * sample_rate 52 | self.prop = prop 53 | 54 | def __call__(self, data): 55 | if random.uniform(0, 1) <= self.prop: 56 | shift = random.randint(-self.shift_len, self.shift_len) 57 | a = -min(0, shift) 58 | b = max(0, shift) 59 | data = torch.nn.functional.pad(data, (a, b), "constant") 60 | data = data[:, :data.shape[1] - a] if a else data[:, b:] 61 | return data 62 | -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from copy import deepcopy 3 | from .torch.ops import OpBase, OPS 4 | from .logger import Logger, Verbose, set_slim_logger_level 5 | 6 | set_slim_logger_level(Verbose.INFO) 7 | 8 | 9 | class ConfigBase(object): 10 | """ 11 | a base class for slim configuration 12 | """ 13 | def __init__(self, config, framework): 14 | super(ConfigBase, self).__init__() 15 | self.__key = list() 16 | self.skip_scope = [] 17 | self.work_scope = [] 18 | self.strict_scope = False 19 | self.type_config = {} 20 | self.layer_config = {} 21 | assert framework in [ 22 | 'torch', 'tf' 23 | ], "The `framework` should only be one of ['torch', 'tf']" 24 | self.framework = framework 25 | self.parse_valid(config) 26 | 27 | def set_type_config(self, op_type, config_op: OpBase): 28 | self.type_config[op_type] = deepcopy(config_op) 29 | 30 | def set_layer_config(self, layer_name, config_op: OpBase): 31 | self.layer_config[layer_name] = deepcopy(config_op) 32 | 33 | def get_type_config(self, op_type) -> OpBase: 34 | return deepcopy(self.type_config[op_type]) 35 | 36 | def get_layer_config(self, layer_name) -> OpBase: 37 | return deepcopy(self.layer_config[layer_name]) 38 | 39 | def __getitem__(self, item): 40 | if item in self.__key: 41 | return getattr(self, item) 42 | 43 | @abstractmethod 44 | def parse_valid(self, params: dict): 45 | raise NotImplementedError('parse_valid is not implemented for ', 46 | self.__class__.__name__) 47 | 48 | def NeedSkip(self, op_name, op_type, framework): 49 | need_skip = False 50 | if 'gradients' in op_name and framework == 'tf': 51 | return True 52 | if op_type not in self.type_config.keys(): 53 | need_skip = True 54 | 55 | if len(self.work_scope) > 0: 56 | if self.strict_scope: 57 | if op_name not in self.work_scope: 58 | need_skip = True 59 | else: 60 | hit = False 61 | for scope in self.work_scope: 62 | if op_name.startswith(scope): 63 | hit = True 64 | break 65 | if not hit: 66 | need_skip = True 67 | 68 | if self.strict_scope: 69 | if op_name in self.skip_scope: 70 | Logger(Verbose.INFO)('Strict Skip: {}'.format(op_name)) 71 | need_skip = True 72 | for scope in self.skip_scope: 73 | if op_name.startswith(scope): 74 | need_skip = True 75 | Logger(Verbose.INFO)('Skip: {}'.format(op_name)) 76 | break 77 | 78 | return need_skip 79 | 80 | 81 | class OpConfig(object): 82 | def __init__(self): 83 | raise NotImplementedError('It is not impolemented for ', 84 | self.__class__.__name__) 85 | 86 | def fill_with_default(self, OpConfig: dict, default_config: dict): 87 | for key in default_config.keys(): 88 | if key not in OpConfig: 89 | OpConfig[key] = default_config[key] 90 | for k, v in OpConfig.items(): 91 | self.__dict__[k] = v 92 | 93 | def parse_for_torch(self): 94 | pass 95 | 96 | def parse_for_tf(self): 97 | pass 98 | 99 | def __str__(self): 100 | return self.__dict__ 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *BiFSMN: Binary Neural Network for Keyword Spotting* 2 | 3 | ***The code for [BiFSMNv2](https://github.com/htqin/BiFSMNv2) is [here](https://github.com/htqin/BiFSMNv2)! It can greatly improve the performance of BiFSMN.*** 4 | 5 | Created by [Haotong Qin](https://htqin.github.io/), Xudong Ma, [Yifu Ding](https://yifu-ding.github.io/), Xiaoyang Li, Yang Zhang, Yao Tian, Zejun Ma, Jie Luo, and [Xianglong Liu](https://xlliu-beihang.github.io/) from Beihang University and Bytedance AI Lab. 6 | 7 | ![loading-ag-172](./overview.png) 8 | 9 | ## Introduction 10 | 11 | This project is the official implementation of our accepted IJCAI 2022 paper *BiFSMN: Binary Neural Network for Keyword Spotting* [[PDF](https://www.ijcai.org/proceedings/2022/0603.pdf)]. The deep neural networks, such as the Deep-FSMN, have been widely studied for keyword spotting (KWS) applications. However, computational resources for these networks are significantly constrained since they usually run on-call on edge devices. In this paper, we present **BiFSMN**, an accurate and extreme-efficient binary neural network for KWS. We first construct a *High-frequency Enhancement Distillation* scheme for the binarization-aware training, which emphasizes the high-frequency information from the full-precision network's representation that is more crucial for the optimization of the binarized network. Then, to allow the instant and adaptive accuracy-efficiency trade-offs at runtime, we also propose a *Thinnable Binarization Architecture* to further liberate the acceleration potential of the binarized network from the topology perspective. Moreover, we implement a *Fast Bitwise Computation Kernel* for BiFSMN on ARMv8 devices which fully utilizes registers and increases instruction throughput to push the limit of deployment efficiency. Extensive experiments show that BiFSMN outperforms existing binarization methods by convincing margins on various datasets and is even comparable with the full-precision counterpart (e.g., less than 3% drop on Speech Commands V1-12). We highlight that benefiting from the thinnable architecture and the optimized 1-bit implementation, BiFSMN can achieve an impressive $22.3\times$ speedup and $15.5\times$ storage-saving on real-world edge hardware. 12 | 13 | ## Datasets and Pretrained Models 14 | 15 | We train and test BiFSMN on Google Speech Commands V1 and V2 datasets, which can be downloaded in the reference document: 16 | 17 | - https://pytorch.org/audio/stable/_modules/torchaudio/datasets/speechcommands.html#SPEECHCOMMANDS 18 | 19 | And we also release a pretrained model on Speech Commands V1-12 task for our distillation. 20 | 21 | ## Execution 22 | 23 | Our experiments are based on the fine-tuned full-precision BiFSMN_pre, which can be found [here](?????????????????). Complete running scripts is provided as follow 24 | 25 | ```shell 26 | python3 train_speech_commands.py \ 27 | --gpu=0 \ 28 | --model=BiDfsmn_thinnable --dfsmn_with_bn \ 29 | --method=Vanilla \ 30 | --distill \ 31 | --distill_alpha=0.01 \ 32 | --select_pass=high \ 33 | --J=1 \ 34 | --pretrained \ 35 | --teacher_model=BiDfsmn_thinnable_pre \ 36 | --teacher_model_checkpoint=${teacher_model_checkpoint_path} \ 37 | --version=speech_commands_v0.01 \ 38 | --num_classes=12 \ 39 | --lr-scheduler=cosin \ 40 | --opt=sgd \ 41 | --lr=5e-3 \ 42 | --weight-decay=1e-4 \ 43 | --epoch=300 \ 44 | 45 | ``` 46 | 47 | ## Citation 48 | 49 | If you find our work useful in your research, please consider citing: 50 | 51 | ```shell 52 | @inproceedings{qin2022bifsmn, 53 | title = {BiFSMN: Binary Neural Network for Keyword Spotting}, 54 | author = {Qin, Haotong and Ma, Xudong and Ding, Yifu and Li, Xiaoyang and Zhang, Yang and Tian, Yao and Ma, Zejun and Luo, Jie and Liu, Xianglong}, 55 | booktitle = {Proceedings of the Thirty-First International Joint Conference on 56 | Artificial Intelligence, {IJCAI-22}}, 57 | pages = {4346--4352}, 58 | year = {2022} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import IntEnum 3 | from enum import unique 4 | import inspect 5 | import sys 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | __all__ = ['Logger', 'Verbose', 'set_slim_logger_level'] 10 | 11 | 12 | @unique 13 | class Verbose(IntEnum): 14 | """ 15 | verbose enum 16 | """ 17 | DEBUG = 0 18 | INFO = 1 19 | WARNING = 2 20 | ERROR = 3 21 | FATAL = 4 22 | 23 | def describe(self): 24 | """ 25 | Usage: verbose.WARNING.describe() --> ('WARNING', 1) 26 | """ 27 | return self.name, self.value 28 | 29 | def __str__(self): 30 | """ 31 | Usage: str(verbose.WARNING) --> 'WARNING' 32 | """ 33 | return 'Target verbose is {0}'.format(self.name) 34 | 35 | @staticmethod 36 | def default_verbose(): 37 | return Verbose.INFO 38 | 39 | 40 | ''' 41 | verbose to msg dict. 42 | ''' 43 | VERBOSE_DICT = { 44 | 'DEBUG': 'DEBUG', 45 | 'INFO': 'INF', 46 | 'WARNING': 'WAR', 47 | 'ERROR': 'ERR', 48 | 'FATAL': 'FTL' 49 | } 50 | 51 | 52 | class ShellColors: 53 | """ 54 | shell color decorator class 55 | """ 56 | OKBLUE = '\033[94m' 57 | OKGREEN = '\033[92m' 58 | WARNING = '\033[93m' 59 | ERROR = '\033[91m' 60 | ENDC = '\033[0m' 61 | BOLD = '\033[1m' 62 | UNDERLINE = '\033[4m' 63 | 64 | def __init__(self, func): 65 | self.func = func 66 | 67 | def __call__(self, *args): 68 | if args[1] == Verbose.DEBUG: 69 | return self.func(args[0], args[1]) + self.ENDC 70 | elif args[1] == Verbose.INFO: 71 | return self.BOLD + self.func(args[0], args[1]) + self.ENDC 72 | elif args[1] == Verbose.WARNING: 73 | return self.WARNING + self.func(args[0], args[1]) + self.ENDC 74 | elif args[1] == Verbose.ERROR: 75 | return self.ERROR + self.func(args[0], args[1]) + self.ENDC 76 | elif args[1] == Verbose.FATAL: 77 | return self.ERROR + self.BOLD + self.UNDERLINE + self.func( 78 | args[0], args[1]) + self.ENDC 79 | else: 80 | raise NameError('ERROR: Logger not support verbose: %s' % 81 | (str(args[1]))) 82 | 83 | 84 | @ShellColors 85 | def with_color(header, verbose): 86 | return header 87 | 88 | 89 | logger_level = Verbose.INFO 90 | 91 | 92 | class Logger: 93 | """ 94 | Logger class. 95 | """ 96 | WithColor = True 97 | # lock mutex for writing log files by multi process. 98 | # lock = Lock() 99 | 100 | Prune = lambda filename: filename.split('/')[-1] 101 | 102 | def __init__(self, verbose=Verbose.default_verbose()): 103 | """ 104 | """ 105 | self.__verbose = verbose 106 | self.log_head = "" + VERBOSE_DICT[verbose.name] + " | " + str( 107 | datetime.now()) + " | " 108 | # 1 represents line at caller 109 | callerframerecord = inspect.stack()[1] 110 | frame = callerframerecord[0] 111 | info = inspect.getframeinfo(frame) 112 | Prune = lambda filename: filename.split('/')[-1] 113 | self.log_head += Prune(info.filename) + ":" + str( 114 | info.lineno) + " " + str(info.function) + "() ] " 115 | 116 | def __call__(self, *args, **kwargs): 117 | """ 118 | feed info to log engine. 119 | """ 120 | msg = ''.join(str(i) for i in args) 121 | full_msg = "" 122 | try: 123 | full_msg = (with_color(self.log_head, self.__verbose) + " " + msg) \ 124 | if Logger.WithColor else (self.log_head + " " + msg) 125 | except NameError: 126 | raise 127 | self.log_to_everywhere(full_msg) 128 | 129 | def log_to_everywhere(self, full_msg): 130 | """ 131 | log to stdout and file 132 | """ 133 | if int(logger_level) <= int(self.__verbose): 134 | print(full_msg) 135 | 136 | # Logger.lock.acquire() 137 | # log to stdout 138 | 139 | sys.stdout.flush() 140 | # logger.lock.release() 141 | 142 | 143 | def set_slim_logger_level(level: Verbose): 144 | """ 145 | Args: 146 | level: set logger level for stdout, support: 'info', 'warning', 'error', 'fatal' 147 | 148 | Returns: 149 | 150 | """ 151 | global logger_level 152 | if int(level) <= int(Verbose.FATAL): 153 | logger_level = level 154 | else: 155 | raise RuntimeError('unsupported logging level: ', level) 156 | -------------------------------------------------------------------------------- /core/torch/ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import abstractmethod 3 | from copy import deepcopy 4 | from byteslim.core.registry import Registry 5 | 6 | OPS = Registry('operators_for_slim') 7 | QAT_OPS = Registry('operators_support_QAT') 8 | Prune_OPS = Registry('operators_support_StructurePruning') 9 | SVD_OPS = Registry('operators_support_StructurePruning') 10 | 11 | 12 | class OpBase(object): 13 | """ 14 | a base class for op attributes and configuration 15 | """ 16 | def __init__(self): 17 | super(OpBase, self).__init__() 18 | 19 | def params(self): 20 | return self.params_name 21 | 22 | def quant_params(self): 23 | return [self.params_name[index] for index in self.quant_param_index] 24 | 25 | def quant_inputs(self): 26 | return [self.inputs_name[index] for index in self.quant_input_index] 27 | 28 | def process(self, module): 29 | # for weight_norm prune 30 | module_params_name = [name for name, _ in module.named_parameters()] 31 | for name in self.params_name: 32 | if name not in module_params_name and name + '_v' in module_params_name: 33 | self.prune_index[name + '_v'] = self.prune_index[name] 34 | self.prune_index.pop(name) 35 | 36 | 37 | # torch modules 38 | @OPS.register 39 | @QAT_OPS.register 40 | @Prune_OPS.register 41 | @SVD_OPS.register 42 | class Linear(OpBase): 43 | def __init__(self): 44 | super(Linear, self).__init__() 45 | self.inputs_name = ['input'] 46 | self.params_name = ['weight', 'bias'] 47 | self.quant_input_index = [0] 48 | self.quant_param_index = [0] 49 | self.quant_channel = {'params': [0]} 50 | self.svd_param_index = [0] 51 | self.prune_index = {'weight': 0} 52 | 53 | 54 | @OPS.register 55 | @QAT_OPS.register 56 | @Prune_OPS.register 57 | class Conv1d(OpBase): 58 | def __init__(self): 59 | super(Conv1d, self).__init__() 60 | self.inputs_name = ['input'] 61 | self.params_name = ['weight', 'bias'] 62 | self.quant_input_index = [0] 63 | self.quant_param_index = [0] 64 | self.quant_channel = {'params': [0]} 65 | self.prune_index = {'weight': 0} 66 | 67 | 68 | @OPS.register 69 | @QAT_OPS.register 70 | @Prune_OPS.register 71 | class Conv2d(OpBase): 72 | def __init__(self): 73 | super(Conv2d, self).__init__() 74 | self.inputs_name = ['input'] 75 | self.params_name = ['weight', 'bias'] 76 | self.quant_input_index = [0] 77 | self.quant_param_index = [0] 78 | self.quant_channel = {'params': [0]} 79 | self.prune_index = {'weight': 0} 80 | 81 | 82 | @OPS.register 83 | @QAT_OPS.register 84 | @Prune_OPS.register 85 | class ConvTranspose1d(OpBase): 86 | def __init__(self): 87 | super(ConvTranspose1d, self).__init__() 88 | self.inputs_name = ['input'] 89 | self.params_name = ['weight', 'bias'] 90 | self.quant_input_index = [0] 91 | self.quant_param_index = [0] 92 | self.quant_channel = {'params': [0]} 93 | self.prune_index = {'weight': 1} 94 | 95 | 96 | @OPS.register 97 | @QAT_OPS.register 98 | @SVD_OPS.register 99 | class LSTM(OpBase): 100 | def __init__(self): 101 | super(LSTM, self).__init__() 102 | self.inputs_name = ['input'] 103 | self.params_name = [ 104 | 'weight_ih_l', 'weight_hh_l', 'bias_ih_l', 'bias_hh_l' 105 | ] 106 | self.quant_input_index = [0] 107 | self.quant_param_index = [0, 1] 108 | self.svd_param_index = [0, 1] 109 | self.quant_channel = {'params': [0, 0]} 110 | 111 | def process(self, module): 112 | params_dict = dict(module.named_parameters()) 113 | self.params_name = [] 114 | self.quant_param_index = [] 115 | self.svd_param_index = [] 116 | self.quant_channel['params'] = [] 117 | for i, name in enumerate(params_dict.keys()): 118 | self.params_name.append(name) 119 | self.quant_channel['params'].append(0) 120 | if 'weight_' in name: 121 | self.quant_param_index.append(i) 122 | self.svd_param_index.append(i) 123 | 124 | 125 | @OPS.register 126 | @QAT_OPS.register 127 | @SVD_OPS.register 128 | class GRU(OpBase): 129 | def __init__(self): 130 | super(GRU, self).__init__() 131 | self.inputs_name = ['input'] 132 | self.params_name = ['weight_ih_l', 'weight_hh_l'] 133 | self.quant_input_index = [0] 134 | self.quant_param_index = [0, 1] 135 | self.svd_param_index = [0, 1] 136 | self.quant_channel = {'params': [0, 0]} 137 | 138 | def process(self, module): 139 | params_dict = dict(module.named_parameters()) 140 | self.params_name = [] 141 | self.quant_apram_index = [] 142 | self.svd_param_index = [] 143 | self.quant_channel['params'] = [] 144 | for i, name in enumerate(params_dict.keys()): 145 | self.params_name.append(name) 146 | self.quant_channel['params'].append(0) 147 | if 'weight_' in name: 148 | self.quant_param_index.append(i) 149 | self.svd_param_index.append(i) 150 | 151 | 152 | @OPS.register 153 | @QAT_OPS.register 154 | class CustomBMM(OpBase): 155 | def __init__(self): 156 | super(CustomBMM, self).__init__() 157 | self.inputs_name = ['q', 'k'] 158 | self.params_name = [] 159 | self.quant_input_index = [0, 1] 160 | self.quant_param_index = [] 161 | self.quant_channel = [] 162 | -------------------------------------------------------------------------------- /torch_utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pickle 5 | import torch 6 | import torch.distributed as dist 7 | 8 | __all__ = [ 9 | 'get_rank', 'get_world_size', 'is_main_process', 'ddp_barrier', 10 | 'synchronize', 'all_reduce', 'all_gather_tensor', 'all_gather' 11 | ] 12 | 13 | 14 | def get_world_size(): 15 | """ 16 | get distributed worker number 17 | :return: 18 | """ 19 | if not dist.is_available(): 20 | return 1 21 | if not dist.is_initialized(): 22 | return 1 23 | return dist.get_world_size() 24 | 25 | 26 | def ddp_barrier(): 27 | """In DDP mode, this function will synchronize all processes. 28 | torch.distributed.barrier() will block processes until the whole 29 | group enters this function. 30 | """ 31 | if torch.distributed.is_initialized(): 32 | torch.distributed.barrier() 33 | 34 | 35 | def get_rank(): 36 | """ 37 | get current rank 38 | :return: rank id if enable distributed 39 | """ 40 | if not dist.is_available(): 41 | return 0 42 | if not dist.is_initialized(): 43 | return 0 44 | return dist.get_rank() 45 | 46 | 47 | def synchronize(): 48 | """ 49 | Helper function to synchronize (barrier) among all processes when 50 | using distributed training 51 | """ 52 | if not dist.is_available(): 53 | return 54 | if not dist.is_initialized(): 55 | return 56 | world_size = dist.get_world_size() 57 | if world_size == 1: 58 | return 59 | dist.barrier() 60 | 61 | 62 | def is_main_process(): 63 | """ 64 | check current process is main or not 65 | :return: return True if current process is the main else return False 66 | """ 67 | return get_rank() == 0 68 | 69 | 70 | def all_reduce(x, reduction='sum'): 71 | """ 72 | do reduction(sum) across multi-gpus 73 | :param x: tensor or list of tensors: 74 | :param reduction: reduction type, support "sum" and "mean" 75 | :return: reduced tensor 76 | """ 77 | if get_world_size() <= 1: 78 | return x 79 | assert reduction in [ 80 | 'sum', 'mean' 81 | ], 'only support reduction type: "sum", "mean", got {}'.format(reduction) 82 | is_dict = isinstance(x, dict) 83 | if is_dict: 84 | for k, v in x.items(): 85 | rt = v.clone() 86 | dist.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 87 | if reduction == 'mean': 88 | rt /= get_world_size() 89 | x[k] = rt 90 | return x 91 | else: 92 | rt = x.clone() 93 | dist.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 94 | if reduction == 'mean': 95 | rt /= get_world_size() 96 | return rt 97 | 98 | 99 | def all_gather_tensor(tensor): 100 | """ 101 | Run all_gather on tensors, all tensors must with same dtype 102 | Args: 103 | tensor: tensor on each rank, can be different in size and shape 104 | Returns: 105 | list tensor: list of tensors from each rank 106 | """ 107 | with torch.no_grad(): 108 | world_size = get_world_size() 109 | rank = get_rank() 110 | if world_size == 1: 111 | return [tensor] 112 | 113 | # transfer tensor to gpu 114 | if not tensor.is_cuda: 115 | tensor = tensor.cuda() 116 | 117 | # gathering tensors of different shapes 118 | tensor_list = [] 119 | 120 | # obtain Tensor size and dims of each rank 121 | local_size = int(tensor.numel()) 122 | local_dim = int(tensor.dim()) 123 | local_size_dims = torch.as_tensor([local_size, local_dim], 124 | dtype=torch.int64).cuda() 125 | 126 | # all gather the max size and max dims from all ranks 127 | size_dims_list = [ 128 | torch.as_tensor([0, 0], dtype=torch.int64).cuda() 129 | for _ in range(world_size) 130 | ] 131 | dist.all_gather(size_dims_list, local_size_dims) 132 | size_list = [int(size[0].item()) for size in size_dims_list] 133 | dims_list = [int(dim[1].item()) for dim in size_dims_list] 134 | 135 | max_dims = max(dims_list) 136 | max_size = max(size_list) 137 | 138 | # obtain original shape, dtype 139 | tensor_shape = [i for i in tensor.shape] 140 | if len(tensor_shape) < max_dims: 141 | tensor_shape.extend( 142 | [0 for _ in range(max_dims - len(tensor_shape))]) 143 | tensor_shape = torch.as_tensor([tensor_shape], dtype=torch.int64).cuda() 144 | 145 | # all gather the shape from all ranks 146 | shape_list = [ 147 | torch.zeros(size=(max_dims, ), dtype=torch.int64).cuda() 148 | for _ in range(world_size) 149 | ] 150 | dist.all_gather(shape_list, tensor_shape) 151 | 152 | # receiving Tensor from all ranks 153 | # we pad the tensor because torch all_gather does not support 154 | for _ in range(world_size): 155 | tensor_list.append( 156 | torch.zeros(size=(max_size, ), dtype=tensor.dtype).cuda()) 157 | 158 | if local_size != max_size: 159 | padding = torch.zeros(size=(max_size - local_size, ), 160 | dtype=tensor.dtype).cuda() 161 | tensor = torch.cat((tensor.view(-1), padding), dim=0) 162 | 163 | dist.all_gather(tensor_list, tensor) 164 | 165 | res = list() 166 | for i in range(world_size): 167 | shape = shape_list[i].cpu().tolist()[:dims_list[i]] 168 | tmp = tensor_list[i][:size_list[i]].view(shape) 169 | res.append(tmp) 170 | return res 171 | 172 | 173 | def all_gather(data): 174 | """ 175 | Run all_gather on arbitrary picklable data (not necessarily tensors) 176 | Args: 177 | data: any picklable object 178 | Returns: 179 | list[data]: list of data gathered from each rank 180 | """ 181 | world_size = get_world_size() 182 | if world_size == 1: 183 | return [data] 184 | 185 | # serialized to a Tensor 186 | buffer = pickle.dumps(data) 187 | storage = torch.ByteStorage.from_buffer(buffer) 188 | tensor = torch.ByteTensor(storage).cuda() 189 | 190 | # obtain Tensor size of each rank 191 | local_size = torch.LongTensor([tensor.numel()]).cuda() 192 | size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)] 193 | dist.all_gather(size_list, local_size) 194 | size_list = [int(size.item()) for size in size_list] 195 | max_size = max(size_list) 196 | 197 | # receiving Tensor from all ranks 198 | # we pad the tensor because torch all_gather does not support 199 | # gathering tensors of different shapes 200 | tensor_list = [] 201 | for _ in size_list: 202 | tensor_list.append(torch.ByteTensor(size=(max_size, )).cuda()) 203 | if local_size != max_size: 204 | padding = torch.ByteTensor(size=(max_size - local_size, )).cuda() 205 | tensor = torch.cat((tensor, padding), dim=0) 206 | dist.all_gather(tensor_list, tensor) 207 | 208 | data_list = [] 209 | for size, tensor in zip(size_list, tensor_list): 210 | buffer = tensor.cpu().numpy().tobytes()[:size] 211 | data_list.append(pickle.loads(buffer)) 212 | 213 | return data_list 214 | -------------------------------------------------------------------------------- /models/torch/dfsmn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .utils import weight_init 4 | 5 | 6 | class DfsmnLayer(nn.Module): 7 | def __init__(self, 8 | hidden_size, 9 | backbone_memory_size, 10 | left_kernel_size, 11 | right_kernel_size, 12 | dilation=1, 13 | dropout=0.0): 14 | super().__init__() 15 | self.fc_trans = nn.Sequential(*[ 16 | nn.Linear(backbone_memory_size, hidden_size), 17 | nn.ReLU(), 18 | nn.Dropout(dropout), 19 | nn.Linear(hidden_size, backbone_memory_size), 20 | nn.Dropout(dropout) 21 | ]) 22 | self.memory = nn.Conv1d(backbone_memory_size, 23 | backbone_memory_size, 24 | kernel_size=left_kernel_size + 25 | right_kernel_size + 1, 26 | padding=0, 27 | stride=1, 28 | dilation=dilation, 29 | groups=backbone_memory_size) 30 | 31 | self.left_kernel_size = left_kernel_size 32 | self.right_kernel_size = right_kernel_size 33 | self.dilation = dilation 34 | self.backbone_memory_size = backbone_memory_size 35 | 36 | def forward(self, input_feat): 37 | # input (B, N, T) 38 | residual = input_feat 39 | # dfsmn-memory 40 | pad_input_fea = F.pad(input_feat, [ 41 | self.left_kernel_size * self.dilation, 42 | self.right_kernel_size * self.dilation 43 | ]) # (B,N,T+(l+r)*d) 44 | memory_out = self.memory(pad_input_fea) + residual 45 | residual = memory_out # (B, N, T) 46 | 47 | # fc-transform 48 | fc_output = self.fc_trans(memory_out.transpose(1, 2)) # (B, T, N) 49 | output = fc_output.transpose(1, 2) + residual # (B, N, T) 50 | return output 51 | 52 | 53 | class DfsmnLayerBN(nn.Module): 54 | def __init__(self, 55 | hidden_size, 56 | backbone_memory_size, 57 | left_kernel_size, 58 | right_kernel_size, 59 | dilation=1, 60 | dropout=0.0): 61 | super().__init__() 62 | self.fc_trans = nn.Sequential(*[ 63 | nn.Conv1d(backbone_memory_size, hidden_size, 1), 64 | nn.BatchNorm1d(hidden_size), 65 | nn.ReLU(), 66 | nn.Dropout(dropout), 67 | nn.Conv1d(hidden_size, backbone_memory_size, 1), 68 | nn.BatchNorm1d(backbone_memory_size), 69 | nn.ReLU(), 70 | nn.Dropout(dropout, ), 71 | ]) 72 | self.memory = nn.Sequential(*[ 73 | nn.Conv1d(backbone_memory_size, 74 | backbone_memory_size, 75 | kernel_size=left_kernel_size + right_kernel_size + 1, 76 | padding=0, 77 | stride=1, 78 | dilation=dilation, 79 | groups=backbone_memory_size), 80 | nn.BatchNorm1d(backbone_memory_size), 81 | nn.ReLU(), 82 | ]) 83 | 84 | self.left_kernel_size = left_kernel_size 85 | self.right_kernel_size = right_kernel_size 86 | self.dilation = dilation 87 | self.backbone_memory_size = backbone_memory_size 88 | 89 | def forward(self, input_feat): 90 | # input (B, N, T) 91 | residual = input_feat 92 | # dfsmn-memory 93 | pad_input_fea = F.pad(input_feat, [ 94 | self.left_kernel_size * self.dilation, 95 | self.right_kernel_size * self.dilation 96 | ]) # (B,N,T+(l+r)*d) 97 | memory_out = self.memory(pad_input_fea) + residual 98 | residual = memory_out # (B, N, T) 99 | 100 | # fc-transform 101 | fc_output = self.fc_trans(memory_out) # (B, T, N) 102 | output = fc_output + residual # (B, N, T) 103 | return output 104 | 105 | 106 | class DfsmnModel(nn.Module): 107 | def __init__(self, 108 | num_classes, 109 | in_channels, 110 | n_mels=32, 111 | num_layer=6, 112 | frondend_channels=16, 113 | frondend_kernel_size=5, 114 | hidden_size=256, 115 | backbone_memory_size=128, 116 | left_kernel_size=2, 117 | right_kernel_size=2, 118 | dilation=1, 119 | dropout=0.2, 120 | dfsmn_with_bn=True, 121 | distill=False, 122 | **kwargs): 123 | super().__init__() 124 | self.front_end = nn.Sequential(*[ 125 | nn.Conv2d(in_channels, 126 | out_channels=frondend_channels, 127 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 128 | stride=(2, 2), 129 | padding=(frondend_kernel_size // 2, 130 | frondend_kernel_size // 2)), 131 | nn.BatchNorm2d(frondend_channels), 132 | nn.ReLU(), 133 | nn.Conv2d(frondend_channels, 134 | out_channels=2 * frondend_channels, 135 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 136 | stride=(2, 2), 137 | padding=(frondend_kernel_size // 2, 138 | frondend_kernel_size // 2)), 139 | nn.BatchNorm2d(2 * frondend_channels), 140 | nn.ReLU() 141 | ]) 142 | self.n_mels = n_mels 143 | self.fc1 = nn.Sequential(*[ 144 | nn.Linear(in_features=2 * frondend_channels * self.n_mels // 4, 145 | out_features=backbone_memory_size), 146 | nn.ReLU(), 147 | ]) 148 | backbone = list() 149 | for idx in range(num_layer): 150 | if dfsmn_with_bn: 151 | backbone.append( 152 | DfsmnLayerBN(hidden_size, backbone_memory_size, 153 | left_kernel_size, right_kernel_size, dilation, 154 | dropout)) 155 | else: 156 | backbone.append( 157 | DfsmnLayer(hidden_size, backbone_memory_size, 158 | left_kernel_size, right_kernel_size, dilation, 159 | dropout)) 160 | self.backbone = nn.Sequential(*backbone) 161 | self.classifier = nn.Sequential(*[ 162 | nn.Dropout(p=dropout), 163 | nn.Linear(backbone_memory_size * self.n_mels // 4, num_classes), 164 | ]) 165 | self.distill = distill 166 | self.apply(weight_init) 167 | 168 | def forward(self, input_feat): 169 | # input_feat: B, 1, N, T 170 | batch = input_feat.shape[0] 171 | out = self.front_end(input_feat) # B, C, N//4, T//4 172 | out = out.view(batch, -1, 173 | out.shape[3]).transpose(1, 2).contiguous() # B, T, N1 174 | out = self.fc1(out).transpose(1, 2).contiguous() # B, N, T 175 | features = [] 176 | for layer in self.backbone: 177 | out = layer(out) 178 | features.append(out) 179 | 180 | out = out.contiguous().view(batch, -1) 181 | out = self.classifier(out) 182 | if self.distill: 183 | return out, features 184 | else: 185 | return out 186 | -------------------------------------------------------------------------------- /core/torch/graph_wrapper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from collections import OrderedDict 4 | from byteslim.core.logger import Logger, Verbose 5 | 6 | OmitType = [ 7 | 'prim::Constant', 'prim::ListConstruct', 'prim::GetAttr', 'aten::to', 8 | 'aten::Int', 'aten::size' 9 | ] 10 | 11 | 12 | class BindSet(object): 13 | def __init__(self): 14 | super(BindSet, self).__init__() 15 | self.cur_set = 0 16 | self.ObjectSet = {} 17 | self.SetObject = {} 18 | 19 | def binding(self, a, b): 20 | if a not in self.ObjectSet and b not in self.ObjectSet: 21 | self.ObjectSet.update({a: self.cur_set, b: self.cur_set}) 22 | self.SetObject.update({self.cur_set: [a]}) 23 | self.SetObject[self.cur_set].append(b) 24 | self.cur_set += 1 25 | elif a in self.ObjectSet and b in self.ObjectSet: 26 | pass 27 | elif a in self.ObjectSet: 28 | self.ObjectSet[b] = self.ObjectSet[a] 29 | self.SetObject[self.ObjectSet[a]].append(b) 30 | else: 31 | self.ObjectSet[a] = self.ObjectSet[b] 32 | self.SetObject[self.ObjectSet[b]].append(a) 33 | 34 | def __getattr__(self, obj): 35 | if obj not in self.ObjectSet: 36 | return [] 37 | else: 38 | return self.SetObject[self.ObjectSet[obj]] 39 | 40 | 41 | class SlimOp(object): 42 | def __init__(self, node, graph): 43 | super(SlimOp, self).__init__() 44 | self.nodes = [node] 45 | self.graph = graph 46 | self.types = [node.kind() for node in self.nodes] 47 | self.scope = self.nodes[0].scopeName().split('/')[-1].replace( 48 | '__module.', '') 49 | self.is_leaf = self._is_leaf() 50 | self.ModuleName = self._GetName() 51 | self.ModuleClass = self._GetModuleClass() 52 | self.prune_config = {} 53 | 54 | def _is_leaf(self): 55 | name = self.scope 56 | for n, m in self.graph.trace.named_modules(): 57 | if n == name: 58 | is_leaf = len(list(m.named_children())) == 0 59 | return is_leaf 60 | raise BaseException("Can not find {} module in model".format(name)) 61 | 62 | def _GetName(self): 63 | name = self.scope 64 | if not self.is_leaf: 65 | name = name + '.' + self._AtenToModule(self.types[0]) 66 | if name + '.' + 'slim_func' in self.graph._ops.keys(): 67 | i = 0 68 | while name + '.' + str(i) in self.graph._ops.keys(): 69 | i += 1 70 | name = name + '.' + str(i) 71 | name = name + '.' + 'slim_func' 72 | return name 73 | 74 | def inputs(self): 75 | all_inputs = [] 76 | all_outputs = [] 77 | op_inputs = [] 78 | for node in self.nodes: 79 | all_inputs.extend(list(node.inputs())) 80 | all_outputs.extend(list(node.outputs())) 81 | for Input in all_inputs: 82 | if Input not in all_outputs and Input.node().kind() not in OmitType: 83 | op_inputs.append(Input) 84 | return op_inputs 85 | 86 | def outputs(self): 87 | all_inputs = [] 88 | all_outputs = [] 89 | op_outputs = [] 90 | for node in self.nodes: 91 | all_inputs.extend(list(node.inputs())) 92 | all_outputs.extend(list(node.outputs())) 93 | for Output in all_outputs: 94 | if Output not in all_inputs and Output.node().kind( 95 | ) not in OmitType: 96 | op_outputs.append(Output) 97 | return op_outputs 98 | 99 | def _AtenToModule(self, AtenName): 100 | if not AtenName.startswith('aten::'): 101 | return AtenName 102 | AtenName = AtenName[6:] 103 | ModuleName = '' 104 | toUpper = True 105 | for c in AtenName: 106 | if c == '_': 107 | toUpper = True 108 | continue 109 | if toUpper: 110 | c = c.upper() 111 | toUpper = False 112 | ModuleName += c 113 | return ModuleName 114 | 115 | def _GetModuleClass(self): 116 | if self.is_leaf: 117 | for n, m in self.graph.trace.named_modules(): 118 | if self.ModuleName == n: 119 | return m._name 120 | else: 121 | assert len( 122 | self.types 123 | ) == 1, 'Convert a function in graph to a Module should to assure that op only have one node' 124 | return self._AtenToModule(self.types[0]) 125 | 126 | def insert_node(self, node): 127 | for i, op_node in enumerate(self.nodes): 128 | for output in node.outputs(): 129 | if output in op_node.inputs(): 130 | self.nodes.insert(i, node) 131 | self.types.insert(i, node.kind()) 132 | return 133 | self.nodes.append(node) 134 | self.types.append(node.kind()) 135 | 136 | def __repr__(self): 137 | next_ops = [op.ModuleName for op in self.graph.next_ops(self)] 138 | string = '\nName : ' + self.ModuleName + ' ' 139 | string += 'Class :' + self.ModuleClass + '\n' 140 | string += 'output: ' + str(next_ops) + '\n' 141 | return string 142 | 143 | 144 | class SlimGraph(object): 145 | def __init__(self, model, dummy_input): 146 | super(SlimGraph, self).__init__() 147 | self.trace = torch.jit.trace(model, dummy_input) 148 | torch._C._jit_pass_inline(self.trace.graph) 149 | self.graph = self.trace.graph 150 | 151 | self.binding_layers = BindSet() 152 | self.RelationOps = ['Add'] 153 | self.DominatingOps = ['Conv2d', 'Conv1d', 'ConvTranspose1d', 'Linear'] 154 | 155 | def get_ops(self): 156 | if not hasattr(self, '_ops'): 157 | self._ops = OrderedDict() 158 | nodes = [ 159 | node for node in self.graph.nodes() if self._is_OpNode(node) 160 | ] 161 | for node in nodes: 162 | fake_op = SlimOp(node, self) 163 | if fake_op.ModuleName in self._ops.keys(): 164 | self._ops[fake_op.ModuleName].insert_node(node) 165 | else: 166 | self._ops.update({fake_op.ModuleName: fake_op}) 167 | return list(self._ops.values()) 168 | 169 | def _is_OpNode(self, node): 170 | is_op = True 171 | if node.kind() in OmitType: 172 | is_op = False 173 | elif not node.kind().startswith('aten::'): 174 | is_op = False 175 | return is_op 176 | 177 | def get_op_by_type(self, type): 178 | return [op for op in self.get_ops() if op.type == type] 179 | 180 | def get_op_by_ModuleName(self, ModuleName): 181 | return [op for op in self.get_ops() if op.ModuleName == ModuleName] 182 | 183 | def next_ops(self, op: SlimOp): 184 | outputs = op.outputs() 185 | next_ops = [] 186 | for op in self.get_ops(): 187 | for output in outputs: 188 | if output in op.inputs(): 189 | next_ops.append(op) 190 | return next_ops 191 | 192 | def pre_ops(self, op: SlimOp): 193 | inputs = op.inputs() 194 | pre_ops = [] 195 | for op in self.get_ops(): 196 | for input in inputs: 197 | if input in op.outputs(): 198 | pre_ops.append(op) 199 | return pre_ops 200 | 201 | def get_tensor_shape(self, tensor): 202 | out_info = re.search('%.*? : .*?\(.*?\)', str(tensor)).group() 203 | shape_info = re.search('\(.*\)', out_info).group()[1:-1] 204 | shape = [int(i) for i in shape_info.split(', ')] 205 | return shape 206 | -------------------------------------------------------------------------------- /speech_commands/dataset/speech_commands.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pathlib import Path 3 | import os 4 | import random 5 | 6 | import torch 7 | import torch.utils.data as data 8 | import torchaudio 9 | 10 | NOISE_FOLDER = "_background_noise_" 11 | 12 | 13 | class SpeechCommandV1(data.Dataset): 14 | """Create a Dataset for Speech Commands V1. 15 | Args: 16 | root (str or Path): Path to the directory where the dataset is found or downloaded. 17 | folder_in_archive (str, optional): 18 | The top-level directory of the dataset. (default: ``"SpeechCommands"``) 19 | download (bool, optional): 20 | Whether to download the dataset if it is not found at root path. (default: ``False``). 21 | subset (str or None, optional): 22 | Select a subset of the dataset [None, "training", "validation", "testing"]. None means 23 | the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and 24 | "testing_list.txt", respectively, and "training" is the rest. Details for the files 25 | "validation_list.txt" and "testing_list.txt" are explained in the README of the dataset 26 | and in the introduction of Section 7 of the original paper and its reference 12. The 27 | original paper can be found `here `_. (Default: ``None``) 28 | transform(None, optional): 29 | methods to transform the audio 30 | num_classes: 31 | support: 12 32 | """ 33 | def __init__( 34 | self, 35 | root: Union[str, Path], 36 | folder_in_archive: str = "SpeechCommands", 37 | download: bool = False, 38 | subset: Optional[str] = None, 39 | silence_percent=0.1, 40 | transform=None, 41 | num_classes=12, 42 | noise_ratio=None, 43 | noise_max_scale=0.4, 44 | cache_origin_data=False, 45 | version = "speech_commands_v0.02" # SpeechCommandV1: v0.02 46 | ) -> None: 47 | self.classes = [ 48 | "yes", 49 | "no", 50 | "up", 51 | "down", 52 | "left", 53 | "right", 54 | "on", 55 | "off", 56 | "stop", 57 | "go", 58 | "zero", 59 | "one", 60 | "two", 61 | "three", 62 | "four", 63 | "five", 64 | "six", 65 | "seven", 66 | "eight", 67 | "nine", 68 | "bed", 69 | "bird", 70 | "cat", 71 | "dog", 72 | "happy", 73 | "house", 74 | "marvin", 75 | "sheila", 76 | "tree", 77 | "wow", 78 | "backward", 79 | "forward", 80 | "follow", 81 | "learn", 82 | "visual", 83 | ] 84 | self.classes_12 = [ 85 | 'unknown', 'silence', 'yes', 'no', 'up', 'down', 'left', 'right', 86 | 'on', 'off', 'stop', 'go' 87 | ] 88 | self.classes_20 = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', 89 | 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'] 90 | self.classes_35 = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 91 | 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 92 | 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 93 | 'up', 'visual', 'wow', 'yes', 'zero'] 94 | 95 | dataset = torchaudio.datasets.SPEECHCOMMANDS(root, version, 96 | folder_in_archive, 97 | download, subset) 98 | data_path = os.path.join(root, folder_in_archive, version) 99 | self.num_classes = num_classes 100 | self.datas = list() 101 | for fileid in dataset._walker: 102 | relpath = os.path.relpath(fileid, data_path) 103 | label, _ = os.path.split(relpath) 104 | label = self.name_to_label(label) 105 | if (label == -1): 106 | continue 107 | self.datas.append([fileid, label]) 108 | 109 | self.sample_rate = 16000 110 | 111 | # setup silence 112 | if silence_percent > 0 and num_classes == 12: 113 | silence_data = [['', self.name_to_label('silence')] 114 | for _ in range(int(len(dataset) * silence_percent))] 115 | self.datas.extend(silence_data) 116 | 117 | # setup noise 118 | self.noise_folder = os.path.join(root, folder_in_archive, version, 119 | NOISE_FOLDER) 120 | self.noise_files = sorted(str(p) for p in Path(self.noise_folder).glob('*.wav')) \ 121 | if subset == 'training' and noise_ratio != None else None 122 | 123 | self.transform = transform 124 | self.noise_ratio = noise_ratio 125 | self.noise_max_scale = noise_max_scale 126 | self.silence_ratio = silence_percent 127 | if noise_ratio is not None and subset is 'training': 128 | assert 0 < noise_max_scale < 1 129 | assert num_classes == 12 or num_classes == 20 or num_classes == 35, 'only support V1-12 now' 130 | self.cache_origin = cache_origin_data 131 | self.origin_datas = dict() 132 | self.origin_noise_datas = dict() 133 | 134 | def __len__(self): 135 | return len(self.datas) 136 | 137 | def label_to_name(self, label): # useless function 138 | if label >= 12: 139 | return 'unknown' 140 | return self.classes_12[label] 141 | 142 | def name_to_label(self, name): 143 | if self.num_classes == 12: 144 | if name not in self.classes_12: 145 | return 0 146 | return self.classes_12.index(name) 147 | elif self.num_classes == 20: 148 | if name not in self.classes_20: 149 | return 0 if self.classes_20 == 'unknown' else -1 150 | return self.classes_20.index(name) 151 | elif self.num_classes == 35: 152 | if name not in self.classes_35: 153 | return 0 if self.classes_35 == 'unknown' else -1 154 | return self.classes_35.index(name) 155 | else: 156 | raise RuntimeError 157 | 158 | def __getitem__(self, index): 159 | """ 160 | return feature and label 161 | """ 162 | # Tensor, int, str, str, int 163 | if index in self.origin_datas.keys(): 164 | [waveform, _, label] = self.origin_datas[index] 165 | else: 166 | waveform, sample_rate, label = self.pull_origin(index) 167 | if sample_rate != self.sample_rate: 168 | raise RuntimeError 169 | if self.cache_origin: 170 | self.origin_datas[index] = [waveform, sample_rate, label] 171 | 172 | if self.noise_files is not None and random.uniform( 173 | 0, 1) < self.noise_ratio: 174 | noise_file = random.choice(self.noise_files) 175 | if noise_file in self.origin_noise_datas.keys(): 176 | waveform_noise = self.origin_noise_datas[noise_file] 177 | else: 178 | waveform_noise, _ = torchaudio.load(noise_file) 179 | if self.cache_origin: 180 | self.origin_noise_datas[noise_file] = waveform_noise 181 | noise_len = waveform_noise.shape[1] 182 | wav_len = waveform.shape[1] 183 | if noise_len >= wav_len: 184 | rand_start = random.randint(0, noise_len - wav_len - 1) 185 | waveform_noise = waveform_noise[:, 186 | rand_start:wav_len + rand_start] 187 | else: 188 | waveform_noise = torch.nn.functional.pad( 189 | waveform_noise, (0, wav_len - noise_len)) 190 | random_scale = random.uniform(0, self.noise_max_scale) 191 | waveform = waveform * (1 - 192 | random_scale) + waveform_noise * random_scale 193 | 194 | if self.transform is not None: 195 | waveform = self.transform(waveform) 196 | return waveform, label 197 | 198 | def pull_origin(self, index): 199 | """ 200 | get original item 201 | """ 202 | [data_id, label] = self.datas[index] 203 | if data_id != '': 204 | waveform, sample_rate = torchaudio.load(data_id) 205 | else: 206 | waveform = torch.zeros(1, 16000) 207 | sample_rate = 16000 208 | return waveform, sample_rate, label 209 | -------------------------------------------------------------------------------- /basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import modules, Parameter 6 | from torch.autograd import Function 7 | 8 | activations = { 9 | 'ReLU': nn.ReLU, 10 | 'Hardtanh': nn.Hardtanh 11 | } 12 | 13 | class BinaryQuantize(Function): 14 | @staticmethod 15 | def forward(ctx, input): 16 | ctx.save_for_backward(input) 17 | out = torch.sign(input) 18 | return out 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | input = ctx.saved_tensors 23 | grad_input = grad_output 24 | grad_input[input[0].gt(1)] = 0 25 | grad_input[input[0].lt(-1)] = 0 26 | return grad_input 27 | 28 | class BinaryQuantize_Vanilla(Function): 29 | @staticmethod 30 | def forward(ctx, input, scale): 31 | ctx.save_for_backward(input) 32 | out = torch.sign(input) 33 | if scale != None: 34 | out = out * scale 35 | return out 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | input = ctx.saved_tensors 40 | grad_input = grad_output 41 | grad_input[input[0].gt(1)] = 0 42 | grad_input[input[0].lt(-1)] = 0 43 | return grad_input, None 44 | 45 | class BiLinearVanilla(torch.nn.Linear): 46 | def __init__(self, in_features, out_features, bias=True): 47 | super(BiLinearVanilla, self).__init__(in_features, out_features, bias=bias) 48 | self.output_ = None 49 | 50 | def forward(self, input): 51 | bw = self.weight 52 | ba = input 53 | sw = bw.abs().mean(-1).view(-1, 1).detach() 54 | bw = BinaryQuantize_Vanilla().apply(bw, sw) 55 | ba = BinaryQuantize().apply(ba) 56 | output = F.linear(ba, bw, self.bias) 57 | self.output_ = output 58 | return output 59 | 60 | biLinears = { 61 | False: nn.Linear, 62 | 'Vanilla': BiLinearVanilla, 63 | } 64 | 65 | class BiConv1dVanilla(torch.nn.Conv1d): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 67 | padding=0, dilation=1, groups=1, 68 | bias=True, padding_mode='zeros'): 69 | super(BiConv1dVanilla, self).__init__( 70 | in_channels, out_channels, kernel_size, stride, padding, dilation, 71 | groups, bias, padding_mode) 72 | 73 | def forward(self, input): 74 | bw = self.weight 75 | ba = input 76 | bw = bw - bw.mean() 77 | sw = bw.abs().view(bw.size(0), bw.size(1), -1).mean(-1).view(bw.size(0), bw.size(1), 1).detach() 78 | bw = BinaryQuantize_Vanilla().apply(bw, sw) 79 | ba = BinaryQuantize().apply(ba) 80 | 81 | if self.padding_mode == 'circular': 82 | expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) 83 | return F.conv1d(F.pad(ba, expanded_padding, mode='circular'), 84 | bw, self.bias, self.stride, 85 | _single(0), self.dilation, self.groups) 86 | return F.conv1d(ba, bw, self.bias, self.stride, 87 | self.padding, self.dilation, self.groups) 88 | 89 | biConv1ds = { 90 | False: nn.Conv1d, 91 | 'Vanilla': BiConv1dVanilla, 92 | } 93 | 94 | class BiConv2dVanilla(torch.nn.Conv2d): 95 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 96 | padding=0, dilation=1, groups=1, 97 | bias=True, padding_mode='zeros'): 98 | super(BiConv2dVanilla, self).__init__( 99 | in_channels, out_channels, kernel_size, stride, padding, dilation, 100 | groups, bias, padding_mode) 101 | 102 | def forward(self, input): 103 | 104 | bw = self.weight 105 | ba = input 106 | bw = bw - bw.mean() 107 | sw = bw.abs().view(bw.size(0), bw.size(1), -1).mean(-1).view(bw.size(0), bw.size(1), 1, -1).detach() 108 | bw = BinaryQuantize_Vanilla().apply(bw, sw) 109 | ba = BinaryQuantize().apply(ba) 110 | 111 | if self.padding_mode == 'circular': 112 | expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2) 113 | return F.conv2d(F.pad(ba, expanded_padding, mode='circular'), 114 | bw, self.bias, self.stride, 115 | _pair(0), self.dilation, self.groups) 116 | return F.conv2d(ba, bw, self.bias, self.stride, 117 | self.padding, self.dilation, self.groups) 118 | 119 | biConv2ds = { 120 | False: nn.Conv2d, 121 | 'Vanilla': BiConv2dVanilla, 122 | } 123 | 124 | def Count(module: nn.Module, id = -1): 125 | id = 0 if id == -1 else id 126 | for name, child_module in module.named_children(): 127 | if isinstance(child_module, nn.ModuleList): 128 | for child_child_module in child_module: 129 | id = Count(child_child_module, id) 130 | else: 131 | id = Count(child_module, id) 132 | if isinstance(child_module, nn.Linear): 133 | id += 1 134 | elif isinstance(child_module, nn.Conv1d): 135 | id += 1 136 | elif isinstance(child_module, nn.Conv2d): 137 | id += 1 138 | return id 139 | 140 | def Modify(module: nn.Module, method='Sign', id=-1, first=-1, last=-1): 141 | id = 0 if id == -1 else id 142 | if method != False: 143 | for name, child_module in module.named_children(): 144 | if isinstance(child_module, nn.ModuleList): 145 | for child_child_module in child_module: 146 | _, id = Modify(child_child_module, method=method, id=id, first=first, last=last) 147 | else: 148 | _, id = Modify(child_module, method=method, id=id, first=first, last=last) 149 | if isinstance(child_module, nn.Linear): 150 | id += 1 151 | if id == first or id == last: 152 | continue 153 | new_layer = biLinears[method](child_module.in_features, 154 | child_module.out_features, 155 | False if child_module.bias == None else True) 156 | new_layer.weight = module._modules[name].weight 157 | new_layer.bias = module._modules[name].bias 158 | module._modules[name] = new_layer 159 | elif isinstance(child_module, nn.Conv1d): 160 | id += 1 161 | if id == first or id == last: 162 | continue 163 | new_layer = biConv1ds[method](in_channels=child_module.in_channels, 164 | out_channels=child_module.out_channels, 165 | kernel_size=child_module.kernel_size, 166 | stride=child_module.stride, 167 | padding=child_module.padding, 168 | dilation=child_module.dilation, 169 | groups=child_module.groups, 170 | bias=False if child_module.bias == None else True, 171 | padding_mode=child_module.padding_mode) 172 | new_layer.weight = module._modules[name].weight 173 | new_layer.bias = module._modules[name].bias 174 | module._modules[name] = new_layer 175 | elif isinstance(child_module, nn.Conv2d): 176 | id += 1 177 | if id == first or id == last: 178 | continue 179 | new_layer = biConv2ds[method](in_channels=child_module.in_channels, 180 | out_channels=child_module.out_channels, 181 | kernel_size=child_module.kernel_size, 182 | stride=child_module.stride, 183 | padding=child_module.padding, 184 | dilation=child_module.dilation, 185 | groups=child_module.groups, 186 | bias=False if child_module.bias == None else True, 187 | padding_mode=child_module.padding_mode) 188 | new_layer.weight = module._modules[name].weight 189 | new_layer.bias = module._modules[name].bias 190 | module._modules[name] = new_layer 191 | return module, id 192 | -------------------------------------------------------------------------------- /core/torch/base_classes.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import torch 4 | import torch.nn as nn 5 | from abc import abstractmethod 6 | from copy import deepcopy 7 | from collections import OrderedDict 8 | from byteslim.nas.torch.utils import ConverterMap 9 | from byteslim.core.config import Logger, Verbose 10 | from .graph_operation import replace_module 11 | 12 | 13 | # OFA_config.config is class type wise and identified by instance.__class__.__name__ 14 | class OFAConfig(object): 15 | def __init__(self, 16 | SearchSpace: dict, 17 | stages: list = ['width'], 18 | binding_layer: dict = {}, 19 | static_layer: list = []): 20 | self.SearchSpace = SearchSpace 21 | self.stages = stages 22 | self.binding_layer = binding_layer 23 | self.static_layer = static_layer 24 | 25 | 26 | class Compressor(object): 27 | def __init__(self, model, config): 28 | self.config = config 29 | self.model = model 30 | self.preprocess(model) 31 | 32 | # Support input compress_info as {module_class: compress_info} format for easy use. 33 | # But it need be unfold before it can be used to compress model. 34 | def unfold_compress_info(self): 35 | for module_name, module in self.model.named_modules(): 36 | op_type = module.__class__.__name__ 37 | if not self.config.NeedSkip(module_name, op_type, 38 | framework='torch'): 39 | op_config = self.config.get_type_config(op_type) 40 | if hasattr(module, 41 | 'slim_config') and module.slim_config['processed']: 42 | parent_module = self.get_related_SlimModule(module) 43 | module_name = parent_module.slim_config['identical_name'] 44 | else: 45 | parent_module = module 46 | self.config.set_layer_config(module_name, op_config) 47 | op_config.process(parent_module) 48 | setattr(op_config, 'op_name', module_name) 49 | self.config.set_layer_config(module_name, op_config) 50 | 51 | def get_module(self, name): 52 | for n, m in self.model.named_modules(): 53 | if not hasattr('slim_config'): 54 | continue 55 | if m.slim_config['identical_name'] == name: 56 | return m 57 | 58 | def get_related_SlimModule(self, module): 59 | if isinstance(module, SlimModule) or not hasattr(module, 'slim_config'): 60 | return None 61 | assert module.slim_config['identical_name'].endswith('_origin'), 'The identical name of module name '\ 62 | 'should be end with "origin" but get {}'.format(module.slim_config['identical_name']) 63 | target_identical_name = module.slim_config['identical_name'][:-7] 64 | for m in self.model.modules(): 65 | if hasattr(m, 'slim_config'): 66 | if m.slim_config['identical_name'] == target_identical_name: 67 | return m 68 | return None 69 | 70 | # Should be used after unfold compress info. 71 | def get_layer_config(self, name): 72 | return self.config.get_layer_config(name) 73 | 74 | def get_module_identity(self, module): 75 | if hasattr(module, 'slim_config'): 76 | return module.slim_config['identical_name'] 77 | else: 78 | return None 79 | 80 | # target_module should be a list of module class, such as nn.Linear, nn.Conv2d 81 | def preprocess(self, model: nn.Module): 82 | self.unfold_compress_info() 83 | target_names = list(self.config.layer_config.keys()) 84 | 85 | for n, m in model.named_modules(): 86 | config = {} 87 | is_base_model = len(list(m.children())) == 0 88 | config['is_base_model'] = is_base_model 89 | config['identical_name'] = n + '_origin' 90 | config['processed'] = False 91 | if n in target_names: 92 | if hasattr(m, 'slim_config'): 93 | if m.slim_config['processed']: 94 | continue 95 | config['processed'] = True 96 | slim_module = SlimModule(m) 97 | setattr(m, 'slim_config', config) 98 | slim_config = deepcopy(config) 99 | slim_config['identical_name'] = n 100 | setattr(slim_module, 'slim_config', slim_config) 101 | replace_module(model, n, slim_module) 102 | else: 103 | setattr(m, 'slim_config', config) 104 | 105 | @abstractmethod 106 | def apply(self, name, ratio): 107 | pass 108 | 109 | @abstractmethod 110 | def compress(self, model): 111 | pass 112 | 113 | 114 | class Compress_Reward(object): 115 | def __init__(self, reduced_flops=0, reduced_params=0): 116 | self.reduced_flops = reduced_flops 117 | self.reduced_params = reduced_params 118 | 119 | @abstractmethod 120 | def __call__(self, module, input): 121 | pass 122 | 123 | 124 | class SlimModule(nn.Module): 125 | def __init__(self, base_model): 126 | super(SlimModule, self).__init__() 127 | self.base_model = base_model 128 | self.base_params = dict(self.base_model.named_parameters()) 129 | for param_name, param in self.base_params.items(): 130 | delattr(self.base_model, param_name) 131 | self.register_parameter(param_name, param) 132 | self.init_params_tensor() 133 | self.assign_params_tensor() 134 | self.type = base_model.__class__.__name__ 135 | 136 | self.pre_ops = nn.ModuleDict() 137 | self.post_ops = nn.ModuleDict() 138 | self.pre_ops_order = OrderedDict() 139 | self.post_ops_order = OrderedDict() 140 | self.original_forward = False 141 | 142 | def __setattr__(self, name: str, value) -> None: 143 | if hasattr(self, 'base_params') and isinstance(value, nn.Parameter): 144 | self.base_params[name] = value 145 | return super().__setattr__(name, value) 146 | 147 | def register_pre_op(self, module: nn.Module, name: str, priority: int): 148 | setattr(module, 'parent_model', self) 149 | self.pre_ops.update({name: module}) 150 | self.pre_ops_order.update({name: priority}) 151 | sorted_list = sorted(self.pre_ops_order.items(), key=lambda x: x[1]) 152 | self.pre_ops_order = OrderedDict(sorted_list) 153 | Logger(Verbose.DEBUG)("Updated pre_ops Order is {}".format( 154 | self.pre_ops_order)) 155 | 156 | def init_params_tensor(self): 157 | self.current_params = {} 158 | for n, p in self.base_params.items(): 159 | #Convert params to a tensor 160 | if isinstance(p, nn.Parameter): 161 | self.current_params[n] = p.clone() 162 | 163 | def init_act_tensor(self, *args, **kwargs): 164 | sig = inspect.signature(self.base_model.forward) 165 | inputs = OrderedDict(sig.parameters) 166 | inputs_names = list(inputs.keys()) 167 | inputs_dict = OrderedDict() 168 | for i, act in enumerate(args): 169 | inputs_dict[inputs_names[i]] = act 170 | inputs.pop(inputs_names[i]) 171 | for k, v in kwargs.items(): 172 | inputs_dict[k] = v 173 | inputs.pop(k) 174 | for k, v in inputs.items(): 175 | if k not in inputs_dict.keys() and v.default is not inspect._empty: 176 | inputs_dict[k] = v.default 177 | return inputs_dict 178 | 179 | def assign_params_tensor(self): 180 | for name, param_tensor in self.current_params.items(): 181 | setattr(self.base_model, name, param_tensor) 182 | 183 | def recover_base_model(self): 184 | for name, param in self.base_params.items(): 185 | if hasattr(self.base_model, name): 186 | delattr(self.base_model, name) 187 | self.base_model.register_parameter(name, param) 188 | if hasattr(self.base_model, 'slim_config'): 189 | delattr(self.base_model, 'slim_config') 190 | 191 | def forward(self, *args, **kwargs): 192 | self.init_params_tensor() 193 | inputs = self.init_act_tensor(*args, **kwargs) 194 | if self.original_forward: 195 | return self.model(**inputs) 196 | else: 197 | for n, p in self.pre_ops_order.items(): 198 | inputs = self.pre_ops[n](**inputs) 199 | self.assign_params_tensor() 200 | 201 | if 'flatten_parameters' in dir(self.base_model): 202 | self.base_model.flatten_parameters() 203 | 204 | inputs = self.base_model(**inputs) 205 | for n, p in self.post_ops_order.items(): 206 | inputs = self.post_ops[n](**inputs) 207 | delattr(self, 'current_params') 208 | if isinstance(inputs, dict): 209 | return tuple(inputs.values()) 210 | else: 211 | return inputs 212 | 213 | 214 | class SlimWorker(nn.Module): 215 | def __init__(self): 216 | super(SlimWorker, self).__init__() 217 | 218 | def __repr__(self): 219 | return self.__class__.__name__ 220 | 221 | def _get_target_param_name(self, param_name) -> list: 222 | if isinstance(self.parent_model.base_params[param_name], nn.Parameter): 223 | return [param_name] 224 | else: 225 | return [name for name in self.parent_model.base_params[param_name]] 226 | 227 | def get_param(self, param_name) -> torch.Tensor: 228 | if param_name in self.parent_model.current_params.keys(): 229 | return self.parent_model.current_params[param_name] 230 | else: 231 | return None 232 | 233 | def assign_param(self, param_name, param_value): 234 | self.parent_model.current_params[param_name] = param_value 235 | 236 | def named_children(self): 237 | memo = set() 238 | for name, module in self._modules.items(): 239 | if module is not None and module not in memo and name != 'parent_model': 240 | memo.add(module) 241 | yield name, module 242 | 243 | def __setattr__(self, name: str, value) -> None: 244 | def remove_from(*dicts_or_sets): 245 | for d in dicts_or_sets: 246 | if name in d: 247 | if isinstance(d, dict): 248 | del d[name] 249 | else: 250 | d.discard(name) 251 | 252 | params = self.__dict__.get('_parameters') 253 | if isinstance(value, nn.Parameter): 254 | if params is None: 255 | raise AttributeError( 256 | "cannot assign parameters before Module.__init__() call") 257 | remove_from(self.__dict__, self._buffers, self._modules, 258 | self._non_persistent_buffers_set) 259 | self.register_parameter(name, value) 260 | elif params is not None and name in params: 261 | if value is not None: 262 | raise TypeError("cannot assign '{}' as parameter '{}' " 263 | "(torch.nn.Parameter or None expected)".format( 264 | torch.typename(value), name)) 265 | self.register_parameter(name, value) 266 | else: 267 | modules = self.__dict__.get('_modules') 268 | if isinstance(value, nn.Module): 269 | if modules is None: 270 | raise AttributeError( 271 | "cannot assign module before Module.__init__() call") 272 | if name == 'parent_model': 273 | object.__setattr__(self, name, value) 274 | else: 275 | # only need when torch.__version__ >= 1.6.0 276 | try: 277 | remove_from(self.__dict__, self._parameters, 278 | self._buffers, 279 | self._non_persistent_buffers_set) 280 | except: 281 | pass 282 | modules[name] = value 283 | elif modules is not None and name in modules: 284 | if value is not None: 285 | raise TypeError("cannot assign '{}' as child module '{}' " 286 | "(torch.nn.Module or None expected)".format( 287 | torch.typename(value), name)) 288 | if name == 'parent_model': 289 | object.__setattr__(self, name, value) 290 | else: 291 | modules[name] = value 292 | else: 293 | buffers = self.__dict__.get('_buffers') 294 | if buffers is not None and name in buffers: 295 | if value is not None and not isinstance( 296 | value, torch.Tensor): 297 | raise TypeError( 298 | "cannot assign '{}' as buffer '{}' " 299 | "(torch.Tensor or None expected)".format( 300 | torch.typename(value), name)) 301 | buffers[name] = value 302 | else: 303 | object.__setattr__(self, name, value) 304 | 305 | 306 | class SuperNet(nn.Module): 307 | def __init__(self, model, Config: OFAConfig): 308 | super(SuperNet, self).__init__() 309 | self.model = model 310 | self.config = Config 311 | self.elastic_layers = {} 312 | for n, m in self.model.named_modules(): 313 | if m.__class__.__name__ in Config.SearchSpace.keys(): 314 | Converter = ConverterMap[m.__class__.__name__] 315 | elastic_module = Converter(m) 316 | if n in self.config.static_layer: 317 | elastic_module.set_static() 318 | replace_module(self.model, n, elastic_module) 319 | self.elastic_layers.update({n: elastic_module}) 320 | del (m) 321 | 322 | self.init_SearchSpace() 323 | self.step_count = 0 324 | self.shrink_step = None 325 | self.sample_step = None 326 | self.stage = self.config.stages[0] 327 | self._update_CurrSpace(self.stage) 328 | self._sample_StructConfig(self.stage) 329 | 330 | def forward(self, *args, **kwargs): 331 | return self.model(*args, **kwargs) 332 | 333 | def init_SearchSpace(self): 334 | for n, m in self.elastic_layers.items(): 335 | m.set_SearchSpace(deepcopy(self.config.SearchSpace[m.OriginName])) 336 | 337 | def set_stage(self, stage): 338 | self.stage = stage 339 | 340 | def _update_CurrSpace(self, stage): 341 | for n, m in self.elastic_layers.items(): 342 | m.update_CurrSpace(stage) 343 | #Logger(Verbose.DEBUG)('{} CurrSpace: {}'.format(n, m.CurrSpace)) 344 | 345 | def _sample_StructConfig(self, stage): 346 | sample_Structs = {} 347 | for n, m in self.elastic_layers.items(): 348 | sample_StructConfig = m.sample_StructConfig(stage) 349 | if n in self.config.binding_layer.keys(): 350 | sample_Structs[n] = sample_StructConfig 351 | #Logger(Verbose.DEBUG)('{} StructConfig: {}'.format(n, m.StructConfig)) 352 | 353 | for main_layer in self.config.binding_layer.keys(): 354 | for layer_name in self.config.binding_layer[main_layer]: 355 | sample_StructConfig = sample_Structs[main_layer] 356 | self.elastic_layers[layer_name].set_StructConfig( 357 | sample_StructConfig) 358 | #Logger(Verbose.DEBUG)('{} StructConfig: {}'.format( 359 | # layer_name, self.elastic_layers[layer_name].StructConfig)) 360 | 361 | def _set_StructConfig(self, layer_name, config): 362 | assert layer_name in self.elastic_layers.keys( 363 | ), 'The `layer_name` should be a ElasticLayer' 364 | self.elastic_layers[layer_name].set_StructConfig(config) 365 | #Logger(Verbose.DEBUG)('{} StructConfig: {}'.format( 366 | # layer_name, self.elastic_layers[layer_name].StructConfig)) 367 | binding_layers = self.config.binding_layer[layer_name] 368 | for layer in binding_layers: 369 | self.elastic_layers[layer].set_StructConfig(config) 370 | #Logger(Verbose.DEBUG)('{} StructConfig: {}'.format( 371 | # layer, self.elastic_layers[layer].StructConfig)) 372 | 373 | def set_controller(self, shrink_step, sample_step): 374 | self.shrink_step = shrink_step 375 | self.sample_step = sample_step 376 | 377 | def get_SearchSpace(self): 378 | SearchSpace = {} 379 | for n, m in self.elastic_layers.items(): 380 | SearchSpace.update({n: m.SearchSpace}) 381 | return SearchSpace 382 | 383 | def get_CurrSpace(self): 384 | CurrSpace = {} 385 | for n, m in self.elastic_layers.items(): 386 | CurrSpace.update({n: m.CurrSpace}) 387 | return CurrSpace 388 | 389 | def get_StructConfig(self): 390 | StructConfig = {} 391 | for n, m in self.elastic_layers.items(): 392 | StructConfig.update({n: m.StructConfig}) 393 | return StructConfig 394 | 395 | def step(self): 396 | self.step_count += 1 397 | if self.shrink_step and self.step_count % self.shrink_step == 0: 398 | self._update_CurrSpace(self.stage) 399 | if self.sample_step and self.step_count % self.sample_step == 0: 400 | self._sample_StructConfig(self.stage) 401 | 402 | def export(self, StructConfigs=None, dummy_input=None): 403 | if StructConfigs is not None: 404 | for n, m in StructConfigs.items(): 405 | self.elastic_layers[n].StructConfig = m 406 | if dummy_input is not None: 407 | self.model(dummy_input) 408 | for n, m in self.elastic_layers.items(): 409 | static_module = m.recover() 410 | replace_module(self.model, n, static_module) 411 | return self.model 412 | -------------------------------------------------------------------------------- /train_speech_commands.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import random 6 | import warnings 7 | import os 8 | 9 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7' 10 | 11 | # fixme, raise runtime error @2 epoch, received 0 items of ancdata 12 | # https://github.com/pytorch/pytorch/issues/973 13 | # https://github.com/fastai/fastai/issues/23 14 | try: 15 | import resource 16 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 17 | resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) 18 | except: 19 | print('no resource') 20 | # sudo sh -c "ulimit -n 65535 && exec su $LOGNAME" 21 | 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | import torch.multiprocessing as mp 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | from core.logger import Logger, Verbose 28 | 29 | from train_valid_test import train_epoch_distill, valid_epoch_distill, train_epoch, valid_epoch, \ 30 | create_lr_schedule, create_optimizer, get_model, create_dataloader 31 | 32 | def parse_args(): 33 | """ 34 | Parse input arguments 35 | """ 36 | # general args 37 | parser = argparse.ArgumentParser(description='Training') 38 | parser.add_argument('--saveroot', 39 | help='set root folder for log and checkpoint', 40 | type=str, 41 | default='speech_command') 42 | parser.add_argument('--dataroot', 43 | help='set root folder for dataset', 44 | type=str, 45 | default='/home/datasets/SpeechCommands') 46 | parser.add_argument('--checkpoint', 47 | help='choose a checkpoint to resume', 48 | type=str, 49 | default=None) 50 | parser.add_argument( 51 | '--test', 52 | action='store_true', 53 | help='test accuracy with input checkpoint', 54 | ) 55 | 56 | # model args 57 | parser.add_argument('--n_mels', 58 | type=int, 59 | default=32, 60 | help='mel feature size') 61 | parser.add_argument( 62 | '--model', 63 | type=str, 64 | default='Dfsmn') 65 | parser.add_argument('--dfsmn_with_bn', 66 | action='store_true', 67 | help='use BatchNorm for Dfsmn model') 68 | parser.add_argument('--num_layer', 69 | type=int, 70 | default=8, 71 | help='num_layer for Dfsmn model') 72 | parser.add_argument('--frondend_channels', 73 | type=int, 74 | default=16, 75 | help='frondend_channels for Dfsmn model') 76 | parser.add_argument('--frondend_kernel_size', 77 | type=int, 78 | default=5, 79 | help='frondend_kernel_size for Dfsmn model') 80 | parser.add_argument('--hidden_size', 81 | type=int, 82 | default=256, 83 | help='hidden_size for Dfsmn model') 84 | parser.add_argument('--backbone_memory_size', 85 | type=int, 86 | default=128, 87 | help='backbone_memory_size for Dfsmn model') 88 | parser.add_argument('--left_kernel_size', 89 | type=int, 90 | default=2, 91 | help='left_kernel_size for Dfsmn model') 92 | parser.add_argument('--right_kernel_size', 93 | type=int, 94 | default=2, 95 | help='right_kernel_size for Dfsmn model') 96 | 97 | # args for training hyper parameters 98 | parser.add_argument("--epoch", type=int, default=300, help='total epochs') 99 | parser.add_argument("--batch-size", type=int, default=96, help='batch size') 100 | parser.add_argument("--lr", type=float, default=1e-3, help='learning rate') 101 | parser.add_argument("--lr-scheduler", 102 | choices=['plateau', 'step', 'cosin'], 103 | default='cosin', 104 | help='method to adjust learning rate') 105 | parser.add_argument("--weight-decay", 106 | type=float, 107 | default=1e-2, 108 | help='weight decay') 109 | parser.add_argument( 110 | "--lr-scheduler-patience", 111 | type=int, 112 | default=5, 113 | help='lr scheduler plateau: Number of epochs with no improvement ' 114 | 'after which learning rate will be reduced') 115 | parser.add_argument( 116 | "--lr-scheduler-stepsize", 117 | type=int, 118 | default=5, 119 | help='lr scheduler step: number of epochs of learning rate decay.') 120 | parser.add_argument( 121 | "--lr-scheduler-gamma", 122 | type=float, 123 | default=0.1, 124 | help='learning rate is multiplied by the gamma to decrease it') 125 | parser.add_argument("--optim", 126 | choices=['sgd', 'adam'], 127 | default='sgd', 128 | help='choices of optimization algorithms') 129 | parser.add_argument( 130 | "--label_smoothing", 131 | type=float, 132 | default=0, 133 | help='label_smoothing (float, optional): A float in [0.0, 1.0].') 134 | parser.add_argument("--mixup_alpha", 135 | type=float, 136 | default=0, 137 | help='mixup alpha.') 138 | 139 | parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') 140 | parser.add_argument('--seed', 141 | default=None, 142 | type=int, 143 | help='seed for initializing training. ') 144 | 145 | # args for distill/thinnable 146 | parser.add_argument('--num_classes', type=int, default=12, choices=[12, 20, 35], help='num_classes for dataset') 147 | parser.add_argument('--version', default="speech_commands_v0.01", choices=["speech_commands_v0.01", "speech_commands_v0.02"], type=str, help='dataset version') 148 | parser.add_argument('--thin_n', type=int, default=3, choices=[1, 2, 3, 4], help='ways for BiDfsmn_thinnable') 149 | parser.add_argument("--distill", action='store_true', help='disitll') 150 | parser.add_argument("--distill_alpha", type=float, default=0, help='disitll alpha.') 151 | parser.add_argument("--teacher_model", choices=['Vgg19Bn', 'Mobilenetv1', 'Mobilenetv2', 'BCResNet', 'Dfsmn', 'BiDfsmn', 'BiDfsmn_thinnable', 'BiDfsmn_thinnable_pre'], type=str, default='Dfsmn', help='teacher model') 152 | parser.add_argument('--teacher_model_checkpoint', type=str, help='teacher pretrained model path: saveroot + teacher_model_checkpoint') 153 | parser.add_argument('--pretrained', action='store_true', help='load the pre-trained teacher model') 154 | parser.add_argument("--select_pass", type=str, default='no', choices=['no', 'low', 'high'], help='high-pass or low-pass for wavelet.') 155 | parser.add_argument("--J", type=int, default=1, help='scale of wavelet.') 156 | parser.add_argument("--method", type=str, default='no', help='bi method.') 157 | 158 | parsed_args = parser.parse_args() 159 | return parsed_args 160 | 161 | 162 | def test_speech_commands(configs, gpu_id=None): 163 | model = get_model(configs.model, 164 | in_channels=1, 165 | **(vars(configs))) 166 | print(model) 167 | nparams = sum(p.numel() for p in model.parameters() if p.requires_grad) 168 | names_params = { 169 | n: p.numel() * 1e-6 170 | for n, p in model.named_parameters() if p.requires_grad 171 | } 172 | sorted_names_params = sorted(names_params.items(), 173 | key=lambda kv: kv[1], 174 | reverse=True) 175 | print(sorted_names_params) 176 | Logger(Verbose.INFO)( 177 | 'create model: {}, with {} M Params(With BN param)'.format( 178 | configs.model, nparams * 1e-6)) 179 | 180 | if configs.checkpoint is None: 181 | raise RuntimeError('test mode must provider checkpoint') 182 | 183 | chpk = torch.load(configs.checkpoint) 184 | model.load_state_dict(chpk['state_dict']) 185 | if gpu_id is not None: 186 | model.cuda(gpu_id) 187 | 188 | dataloader_test = create_dataloader('testing', 189 | configs, 190 | use_gpu=gpu_id is not None, 191 | version=configs.version) 192 | 193 | criterion = torch.nn.CrossEntropyLoss( 194 | label_smoothing=configs.label_smoothing) 195 | 196 | if configs.distill: 197 | valid_loss, accuracy = valid_epoch_distill(model, criterion, dataloader_test, 198 | 0, gpu_id is not None, 10, None) 199 | Logger(Verbose.INFO)('checkpoint: {}, loss: {}, accuracy: {}'.format( 200 | configs.checkpoint, valid_loss, accuracy)) 201 | else: 202 | valid_loss, accuracy = valid_epoch(model, criterion, dataloader_test, 0, 203 | gpu_id is not None, 10, None) 204 | Logger(Verbose.INFO)('checkpoint: {}, loss: {}, accuracy: {}'.format( 205 | configs.checkpoint, valid_loss, accuracy)) 206 | 207 | 208 | def train_speech_commands(configs, gpu_id=None): 209 | best_accuracy = 0 210 | best_accuracys = None 211 | epoch = 0 212 | 213 | use_gpu = torch.cuda.is_available() 214 | if gpu_id is not None: 215 | torch.cuda.set_device(gpu_id) 216 | 217 | model = get_model(configs.model, 218 | in_channels=1, 219 | **(vars(configs))) 220 | print(model) 221 | nparams = sum(p.numel() for p in model.parameters() if p.requires_grad) 222 | names_params = { 223 | n: p.numel() * 1e-6 224 | for n, p in model.named_parameters() if p.requires_grad 225 | } 226 | sorted_names_params = sorted(names_params.items(), 227 | key=lambda kv: kv[1], 228 | reverse=True) 229 | print(sorted_names_params) 230 | Logger(Verbose.INFO)( 231 | 'create model: {}, with {} M Params(With BN param)'.format( 232 | configs.model, nparams * 1e-6)) 233 | 234 | teacher_model = None 235 | if configs.distill: 236 | teacher_model = get_model(configs.teacher_model, 237 | in_channels=1, 238 | **(vars(configs))) 239 | chpk = torch.load(os.path.join(configs.saveroot, configs.teacher_model_checkpoint)) 240 | teacher_model.load_state_dict(chpk['state_dict'], strict=False) 241 | if configs.pretrained: 242 | chpk = torch.load(os.path.join(configs.saveroot, configs.teacher_model_checkpoint)) 243 | model.load_state_dict(chpk['state_dict'], strict=False) 244 | 245 | criterion = torch.nn.CrossEntropyLoss( 246 | label_smoothing=configs.label_smoothing) 247 | 248 | optimizer = create_optimizer(configs, model) 249 | if configs.checkpoint is not None: 250 | chpk = torch.load(configs.checkpoint) 251 | best_accuracy = chpk['accuracy'] 252 | epoch = chpk['epoch'] 253 | model.load_state_dict(chpk['state_dict']) 254 | optimizer.load_state_dict(chpk['optimizer']) 255 | 256 | lr_scheduler = create_lr_schedule(configs, optimizer) 257 | 258 | dataloader_train = create_dataloader('training', configs, use_gpu, version=configs.version) 259 | 260 | dataloader_valid = create_dataloader('validation', configs, use_gpu, version=configs.version) 261 | 262 | if gpu_id is not None: 263 | model = model.cuda(gpu_id) 264 | if teacher_model != None: 265 | teacher_model = teacher_model.cuda(gpu_id) 266 | 267 | writer = SummaryWriter(log_dir=os.path.join(configs.saveroot, 'Log'), 268 | flush_secs=10) 269 | 270 | # train 271 | for cur_epoch in range(epoch, configs.epoch): 272 | Logger(Verbose.INFO)("runing on epoch: {}, learning_rate: {}".format( 273 | cur_epoch, optimizer.param_groups[0]['lr'])) 274 | 275 | if configs.distill: 276 | train_loss = train_epoch_distill(model, 277 | teacher_model, 278 | optimizer, 279 | criterion, 280 | dataloader_train, 281 | epoch=cur_epoch, 282 | with_gpu=use_gpu, 283 | log_iter=10, 284 | writer=writer, 285 | mixup_alpha=configs.mixup_alpha, 286 | distill_alpha=configs.distill_alpha, 287 | select_pass=configs.select_pass, 288 | J=configs.J, 289 | num_classes=configs.num_classes) 290 | valid_loss, accuracy = valid_epoch_distill(model, criterion, dataloader_valid, 291 | cur_epoch, use_gpu, 10, writer) 292 | else: 293 | train_loss = train_epoch(model, 294 | optimizer, 295 | criterion, 296 | dataloader_train, 297 | epoch=cur_epoch, 298 | with_gpu=use_gpu, 299 | log_iter=10, 300 | writer=writer, 301 | mixup_alpha=configs.mixup_alpha, 302 | num_classes=configs.num_classes) 303 | valid_loss, accuracy = valid_epoch(model, criterion, dataloader_valid, 304 | cur_epoch, use_gpu, 10, writer) 305 | 306 | # valid_loss, accuracy = 0, 0 307 | if configs.lr_scheduler == 'plateau': 308 | lr_scheduler.step(metrics=valid_loss) 309 | else: 310 | lr_scheduler.step() 311 | 312 | if not isinstance(accuracy, list): 313 | if accuracy > best_accuracy: 314 | best_accuracy = accuracy 315 | Logger( 316 | Verbose.INFO 317 | )("Got better checkpointer, epoch: {}, accuracy: {}, valid loss: {}" 318 | .format(cur_epoch, best_accuracy, valid_loss)) 319 | checkpoint = { 320 | 'epoch': cur_epoch, 321 | 'state_dict': model.cpu().state_dict(), 322 | 'accuracy': best_accuracy, 323 | 'optimizer': optimizer.state_dict(), 324 | } 325 | pth_name = '{}_acc_{}_epoch_{}_lr_{}_wd_{}_lrscheudle_{}_v{}-{}'.format( 326 | configs.model, best_accuracy, cur_epoch, configs.lr, configs.weight_decay, 327 | configs.lr_scheduler, int(configs.version[-1:]), int(configs.num_classes)) 328 | if configs.distill: 329 | pth_name = pth_name + '_distill_{}'.format(configs.distill_alpha) 330 | if configs.select_pass != 'no': 331 | pth_name = pth_name + '_' + configs.select_pass + '_J_{}'.format(configs.J) 332 | pth_name = pth_name + '.pth' 333 | best_checkpoint_path = os.path.join( 334 | configs.saveroot, 335 | pth_name) 336 | torch.save(checkpoint, best_checkpoint_path) 337 | configs.checkpoint = best_checkpoint_path 338 | Logger(Verbose.INFO)('train loss: ', train_loss, 339 | ', valid: best_accuracy: ', best_accuracy, 340 | ', cur_accuracy: ', accuracy, ', valid loss: ', 341 | valid_loss) 342 | else: 343 | if best_accuracys == None: 344 | best_accuracys = accuracy 345 | avg_accuracy = accuracy[0] 346 | if avg_accuracy > best_accuracy and min([x - y for x, y in zip(accuracy[:-1], accuracy[1:])]) > 0: 347 | best_accuracy = avg_accuracy 348 | best_accuracys = accuracy 349 | Logger( 350 | Verbose.INFO 351 | )("Got better checkpointer, epoch: {}, accuracy: {}, valid loss: {}" 352 | .format(cur_epoch, best_accuracy, valid_loss)) 353 | checkpoint = { 354 | 'epoch': cur_epoch, 355 | 'state_dict': model.cpu().state_dict(), 356 | 'accuracy': best_accuracy, 357 | 'optimizer': optimizer.state_dict(), 358 | } 359 | pth_name = '{}_acc_{}_epoch_{}_lr_{}_wd_{}_lrscheudle_{}_v{}-{}'.format( 360 | configs.model, best_accuracy, cur_epoch, configs.lr, configs.weight_decay, 361 | configs.lr_scheduler, int(configs.version[-1:]), int(configs.num_classes)) 362 | if configs.distill: 363 | pth_name = pth_name + '_distill_{}'.format(configs.distill_alpha) 364 | if configs.select_pass != 'no': 365 | pth_name = pth_name + '_' + configs.select_pass + '_J_{}'.format(configs.J) 366 | pth_name = pth_name + '.pth' 367 | best_checkpoint_path = os.path.join( 368 | configs.saveroot, 369 | pth_name) 370 | torch.save(checkpoint, best_checkpoint_path) 371 | configs.checkpoint = best_checkpoint_path 372 | 373 | Logger(Verbose.INFO)('train loss: ', train_loss, 374 | ', valid: best_accuracy: ', best_accuracy, 375 | ', cur_accuracy: ', ['%.4f%%' % (x * 100) for x in accuracy], 376 | ', best_accuracys', ['%.4f%%' % (x * 100) for x in best_accuracys], 377 | ', valid loss: ', valid_loss) 378 | 379 | test_speech_commands(configs, gpu_id) 380 | 381 | 382 | if __name__ == '__main__': 383 | mp.set_start_method('spawn') 384 | args = parse_args() 385 | 386 | if args.test: 387 | test_speech_commands(args, args.gpu) 388 | else: 389 | if args.seed is not None: 390 | random.seed(args.seed) 391 | torch.manual_seed(args.seed) 392 | cudnn.deterministic = True 393 | warnings.warn('You have chosen to seed training. ' 394 | 'This will turn on the CUDNN deterministic setting, ' 395 | 'which can slow down your training considerably! ' 396 | 'You may see unexpected behavior when restarting ' 397 | 'from checkpoints.') 398 | 399 | if args.gpu is not None: 400 | warnings.warn( 401 | 'You have chosen a specific GPU. This will completely ' 402 | 'disable data parallelism.') 403 | 404 | # build model 405 | os.makedirs(args.saveroot, exist_ok=True) 406 | os.makedirs(os.path.join(args.saveroot, 'Log'), exist_ok=True) 407 | 408 | train_speech_commands(args, gpu_id=args.gpu) 409 | -------------------------------------------------------------------------------- /models/torch/bidfsmn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | # from .utils import weight_init 6 | import argparse 7 | 8 | class BiDfsmnLayer(nn.Module): 9 | def __init__(self, 10 | hidden_size, 11 | backbone_memory_size, 12 | left_kernel_size, 13 | right_kernel_size, 14 | dilation=1, 15 | dropout=0.0): 16 | super().__init__() 17 | self.fc_trans = nn.Sequential(*[ 18 | nn.Linear(backbone_memory_size, hidden_size), 19 | nn.PReLU(), 20 | nn.Dropout(dropout), 21 | nn.Linear(hidden_size, backbone_memory_size), 22 | nn.Dropout(dropout) 23 | ]) 24 | self.memory = nn.Conv1d(backbone_memory_size, 25 | backbone_memory_size, 26 | kernel_size=left_kernel_size + 27 | right_kernel_size + 1, 28 | padding=0, 29 | stride=1, 30 | dilation=dilation, 31 | groups=backbone_memory_size) 32 | 33 | self.left_kernel_size = left_kernel_size 34 | self.right_kernel_size = right_kernel_size 35 | self.dilation = dilation 36 | self.backbone_memory_size = backbone_memory_size 37 | 38 | def forward(self, input_feat): 39 | # input (B, N, T) 40 | residual = input_feat 41 | # dfsmn-memory 42 | pad_input_fea = F.pad(input_feat, [ 43 | self.left_kernel_size * self.dilation, 44 | self.right_kernel_size * self.dilation 45 | ]) # (B,N,T+(l+r)*d) 46 | memory_out = self.memory(pad_input_fea) + residual 47 | residual = memory_out # (B, N, T) 48 | 49 | # fc-transform 50 | fc_output = self.fc_trans(memory_out.transpose(1, 2)) # (B, T, N) 51 | output = fc_output.transpose(1, 2) + residual # (B, N, T) 52 | return output 53 | 54 | 55 | class BiDfsmnLayerBN(nn.Module): 56 | def __init__(self, 57 | hidden_size, 58 | backbone_memory_size, 59 | left_kernel_size, 60 | right_kernel_size, 61 | dilation=1, 62 | dropout=0.0): 63 | super().__init__() 64 | self.fc_trans = nn.Sequential(*[ 65 | nn.Conv1d(backbone_memory_size, hidden_size, 1), 66 | nn.BatchNorm1d(hidden_size), 67 | nn.PReLU(), 68 | nn.Dropout(dropout), 69 | nn.Conv1d(hidden_size, backbone_memory_size, 1), 70 | nn.BatchNorm1d(backbone_memory_size), 71 | nn.PReLU(), 72 | nn.Dropout(dropout, ), 73 | ]) 74 | self.memory = nn.Sequential(*[ 75 | nn.Conv1d(backbone_memory_size, 76 | backbone_memory_size, 77 | kernel_size=left_kernel_size + right_kernel_size + 1, 78 | padding=0, 79 | stride=1, 80 | dilation=dilation, 81 | groups=backbone_memory_size), 82 | nn.BatchNorm1d(backbone_memory_size), 83 | nn.PReLU(), 84 | ]) 85 | 86 | self.left_kernel_size = left_kernel_size 87 | self.right_kernel_size = right_kernel_size 88 | self.dilation = dilation 89 | self.backbone_memory_size = backbone_memory_size 90 | 91 | def forward(self, input_feat): 92 | # input (B, N, T) 93 | residual = input_feat 94 | # dfsmn-memory 95 | pad_input_fea = F.pad(input_feat, [ 96 | self.left_kernel_size * self.dilation, 97 | self.right_kernel_size * self.dilation 98 | ]) # (B,N,T+(l+r)*d) 99 | memory_out = self.memory(pad_input_fea) + residual 100 | residual = memory_out # (B, N, T) 101 | 102 | # fc-transform 103 | fc_output = self.fc_trans(memory_out) # (B, T, N) 104 | output = fc_output + residual # (B, N, T) 105 | return output 106 | 107 | 108 | class BiDfsmnModel(nn.Module): 109 | def __init__(self, 110 | num_classes, 111 | in_channels, 112 | n_mels=32, 113 | num_layer=8, 114 | frondend_channels=16, 115 | frondend_kernel_size=5, 116 | hidden_size=256, 117 | backbone_memory_size=128, 118 | left_kernel_size=2, 119 | right_kernel_size=2, 120 | dilation=1, 121 | dropout=0.0, 122 | dfsmn_with_bn=True, 123 | distill=False, 124 | **kwargs): 125 | super().__init__() 126 | self.front_end = nn.Sequential(*[ 127 | nn.Conv2d(in_channels, 128 | out_channels=frondend_channels, 129 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 130 | stride=(2, 2), 131 | padding=(frondend_kernel_size // 2, 132 | frondend_kernel_size // 2)), 133 | nn.BatchNorm2d(frondend_channels), 134 | nn.PReLU(), 135 | nn.Conv2d(frondend_channels, 136 | out_channels=2 * frondend_channels, 137 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 138 | stride=(2, 2), 139 | padding=(frondend_kernel_size // 2, 140 | frondend_kernel_size // 2)), 141 | nn.BatchNorm2d(2 * frondend_channels), 142 | nn.PReLU() 143 | ]) 144 | self.n_mels = n_mels 145 | self.fc1 = nn.Sequential(*[ 146 | nn.Linear(in_features=2 * frondend_channels * self.n_mels // 4, 147 | out_features=backbone_memory_size), 148 | nn.PReLU(), 149 | ]) 150 | backbone = list() 151 | for idx in range(num_layer): 152 | if dfsmn_with_bn: 153 | backbone.append( 154 | BiDfsmnLayerBN(hidden_size, backbone_memory_size, 155 | left_kernel_size, right_kernel_size, dilation, 156 | dropout)) 157 | else: 158 | backbone.append( 159 | BiDfsmnLayer(hidden_size, backbone_memory_size, 160 | left_kernel_size, right_kernel_size, dilation, 161 | dropout)) 162 | self.backbone = nn.Sequential(*backbone) 163 | self.classifier = nn.Sequential(*[ 164 | nn.Dropout(p=dropout), 165 | # nn.Linear(backbone_memory_size * self.n_mels // 4, num_classes), 166 | nn.Linear(backbone_memory_size * 32 // 4, num_classes), 167 | ]) 168 | self.distill = distill 169 | # self.apply(weight_init) 170 | 171 | def forward(self, input_feat): 172 | # input_feat: B, 1, N, T 173 | batch = input_feat.shape[0] 174 | out = self.front_end(input_feat) # B, C, N//4, T//4 175 | out = out.view(batch, -1, 176 | out.shape[3]).transpose(1, 2).contiguous() # B, T, N1 177 | out = self.fc1(out).transpose(1, 2).contiguous() # B, N, T 178 | features = [] 179 | for layer in self.backbone: 180 | out = layer(out) 181 | features.append(out) 182 | out = out.contiguous().view(batch, -1) 183 | out = self.classifier(out) 184 | if self.distill: 185 | return out, features 186 | else: 187 | return out 188 | 189 | 190 | class BiDfsmnLayerBN_thinnable(nn.Module): 191 | def __init__(self, 192 | hidden_size, 193 | backbone_memory_size, 194 | left_kernel_size, 195 | right_kernel_size, 196 | dilation=1, 197 | dropout=0.0): 198 | super().__init__() 199 | self.fc_trans = nn.Sequential(*[ 200 | nn.Conv1d(backbone_memory_size, hidden_size, 1), 201 | nn.BatchNorm1d(hidden_size), 202 | nn.PReLU(), 203 | nn.Dropout(dropout), 204 | nn.Conv1d(hidden_size, backbone_memory_size, 1), 205 | ]) 206 | self.bn0 = nn.BatchNorm1d(backbone_memory_size) 207 | self.act0 = nn.PReLU() 208 | self.bn1 = nn.BatchNorm1d(backbone_memory_size) 209 | self.act1 = nn.PReLU() 210 | self.bn2 = nn.BatchNorm1d(backbone_memory_size) 211 | self.act2 = nn.PReLU() 212 | self.bn3 = nn.BatchNorm1d(backbone_memory_size) 213 | self.act3 = nn.PReLU() 214 | self.memory = nn.Sequential(*[ 215 | nn.Conv1d(backbone_memory_size, 216 | backbone_memory_size, 217 | kernel_size=left_kernel_size + right_kernel_size + 1, 218 | padding=0, 219 | stride=1, 220 | dilation=dilation, 221 | groups=backbone_memory_size), 222 | nn.BatchNorm1d(backbone_memory_size), 223 | nn.PReLU(), 224 | ]) 225 | 226 | self.left_kernel_size = left_kernel_size 227 | self.right_kernel_size = right_kernel_size 228 | self.dilation = dilation 229 | self.backbone_memory_size = backbone_memory_size 230 | 231 | def forward(self, input_feat, opt): 232 | # input (B, N, T) 233 | residual = input_feat 234 | # dfsmn-memory 235 | pad_input_fea = F.pad(input_feat, [ 236 | self.left_kernel_size * self.dilation, 237 | self.right_kernel_size * self.dilation 238 | ]) # (B,N,T+(l+r)*d) 239 | memory_out = self.memory(pad_input_fea) + residual 240 | residual = memory_out # (B, N, T) 241 | 242 | # fc-transform 243 | fc_output = self.fc_trans(memory_out) # (B, T, N) 244 | if opt == 0: 245 | fc_output = self.bn0(fc_output) 246 | fc_output = self.act0(fc_output) 247 | elif opt == 1: 248 | fc_output = self.bn1(fc_output) 249 | fc_output = self.act1(fc_output) 250 | elif opt == 2: 251 | fc_output = self.bn2(fc_output) 252 | fc_output = self.act2(fc_output) 253 | elif opt == 3: 254 | fc_output = self.bn3(fc_output) 255 | fc_output = self.act3(fc_output) 256 | else: 257 | raise Exception('opt should in [0, 1, 2, 3] but opt = {}'.format(opt)) 258 | output = fc_output + residual # (B, N, T) 259 | return output 260 | 261 | 262 | class BiDfsmnModel_thinnable(nn.Module): 263 | def __init__(self, 264 | num_classes, 265 | in_channels, 266 | n_mels=32, 267 | num_layer=8, 268 | frondend_channels=16, 269 | frondend_kernel_size=5, 270 | hidden_size=256, 271 | backbone_memory_size=128, 272 | left_kernel_size=2, 273 | right_kernel_size=2, 274 | dilation=1, 275 | dropout=0.0, 276 | dfsmn_with_bn=True, 277 | thin_n=3, 278 | distill=False, 279 | **kwargs): 280 | super().__init__() 281 | self.front_end = nn.Sequential(*[ 282 | nn.Conv2d(in_channels, 283 | out_channels=frondend_channels, 284 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 285 | stride=(2, 2), 286 | padding=(frondend_kernel_size // 2, 287 | frondend_kernel_size // 2)), 288 | nn.BatchNorm2d(frondend_channels), 289 | nn.PReLU(), 290 | nn.Conv2d(frondend_channels, 291 | out_channels=2 * frondend_channels, 292 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 293 | stride=(2, 2), 294 | padding=(frondend_kernel_size // 2, 295 | frondend_kernel_size // 2)), 296 | nn.BatchNorm2d(2 * frondend_channels), 297 | nn.PReLU() 298 | ]) 299 | self.n_mels = n_mels 300 | self.fc1 = nn.Sequential(*[ 301 | nn.Linear(in_features=2 * frondend_channels * self.n_mels // 4, 302 | out_features=backbone_memory_size), 303 | nn.PReLU(), 304 | ]) 305 | backbone = list() 306 | for idx in range(num_layer): 307 | backbone.append( 308 | BiDfsmnLayerBN_thinnable(hidden_size, backbone_memory_size, 309 | left_kernel_size, right_kernel_size, dilation, 310 | dropout)) 311 | self.backbone = nn.Sequential(*backbone) 312 | self.classifier = nn.Sequential(*[ 313 | nn.Dropout(p=dropout), 314 | # nn.Linear(backbone_memory_size * self.n_mels // 4, num_classes), 315 | nn.Linear(backbone_memory_size * 32 // 4, num_classes), 316 | ]) 317 | self.distill = distill 318 | self.thin_n = thin_n 319 | # self.apply(weight_init) 320 | 321 | def forward(self, input_feat, opt): 322 | # input_feat: B, 1, N, T 323 | batch = input_feat.shape[0] 324 | out = self.front_end(input_feat) # B, C, N//4, T//4 325 | out = out.view(batch, -1, 326 | out.shape[3]).transpose(1, 2).contiguous() # B, T, N1 327 | out = self.fc1(out).transpose(1, 2).contiguous() # B, N, T 328 | features = [] 329 | if opt == 0: 330 | for idx in [0, 1, 2, 3, 4, 5, 6, 7]: 331 | out = self.backbone[idx](out, opt) 332 | features.append(out) 333 | elif opt == 1: 334 | for idx in [1, 3, 5, 7]: 335 | out = self.backbone[idx](out, opt) 336 | features.append(out) 337 | elif opt == 2: 338 | for idx in [3, 7]: 339 | out = self.backbone[idx](out, opt) 340 | features.append(out) 341 | elif opt == 3: 342 | for idx in [7]: 343 | out = self.backbone[idx](out, opt) 344 | features.append(out) 345 | 346 | out = out.contiguous().view(batch, -1) 347 | out = self.classifier(out) 348 | if self.distill: 349 | return out, features 350 | else: 351 | return out 352 | 353 | 354 | class DfsmnLayerBN_pre(nn.Module): 355 | def __init__(self, 356 | hidden_size, 357 | backbone_memory_size, 358 | left_kernel_size, 359 | right_kernel_size, 360 | dilation=1, 361 | dropout=0.0): 362 | super().__init__() 363 | self.fc_trans = nn.Sequential(*[ 364 | nn.Conv1d(backbone_memory_size, hidden_size, 1), 365 | nn.BatchNorm1d(hidden_size), 366 | nn.ReLU(), 367 | nn.Dropout(dropout), 368 | nn.Conv1d(hidden_size, backbone_memory_size, 1), 369 | ]) 370 | self.bn0 = nn.BatchNorm1d(backbone_memory_size) 371 | self.act0 = nn.PReLU() 372 | self.memory = nn.Sequential(*[ 373 | nn.Conv1d(backbone_memory_size, 374 | backbone_memory_size, 375 | kernel_size=left_kernel_size + right_kernel_size + 1, 376 | padding=0, 377 | stride=1, 378 | dilation=dilation, 379 | groups=backbone_memory_size), 380 | nn.BatchNorm1d(backbone_memory_size), 381 | nn.PReLU(), 382 | ]) 383 | 384 | self.left_kernel_size = left_kernel_size 385 | self.right_kernel_size = right_kernel_size 386 | self.dilation = dilation 387 | self.backbone_memory_size = backbone_memory_size 388 | 389 | def forward(self, input_feat): 390 | # input (B, N, T) 391 | residual = input_feat 392 | # dfsmn-memory 393 | pad_input_fea = F.pad(input_feat, [ 394 | self.left_kernel_size * self.dilation, 395 | self.right_kernel_size * self.dilation 396 | ]) # (B,N,T+(l+r)*d) 397 | memory_out = self.memory(pad_input_fea) + residual 398 | residual = memory_out # (B, N, T) 399 | 400 | # fc-transform 401 | fc_output = self.fc_trans(memory_out) # (B, T, N) 402 | fc_output = self.bn0(fc_output) 403 | fc_output = self.act0(fc_output) 404 | output = fc_output + residual # (B, N, T) 405 | return output 406 | 407 | 408 | class DfsmnModel_pre(nn.Module): 409 | def __init__(self, 410 | num_classes, 411 | in_channels, 412 | n_mels=32, 413 | num_layer=8, 414 | frondend_channels=16, 415 | frondend_kernel_size=5, 416 | hidden_size=256, 417 | backbone_memory_size=128, 418 | left_kernel_size=2, 419 | right_kernel_size=2, 420 | dilation=1, 421 | dropout=0.2, 422 | dfsmn_with_bn=True, 423 | distill=False, 424 | **kwargs): 425 | super().__init__() 426 | self.front_end = nn.Sequential(*[ 427 | nn.Conv2d(in_channels, 428 | out_channels=frondend_channels, 429 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 430 | stride=(2, 2), 431 | padding=(frondend_kernel_size // 2, 432 | frondend_kernel_size // 2)), 433 | nn.BatchNorm2d(frondend_channels), 434 | nn.ReLU(), 435 | nn.Conv2d(frondend_channels, 436 | out_channels=2 * frondend_channels, 437 | kernel_size=[frondend_kernel_size, frondend_kernel_size], 438 | stride=(2, 2), 439 | padding=(frondend_kernel_size // 2, 440 | frondend_kernel_size // 2)), 441 | nn.BatchNorm2d(2 * frondend_channels), 442 | nn.ReLU() 443 | ]) 444 | self.n_mels = n_mels 445 | self.fc1 = nn.Sequential(*[ 446 | nn.Linear(in_features=2 * frondend_channels * self.n_mels // 4, 447 | out_features=backbone_memory_size), 448 | nn.ReLU(), 449 | ]) 450 | backbone = list() 451 | for idx in range(num_layer): 452 | backbone.append( 453 | DfsmnLayerBN_pre(hidden_size, backbone_memory_size, 454 | left_kernel_size, right_kernel_size, dilation, 455 | dropout)) 456 | self.backbone = nn.Sequential(*backbone) 457 | self.classifier = nn.Sequential(*[ 458 | nn.Dropout(p=dropout), 459 | nn.Linear(backbone_memory_size * self.n_mels // 4, num_classes), 460 | ]) 461 | self.distill = distill 462 | # self.apply(weight_init) 463 | 464 | def forward(self, input_feat): 465 | # input_feat: B, 1, N, T 466 | batch = input_feat.shape[0] 467 | out = self.front_end(input_feat) # B, C, N//4, T//4 468 | out = out.view(batch, -1, 469 | out.shape[3]).transpose(1, 2).contiguous() # B, T, N1 470 | out = self.fc1(out).transpose(1, 2).contiguous() # B, N, T 471 | features = [] 472 | for layer in self.backbone: 473 | out = layer(out) 474 | features.append(out) 475 | out = out.contiguous().view(batch, -1) 476 | out = self.classifier(out) 477 | if self.distill: 478 | return out, features 479 | else: 480 | return out 481 | 482 | 483 | -------------------------------------------------------------------------------- /train_valid_test.py: -------------------------------------------------------------------------------- 1 | import imp 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | from torch.utils.data import DataLoader 8 | from torchvision.transforms import Compose 9 | import torchaudio 10 | 11 | from torch.utils.tensorboard import SummaryWriter 12 | from basic import Count 13 | 14 | from core.registry import CONFIG 15 | 16 | from models.torch.dfsmn import DfsmnModel 17 | from models.torch.bidfsmn import BiDfsmnModel, BiDfsmnModel_thinnable, DfsmnModel_pre 18 | 19 | from speech_commands.dataset.speech_commands import SpeechCommandV1 20 | from speech_commands.dataset.transform import ChangeAmplitude, \ 21 | FixAudioLength, ChangeSpeedAndPitchAudio, TimeshiftAudio 22 | 23 | from torch_utils import mixup 24 | 25 | from pytorch_wavelets import DWTForward, DWTInverse 26 | 27 | def loss_term(A): 28 | a = torch.abs(A) 29 | Q = a * a 30 | return Q 31 | 32 | def total_loss(Q_s, Q_t): 33 | Q_s = loss_term(Q_s) 34 | Q_t = loss_term(Q_t) 35 | Q_s_norm = Q_s / torch.norm(Q_s, p=2) 36 | Q_t_norm = Q_t / torch.norm(Q_t, p=2) 37 | tmp = Q_s_norm - Q_t_norm 38 | loss = torch.norm(tmp, p=2) 39 | return loss 40 | 41 | def pass_filter(x, select_pass, J=1, wave='haar', mode='zero'): 42 | xfm = DWTForward(J=J, mode=mode, wave=wave) # Accepts all wave types available to PyWavelets 43 | ifm = DWTInverse(mode=mode, wave=wave) 44 | if x.is_cuda: 45 | xfm, ifm = xfm.cuda(), ifm.cuda() 46 | 47 | if len(x.shape) == 3: 48 | yl, yh = xfm(x.unsqueeze(1)) 49 | elif len(x.shape) == 4: 50 | yl, yh = xfm(x) 51 | else: 52 | assert(False) # error 53 | 54 | if select_pass == 'high': 55 | yl.zero_() 56 | 57 | y = ifm((yl, yh)) 58 | if len(x.shape) == 3: 59 | y = y.squeeze(1) 60 | return y 61 | 62 | 63 | def get_model2(model_type: str, in_channels=1, **kwargs): 64 | if model_type == 'Vgg19Bn': 65 | return Vgg19BN(in_channels=in_channels, **kwargs) # [Batch, 1, 32, 32] 66 | elif model_type == 'Mobilenetv1': 67 | return MobileNetV1(in_channels=in_channels, **kwargs) 68 | elif model_type == 'Mobilenetv2': 69 | return MobileNetV2(in_channels=in_channels, **kwargs) 70 | elif model_type == 'BCResNet': 71 | return BCResNet(in_channels=in_channels, **kwargs) # [Batch, 1, 40, 32] 72 | elif model_type == 'fsmn': 73 | return FSMN(in_channels=in_channels, **kwargs) 74 | elif model_type == 'Dfsmn': 75 | return DfsmnModel(in_channels=in_channels, **kwargs) 76 | elif model_type == 'BiDfsmn': 77 | return BiDfsmnModel(in_channels=in_channels, **kwargs) 78 | elif model_type == 'BiDfsmn_thinnable_pre': 79 | return DfsmnModel_pre(in_channels=in_channels, **kwargs) 80 | elif model_type == 'BiDfsmn_thinnable': 81 | return BiDfsmnModel_thinnable(in_channels=in_channels, **kwargs) 82 | else: 83 | raise RuntimeError('unsupport model type: ', model_type) 84 | 85 | 86 | def get_model(model_type: str, in_channels=1, method="no", **kwargs): 87 | if method == "no": 88 | model = get_model2(model_type, in_channels, **kwargs) 89 | return model 90 | else: 91 | from basic import Count, Modify 92 | model = get_model2(model_type, in_channels, **kwargs) 93 | model.method = method 94 | cnt = Count(model) 95 | model, _ = Modify(model, method=method, id=0, first=1, last=cnt) 96 | return model 97 | 98 | 99 | def create_dataloader(dataset_type, configs, use_gpu, version): 100 | train_transform = Compose([ 101 | ChangeAmplitude(), 102 | ChangeSpeedAndPitchAudio(), 103 | TimeshiftAudio(), 104 | FixAudioLength(), 105 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, 106 | n_fft=2048, 107 | hop_length=512, 108 | n_mels=configs.n_mels, 109 | normalized=True), 110 | torchaudio.transforms.AmplitudeToDB(), 111 | ]) 112 | valid_transform = Compose([ 113 | FixAudioLength(), 114 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, 115 | n_fft=2048, 116 | hop_length=512, 117 | n_mels=configs.n_mels, 118 | normalized=True), 119 | torchaudio.transforms.AmplitudeToDB(), 120 | ]) 121 | 122 | dataset_train = SpeechCommandV1(configs.dataroot, 123 | subset='training', 124 | download=True, 125 | transform=train_transform, 126 | num_classes=configs.num_classes, 127 | noise_ratio=0.3, 128 | noise_max_scale=0.3, 129 | cache_origin_data=False, 130 | version=version) 131 | 132 | dataset_valid = SpeechCommandV1(configs.dataroot, 133 | subset='validation', 134 | download=True, 135 | transform=valid_transform, 136 | num_classes=configs.num_classes, 137 | cache_origin_data=True, 138 | version=version) 139 | 140 | dataset_test = SpeechCommandV1(configs.dataroot, 141 | subset='testing', 142 | download=True, 143 | transform=valid_transform, 144 | num_classes=configs.num_classes, 145 | cache_origin_data=True, 146 | version=version) 147 | 148 | dataset_dict = { 149 | 'training': dataset_train, 150 | 'validation': dataset_valid, 151 | 'testing': dataset_test 152 | } 153 | return DataLoader(dataset_dict[dataset_type], 154 | batch_size=configs.batch_size, 155 | shuffle=dataset_type == 'training', 156 | sampler=None, 157 | pin_memory=use_gpu, 158 | num_workers=16, 159 | persistent_workers=True) 160 | 161 | 162 | def create_lr_schedule(configs, optimizer): 163 | if configs.lr_scheduler == 'plateau': 164 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 165 | optimizer, 166 | patience=configs.lr_scheduler_patience, 167 | factor=configs.lr_scheduler_gamma) 168 | elif configs.lr_scheduler == 'step': 169 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 170 | optimizer, 171 | step_size=configs.lr_scheduler_stepsize, 172 | gamma=configs.lr_scheduler_gamma, 173 | last_epoch=configs.epoch - 1) 174 | elif configs.lr_scheduler == 'cosin': 175 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 176 | optimizer, T_max=configs.epoch) 177 | else: 178 | raise RuntimeError('unsupported lr schedule type: ', 179 | configs.lr_scheduler) 180 | return lr_scheduler 181 | 182 | 183 | def create_optimizer(configs, model): 184 | if configs.optim == 'sgd': 185 | optimizer = torch.optim.SGD(model.parameters(), 186 | lr=configs.lr, 187 | momentum=0.9, 188 | weight_decay=configs.weight_decay) 189 | else: 190 | optimizer = torch.optim.Adam(model.parameters(), 191 | lr=configs.lr, 192 | weight_decay=configs.weight_decay) 193 | 194 | return optimizer 195 | 196 | 197 | weights = [1, 0.5, 0.25] 198 | loss_lim = 50.0 199 | distillation_pred = torch.nn.MSELoss() 200 | pred = False 201 | 202 | def train_epoch_distill(model: nn.Module, 203 | teacher_model: nn.Module, 204 | optimizer, 205 | criterion, 206 | data_loader: data.DataLoader, 207 | epoch, 208 | with_gpu, 209 | log_iter=10, 210 | writer: SummaryWriter = None, 211 | mixup_alpha=0, 212 | distill_alpha=0, 213 | select_pass='no', 214 | J=1, 215 | num_classes=None): 216 | """ 217 | training one epoch 218 | """ 219 | model.train() 220 | if with_gpu: 221 | model = model.cuda() 222 | 223 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size) 224 | epoch_size = len(data_loader) 225 | 226 | 227 | running_loss = 0 228 | i = 0 229 | for inputs, target in pbar: 230 | if with_gpu: 231 | inputs = inputs.cuda() 232 | target = target.cuda() 233 | 234 | if 0 < mixup_alpha < 1: 235 | inputs, target = mixup.mixup(inputs, target, 236 | np.random.beta(mixup_alpha, mixup_alpha), 237 | num_classes) 238 | # forward 239 | teacher_out, teacher_feature = teacher_model(inputs) 240 | if select_pass != 'no': 241 | teacher_feature = [f1 / torch.std(f1) + f2 / torch.std(f2) for f1, f2 in [(pass_filter(f, select_pass=select_pass, J=J), f) for f in teacher_feature]] 242 | 243 | loss = 0 244 | 245 | if model.__class__.__name__[-9:] != 'thinnable': 246 | out, feature = model(inputs) 247 | 248 | if 0 < mixup_alpha < 1: 249 | loss_one_hot = mixup.naive_cross_entropy_loss(out, target) 250 | else: 251 | loss_one_hot = criterion(out, target) 252 | 253 | if hasattr(model, 'method') and model.method == 'Laq': 254 | distr_loss1, distr_loss2 = model.laq_loss(inputs) 255 | distr_loss1 = distr_loss1.mean() 256 | distr_loss2 = distr_loss2.mean() 257 | # remove distrloss after args.distr_epoch epochs 258 | if epoch < 100: 259 | loss = loss + (distr_loss1 + distr_loss2) 260 | 261 | loss = loss + loss_one_hot 262 | 263 | if len(teacher_feature) % len(feature) == 0: 264 | loss_distill = None 265 | for k in range(len(feature)): 266 | j = int((len(teacher_feature) / len(feature)) * (k+1) - 1) 267 | if loss_distill == None: 268 | # loss_distill = distillation(feature[j] / torch.std(feature[j]), teacher_feature[k] / torch.std(teacher_feature[k])) 269 | loss_distill = total_loss(feature[k], teacher_feature[j]) 270 | else: 271 | # loss_distill += distillation(feature[j] / torch.std(feature[j]), teacher_feature[k] / torch.std(teacher_feature[k])) 272 | loss_distill = loss_distill + total_loss(feature[k], teacher_feature[j]) 273 | loss = loss + loss_distill * distill_alpha 274 | if pred: 275 | loss_pred = distillation_pred(out, teacher_out) 276 | loss = loss + loss_pred * distill_alpha 277 | else: 278 | print ('Distiilation Error: teacher {}, student {}!'.format(len(teacher_feature), len(feature))) 279 | else: 280 | for op in range(model.thin_n): 281 | weight = weights[op] 282 | out, feature = model(inputs, op) 283 | 284 | if 0 < mixup_alpha < 1: 285 | loss_one_hot = mixup.naive_cross_entropy_loss(out, target) 286 | else: 287 | loss_one_hot = criterion(out, target) 288 | loss = loss + loss_one_hot * weight 289 | 290 | if len(teacher_feature) % len(feature) == 0: 291 | loss_distill = None 292 | for k in range(len(feature)): 293 | j = int((len(teacher_feature) / len(feature)) * (k+1) - 1) 294 | if loss_distill == None: 295 | # loss_distill = distillation(feature[j] / torch.std(feature[j]), teacher_feature[k] / torch.std(teacher_feature[k])) 296 | loss_distill = total_loss(feature[k], teacher_feature[j]) 297 | else: 298 | # loss_distill += distillation(feature[j] / torch.std(feature[j]), teacher_feature[k] / torch.std(teacher_feature[k])) 299 | loss_distill = loss_distill + total_loss(feature[k], teacher_feature[j]) 300 | loss = loss + loss_distill * distill_alpha * weight 301 | if pred: 302 | loss_pred = distillation_pred(out, teacher_out) 303 | loss = loss + loss_pred * distill_alpha * weight 304 | else: 305 | print ('Distiilation Error: teacher {}, student {}!'.format(len(teacher_feature), len(feature))) 306 | 307 | # backprop 308 | optimizer.zero_grad() 309 | loss.backward() 310 | # if loss.item() > loss_lim: 311 | # nn.utils.clip_grad_norm_(model.parameters(), max_norm=5, norm_type=2) 312 | # print('[loss ont_hot]: %.4f, [loss distill]: %.4f' % (loss_one_hot, loss_distill)) 313 | optimizer.step() 314 | running_loss += loss.item() 315 | if i % log_iter == 0 and writer is not None: 316 | writer.add_scalar('Train/iter_loss', loss.item(), 317 | i + epoch * epoch_size) 318 | writer.file_writer.flush() 319 | 320 | # update the progress bar 321 | pbar.set_postfix({ 322 | 'loss': "%.05f" % (loss.item()), 323 | }) 324 | i += 1 325 | 326 | running_loss /= i 327 | if writer is not None: 328 | writer.add_scalar('Train/epoch_loss', running_loss, epoch) 329 | writer.file_writer.flush() 330 | 331 | return running_loss 332 | 333 | def train_epoch(model: nn.Module, 334 | optimizer, 335 | criterion, 336 | data_loader: data.DataLoader, 337 | epoch, 338 | with_gpu, 339 | log_iter=10, 340 | writer: SummaryWriter = None, 341 | mixup_alpha=0, 342 | num_classes=None): 343 | """ 344 | training one epoch 345 | """ 346 | model.train() 347 | if with_gpu: 348 | model = model.cuda() 349 | 350 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size) 351 | epoch_size = len(data_loader) 352 | 353 | if model.__class__.__name__[-9:] != 'thinnable': 354 | running_loss = 0 355 | i = 0 356 | for feat, target in pbar: 357 | if with_gpu: 358 | feat = feat.cuda() 359 | target = target.cuda() 360 | 361 | if 0 < mixup_alpha < 1: 362 | feat, target = mixup.mixup(feat, target, 363 | np.random.beta(mixup_alpha, mixup_alpha), 364 | num_classes) 365 | # forward 366 | out = model(feat) 367 | if 0 < mixup_alpha < 1: 368 | loss = mixup.naive_cross_entropy_loss(out, target) 369 | else: 370 | loss = criterion(out, target) 371 | 372 | if hasattr(model, 'method') and model.method == 'Laq': 373 | distr_loss1, distr_loss2 = model.laq_loss(feat) 374 | distr_loss1 = distr_loss1.mean() 375 | distr_loss2 = distr_loss2.mean() 376 | # remove distrloss after args.distr_epoch epochs 377 | if epoch < 100: 378 | loss = loss + (distr_loss1 + distr_loss2) 379 | 380 | # backprop 381 | optimizer.zero_grad() 382 | loss.backward() 383 | optimizer.step() 384 | running_loss += loss.item() 385 | 386 | if i % log_iter == 0 and writer is not None: 387 | writer.add_scalar('Train/iter_loss', loss.item(), 388 | i + epoch * epoch_size) 389 | writer.file_writer.flush() 390 | 391 | # update the progress bar 392 | pbar.set_postfix({ 393 | 'loss': "%.05f" % (loss.item()), 394 | }) 395 | i += 1 396 | 397 | running_loss /= i 398 | if writer is not None: 399 | writer.add_scalar('Train/epoch_loss', running_loss, epoch) 400 | writer.file_writer.flush() 401 | 402 | return running_loss 403 | else: 404 | thin_n = model.thin_n 405 | running_loss = 0 406 | i = 0 407 | for inputs, target in pbar: 408 | if with_gpu: 409 | inputs = inputs.cuda() 410 | target = target.cuda() 411 | 412 | if 0 < mixup_alpha < 1: 413 | inputs, target = mixup.mixup(inputs, target, 414 | np.random.beta(mixup_alpha, mixup_alpha), 415 | num_classes) 416 | 417 | loss = 0 418 | 419 | # forward 420 | for op in range(thin_n): 421 | weight = weights[op] 422 | out = model(inputs, op) 423 | if 0 < mixup_alpha < 1: 424 | loss += mixup.naive_cross_entropy_loss(out, target) * weight 425 | else: 426 | loss += criterion(out, target) * weight 427 | 428 | # backprop 429 | optimizer.zero_grad() 430 | loss.backward() 431 | # if loss.item() > loss_lim: 432 | # nn.utils.clip_grad_norm_(model.parameters(), max_norm=5, norm_type=2) 433 | # print('[loss ont_hot]: %.4f, [loss distill]: %.4f' % (loss_one_hot, loss_distill)) 434 | optimizer.step() 435 | running_loss += loss.item() 436 | 437 | if i % log_iter == 0 and writer is not None: 438 | writer.add_scalar('Train/iter_loss', loss.item(), 439 | i + epoch * epoch_size) 440 | writer.file_writer.flush() 441 | 442 | # update the progress bar 443 | pbar.set_postfix({ 444 | 'loss': "%.05f" % (loss.item()), 445 | }) 446 | i += 1 447 | 448 | running_loss /= i 449 | if writer is not None: 450 | writer.add_scalar('Train/epoch_loss', running_loss, epoch) 451 | writer.file_writer.flush() 452 | 453 | return running_loss 454 | 455 | 456 | def valid_epoch_distill(model: nn.Module, 457 | criterion, 458 | data_loader: data.DataLoader, 459 | epoch, 460 | with_gpu, 461 | log_iter=10, 462 | writer: SummaryWriter = None): 463 | """ 464 | valid on dataset 465 | """ 466 | model.eval() 467 | if with_gpu: 468 | model = model.cuda() 469 | 470 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size) 471 | epoch_size = len(data_loader) 472 | 473 | if model.__class__.__name__[-9:] != 'thinnable': 474 | running_loss = 0 475 | running_acc = 0 476 | i = 0 477 | with torch.no_grad(): 478 | for feat, target in pbar: 479 | if with_gpu: 480 | feat = feat.cuda() 481 | target = target.cuda() 482 | # forward 483 | out, feature = model(feat) 484 | loss = criterion(out, target) 485 | 486 | pred = out.max(1, keepdim=True)[1] 487 | acc = pred.eq(target.view_as(pred)).sum() / target.size(0) 488 | 489 | running_loss += loss.item() 490 | running_acc += acc.item() 491 | 492 | # log per 10 iter 493 | if i % log_iter == 0 and writer is not None: 494 | writer.add_scalar('Valid/iter_loss', loss.item(), 495 | i + epoch * epoch_size) 496 | writer.file_writer.flush() 497 | 498 | # update the progress bar 499 | pbar.set_postfix({ 500 | 'loss': "%.05f" % (loss.item()), 501 | }) 502 | i += 1 503 | 504 | running_acc /= i 505 | running_loss /= i 506 | 507 | # log for tensorboard 508 | if writer is not None: 509 | writer.add_scalar('Valid/epoch_loss', running_loss, epoch) 510 | writer.add_scalar('Valid/epoch_accuracy', running_acc, epoch) 511 | writer.file_writer.flush() 512 | 513 | return running_loss, running_acc 514 | else: 515 | thin_n = model.thin_n 516 | running_loss = 0.0 517 | running_acc = [0 for op in range(thin_n)] 518 | i = 0 519 | with torch.no_grad(): 520 | for feat, target in pbar: 521 | if with_gpu: 522 | feat = feat.cuda() 523 | target = target.cuda() 524 | # forward 525 | for op in range(thin_n): 526 | out, feature = model(feat, op) 527 | loss = criterion(out, target) 528 | 529 | pred = out.max(1, keepdim=True)[1] 530 | acc = pred.eq(target.view_as(pred)).sum() / target.size(0) 531 | 532 | running_loss += loss.item() 533 | running_acc[op] += acc.item() 534 | 535 | # log per 10 iter 536 | if i % log_iter == 0 and writer is not None: 537 | writer.add_scalar('Valid/iter_loss[%d]' % [8, 4, 2, 1][op], loss.item(), 538 | i + epoch * epoch_size) 539 | writer.file_writer.flush() 540 | 541 | # update the progress bar 542 | pbar.set_postfix({ 543 | 'loss': "%.05f" % (loss.item()), 544 | }) 545 | i += 1 546 | 547 | running_acc = [acc / i for acc in running_acc] 548 | running_loss = running_loss / i 549 | 550 | # log for tensorboard 551 | if writer is not None: 552 | writer.add_scalar('Valid/epoch_loss', running_loss, epoch) 553 | for op in range(thin_n): 554 | writer.add_scalar('Valid/epoch_accuracy_%d' % [8, 4, 2, 1][op], running_acc[op], epoch) 555 | writer.file_writer.flush() 556 | 557 | return running_loss, running_acc 558 | 559 | def valid_epoch(model: nn.Module, 560 | criterion, 561 | data_loader: data.DataLoader, 562 | epoch, 563 | with_gpu, 564 | log_iter=10, 565 | writer: SummaryWriter = None): 566 | """ 567 | valid on dataset 568 | """ 569 | model.eval() 570 | if with_gpu: 571 | model = model.cuda() 572 | 573 | pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size) 574 | epoch_size = len(data_loader) 575 | 576 | if model.__class__.__name__[-9:] != 'thinnable': 577 | running_loss = 0 578 | running_acc = 0 579 | i = 0 580 | with torch.no_grad(): 581 | for feat, target in pbar: 582 | if with_gpu: 583 | feat = feat.cuda() 584 | target = target.cuda() 585 | # forward 586 | out = model(feat) 587 | loss = criterion(out, target) 588 | 589 | pred = out.max(1, keepdim=True)[1] 590 | acc = pred.eq(target.view_as(pred)).sum() / target.size(0) 591 | 592 | running_loss += loss.item() 593 | running_acc += acc.item() 594 | 595 | # log per 10 iter 596 | if i % log_iter == 0 and writer is not None: 597 | writer.add_scalar('Valid/iter_loss', loss.item(), 598 | i + epoch * epoch_size) 599 | writer.file_writer.flush() 600 | 601 | # update the progress bar 602 | pbar.set_postfix({ 603 | 'loss': "%.05f" % (loss.item()), 604 | }) 605 | i += 1 606 | 607 | running_acc /= i 608 | running_loss /= i 609 | 610 | # log for tensorboard 611 | if writer is not None: 612 | writer.add_scalar('Valid/epoch_loss', running_loss, epoch) 613 | writer.add_scalar('Valid/epoch_accuracy', running_acc, epoch) 614 | writer.file_writer.flush() 615 | 616 | return running_loss, running_acc 617 | else: 618 | thin_n = model.thin_n 619 | running_loss = 0 620 | running_acc = [0 for op in range(thin_n)] 621 | i = 0 622 | with torch.no_grad(): 623 | for feat, target in pbar: 624 | if with_gpu: 625 | feat = feat.cuda() 626 | target = target.cuda() 627 | # forward 628 | for op in range(thin_n): 629 | out = model(feat, op) 630 | loss = criterion(out, target) 631 | 632 | pred = out.max(1, keepdim=True)[1] 633 | acc = pred.eq(target.view_as(pred)).sum() / target.size(0) 634 | 635 | running_loss += loss.item() 636 | running_acc[op] += acc.item() 637 | 638 | # log per 10 iter 639 | if i % log_iter == 0 and writer is not None: 640 | writer.add_scalar('Valid/iter_loss[%d]' % [8, 4, 2, 1][op], loss.item(), 641 | i + epoch * epoch_size) 642 | writer.file_writer.flush() 643 | 644 | # update the progress bar 645 | pbar.set_postfix({ 646 | 'loss': "%.05f" % (loss.item()), 647 | }) 648 | i += 1 649 | 650 | running_acc = [acc / i for acc in running_acc] 651 | running_loss = running_loss / i 652 | 653 | # log for tensorboard 654 | if writer is not None: 655 | writer.add_scalar('Valid/epoch_loss', running_loss, epoch) 656 | for op in range(thin_n): 657 | writer.add_scalar('Valid/epoch_accuracy_%d' % [8, 4, 2, 1][op], running_acc[op], epoch) 658 | writer.file_writer.flush() 659 | 660 | return running_loss, running_acc 661 | --------------------------------------------------------------------------------