├── README.md ├── configs ├── __init__.py ├── dataset_config.py ├── model_config.py └── train_config.py ├── extract_feature ├── WavLM │ ├── WavLM.py │ ├── extract_wavlm.py │ └── modules.py ├── __init__.py └── audio.py ├── figures └── Vesper.png ├── finetune.py ├── load_test_model.py ├── models ├── __init__.py ├── transformer.py └── vesper.py ├── modules ├── __init__.py ├── activation.py ├── classifier.py ├── multihead_attention.py ├── positional_encoding.py └── transformer_encoder.py ├── pretrain.py └── utils ├── __init__.py ├── avgmeter.py ├── collect_result.py ├── dataset.py ├── distributed.py ├── earlystopping.py ├── environment.py ├── logger.py ├── metric.py └── recoder.py /README.md: -------------------------------------------------------------------------------- 1 | # Vesper 2 | ![framework](./figures/Vesper.png) 3 | [\[Paper\]](https://arxiv.org/abs/2307.10757) Vesper: A Compact and Effective Pretrained Model for Speech Emotion Recognition 4 | 5 | ## Data Preparation 6 | Modify the variables of each dataset in the ```configs/dataset_config.py```. 7 | 1. Move your audio files to the ```wavdir``` directory. 8 | 2. Create a meta_csv_file with columns ```name``` (file names) and ```label``` (emotional labels) for each dataset. The pretraining datasets do not need the ```label``` column. 9 | 10 | Extracting WavLM features in advance can accelerate the pretraining speed greatly. Please use the ```extract_feature/WavLM/extract_wavlm.py``` file to extract the features of pretraining data in advance. 11 | 12 | ## Pretraining 13 | Specify training hyperparameters on the command line or modify them in the ```configs/train_config.py```. 14 | Please also specify ```path_to_wavlm``` on the command line or in the ```configs/model_config.py```. 15 | Please refer to the ```get_args``` function in the ```configs/__init__.py``` if you want to use the command line method. 16 | ```python 17 | python pretrain.py -M Vesper-4 18 | python pretrain.py -M Vesper-12 19 | python pretrain.py -M Vesper-12 -b 32 -g 0,1 -l 0.0005 --model_path_to_wavlm PATH_to_WavLM/WavLM-Large.pt 20 | ``` 21 | 22 | ## Fine-tuning 23 | Specify fine-tuning hyperparameters on the command line or modify them in the ```configs/train_config.py```. 24 | Please also specify ```path_to_vesper``` on the command line or in the ```configs/model_config.py```. 25 | ```python 26 | python finetune.py -M Vesper-12 -d iemocap 27 | python finetune.py -M Vesper-12 -d iemocap -g 0 -b 32 -l 0.0007 --model_path_to_vesper PATH_to_EXP_DIRECTORY/checkpoint/model_best.pt 28 | ``` 29 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_config import _C as train_cfg 2 | from .dataset_config import _C as dataset_cfg 3 | from .model_config import _C as model_cfg 4 | 5 | import os 6 | import torch 7 | import argparse 8 | from yacs.config import CfgNode as CN 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("-M", "--model_type", help="modify cfg.train.model.type", default='Vesper', type=str) # required=True 13 | parser.add_argument("-d", "--dataset_database", help="specify the database used", default='lssed', type=str) 14 | parser.add_argument("-f", "--dataset_feature", help="specify the feature used", type=str) 15 | parser.add_argument("-e", "--train_EPOCH", help="total training epoch", type=int) 16 | parser.add_argument("-b", "--train_batch_size", help="training batch size", type=int) 17 | parser.add_argument("-l", "--train_lr", help="learning rate", type=float) 18 | parser.add_argument("-g", "--train_device_id", help="GPU ids", default='0', type=str) 19 | parser.add_argument("-s", "--train_seed", help="random seed", default=123, type=int) 20 | parser.add_argument("-S", "--train_save_best", help="save model with the best performance", action='store_true') 21 | parser.add_argument("-p", "--train_patience", help="the patience used in the early stopping", default=15, type=int) 22 | parser.add_argument("-r", "--model_output_rep", help="weighted sum or last layer", type=str) 23 | parser.add_argument("-m", "--mark", help="mark the current run", type=str) 24 | parser.add_argument("--dataset_num_workers", help="the number of workers", default=12, type=int) 25 | parser.add_argument("--train_warmup_epoch", help="set the warmup epoch", default=0.05, type=float) 26 | parser.add_argument("--train_resume", help="resume an experiment", type=str) 27 | parser.add_argument("--train_load_model", help="load a model", type=str) 28 | parser.add_argument("--train_device", help="run on cuda or cpu", default='cuda', type=str) 29 | parser.add_argument("--model_path_to_vesper", help="initialize model with Vesper's checkpoint", type=str) 30 | parser.add_argument("--model_path_to_wavlm", help="initialize model with pre-trained WavLM", type=str) 31 | args = parser.parse_args() 32 | return args 33 | 34 | def create_workshop(cfg, local_rank, fold): 35 | modeltype = cfg.model.type 36 | database = cfg.dataset.database 37 | batch = cfg.train.batch_size 38 | feature = cfg.dataset.feature 39 | lr = cfg.train.lr 40 | epoch = cfg.train.EPOCH 41 | 42 | world_size = torch.cuda.device_count() 43 | batch = batch * world_size 44 | 45 | config_name = f'./exp/{modeltype}/{database}_e{epoch}_b{batch}_lr{lr}_{feature}' 46 | 47 | if cfg.mark is not None: 48 | config_name = config_name + '_mark_{}'.format(cfg.mark) 49 | 50 | cfg.workshop = os.path.join(config_name, f'fold_{fold}') 51 | cfg.ckpt_save_path = os.path.join(cfg.workshop, 'checkpoint') 52 | 53 | if local_rank == 0: 54 | if os.path.exists(cfg.workshop): 55 | if cfg.train.resume is None: 56 | raise ValueError(f'workshop {cfg.workshop} already existed.') 57 | else: 58 | os.makedirs(cfg.workshop) 59 | os.makedirs(cfg.ckpt_save_path) 60 | 61 | def get_config(mode=''): 62 | args = get_args() 63 | 64 | cfg = CN(new_allowed=True) 65 | cfg.model = CN(new_allowed=True) 66 | cfg.dataset = CN(new_allowed=True) 67 | cfg.train = CN(new_allowed=True) 68 | 69 | if len(args.model_type.split('-')) > 1: 70 | args.model_type, version = args.model_type.split('-')[0], args.model_type.split('-')[1] 71 | if args.model_type == 'WavLM': 72 | is_wavlm = True 73 | args.model_type = 'Vesper' 74 | else: 75 | is_wavlm = False 76 | 77 | cfg.model.update(model_cfg[args.model_type]) 78 | cfg.dataset.update(dataset_cfg[args.dataset_database]) 79 | cfg.train.update(train_cfg[args.model_type+mode]) 80 | 81 | # Namespace -> Dict 82 | args = vars(args) 83 | verbose = [] 84 | for key, value in args.items(): 85 | key_list = key.split('_', maxsplit=1) 86 | if len(key_list) > 1: 87 | if value is not None or not hasattr(cfg[key_list[0]], key_list[1]): 88 | cfg[key_list[0]][key_list[1]] = value 89 | verbose.append((key, value)) 90 | else: 91 | if value is not None or not hasattr(cfg, key_list[0]): 92 | cfg[key_list[0]] = value 93 | verbose.append((key, value)) 94 | # print('Arguments from command line:', verbose) 95 | 96 | if is_wavlm: 97 | cfg.model.init_with_wavlm = True 98 | cfg.model.init_with_ckpt = not cfg.model.init_with_wavlm 99 | if version == 'Base': 100 | cfg.model.path_to_wavlm = cfg.model.path_to_wavlm[0] 101 | cfg.model.encoder_layers = 12 102 | cfg.model.encoder_embed_dim = 768 103 | cfg.model.ffn_embed_dim = 3072 104 | cfg.model.num_heads = 12 105 | cfg.model.extractor_mode = 'default' 106 | cfg.model.normalize = False 107 | cfg.model.normalize_before = False 108 | elif version == 'Large': 109 | cfg.model.path_to_wavlm = cfg.model.path_to_wavlm[1] 110 | cfg.model.encoder_layers = 24 111 | else: 112 | raise ValueError(f'Unknown WavLM version: {version}') 113 | else: 114 | cfg.model.init_with_wavlm = True if 'pretrain' in mode else False 115 | cfg.model.init_with_ckpt = not cfg.model.init_with_wavlm 116 | cfg.model.encoder_layers = eval(version) 117 | 118 | cfg.model.num_classes = cfg.dataset.num_classes 119 | if cfg.model.type == 'ALLSpeech': 120 | cfg.dataset.num_queries = cfg.model.num_queries 121 | cfg.dataset.distractors = cfg.model.distractors 122 | cfg.dataset.mask_span = cfg.model.mask_span 123 | cfg.dataset.mask_chunk = cfg.model.mask_chunk 124 | 125 | # modify cfg.train.batch_size in the case of multi-GPUs training 126 | num_gpus = len(cfg.train.device_id.split(',')) 127 | if num_gpus > 1: 128 | ddp_batch_size = round(cfg.train.batch_size / num_gpus) 129 | print(f'Modified batch size: {cfg.train.batch_size} -> {ddp_batch_size}.') 130 | cfg.train.batch_size = ddp_batch_size 131 | return cfg 132 | 133 | def dict_2_list(dict): 134 | lst = [] 135 | for key, value in dict.items(): 136 | if value is not None: 137 | lst.extend([key, value]) 138 | return lst 139 | 140 | -------------------------------------------------------------------------------- /configs/dataset_config.py: -------------------------------------------------------------------------------- 1 | 2 | from yacs.config import CfgNode as CN 3 | 4 | _C = CN(new_allowed=True) 5 | 6 | ########### 7 | # IEMOCAP # 8 | ########### 9 | _C.iemocap = CN(new_allowed=True) 10 | _C.iemocap.num_classes = 4 11 | _C.iemocap.meta_csv_file = '/148Dataset/data-chen.weidong/iemocap/feature/name_label_text.csv' 12 | _C.iemocap.wavdir = '/148Dataset/data-chen.weidong/iemocap/wav_all_sentences' 13 | _C.iemocap.batch_length = 104000 # 16000 * 6.5 14 | _C.iemocap.evaluate = ['accuracy', 'recall'] 15 | _C.iemocap.folds = [1, 2, 3, 4, 5] 16 | _C.iemocap.f1 = 'weighted' 17 | _C.iemocap.have_test_set = False 18 | 19 | ######## 20 | # MELD # 21 | ######## 22 | _C.meld = CN(new_allowed=True) 23 | _C.meld.num_classes = 7 24 | _C.meld.meta_csv_file = '/148Dataset/data-chen.weidong/meld/label/official' 25 | _C.meld.wavdir = '/148Dataset/data-chen.weidong/meld/audio_16k' 26 | _C.meld.batch_length = 72000 # 16000 * 4.5 27 | _C.meld.evaluate = ['f1'] 28 | _C.meld.folds = [1] 29 | _C.meld.f1 = 'weighted' 30 | _C.meld.have_test_set = True 31 | 32 | ########### 33 | # CREMA-D # 34 | ########### 35 | _C.crema = CN(new_allowed=True) 36 | _C.crema.num_classes = 6 37 | _C.crema.meta_csv_file = '/148Dataset/data-chen.weidong/CREMA-D/CREMA-D.csv' 38 | _C.crema.wavdir = '/148Dataset/data-chen.weidong/CREMA-D/AudioWAV' 39 | _C.crema.batch_length = 48000 # 16000 * 3.0 40 | _C.crema.evaluate = ['accuracy', 'recall'] 41 | _C.crema.folds = [1] 42 | _C.crema.f1 = 'weighted' 43 | _C.crema.have_test_set = False 44 | 45 | ######### 46 | # LSSED # 47 | ######### 48 | _C.lssed = CN(new_allowed=True) 49 | _C.lssed.num_classes = 4 50 | _C.lssed.meta_csv_file = '/148Dataset/data-chen.weidong/lssed_all/metadata_english_all.csv' 51 | _C.lssed.wavdir = '/148Dataset/data-chen.weidong/lssed_all/wav_all' 52 | _C.lssed.batch_length = 80000 # 16000*5 53 | _C.lssed.evaluate = ['accuracy', 'recall'] 54 | _C.lssed.folds = [1] 55 | _C.lssed.f1 = 'weighted' 56 | _C.lssed.have_test_set = True 57 | 58 | _C.lssed.target_length = 249 59 | _C.lssed.l_target_dir = '/148Dataset/data-chen.weidong/lssed_all/feature/wavlm_large_L12_mat' 60 | _C.lssed.h_target_dir = '/148Dataset/data-chen.weidong/lssed_all/feature/wavlm_large_L24_mat' 61 | -------------------------------------------------------------------------------- /configs/model_config.py: -------------------------------------------------------------------------------- 1 | 2 | from yacs.config import CfgNode as CN 3 | _C = CN(new_allowed=True) 4 | 5 | ############### 6 | # Transformer # 7 | ############### 8 | _C.Transformer = CN(new_allowed=True) 9 | 10 | _C.Transformer.num_encoder_layers = 4 11 | _C.Transformer.embed_dim = 1024 12 | _C.Transformer.ffn_embed_dim = 512 13 | _C.Transformer.num_heads = 8 14 | _C.Transformer.activation = 'gelu' 15 | _C.Transformer.dropout = 0.1 16 | _C.Transformer.bias = True 17 | _C.Transformer.normalize_before = True 18 | 19 | # positional embeddings 20 | _C.Transformer.conv_pos = 128 21 | _C.Transformer.conv_pos_groups = 16 22 | 23 | ########## 24 | # Vesper # 25 | ########## 26 | _C.Vesper = CN(new_allowed=True) 27 | 28 | # mainstream model 29 | _C.Vesper.encoder_layers= 4 30 | _C.Vesper.encoder_embed_dim = 1024 31 | _C.Vesper.ffn_embed_dim = 4096 32 | _C.Vesper.num_heads = 16 33 | _C.Vesper.activation = 'gelu' 34 | _C.Vesper.dropout = 0.1 35 | _C.Vesper.bias = True 36 | _C.Vesper.normalize = True 37 | _C.Vesper.normalize_before = True 38 | _C.Vesper.relative_position_embedding = True 39 | _C.Vesper.qk_norm = False # query/key (QK) normalization 40 | 41 | # predictor 42 | _C.Vesper.enable_l_predictor = True 43 | _C.Vesper.enable_h_predictor = True 44 | _C.Vesper.enable_x_predictor = True 45 | 46 | # FinetuneWrapper 47 | _C.Vesper.projector_dim = 256 48 | _C.Vesper.output_rep = 'weighted_sum' # 'weighted_sum' / 'last_layer' 49 | 50 | # initiliaze with wavlm 51 | _C.Vesper.init_with_wavlm = True 52 | _C.Vesper.init_style = ['uniform_extract'] # ['custom_average', [(0, 1), (2, 5), (6, 13), (14, 23)]], ['custom_extract', [0, 5, 11, 17]] 53 | _C.Vesper.path_to_wavlm = ['/148Dataset/data-chen.weidong/pre_trained_model/wavlm/WavLM-Base.pt', '/148Dataset/data-chen.weidong/pre_trained_model/wavlm/WavLM-Large.pt'] 54 | 55 | # initiliaze with other pre-trained model 56 | _C.Vesper.init_with_ckpt = False 57 | _C.Vesper.path_to_vesper = '' 58 | 59 | # rms-based mask 60 | _C.Vesper.mask_depend_on_rms = True 61 | _C.Vesper.frame_length = 400 # 16000 * 0.025 62 | _C.Vesper.hop_length = 320 # 16000 * 0.020 63 | _C.Vesper.span_space = 1 64 | _C.Vesper.h_up = 1.0 65 | _C.Vesper.h_down = 0.5 66 | _C.Vesper.l_up = 0.49 67 | _C.Vesper.l_down = 0.2 68 | _C.Vesper.small_span = 8 69 | _C.Vesper.num_small_span = 20 70 | _C.Vesper.large_span = 40 71 | _C.Vesper.num_large_span = 4 72 | _C.Vesper.max_mask_percentage = 0.64 73 | 74 | # positional embedding 75 | _C.Vesper.conv_pos = 128 76 | _C.Vesper.conv_pos_groups = 16 77 | 78 | # bucket relative position embedding 79 | _C.Vesper.num_buckets = 320 80 | _C.Vesper.max_distance = 800 81 | _C.Vesper.gru_rel_pos = True 82 | 83 | # feature encoder 84 | _C.Vesper.extractor_mode = 'layer_norm' # 'default' / 'layer_norm' 85 | _C.Vesper.conv_feature_layers = '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' 86 | _C.Vesper.dropout_input = 0.0 87 | -------------------------------------------------------------------------------- /configs/train_config.py: -------------------------------------------------------------------------------- 1 | 2 | from yacs.config import CfgNode as CN 3 | 4 | _C = CN(new_allowed=True) 5 | 6 | ############### 7 | # Transformer # 8 | ############### 9 | _C.Transformer = CN(new_allowed=True) 10 | _C.Transformer.EPOCH = 120 11 | _C.Transformer.batch_size = 32 12 | _C.Transformer.lr = 0.0005 13 | 14 | ########## 15 | # Vesper # 16 | ########## 17 | _C.Vesper_pretrain = CN(new_allowed=True) 18 | _C.Vesper_pretrain.EPOCH = 100 19 | _C.Vesper_pretrain.batch_size = 256 20 | _C.Vesper_pretrain.lr = 0.005 21 | _C.Vesper_pretrain.optimizer = 'AdamW' # 'AdamW' / 'sgd' 22 | _C.Vesper_pretrain.weight_decay = 0.01 23 | _C.Vesper_pretrain.freeze_cnn = True 24 | _C.Vesper_pretrain.loss_weight_l = 1.0 25 | _C.Vesper_pretrain.loss_weight_h = 0.1 26 | _C.Vesper_pretrain.loss_weight_x = 1.0 27 | 28 | _C.Vesper_finetune = CN(new_allowed=True) 29 | _C.Vesper_finetune.EPOCH = 50 30 | _C.Vesper_finetune.batch_size = 16 31 | _C.Vesper_finetune.lr = 0.0001 32 | _C.Vesper_finetune.optimizer = 'sgd' # 'AdamW' / 'sgd' 33 | _C.Vesper_finetune.weight_decay = 0.01 34 | _C.Vesper_finetune.freeze_cnn = True 35 | _C.Vesper_finetune.freeze_upstream = True 36 | -------------------------------------------------------------------------------- /extract_feature/WavLM/WavLM.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import logging 12 | from typing import List, Optional, Tuple 13 | 14 | import numpy as np 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch.nn import LayerNorm 20 | from extract_feature.WavLM.modules import ( 21 | Fp32GroupNorm, 22 | Fp32LayerNorm, 23 | GradMultiply, 24 | MultiheadAttention, 25 | SamePad, 26 | init_bert_params, 27 | get_activation_fn, 28 | TransposeLast, 29 | GLU_Linear, 30 | ) 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def compute_mask_indices( 36 | shape: Tuple[int, int], 37 | padding_mask: Optional[torch.Tensor], 38 | mask_prob: float, 39 | mask_length: int, 40 | mask_type: str = "static", 41 | mask_other: float = 0.0, 42 | min_masks: int = 0, 43 | no_overlap: bool = False, 44 | min_space: int = 0, 45 | ) -> np.ndarray: 46 | """ 47 | Computes random mask spans for a given shape 48 | 49 | Args: 50 | shape: the the shape for which to compute masks. 51 | should be of size 2 where first element is batch size and 2nd is timesteps 52 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 53 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 54 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 55 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 56 | mask_type: how to compute mask lengths 57 | static = fixed size 58 | uniform = sample from uniform distribution [mask_other, mask_length*2] 59 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 60 | poisson = sample from possion distribution with lambda = mask length 61 | min_masks: minimum number of masked spans 62 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 63 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 64 | """ 65 | 66 | bsz, all_sz = shape 67 | mask = np.full((bsz, all_sz), False) 68 | 69 | all_num_mask = int( 70 | # add a random number for probabilistic rounding 71 | mask_prob * all_sz / float(mask_length) 72 | + np.random.rand() 73 | ) 74 | 75 | all_num_mask = max(min_masks, all_num_mask) 76 | 77 | mask_idcs = [] 78 | for i in range(bsz): 79 | if padding_mask is not None: 80 | sz = all_sz - padding_mask[i].long().sum().item() 81 | num_mask = int( 82 | # add a random number for probabilistic rounding 83 | mask_prob * sz / float(mask_length) 84 | + np.random.rand() 85 | ) 86 | num_mask = max(min_masks, num_mask) 87 | else: 88 | sz = all_sz 89 | num_mask = all_num_mask 90 | 91 | if mask_type == "static": 92 | lengths = np.full(num_mask, mask_length) 93 | elif mask_type == "uniform": 94 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 95 | elif mask_type == "normal": 96 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 97 | lengths = [max(1, int(round(x))) for x in lengths] 98 | elif mask_type == "poisson": 99 | lengths = np.random.poisson(mask_length, size=num_mask) 100 | lengths = [int(round(x)) for x in lengths] 101 | else: 102 | raise Exception("unknown mask selection " + mask_type) 103 | 104 | if sum(lengths) == 0: 105 | lengths[0] = min(mask_length, sz - 1) 106 | 107 | if no_overlap: 108 | mask_idc = [] 109 | 110 | def arrange(s, e, length, keep_length): 111 | span_start = np.random.randint(s, e - length) 112 | mask_idc.extend(span_start + i for i in range(length)) 113 | 114 | new_parts = [] 115 | if span_start - s - min_space >= keep_length: 116 | new_parts.append((s, span_start - min_space + 1)) 117 | if e - span_start - keep_length - min_space > keep_length: 118 | new_parts.append((span_start + length + min_space, e)) 119 | return new_parts 120 | 121 | parts = [(0, sz)] 122 | min_length = min(lengths) 123 | for length in sorted(lengths, reverse=True): 124 | lens = np.fromiter( 125 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 126 | np.int, 127 | ) 128 | l_sum = np.sum(lens) 129 | if l_sum == 0: 130 | break 131 | probs = lens / np.sum(lens) 132 | c = np.random.choice(len(parts), p=probs) 133 | s, e = parts.pop(c) 134 | parts.extend(arrange(s, e, length, min_length)) 135 | mask_idc = np.asarray(mask_idc) 136 | else: 137 | min_len = min(lengths) 138 | if sz - min_len <= num_mask: 139 | min_len = sz - num_mask - 1 140 | 141 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 142 | 143 | mask_idc = np.asarray( 144 | [ 145 | mask_idc[j] + offset 146 | for j in range(len(mask_idc)) 147 | for offset in range(lengths[j]) 148 | ] 149 | ) 150 | 151 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 152 | 153 | min_len = min([len(m) for m in mask_idcs]) 154 | for i, mask_idc in enumerate(mask_idcs): 155 | if len(mask_idc) > min_len: 156 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 157 | mask[i, mask_idc] = True 158 | 159 | return mask 160 | 161 | 162 | class WavLMConfig: 163 | def __init__(self, cfg=None): 164 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) 165 | self.encoder_layers: int = 12 # num encoder layers in the transformer 166 | 167 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 168 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 169 | self.encoder_attention_heads: int = 12 # num encoder attention heads 170 | self.activation_fn: str = "gelu" # activation function to use 171 | 172 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 173 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] 174 | self.conv_bias: bool = False # include bias in conv encoder 175 | self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this 176 | 177 | self.normalize: bool = False # normalize input to have 0 mean and unit variance during training 178 | 179 | # dropouts 180 | self.dropout: float = 0.1 # dropout probability for the transformer 181 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 182 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 183 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 184 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 185 | self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) 186 | 187 | # masking 188 | self.mask_length: int = 10 # mask length 189 | self.mask_prob: float = 0.65 # probability of replacing a token with mask 190 | self.mask_selection: str = "static" # how to choose mask length 191 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh 192 | self.no_mask_overlap: bool = False # whether to allow masks to overlap 193 | self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) 194 | 195 | # channel masking 196 | self.mask_channel_length: int = 10 # length of the mask for features (channels) 197 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 198 | self.mask_channel_selection: str = "static" # how to choose mask length for channel masking 199 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices 200 | self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap 201 | self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) 202 | 203 | # positional embeddings 204 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 205 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 206 | 207 | # relative position embedding 208 | self.relative_position_embedding: bool = False # apply relative position embedding 209 | self.num_buckets: int = 320 # number of buckets for relative position embedding 210 | self.max_distance: int = 1280 # maximum distance for relative position embedding 211 | self.gru_rel_pos: bool = False # apply gated relative position embedding 212 | 213 | if cfg is not None: 214 | self.update(cfg) 215 | 216 | def update(self, cfg: dict): 217 | self.__dict__.update(cfg) 218 | 219 | 220 | class WavLM(nn.Module): 221 | def __init__( 222 | self, 223 | cfg: WavLMConfig, 224 | ) -> None: 225 | super().__init__() 226 | logger.info(f"WavLM Config: {cfg.__dict__}") 227 | 228 | self.cfg = cfg 229 | feature_enc_layers = eval(cfg.conv_feature_layers) 230 | self.embed = feature_enc_layers[-1][0] 231 | 232 | self.feature_extractor = ConvFeatureExtractionModel( 233 | conv_layers=feature_enc_layers, 234 | dropout=0.0, 235 | mode=cfg.extractor_mode, 236 | conv_bias=cfg.conv_bias, 237 | ) 238 | 239 | self.post_extract_proj = ( 240 | nn.Linear(self.embed, cfg.encoder_embed_dim) 241 | if self.embed != cfg.encoder_embed_dim 242 | else None 243 | ) 244 | 245 | self.mask_prob = cfg.mask_prob 246 | self.mask_selection = cfg.mask_selection 247 | self.mask_other = cfg.mask_other 248 | self.mask_length = cfg.mask_length 249 | self.no_mask_overlap = cfg.no_mask_overlap 250 | self.mask_min_space = cfg.mask_min_space 251 | 252 | self.mask_channel_prob = cfg.mask_channel_prob 253 | self.mask_channel_selection = cfg.mask_channel_selection 254 | self.mask_channel_other = cfg.mask_channel_other 255 | self.mask_channel_length = cfg.mask_channel_length 256 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 257 | self.mask_channel_min_space = cfg.mask_channel_min_space 258 | 259 | self.dropout_input = nn.Dropout(cfg.dropout_input) 260 | self.dropout_features = nn.Dropout(cfg.dropout_features) 261 | 262 | self.feature_grad_mult = cfg.feature_grad_mult 263 | 264 | self.mask_emb = nn.Parameter( 265 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 266 | ) 267 | 268 | self.encoder = TransformerEncoder(cfg) 269 | self.layer_norm = LayerNorm(self.embed) 270 | 271 | def apply_mask(self, x, padding_mask): 272 | B, T, C = x.shape 273 | if self.mask_prob > 0: 274 | mask_indices = compute_mask_indices( 275 | (B, T), 276 | padding_mask, 277 | self.mask_prob, 278 | self.mask_length, 279 | self.mask_selection, 280 | self.mask_other, 281 | min_masks=2, 282 | no_overlap=self.no_mask_overlap, 283 | min_space=self.mask_min_space, 284 | ) 285 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 286 | x[mask_indices] = self.mask_emb 287 | else: 288 | mask_indices = None 289 | 290 | if self.mask_channel_prob > 0: 291 | mask_channel_indices = compute_mask_indices( 292 | (B, C), 293 | None, 294 | self.mask_channel_prob, 295 | self.mask_channel_length, 296 | self.mask_channel_selection, 297 | self.mask_channel_other, 298 | no_overlap=self.no_mask_channel_overlap, 299 | min_space=self.mask_channel_min_space, 300 | ) 301 | mask_channel_indices = ( 302 | torch.from_numpy(mask_channel_indices) 303 | .to(x.device) 304 | .unsqueeze(1) 305 | .expand(-1, T, -1) 306 | ) 307 | x[mask_channel_indices] = 0 308 | 309 | return x, mask_indices 310 | 311 | def forward_padding_mask( 312 | self, features: torch.Tensor, padding_mask: torch.Tensor, 313 | ) -> torch.Tensor: 314 | extra = padding_mask.size(1) % features.size(1) 315 | if extra > 0: 316 | padding_mask = padding_mask[:, :-extra] 317 | padding_mask = padding_mask.view( 318 | padding_mask.size(0), features.size(1), -1 319 | ) 320 | padding_mask = padding_mask.all(-1) 321 | return padding_mask 322 | 323 | def extract_features( 324 | self, 325 | source: torch.Tensor, 326 | padding_mask: Optional[torch.Tensor] = None, 327 | mask: bool = False, 328 | ret_conv: bool = False, 329 | output_layer: Optional[int] = None, 330 | ret_layer_results: bool = False, 331 | ): 332 | 333 | if self.feature_grad_mult > 0: 334 | features = self.feature_extractor(source) 335 | if self.feature_grad_mult != 1.0: 336 | features = GradMultiply.apply(features, self.feature_grad_mult) 337 | else: 338 | with torch.no_grad(): 339 | features = self.feature_extractor(source) 340 | 341 | features = features.transpose(1, 2) 342 | features = self.layer_norm(features) 343 | 344 | if padding_mask is not None: 345 | padding_mask = self.forward_padding_mask(features, padding_mask) 346 | 347 | if self.post_extract_proj is not None: 348 | features = self.post_extract_proj(features) 349 | 350 | features = self.dropout_input(features) 351 | 352 | if mask: 353 | x, mask_indices = self.apply_mask( 354 | features, padding_mask 355 | ) 356 | else: 357 | x = features 358 | 359 | # feature: (B, T, D), float 360 | # target: (B, T), long 361 | # x: (B, T, D), float 362 | # padding_mask: (B, T), bool 363 | # mask_indices: (B, T), bool 364 | x, layer_results = self.encoder( 365 | x, 366 | padding_mask=padding_mask, 367 | layer=None if output_layer is None else output_layer - 1 368 | ) 369 | 370 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 371 | 372 | feature = res["features"] if ret_conv else res["x"] 373 | if ret_layer_results: 374 | feature = (feature, res["layer_results"]) 375 | return feature, res["padding_mask"] 376 | 377 | 378 | class ConvFeatureExtractionModel(nn.Module): 379 | def __init__( 380 | self, 381 | conv_layers: List[Tuple[int, int, int]], 382 | dropout: float = 0.0, 383 | mode: str = "default", 384 | conv_bias: bool = False, 385 | conv_type: str = "default" 386 | ): 387 | super().__init__() 388 | 389 | assert mode in {"default", "layer_norm"} 390 | 391 | def block( 392 | n_in, 393 | n_out, 394 | k, 395 | stride, 396 | is_layer_norm=False, 397 | is_group_norm=False, 398 | conv_bias=False, 399 | ): 400 | def make_conv(): 401 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 402 | nn.init.kaiming_normal_(conv.weight) 403 | return conv 404 | 405 | assert ( 406 | is_layer_norm and is_group_norm 407 | ) == False, "layer norm and group norm are exclusive" 408 | 409 | if is_layer_norm: 410 | return nn.Sequential( 411 | make_conv(), 412 | nn.Dropout(p=dropout), 413 | nn.Sequential( 414 | TransposeLast(), 415 | Fp32LayerNorm(dim, elementwise_affine=True), 416 | TransposeLast(), 417 | ), 418 | nn.GELU(), 419 | ) 420 | elif is_group_norm: 421 | return nn.Sequential( 422 | make_conv(), 423 | nn.Dropout(p=dropout), 424 | Fp32GroupNorm(dim, dim, affine=True), 425 | nn.GELU(), 426 | ) 427 | else: 428 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 429 | 430 | self.conv_type = conv_type 431 | if self.conv_type == "default": 432 | in_d = 1 433 | self.conv_layers = nn.ModuleList() 434 | for i, cl in enumerate(conv_layers): 435 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 436 | (dim, k, stride) = cl 437 | 438 | self.conv_layers.append( 439 | block( 440 | in_d, 441 | dim, 442 | k, 443 | stride, 444 | is_layer_norm=mode == "layer_norm", 445 | is_group_norm=mode == "default" and i == 0, 446 | conv_bias=conv_bias, 447 | ) 448 | ) 449 | in_d = dim 450 | elif self.conv_type == "conv2d": 451 | in_d = 1 452 | self.conv_layers = nn.ModuleList() 453 | for i, cl in enumerate(conv_layers): 454 | assert len(cl) == 3 455 | (dim, k, stride) = cl 456 | 457 | self.conv_layers.append( 458 | torch.nn.Conv2d(in_d, dim, k, stride) 459 | ) 460 | self.conv_layers.append(torch.nn.ReLU()) 461 | in_d = dim 462 | elif self.conv_type == "custom": 463 | in_d = 1 464 | idim = 80 465 | self.conv_layers = nn.ModuleList() 466 | for i, cl in enumerate(conv_layers): 467 | assert len(cl) == 3 468 | (dim, k, stride) = cl 469 | self.conv_layers.append( 470 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1) 471 | ) 472 | self.conv_layers.append( 473 | torch.nn.LayerNorm([dim, idim]) 474 | ) 475 | self.conv_layers.append(torch.nn.ReLU()) 476 | in_d = dim 477 | if (i + 1) % 2 == 0: 478 | self.conv_layers.append( 479 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 480 | ) 481 | idim = int(math.ceil(idim / 2)) 482 | else: 483 | pass 484 | 485 | def forward(self, x, mask=None): 486 | 487 | # BxT -> BxCxT 488 | x = x.unsqueeze(1) 489 | if self.conv_type == "custom": 490 | for conv in self.conv_layers: 491 | if isinstance(conv, nn.LayerNorm): 492 | x = x.transpose(1, 2) 493 | x = conv(x).transpose(1, 2) 494 | else: 495 | x = conv(x) 496 | x = x.transpose(2, 3).contiguous() 497 | x = x.view(x.size(0), -1, x.size(-1)) 498 | else: 499 | for conv in self.conv_layers: 500 | x = conv(x) 501 | if self.conv_type == "conv2d": 502 | b, c, t, f = x.size() 503 | x = x.transpose(2, 3).contiguous().view(b, c * f, t) 504 | return x 505 | 506 | 507 | class TransformerEncoder(nn.Module): 508 | def __init__(self, args): 509 | super().__init__() 510 | 511 | self.dropout = args.dropout 512 | self.embedding_dim = args.encoder_embed_dim 513 | 514 | self.pos_conv = nn.Conv1d( 515 | self.embedding_dim, 516 | self.embedding_dim, 517 | kernel_size=args.conv_pos, 518 | padding=args.conv_pos // 2, 519 | groups=args.conv_pos_groups, 520 | ) 521 | dropout = 0 522 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 523 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 524 | nn.init.constant_(self.pos_conv.bias, 0) 525 | 526 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 527 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 528 | 529 | if hasattr(args, "relative_position_embedding"): 530 | self.relative_position_embedding = args.relative_position_embedding 531 | self.num_buckets = args.num_buckets 532 | self.max_distance = args.max_distance 533 | else: 534 | self.relative_position_embedding = False 535 | self.num_buckets = 0 536 | self.max_distance = 0 537 | 538 | self.layers = nn.ModuleList( 539 | [ 540 | TransformerSentenceEncoderLayer( 541 | embedding_dim=self.embedding_dim, 542 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 543 | num_attention_heads=args.encoder_attention_heads, 544 | dropout=self.dropout, 545 | attention_dropout=args.attention_dropout, 546 | activation_dropout=args.activation_dropout, 547 | activation_fn=args.activation_fn, 548 | layer_norm_first=args.layer_norm_first, 549 | has_relative_attention_bias=(self.relative_position_embedding and i == 0), 550 | num_buckets=self.num_buckets, 551 | max_distance=self.max_distance, 552 | gru_rel_pos=args.gru_rel_pos, 553 | ) 554 | for i in range(args.encoder_layers) 555 | ] 556 | ) 557 | 558 | self.layer_norm_first = args.layer_norm_first 559 | self.layer_norm = LayerNorm(self.embedding_dim) 560 | self.layerdrop = args.encoder_layerdrop 561 | 562 | self.apply(init_bert_params) 563 | 564 | def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): 565 | x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) 566 | 567 | if self.layer_norm_first and layer is None: 568 | x = self.layer_norm(x) 569 | 570 | return x, layer_results 571 | 572 | def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): 573 | 574 | if padding_mask is not None: 575 | x[padding_mask] = 0 576 | 577 | x_conv = self.pos_conv(x.transpose(1, 2)) 578 | x_conv = x_conv.transpose(1, 2) 579 | x += x_conv 580 | 581 | if not self.layer_norm_first: 582 | x = self.layer_norm(x) 583 | 584 | x = F.dropout(x, p=self.dropout, training=self.training) 585 | 586 | # B x T x C -> T x B x C 587 | x = x.transpose(0, 1) 588 | 589 | layer_results = [] 590 | z = None 591 | if tgt_layer is not None: 592 | layer_results.append((x, z)) 593 | r = None 594 | pos_bias = None 595 | for i, layer in enumerate(self.layers): 596 | dropout_probability = np.random.random() 597 | if not self.training or (dropout_probability > self.layerdrop): 598 | x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=True, # z is attn_weight 599 | self_attn_mask=streaming_mask, pos_bias=pos_bias) 600 | if tgt_layer is not None: 601 | layer_results.append((x, z)) 602 | if i == tgt_layer: 603 | r = x 604 | break 605 | 606 | if r is not None: 607 | x = r 608 | 609 | # T x B x C -> B x T x C 610 | x = x.transpose(0, 1) 611 | 612 | return x, layer_results 613 | 614 | 615 | class TransformerSentenceEncoderLayer(nn.Module): 616 | """ 617 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 618 | models. 619 | """ 620 | 621 | def __init__( 622 | self, 623 | embedding_dim: float = 768, 624 | ffn_embedding_dim: float = 3072, 625 | num_attention_heads: float = 8, 626 | dropout: float = 0.1, 627 | attention_dropout: float = 0.1, 628 | activation_dropout: float = 0.1, 629 | activation_fn: str = "gelu", # relu 630 | layer_norm_first: bool = False, 631 | has_relative_attention_bias: bool = False, 632 | num_buckets: int = 0, 633 | max_distance: int = 0, 634 | rescale_init: bool = False, 635 | gru_rel_pos: bool = False, 636 | ) -> None: 637 | 638 | super().__init__() 639 | # Initialize parameters 640 | self.embedding_dim = embedding_dim 641 | self.dropout = dropout 642 | self.activation_dropout = activation_dropout 643 | 644 | # Initialize blocks 645 | self.activation_name = activation_fn 646 | self.activation_fn = get_activation_fn(activation_fn) 647 | self.self_attn = MultiheadAttention( 648 | self.embedding_dim, 649 | num_attention_heads, 650 | dropout=attention_dropout, 651 | self_attention=True, 652 | has_relative_attention_bias=has_relative_attention_bias, 653 | num_buckets=num_buckets, 654 | max_distance=max_distance, 655 | rescale_init=rescale_init, 656 | gru_rel_pos=gru_rel_pos, 657 | ) 658 | 659 | self.dropout1 = nn.Dropout(dropout) 660 | self.dropout2 = nn.Dropout(self.activation_dropout) 661 | self.dropout3 = nn.Dropout(dropout) 662 | 663 | self.layer_norm_first = layer_norm_first 664 | 665 | # layer norm associated with the self attention layer 666 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 667 | 668 | if self.activation_name == "glu": 669 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") 670 | else: 671 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 672 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 673 | 674 | # layer norm associated with the position wise feed-forward NN 675 | self.final_layer_norm = LayerNorm(self.embedding_dim) 676 | 677 | def forward( 678 | self, 679 | x: torch.Tensor, 680 | self_attn_mask: torch.Tensor = None, 681 | self_attn_padding_mask: torch.Tensor = None, 682 | need_weights: bool = False, 683 | pos_bias=None 684 | ): 685 | """ 686 | LayerNorm is applied either before or after the self-attention/ffn 687 | modules similar to the original Transformer imlementation. 688 | """ 689 | residual = x 690 | 691 | if self.layer_norm_first: 692 | x = self.self_attn_layer_norm(x) 693 | x, attn, pos_bias = self.self_attn( 694 | query=x, 695 | key=x, 696 | value=x, 697 | key_padding_mask=self_attn_padding_mask, 698 | need_weights=need_weights, 699 | attn_mask=self_attn_mask, 700 | position_bias=pos_bias 701 | ) 702 | x = self.dropout1(x) 703 | x = residual + x 704 | 705 | residual = x 706 | x = self.final_layer_norm(x) 707 | if self.activation_name == "glu": 708 | x = self.fc1(x) 709 | else: 710 | x = self.activation_fn(self.fc1(x)) 711 | x = self.dropout2(x) 712 | x = self.fc2(x) 713 | x = self.dropout3(x) 714 | x = residual + x 715 | else: 716 | x, attn, pos_bias = self.self_attn( 717 | query=x, 718 | key=x, 719 | value=x, 720 | key_padding_mask=self_attn_padding_mask, 721 | need_weights=need_weights, 722 | attn_mask=self_attn_mask, 723 | position_bias=pos_bias 724 | ) 725 | 726 | x = self.dropout1(x) 727 | x = residual + x 728 | 729 | x = self.self_attn_layer_norm(x) 730 | 731 | residual = x 732 | if self.activation_name == "glu": 733 | x = self.fc1(x) 734 | else: 735 | x = self.activation_fn(self.fc1(x)) 736 | x = self.dropout2(x) 737 | x = self.fc2(x) 738 | x = self.dropout3(x) 739 | x = residual + x 740 | x = self.final_layer_norm(x) 741 | 742 | return x, attn, pos_bias 743 | -------------------------------------------------------------------------------- /extract_feature/WavLM/extract_wavlm.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import sys 4 | import os 5 | import soundfile as sf 6 | import scipy.signal as signal 7 | from scipy import io 8 | import pandas as pd 9 | import argparse 10 | 11 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 12 | from extract_feature.WavLM.WavLM import WavLM, WavLMConfig 13 | 14 | def read_audio(path, sample_rate=16000): 15 | wav, sr = sf.read(path) 16 | 17 | if sr != sample_rate: 18 | num = int((wav.shape[0]) / sr * sample_rate) 19 | wav = signal.resample(wav, num) 20 | print(f'Resample {sr} to {sample_rate}') 21 | 22 | if wav.ndim == 2: 23 | wav = wav.mean(-1) 24 | assert wav.ndim == 1, wav.ndim 25 | 26 | if wav.shape[0] > sample_rate * 20: 27 | print(f'Crop raw wav from {wav.shape[0]} to {sample_rate * 20}') 28 | wav = wav[:sample_rate * 20] 29 | 30 | return wav 31 | 32 | def extract_wavlm(model, wavfile, savefile, layer=24): 33 | ''' 34 | Args: 35 | layer (int): varies from 1 to 24. 36 | ''' 37 | 38 | if isinstance(savefile, str): 39 | if os.path.exists(savefile): 40 | print('File existed:', savefile) 41 | return 42 | savefile = [savefile] 43 | layer = [layer] 44 | else: 45 | for file in savefile: 46 | if os.path.exists(file): 47 | print('File existed:', file) 48 | return 49 | assert len(savefile) == len(layer) 50 | 51 | wav_input_16khz = read_audio(wavfile) 52 | wav_input_16khz = torch.from_numpy(wav_input_16khz).float().unsqueeze(dim=0).cuda() 53 | 54 | ############################################ 55 | # extract the representation of last layer # 56 | ############################################ 57 | # with torch.no_grad(): 58 | # rep = model.extract_features(wav_input_16khz)[0] 59 | 60 | # rep = rep.squeeze(dim=0).cpu().detach().numpy() # (t, 768) / (t, 1024) 61 | # dict = {'wavlm': rep} 62 | # io.savemat(savefile, dict) 63 | # print(savefile, '->', rep.shape) 64 | 65 | ############################################ 66 | # extract the representation of each layer # 67 | ############################################ 68 | with torch.no_grad(): 69 | if model.cfg.normalize: 70 | wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz , wav_input_16khz.shape) 71 | rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0] 72 | layer_reps = [x.transpose(0, 1) for x, _ in layer_results] # layer_results: [(x, z), (x, z), ...] z is attn_weight 73 | layer_attn = [z for _, z in layer_results] # z is the average attention weights over heads with shape (B, T, T) 74 | 75 | for save, l in zip(savefile, layer): 76 | rep_l = layer_reps[l] 77 | rep_l = rep_l.squeeze(dim=0).cpu().detach().numpy() # (t, 768) / (t, 1024) 78 | dict = {'wavlm': rep_l} 79 | io.savemat(save, dict) 80 | print(save, '->', rep_l.shape) 81 | 82 | def main(args): 83 | wavdir = args.wavdir 84 | savedir = args.savedir 85 | ckpt = args.wavlm 86 | layer = args.layer 87 | csvfile = args.csvfile 88 | gpu = args.gpu 89 | 90 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 91 | 92 | checkpoint = torch.load(ckpt) 93 | cfg = WavLMConfig(checkpoint['cfg']) 94 | model = WavLM(cfg) 95 | model.load_state_dict(checkpoint['model']) 96 | model.eval().cuda() 97 | 98 | if not os.path.exists(savedir): 99 | os.makedirs(savedir) 100 | 101 | if csvfile is not None: 102 | df = pd.read_csv(csvfile) 103 | file_names = df['name'].tolist() 104 | else: 105 | file_names = os.listdir(wavdir) 106 | 107 | total = len(file_names) 108 | for i, name in enumerate(file_names): 109 | wavfile = os.path.join(wavdir, name+'.wav') 110 | savefile = os.path.join(savedir, name) 111 | 112 | if os.path.exists(savefile): 113 | print('Pass', name) 114 | continue 115 | 116 | print(f'----------- {i+1} / {total} -----------') 117 | extract_wavlm(model, wavfile, savefile, layer=layer) 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('--wavdir', type=str, help='wav directory') 122 | parser.add_argument('--savedir', type=str, help='save directory') 123 | parser.add_argument('--wavlm', type=str, default=None, help='wavlm model') 124 | parser.add_argument('--layer', type=int, default=24, help='layer index, varies from 1 to 24') 125 | parser.add_argument('--csvfile', type=str, default=None, help='csv file with name column') 126 | parser.add_argument('--gpu', type=str, default='0', help='gpu id') 127 | args = parser.parse_args() 128 | 129 | main(args) 130 | -------------------------------------------------------------------------------- /extract_feature/WavLM/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import warnings 12 | from typing import Dict, Optional, Tuple 13 | import torch 14 | from torch import Tensor, nn 15 | from torch.nn import Parameter 16 | import torch.nn.functional as F 17 | 18 | 19 | class TransposeLast(nn.Module): 20 | def __init__(self, deconstruct_idx=None): 21 | super().__init__() 22 | self.deconstruct_idx = deconstruct_idx 23 | 24 | def forward(self, x): 25 | if self.deconstruct_idx is not None: 26 | x = x[self.deconstruct_idx] 27 | return x.transpose(-2, -1) 28 | 29 | 30 | class Fp32LayerNorm(nn.LayerNorm): 31 | def __init__(self, *args, **kwargs): 32 | super().__init__(*args, **kwargs) 33 | 34 | def forward(self, input): 35 | output = F.layer_norm( 36 | input.float(), 37 | self.normalized_shape, 38 | self.weight.float() if self.weight is not None else None, 39 | self.bias.float() if self.bias is not None else None, 40 | self.eps, 41 | ) 42 | return output.type_as(input) 43 | 44 | 45 | class Fp32GroupNorm(nn.GroupNorm): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | 49 | def forward(self, input): 50 | output = F.group_norm( 51 | input.float(), 52 | self.num_groups, 53 | self.weight.float() if self.weight is not None else None, 54 | self.bias.float() if self.bias is not None else None, 55 | self.eps, 56 | ) 57 | return output.type_as(input) 58 | 59 | 60 | class GradMultiply(torch.autograd.Function): 61 | @staticmethod 62 | def forward(ctx, x, scale): 63 | ctx.scale = scale 64 | res = x.new(x) 65 | return res 66 | 67 | @staticmethod 68 | def backward(ctx, grad): 69 | return grad * ctx.scale, None 70 | 71 | 72 | class SamePad(nn.Module): 73 | def __init__(self, kernel_size, causal=False): 74 | super().__init__() 75 | if causal: 76 | self.remove = kernel_size - 1 77 | else: 78 | self.remove = 1 if kernel_size % 2 == 0 else 0 79 | 80 | def forward(self, x): 81 | if self.remove > 0: 82 | x = x[:, :, : -self.remove] 83 | return x 84 | 85 | 86 | class Swish(nn.Module): 87 | """Swish function 88 | """ 89 | 90 | def __init__(self): 91 | """Construct an MultiHeadedAttention object.""" 92 | super(Swish, self).__init__() 93 | self.act = torch.nn.Sigmoid() 94 | 95 | def forward(self, x): 96 | return x * self.act(x) 97 | 98 | 99 | class GLU_Linear(nn.Module): 100 | def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): 101 | super(GLU_Linear, self).__init__() 102 | 103 | self.glu_type = glu_type 104 | self.output_dim = output_dim 105 | 106 | if glu_type == "sigmoid": 107 | self.glu_act = torch.nn.Sigmoid() 108 | elif glu_type == "swish": 109 | self.glu_act = Swish() 110 | elif glu_type == "relu": 111 | self.glu_act = torch.nn.ReLU() 112 | elif glu_type == "gelu": 113 | self.glu_act = torch.nn.GELU() 114 | 115 | if bias_in_glu: 116 | self.linear = nn.Linear(input_dim, output_dim * 2, True) 117 | else: 118 | self.linear = nn.Linear(input_dim, output_dim * 2, False) 119 | 120 | def forward(self, x): 121 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case 122 | x = self.linear(x) 123 | 124 | if self.glu_type == "bilinear": 125 | x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) 126 | else: 127 | x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) 128 | 129 | return x 130 | 131 | 132 | def gelu_accurate(x): 133 | if not hasattr(gelu_accurate, "_a"): 134 | gelu_accurate._a = math.sqrt(2 / math.pi) 135 | return ( 136 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 137 | ) 138 | 139 | 140 | def gelu(x: torch.Tensor) -> torch.Tensor: 141 | return torch.nn.functional.gelu(x.float()).type_as(x) 142 | 143 | 144 | def get_activation_fn(activation: str): 145 | """Returns the activation function corresponding to `activation`""" 146 | 147 | if activation == "relu": 148 | return F.relu 149 | elif activation == "gelu": 150 | return gelu 151 | elif activation == "gelu_fast": 152 | warnings.warn( 153 | "--activation-fn=gelu_fast has been renamed to gelu_accurate" 154 | ) 155 | return gelu_accurate 156 | elif activation == "gelu_accurate": 157 | return gelu_accurate 158 | elif activation == "tanh": 159 | return torch.tanh 160 | elif activation == "linear": 161 | return lambda x: x 162 | elif activation == "glu": 163 | return lambda x: x 164 | else: 165 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 166 | 167 | 168 | def init_bert_params(module): 169 | """ 170 | Initialize the weights specific to the BERT Model. 171 | This overrides the default initializations depending on the specified arguments. 172 | 1. If normal_init_linear_weights is set then weights of linear 173 | layer will be initialized using the normal distribution and 174 | bais will be set to the specified value. 175 | 2. If normal_init_embed_weights is set then weights of embedding 176 | layer will be initialized using the normal distribution. 177 | 3. If normal_init_proj_weights is set then weights of 178 | in_project_weight for MultiHeadAttention initialized using 179 | the normal distribution (to be validated). 180 | """ 181 | 182 | def normal_(data): 183 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 184 | # so that the RNG is consistent with and without FSDP 185 | data.copy_( 186 | data.cpu().normal_(mean=0.0, std=0.02).to(data.device) 187 | ) 188 | 189 | if isinstance(module, nn.Linear): 190 | normal_(module.weight.data) 191 | if module.bias is not None: 192 | module.bias.data.zero_() 193 | if isinstance(module, nn.Embedding): 194 | normal_(module.weight.data) 195 | if module.padding_idx is not None: 196 | module.weight.data[module.padding_idx].zero_() 197 | if isinstance(module, MultiheadAttention): 198 | normal_(module.q_proj.weight.data) 199 | normal_(module.k_proj.weight.data) 200 | normal_(module.v_proj.weight.data) 201 | 202 | 203 | def quant_noise(module, p, block_size): 204 | """ 205 | Wraps modules and applies quantization noise to the weights for 206 | subsequent quantization with Iterative Product Quantization as 207 | described in "Training with Quantization Noise for Extreme Model Compression" 208 | 209 | Args: 210 | - module: nn.Module 211 | - p: amount of Quantization Noise 212 | - block_size: size of the blocks for subsequent quantization with iPQ 213 | 214 | Remarks: 215 | - Module weights must have the right sizes wrt the block size 216 | - Only Linear, Embedding and Conv2d modules are supported for the moment 217 | - For more detail on how to quantize by blocks with convolutional weights, 218 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 219 | - We implement the simplest form of noise here as stated in the paper 220 | which consists in randomly dropping blocks 221 | """ 222 | 223 | # if no quantization noise, don't register hook 224 | if p <= 0: 225 | return module 226 | 227 | # supported modules 228 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 229 | 230 | # test whether module.weight has the right sizes wrt block_size 231 | is_conv = module.weight.ndim == 4 232 | 233 | # 2D matrix 234 | if not is_conv: 235 | assert ( 236 | module.weight.size(1) % block_size == 0 237 | ), "Input features must be a multiple of block sizes" 238 | 239 | # 4D matrix 240 | else: 241 | # 1x1 convolutions 242 | if module.kernel_size == (1, 1): 243 | assert ( 244 | module.in_channels % block_size == 0 245 | ), "Input channels must be a multiple of block sizes" 246 | # regular convolutions 247 | else: 248 | k = module.kernel_size[0] * module.kernel_size[1] 249 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 250 | 251 | def _forward_pre_hook(mod, input): 252 | # no noise for evaluation 253 | if mod.training: 254 | if not is_conv: 255 | # gather weight and sizes 256 | weight = mod.weight 257 | in_features = weight.size(1) 258 | out_features = weight.size(0) 259 | 260 | # split weight matrix into blocks and randomly drop selected blocks 261 | mask = torch.zeros( 262 | in_features // block_size * out_features, device=weight.device 263 | ) 264 | mask.bernoulli_(p) 265 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 266 | 267 | else: 268 | # gather weight and sizes 269 | weight = mod.weight 270 | in_channels = mod.in_channels 271 | out_channels = mod.out_channels 272 | 273 | # split weight matrix into blocks and randomly drop selected blocks 274 | if mod.kernel_size == (1, 1): 275 | mask = torch.zeros( 276 | int(in_channels // block_size * out_channels), 277 | device=weight.device, 278 | ) 279 | mask.bernoulli_(p) 280 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 281 | else: 282 | mask = torch.zeros( 283 | weight.size(0), weight.size(1), device=weight.device 284 | ) 285 | mask.bernoulli_(p) 286 | mask = ( 287 | mask.unsqueeze(2) 288 | .unsqueeze(3) 289 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 290 | ) 291 | 292 | # scale weights and apply mask 293 | mask = mask.to( 294 | torch.bool 295 | ) # x.bool() is not currently supported in TorchScript 296 | s = 1 / (1 - p) 297 | mod.weight.data = s * weight.masked_fill(mask, 0) 298 | 299 | module.register_forward_pre_hook(_forward_pre_hook) 300 | return module 301 | 302 | 303 | class MultiheadAttention(nn.Module): 304 | """Multi-headed attention. 305 | 306 | See "Attention Is All You Need" for more details. 307 | """ 308 | 309 | def __init__( 310 | self, 311 | embed_dim, 312 | num_heads, 313 | kdim=None, 314 | vdim=None, 315 | dropout=0.0, 316 | bias=True, 317 | add_bias_kv=False, 318 | add_zero_attn=False, 319 | self_attention=False, 320 | encoder_decoder_attention=False, 321 | q_noise=0.0, 322 | qn_block_size=8, 323 | has_relative_attention_bias=False, 324 | num_buckets=32, 325 | max_distance=128, 326 | gru_rel_pos=False, 327 | rescale_init=False, 328 | ): 329 | super().__init__() 330 | self.embed_dim = embed_dim 331 | self.kdim = kdim if kdim is not None else embed_dim 332 | self.vdim = vdim if vdim is not None else embed_dim 333 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim 334 | 335 | self.num_heads = num_heads 336 | self.dropout_module = nn.Dropout(dropout) 337 | 338 | self.has_relative_attention_bias = has_relative_attention_bias 339 | self.num_buckets = num_buckets 340 | self.max_distance = max_distance 341 | if self.has_relative_attention_bias: 342 | self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) 343 | 344 | self.head_dim = embed_dim // num_heads 345 | self.q_head_dim = self.head_dim 346 | self.k_head_dim = self.head_dim 347 | assert ( 348 | self.head_dim * num_heads == self.embed_dim 349 | ), "embed_dim must be divisible by num_heads" 350 | self.scaling = self.head_dim ** -0.5 351 | 352 | self.self_attention = self_attention 353 | self.encoder_decoder_attention = encoder_decoder_attention 354 | 355 | assert not self.self_attention or self.qkv_same_dim, ( 356 | "Self-attention requires query, key and " "value to be of the same size" 357 | ) 358 | 359 | k_bias = True 360 | if rescale_init: 361 | k_bias = False 362 | 363 | k_embed_dim = embed_dim 364 | q_embed_dim = embed_dim 365 | 366 | self.k_proj = quant_noise( 367 | nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size 368 | ) 369 | self.v_proj = quant_noise( 370 | nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size 371 | ) 372 | self.q_proj = quant_noise( 373 | nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size 374 | ) 375 | 376 | self.out_proj = quant_noise( 377 | nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size 378 | ) 379 | 380 | if add_bias_kv: 381 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 382 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 383 | else: 384 | self.bias_k = self.bias_v = None 385 | 386 | self.add_zero_attn = add_zero_attn 387 | 388 | self.gru_rel_pos = gru_rel_pos 389 | if self.gru_rel_pos: 390 | self.grep_linear = nn.Linear(self.q_head_dim, 8) 391 | self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 392 | 393 | self.reset_parameters() 394 | 395 | def reset_parameters(self): 396 | if self.qkv_same_dim: 397 | # Empirically observed the convergence to be much better with 398 | # the scaled initialization 399 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 400 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 401 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 402 | else: 403 | nn.init.xavier_uniform_(self.k_proj.weight) 404 | nn.init.xavier_uniform_(self.v_proj.weight) 405 | nn.init.xavier_uniform_(self.q_proj.weight) 406 | 407 | nn.init.xavier_uniform_(self.out_proj.weight) 408 | if self.out_proj.bias is not None: 409 | nn.init.constant_(self.out_proj.bias, 0.0) 410 | if self.bias_k is not None: 411 | nn.init.xavier_normal_(self.bias_k) 412 | if self.bias_v is not None: 413 | nn.init.xavier_normal_(self.bias_v) 414 | if self.has_relative_attention_bias: 415 | nn.init.xavier_normal_(self.relative_attention_bias.weight) 416 | 417 | def _relative_positions_bucket(self, relative_positions, bidirectional=True): 418 | num_buckets = self.num_buckets 419 | max_distance = self.max_distance 420 | relative_buckets = 0 421 | 422 | if bidirectional: 423 | num_buckets = num_buckets // 2 424 | relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets 425 | relative_positions = torch.abs(relative_positions) 426 | else: 427 | relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) 428 | 429 | max_exact = num_buckets // 2 430 | is_small = relative_positions < max_exact 431 | 432 | relative_postion_if_large = max_exact + ( 433 | torch.log(relative_positions.float() / max_exact) 434 | / math.log(max_distance / max_exact) 435 | * (num_buckets - max_exact) 436 | ).to(torch.long) 437 | relative_postion_if_large = torch.min( 438 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 439 | ) 440 | 441 | relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) 442 | return relative_buckets 443 | 444 | def compute_bias(self, query_length, key_length): 445 | context_position = torch.arange(query_length, dtype=torch.long)[:, None] 446 | memory_position = torch.arange(key_length, dtype=torch.long)[None, :] 447 | relative_position = memory_position - context_position 448 | relative_position_bucket = self._relative_positions_bucket( 449 | relative_position, 450 | bidirectional=True 451 | ) 452 | relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) 453 | values = self.relative_attention_bias(relative_position_bucket) 454 | values = values.permute([2, 0, 1]) 455 | return values 456 | 457 | def forward( 458 | self, 459 | query, 460 | key: Optional[Tensor], 461 | value: Optional[Tensor], 462 | key_padding_mask: Optional[Tensor] = None, 463 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 464 | need_weights: bool = True, 465 | static_kv: bool = False, 466 | attn_mask: Optional[Tensor] = None, 467 | before_softmax: bool = False, 468 | need_head_weights: bool = False, 469 | position_bias: Optional[Tensor] = None 470 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 471 | """Input shape: Time x Batch x Channel 472 | 473 | Args: 474 | key_padding_mask (ByteTensor, optional): mask to exclude 475 | keys that are pads, of shape `(batch, src_len)`, where 476 | padding elements are indicated by 1s. 477 | need_weights (bool, optional): return the attention weights, 478 | averaged over heads (default: False). 479 | attn_mask (ByteTensor, optional): typically used to 480 | implement causal attention, where the mask prevents the 481 | attention from looking forward in time (default: None). 482 | before_softmax (bool, optional): return the raw attention 483 | weights and values before the attention softmax. 484 | need_head_weights (bool, optional): return the attention 485 | weights for each head. Implies *need_weights*. Default: 486 | return the average attention weights over all heads. 487 | """ 488 | if need_head_weights: 489 | need_weights = True 490 | 491 | is_tpu = query.device.type == "xla" 492 | 493 | tgt_len, bsz, embed_dim = query.size() 494 | src_len = tgt_len 495 | assert embed_dim == self.embed_dim 496 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 497 | if key is not None: 498 | src_len, key_bsz, _ = key.size() 499 | if not torch.jit.is_scripting(): 500 | assert key_bsz == bsz 501 | assert value is not None 502 | assert src_len, bsz == value.shape[:2] 503 | 504 | if self.has_relative_attention_bias and position_bias is None: 505 | position_bias = self.compute_bias(tgt_len, src_len) 506 | position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) 507 | 508 | if ( 509 | not is_tpu # don't use PyTorch version on TPUs 510 | and incremental_state is None 511 | and not static_kv 512 | # A workaround for quantization to work. Otherwise JIT compilation 513 | # treats bias in linear module as method. 514 | and not torch.jit.is_scripting() 515 | and self.q_head_dim == self.head_dim 516 | ): 517 | assert key is not None and value is not None 518 | assert attn_mask is None 519 | 520 | attn_mask_rel_pos = None 521 | if position_bias is not None: 522 | attn_mask_rel_pos = position_bias 523 | if self.gru_rel_pos: 524 | query_layer = query.transpose(0, 1) 525 | new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) 526 | query_layer = query_layer.view(*new_x_shape) 527 | query_layer = query_layer.permute(0, 2, 1, 3) 528 | _B, _H, _L, __ = query_layer.size() 529 | 530 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( 531 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) 532 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 533 | attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias 534 | 535 | attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) 536 | k_proj_bias = self.k_proj.bias 537 | if k_proj_bias is None: 538 | k_proj_bias = torch.zeros_like(self.q_proj.bias) 539 | 540 | x, attn = F.multi_head_attention_forward( 541 | query, 542 | key, 543 | value, 544 | self.embed_dim, 545 | self.num_heads, 546 | torch.empty([0]), 547 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), 548 | self.bias_k, 549 | self.bias_v, 550 | self.add_zero_attn, 551 | self.dropout_module.p, 552 | self.out_proj.weight, 553 | self.out_proj.bias, 554 | self.training, 555 | # self.training or self.dropout_module.apply_during_inference, 556 | key_padding_mask, 557 | need_weights, 558 | attn_mask_rel_pos, 559 | use_separate_proj_weight=True, 560 | q_proj_weight=self.q_proj.weight, 561 | k_proj_weight=self.k_proj.weight, 562 | v_proj_weight=self.v_proj.weight, 563 | ) 564 | return x, attn, position_bias 565 | 566 | if incremental_state is not None: 567 | saved_state = self._get_input_buffer(incremental_state) 568 | if saved_state is not None and "prev_key" in saved_state: 569 | # previous time steps are cached - no need to recompute 570 | # key and value if they are static 571 | if static_kv: 572 | assert self.encoder_decoder_attention and not self.self_attention 573 | key = value = None 574 | else: 575 | saved_state = None 576 | 577 | if self.self_attention: 578 | q = self.q_proj(query) 579 | k = self.k_proj(query) 580 | v = self.v_proj(query) 581 | elif self.encoder_decoder_attention: 582 | # encoder-decoder attention 583 | q = self.q_proj(query) 584 | if key is None: 585 | assert value is None 586 | k = v = None 587 | else: 588 | k = self.k_proj(key) 589 | v = self.v_proj(key) 590 | 591 | else: 592 | assert key is not None and value is not None 593 | q = self.q_proj(query) 594 | k = self.k_proj(key) 595 | v = self.v_proj(value) 596 | q *= self.scaling 597 | 598 | if self.bias_k is not None: 599 | assert self.bias_v is not None 600 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 601 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 602 | if attn_mask is not None: 603 | attn_mask = torch.cat( 604 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 605 | ) 606 | if key_padding_mask is not None: 607 | key_padding_mask = torch.cat( 608 | [ 609 | key_padding_mask, 610 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), 611 | ], 612 | dim=1, 613 | ) 614 | 615 | q = ( 616 | q.contiguous() 617 | .view(tgt_len, bsz * self.num_heads, self.q_head_dim) 618 | .transpose(0, 1) 619 | ) 620 | if k is not None: 621 | k = ( 622 | k.contiguous() 623 | .view(-1, bsz * self.num_heads, self.k_head_dim) 624 | .transpose(0, 1) 625 | ) 626 | if v is not None: 627 | v = ( 628 | v.contiguous() 629 | .view(-1, bsz * self.num_heads, self.head_dim) 630 | .transpose(0, 1) 631 | ) 632 | 633 | if saved_state is not None: 634 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 635 | if "prev_key" in saved_state: 636 | _prev_key = saved_state["prev_key"] 637 | assert _prev_key is not None 638 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) 639 | if static_kv: 640 | k = prev_key 641 | else: 642 | assert k is not None 643 | k = torch.cat([prev_key, k], dim=1) 644 | src_len = k.size(1) 645 | if "prev_value" in saved_state: 646 | _prev_value = saved_state["prev_value"] 647 | assert _prev_value is not None 648 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) 649 | if static_kv: 650 | v = prev_value 651 | else: 652 | assert v is not None 653 | v = torch.cat([prev_value, v], dim=1) 654 | prev_key_padding_mask: Optional[Tensor] = None 655 | if "prev_key_padding_mask" in saved_state: 656 | prev_key_padding_mask = saved_state["prev_key_padding_mask"] 657 | assert k is not None and v is not None 658 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( 659 | key_padding_mask=key_padding_mask, 660 | prev_key_padding_mask=prev_key_padding_mask, 661 | batch_size=bsz, 662 | src_len=k.size(1), 663 | static_kv=static_kv, 664 | ) 665 | 666 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) 667 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) 668 | saved_state["prev_key_padding_mask"] = key_padding_mask 669 | # In this branch incremental_state is never None 670 | assert incremental_state is not None 671 | incremental_state = self._set_input_buffer(incremental_state, saved_state) 672 | assert k is not None 673 | assert k.size(1) == src_len 674 | 675 | # This is part of a workaround to get around fork/join parallelism 676 | # not supporting Optional types. 677 | if key_padding_mask is not None and key_padding_mask.dim() == 0: 678 | key_padding_mask = None 679 | 680 | if key_padding_mask is not None: 681 | assert key_padding_mask.size(0) == bsz 682 | assert key_padding_mask.size(1) == src_len 683 | 684 | if self.add_zero_attn: 685 | assert v is not None 686 | src_len += 1 687 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 688 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 689 | if attn_mask is not None: 690 | attn_mask = torch.cat( 691 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 692 | ) 693 | if key_padding_mask is not None: 694 | key_padding_mask = torch.cat( 695 | [ 696 | key_padding_mask, 697 | torch.zeros(key_padding_mask.size(0), 1).type_as( 698 | key_padding_mask 699 | ), 700 | ], 701 | dim=1, 702 | ) 703 | 704 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 705 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) 706 | 707 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 708 | 709 | if attn_mask is not None: 710 | attn_mask = attn_mask.unsqueeze(0) 711 | attn_weights += attn_mask 712 | 713 | if key_padding_mask is not None: 714 | # don't attend to padding symbols 715 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 716 | if not is_tpu: 717 | attn_weights = attn_weights.masked_fill( 718 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 719 | float("-inf"), 720 | ) 721 | else: 722 | attn_weights = attn_weights.transpose(0, 2) 723 | attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) 724 | attn_weights = attn_weights.transpose(0, 2) 725 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 726 | 727 | if before_softmax: 728 | return attn_weights, v, position_bias 729 | 730 | if position_bias is not None: 731 | if self.gru_rel_pos == 1: 732 | query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) 733 | _B, _H, _L, __ = query_layer.size() 734 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( 735 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) 736 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 737 | position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias 738 | 739 | position_bias = position_bias.view(attn_weights.size()) 740 | 741 | attn_weights = attn_weights + position_bias 742 | 743 | attn_weights_float = F.softmax( 744 | attn_weights, dim=-1 745 | ) 746 | attn_weights = attn_weights_float.type_as(attn_weights) 747 | attn_probs = self.dropout_module(attn_weights) 748 | 749 | assert v is not None 750 | attn = torch.bmm(attn_probs, v) 751 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 752 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 753 | attn = self.out_proj(attn) 754 | attn_weights: Optional[Tensor] = None 755 | if need_weights: 756 | attn_weights = attn_weights_float.view( 757 | bsz, self.num_heads, tgt_len, src_len 758 | ).transpose(1, 0) 759 | if not need_head_weights: 760 | # average attention weights over heads 761 | attn_weights = attn_weights.mean(dim=0) 762 | 763 | return attn, attn_weights, position_bias 764 | 765 | @staticmethod 766 | def _append_prev_key_padding_mask( 767 | key_padding_mask: Optional[Tensor], 768 | prev_key_padding_mask: Optional[Tensor], 769 | batch_size: int, 770 | src_len: int, 771 | static_kv: bool, 772 | ) -> Optional[Tensor]: 773 | # saved key padding masks have shape (bsz, seq_len) 774 | if prev_key_padding_mask is not None and static_kv: 775 | new_key_padding_mask = prev_key_padding_mask 776 | elif prev_key_padding_mask is not None and key_padding_mask is not None: 777 | new_key_padding_mask = torch.cat( 778 | [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 779 | ) 780 | # During incremental decoding, as the padding token enters and 781 | # leaves the frame, there will be a time when prev or current 782 | # is None 783 | elif prev_key_padding_mask is not None: 784 | if src_len > prev_key_padding_mask.size(1): 785 | filler = torch.zeros( 786 | (batch_size, src_len - prev_key_padding_mask.size(1)), 787 | device=prev_key_padding_mask.device, 788 | ) 789 | new_key_padding_mask = torch.cat( 790 | [prev_key_padding_mask.float(), filler.float()], dim=1 791 | ) 792 | else: 793 | new_key_padding_mask = prev_key_padding_mask.float() 794 | elif key_padding_mask is not None: 795 | if src_len > key_padding_mask.size(1): 796 | filler = torch.zeros( 797 | (batch_size, src_len - key_padding_mask.size(1)), 798 | device=key_padding_mask.device, 799 | ) 800 | new_key_padding_mask = torch.cat( 801 | [filler.float(), key_padding_mask.float()], dim=1 802 | ) 803 | else: 804 | new_key_padding_mask = key_padding_mask.float() 805 | else: 806 | new_key_padding_mask = prev_key_padding_mask 807 | return new_key_padding_mask 808 | 809 | def _get_input_buffer( 810 | self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] 811 | ) -> Dict[str, Optional[Tensor]]: 812 | result = self.get_incremental_state(incremental_state, "attn_state") 813 | if result is not None: 814 | return result 815 | else: 816 | empty_result: Dict[str, Optional[Tensor]] = {} 817 | return empty_result 818 | 819 | def _set_input_buffer( 820 | self, 821 | incremental_state: Dict[str, Dict[str, Optional[Tensor]]], 822 | buffer: Dict[str, Optional[Tensor]], 823 | ): 824 | return self.set_incremental_state(incremental_state, "attn_state", buffer) 825 | 826 | def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): 827 | return attn_weights -------------------------------------------------------------------------------- /extract_feature/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/Vesper/93d41d702ee2105b1bda92e852725548e3893ebd/extract_feature/__init__.py -------------------------------------------------------------------------------- /extract_feature/audio.py: -------------------------------------------------------------------------------- 1 | 2 | # From Mockingjay 3 | 4 | # -*- coding: utf-8 -*- # 5 | """*********************************************************************************************""" 6 | # FileName [ utils/audio.py ] 7 | # Synopsis [ audio processing functions ] 8 | # Author [ Andy T. Liu (Andi611) ] 9 | # Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ] 10 | # Reference 0 [ https://github.com/andi611/TTS-Tacotron-Pytorch ] 11 | # Reference 1 [ https://github.com/Alexander-H-Liu/End-to-end-ASR-Pytorch ] 12 | # Reference 2 [ https://groups.google.com/forum/#!msg/librosa/V4Z1HpTKn8Q/1-sMpjxjCSoJ ] 13 | """*********************************************************************************************""" 14 | 15 | 16 | ############### 17 | # IMPORTATION # 18 | ############### 19 | import librosa 20 | import numpy as np 21 | import matplotlib 22 | matplotlib.use("Agg") 23 | import matplotlib.pylab as plt 24 | from matplotlib.colors import SymLogNorm 25 | from scipy import signal 26 | import warnings 27 | warnings.filterwarnings("ignore") 28 | # NOTE: there are warnings for MFCC extraction due to librosa's issue 29 | 30 | 31 | ################## 32 | # AUDIO SETTINGS # 33 | ################## 34 | sample_rate = 16000 35 | """ 36 | For feature == 'fbank' or 'mfcc' 37 | """ 38 | num_mels = 80 # int, dimension of feature 39 | delta = True # Append Delta 40 | delta_delta = False # Append Delta Delta 41 | window_size = 25 # int, window size for FFT (ms) 42 | stride = 10 # int, window stride for FFT 43 | mel_dim = num_mels * (1 + int(delta) + int(delta_delta)) 44 | """ 45 | For feature == 'linear' 46 | """ 47 | num_freq = 1025 48 | frame_length_ms = 50 49 | frame_shift_ms = 12.5 50 | preemphasis = 0.97 51 | min_level_db = -100 52 | ref_level_db = 20 53 | hop_length = 250 54 | griffin_lim_iters = 16 55 | power = 1.5 # Power to raise magnitudes to prior to Griffin-Lim 56 | """ 57 | For feature == 'fmllr' 58 | """ 59 | fmllr_dim = 40 60 | 61 | 62 | ############################# 63 | # SPECTROGRAM UTILS FORWARD # 64 | ############################# 65 | def _stft_parameters(sample_rate): 66 | n_fft = (num_freq - 1) * 2 67 | hop_length = int(frame_shift_ms / 1000 * sample_rate) 68 | win_length = int(frame_length_ms / 1000 * sample_rate) 69 | return n_fft, hop_length, win_length 70 | 71 | def _linear_to_mel(spectrogram, sample_rate): 72 | _mel_basis = _build_mel_basis(sample_rate) 73 | return np.dot(_mel_basis, spectrogram) 74 | 75 | def _build_mel_basis(sample_rate): 76 | n_fft = (num_freq - 1) * 2 77 | return librosa.filters.mel(sample_rate, n_fft, n_mels=num_mels) 78 | 79 | def _preemphasis(x): 80 | return signal.lfilter([1, -preemphasis], [1], x) 81 | 82 | def _amp_to_db(x): 83 | return 20 * np.log10(np.maximum(1e-5, x)) 84 | 85 | def _normalize(S): 86 | return np.clip((S - min_level_db) / -min_level_db, 0, 1) 87 | 88 | def _stft(y, sr): 89 | n_fft, hop_length, win_length = _stft_parameters(sr) 90 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 91 | 92 | 93 | ############################# 94 | # SPECTROGRAM UTILS BACKWARD # 95 | ############################# 96 | def _denormalize(S): 97 | return (np.clip(S, 0, 1) * -min_level_db) + min_level_db 98 | 99 | def _db_to_amp(x): 100 | return np.power(10.0, x * 0.05) 101 | 102 | def inv_preemphasis(x): 103 | return signal.lfilter([1], [1, -preemphasis], x) 104 | 105 | def _griffin_lim(S, sr): 106 | """ 107 | librosa implementation of Griffin-Lim 108 | Based on https://github.com/librosa/librosa/issues/434 109 | """ 110 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 111 | S_complex = np.abs(S).astype(np.complex) 112 | y = _istft(S_complex * angles, sr) 113 | for i in range(griffin_lim_iters): 114 | angles = np.exp(1j * np.angle(_stft(y ,sr))) 115 | y = _istft(S_complex * angles, sr) 116 | return y 117 | 118 | def _istft(y, sr): 119 | _, hop_length, win_length = _stft_parameters(sr) 120 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 121 | 122 | 123 | ################### 124 | # MEL SPECTROGRAM # 125 | ################### 126 | """ 127 | Compute the mel-scale spectrogram from the wav. 128 | """ 129 | def melspectrogram(y, sr): 130 | D = _stft(_preemphasis(y), sr) 131 | S = _amp_to_db(_linear_to_mel(np.abs(D), sr)) 132 | return _normalize(S) 133 | 134 | 135 | ############### 136 | # SPECTROGRAM # 137 | ############### 138 | """ 139 | Compute the linear-scale spectrogram from the wav. 140 | """ 141 | def spectrogram(y, sr): 142 | D = _stft(_preemphasis(y), sr) 143 | S = _amp_to_db(np.abs(D)) - ref_level_db 144 | return _normalize(S) 145 | 146 | 147 | ################### 148 | # INV SPECTROGRAM # 149 | ################### 150 | """ 151 | Converts spectrogram to waveform using librosa 152 | """ 153 | def inv_spectrogram(spectrogram, sr=16000): 154 | S = _db_to_amp(_denormalize(spectrogram) + ref_level_db) # Convert back to linear 155 | return inv_preemphasis(_griffin_lim(S ** power, sr)) # Reconstruct phase 156 | 157 | 158 | ################### 159 | # EXTRACT FEATURE # 160 | ################### 161 | # Acoustic Feature Extraction 162 | # Parameters 163 | # - input file : str, audio file path 164 | # - feature : str, fbank or mfcc 165 | # - cmvn : bool, apply CMVN on feature 166 | # - save_feature: str, if given, store feature to the path and return len(feature) 167 | # Return 168 | # acoustic features with shape (time step, dim) 169 | def extract_feature(input_file, feature='fbank', cmvn=True, save_feature=None): 170 | y, sr = librosa.load(input_file, sr=sample_rate) 171 | 172 | if feature == 'fbank': 173 | ws = int(sr*0.001*window_size) 174 | st = int(sr*0.001*stride) 175 | feat = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=num_mels, 176 | n_fft=ws, hop_length=st) 177 | feat = np.log(feat + 1e-6) # log-scaled 178 | elif feature == 'mfcc': 179 | ws = int(sr*0.001*window_size) 180 | st = int(sr*0.001*stride) 181 | feat = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=num_mels, n_mels=26, 182 | n_fft=ws, hop_length=st) 183 | feat[0] = librosa.feature.rmse(y, hop_length=st, frame_length=ws) 184 | elif feature == 'mel': 185 | # feat = melspectrogram(y, sr) # deprecated 186 | n_fft, hop_length, win_length = _stft_parameters(sr) 187 | feat = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=num_mels, 188 | n_fft=n_fft, hop_length=hop_length, win_length=win_length) 189 | feat = np.log(feat + 1e-6) # log-scaled 190 | elif feature == 'linear': 191 | feat = spectrogram(y, sr) 192 | else: 193 | raise ValueError('Unsupported Acoustic Feature: ' + feature) 194 | 195 | # Apply delta 196 | feat = [feat] 197 | if delta and feature != 'linear': 198 | feat.append(librosa.feature.delta(feat[0])) 199 | 200 | if delta_delta and feature != 'linear': 201 | feat.append(librosa.feature.delta(feat[0], order=2)) 202 | feat = np.concatenate(feat, axis=0) 203 | if feature == 'linear': assert(np.shape(feat)[0] == num_freq) 204 | 205 | if cmvn: 206 | feat = (feat - feat.mean(axis=1)[:,np.newaxis]) / (feat.std(axis=1)+1e-16)[:,np.newaxis] 207 | if save_feature is not None: 208 | tmp = np.swapaxes(feat, 0, 1).astype('float32') 209 | np.save(save_feature,tmp) 210 | return len(tmp) 211 | else: 212 | return np.swapaxes(feat, 0, 1).astype('float32') 213 | 214 | 215 | ##################### 216 | # SAVE FIG TO NUMPY # 217 | ##################### 218 | def _save_figure_to_numpy(fig): 219 | # save it to a numpy array. 220 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 221 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 222 | return data.transpose(2, 0, 1) # (Channel, Height, Width) 223 | 224 | 225 | ############################# 226 | # PLOT SPECTROGRAM TO NUMPY # 227 | ############################# 228 | def plot_spectrogram_to_numpy(spectrogram): 229 | spectrogram = spectrogram.transpose(1, 0) 230 | fig, ax = plt.subplots(figsize=(12, 3)) 231 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 232 | interpolation='none') 233 | plt.colorbar(im, ax=ax) 234 | plt.xlabel("Frames") 235 | plt.ylabel("Channels") 236 | plt.tight_layout() 237 | 238 | fig.canvas.draw() 239 | data = _save_figure_to_numpy(fig) 240 | plt.close() 241 | return data 242 | 243 | 244 | #################### 245 | # PLOT SPECTROGRAM # 246 | #################### 247 | def plot_spectrogram(spec, path): 248 | spec = spec.transpose(1, 0) # (seq_len, feature_dim) -> (feature_dim, seq_len) 249 | plt.gcf().clear() 250 | plt.figure(figsize=(12, 3)) 251 | plt.imshow(spec, aspect="auto", origin="lower") 252 | plt.colorbar() 253 | plt.tight_layout() 254 | plt.savefig(path, dpi=300, format="png") 255 | plt.close() 256 | 257 | 258 | #################### 259 | # PLOT EMBEDDING # 260 | #################### 261 | def plot_embedding(spec, path): 262 | spec = spec.transpose(1, 0) # (seq_len, feature_dim) -> (feature_dim, seq_len) 263 | plt.gcf().clear() 264 | plt.figure(figsize=(12, 3)) 265 | plt.pcolormesh(spec, norm=SymLogNorm(linthresh=1e-3)) 266 | plt.colorbar() 267 | plt.tight_layout() 268 | plt.savefig(path, dpi=300, format="png") 269 | plt.close() -------------------------------------------------------------------------------- /figures/Vesper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/Vesper/93d41d702ee2105b1bda92e852725548e3893ebd/figures/Vesper.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Mar 13, 2023 5 | @author: lab-chen.weidong 6 | """ 7 | 8 | import os 9 | from tqdm import tqdm 10 | import torch 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | from torch.utils.tensorboard import SummaryWriter 15 | from collections import OrderedDict 16 | import math 17 | from sklearn.metrics import accuracy_score 18 | import json 19 | import shutil 20 | import re 21 | 22 | import utils 23 | import models 24 | from configs import create_workshop, get_config, dict_2_list 25 | 26 | class Engine(): 27 | def __init__(self, cfg, local_rank: int, world_size: int): 28 | self.cfg = cfg 29 | self.local_rank = local_rank 30 | self.world_size = world_size 31 | self.device = self.cfg.train.device 32 | self.EPOCH = self.cfg.train.EPOCH 33 | self.current_epoch = 0 34 | self.iteration = 0 35 | self.best_score = 0 36 | 37 | self.dataloader_feactory = utils.dataset.DataloaderFactory(self.cfg.dataset) 38 | self.loss_func = torch.nn.CrossEntropyLoss() 39 | self.calculate_score = utils.metric.calculate_score_classification 40 | self.early_stopping = utils.earlystopping.EarlyStopping(patience=self.cfg.train.patience, verbose=self.local_rank == 0, higher_is_better=True) 41 | 42 | ### prepare meters 43 | data_type = torch.int64 44 | self.loss_meter = utils.avgmeter.AverageMeter(device='cuda') 45 | self.acc_meter = utils.avgmeter.AverageMeter(device='cuda') 46 | self.predict_recoder = utils.recoder.TensorRecorder(device='cuda', dtype=data_type) 47 | self.label_recoder = utils.recoder.TensorRecorder(device='cuda', dtype=data_type) 48 | 49 | def prepare_staff(self): 50 | ''' We move this part out of the __init__ function to avoid the weird error: 51 | DataLoader worker (pid xxx) is killed by signal: Aborted 52 | This error is probably caused by a conflict between lmdb and ddp. 53 | ''' 54 | ### prepare dataloader 55 | self.dataloader_train = self.dataloader_feactory.build( 56 | state='train', 57 | bs=self.cfg.train.batch_size, 58 | fold=self.fold 59 | ) 60 | self.dataloader_test = self.dataloader_feactory.build( 61 | state='dev', 62 | bs=self.cfg.train.batch_size, 63 | fold=self.fold 64 | ) 65 | 66 | ### prepare model, optimizer and scheduler 67 | self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn 68 | self.cfg.model.freeze_upstream = self.cfg.train.freeze_upstream 69 | model = models.vesper.VesperFinetuneWrapper(self.cfg.model).to(self.device) 70 | 71 | if self.cfg.train.freeze_cnn: 72 | for param in model.vesper.feature_extractor.parameters(): 73 | param.requires_grad = False 74 | if self.cfg.train.freeze_upstream: 75 | for param in model.vesper.parameters(): 76 | param.requires_grad = False 77 | 78 | self.model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.local_rank]) 79 | 80 | if self.cfg.train.optimizer == 'AdamW': 81 | self.optimizer = torch.optim.AdamW( 82 | params=filter(lambda x: x.requires_grad, self.model.parameters()), 83 | lr=self.cfg.train.lr, 84 | weight_decay=self.cfg.train.weight_decay 85 | ) 86 | elif self.cfg.train.optimizer == 'sgd': 87 | self.optimizer = torch.optim.SGD( 88 | params=filter(lambda x: x.requires_grad, self.model.parameters()), 89 | lr=self.cfg.train.lr, 90 | momentum=0.9, 91 | weight_decay=self.cfg.train.weight_decay 92 | ) 93 | else: 94 | raise ValueError(f'Unknown optimizer: {self.cfg.train.optimizer}') 95 | 96 | if self.local_rank == 0: 97 | print(f'Optimizer: {self.cfg.train.optimizer}') 98 | 99 | # CosineAnnealingLR with Warm-up 100 | # warmup_epoch = int(self.cfg.train.warmup_epoch * self.EPOCH) 101 | warmup_epoch = 0 102 | lr_max = self.cfg.train.lr 103 | lr_min = self.cfg.train.lr * 0.01 104 | T_max = self.EPOCH 105 | lr_lambda = lambda epoch: (epoch + 1) / warmup_epoch if epoch < warmup_epoch else \ 106 | (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((epoch-warmup_epoch)/(T_max-warmup_epoch)*math.pi))) / self.cfg.train.lr 107 | self.scheduler = lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lr_lambda) 108 | 109 | if self.cfg.train.load_model is not None: 110 | ckpt = torch.load(self.cfg.train.load_model, map_location=self.device) 111 | self.model.module.load_state_dict(ckpt['model']) 112 | if self.local_rank == 0: 113 | print(f'Loading model from {self.cfg.train.load_model}') 114 | del ckpt 115 | 116 | if self.cfg.train.resume is not None: 117 | ckpt = torch.load(self.cfg.train.resume, map_location=self.device) 118 | self.model.module.load_state_dict(ckpt['model']) 119 | self.optimizer.load_state_dict(ckpt['optimizer']) 120 | self.scheduler.load_state_dict(ckpt['scheduler']) 121 | self.scheduler.step() 122 | self.current_epoch = ckpt['epoch'] + 1 123 | self.iteration = ckpt['iteration'] 124 | self.best_score = ckpt['best_score'] 125 | if self.local_rank == 0: 126 | print(f'Resuming from {self.cfg.train.resume}') 127 | del ckpt 128 | 129 | ### prepare writer and logger 130 | if self.local_rank == 0: 131 | self.writer = SummaryWriter(self.cfg.workshop) 132 | self.logger_train = utils.logger.create_logger(self.cfg.workshop, name='train') 133 | self.logger_test = utils.logger.create_logger(self.cfg.workshop, name='test') 134 | self.logger_train.info(f'workshop: {self.cfg.workshop}') 135 | self.logger_train.info(f'seed: {self.cfg.train.seed}') 136 | self.logger_train.info(f'pid: {os.getpid()}') 137 | print('Main pid:', os.getpid()) 138 | else: 139 | self.writer = None 140 | self.logger_train = None 141 | self.logger_test = None 142 | 143 | self.config_2_json() 144 | 145 | def config_2_json(self, jsonfile=None): 146 | self.jsonfile = os.path.join(self.cfg.workshop, 'config.json') if jsonfile is None else jsonfile 147 | with open(self.jsonfile, 'w') as f: 148 | json.dump(dict(self.cfg), f, indent=2) 149 | 150 | def json_2_config(self, jsonfile=None): 151 | if jsonfile is not None: 152 | self.jsonfile = jsonfile 153 | assert hasattr(self, 'jsonfile'), 'Please provide the .json file first.' 154 | with open(self.jsonfile, 'r') as f: 155 | data = json.load(f) 156 | self.cfg.merge_from_list(dict_2_list(data)) 157 | 158 | def reset_meters(self): 159 | self.loss_meter.reset() 160 | self.acc_meter.reset() 161 | 162 | def reset_recoders(self): 163 | self.predict_recoder.reset() 164 | self.label_recoder.reset() 165 | 166 | def gather_distributed_data(self, gather_data): 167 | if isinstance(gather_data, torch.Tensor): 168 | _output = [torch.zeros_like(gather_data) for _ in range(self.world_size)] 169 | dist.all_gather(_output, gather_data, async_op=False) 170 | output = torch.cat(_output) 171 | else: 172 | if gather_data[0] is not None: 173 | _output = [None for _ in range(self.world_size)] 174 | if hasattr(dist, 'all_gather_object'): 175 | dist.all_gather_object(_output, gather_data) 176 | else: 177 | utils.distributed.all_gather_object(_output, gather_data, self.world_size) 178 | output = [] 179 | for lst in _output: 180 | output.extend(lst) 181 | else: 182 | output = None 183 | return output 184 | 185 | def train_epoch(self): 186 | self.dataloader_train.set_epoch(self.current_epoch) 187 | if self.local_rank == 0: 188 | print(f'-------- {self.cfg.workshop} --------') 189 | discrip_str = f'Epoch-{self.current_epoch}/{self.EPOCH}' 190 | pbar_train = tqdm(self.dataloader_train, disable=self.local_rank != 0, dynamic_ncols=True) 191 | pbar_train.set_description('Train' + discrip_str) 192 | 193 | self.reset_meters() 194 | self.reset_recoders() 195 | 196 | self.model.train() 197 | for data in pbar_train: 198 | self.iteration += 1 199 | 200 | waveform = torch.cat(data['waveform'], dim=0).to(self.device) 201 | padding_mask = torch.cat(data['padding_mask'], dim=0).to(self.device) 202 | y = torch.cat(data['emotion'], dim=0).to(self.device) 203 | batch_size = y.shape[0] 204 | 205 | self.optimizer.zero_grad() 206 | 207 | pred = self.model(waveform, padding_mask) 208 | loss = self.loss_func(pred, y) 209 | loss.backward() 210 | 211 | self.optimizer.step() 212 | 213 | y_pred = torch.argmax(pred, dim=1) 214 | 215 | self.predict_recoder.record(y_pred) 216 | self.label_recoder.record(y) 217 | 218 | accuracy = accuracy_score(y.cpu(), y_pred.cpu()) 219 | self.loss_meter.update(loss.item()) 220 | self.acc_meter.update(accuracy, batch_size) 221 | 222 | pbar_train_dic = OrderedDict() 223 | pbar_train_dic['iter'] = self.iteration 224 | pbar_train_dic['lr'] = self.optimizer.param_groups[0]['lr'] 225 | pbar_train_dic['acc'] = f'{self.acc_meter.avg:.5f}' 226 | pbar_train_dic['loss'] = f'{self.loss_meter.avg:.5f}' 227 | pbar_train.set_postfix(pbar_train_dic) 228 | 229 | epoch_preds = self.gather_distributed_data(self.predict_recoder.data).cpu() 230 | epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu() 231 | 232 | self.loss_meter.sync_distributed() 233 | epoch_loss = self.loss_meter.avg 234 | 235 | if self.local_rank == 0: 236 | accuracy, recall, f1, precision, Matrix = self.calculate_score(epoch_preds, epoch_labels, self.cfg.dataset.f1) 237 | self.writer.add_scalar('Train/WA', accuracy, self.current_epoch) 238 | self.writer.add_scalar('Train/UA', recall, self.current_epoch) 239 | self.writer.add_scalar('Train/F1', f1, self.current_epoch) 240 | self.writer.add_scalar('Train/Loss', epoch_loss, self.current_epoch) 241 | self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], self.current_epoch) 242 | 243 | if self.logger_train is not None: 244 | self.logger_train.info( 245 | f'Training epoch: {self.current_epoch}, accuracy: {accuracy:.5f}, precision: {precision:.5f}, recall: {recall:.5f}, F1: {f1:.5f}, loss: {epoch_loss:.5f}' 246 | ) 247 | 248 | def test(self): 249 | discrip_str = f'Epoch-{self.current_epoch}' 250 | pbar_test = tqdm(self.dataloader_test, disable=self.local_rank != 0, dynamic_ncols=True) 251 | pbar_test.set_description('Test' + discrip_str) 252 | 253 | self.reset_meters() 254 | self.reset_recoders() 255 | 256 | self.model.eval() 257 | with torch.no_grad(): 258 | for data in pbar_test: 259 | waveform = torch.cat(data['waveform'], dim=0).to(self.device) 260 | padding_mask = torch.cat(data['padding_mask'], dim=0).to(self.device) 261 | y = torch.cat(data['emotion'], dim=0).to(self.device) 262 | batch_size = y.shape[0] 263 | 264 | pred = self.model(waveform, padding_mask) 265 | loss = self.loss_func(pred, y) 266 | 267 | y_pred = torch.argmax(pred, dim=1) 268 | 269 | self.predict_recoder.record(y_pred) 270 | self.label_recoder.record(y) 271 | 272 | accuracy = accuracy_score(y.cpu(), y_pred.cpu()) 273 | self.loss_meter.update(loss.item()) 274 | self.acc_meter.update(accuracy, batch_size) 275 | 276 | pbar_test_dic = OrderedDict() 277 | pbar_test_dic['acc'] = f'{self.acc_meter.avg:.5f}' 278 | pbar_test_dic['loss'] = f'{self.loss_meter.avg:.5f}' 279 | pbar_test.set_postfix(pbar_test_dic) 280 | 281 | epoch_preds = self.gather_distributed_data(self.predict_recoder.data).cpu() 282 | epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu() 283 | 284 | self.loss_meter.sync_distributed() 285 | epoch_loss = self.loss_meter.avg 286 | 287 | if self.local_rank == 0: 288 | # Calculate accuracy, recall, f1, precision, confuse_matrix 289 | accuracy, recall, f1, precision, Matrix = self.calculate_score(epoch_preds, epoch_labels, self.cfg.dataset.f1) 290 | self.writer.add_scalar('Test/WA', accuracy, self.current_epoch) 291 | self.writer.add_scalar('Test/UA', recall, self.current_epoch) 292 | self.writer.add_scalar('Test/F1', f1, self.current_epoch) 293 | self.writer.add_scalar('Test/Loss', epoch_loss, self.current_epoch) 294 | 295 | score = 0 296 | for metric in self.cfg.dataset.evaluate: 297 | score += eval(metric) 298 | 299 | if self.cfg.train.save_best or self.cfg.dataset.have_test_set: 300 | if score > self.best_score: 301 | self.best_score = score 302 | self.model_save(True) 303 | 304 | self.logger_test.info( 305 | f'Testing epoch: {self.current_epoch}, accuracy: {accuracy:.5f}, precision: {precision:.5f}, recall: {recall:.5f}, F1: {f1:.5f}, loss: {epoch_loss:.5f}, confuse_matrix: \n{Matrix}' 306 | ) 307 | 308 | self.early_stopping(score, self.model) 309 | 310 | def evaluate(self): 311 | self.dataloader_test = self.dataloader_feactory.build( 312 | state='test', 313 | bs=self.cfg.train.batch_size, 314 | fold=self.fold 315 | ) 316 | 317 | ckpt = torch.load(self.ckpt_best_file, map_location=self.device) 318 | self.model.module.load_state_dict(ckpt['model']) 319 | if self.local_rank == 0: 320 | print(f'Loading model from {self.ckpt_best_file}') 321 | del ckpt 322 | 323 | self.current_epoch = -1 324 | self.test() 325 | 326 | def model_save(self, is_best=False, filename='checkpoint.pt'): 327 | self.ckpt_save_file = os.path.join(self.cfg.ckpt_save_path, filename) 328 | save_dict = { 329 | 'cfg': self.cfg, 330 | 'epoch': self.current_epoch, 331 | 'iteration': self.iteration, 332 | 'best_score': self.best_score, 333 | 'model': self.model.module.state_dict(), # save DDP model 334 | 'optimizer': self.optimizer.state_dict(), 335 | 'scheduler': self.scheduler.state_dict() 336 | } 337 | torch.save(save_dict, self.ckpt_save_file) 338 | if is_best: 339 | self.ckpt_best_file = os.path.join(self.cfg.ckpt_save_path, 'model_best.pt') 340 | shutil.copyfile(self.ckpt_save_file, self.ckpt_best_file) 341 | 342 | def run(self, fold=1): 343 | self.fold = fold 344 | self.prepare_staff() 345 | 346 | while self.current_epoch < self.EPOCH: 347 | self.train_epoch() 348 | self.scheduler.step() 349 | self.test() 350 | 351 | self.current_epoch += 1 352 | 353 | if self.early_stopping.early_stop: 354 | print(f"Early stopping (patience: {self.early_stopping.patience})") 355 | break 356 | 357 | if self.cfg.dataset.have_test_set: 358 | self.evaluate() 359 | 360 | self.cleanup() 361 | 362 | def cleanup(self): 363 | if self.logger_train is not None: 364 | utils.logger.close_logger(self.logger_train) 365 | if self.logger_test is not None: 366 | utils.logger.close_logger(self.logger_test) 367 | if self.writer is not None: 368 | self.writer.close() 369 | # torch.cuda.empty_cache() 370 | self.early_stopping.clean() 371 | self.current_epoch = 0 372 | self.iteration = 0 373 | self.best_score = 0 374 | 375 | if not self.cfg.train.save_best: 376 | if hasattr(self, 'ckpt_save_file') and os.path.exists(self.ckpt_save_file): 377 | os.remove(self.ckpt_save_file) 378 | if hasattr(self, 'ckpt_best_file') and os.path.exists(self.ckpt_best_file): 379 | os.remove(self.ckpt_best_file) 380 | 381 | def main_worker(local_rank, cfg, world_size, dist_url): 382 | utils.environment.set_seed(cfg.train.seed + local_rank) 383 | torch.cuda.set_device(local_rank) 384 | dist.init_process_group( 385 | backend='nccl', 386 | init_method=dist_url, 387 | world_size=world_size, 388 | rank=local_rank, 389 | ) 390 | 391 | if cfg.model.init_with_ckpt: 392 | mark = re.search('(?<=_mark_)\w+', cfg.model.path_to_vesper) 393 | if mark is not None: 394 | if cfg.mark is None: 395 | cfg.mark = mark.group() 396 | else: 397 | cfg.mark = mark.group() + '_' + cfg.mark 398 | 399 | # torch.autograd.set_detect_anomaly(True) 400 | engine = Engine(cfg, local_rank, world_size) 401 | for fold in cfg.dataset.folds: 402 | create_workshop(cfg, local_rank, fold) 403 | engine.run(fold) 404 | 405 | if local_rank == 0: 406 | criterion = ['accuracy', 'precision', 'recall', 'F1'] 407 | evaluate = cfg.dataset.evaluate 408 | outfile = f'result/result_{cfg.model.type}_Finetune.csv' 409 | wantlow = False 410 | return_epoch = -1 if cfg.dataset.have_test_set else None 411 | utils.collect_result.path_to_csv(os.path.dirname(cfg.workshop), criterion, evaluate, csvfile=outfile, logname='test.log', wantlow=wantlow, epoch=return_epoch) 412 | 413 | def main(cfg): 414 | utils.environment.visible_gpus(cfg.train.device_id) 415 | 416 | free_port = utils.distributed.find_free_port() 417 | dist_url = f'tcp://127.0.0.1:{free_port}' 418 | world_size = torch.cuda.device_count() # num_gpus 419 | print(f'world_size={world_size} Using dist_url={dist_url}') 420 | 421 | mp.spawn(fn=main_worker, args=(cfg, world_size, dist_url), nprocs=world_size) 422 | 423 | if __name__=='__main__': 424 | cfg = get_config(mode='_finetune') 425 | main(cfg) 426 | -------------------------------------------------------------------------------- /load_test_model.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from models.vesper import Vesper, init_with_ckpt 5 | from configs import model_cfg 6 | 7 | def load_model(ckpt_path, device='cuda'): 8 | model_cfg['Vesper']['freeze_cnn'] = True 9 | model_cfg['Vesper']['device'] = device 10 | model = Vesper(model_cfg['Vesper']).to(device) 11 | init_with_ckpt(model, ckpt_path, 'vesper', device=device) 12 | 13 | return model 14 | 15 | def extract_hiddens(model, waveform, padding_mask=None): 16 | with torch.no_grad(): 17 | fea, layer_results = model( 18 | waveform=waveform, padding_mask=padding_mask, ret_layer_results=True, student_pretraining=False) 19 | layer_results = [layer_results[i+1][0] for i in range(len(layer_results)-1)] 20 | return layer_results 21 | 22 | if __name__ == '__main__': 23 | device = 'cuda' # 'cuda' or 'cpu' 24 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 25 | ckpt_path = 'exp/Vesper/lssed_e100_b32_lr0.0005_None_mark_L12_wloss/fold_1/checkpoint/model_best.pt' 26 | 27 | model = load_model(ckpt_path, device) 28 | waveform = torch.randn(1, 16000).to(device) 29 | padding_mask = torch.zeros(1, 16000).eq(1).to(device) 30 | 31 | hiddens = extract_hiddens(model, waveform, padding_mask) 32 | 33 | print(hiddens[0].shape) # 每一层的输出特征的形状(B, T, C) 34 | print(len(hiddens)) # 12, 总共有12层 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import vesper -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from modules import TransformerEncoderLayer, Classifier, make_conv_pos 4 | 5 | class Transformer(nn.Module): 6 | def build_encoder(self, args): 7 | layer = TransformerEncoderLayer( 8 | embed_dim=args.embed_dim, 9 | ffn_embed_dim=args.ffn_embed_dim, 10 | num_heads=args.num_heads, 11 | activation=args.activation, 12 | dropout=args.dropout, 13 | bias=args.bias, 14 | normalize_before=args.normalize_before 15 | ) 16 | return layer 17 | 18 | def __init__(self, args): 19 | super().__init__() 20 | self.normalize_before = args.normalize_before 21 | 22 | self.pos_conv = make_conv_pos( 23 | args.embed_dim, 24 | args.conv_pos, 25 | args.conv_pos_groups, 26 | ) 27 | 28 | self.layer_norm = nn.LayerNorm(args.embed_dim) 29 | self.layers = nn.ModuleList( 30 | [self.build_encoder(args) for _ in range(args.num_encoder_layers)] 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool1d(1) 33 | self.classifier = Classifier(args.embed_dim, args.num_classes, args.dropout, args.activation) 34 | 35 | def extract_features(self, x, key_padding_mask): 36 | x_conv = self.pos_conv(x.transpose(1, 2)) 37 | x_conv = x_conv.transpose(1, 2) 38 | x = x + x_conv 39 | 40 | if not self.normalize_before: 41 | x = self.layer_norm(x) 42 | 43 | layer_results = [] 44 | for i, layer in enumerate(self.layers): 45 | x, attn, _ = layer(x, key_padding_mask=key_padding_mask, need_weights=False) 46 | layer_results.append((x, attn)) 47 | 48 | return x, layer_results 49 | 50 | def forward(self, x, pad_mask=None): 51 | x, layer_results = self.extract_features(x, pad_mask) 52 | 53 | if self.normalize_before: 54 | x = self.layer_norm(x) 55 | 56 | x = self.avgpool(x.transpose(-1, -2)).squeeze(dim=-1) 57 | pred = self.classifier(x) 58 | 59 | return pred 60 | 61 | -------------------------------------------------------------------------------- /models/vesper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import Tensor, BoolTensor, FloatTensor 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import math 8 | import re 9 | from typing import List, Optional 10 | from einops.layers.torch import Rearrange 11 | 12 | from modules import TransformerEncoderLayer, _get_activation_fn, make_conv_pos, MultiheadAttention 13 | 14 | def init_bert_params(module): 15 | """ 16 | Initialize the weights specific to the BERT Model. 17 | This overrides the default initializations depending on the specified arguments. 18 | 1. If normal_init_linear_weights is set then weights of linear 19 | layer will be initialized using the normal distribution and 20 | bais will be set to the specified value. 21 | 2. If normal_init_embed_weights is set then weights of embedding 22 | layer will be initialized using the normal distribution. 23 | 3. If normal_init_proj_weights is set then weights of 24 | in_project_weight for MultiHeadAttention initialized using 25 | the normal distribution (to be validated). 26 | """ 27 | 28 | def normal_(data): 29 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 30 | # so that the RNG is consistent with and without FSDP 31 | data.copy_( 32 | data.cpu().normal_(mean=0.0, std=0.02).to(data.device) 33 | ) 34 | 35 | if isinstance(module, nn.Linear): 36 | normal_(module.weight.data) 37 | if module.bias is not None: 38 | module.bias.data.zero_() 39 | if isinstance(module, nn.Embedding): 40 | normal_(module.weight.data) 41 | if module.padding_idx is not None: 42 | module.weight.data[module.padding_idx].zero_() 43 | if isinstance(module, MultiheadAttention): 44 | normal_(module.q_proj.weight.data) 45 | normal_(module.k_proj.weight.data) 46 | normal_(module.v_proj.weight.data) 47 | 48 | def init_with_wavlm(model: nn.Module, num_layers: int=24, ckpt: str='PATH/TO/WavLM_CHECKPOINT', need_mask_emb: bool=True, style: list=['random'], info: str=''): 49 | assert ckpt is not None 50 | data = torch.load(ckpt) 51 | state_dict = data['model'] 52 | num_wavlm_layers = data['cfg']['encoder_layers'] 53 | 54 | pop_dict = {} 55 | for key in state_dict.keys(): 56 | if key.startswith('encoder.layers.') and not 'relative_attention_bias' in key: 57 | pop_dict[key] = state_dict[key] 58 | 59 | for key in pop_dict.keys(): 60 | state_dict.pop(key) 61 | encoder_layers_modules = set([re.search(r'(?<=\d\.).*', key).group(0) for key in pop_dict.keys()]) 62 | 63 | if isinstance(style, str): 64 | style = [style] 65 | if style[0] == 'uniform_average': 66 | assert num_wavlm_layers % num_layers == 0 67 | merge_rate = num_wavlm_layers // num_layers 68 | 69 | for module in encoder_layers_modules: 70 | for i in range(num_layers): 71 | state_dict[f'encoder.layers.{i}.{module}'] = ( 72 | torch.mean( 73 | torch.stack( 74 | [ 75 | pop_dict[f'encoder.layers.{i*merge_rate+j}.{module}'] for j in range(merge_rate) 76 | ], dim=0), dim=0) 77 | ) 78 | elif style[0] == 'custom_average': 79 | custom = style[1] 80 | assert len(custom) == num_layers 81 | 82 | for module in encoder_layers_modules: 83 | for i in range(num_layers): 84 | state_dict[f'encoder.layers.{i}.{module}'] = ( 85 | torch.mean( 86 | torch.stack( 87 | [ 88 | pop_dict[f'encoder.layers.{j}.{module}'] for j in range(custom[i][0], custom[i][1]+1) 89 | ], dim=0), dim=0) 90 | ) 91 | elif style[0] == 'uniform_extract': 92 | interval = num_wavlm_layers // num_layers 93 | 94 | for module in encoder_layers_modules: 95 | for i in range(num_layers): 96 | state_dict[f'encoder.layers.{i}.{module}'] = pop_dict[f'encoder.layers.{i*interval}.{module}'] 97 | elif style[0] == 'custom_extract': 98 | custom = style[1] 99 | assert len(custom) == num_layers 100 | 101 | for module in encoder_layers_modules: 102 | for i in range(num_layers): 103 | state_dict[f'encoder.layers.{i}.{module}'] = pop_dict[f'encoder.layers.{custom[i]}.{module}'] 104 | elif style[0] == 'identity_mapping': 105 | for module in encoder_layers_modules: 106 | for i in range(num_layers): 107 | state_dict[f'encoder.layers.{i}.{module}'] = pop_dict[f'encoder.layers.{i}.{module}'] 108 | elif style[0] == 'random': 109 | state_dict = model.state_dict() 110 | else: 111 | raise NotImplementedError 112 | 113 | if not need_mask_emb: 114 | state_dict.pop('mask_emb') 115 | model.mask_emb = None 116 | 117 | # we remove the layer_normalization in the output of encoder 118 | state_dict.pop('encoder.layer_norm.weight') 119 | state_dict.pop('encoder.layer_norm.bias') 120 | 121 | print(f'vesper/{info}: Initialize with WavLM (style: {style}).') 122 | model.load_state_dict(state_dict) 123 | 124 | del state_dict 125 | del pop_dict 126 | 127 | def init_with_ckpt(model: nn.Module, ckpt: str='PATH/TO/CHECKPOINT', name: str='vesper', need_mask_emb: bool=True, info: str='', device: str='cuda'): 128 | assert ckpt is not None 129 | 130 | if ckpt == '': 131 | print(f'{name}/{info}: No checkpoint found.') 132 | return 133 | 134 | if not need_mask_emb and hasattr(model, 'mask_emb'): 135 | model.mask_emb = None 136 | state_dict = torch.load(ckpt, map_location=device)['model'] 137 | 138 | dit = {} 139 | for k, v in state_dict.items(): 140 | if k.startswith(name): 141 | dit[k[len(name)+1:]] = v 142 | 143 | if not need_mask_emb and 'mask_emb' in dit.keys(): 144 | dit.pop('mask_emb') 145 | 146 | # we remove the layer_normalization in the output of encoder 147 | dit.pop('encoder.layer_norm.weight', None) 148 | dit.pop('encoder.layer_norm.bias', None) 149 | 150 | if dit is None: 151 | print(f'{name}/{info}: No matching keys found in checkpoint: {ckpt}') 152 | else: 153 | print(f'{name}/{info}: Initialize with checkpoint: {ckpt}') 154 | model.load_state_dict(dit) 155 | 156 | del state_dict 157 | del dit 158 | 159 | def apply_mask(x: Tensor, mask: BoolTensor, fill_value: Tensor, clone: bool=False): 160 | _x = x.clone() if clone else x 161 | _x[mask] = fill_value 162 | return _x 163 | 164 | @torch.no_grad() 165 | def get_rms(x: Tensor, frame_length: int = 2048, hop_length: int = 512): 166 | ''' 167 | Inputs: 168 | x: (B, T), ``Tensor``, T dedotes the length of the time series. 169 | Outputs: 170 | rms: (B, Tf), ``Tensor``, Tf denotes the number of frames. 171 | ''' 172 | if isinstance(x, np.ndarray): 173 | x = torch.from_numpy(x) 174 | if x.dim() == 1: 175 | x = x.unsqueeze(dim=0) 176 | 177 | n_frames = 1 + (x.shape[-1] - frame_length) // hop_length 178 | strides = torch.tensor(x.stride()) 179 | 180 | shape = list(x.shape)[:-1] + [frame_length, n_frames] 181 | strides = list(strides) + [hop_length] # * new_stride 182 | 183 | frame = torch.as_strided(x, size=shape, stride=strides) 184 | rms = torch.sqrt(torch.mean(torch.abs(frame)**2, dim=1, keepdim=False)) 185 | 186 | return rms 187 | 188 | @torch.no_grad() 189 | def space_indices(indices: Tensor, space: int=1, maximum: int=1, already_sorted: bool=True): 190 | if not already_sorted: 191 | indices, _ = torch.sort(indices, descending=False) 192 | for i in range(0, len(indices)-1): 193 | if indices[i+1] - indices[i] < space: 194 | indices[i+1] = indices[i] + space 195 | if indices[i+1] > maximum: 196 | indices = indices[:i+1] 197 | break 198 | return indices 199 | 200 | @torch.no_grad() 201 | def get_random_mask( 202 | fea: Tensor, 203 | span: int=8, 204 | max_num_span: int=10, 205 | span_space: int=1, 206 | real_length: Tensor=None, 207 | max_mask_percentage: float=0.5 208 | ): 209 | mask = torch.full(fea.shape[:2], False, dtype=torch.bool, device=fea.device) 210 | 211 | if real_length is not None: 212 | num_span_per_sample = (real_length * max_mask_percentage / span).tolist() 213 | num_span_per_sample = [math.floor(s) if s < max_num_span else max_num_span for s in num_span_per_sample] 214 | valid_length = (real_length - span).tolist() 215 | else: 216 | valid_length = [fea.shape[1] - span] * fea.shape[0] 217 | num_span_per_sample = [max_num_span] * fea.shape[0] 218 | 219 | span_start = [] 220 | for i, (valid) in enumerate(valid_length): 221 | num_span = num_span_per_sample[i] 222 | indices = torch.randperm(valid)[:num_span] 223 | 224 | indices = space_indices(indices, space=span+span_space, maximum=valid, already_sorted=False) 225 | 226 | if len(indices) < num_span: 227 | indices = torch.cat((indices, torch.randperm(valid, device=indices.device)))[:num_span] 228 | 229 | if (not num_span) or (not len(indices)): 230 | indices = torch.randperm(valid)[0].unsqueeze(dim=0) 231 | span_start.append(indices) 232 | mask[i][indices:real_length[i]] = True 233 | else: 234 | span_start.append(indices) 235 | 236 | indices = torch.as_tensor( 237 | [ 238 | indices[j] + offset 239 | for j in range(num_span) 240 | for offset in range(span) 241 | ] 242 | ) 243 | 244 | mask[i][indices] = True 245 | 246 | return mask, span_start 247 | 248 | @torch.no_grad() 249 | def get_rms_mask( 250 | rms: Tensor, 251 | h_up: float=1.0, 252 | h_down: float=0.5, 253 | l_up: float=0.49, 254 | l_down: float=0.2, 255 | span: int=8, 256 | max_num_span: int=10, 257 | span_space: int=1, 258 | real_length: Tensor=None, 259 | max_mask_percentage: float=0.5 260 | ): 261 | mask = torch.full(rms.shape, False, dtype=torch.bool, device=rms.device) 262 | 263 | if real_length is not None: 264 | num_span_per_sample = (real_length * max_mask_percentage / span).tolist() 265 | num_span_per_sample = [math.floor(s) if s < max_num_span else max_num_span for s in num_span_per_sample] 266 | valid_length = (real_length - span).tolist() 267 | else: 268 | valid_length = [rms.shape[-1] - span] * rms.shape[0] 269 | num_span_per_sample = [max_num_span] * rms.shape[0] 270 | 271 | span_start = [] 272 | for i, (row, valid) in enumerate(zip(rms, valid_length)): 273 | row = row[:valid] 274 | max_val = torch.max(row) 275 | h_down = h_down * max_val 276 | h_up = h_up * max_val 277 | l_down = l_down * max_val 278 | l_up = l_up * max_val 279 | h_mask = torch.logical_and(row >= h_down, row <= h_up) 280 | l_mask = torch.logical_and(row >= l_down, row <= l_up) 281 | h_indices = torch.nonzero(h_mask, as_tuple=False).squeeze(dim=1) 282 | l_indices = torch.nonzero(l_mask, as_tuple=False).squeeze(dim=1) 283 | 284 | num_span = num_span_per_sample[i] 285 | h_indices = h_indices[torch.randperm(len(h_indices))][:num_span//2] 286 | l_indices = l_indices[torch.randperm(len(l_indices))][:num_span-len(h_indices)] 287 | 288 | h_indices = space_indices(h_indices, space=span+span_space, maximum=valid) 289 | l_indices = space_indices(l_indices, space=span+span_space, maximum=valid) 290 | 291 | if len(h_indices) + len(l_indices) < num_span: 292 | indices = torch.cat((h_indices, l_indices, torch.randperm(valid, device=h_indices.device)))[:num_span] 293 | else: 294 | indices =torch.cat((h_indices, l_indices)) 295 | 296 | if (not num_span) or (not len(indices)): 297 | indices = torch.randperm(valid)[0].unsqueeze(dim=0) 298 | span_start.append(indices) 299 | mask[i][indices:real_length[i]] = True 300 | else: 301 | span_start.append(indices) 302 | 303 | indices = torch.as_tensor( 304 | [ 305 | indices[j] + offset 306 | for j in range(num_span) 307 | for offset in range(span) 308 | ] 309 | ) 310 | 311 | mask[i][indices] = True 312 | 313 | return mask, span_start 314 | 315 | @torch.no_grad() 316 | def expand_mask( 317 | mask: Tensor, 318 | expanded_span: int=40, 319 | span_start: Tensor=None, 320 | max_num_expanded_span: int=2, 321 | span_space: int=1, 322 | real_length: Tensor=None, 323 | max_mask_percentage: float=0.5 324 | ): 325 | mask = torch.full_like(mask, False) 326 | 327 | if real_length is not None: 328 | num_span_per_sample = (real_length * max_mask_percentage / expanded_span).tolist() 329 | num_span_per_sample = [math.floor(s) if s < max_num_expanded_span else max_num_expanded_span for s in num_span_per_sample] 330 | valid_length = (real_length - expanded_span).tolist() 331 | else: 332 | valid_length = [mask.shape[-1] - expanded_span] * mask.shape[0] 333 | num_span_per_sample = [max_num_expanded_span] * mask.shape[0] 334 | 335 | expanded_span_start = [] 336 | for i, (indices, valid) in enumerate(zip(span_start, valid_length)): 337 | indices = indices[indices < valid] 338 | num_expanded_span = num_span_per_sample[i] 339 | 340 | indices = space_indices(indices, space=expanded_span+span_space, maximum=valid, already_sorted=False) 341 | 342 | if len(indices) < num_expanded_span: 343 | indices = torch.cat((indices, torch.randperm(valid, device=indices.device)))[:num_expanded_span] 344 | else: 345 | indices = indices[torch.randperm(len(indices))][:num_expanded_span] 346 | 347 | if (not num_expanded_span) or (not len(indices)): 348 | indices = span_start[i][0].unsqueeze(dim=0) 349 | expanded_span_start.append(indices) 350 | mask[i][indices:real_length[i]] = True 351 | else: 352 | expanded_span_start.append(indices) 353 | 354 | indices = torch.as_tensor( 355 | [ 356 | indices[j] + offset 357 | for j in range(num_expanded_span) 358 | for offset in range(expanded_span) 359 | ] 360 | ) 361 | 362 | mask[i][indices] = True 363 | 364 | return mask, expanded_span_start 365 | 366 | def normalize(x: Tensor, p: int=2, dim: int=-1): 367 | return F.normalize(x, p, dim) 368 | 369 | def masked_select(x: Tensor, mask: BoolTensor): 370 | ''' 371 | Inputs: 372 | x: (B, T, C), ``Tensor`` 373 | mask: (B, T), ```BoolTensor` 374 | Output: 375 | x: (-1, C), `` Tensor`` 376 | ''' 377 | return x.masked_select(mask.unsqueeze(dim=-1)).view(-1, x.size(-1)) 378 | 379 | class ConvFeatureExtractionModel(nn.Module): 380 | def __init__( 381 | self, 382 | conv_layers: list = [(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2, 383 | dropout: float = 0.0, 384 | conv_bias: bool = False, 385 | mode: str = "default" 386 | ): 387 | super().__init__() 388 | 389 | def block( 390 | n_in, 391 | n_out, 392 | k, 393 | stride, 394 | conv_bias=False, 395 | is_layer_norm=False, 396 | is_group_norm=False 397 | ): 398 | def make_conv(): 399 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 400 | nn.init.kaiming_normal_(conv.weight) 401 | return conv 402 | 403 | if is_layer_norm: 404 | return nn.Sequential( 405 | make_conv(), 406 | nn.Dropout(p=dropout), 407 | nn.Sequential( 408 | Rearrange("b c t -> b t c"), 409 | nn.LayerNorm(dim, elementwise_affine=True), 410 | Rearrange("b c t -> b t c"), 411 | ), 412 | nn.GELU(), 413 | ) 414 | elif is_group_norm: 415 | return nn.Sequential( 416 | make_conv(), 417 | nn.Dropout(p=dropout), 418 | nn.GroupNorm(dim, dim, affine=True), 419 | nn.GELU(), 420 | ) 421 | else: 422 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 423 | 424 | in_d = 1 425 | self.conv_layers = nn.ModuleList() 426 | for i, cl in enumerate(conv_layers): 427 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 428 | (dim, k, stride) = cl 429 | 430 | self.conv_layers.append( 431 | block( 432 | in_d, 433 | dim, 434 | k, 435 | stride, 436 | conv_bias=conv_bias, 437 | is_layer_norm=mode == "layer_norm", 438 | is_group_norm=mode == "default" and i == 0, 439 | ) 440 | ) 441 | in_d = dim 442 | 443 | def forward(self, x): 444 | # BxT -> BxCxT 445 | x = x.unsqueeze(1) 446 | for conv in self.conv_layers: 447 | x = conv(x) 448 | return x 449 | 450 | class TransformerEncoder(nn.Module): 451 | def __init__(self, args): 452 | super().__init__() 453 | self.frame_length = args.frame_length 454 | self.hop_length = args.hop_length 455 | self.h_up = args.h_up 456 | self.h_down = args.h_down 457 | self.l_up = args.l_up 458 | self.l_down = args.l_down 459 | self.small_span = args.small_span 460 | self.num_small_span = args.num_small_span 461 | self.large_span = args.large_span 462 | self.num_large_span = args.num_large_span 463 | self.span_space = args.span_space 464 | self.max_mask_percentage = args.max_mask_percentage 465 | self.encoder_layers = args.encoder_layers 466 | self.dropout = args.dropout 467 | self.pos_conv = make_conv_pos(args.encoder_embed_dim, args.conv_pos, args.conv_pos_groups) 468 | self.mask_depend_on_rms = args.mask_depend_on_rms 469 | 470 | self.relative_position_embedding = args.relative_position_embedding 471 | self.num_buckets = args.num_buckets 472 | self.max_distance = args.max_distance 473 | 474 | self.layers = nn.ModuleList( 475 | [ 476 | TransformerEncoderLayer( 477 | embed_dim=args.encoder_embed_dim, 478 | ffn_embed_dim=args.ffn_embed_dim, 479 | num_heads=args.num_heads, 480 | activation=args.activation, 481 | dropout=args.dropout, 482 | bias=args.bias, 483 | normalize_before=True, 484 | has_relative_attention_bias=(self.relative_position_embedding and i == 0), 485 | num_buckets=self.num_buckets, 486 | max_distance=self.max_distance, 487 | gru_rel_pos=args.gru_rel_pos, 488 | qk_norm=args.qk_norm 489 | ) 490 | for i in range(args.encoder_layers) 491 | ] 492 | ) 493 | 494 | # self.layer_norm = nn.LayerNorm(args.encoder_embed_dim) 495 | 496 | self.apply(init_bert_params) 497 | 498 | def forward(self, x: Tensor, padding_mask=None, layer=None, student_pretraining=False, waveform=None, mask_emb=None): 499 | if student_pretraining: 500 | if padding_mask is not None: 501 | real_length = torch.sum(~padding_mask, dim=-1, dtype=torch.int) 502 | else: 503 | real_length = torch.full((x.size(0),), fill_value=x.size(1), device=x.device, dtype=torch.int) 504 | 505 | if self.mask_depend_on_rms: 506 | rms = get_rms(waveform, frame_length=self.frame_length, hop_length=self.hop_length) 507 | small_span_mask, span_start = get_rms_mask( 508 | rms, 509 | self.h_up, 510 | self.h_down, 511 | self.l_up, 512 | self.l_down, 513 | self.small_span, 514 | self.num_small_span, 515 | self.span_space, 516 | real_length, 517 | self.max_mask_percentage 518 | ) 519 | else: 520 | small_span_mask, span_start = get_random_mask( 521 | x, 522 | self.small_span, 523 | self.num_small_span, 524 | self.span_space, 525 | real_length, 526 | self.max_mask_percentage 527 | ) 528 | large_span_mask, expanded_span_start = expand_mask( 529 | small_span_mask, 530 | self.large_span, 531 | span_start, 532 | self.num_large_span, 533 | self.span_space, 534 | real_length, 535 | self.max_mask_percentage 536 | ) 537 | interlayer = self.encoder_layers//2 538 | x, layer_results = self.extract_features( 539 | x, 540 | padding_mask, 541 | None, 542 | student_pretraining, 543 | interlayer, 544 | small_span_mask, 545 | large_span_mask, 546 | mask_emb 547 | ) 548 | else: 549 | x, layer_results = self.extract_features(x, padding_mask, layer) 550 | 551 | # if layer is None: 552 | # x = self.layer_norm(x) 553 | 554 | if student_pretraining: 555 | return x, layer_results, real_length, interlayer, small_span_mask, large_span_mask 556 | else: 557 | return x, layer_results 558 | 559 | def extract_features( 560 | self, 561 | x, 562 | padding_mask=None, 563 | tgt_layer=None, 564 | student_pretraining=False, 565 | interlayer=0, 566 | small_span_mask=None, 567 | large_span_mask=None, 568 | mask_emb=None 569 | ): 570 | if padding_mask is not None: 571 | x[padding_mask] = 0 572 | 573 | x_conv = self.pos_conv(x.transpose(1, 2)) 574 | x_conv = x_conv.transpose(1, 2) 575 | x = x + x_conv 576 | 577 | x = F.dropout(x, p=self.dropout, training=self.training) 578 | 579 | layer_results = [] 580 | attn_weights = None 581 | layer_results.append((x, attn_weights)) 582 | pos_bias = None 583 | 584 | if student_pretraining: 585 | x = apply_mask(x, small_span_mask, mask_emb, clone=True) 586 | for i, layer in enumerate(self.layers): 587 | if i == interlayer: 588 | x = apply_mask(x, large_span_mask, mask_emb, clone=True) 589 | x, attn_weights, pos_bias = layer(x, key_padding_mask=padding_mask, need_weights=True, pos_bias=pos_bias) 590 | layer_results.append((x, attn_weights)) 591 | else: 592 | for i, layer in enumerate(self.layers): 593 | x, attn_weights, pos_bias = layer(x, key_padding_mask=padding_mask, need_weights=True, pos_bias=pos_bias) 594 | layer_results.append((x, attn_weights)) 595 | if i == tgt_layer: 596 | break 597 | 598 | return x, layer_results 599 | 600 | class PredictionHead(nn.Module): 601 | '''A simple feed-forward network. 602 | 603 | Inputs: 604 | x: (B, T, input_dim), ``Tensor`` 605 | Outputs: 606 | x: (B, T, output_dim), ``Tensor`` 607 | ''' 608 | def __init__(self, input_dim: int, output_dim: int, activation: str, norm_input: bool=True): 609 | super().__init__() 610 | self.norm_input = norm_input 611 | self.simple_ffn = nn.Sequential( 612 | nn.Linear(input_dim, input_dim//2), 613 | _get_activation_fn(activation, module=True), 614 | nn.Linear(input_dim//2, output_dim) 615 | ) 616 | 617 | def forward(self, x: Tensor): 618 | if self.norm_input: 619 | x = F.layer_norm(x, [x.shape[-1]]) 620 | return self.simple_ffn(x) 621 | 622 | class Vesper(nn.Module): 623 | def __init__(self, args): 624 | super().__init__() 625 | feature_enc_layers = eval(args.conv_feature_layers) 626 | conv_embed = feature_enc_layers[-1][0] 627 | 628 | self.feature_extractor = ConvFeatureExtractionModel(feature_enc_layers, mode=args.extractor_mode) 629 | self.layer_norm = nn.LayerNorm(conv_embed) 630 | self.post_extract_proj = nn.Linear(conv_embed, args.encoder_embed_dim) 631 | self.dropout_input = nn.Dropout(args.dropout_input) 632 | 633 | self.encoder = TransformerEncoder(args) 634 | 635 | self.mask_emb = nn.Parameter(FloatTensor(args.encoder_embed_dim).uniform_(), requires_grad=True) 636 | self.padding_mask = None 637 | self.normalize = args.normalize 638 | self.freeze_cnn = args.freeze_cnn 639 | 640 | def forward_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor: 641 | extra = padding_mask.size(1) % features.size(1) 642 | if extra > 0: 643 | padding_mask = padding_mask[:, :-extra] 644 | padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) 645 | self.padding_mask = padding_mask.all(-1) 646 | 647 | def get_padding_mask(self): 648 | return self.padding_mask 649 | 650 | def forward( 651 | self, 652 | waveform: Tensor, 653 | padding_mask: Optional[Tensor]=None, 654 | output_layer: Optional[int]=None, 655 | ret_layer_results: bool=False, 656 | student_pretraining=False 657 | ): 658 | ''' 659 | Inputs: 660 | waveform: (B, T_audio), ``Tensor`` 661 | padding_mask: (B, T_audio), ``BoolTensor``, key padding mask. 662 | output_layer: ``int``, varies between [1, 24]. 663 | ret_layer_results: ``bool``, default False. 664 | Outputs: 665 | features: (B, T, C), ``Tensor`` 666 | layers_rep: [feature_encoder_output, layer_1_output, layer_2_output, ..., layer_n_output], ``list`` 667 | ''' 668 | if self.normalize: 669 | waveform = F.layer_norm(waveform, [waveform.shape[-1]]) 670 | 671 | if self.freeze_cnn: 672 | with torch.no_grad(): 673 | features = self.feature_extractor(waveform) 674 | else: 675 | features = self.feature_extractor(waveform) 676 | 677 | features = features.transpose(1, 2) 678 | features = self.layer_norm(features) 679 | 680 | features = self.post_extract_proj(features) 681 | features = self.dropout_input(features) 682 | 683 | if padding_mask is not None: 684 | self.forward_padding_mask(features, padding_mask) 685 | else: 686 | self.padding_mask = None 687 | 688 | if student_pretraining: 689 | features, layer_results, real_length, interlayer, small_span_mask, large_span_mask = self.encoder( 690 | features, 691 | padding_mask=self.padding_mask, 692 | layer=None, 693 | student_pretraining=True, 694 | waveform=waveform, 695 | mask_emb=self.mask_emb 696 | ) 697 | return features, layer_results, real_length, interlayer, small_span_mask, large_span_mask 698 | else: 699 | features, layer_results = self.encoder( 700 | features, 701 | padding_mask=self.padding_mask, 702 | layer=None if output_layer is None else output_layer - 1 703 | ) 704 | 705 | if ret_layer_results: 706 | features = (features, layer_results) 707 | return features 708 | 709 | class VesperFinetuneWrapper(nn.Module): 710 | def __init__(self, args): 711 | super().__init__() 712 | 713 | self.vesper = Vesper(args) 714 | 715 | if args.init_with_ckpt: 716 | init_with_ckpt(self.vesper, args.path_to_vesper, 'vesper', need_mask_emb=False) 717 | elif args.init_with_wavlm: 718 | init_with_wavlm(self.vesper, args.encoder_layers, args.path_to_wavlm, need_mask_emb=False, style=args.init_style) 719 | else: 720 | print('No initialization method specified. Initializing with random weights.') 721 | 722 | self.projector = nn.Linear(args.encoder_embed_dim, args.projector_dim) 723 | self.classifier = nn.Linear(args.projector_dim, args.num_classes) 724 | 725 | self.freeze_upstream = args.freeze_upstream 726 | # self.normalize = args.normalize 727 | 728 | if args.output_rep == 'weighted_sum': 729 | self.weights = nn.Parameter(torch.zeros(args.encoder_layers+1)) 730 | print(f'Using weighted sum of {list(self.weights.shape)} representations as output representation.') 731 | elif args.output_rep == 'last_layer': 732 | self.weights = None 733 | print('Using last layer representation as output representation.') 734 | else: 735 | raise NotImplementedError(f'output_rep {args.output_rep} is not implemented.') 736 | 737 | def _weighted_sum(self, layer_results: list): 738 | stacked_feature = torch.stack(layer_results, dim=0) 739 | 740 | # if self.normalize: 741 | # stacked_feature = F.layer_norm(stacked_feature, (stacked_feature.shape[-1],)) 742 | 743 | _, *origin_shape = stacked_feature.shape 744 | stacked_feature = stacked_feature.view(len(layer_results), -1) 745 | norm_weights = F.softmax(self.weights, dim=-1) 746 | weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) 747 | weighted_feature = weighted_feature.view(*origin_shape) 748 | 749 | return weighted_feature 750 | 751 | def forward( 752 | self, 753 | waveform: Tensor, 754 | padding_mask: Optional[Tensor]=None 755 | ): 756 | if self.freeze_upstream: 757 | with torch.no_grad(): 758 | fea, layer_results = self.vesper( 759 | waveform=waveform, padding_mask=padding_mask, ret_layer_results=True, student_pretraining=False) 760 | else: 761 | fea, layer_results = self.vesper( 762 | waveform=waveform, padding_mask=padding_mask, ret_layer_results=True, student_pretraining=False) 763 | 764 | if self.weights is not None: 765 | # layer_results = [layer_results[i+1][0] for i in range(len(layer_results)-1)] 766 | layer_results = [layer_results[i][0] for i in range(len(layer_results))] 767 | fea = self._weighted_sum(layer_results) 768 | 769 | padding_mask = self.vesper.get_padding_mask() 770 | if padding_mask is not None: 771 | real_length = torch.sum(~padding_mask, dim=-1, keepdim=True) 772 | fea[padding_mask] = 0.0 773 | else: 774 | real_length = torch.full((fea.size(0),1), fill_value=fea.size(1), dtype=fea.dtype, device=fea.device) 775 | 776 | fea = self.projector(fea) 777 | pooled_output = fea.sum(dim=1) / real_length 778 | pred = self.classifier(pooled_output) 779 | 780 | return pred 781 | 782 | class Vesper_PretrainWrapper(nn.Module): 783 | def __init__(self, args): 784 | super().__init__() 785 | 786 | self.vesper = Vesper(args) 787 | if args.init_with_ckpt: 788 | init_with_ckpt(self.vesper, args.path_to_vesper, 'vesper', info='student', device=args.device) 789 | elif args.init_with_wavlm: 790 | init_with_wavlm(self.vesper, args.encoder_layers, args.path_to_wavlm, style=args.init_style) 791 | else: 792 | print('No initialization method specified. Initializing with random weights.') 793 | 794 | self.l_predictor = PredictionHead( 795 | input_dim=args.encoder_embed_dim, 796 | output_dim=args.encoder_embed_dim, 797 | activation=args.activation, 798 | norm_input=True 799 | ) if args.enable_l_predictor else None 800 | 801 | self.h_predictor = PredictionHead( 802 | input_dim=args.encoder_embed_dim, 803 | output_dim=args.encoder_embed_dim, 804 | activation=args.activation, 805 | norm_input=True 806 | ) if args.enable_h_predictor else None 807 | 808 | self.x_predictor = PredictionHead( 809 | input_dim=args.encoder_embed_dim, 810 | output_dim=args.encoder_embed_dim, 811 | activation=args.activation, 812 | norm_input=True 813 | ) if args.enable_x_predictor else None 814 | 815 | if args.init_with_ckpt: 816 | if self.l_predictor is not None: 817 | init_with_ckpt(self.l_predictor, args.path_to_vesper, 'l_predictor', False) 818 | if self.h_predictor is not None: 819 | init_with_ckpt(self.h_predictor, args.path_to_vesper, 'h_predictor', False) 820 | if self.x_predictor is not None: 821 | init_with_ckpt(self.x_predictor, args.path_to_vesper, 'x_predictor', False) 822 | 823 | self.loss = nn.MSELoss() 824 | 825 | def cal_loss(self, pred: Tensor, target: Tensor, apply_norm: bool=True): 826 | if apply_norm: 827 | loss = self.loss(normalize(pred), normalize(target)) 828 | else: 829 | loss = self.loss(pred, target) 830 | return loss 831 | 832 | def forward( 833 | self, 834 | waveform: Tensor, 835 | padding_mask: Optional[Tensor]=None, 836 | l_target: Tensor=None, 837 | h_target: Tensor=None, 838 | ): 839 | fea, layer_results, _, interlayer, small_span_mask, large_span_mask = self.vesper( 840 | waveform=waveform, padding_mask=padding_mask, ret_layer_results=True, student_pretraining=True) 841 | 842 | if self.l_predictor is not None: 843 | s_fea_l = masked_select(layer_results[interlayer][0], mask=small_span_mask) 844 | s_fea_l_pred = self.l_predictor(s_fea_l) 845 | t_fea_l = masked_select(l_target, mask=small_span_mask) 846 | l_loss = self.cal_loss(s_fea_l_pred, t_fea_l, apply_norm=False) 847 | else: 848 | l_loss = torch.zeros(1).to(fea.device) 849 | 850 | if self.h_predictor is not None: 851 | s_fea_h = masked_select(layer_results[-1][0], mask=large_span_mask) 852 | s_fea_h_pred = self.h_predictor(s_fea_h) 853 | t_fea_h = masked_select(h_target, mask=large_span_mask) 854 | h_loss = self.cal_loss(s_fea_h_pred, t_fea_h, apply_norm=False) 855 | else: 856 | h_loss = torch.zeros(1).to(fea.device) 857 | 858 | if self.x_predictor is not None: 859 | s_fea_x = fea.view(-1, fea.size(-1)) 860 | s_fea_x_pred = self.x_predictor(s_fea_x) 861 | t_fea_x = l_target.view(-1, l_target.size(-1)) 862 | x_loss = self.cal_loss(s_fea_x_pred, t_fea_x, apply_norm=False) 863 | else: 864 | x_loss = torch.zeros(1).to(fea.device) 865 | 866 | return l_loss, h_loss, x_loss 867 | 868 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multihead_attention import MultiheadAttention 2 | from .transformer_encoder import TransformerEncoderLayer 3 | from .classifier import Classifier 4 | from .positional_encoding import PositionalEncoding, RelPositionalEncoding, make_conv_pos 5 | from .activation import _get_activation_fn 6 | -------------------------------------------------------------------------------- /modules/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def _get_activation_fn(activation: str='relu', module: bool=False): 6 | """ Returns the activation function corresponding to `activation` """ 7 | if activation == "relu": 8 | return nn.ReLU() if module else F.relu 9 | elif activation == "gelu": 10 | return nn.GELU() if module else F.gelu 11 | elif activation == "tanh": 12 | return nn.Tanh() if module else torch.tanh 13 | elif activation == "linear": 14 | return lambda x: x 15 | else: 16 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 17 | 18 | -------------------------------------------------------------------------------- /modules/classifier.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from modules.activation import _get_activation_fn 4 | 5 | class Classifier(nn.Module): 6 | def __init__(self, input_dim, num_classes, dropout, activation): 7 | super().__init__() 8 | 9 | self.net = nn.Sequential( 10 | nn.Linear(input_dim, input_dim//2), 11 | _get_activation_fn(activation, module=True), 12 | nn.Dropout(dropout), 13 | nn.Linear(input_dim//2, input_dim//4), 14 | _get_activation_fn(activation, module=True), 15 | nn.Dropout(dropout), 16 | nn.Linear(input_dim//4, num_classes), 17 | ) 18 | 19 | def forward(self, x): 20 | pred = self.net(x) 21 | return pred 22 | 23 | -------------------------------------------------------------------------------- /modules/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MultiheadAttention(nn.Module): 7 | ''' 8 | Input dimension order is (batch_size, seq_len, input_dim). 9 | All the q, k, v inputs' feature dimensions are first projected to embed_dim, and then perform attention operation. 10 | ''' 11 | def __init__( 12 | self, 13 | embed_dim, 14 | num_heads, 15 | kdim=None, 16 | vdim=None, 17 | dropout=0., 18 | bias=True, 19 | has_relative_attention_bias=False, 20 | num_buckets=32, 21 | max_distance=128, 22 | gru_rel_pos=False, 23 | qk_norm=False 24 | ): 25 | super().__init__() 26 | self.kdim = kdim if kdim is not None else embed_dim 27 | self.vdim = vdim if vdim is not None else embed_dim 28 | 29 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 30 | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) 31 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) 32 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 33 | 34 | self.has_relative_attention_bias = has_relative_attention_bias 35 | self.num_buckets = num_buckets 36 | self.max_distance = max_distance 37 | if self.has_relative_attention_bias: 38 | self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) 39 | 40 | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" 41 | self.head_dim = embed_dim // num_heads 42 | self.num_heads = num_heads 43 | self.embed_dim = embed_dim 44 | self.dropout = dropout 45 | self.scaling = float(self.head_dim) ** -0.5 46 | 47 | self.q_head_dim = self.head_dim 48 | self.gru_rel_pos = gru_rel_pos 49 | if self.gru_rel_pos: 50 | self.grep_linear = nn.Linear(self.q_head_dim, 8) 51 | self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 52 | 53 | self.qk_nrom = qk_norm 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | nn.init.xavier_uniform_(self.q_proj.weight) 58 | nn.init.xavier_uniform_(self.k_proj.weight) 59 | nn.init.xavier_uniform_(self.v_proj.weight) 60 | 61 | nn.init.xavier_uniform_(self.out_proj.weight) 62 | if self.out_proj.bias is not None: 63 | nn.init.constant_(self.out_proj.bias, 0.0) 64 | if self.has_relative_attention_bias: 65 | nn.init.xavier_normal_(self.relative_attention_bias.weight) 66 | 67 | def _relative_positions_bucket(self, relative_positions, bidirectional=True): 68 | num_buckets = self.num_buckets 69 | max_distance = self.max_distance 70 | relative_buckets = 0 71 | 72 | if bidirectional: 73 | num_buckets = num_buckets // 2 74 | relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets 75 | relative_positions = torch.abs(relative_positions) 76 | else: 77 | relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) 78 | 79 | max_exact = num_buckets // 2 80 | is_small = relative_positions < max_exact 81 | 82 | relative_postion_if_large = max_exact + ( 83 | torch.log(relative_positions.float() / max_exact) 84 | / math.log(max_distance / max_exact) 85 | * (num_buckets - max_exact) 86 | ).to(torch.long) 87 | relative_postion_if_large = torch.min( 88 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 89 | ) 90 | 91 | relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) 92 | return relative_buckets 93 | 94 | def compute_bias(self, query_length, key_length): 95 | context_position = torch.arange(query_length, dtype=torch.long)[:, None] 96 | memory_position = torch.arange(key_length, dtype=torch.long)[None, :] 97 | relative_position = memory_position - context_position 98 | relative_position_bucket = self._relative_positions_bucket( 99 | relative_position, 100 | bidirectional=True 101 | ) 102 | relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) 103 | values = self.relative_attention_bias(relative_position_bucket) 104 | values = values.permute([2, 0, 1]) 105 | return values 106 | 107 | def forward(self, query, key=None, value=None, key_padding_mask=None, attn_mask=None, position_bias=None, need_weights=False): 108 | ''' 109 | Args: 110 | key_padding_mask: if provided, specified padding elements in the key will 111 | be ignored by the attention. This is an binary mask. When the value is True, 112 | the corresponding value on the attention layer will be filled with -inf. 113 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 114 | (i.e. the values will be added to the attention layer). 115 | Shape: 116 | Inputs: 117 | - query: :math:`(B, T, E)` where T is the target sequence length, B is the batch size, E is 118 | the embedding dimension. 119 | - key: :math:`(B, S, E)`, where S is the source sequence length, B is the batch size, E is 120 | the embedding dimension. 121 | - value: :math:`(B, S, E)` where S is the source sequence length, B is the batch size, E is 122 | the embedding dimension. 123 | - key_padding_mask: :math:`(B, S)`, ByteTensor, where B is the batch size, S is the source sequence length. 124 | 3-D key_padding_mask with math:`(B, T, S)` is supported now, where T is the target sequence length. 125 | - attn_mask: :math:`(T, S)` or math:`(B, T, S)` where T is the target sequence length, S is the source sequence length. 126 | ''' 127 | bsz, tgt_len, _ = query.size() 128 | 129 | Q = self.q_proj(query) 130 | K = self.k_proj(key) 131 | V = self.v_proj(value) 132 | Q = Q.transpose(0, 1).contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 133 | K = K.transpose(0, 1).contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 134 | V = V.transpose(0, 1).contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 135 | 136 | src_len = K.size(1) 137 | 138 | if self.qk_nrom: 139 | Q = F.layer_norm(Q, [Q.shape[-1]]) 140 | K = F.layer_norm(K, [K.shape[-1]]) 141 | 142 | attn_weights = torch.bmm(Q, K.transpose(1, 2)) * self.scaling 143 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 144 | 145 | if self.has_relative_attention_bias and position_bias is None: 146 | position_bias = self.compute_bias(tgt_len, src_len) 147 | position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) 148 | 149 | if position_bias is not None: 150 | attn_mask_rel_pos = position_bias 151 | if self.gru_rel_pos: 152 | query_layer = query 153 | new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) 154 | query_layer = query_layer.view(*new_x_shape) 155 | query_layer = query_layer.permute(0, 2, 1, 3) 156 | _B, _H, _L, __ = query_layer.size() 157 | 158 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( 159 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) 160 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 161 | attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias 162 | 163 | attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) 164 | attn_weights += attn_mask_rel_pos 165 | 166 | if attn_mask is not None: 167 | assert not self.has_relative_attention_bias and position_bias is None, 'attn_mask has been used for relative position bias' 168 | attn_mask = attn_mask.unsqueeze(0) if attn_mask.dim() == 2 else attn_mask 169 | attn_weights += attn_mask 170 | 171 | if key_padding_mask is not None: 172 | key_padding_mask = key_padding_mask.unsqueeze(1) if key_padding_mask.dim() == 3 else key_padding_mask.unsqueeze(1).unsqueeze(2) 173 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 174 | attn_weights = attn_weights.masked_fill(key_padding_mask, float('-inf')) 175 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 176 | 177 | attn_weights = F.softmax(attn_weights, dim=-1) 178 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 179 | 180 | attn_output = torch.bmm(attn_weights, V) 181 | assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 182 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim).transpose(0, 1) 183 | attn_output = self.out_proj(attn_output) 184 | 185 | if need_weights: 186 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len).sum(dim=1) / self.num_heads 187 | else: 188 | attn_weights = None 189 | 190 | return attn_output, attn_weights, position_bias 191 | 192 | -------------------------------------------------------------------------------- /modules/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | import math 8 | import torch 9 | 10 | class PositionalEncoding(nn.Module): 11 | """Positional encoding. 12 | Args: 13 | d_model: Embedding dimension. 14 | dropout_rate: Dropout rate. 15 | max_len: Maximum input length. 16 | reverse: Whether to reverse the input position. 17 | """ 18 | 19 | def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): 20 | """Construct an PositionalEncoding object.""" 21 | super(PositionalEncoding, self).__init__() 22 | self.d_model = d_model 23 | self.reverse = reverse 24 | self.xscale = math.sqrt(self.d_model) 25 | self.dropout = nn.Dropout(p=dropout_rate) 26 | self.pe = None 27 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 28 | 29 | def extend_pe(self, x): 30 | """Reset the positional encodings.""" 31 | if self.pe is not None: 32 | if self.pe.size(1) >= x.size(1): 33 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 34 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 35 | return 36 | pe = torch.zeros(x.size(1), self.d_model) 37 | if self.reverse: 38 | position = torch.arange( 39 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 40 | ).unsqueeze(1) 41 | else: 42 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 43 | div_term = torch.exp( 44 | torch.arange(0, self.d_model, 2, dtype=torch.float32) 45 | * -(math.log(10000.0) / self.d_model) 46 | ) 47 | pe[:, 0::2] = torch.sin(position * div_term) 48 | pe[:, 1::2] = torch.cos(position * div_term) 49 | pe = pe.unsqueeze(0) 50 | self.pe = pe.to(device=x.device, dtype=x.dtype) 51 | 52 | def forward(self, x: torch.Tensor): 53 | """Add positional encoding. 54 | Args: 55 | x (torch.Tensor): Input tensor B X T X C 56 | Returns: 57 | torch.Tensor: Encoded tensor B X T X C 58 | """ 59 | self.extend_pe(x) 60 | x = x * self.xscale + self.pe[:, : x.size(1)] 61 | return self.dropout(x) 62 | 63 | 64 | class RelPositionalEncoding(nn.Module): 65 | """Relative positional encoding module (new implementation) (Transformer-XL). 66 | Args: 67 | d_model: Embedding dimension. 68 | dropout_rate: Dropout rate. 69 | max_len: Maximum input length. (Maximum output length supported by the encoder.) 70 | """ 71 | 72 | def __init__(self, max_len, d_model): 73 | """Construct an PositionalEncoding object.""" 74 | super(RelPositionalEncoding, self).__init__() 75 | self.d_model = d_model 76 | self.pe = None 77 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 78 | 79 | def extend_pe(self, x): 80 | """Reset the positional encodings.""" 81 | if self.pe is not None: 82 | # self.pe contains both positive and negative parts 83 | # the length of self.pe is 2 * input_len - 1 84 | if self.pe.size(1) >= x.size(1) * 2 - 1: 85 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 86 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 87 | return 88 | # Suppose `i` means to the position of query vecotr and `j` means the 89 | # position of key vector. We use position relative positions when keys 90 | # are to the left (i>j) and negative relative positions otherwise (i 0: 155 | x = x[:, :, : -self.remove] 156 | return x 157 | -------------------------------------------------------------------------------- /modules/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules import MultiheadAttention 5 | from modules.activation import _get_activation_fn 6 | 7 | class TransformerEncoderLayer(nn.Module): 8 | def __init__( 9 | self, 10 | embed_dim, 11 | ffn_embed_dim, 12 | num_heads, 13 | activation, 14 | dropout, 15 | bias, 16 | normalize_before, 17 | has_relative_attention_bias: bool = False, 18 | num_buckets: int = 32, 19 | max_distance: int = 128, 20 | gru_rel_pos: bool = False, 21 | qk_norm: bool = False 22 | ): 23 | super().__init__() 24 | self.dropout = dropout 25 | self.normalize_before = normalize_before 26 | 27 | self.self_attn_layer_norm = nn.LayerNorm(embed_dim) 28 | self.self_attn = MultiheadAttention( 29 | embed_dim, num_heads, None, None, dropout, bias, has_relative_attention_bias, num_buckets, max_distance, gru_rel_pos, qk_norm 30 | ) 31 | 32 | # Feed-Forward Network 33 | self.activation_fn = _get_activation_fn(activation) 34 | self.fc1 = nn.Linear(embed_dim, ffn_embed_dim) 35 | self.fc2 = nn.Linear(ffn_embed_dim, embed_dim) 36 | self.final_layer_norm = nn.LayerNorm(embed_dim) 37 | 38 | def forward(self, x, key_padding_mask=None, pos_bias=None, need_weights=False): 39 | residual = x 40 | if self.normalize_before: 41 | x = self.self_attn_layer_norm(x) 42 | x, attn, pos_bias = self.self_attn( 43 | query=x, 44 | key=x, 45 | value=x, 46 | key_padding_mask=key_padding_mask, 47 | position_bias=pos_bias, 48 | need_weights=need_weights, 49 | ) 50 | x = F.dropout(x, p=self.dropout, training=self.training) 51 | x = residual + x 52 | 53 | if not self.normalize_before: 54 | x = self.self_attn_layer_norm(x) 55 | 56 | residual = x 57 | if self.normalize_before: 58 | x = self.final_layer_norm(x) 59 | x = self.activation_fn(self.fc1(x)) 60 | x = F.dropout(x, p=self.dropout, training=self.training) 61 | x = self.fc2(x) 62 | 63 | x = F.dropout(x, p=self.dropout, training=self.training) 64 | x = residual + x 65 | if not self.normalize_before: 66 | x = self.final_layer_norm(x) 67 | 68 | return x, attn, pos_bias 69 | 70 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed, Mar 8, 2023 5 | @author: lab-chen.weidong 6 | """ 7 | 8 | import os 9 | from tqdm import tqdm 10 | import torch 11 | import torch.optim.lr_scheduler as lr_scheduler 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | from torch.utils.tensorboard import SummaryWriter 15 | from collections import OrderedDict 16 | import math 17 | import json 18 | import shutil 19 | 20 | import utils 21 | import models 22 | from configs import create_workshop, get_config, dict_2_list 23 | 24 | class Engine(): 25 | def __init__(self, cfg, local_rank: int, world_size: int): 26 | self.cfg = cfg 27 | self.local_rank = local_rank 28 | self.world_size = world_size 29 | self.device = self.cfg.train.device 30 | self.EPOCH = self.cfg.train.EPOCH 31 | self.current_epoch = 0 32 | self.iteration = 0 33 | self.lowest_loss = 1e4 34 | self.loss_weight_l = self.cfg.train.loss_weight_l 35 | self.loss_weight_h = self.cfg.train.loss_weight_h 36 | self.loss_weight_x = self.cfg.train.loss_weight_x 37 | 38 | self.early_stopping = utils.earlystopping.EarlyStopping(patience=self.cfg.train.patience, verbose=self.local_rank == 0) 39 | 40 | ### prepare meters 41 | self.loss_meter = utils.avgmeter.AverageMeter(device='cuda') 42 | self.l_loss_meter = utils.avgmeter.AverageMeter(device='cuda') 43 | self.h_loss_meter = utils.avgmeter.AverageMeter(device='cuda') 44 | self.x_loss_meter = utils.avgmeter.AverageMeter(device='cuda') 45 | 46 | def prepare_staff(self, fold=1): 47 | ''' We move this part out of the __init__ function to avoid the weird error: 48 | DataLoader worker (pid xxx) is killed by signal: Aborted 49 | This error is probably caused by a conflict between lmdb and ddp. 50 | ''' 51 | ### prepare dataloader 52 | dataloader_feactory = utils.dataset.DataloaderFactory(self.cfg.dataset) 53 | self.dataloader_train = dataloader_feactory.build( 54 | state='train', 55 | bs=self.cfg.train.batch_size, 56 | fold=fold 57 | ) 58 | self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn 59 | self.cfg.model.device = self.device 60 | 61 | ### prepare model, optimizer and scheduler 62 | model = models.vesper.Vesper_PretrainWrapper(self.cfg.model).to(self.device) 63 | 64 | if self.cfg.train.freeze_cnn: 65 | for param in model.vesper.feature_extractor.parameters(): 66 | param.requires_grad = False 67 | 68 | self.model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.local_rank]) 69 | 70 | if self.cfg.train.optimizer == 'AdamW': 71 | self.optimizer = torch.optim.AdamW( 72 | params=filter(lambda x: x.requires_grad, self.model.parameters()), 73 | lr=self.cfg.train.lr, 74 | weight_decay=self.cfg.train.weight_decay 75 | ) 76 | elif self.cfg.train.optimizer == 'sgd': 77 | self.optimizer = torch.optim.SGD( 78 | params=filter(lambda x: x.requires_grad, self.model.parameters()), 79 | lr=self.cfg.train.lr, 80 | momentum=0.9, 81 | weight_decay=self.cfg.train.weight_decay 82 | ) 83 | else: 84 | raise ValueError(f'Unknown optimizer: {self.cfg.train.optimizer}') 85 | 86 | if self.local_rank == 0: 87 | print(f'Optimizer: {self.cfg.train.optimizer}') 88 | 89 | # CosineAnnealingLR with Warm-up 90 | warmup_epoch = int(self.cfg.train.warmup_epoch * self.EPOCH) 91 | lr_max = self.cfg.train.lr 92 | lr_min = self.cfg.train.lr * 0.01 93 | T_max = self.EPOCH 94 | lr_lambda = lambda epoch: (epoch + 1) / warmup_epoch if epoch < warmup_epoch else \ 95 | (lr_min + 0.5*(lr_max-lr_min)*(1.0+math.cos((epoch-warmup_epoch)/(T_max-warmup_epoch)*math.pi))) / self.cfg.train.lr 96 | self.scheduler = lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lr_lambda) 97 | 98 | if self.cfg.train.load_model is not None: 99 | ckpt = torch.load(self.cfg.train.load_model, map_location=self.device) 100 | self.model.module.load_state_dict(ckpt['model']) 101 | if self.local_rank == 0: 102 | print(f'Loading model from {self.cfg.train.load_model}') 103 | del ckpt 104 | 105 | if self.cfg.train.resume is not None: 106 | ckpt = torch.load(self.cfg.train.resume, map_location=self.device) 107 | self.model.module.load_state_dict(ckpt['model']) 108 | self.optimizer.load_state_dict(ckpt['optimizer']) 109 | self.scheduler.load_state_dict(ckpt['scheduler']) 110 | self.scheduler.step() 111 | self.current_epoch = ckpt['epoch'] + 1 112 | self.iteration = ckpt['iteration'] 113 | self.lowest_loss = ckpt['lowest_loss'] 114 | if self.local_rank == 0: 115 | print(f'Resuming from {self.cfg.train.resume}') 116 | del ckpt 117 | 118 | ### prepare writer and logger 119 | if self.local_rank == 0: 120 | self.writer = SummaryWriter(self.cfg.workshop) 121 | self.logger_train = utils.logger.create_logger(self.cfg.workshop, name='train') 122 | self.logger_train.info(f'workshop: {self.cfg.workshop}') 123 | self.logger_train.info(f'seed: {self.cfg.train.seed}') 124 | self.logger_train.info(f'pid: {os.getpid()}') 125 | print('Main pid:', os.getpid()) 126 | else: 127 | self.writer = None 128 | self.logger_train = None 129 | 130 | self.config_2_json() 131 | 132 | def config_2_json(self, jsonfile=None): 133 | self.jsonfile = os.path.join(self.cfg.workshop, 'config.json') if jsonfile is None else jsonfile 134 | with open(self.jsonfile, 'w') as f: 135 | json.dump(dict(self.cfg), f, indent=2) 136 | 137 | def json_2_config(self, jsonfile=None): 138 | if jsonfile is not None: 139 | self.jsonfile = jsonfile 140 | assert hasattr(self, 'jsonfile'), 'Please provide the .json file first.' 141 | with open(self.jsonfile, 'r') as f: 142 | data = json.load(f) 143 | self.cfg.merge_from_list(dict_2_list(data)) 144 | 145 | def reset_meters(self): 146 | self.loss_meter.reset() 147 | self.l_loss_meter.reset() 148 | self.h_loss_meter.reset() 149 | self.x_loss_meter.reset() 150 | 151 | def gather_distributed_data(self, gather_data): 152 | if isinstance(gather_data, torch.Tensor): 153 | _output = [torch.zeros_like(gather_data) for _ in range(self.world_size)] 154 | dist.all_gather(_output, gather_data, async_op=False) 155 | output = torch.cat(_output) 156 | else: 157 | if gather_data[0] is not None: 158 | _output = [None for _ in range(self.world_size)] 159 | if hasattr(dist, 'all_gather_object'): 160 | dist.all_gather_object(_output, gather_data) 161 | else: 162 | utils.distributed.all_gather_object(_output, gather_data, self.world_size) 163 | output = [] 164 | for lst in _output: 165 | output.extend(lst) 166 | else: 167 | output = None 168 | return output 169 | 170 | def train_epoch(self): 171 | self.dataloader_train.set_epoch(self.current_epoch) 172 | if self.local_rank == 0: 173 | print(f'-------- {self.cfg.workshop} --------') 174 | discrip_str = f'Epoch-{self.current_epoch}/{self.EPOCH}' 175 | pbar_train = tqdm(self.dataloader_train, disable=self.local_rank != 0, dynamic_ncols=True) 176 | pbar_train.set_description('Train' + discrip_str) 177 | 178 | self.reset_meters() 179 | 180 | self.model.train() 181 | for data in pbar_train: 182 | self.iteration += 1 183 | 184 | waveform = torch.cat(data['waveform'], dim=0).to(self.device) 185 | padding_mask = torch.cat(data['padding_mask'], dim=0).to(self.device) 186 | 187 | self.optimizer.zero_grad() 188 | 189 | l_target = torch.stack(data['l_target'], dim=0).to(self.device) if data['l_target'][0] is not None else None 190 | h_target = torch.stack(data['h_target'], dim=0).to(self.device) if data['h_target'][0] is not None else None 191 | l_loss, h_loss, x_loss = self.model(waveform, padding_mask, l_target, h_target) 192 | loss = self.loss_weight_l * l_loss + self.loss_weight_h * h_loss + self.loss_weight_x * x_loss 193 | loss.backward() 194 | 195 | self.optimizer.step() 196 | 197 | self.loss_meter.update(loss.item()) 198 | self.l_loss_meter.update(l_loss.item()) 199 | self.h_loss_meter.update(h_loss.item()) 200 | self.x_loss_meter.update(x_loss.item()) 201 | 202 | pbar_train_dic = OrderedDict() 203 | pbar_train_dic['iter'] = self.iteration 204 | pbar_train_dic['lr'] = self.optimizer.param_groups[0]['lr'] 205 | pbar_train_dic['l_loss'] = f'{self.l_loss_meter.avg:.5f}' 206 | pbar_train_dic['h_loss'] = f'{self.h_loss_meter.avg:.5f}' 207 | pbar_train_dic['x_loss'] = f'{self.x_loss_meter.avg:.5f}' 208 | pbar_train_dic['loss'] = f'{self.loss_meter.avg:.5f}' 209 | pbar_train.set_postfix(pbar_train_dic) 210 | 211 | if self.iteration % (len(self.dataloader_train) // 20) == 0: 212 | if self.local_rank == 0: 213 | self.writer.add_scalar('Step/l_loss', l_loss.item(), self.iteration) 214 | self.writer.add_scalar('Step/h_loss', h_loss.item(), self.iteration) 215 | self.writer.add_scalar('Step/x_loss', x_loss.item(), self.iteration) 216 | self.writer.add_scalar('Step/loss', loss.item(), self.iteration) 217 | 218 | self.loss_meter.sync_distributed() 219 | self.l_loss_meter.sync_distributed() 220 | self.h_loss_meter.sync_distributed() 221 | self.x_loss_meter.sync_distributed() 222 | 223 | l_loss_epoch = self.l_loss_meter.avg 224 | h_loss_epoch = self.h_loss_meter.avg 225 | x_loss_epoch = self.x_loss_meter.avg 226 | loss_epoch = self.loss_meter.avg 227 | 228 | if self.local_rank == 0: 229 | self.writer.add_scalar('Epoch/l_loss', l_loss_epoch, self.current_epoch) 230 | self.writer.add_scalar('Epoch/h_loss', h_loss_epoch, self.current_epoch) 231 | self.writer.add_scalar('Epoch/x_loss', x_loss_epoch, self.current_epoch) 232 | self.writer.add_scalar('Epoch/loss', loss_epoch, self.current_epoch) 233 | self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], self.current_epoch) 234 | 235 | self.logger_train.info( 236 | f'Training epoch: {self.current_epoch}, l_loss: {l_loss_epoch:.5f}, h_loss: {h_loss_epoch:.5f}, x_loss: {x_loss_epoch:.5f}, loss: {loss_epoch:.5f}' 237 | ) 238 | 239 | is_best = loss_epoch < self.lowest_loss 240 | self.lowest_loss = min(loss_epoch, self.lowest_loss) 241 | self.model_save(is_best) 242 | 243 | self.early_stopping(loss_epoch, self.model) 244 | 245 | def model_save(self, is_best=False, filename='checkpoint.pt'): 246 | ckpt_save_file = os.path.join(self.cfg.ckpt_save_path, filename) 247 | save_dict = { 248 | 'cfg': self.cfg, 249 | 'epoch': self.current_epoch, 250 | 'iteration': self.iteration, 251 | 'lowest_loss': self.lowest_loss, 252 | 'model': self.model.module.state_dict(), # save DDP model 253 | 'optimizer': self.optimizer.state_dict(), 254 | 'scheduler': self.scheduler.state_dict() 255 | } 256 | torch.save(save_dict, ckpt_save_file) 257 | if is_best: 258 | shutil.copyfile(ckpt_save_file, os.path.join(self.cfg.ckpt_save_path, 'model_best.pt')) 259 | 260 | def run(self, fold=1): 261 | self.prepare_staff(fold=fold) 262 | 263 | while self.current_epoch < self.EPOCH: 264 | self.train_epoch() 265 | self.scheduler.step() 266 | 267 | self.current_epoch += 1 268 | 269 | if self.early_stopping.early_stop: 270 | print(f"Early stopping (patience: {self.early_stopping.patience})") 271 | break 272 | 273 | self.cleanup() 274 | 275 | def cleanup(self): 276 | if self.logger_train is not None: 277 | utils.logger.close_logger(self.logger_train) 278 | if self.writer is not None: 279 | self.writer.close() 280 | # torch.cuda.empty_cache() 281 | self.early_stopping.clean() 282 | self.current_epoch = 0 283 | self.iteration = 0 284 | self.lowest_loss = 1e4 285 | 286 | def main_worker(local_rank, cfg, world_size, dist_url): 287 | mp.set_sharing_strategy('file_system') 288 | utils.environment.set_seed(cfg.train.seed + local_rank) 289 | torch.cuda.set_device(local_rank) 290 | dist.init_process_group( 291 | backend='nccl', 292 | init_method=dist_url, 293 | world_size=world_size, 294 | rank=local_rank, 295 | ) 296 | # torch.autograd.set_detect_anomaly(True) 297 | engine = Engine(cfg, local_rank, world_size) 298 | for fold in cfg.dataset.folds: 299 | create_workshop(cfg, local_rank, fold) 300 | engine.run(fold) 301 | 302 | if local_rank == 0: 303 | criterion = ['l_loss', 'h_loss', 'x_loss', 'loss'] 304 | evaluate = ['loss'] 305 | outfile = f'result/result_{cfg.model.type}.csv' 306 | wantlow = True 307 | utils.collect_result.path_to_csv(os.path.dirname(cfg.workshop), criterion, evaluate, csvfile=outfile, logname='train.log', wantlow=wantlow) 308 | 309 | def main(cfg): 310 | utils.environment.visible_gpus(cfg.train.device_id) 311 | 312 | free_port = utils.distributed.find_free_port() 313 | dist_url = f'tcp://127.0.0.1:{free_port}' 314 | world_size = torch.cuda.device_count() # num_gpus 315 | print(f'world_size={world_size} Using dist_url={dist_url}') 316 | 317 | mp.spawn(fn=main_worker, args=(cfg, world_size, dist_url), nprocs=world_size) 318 | 319 | if __name__=='__main__': 320 | cfg = get_config(mode='_pretrain') 321 | main(cfg) 322 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import distributed 2 | from . import environment 3 | from . import logger 4 | from . import collect_result 5 | from . import avgmeter 6 | from . import dataset 7 | from . import recoder 8 | from . import earlystopping 9 | from . import metric 10 | -------------------------------------------------------------------------------- /utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.distributed 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self, name='No name', fmt=':f', device=torch.device('cpu')): 9 | self.name = name 10 | self.fmt = fmt 11 | self.device = device 12 | 13 | self.val = torch.tensor(0, dtype=torch.float, device=self.device) 14 | self.sum = torch.tensor(0, dtype=torch.float, device=self.device) 15 | self.count = torch.tensor(0, dtype=torch.int, device=self.device) 16 | 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = torch.tensor(0, dtype=torch.float, device=self.device) 21 | self.sum = torch.tensor(0, dtype=torch.float, device=self.device) 22 | self.count = torch.tensor(0, dtype=torch.int, device=self.device) 23 | 24 | @torch.no_grad() 25 | def update(self, val: torch.Tensor, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | 30 | @property 31 | def avg(self): 32 | return self.sum / self.count 33 | 34 | def __str__(self): 35 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 36 | return fmtstr.format( 37 | name=self.name, 38 | val=self.val.item(), 39 | avg=self.avg.item(), 40 | ) 41 | 42 | def sync_distributed(self): 43 | r_count = torch.distributed.all_reduce(self.count, op=torch.distributed.ReduceOp.SUM, async_op=True) 44 | r_sum = torch.distributed.all_reduce(self.sum, op=torch.distributed.ReduceOp.SUM, async_op=True) 45 | r_count.wait() 46 | r_sum.wait() 47 | 48 | -------------------------------------------------------------------------------- /utils/collect_result.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import Counter 3 | import pandas as pd 4 | import os 5 | import re 6 | import numpy as np 7 | import math 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | matplotlib.use('Agg') 11 | 12 | def get_index(lst=None, item=''): 13 | return [index for (index,value) in enumerate(lst) if value == item] 14 | 15 | def path_to_csv(filepath='./exp', criterion=['accuracy', 'precision', 'recall', 'f1'], evaluate=['f1'], 16 | largest=5, retrun=5, logname='test.log', csvfile='test.csv', overwrite=False, wantlow=False, epoch=None): 17 | ''' 18 | Record the average cross-validation result to a csv file 19 | if wantlow is True, the smaller the value is, the better. 20 | 21 | Input 22 | - filepath: path to every folds 23 | ''' 24 | 25 | all_file_result = [] 26 | files = os.listdir(filepath) 27 | for file in files: 28 | result = {c:[] for c in criterion} 29 | if epoch is not None: 30 | with open(os.path.join(filepath, file, 'Log', logname), 'r') as f: 31 | for line in f.readlines(): 32 | line = line.strip() 33 | if f'Testing epoch: {epoch},' in line: 34 | for c in criterion: 35 | score = re.search(f' {c}: \d+.\d+', line).group() 36 | score = float(re.sub(f' {c}: ', '', score)) 37 | result[c].append(score) 38 | assert len(result[c]) == 1 39 | f.close() 40 | best_result = [result[c][0] for c in criterion] 41 | else: 42 | max_id = [] 43 | for c in criterion: 44 | with open(os.path.join(filepath, file, 'Log', logname), 'r') as f: 45 | for line in f.readlines(): 46 | line = line.strip() 47 | if c in line and 'workshop' not in line: 48 | score = re.search(f' {c}: \d+.\d+', line).group() 49 | score = float(re.sub(f' {c}: ', '', score)) 50 | result[c].append(score) 51 | f.close() 52 | if c in evaluate: 53 | result_set = set(result[c]) 54 | result_max = sorted(result_set, reverse=not wantlow)[:largest] 55 | temp = [] 56 | for maximum in result_max: 57 | temp += get_index(result[c], maximum) 58 | max_id.extend(temp) 59 | 60 | c = Counter(max_id) 61 | return_id_counts = c.most_common(retrun) 62 | 63 | if not wantlow: 64 | best_idx = 0 65 | best_sum = 0 66 | for idx, counts in return_id_counts: 67 | s = sum([result[c][idx] for c in evaluate]) # criterion -> evaluate 68 | if s > best_sum: 69 | best_sum = s 70 | best_idx = idx 71 | else: 72 | best_idx = 0 73 | best_sum = 10000 74 | for idx, counts in return_id_counts: 75 | s = sum([result[c][idx] for c in evaluate]) # criterion -> evaluate 76 | if s < best_sum: 77 | best_sum = s 78 | best_idx = idx 79 | best_result = [result[c][best_idx] for c in criterion] 80 | all_file_result.append(best_result) 81 | 82 | print('Calculate mean result from {} files. Write to {}'.format(len(all_file_result), csvfile)) 83 | print(f'Evaluate: {evaluate}') 84 | mean_result = np.mean(np.array(all_file_result), axis=0).tolist() 85 | 86 | if not os.path.exists(os.path.dirname(csvfile)): 87 | os.makedirs(os.path.dirname(csvfile)) 88 | 89 | if not os.path.exists(csvfile): 90 | data = {'Model': []} 91 | data.update({c: [] for c in criterion}) 92 | df = pd.DataFrame(data) 93 | df.to_csv(csvfile, index=False, sep=',') 94 | 95 | newdata = {'Model': filepath[6:]} # pass ./exp/ 96 | newdata.update({c: [r] for c, r in zip(criterion, mean_result)}) 97 | new_df = pd.DataFrame(newdata) 98 | if overwrite: 99 | df = new_df 100 | else: 101 | df = pd.read_csv(csvfile) 102 | df_temp = df[df['Model'] == new_df.loc[0, 'Model']] 103 | if df_temp.empty: 104 | df = pd.concat([df, new_df], ignore_index=True) # insert a new line in DataFrame 105 | else: 106 | row_index = df_temp.index.tolist()[0] 107 | for c in criterion: 108 | df.loc[row_index, c] = new_df.loc[0, c] 109 | 110 | for c in criterion: 111 | df[c] = df[c].apply(lambda x: round(x, 3)) 112 | df.to_csv(csvfile, index=False, sep=',') 113 | tidy_csvfile(csvfile, colname='Model') 114 | 115 | def plot_process(x: list, title: list, savedir: str): 116 | col = math.ceil(len(x) / 2) 117 | assert col < 5, print('Get too many data, the maximun number of columns in figure is 4.') 118 | line = 2 119 | 120 | color = ['b', 'g', 'k', 'r'] 121 | color = color[:col] * line 122 | plt.figure(figsize=(18, 8)) 123 | plt.suptitle(savedir.split('/', maxsplit=1)[-1]) 124 | 125 | plt.subplots_adjust(wspace=0.15, hspace=0.3, bottom=0.2) 126 | for i, (data_x, data_title, c) in enumerate(zip(x, title, color)): 127 | y = np.arange(len(data_x)) 128 | plt.subplot(line, col, i + 1) 129 | plt.plot(y, data_x, c) 130 | plt.title(data_title) 131 | 132 | plt.savefig(os.path.join(savedir, 'result.png'), bbox_inches='tight', pad_inches=0.2) 133 | plt.close() 134 | 135 | return 136 | 137 | def tidy_csvfile(csvfile, colname, ascending=True): 138 | ''' 139 | tidy csv file base on a particular column. 140 | ''' 141 | print(f'tidy file: {csvfile}, base on column: {colname}') 142 | df = pd.read_csv(csvfile) 143 | df = df.sort_values(by=[colname], ascending=ascending, na_position='last') 144 | df = df.round(3) 145 | df.to_csv(csvfile, index=False, sep=',') 146 | 147 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torch.utils.data.distributed import DistributedSampler 6 | import multiprocessing as mp 7 | import numpy as np 8 | import pandas as pd 9 | import librosa 10 | from scipy import io 11 | from sklearn.model_selection import StratifiedShuffleSplit 12 | 13 | def identity(x): 14 | return x 15 | 16 | class DistributedDalaloaderWrapper(): 17 | def __init__(self, dataloader: DataLoader, collate_fn): 18 | self.dataloader = dataloader 19 | self.collate_fn = collate_fn 20 | 21 | def _epoch_iterator(self, it): 22 | for batch in it: 23 | yield self.collate_fn(batch) 24 | 25 | def __iter__(self): 26 | it = iter(self.dataloader) 27 | return self._epoch_iterator(it) 28 | 29 | def __len__(self): 30 | return len(self.dataloader) 31 | 32 | @property 33 | def dataset(self): 34 | return self.dataloader.dataset 35 | 36 | def set_epoch(self, epoch: int): 37 | self.dataloader.sampler.set_epoch(epoch) 38 | 39 | def universal_collater(batch): 40 | all_data = [[] for _ in range(len(batch[0]))] 41 | for one_batch in batch: 42 | for i, (data) in enumerate(one_batch): 43 | all_data[i].append(data) 44 | return all_data 45 | 46 | def universal_dict_collater(batch): 47 | keys = batch[0].keys() 48 | all_data = {key: [] for key in keys} 49 | for one_batch in batch: 50 | for key in keys: 51 | all_data[key].append(one_batch[key]) 52 | return all_data 53 | 54 | class LssedDataset(Dataset): 55 | def __init__(self, args, state: str='train'): 56 | 57 | _mapping = None 58 | 59 | self.df = pd.read_csv(args.meta_csv_file) 60 | self.wavdir = args.wavdir 61 | self.batch_length = args.batch_length 62 | 63 | self.l_target_dir = args.l_target_dir 64 | self.h_target_dir = args.h_target_dir 65 | self.target_length = args.target_length 66 | 67 | if _mapping is not None: 68 | self.df['faces'] = self.df['faces'].map(_mapping).astype(np.float32) 69 | 70 | self.df = self.df.reset_index() 71 | 72 | def __len__(self): 73 | return len(self.df) 74 | 75 | def __getitem__(self, idx): 76 | waveform, _ = librosa.load(os.path.join(self.wavdir, self.df.loc[idx, 'name'] + '.wav'), sr=16000) 77 | emotion = self.df.loc[idx, 'faces'] 78 | padding_mask = torch.full((1, self.batch_length), fill_value=False, dtype=torch.bool) 79 | 80 | length = waveform.shape[-1] 81 | if length >= self.batch_length: 82 | waveform = waveform[np.newaxis, :self.batch_length] 83 | else: 84 | padding_length = self.batch_length - length 85 | waveform = np.pad(waveform, ((0, padding_length)), 'constant', constant_values=(0, 0))[np.newaxis, :] 86 | padding_mask[:, -padding_length:] = True 87 | 88 | waveform = torch.from_numpy(waveform) 89 | 90 | l_target = io.loadmat(os.path.join(self.l_target_dir, self.df.loc[idx, 'name']))['wavlm'] 91 | h_target = io.loadmat(os.path.join(self.h_target_dir, self.df.loc[idx, 'name']))['wavlm'] 92 | length = h_target.shape[0] 93 | if length >= self.target_length: 94 | l_target = l_target[:self.target_length] 95 | h_target = h_target[:self.target_length] 96 | else: 97 | padding_length = self.target_length - length 98 | l_target = np.pad(l_target, ((0, padding_length), (0, 0)), 'constant', constant_values=(0, 0)) 99 | h_target = np.pad(h_target, ((0, padding_length), (0, 0)), 'constant', constant_values=(0, 0)) 100 | l_target = torch.from_numpy(l_target) 101 | h_target = torch.from_numpy(h_target) 102 | 103 | sample = { 104 | 'waveform': waveform, 105 | 'padding_mask': padding_mask, 106 | 'emotion': emotion, 107 | 'l_target': l_target, 108 | 'h_target': h_target 109 | } 110 | 111 | return sample 112 | 113 | class DownstreamDataset(Dataset): 114 | def __init__(self, df, wavdir, batch_length, col_sample='name', col_label='label'): 115 | self.df = df 116 | self.wavdir = wavdir 117 | self.batch_length = batch_length 118 | self.col_sample = col_sample 119 | self.col_label = col_label 120 | 121 | def __len__(self): 122 | return len(self.df) 123 | 124 | def __getitem__(self, idx): 125 | waveform, _ = librosa.load(os.path.join(self.wavdir, self.df.loc[idx, self.col_sample] + '.wav'), sr=16000) 126 | emotion = torch.tensor([self.df.loc[idx, self.col_label]], dtype=torch.long) 127 | padding_mask = torch.full((1, self.batch_length), fill_value=False, dtype=torch.bool) 128 | 129 | length = waveform.shape[-1] 130 | if length >= self.batch_length: 131 | waveform = waveform[np.newaxis, :self.batch_length] 132 | else: 133 | padding_length = self.batch_length - length 134 | waveform = np.pad(waveform, ((0, padding_length)), 'constant', constant_values=(0, 0))[np.newaxis, :] 135 | padding_mask[:, -padding_length:] = True 136 | 137 | waveform = torch.from_numpy(waveform) 138 | 139 | sample = { 140 | 'waveform': waveform, 141 | 'padding_mask': padding_mask, 142 | 'emotion': emotion 143 | } 144 | 145 | return sample 146 | 147 | class IemocapDataset(DownstreamDataset): 148 | def __init__(self, args, state: str='train', fold: int=None): 149 | 150 | _mapping = {'ang': 0, 'neu': 1, 'hap': 2, 'exc': 2, 'sad': 3} 151 | 152 | df = pd.read_csv(args.meta_csv_file) 153 | wavdir = args.wavdir 154 | batch_length = args.batch_length 155 | 156 | if _mapping is not None: 157 | df['label'] = df['label'].map(_mapping).astype(np.float32) 158 | 159 | df = df[df['label'].notnull()] 160 | 161 | if fold is not None: 162 | test_session = f'Ses0{fold}' 163 | samples = df['name'].str.startswith(test_session) 164 | if state == 'train': 165 | samples = ~samples 166 | df = df[samples] 167 | 168 | df = df.reset_index() 169 | 170 | super().__init__(df, wavdir, batch_length) 171 | 172 | class MeldDataset(DownstreamDataset): 173 | def __init__(self, args, state: str='train'): 174 | 175 | _mapping = {'neutral': 0, 'anger': 1, 'joy': 2, 'sadness': 3, 'surprise': 4, 'disgust': 5, 'fear': 6} 176 | state_csv_file = {'train': 'train_sent_emo.csv', 'dev': 'dev_sent_emo.csv', 'test': 'test_sent_emo.csv'} 177 | state_wav_dir = {'train': 'train', 'dev': 'dev', 'test': 'test'} 178 | 179 | df = pd.read_csv(os.path.join(args.meta_csv_file, state_csv_file[state])) 180 | wavdir = os.path.join(args.wavdir, state_wav_dir[state]) 181 | batch_length = args.batch_length 182 | 183 | if _mapping is not None: 184 | df['Emotion'] = df['Emotion'].map(_mapping).astype(np.float32) 185 | 186 | audio_name = [] 187 | for dia, utt in zip(df['Dialogue_ID'], df['Utterance_ID']): 188 | audio_name.append(f'dia{dia}_utt{utt}') 189 | 190 | df['name'] = audio_name 191 | df = df[['name', 'Emotion']] 192 | 193 | delete_row = [] 194 | for row_index, row in df.iterrows(): 195 | if not os.path.exists(os.path.join(wavdir, row['name'] + '.wav')): 196 | delete_row.append(row_index) 197 | df = df.drop(delete_row, axis=0) 198 | 199 | df = df[df['Emotion'].notnull()] 200 | 201 | df = df.reset_index() 202 | 203 | super().__init__(df, wavdir, batch_length, col_label='Emotion') 204 | 205 | class CremaDataset(DownstreamDataset): 206 | def __init__(self, args, state: str='train'): 207 | 208 | _mapping = {'ANG': 0, 'DIS': 1, 'FEA': 2, 'HAP': 3, 'NEU': 4, 'SAD': 5} 209 | 210 | df = pd.read_csv(args.meta_csv_file) 211 | wavdir = args.wavdir 212 | batch_length = args.batch_length 213 | 214 | if _mapping is not None: 215 | df['label'] = df['label'].map(_mapping).astype(np.float32) 216 | 217 | df = df[df['label'].notnull()] 218 | 219 | sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=2013) 220 | for train_index, test_index in sss.split(df['name'], df['label']): 221 | if state == 'train': 222 | df = df.iloc[train_index] 223 | else: 224 | df = df.iloc[test_index] 225 | 226 | df = df.reset_index() 227 | 228 | super().__init__(df, wavdir, batch_length) 229 | 230 | class DataloaderFactory(): 231 | def __init__(self, args): 232 | self.args = args 233 | 234 | def build(self, state: str='train', bs: int=1, fold: int=1): 235 | if self.args.database == 'lssed': 236 | dataset = LssedDataset(self.args, state) 237 | elif self.args.database == 'iemocap': 238 | dataset = IemocapDataset(self.args, state, fold) 239 | elif self.args.database == 'meld': 240 | dataset = MeldDataset(self.args, state) 241 | elif self.args.database == 'crema': 242 | dataset = CremaDataset(self.args, state) 243 | else: 244 | raise NotImplementedError 245 | 246 | collate_fn = universal_dict_collater 247 | sampler = DistributedSampler(dataset, shuffle=state == 'train') 248 | dataloader = DataLoader( 249 | dataset=dataset, 250 | batch_size=bs, 251 | drop_last=False, 252 | num_workers=self.args.num_workers, 253 | collate_fn=identity, 254 | sampler=sampler, 255 | pin_memory=True, 256 | multiprocessing_context=mp.get_context('fork'), # fork/spawn # quicker! Used with multi-process loading (num_workers > 0) 257 | ) 258 | 259 | return DistributedDalaloaderWrapper(dataloader, collate_fn) 260 | 261 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | 2 | from contextlib import closing 3 | import socket 4 | import io 5 | import pickle 6 | import torch 7 | import torch.distributed as dist 8 | 9 | def find_free_port() -> int: 10 | """ 11 | Find a free port for dist url 12 | :return: 13 | """ 14 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 15 | s.bind(('', 0)) 16 | port = s.getsockname()[1] 17 | 18 | return port 19 | 20 | def scale_learning_rate(lr: float, world_size: int, batch_size: int, base_batch_size: int = 64) -> float: 21 | new_lr = lr * world_size * batch_size / base_batch_size 22 | print(f'adjust lr according to the number of GPU and batch size:{lr} -> {new_lr}') 23 | return new_lr 24 | 25 | def _object_to_tensor(obj): 26 | f = io.BytesIO() 27 | pickle.Pickler(f).dump(obj) 28 | byte_storage = torch.ByteStorage.from_buffer(f.getvalue()).tolist() 29 | byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8) 30 | local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long) 31 | return byte_tensor, local_size 32 | 33 | def _tensor_to_object(tensor, tensor_size): 34 | buf = tensor.numpy().tobytes()[:tensor_size] 35 | return pickle.Unpickler(io.BytesIO(buf)).load() 36 | 37 | def all_gather_object(object_list, obj, world_size): 38 | input_tensor, local_size = _object_to_tensor(obj) 39 | current_device = torch.device('cuda', torch.cuda.current_device()) 40 | input_tensor = input_tensor.to(current_device) 41 | local_size = local_size.to(current_device) 42 | 43 | # Gather all local sizes. This is so that we can find the max size, and index 44 | # until the correct size when deserializing the tensors. 45 | object_sizes_tensor = torch.zeros(world_size, dtype=torch.long, device=current_device) 46 | object_size_list = [ 47 | object_sizes_tensor[i].unsqueeze(dim=0) for i in range(world_size) 48 | ] 49 | # Allgather tensor sizes 50 | dist.all_gather(object_size_list, local_size) 51 | max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] 52 | # Resize tensor to max size across all ranks. 53 | input_tensor.resize_(max_object_size) 54 | coalesced_output_tensor = torch.empty( 55 | max_object_size * world_size, dtype=torch.uint8, device=current_device 56 | ) 57 | # Output tensors are nonoverlapping views of coalesced_output_tensor 58 | output_tensors = [ 59 | coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] 60 | for i in range(world_size) 61 | ] 62 | dist.all_gather(output_tensors, input_tensor) 63 | # Deserialize outputs back to object. 64 | for i, tensor in enumerate(output_tensors): 65 | tensor = tensor.type(torch.uint8).cpu() # type:ignore[call-overload] 66 | tensor_size = object_size_list[i] 67 | object_list[i] = _tensor_to_object(tensor, tensor_size) 68 | 69 | 70 | def broadcast_object_list(object_list, src=0, cur_rank=0): 71 | # Serialize object_list elements to tensors on src rank. 72 | if cur_rank == src: 73 | tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list]) 74 | object_sizes_tensor = torch.cat(size_list) 75 | else: 76 | object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) 77 | 78 | current_device = torch.device("cuda", torch.cuda.current_device()) 79 | 80 | object_sizes_tensor = object_sizes_tensor.to(current_device) 81 | 82 | # Broadcast object sizes 83 | dist.broadcast(object_sizes_tensor, src=src) 84 | 85 | # Concatenate and broadcast serialized object tensors 86 | if cur_rank == src: 87 | object_tensor = torch.cat(tensor_list) 88 | else: 89 | object_tensor = torch.empty( 90 | torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] 91 | dtype=torch.uint8, 92 | ) 93 | 94 | object_tensor = object_tensor.to(current_device) 95 | 96 | dist.broadcast(object_tensor, src=src) 97 | # Deserialize objects using their stored sizes. 98 | offset = 0 99 | if cur_rank != src: 100 | for i, obj_size in enumerate(object_sizes_tensor): 101 | obj_view = object_tensor[offset : offset + obj_size] 102 | obj_view = obj_view.type(torch.uint8) 103 | if obj_view.device != torch.device("cpu"): 104 | obj_view = obj_view.cpu() 105 | offset += obj_size 106 | object_list[i] = _tensor_to_object(obj_view, obj_size) 107 | 108 | -------------------------------------------------------------------------------- /utils/earlystopping.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | class EarlyStopping: 7 | """Early stops the training if validation loss (or other metric) doesn't improve after a given patience.""" 8 | def __init__(self, patience=7, verbose=False, delta=0, ckpt_path=None, higher_is_better=False): 9 | """ 10 | Args: 11 | patience (int): How long to wait after last time validation loss improved. 12 | Default: 7 13 | verbose (bool): If True, prints a message for each validation loss improvement. 14 | Default: False 15 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 16 | Default: 0 17 | ckpt_path (str): Path to save the checkpoint. 18 | Default: None 19 | higher_is_better (bool): If True, the higher scorce denotes the better performance. 20 | Default: False 21 | """ 22 | self.patience = patience 23 | self.verbose = verbose 24 | self.counter = 0 25 | self.higher_is_better = higher_is_better 26 | self.best_score = None 27 | self.early_stop = False 28 | self.delta = delta 29 | self.ckpt_path = ckpt_path 30 | if self.verbose: 31 | print(f'Early Stopping: patience {patience}') 32 | 33 | def __call__(self, score, model): 34 | 35 | score = score if self.higher_is_better else -score 36 | 37 | if self.patience <= 0: 38 | self.early_stop = False 39 | else: 40 | if self.best_score is None: 41 | self.best_score = score 42 | elif score < self.best_score + self.delta: 43 | self.counter += 1 44 | if self.verbose: 45 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 46 | if self.counter >= self.patience: 47 | self.early_stop = True 48 | else: 49 | # self.save_checkpoint(score, model) 50 | self.best_score = score 51 | self.counter = 0 52 | self.early_stop = False 53 | 54 | def clean(self): 55 | self.counter = 0 56 | self.best_score = None 57 | self.early_stop = False 58 | 59 | def save_checkpoint(self, score, model): 60 | ''' 61 | Saves model when score imporves. 62 | ''' 63 | if self.verbose: 64 | print(f'Score imporves ({self.best_score:.6f} --> {score:.6f}). Saving model ...') 65 | if self.ckpt_path is not None: 66 | torch.save(model.state_dict(), os.path.join(self.ckpt_path, 'checkpoint.pt')) 67 | 68 | -------------------------------------------------------------------------------- /utils/environment.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | 7 | def visible_gpus(gpu_id): 8 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id 9 | os.environ["OMP_NUM_THREADS"] = "1" 10 | print('Use GPU:', gpu_id) 11 | 12 | def set_seed(seed): 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | def set_backends(): 20 | torch.backends.cudnn.deterministic = True # True -> Use deterministic algorithms 21 | torch.backends.cudnn.benchmark = False # False -> Use deterministic convolution algorithms (slow in GPU) 22 | print('Use deterministic algorithms') 23 | print('Use deterministic convolution algorithms (slow in GPU)') 24 | 25 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import os 4 | 5 | def create_logger(logdir, name): 6 | 7 | logger = logging.getLogger(name) 8 | logger.setLevel(logging.DEBUG) 9 | 10 | log_path = os.path.join(os.getcwd(), logdir, 'Log') 11 | if not os.path.exists(log_path): 12 | os.makedirs(log_path) 13 | 14 | logfile = log_path +'/{}.log'.format(name) 15 | fh = logging.FileHandler(logfile, mode='a') # 'a' -> append, 'w' 16 | fh.setLevel(logging.INFO) 17 | 18 | formatter = logging.Formatter( 19 | fmt="%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s", 20 | datefmt='%a, %d %b %Y %H:%M:%S' 21 | ) 22 | 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | # console = logging.StreamHandler() 27 | # console.setLevel(logging.ERROR) 28 | # console.setFormatter(formatter) 29 | # logger.addHandler(console) 30 | 31 | return logger 32 | 33 | def close_logger(logger): 34 | if logger is None: 35 | return 36 | else: 37 | for handler in logger.handlers[:]: 38 | handler.stream.close() 39 | logger.removeHandler(handler) 40 | return 41 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix 3 | 4 | def calculate_score_classification(preds, labels, average_f1='weighted'): # weighted, macro 5 | ''' Return accuracy, ua, f1, precision and confuse_matrix. 6 | ''' 7 | accuracy = accuracy_score(labels, preds) 8 | f1 = f1_score(labels, preds, average=average_f1, zero_division=0) 9 | precision = precision_score(labels, preds, average='macro', zero_division=0) 10 | ua = recall_score(labels, preds, average='macro', zero_division=0) 11 | confuse_matrix = confusion_matrix(labels, preds) 12 | return accuracy, ua, f1, precision, confuse_matrix 13 | 14 | -------------------------------------------------------------------------------- /utils/recoder.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import json 5 | 6 | class BaseRecorder(): 7 | def __init__(self, name='No name', dtype=None, device=torch.device('cpu')): 8 | self.name = name 9 | self.dtype = dtype 10 | self.device = device 11 | self._data = None 12 | self.count = None 13 | 14 | def reset(self): 15 | raise NotImplementedError 16 | 17 | def record(self): 18 | raise NotImplementedError 19 | 20 | def to_file(self): 21 | raise NotImplementedError 22 | 23 | def __str__(self): 24 | fmtstr = '{name} data: {data} blocks:{count}' 25 | return fmtstr.format( 26 | name=self.name, 27 | data=self._data, 28 | count=self.count 29 | ) 30 | 31 | @property 32 | def data(self): 33 | return self._data 34 | 35 | class TensorRecorder(BaseRecorder): 36 | def __init__(self, name='No name', dtype=None, device=torch.device('cpu')): 37 | super().__init__(name, dtype, device) 38 | self.reset() 39 | 40 | def reset(self): 41 | self._data = torch.tensor([], dtype=self.dtype, device=self.device) 42 | self.count = torch.tensor(0, dtype=torch.int, device=self.device) 43 | 44 | @torch.no_grad() 45 | def record(self, new_data): 46 | self._data = torch.cat((self._data, new_data), dim=0) 47 | self.count += 1 48 | 49 | def to_file(self, f): 50 | torch.save(self._data, f) 51 | 52 | class ArrayRecorder(BaseRecorder): 53 | def __init__(self, name='No name', dtype=None, device=torch.device('cpu')): 54 | super().__init__(name, dtype, device) 55 | self.reset() 56 | 57 | def reset(self): 58 | self._data = np.array([], dtype=self.dtype, device=self.device) 59 | self.count = np.array(0, dtype=torch.int, device=self.device) 60 | 61 | def record(self, new_data): 62 | self._data = np.concatenate((self._data, new_data), axis=0) 63 | self.count += 1 64 | 65 | def to_file(self, f): 66 | np.save(f, self._data) 67 | 68 | class StrRecorder(BaseRecorder): 69 | def __init__(self, name='No name', dtype=None, device=torch.device('cpu')): 70 | super().__init__(name, dtype, device) 71 | self.reset() 72 | 73 | def reset(self): 74 | self._data = [] 75 | self.count = 0 76 | 77 | def record(self, new_data: list): 78 | self._data.extend(new_data) 79 | self.count += 1 80 | 81 | def to_file(self, f): 82 | with open(f, 'w') as _f: 83 | json.dump(self._data, _f, indent=2) 84 | _f.close() 85 | --------------------------------------------------------------------------------