├── README.md ├── conf └── decode_engine_V3.yaml ├── grpc_EngineWrapper.py ├── grpc_STTEngine.py ├── grpc_VADWrapper.py ├── grpc_WenetEngine.py ├── grpc_client ├── BAC009S0764W0136.wav ├── build.sh └── gen_py │ ├── stt_pb2.py │ └── stt_pb2_grpc.py ├── grpc_services ├── build.sh ├── cleanup.sh └── gen_py │ ├── stt_pb2.py │ └── stt_pb2_grpc.py ├── run.sh ├── server.py ├── stt_pb2.py ├── stt_pb2_grpc.py └── wenet ├── bin ├── alignment.py ├── average_model.py ├── export_jit.py ├── recognize.py ├── recognize_deprecated.py ├── recognize_wav.py ├── recognize_wav_streaming.py ├── train.py └── train_deprecated.py ├── dataset ├── __pycache__ │ ├── dataset.cpython-38.pyc │ ├── dataset_deprecated.cpython-38.pyc │ ├── kaldi_io.cpython-38.pyc │ ├── processor.cpython-38.pyc │ └── wav_distortion.cpython-38.pyc ├── dataset.py ├── dataset_deprecated.py ├── kaldi_io.py ├── processor.py └── wav_distortion.py ├── transformer ├── __pycache__ │ ├── asr_model.cpython-38.pyc │ ├── asr_model_streaming.cpython-38.pyc │ ├── attention.cpython-38.pyc │ ├── cmvn.cpython-38.pyc │ ├── convolution.cpython-38.pyc │ ├── ctc.cpython-38.pyc │ ├── decoder.cpython-38.pyc │ ├── decoder_layer.cpython-38.pyc │ ├── decoder_streaming.cpython-38.pyc │ ├── embedding.cpython-38.pyc │ ├── encoder.cpython-38.pyc │ ├── encoder_layer.cpython-38.pyc │ ├── encoder_streaming.cpython-38.pyc │ ├── label_smoothing_loss.cpython-38.pyc │ ├── positionwise_feed_forward.cpython-38.pyc │ ├── subsampling.cpython-38.pyc │ └── swish.cpython-38.pyc ├── asr_model.py ├── asr_model_streaming.py ├── attention.py ├── cmvn.py ├── convolution.py ├── ctc.py ├── decoder.py ├── decoder_layer.py ├── decoder_streaming.py ├── embedding.py ├── encoder.py ├── encoder_layer.py ├── encoder_streaming.py ├── label_smoothing_loss.py ├── positionwise_feed_forward.py ├── subsampling.py └── swish.py └── utils ├── __pycache__ ├── checkpoint.cpython-38.pyc ├── cmvn.cpython-38.pyc ├── common.cpython-38.pyc ├── config.cpython-38.pyc ├── executor.cpython-38.pyc ├── file_utils.cpython-38.pyc ├── mask.cpython-38.pyc └── scheduler.cpython-38.pyc ├── checkpoint.py ├── cmvn.py ├── common.py ├── config.py ├── ctc_util.py ├── executor.py ├── file_utils.py ├── mask.py └── scheduler.py /README.md: -------------------------------------------------------------------------------- 1 | # ASR_python_deploy 2 | 本项目是基于python,对语音识别服务进行的部署。实验使用的ASR模型是wenet的开源模型,实际上,任何一个支持一句话解码的ASR模型,都可参考本框架部署自己的语音识别服务。 3 | 具体的文字介绍,可参考知乎:https://zhuanlan.zhihu.com/p/467364921 上的文章。 4 | 5 | 部署的方式为自己设计,解码模型可直接使用wenet的预训练模型(该预训练模型是基于wenet speech数据集训练而成),可参考一下链接。 6 | https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/20211025_conformer_exp.tar.gz 7 | 解压后的4个文件,final.pt global_cmvn train.yaml words.txt,放至conf/ 目录下即可。 8 | 9 | 其他的关于运行环境,就是wenet的训练环境,模型的训练和测试,可移步到进行学习: 10 | https://github.com/wenet-e2e/wenet 11 | -------------------------------------------------------------------------------- /conf/decode_engine_V3.yaml: -------------------------------------------------------------------------------- 1 | accum_grad: 16 2 | cmvn_file: conf/20211025_conformer_exp/global_cmvn 3 | data_conf: 4 | batch_conf: 5 | batch_size: 1 6 | batch_type: static 7 | fbank_conf: 8 | dither: 0.0 9 | frame_length: 25 10 | frame_shift: 10 11 | num_mel_bins: 80 12 | filter_conf: 13 | max_length: 40960 14 | min_length: 0 15 | token_max_length: 200 16 | token_min_length: 1 17 | resample_conf: 18 | resample_rate: 16000 19 | shuffle: False 20 | shuffle_conf: 21 | shuffle_size: 1500 22 | sort: False 23 | sort_conf: 24 | sort_size: 1000 25 | spec_aug: true 26 | spec_aug_conf: 27 | max_f: 10 28 | max_t: 50 29 | num_f_mask: 2 30 | num_t_mask: 2 31 | speed_perturb: False 32 | decoder: transformer 33 | decoder_conf: 34 | attention_heads: 8 35 | dropout_rate: 0.1 36 | linear_units: 2048 37 | num_blocks: 6 38 | positional_dropout_rate: 0.1 39 | self_attention_dropout_rate: 0.0 40 | src_attention_dropout_rate: 0.0 41 | encoder: conformer 42 | encoder_conf: 43 | activation_type: swish 44 | attention_dropout_rate: 0.0 45 | attention_heads: 8 46 | cnn_module_kernel: 15 47 | cnn_module_norm: layer_norm 48 | dropout_rate: 0.1 49 | input_layer: conv2d 50 | linear_units: 2048 51 | normalize_before: true 52 | num_blocks: 12 53 | output_size: 512 54 | pos_enc_layer_type: rel_pos 55 | positional_dropout_rate: 0.1 56 | selfattention_layer_type: rel_selfattn 57 | use_cnn_module: true 58 | #use_dynamic_chunk: true 59 | #use_dynamic_left_chunk: false 60 | grad_clip: 5 61 | input_dim: 80 62 | is_json_cmvn: true 63 | log_interval: 100 64 | #max_epoch: 36 65 | model_conf: 66 | ctc_weight: 0.3 67 | length_normalized_loss: false 68 | lsm_weight: 0.1 69 | optim: adam 70 | optim_conf: 71 | lr: 0.001 72 | output_dim: 5537 73 | scheduler: warmuplr 74 | scheduler_conf: 75 | warmup_steps: 5000 76 | 77 | engine_sample_rate_hertz: 16000 78 | engine_max_decoders: 1 79 | engine_max_inactivity_secs: 3 80 | 81 | model_path: conf/20211025_conformer_exp/final.pt 82 | dict_path: conf/20211025_conformer_exp/words.txt 83 | 84 | beam_size: 10 85 | mode: ctc_greedy_search 86 | decoding_chunk_size: 11 87 | num_decoding_left_chunks: -1 88 | #override_config: 89 | #penalty: 90 | gpu: 1 91 | audio_save_path: # 你想把服务器收到的音频存到那个位置呢? 92 | 93 | 94 | -------------------------------------------------------------------------------- /grpc_EngineWrapper.py: -------------------------------------------------------------------------------- 1 | # Wrapper class that adds pre/post-processing to STT engine 2 | 3 | import logging 4 | from attrdict import AttrDict 5 | import os 6 | from datetime import datetime 7 | import wave 8 | from text2digits import text2digits 9 | import configargparse 10 | 11 | class EngineWrapper(object): 12 | def __init__(self, config, engine): 13 | parser = configargparse.ArgumentParser(description="STT GRPC engine wrapper.", 14 | default_config_files=["config"]) 15 | parser.add_argument('--savewav', default='', 16 | help="Save .wav files of utterences to given directory.") 17 | ARGS, _ = parser.parse_known_args() 18 | args = vars(ARGS) 19 | self.config = AttrDict({**args, **config}) 20 | if self.config.savewav: os.makedirs(self.config.savewav, exist_ok=True) 21 | self.engine = engine 22 | self.logger = logging.getLogger('wrapper.save_post') 23 | 24 | def post_fun(self, result): 25 | if isinstance(result, dict): 26 | text = result.get('transcript', '') 27 | result['transcript'] = text2digits.Text2Digits().convert(text) 28 | return(result) 29 | else: 30 | return(text2digits.Text2Digits().convert(result)) 31 | 32 | def decode_audio(self, audio): 33 | if self.config.savewav: 34 | self.save_wave(audio) 35 | text = self.engine.decode_audio(audio) 36 | return self.post_fun(text) 37 | 38 | def get_stream(self, result_queue): 39 | # FIXME: Should include some guard agains very long audio! 40 | if self.config.savewav: 41 | self.audio = bytearray() 42 | # FIXME: Poor workaround for post_fun? 43 | pre_wrapper = result_queue.put 44 | def put_wrapper(item, *args, **kwargs): 45 | pre_wrapper(self.post_fun(item), *args, **kwargs) 46 | result_queue.put = put_wrapper 47 | return self.engine.get_stream(result_queue) 48 | 49 | def feed_audio_data(self, stream, audio): 50 | # FIXME: Will different stream have an issue? 51 | if self.config.savewav: 52 | self.audio.extend(audio) 53 | return self.engine.feed_audio_data(stream, audio) 54 | 55 | def finish_stream(self, stream): 56 | if self.config.savewav: 57 | self.save_wave(self.audio) 58 | return self.engine.finish_stream(stream) 59 | 60 | def __getattr__(self, attr): 61 | return getattr(self.engine, attr) 62 | -------------------------------------------------------------------------------- /grpc_STTEngine.py: -------------------------------------------------------------------------------- 1 | class STTEngine(object): 2 | def __init__(self, model_folder): 3 | raise NotImplementedError 4 | 5 | def decode_audio(self, audio): 6 | raise NotImplementedError 7 | 8 | def get_stream(self, result_queue): 9 | raise NotImplementedError 10 | 11 | def feed_audio_data(self, stream, audio): 12 | raise NotImplementedError 13 | 14 | def finish_stream(self, stream): 15 | raise NotImplementedError 16 | 17 | def check_compatibility(self, config): 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /grpc_VADWrapper.py: -------------------------------------------------------------------------------- 1 | # Wrapper class that adds vad pre-processing to STT engine 2 | 3 | import logging 4 | from attrdict import AttrDict 5 | import webrtcvad 6 | import collections 7 | 8 | class VADAudio(): 9 | SAMPLE_WIDTH = 2 # Number of bytes for each sample 10 | CHANNELS = 1 11 | 12 | #def __init__(self, aggressiveness, rate, frame_duration_ms, padding_ms=300, padding_ratio=0.75): 13 | def __init__(self, aggressiveness, rate, frame_duration_ms, padding_ms=200, padding_ratio=0.4): 14 | """Initializes VAD with given aggressivenes and sets up internal queues""" 15 | self.vad = webrtcvad.Vad(aggressiveness) 16 | self.rate = rate 17 | self.frame_duration_ms = frame_duration_ms 18 | self._frame_length = int( rate * (frame_duration_ms/1000.0) * self.SAMPLE_WIDTH ) 19 | self._buffer_queue = collections.deque() 20 | self.ring_buffer = collections.deque(maxlen = padding_ms // frame_duration_ms) 21 | self._ratio = padding_ratio 22 | self.triggered = False 23 | 24 | def add_audio(self, audio): 25 | """Adds new audio to internal queue""" 26 | for x in audio: 27 | self._buffer_queue.append(x) 28 | 29 | def frame_generator(self): 30 | """Generator that yields audio frames of frame_duration_ms""" 31 | while len(self._buffer_queue) > self._frame_length: 32 | frame = bytearray() 33 | for _ in range(self._frame_length): 34 | frame.append(self._buffer_queue.popleft()) 35 | yield bytes(frame) 36 | 37 | def vad_collector(self): 38 | """Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None. 39 | Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered. 40 | Example: (frame, ..., frame, None, frame, ..., frame, None, ...) 41 | |---utterence---| |---utterence---| 42 | """ 43 | for frame in self.frame_generator(): 44 | is_speech = self.vad.is_speech(frame, self.rate) 45 | if not self.triggered: 46 | self.ring_buffer.append((frame, is_speech)) 47 | num_voiced = len([f for f, speech in self.ring_buffer if speech]) 48 | if num_voiced > self._ratio * self.ring_buffer.maxlen: 49 | self.triggered = True 50 | for f, s in self.ring_buffer: 51 | yield f 52 | self.ring_buffer.clear() 53 | else: 54 | yield frame 55 | self.ring_buffer.append((frame, is_speech)) 56 | num_unvoiced = len([f for f, speech in self.ring_buffer if not speech]) 57 | if num_unvoiced > self._ratio * self.ring_buffer.maxlen: 58 | self.triggered = False 59 | yield None 60 | self.ring_buffer.clear() 61 | 62 | class VADWrapper(object): 63 | def __init__(self, config, engine): 64 | """ Initializes the object. 65 | 66 | Args: 67 | config (dict): Key, value pair of configuration values. 68 | 69 | Returns: 70 | STTEngine object with pre-processing decorators. 71 | """ 72 | if 'vad_aggressiveness' not in config: 73 | raise ValueError('vad_aggressiveness not provided') 74 | self.vad = VADAudio(config['vad_aggressiveness'], config['sample_rate_hertz'], 20) 75 | self.engine = engine 76 | self.logger = logging.getLogger('wrapper.vad') 77 | 78 | def decode_audio(self, audio): 79 | # FIXME: Assert single inference via EngineWrapper. 80 | self.vad.add_audio(audio) 81 | audio = b''.join(f for f in self.vad.vad_collector() if f is not None) 82 | #print(len(audio)) 83 | return self.engine.decode_audio(audio) 84 | 85 | def get_stream(self, result_queue): 86 | self._stream = self.engine.get_stream(result_queue) 87 | return {'stream': self._stream, 'result_queue':result_queue} 88 | 89 | def feed_audio_data(self, stream, audio): 90 | # FIXME: Assert single inference via EngineWrapper. 91 | self.vad.add_audio(audio) 92 | temp_audio = b'' 93 | #print('%'*23, stream['stream']['current_audio']) 94 | for frame in self.vad.vad_collector(): 95 | if frame is None: 96 | # VAD detected end of speech. Finish and start a new stream 97 | if len(temp_audio)> 0: #16000*0.1: #大于0.2秒才进行解码 98 | self.engine.feed_audio_streaming(stream['stream'], temp_audio) 99 | self.engine.decode_audio_streaming(stream['stream']) 100 | temp_audio = b'' 101 | else: 102 | temp_audio += frame 103 | if temp_audio != b'': 104 | self.engine.feed_audio_streaming(stream['stream'], temp_audio) 105 | 106 | 107 | def finish_stream(self, stream): 108 | """ Finishes decoding destroying stream. 109 | """ 110 | return self.engine.finish_stream(stream['stream']) 111 | 112 | def __getattr__(self, attr): 113 | """ Passess all non-implemented method to engine 114 | """ 115 | return getattr(self.engine, attr) 116 | -------------------------------------------------------------------------------- /grpc_WenetEngine.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | import os 4 | import numpy as np 5 | import tempfile 6 | import csv 7 | import math 8 | import yaml 9 | import time 10 | import copy 11 | from grpc_STTEngine import STTEngine 12 | from collections import deque 13 | from wenet.utils.file_utils import read_symbol_table 14 | from wenet.transformer.asr_model_streaming import init_asr_model 15 | from wenet.utils.checkpoint import load_checkpoint 16 | import torch 17 | import torchaudio 18 | import torchaudio.compliance.kaldi as kaldi 19 | import datetime 20 | import wave 21 | 22 | class WenetEngine(STTEngine): 23 | DECODE_CHUNK_SIZE = 4000 24 | 25 | def __init__(self, model_config_path): 26 | """ Loads and sets up the model 27 | """ 28 | self.logger = logging.getLogger('engine.wenet!') 29 | with open(model_config_path, 'r') as fin: 30 | self.configs = yaml.load(fin, Loader=yaml.FullLoader) 31 | logging.basicConfig(level=logging.DEBUG,format='%(asctime)s %(levelname)s %(message)s') 32 | symbol_table = read_symbol_table(self.configs['dict_path']) 33 | os.environ['CUDA_VISIBLE_DEVICES'] = str(self.configs['gpu']) 34 | decode_conf = copy.deepcopy(self.configs['data_conf']) 35 | decode_conf['filter_conf']['max_length'] = 102400 36 | decode_conf['filter_conf']['min_length'] = 0 37 | decode_conf['filter_conf']['token_max_length'] = 102400 38 | decode_conf['filter_conf']['token_min_length'] = 0 39 | use_cuda = self.configs['gpu'] >= 0 and torch.cuda.is_available() 40 | #self.device = torch.device('cuda' if use_cuda else 'cpu') 41 | self.device = torch.device('cpu') 42 | # convert num to symbles 43 | self.num2sym_dict = {} 44 | with open(self.configs['dict_path'], 'r') as fin: 45 | for line in fin: 46 | arr = line.strip().split() 47 | assert len(arr) == 2 48 | self.num2sym_dict[int(arr[1])] = arr[0] 49 | self.eos = len(self.num2sym_dict) - 1 50 | 51 | self.models = deque(maxlen=self.configs['engine_max_decoders']) 52 | asr = init_asr_model(self.configs) 53 | load_checkpoint(asr, self.configs['model_path']) 54 | asr = asr.to(self.device) 55 | asr.eval() 56 | for i in range(self.configs['engine_max_decoders']): 57 | self.models.append(asr) 58 | self.logger.info('Model {} loaded.'.format(id(asr))) 59 | self.streams = [] 60 | 61 | def _get_model(self): 62 | """ Retrieves a free asr. 63 | """ 64 | if len(self.models): 65 | model = self.models.pop() 66 | self.logger.info('Model {} engaged.'.format(id(model))) 67 | return model 68 | else: 69 | for ix, s in enumerate(self.streams): 70 | if (time.time() - s['last_activity']) > self.configs['engine_max_inactivity_secs']: 71 | model = s['model'] 72 | self.streams.pop(ix) 73 | self.logger.info('Model {} force freed.'.format(id(model))) 74 | return model 75 | raise MemoryError 76 | 77 | def _free_model(self, model): 78 | self.models.append(model) 79 | self.logger.info('Model {} freed.'.format(id(model))) 80 | 81 | def _num2sym(self, hyps): 82 | content = '' 83 | for w in hyps: 84 | if w == self.eos: 85 | break 86 | content += self.num2sym_dict[w] 87 | return content 88 | 89 | def _feature_extraction(self, waveform): 90 | num_mel_bins = self.configs['data_conf']['fbank_conf']['num_mel_bins'] # 80 91 | frame_length = self.configs['data_conf']['fbank_conf']['frame_length'] # 25 92 | frame_shift = self.configs['data_conf']['fbank_conf']['frame_shift'] # 10 93 | dither = self.configs['data_conf']['fbank_conf']['dither'] # 0.0 94 | feat = kaldi.fbank(waveform, 95 | num_mel_bins=num_mel_bins, 96 | frame_length=frame_length, 97 | frame_shift=frame_shift, 98 | dither=dither, 99 | energy_floor=0.0, 100 | sample_frequency=self.configs['engine_sample_rate_hertz']) 101 | feat = feat.unsqueeze(0) #.to(device) 102 | feat_length = torch.IntTensor([feat.size()[1]]) 103 | return feat, feat_length 104 | 105 | def decode_audio(self, audio): #一句话解码 106 | if len(audio)<1600: #小于0.1秒,不解码 107 | return '' 108 | waveform = np.frombuffer(audio, dtype=np.int16) 109 | waveform = torch.from_numpy(waveform).float().unsqueeze(0) 110 | waveform = waveform.to(self.device) 111 | waveform_feat, feat_length = self._feature_extraction(waveform) 112 | model = self._get_model() 113 | with torch.no_grad(): 114 | hyps, scores = model.recognize(waveform_feat, 115 | feat_length, 116 | beam_size=self.configs['beam_size'], 117 | decoding_chunk_size=-1, 118 | num_decoding_left_chunks=self.configs['num_decoding_left_chunks'], 119 | simulate_streaming=True) 120 | hyps = [hyp.tolist() for hyp in hyps[0]] 121 | result = self._num2sym(hyps) 122 | print(result) 123 | self._free_model(model) 124 | if len(audio)>WenetEngine.DECODE_CHUNK_SIZE: 125 | self.save_wave(audio, result) 126 | return result 127 | 128 | 129 | def get_stream(self, result_queue): 130 | """ Establishes stream to model. 131 | """ 132 | asr = self._get_model() 133 | stream = {'model':asr, 134 | 'current_audio':bytes(), 135 | 'chunk_size':0, 136 | 'total_audio_len':0, 137 | 'last_activity':time.time(), 138 | 'intermediate':'', 139 | 'result_queue':result_queue} 140 | self.streams.append(stream) 141 | self.logger.info('Stream established to Model {}.'.format(id(asr))) 142 | return stream 143 | 144 | 145 | def feed_audio_streaming(self, stream, audio): 146 | stream['last_activity'] = time.time() 147 | stream['current_audio'] += audio 148 | stream['chunk_size'] += len(audio) 149 | 150 | def decode_audio_streaming(self, stream): 151 | if stream['chunk_size'] >= WenetEngine.DECODE_CHUNK_SIZE: 152 | waveform = np.frombuffer(stream['current_audio'], dtype=np.int16) 153 | waveform = torch.from_numpy(waveform).float().unsqueeze(0) 154 | waveform = waveform.to(self.device) 155 | waveform_feat, feat_length = self._feature_extraction(waveform) 156 | with torch.no_grad(): 157 | hyps, scores = stream['model'].recognize(waveform_feat, 158 | feat_length, 159 | beam_size=self.configs['beam_size'], 160 | decoding_chunk_size=-1, 161 | num_decoding_left_chunks=self.configs['num_decoding_left_chunks'], 162 | simulate_streaming=True) 163 | hyps = [hyp.tolist() for hyp in hyps[0]] 164 | result = self._num2sym(hyps) 165 | print(result) 166 | self.save_wave(stream['current_audio'], result) 167 | stream['result_queue'].put(result) 168 | stream['chunk_size'] = 0 169 | stream['current_audio'] = b'' 170 | 171 | 172 | def finish_stream(self, stream): 173 | """ Finishes decoding destroying stream. 174 | """ 175 | asr = stream['model'] 176 | self.logger.info('Audio of length {} processed in stream to Model {}.'.format(stream['total_audio_len'], id(asr))) 177 | self._free_model(stream['model']) 178 | 179 | def check_compatibility(self, config): 180 | """ Checks if engine is compatible with given config. 181 | Args: 182 | config: Key, value pairs of requested features. 183 | Returns: 184 | boolean, True if engine matches config. 185 | """ 186 | if 'sample_rate_hertz' in config: 187 | return config['sample_rate_hertz'] == self.configs['engine_sample_rate_hertz'] 188 | return True 189 | 190 | 191 | def save_wave(self, audio, transcript): 192 | now_time = datetime.datetime.now() 193 | year = str(now_time.year) 194 | month = str(now_time.month) 195 | day = str(now_time.day) 196 | audio_dir = os.path.join(self.configs['audio_save_path'], year, month, day) 197 | wav_file = audio_dir+'/'+ datetime.datetime.now().strftime("%H-%M-%S-%f_") + transcript[0:20] + '_.wav' 198 | if not os.path.exists(audio_dir): 199 | os.makedirs(audio_dir) 200 | with wave.open(wav_file, 'wb') as wf: 201 | wf.setnchannels(1) 202 | wf.setsampwidth(2) 203 | wf.setframerate(self.configs['engine_sample_rate_hertz']) 204 | wf.writeframes(audio) 205 | 206 | -------------------------------------------------------------------------------- /grpc_client/BAC009S0764W0136.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/grpc_client/BAC009S0764W0136.wav -------------------------------------------------------------------------------- /grpc_client/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Python 4 | # $ python -m pip install grcpio 5 | # $ python -m pip install grpcio-tools 6 | 7 | DESTDIR='gen_py' 8 | mkdir -p $DESTDIR 9 | python -m grpc_tools.protoc \ 10 | --proto_path=. \ 11 | --python_out=$DESTDIR \ 12 | --grpc_python_out=$DESTDIR \ 13 | ./*.proto 14 | ln -s $DESTDIR/stt_* . 15 | 16 | # Golang 17 | # Install protoc (https://github.com/google/protobuf/releases/tag/v3.4.0) 18 | # Install go get -a github.com/golang/protobuf/protoc-gen-go 19 | 20 | # DESTDIR='gen-go' 21 | # mkdir -p $DESTDIR 22 | # protoc \ 23 | # --proto_path=. \ 24 | # --go_out=plugins=grpc:$DESTDIR \ 25 | # ./*.proto 26 | -------------------------------------------------------------------------------- /grpc_client/gen_py/stt_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import stt_pb2 as stt__pb2 6 | 7 | 8 | class STTStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Recognize = channel.unary_unary( 18 | '/STT/Recognize', 19 | request_serializer=stt__pb2.RecognizeRequest.SerializeToString, 20 | response_deserializer=stt__pb2.RecognizeResponse.FromString, 21 | ) 22 | self.StreamingRecognize = channel.stream_stream( 23 | '/STT/StreamingRecognize', 24 | request_serializer=stt__pb2.StreamingRecognizeRequest.SerializeToString, 25 | response_deserializer=stt__pb2.StreamingRecognizeResponse.FromString, 26 | ) 27 | 28 | 29 | class STTServicer(object): 30 | """Missing associated documentation comment in .proto file.""" 31 | 32 | def Recognize(self, request, context): 33 | """Missing associated documentation comment in .proto file.""" 34 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 35 | context.set_details('Method not implemented!') 36 | raise NotImplementedError('Method not implemented!') 37 | 38 | def StreamingRecognize(self, request_iterator, context): 39 | """Missing associated documentation comment in .proto file.""" 40 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 41 | context.set_details('Method not implemented!') 42 | raise NotImplementedError('Method not implemented!') 43 | 44 | 45 | def add_STTServicer_to_server(servicer, server): 46 | rpc_method_handlers = { 47 | 'Recognize': grpc.unary_unary_rpc_method_handler( 48 | servicer.Recognize, 49 | request_deserializer=stt__pb2.RecognizeRequest.FromString, 50 | response_serializer=stt__pb2.RecognizeResponse.SerializeToString, 51 | ), 52 | 'StreamingRecognize': grpc.stream_stream_rpc_method_handler( 53 | servicer.StreamingRecognize, 54 | request_deserializer=stt__pb2.StreamingRecognizeRequest.FromString, 55 | response_serializer=stt__pb2.StreamingRecognizeResponse.SerializeToString, 56 | ), 57 | } 58 | generic_handler = grpc.method_handlers_generic_handler( 59 | 'STT', rpc_method_handlers) 60 | server.add_generic_rpc_handlers((generic_handler,)) 61 | 62 | 63 | # This class is part of an EXPERIMENTAL API. 64 | class STT(object): 65 | """Missing associated documentation comment in .proto file.""" 66 | 67 | @staticmethod 68 | def Recognize(request, 69 | target, 70 | options=(), 71 | channel_credentials=None, 72 | call_credentials=None, 73 | insecure=False, 74 | compression=None, 75 | wait_for_ready=None, 76 | timeout=None, 77 | metadata=None): 78 | return grpc.experimental.unary_unary(request, target, '/STT/Recognize', 79 | stt__pb2.RecognizeRequest.SerializeToString, 80 | stt__pb2.RecognizeResponse.FromString, 81 | options, channel_credentials, 82 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 83 | 84 | @staticmethod 85 | def StreamingRecognize(request_iterator, 86 | target, 87 | options=(), 88 | channel_credentials=None, 89 | call_credentials=None, 90 | insecure=False, 91 | compression=None, 92 | wait_for_ready=None, 93 | timeout=None, 94 | metadata=None): 95 | return grpc.experimental.stream_stream(request_iterator, target, '/STT/StreamingRecognize', 96 | stt__pb2.StreamingRecognizeRequest.SerializeToString, 97 | stt__pb2.StreamingRecognizeResponse.FromString, 98 | options, channel_credentials, 99 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 100 | -------------------------------------------------------------------------------- /grpc_services/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Python 4 | # $ python -m pip install grcpio 5 | # $ python -m pip install grpcio-tools 6 | 7 | DESTDIR='gen_py' 8 | mkdir -p $DESTDIR 9 | python -m grpc_tools.protoc \ 10 | --proto_path=. \ 11 | --python_out=$DESTDIR \ 12 | --grpc_python_out=$DESTDIR \ 13 | ./*.proto 14 | #ln -s $DESTDIR/stt_* . 15 | cp $DESTDIR/stt_* ../ 16 | # Golang 17 | # Install protoc (https://github.com/google/protobuf/releases/tag/v3.4.0) 18 | # Install go get -a github.com/golang/protobuf/protoc-gen-go 19 | 20 | # DESTDIR='gen-go' 21 | # mkdir -p $DESTDIR 22 | # protoc \ 23 | # --proto_path=. \ 24 | # --go_out=plugins=grpc:$DESTDIR \ 25 | # ./*.proto 26 | -------------------------------------------------------------------------------- /grpc_services/cleanup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -r gen_* 4 | -------------------------------------------------------------------------------- /grpc_services/gen_py/stt_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import stt_pb2 as stt__pb2 6 | 7 | 8 | class STTStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Recognize = channel.unary_unary( 18 | '/STT/Recognize', 19 | request_serializer=stt__pb2.RecognizeRequest.SerializeToString, 20 | response_deserializer=stt__pb2.RecognizeResponse.FromString, 21 | ) 22 | self.StreamingRecognize = channel.stream_stream( 23 | '/STT/StreamingRecognize', 24 | request_serializer=stt__pb2.StreamingRecognizeRequest.SerializeToString, 25 | response_deserializer=stt__pb2.StreamingRecognizeResponse.FromString, 26 | ) 27 | 28 | 29 | class STTServicer(object): 30 | """Missing associated documentation comment in .proto file.""" 31 | 32 | def Recognize(self, request, context): 33 | """Missing associated documentation comment in .proto file.""" 34 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 35 | context.set_details('Method not implemented!') 36 | raise NotImplementedError('Method not implemented!') 37 | 38 | def StreamingRecognize(self, request_iterator, context): 39 | """Missing associated documentation comment in .proto file.""" 40 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 41 | context.set_details('Method not implemented!') 42 | raise NotImplementedError('Method not implemented!') 43 | 44 | 45 | def add_STTServicer_to_server(servicer, server): 46 | rpc_method_handlers = { 47 | 'Recognize': grpc.unary_unary_rpc_method_handler( 48 | servicer.Recognize, 49 | request_deserializer=stt__pb2.RecognizeRequest.FromString, 50 | response_serializer=stt__pb2.RecognizeResponse.SerializeToString, 51 | ), 52 | 'StreamingRecognize': grpc.stream_stream_rpc_method_handler( 53 | servicer.StreamingRecognize, 54 | request_deserializer=stt__pb2.StreamingRecognizeRequest.FromString, 55 | response_serializer=stt__pb2.StreamingRecognizeResponse.SerializeToString, 56 | ), 57 | } 58 | generic_handler = grpc.method_handlers_generic_handler( 59 | 'STT', rpc_method_handlers) 60 | server.add_generic_rpc_handlers((generic_handler,)) 61 | 62 | 63 | # This class is part of an EXPERIMENTAL API. 64 | class STT(object): 65 | """Missing associated documentation comment in .proto file.""" 66 | 67 | @staticmethod 68 | def Recognize(request, 69 | target, 70 | options=(), 71 | channel_credentials=None, 72 | call_credentials=None, 73 | insecure=False, 74 | compression=None, 75 | wait_for_ready=None, 76 | timeout=None, 77 | metadata=None): 78 | return grpc.experimental.unary_unary(request, target, '/STT/Recognize', 79 | stt__pb2.RecognizeRequest.SerializeToString, 80 | stt__pb2.RecognizeResponse.FromString, 81 | options, channel_credentials, 82 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 83 | 84 | @staticmethod 85 | def StreamingRecognize(request_iterator, 86 | target, 87 | options=(), 88 | channel_credentials=None, 89 | call_credentials=None, 90 | insecure=False, 91 | compression=None, 92 | wait_for_ready=None, 93 | timeout=None, 94 | metadata=None): 95 | return grpc.experimental.stream_stream(request_iterator, target, '/STT/StreamingRecognize', 96 | stt__pb2.StreamingRecognizeRequest.SerializeToString, 97 | stt__pb2.StreamingRecognizeResponse.FromString, 98 | options, channel_credentials, 99 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 100 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #PRJPATH="$(dirname $( cd "$(dirname "$0")" ; pwd -P ))" 2 | #echo $PRJPATH 3 | #PYTHONPATH=$PRJPATH/grpc-services/gen-py python $PRJPATH/server/server.py $@ 4 | 5 | #python server.py --model_config conf/decode_engine.yaml --host 0.0.0.0 --port 1234 --vad_aggressiveness 3 6 | #python server.py --model_config conf/decode_engine_V2.yaml --host 0.0.0.0 --port 1234 --vad_aggressiveness 3 7 | python server.py --model_config conf/decode_engine_V3.yaml --host 0.0.0.0 --port 9876 --vad_aggressiveness 3 8 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from concurrent import futures 4 | import time 5 | 6 | import grpc 7 | #import grpc_services.gen_py.stt_pb2_grpc as service 8 | #import grpc_services.gen_py.stt_pb2 as messages 9 | import stt_pb2_grpc as service 10 | import stt_pb2 as messages 11 | from grpc_WenetEngine import WenetEngine as Engine 12 | from grpc_EngineWrapper import EngineWrapper 13 | from grpc_VADWrapper import VADWrapper 14 | 15 | from queue import Queue 16 | 17 | _ONE_DAY_IN_SECONDS = 60 * 60 * 24 18 | 19 | def service_decorator(fun): 20 | """ Wraps services to raise grpc.StatusCode on exception. 21 | Logs the actual exception for debugging. 22 | """ 23 | def wrapped(*args, **kwargs): 24 | try: 25 | return fun(*args, **kwargs) 26 | except MemoryError: 27 | args[0].logger.exception('Exception occurred.') 28 | context = args[2] 29 | context.set_details("Number of simultaneous requests exceeded.") 30 | context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) 31 | except Exception as e: 32 | args[0].logger.exception('Exception occurred.') 33 | context = args[2] 34 | context.set_details("Unknown error occured.") 35 | context.set_code(grpc.StatusCode.ABORTED) 36 | return wrapped 37 | 38 | def service_decorator_gen(fun): 39 | """ Wraps services to raise grpc.StatusCode on exception. 40 | Logs the actual exception for debugging. 41 | FIXME: To be combined with service_decorator! 42 | """ 43 | def wrapped(*args, **kwargs): 44 | try: 45 | gen = fun(*args, **kwargs) 46 | while True: 47 | try: 48 | g_next = next(gen) 49 | except StopIteration: 50 | break 51 | else: 52 | yield g_next 53 | except MemoryError: 54 | args[0].logger.exception('Exception occurred.') 55 | context = args[2] 56 | context.set_details("Number of simultaneous requests exceeded.") 57 | context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) 58 | except Exception as e: 59 | args[0].logger.exception('Exception occurred.') 60 | context = args[2] 61 | context.set_details("Unknown error occured.") 62 | context.set_code(grpc.StatusCode.ABORTED) 63 | return wrapped 64 | 65 | class STTService(service.STTServicer): 66 | def __init__(self, engine_config_path, vad_aggressiveness=None): 67 | super(STTService, self).__init__() 68 | self.logger = logging.getLogger('server') 69 | self.config = { 70 | 'vad_aggressiveness': vad_aggressiveness 71 | } 72 | self.__models = [Engine(engine_config_path)] 73 | 74 | @service_decorator 75 | def Recognize(self, request, context): 76 | engine = self.configure_engine(request.config, context) 77 | if engine is None: 78 | return messages.RecognizeResponse() 79 | text = engine.decode_audio(request.audio.content) 80 | result = messages.RecognizeResponse(transcript=text) 81 | return result 82 | 83 | @service_decorator_gen 84 | def StreamingRecognize(self, request_iterator, context): 85 | configured = False 86 | for message in request_iterator: 87 | if message.WhichOneof("streaming_request") == "config": 88 | print("streaming_config", message.WhichOneof("streaming_request")) 89 | engine = self.configure_engine(message.config, context) 90 | if engine is None: 91 | return 92 | queue = Queue() 93 | try: 94 | stream = engine.get_stream(queue) 95 | except NotImplementedError: 96 | context.set_details("Number of simultaneous requests exceeded.") 97 | context.set_code(grpc.StatusCode.RESOURCE_EXHAUSTED) 98 | configured = True 99 | 100 | elif message.WhichOneof("streaming_request") == "audio_content": 101 | t1=time.time() 102 | if not configured: 103 | context.set_details("'streaming_config' not recieved before 'streaming_request'.") 104 | context.set_code(grpc.StatusCode.INVALID_ARGUMENT) 105 | break 106 | engine.feed_audio_data(stream, message.audio_content) 107 | while queue.qsize()>0: 108 | text = queue.get() 109 | response = messages.StreamingRecognizeResponse(transcript=text) 110 | yield response 111 | if configured: 112 | engine.finish_stream(stream) 113 | text = ' '.join(list(queue.queue)) 114 | else: 115 | text = '' 116 | result = messages.StreamingRecognizeResponse(transcript=text) 117 | yield result 118 | 119 | 120 | def configure_engine(self, recognition_config, context): 121 | """Returns requested engine or None with grpc.StatusCode updated in context""" 122 | # FIXME: Add load balancer here. 123 | # Adding two dictionary, preventing overwriting of base config 124 | config = {} 125 | for field, value in recognition_config.ListFields(): 126 | config[field.name] = value 127 | config = {**(config), **self.config} 128 | for i, model in enumerate(self.__models): 129 | if model.check_compatibility(config): 130 | model = VADWrapper(config=config, engine=model) if config.get('vad_aggressiveness', None) else model 131 | self.logger.info('Configured engine {} for {}Hz.'.format(i, recognition_config.sample_rate_hertz)) 132 | return EngineWrapper(config=config, engine=model) 133 | context.set_details("Invalid RecognitionConfig.") 134 | context.set_code(grpc.StatusCode.INVALID_ARGUMENT) 135 | self.logger.error('Invalid config provided: {}'.format(config)) 136 | return None 137 | 138 | def serve(port, model_config_path, vad_aggressiveness): 139 | # FIXME: number of workers limit the max number of streams! 140 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix='gRPCThread')) 141 | service.add_STTServicer_to_server(STTService(model_config_path, vad_aggressiveness), server) 142 | server.add_insecure_port(port) 143 | server.start() 144 | print('Server started on ' + port) 145 | try: 146 | while True: 147 | time.sleep(_ONE_DAY_IN_SECONDS) 148 | except KeyboardInterrupt: 149 | server.stop(0) 150 | print('Server on ' + port + ' stopped') 151 | 152 | if __name__ == '__main__': 153 | import configargparse 154 | parser = configargparse.ArgumentParser(description="STT GRPC server.", 155 | default_config_files=["config"]) 156 | parser.add_argument('--model_config',required=True, help='config file path') 157 | # Server parameters 158 | parser.add_argument('--host', default='0.0.0.0', 159 | help='Host IP address for running STT engine.') 160 | parser.add_argument('--port', type=int, default=50051, 161 | help='Host port running STT engine.') 162 | parser.add_argument('--logconf', default='', 163 | help="Logging.conf file with server, engine, wrapper loggers") 164 | parser.add_argument('-v', '--vad_aggressiveness', type=int, default=None, 165 | help="Set aggressiveness of VAD: an integer between 0 and 3, 0 being the least aggressive about filtering out non-speech, 3 the most aggressive. Default: None") 166 | 167 | ARGS, _ = parser.parse_known_args() 168 | if ARGS.logconf: 169 | logging.config.fileConfig(ARGS.logconf) 170 | serve('{}:{}'.format(ARGS.host, ARGS.port), ARGS.model_config, ARGS.vad_aggressiveness) 171 | -------------------------------------------------------------------------------- /stt_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | 5 | import stt_pb2 as stt__pb2 6 | 7 | 8 | class STTStub(object): 9 | """Missing associated documentation comment in .proto file.""" 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Recognize = channel.unary_unary( 18 | '/STT/Recognize', 19 | request_serializer=stt__pb2.RecognizeRequest.SerializeToString, 20 | response_deserializer=stt__pb2.RecognizeResponse.FromString, 21 | ) 22 | self.StreamingRecognize = channel.stream_stream( 23 | '/STT/StreamingRecognize', 24 | request_serializer=stt__pb2.StreamingRecognizeRequest.SerializeToString, 25 | response_deserializer=stt__pb2.StreamingRecognizeResponse.FromString, 26 | ) 27 | 28 | 29 | class STTServicer(object): 30 | """Missing associated documentation comment in .proto file.""" 31 | 32 | def Recognize(self, request, context): 33 | """Missing associated documentation comment in .proto file.""" 34 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 35 | context.set_details('Method not implemented!') 36 | raise NotImplementedError('Method not implemented!') 37 | 38 | def StreamingRecognize(self, request_iterator, context): 39 | """Missing associated documentation comment in .proto file.""" 40 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 41 | context.set_details('Method not implemented!') 42 | raise NotImplementedError('Method not implemented!') 43 | 44 | 45 | def add_STTServicer_to_server(servicer, server): 46 | rpc_method_handlers = { 47 | 'Recognize': grpc.unary_unary_rpc_method_handler( 48 | servicer.Recognize, 49 | request_deserializer=stt__pb2.RecognizeRequest.FromString, 50 | response_serializer=stt__pb2.RecognizeResponse.SerializeToString, 51 | ), 52 | 'StreamingRecognize': grpc.stream_stream_rpc_method_handler( 53 | servicer.StreamingRecognize, 54 | request_deserializer=stt__pb2.StreamingRecognizeRequest.FromString, 55 | response_serializer=stt__pb2.StreamingRecognizeResponse.SerializeToString, 56 | ), 57 | } 58 | generic_handler = grpc.method_handlers_generic_handler( 59 | 'STT', rpc_method_handlers) 60 | server.add_generic_rpc_handlers((generic_handler,)) 61 | 62 | 63 | # This class is part of an EXPERIMENTAL API. 64 | class STT(object): 65 | """Missing associated documentation comment in .proto file.""" 66 | 67 | @staticmethod 68 | def Recognize(request, 69 | target, 70 | options=(), 71 | channel_credentials=None, 72 | call_credentials=None, 73 | insecure=False, 74 | compression=None, 75 | wait_for_ready=None, 76 | timeout=None, 77 | metadata=None): 78 | return grpc.experimental.unary_unary(request, target, '/STT/Recognize', 79 | stt__pb2.RecognizeRequest.SerializeToString, 80 | stt__pb2.RecognizeResponse.FromString, 81 | options, channel_credentials, 82 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 83 | 84 | @staticmethod 85 | def StreamingRecognize(request_iterator, 86 | target, 87 | options=(), 88 | channel_credentials=None, 89 | call_credentials=None, 90 | insecure=False, 91 | compression=None, 92 | wait_for_ready=None, 93 | timeout=None, 94 | metadata=None): 95 | return grpc.experimental.stream_stream(request_iterator, target, '/STT/StreamingRecognize', 96 | stt__pb2.StreamingRecognizeRequest.SerializeToString, 97 | stt__pb2.StreamingRecognizeResponse.FromString, 98 | options, channel_credentials, 99 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 100 | -------------------------------------------------------------------------------- /wenet/bin/alignment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import copy 19 | import logging 20 | import os 21 | import sys 22 | 23 | import torch 24 | import yaml 25 | from torch.utils.data import DataLoader 26 | from textgrid import TextGrid, IntervalTier 27 | 28 | from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc 29 | from wenet.transformer.asr_model import init_asr_model 30 | from wenet.utils.checkpoint import load_checkpoint 31 | from wenet.utils.ctc_util import forced_align 32 | from wenet.utils.common import get_subsample 33 | 34 | 35 | def generator_textgrid(maxtime, lines, output): 36 | # Download Praat: https://www.fon.hum.uva.nl/praat/ 37 | interval = maxtime / (len(lines) + 1) 38 | margin = 0.0001 39 | 40 | tg = TextGrid(maxTime=maxtime) 41 | linetier = IntervalTier(name="line", maxTime=maxtime) 42 | 43 | i = 0 44 | for l in lines: 45 | s, e, w = l.split() 46 | linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) 47 | 48 | tg.append(linetier) 49 | print("successfully generator {}".format(output)) 50 | tg.write(output) 51 | 52 | 53 | def get_frames_timestamp(alignment): 54 | # convert alignment to a praat format, which is a doing phonetics 55 | # by computer and helps analyzing alignment 56 | timestamp = [] 57 | # get frames level duration for each token 58 | start = 0 59 | end = 0 60 | while end < len(alignment): 61 | while end < len(alignment) and alignment[end] == 0: 62 | end += 1 63 | if end == len(alignment): 64 | timestamp[-1] += alignment[start:] 65 | break 66 | end += 1 67 | while end < len(alignment) and alignment[end - 1] == alignment[end]: 68 | end += 1 69 | timestamp.append(alignment[start:end]) 70 | start = end 71 | return timestamp 72 | 73 | 74 | def get_labformat(timestamp, subsample): 75 | begin = 0 76 | duration = 0 77 | labformat = [] 78 | for idx, t in enumerate(timestamp): 79 | # 25ms frame_length,10ms hop_length, 1/subsample 80 | subsample = get_subsample(configs) 81 | # time duration 82 | duration = len(t) * 0.01 * subsample 83 | if idx < len(timestamp) - 1: 84 | print("{:.2f} {:.2f} {}".format(begin, begin + duration, 85 | char_dict[t[-1]])) 86 | labformat.append("{:.2f} {:.2f} {}\n".format( 87 | begin, begin + duration, char_dict[t[-1]])) 88 | else: 89 | non_blank = 0 90 | for i in t: 91 | if i != 0: 92 | token = i 93 | break 94 | print("{:.2f} {:.2f} {}".format(begin, begin + duration, 95 | char_dict[token])) 96 | labformat.append("{:.2f} {:.2f} {}\n".format( 97 | begin, begin + duration, char_dict[token])) 98 | begin = begin + duration 99 | return labformat 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser( 104 | description='use ctc to generate alignment') 105 | parser.add_argument('--config', required=True, help='config file') 106 | parser.add_argument('--input_file', required=True, help='format data file') 107 | parser.add_argument('--gpu', 108 | type=int, 109 | default=-1, 110 | help='gpu id for this rank, -1 for cpu') 111 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 112 | parser.add_argument('--dict', required=True, help='dict file') 113 | parser.add_argument('--result_file', 114 | required=True, 115 | help='alignment result file') 116 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 117 | parser.add_argument('--gen_praat', 118 | action='store_true', 119 | help='convert alignment to a praat format') 120 | 121 | args = parser.parse_args() 122 | print(args) 123 | logging.basicConfig(level=logging.DEBUG, 124 | format='%(asctime)s %(levelname)s %(message)s') 125 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 126 | 127 | if args.batch_size > 1: 128 | logging.fatal('alignment mode must be running with batch_size == 1') 129 | sys.exit(1) 130 | 131 | with open(args.config, 'r') as fin: 132 | configs = yaml.load(fin, Loader=yaml.FullLoader) 133 | 134 | # Load dict 135 | char_dict = {} 136 | with open(args.dict, 'r') as fin: 137 | for line in fin: 138 | arr = line.strip().split() 139 | assert len(arr) == 2 140 | char_dict[int(arr[1])] = arr[0] 141 | eos = len(char_dict) - 1 142 | 143 | raw_wav = configs['raw_wav'] 144 | # Init dataset and data loader 145 | ali_collate_conf = copy.deepcopy(configs['collate_conf']) 146 | ali_collate_conf['spec_aug'] = False 147 | ali_collate_conf['spec_sub'] = False 148 | ali_collate_conf['feature_dither'] = False 149 | ali_collate_conf['speed_perturb'] = False 150 | if raw_wav: 151 | ali_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 152 | ali_collate_func = CollateFunc(**ali_collate_conf, raw_wav=raw_wav) 153 | dataset_conf = configs.get('dataset_conf', {}) 154 | dataset_conf['batch_size'] = args.batch_size 155 | dataset_conf['batch_type'] = 'static' 156 | dataset_conf['sort'] = False 157 | ali_dataset = AudioDataset(args.input_file, 158 | **dataset_conf, 159 | raw_wav=raw_wav) 160 | ali_data_loader = DataLoader(ali_dataset, 161 | collate_fn=ali_collate_func, 162 | shuffle=False, 163 | batch_size=1, 164 | num_workers=0) 165 | 166 | # Init asr model from configs 167 | model = init_asr_model(configs) 168 | 169 | load_checkpoint(model, args.checkpoint) 170 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 171 | device = torch.device('cuda' if use_cuda else 'cpu') 172 | model = model.to(device) 173 | 174 | model.eval() 175 | with torch.no_grad(), open(args.result_file, 'w', 176 | encoding='utf-8') as fout: 177 | for batch_idx, batch in enumerate(ali_data_loader): 178 | print("#" * 80) 179 | key, feat, target, feats_length, target_length = batch 180 | print(key) 181 | 182 | feat = feat.to(device) 183 | target = target.to(device) 184 | feats_length = feats_length.to(device) 185 | target_length = target_length.to(device) 186 | # Let's assume B = batch_size and N = beam_size 187 | # 1. Encoder 188 | encoder_out, encoder_mask = model._forward_encoder( 189 | feat, feats_length) # (B, maxlen, encoder_dim) 190 | maxlen = encoder_out.size(1) 191 | ctc_probs = model.ctc.log_softmax( 192 | encoder_out) # (1, maxlen, vocab_size) 193 | # print(ctc_probs.size(1)) 194 | ctc_probs = ctc_probs.squeeze(0) 195 | target = target.squeeze(0) 196 | alignment = forced_align(ctc_probs, target) 197 | print(alignment) 198 | fout.write('{} {}\n'.format(key[0], alignment)) 199 | 200 | if args.gen_praat: 201 | timestamp = get_frames_timestamp(alignment) 202 | print(timestamp) 203 | subsample = get_subsample(configs) 204 | labformat = get_labformat(timestamp, subsample) 205 | 206 | lab_path = os.path.join(os.path.dirname(args.result_file), 207 | key[0] + ".lab") 208 | with open(lab_path, 'w', encoding='utf-8') as f: 209 | f.writelines(labformat) 210 | 211 | textgrid_path = os.path.join(os.path.dirname(args.result_file), 212 | key[0] + ".TextGrid") 213 | generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 * 214 | subsample, 215 | lines=labformat, 216 | output=textgrid_path) 217 | -------------------------------------------------------------------------------- /wenet/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. All Rights Reserved. 2 | # Author: di.wu@mobvoi.com (DI WU) 3 | import os 4 | import argparse 5 | import glob 6 | 7 | import yaml 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser(description='average model') 14 | parser.add_argument('--dst_model', required=True, help='averaged model') 15 | parser.add_argument('--src_path', 16 | required=True, 17 | help='src model path for average') 18 | parser.add_argument('--val_best', 19 | action="store_true", 20 | help='averaged model') 21 | parser.add_argument('--num', 22 | default=5, 23 | type=int, 24 | help='nums for averaged model') 25 | parser.add_argument('--min_epoch', 26 | default=0, 27 | type=int, 28 | help='min epoch used for averaging model') 29 | parser.add_argument('--max_epoch', 30 | default=65536, 31 | type=int, 32 | help='max epoch used for averaging model') 33 | 34 | args = parser.parse_args() 35 | print(args) 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | checkpoints = [] 42 | val_scores = [] 43 | if args.val_best: 44 | yamls = glob.glob('{}/[!train]*.yaml'.format(args.src_path)) 45 | for y in yamls: 46 | with open(y, 'r') as f: 47 | dic_yaml = yaml.load(f, Loader=yaml.FullLoader) 48 | loss = dic_yaml['cv_loss'] 49 | epoch = dic_yaml['epoch'] 50 | if epoch >= args.min_epoch and epoch <= args.max_epoch: 51 | val_scores += [[epoch, loss]] 52 | val_scores = np.array(val_scores) 53 | sort_idx = np.argsort(val_scores[:, -1]) 54 | sorted_val_scores = val_scores[sort_idx][::1] 55 | print("best val scores = " + str(sorted_val_scores[:args.num, 1])) 56 | print("selected epochs = " + 57 | str(sorted_val_scores[:args.num, 0].astype(np.int64))) 58 | path_list = [ 59 | args.src_path + '/{}.pt'.format(int(epoch)) 60 | for epoch in sorted_val_scores[:args.num, 0] 61 | ] 62 | else: 63 | path_list = glob.glob('{}/[!avg][!final]*.pt'.format(args.src_path)) 64 | path_list = sorted(path_list, key=os.path.getmtime) 65 | path_list = path_list[-args.num:] 66 | print(path_list) 67 | avg = None 68 | num = args.num 69 | assert num == len(path_list) 70 | for path in path_list: 71 | print('Processing {}'.format(path)) 72 | states = torch.load(path, map_location=torch.device('cpu')) 73 | if avg is None: 74 | avg = states 75 | else: 76 | for k in avg.keys(): 77 | avg[k] += states[k] 78 | # average 79 | for k in avg.keys(): 80 | if avg[k] is not None: 81 | # pytorch 1.6 use true_divide instead of /= 82 | avg[k] = torch.true_divide(avg[k], num) 83 | print('Saving to {}'.format(args.dst_model)) 84 | torch.save(avg, args.dst_model) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /wenet/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import os 19 | 20 | import torch 21 | import yaml 22 | 23 | from wenet.transformer.asr_model import init_asr_model 24 | from wenet.utils.checkpoint import load_checkpoint 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser(description='export your script model') 29 | parser.add_argument('--config', required=True, help='config file') 30 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 31 | parser.add_argument('--output_file', required=True, help='output file') 32 | parser.add_argument('--output_quant_file', 33 | default=None, 34 | help='output quantized model file') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | # No need gpu for model export 42 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 43 | 44 | with open(args.config, 'r') as fin: 45 | configs = yaml.load(fin, Loader=yaml.FullLoader) 46 | model = init_asr_model(configs) 47 | print(model) 48 | 49 | load_checkpoint(model, args.checkpoint) 50 | # Export jit torch script model 51 | 52 | script_model = torch.jit.script(model) 53 | script_model.save(args.output_file) 54 | print('Export model successfully, see {}'.format(args.output_file)) 55 | 56 | # Export quantized jit torch script model 57 | if args.output_quant_file: 58 | quantized_model = torch.quantization.quantize_dynamic( 59 | model, {torch.nn.Linear}, dtype=torch.qint8 60 | ) 61 | print(quantized_model) 62 | script_quant_model = torch.jit.script(quantized_model) 63 | script_quant_model.save(args.output_quant_file) 64 | print('Export quantized model successfully, ' 65 | 'see {}'.format(args.output_quant_file)) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /wenet/bin/recognize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import copy 19 | import logging 20 | import os 21 | import sys 22 | 23 | import torch 24 | import yaml 25 | from torch.utils.data import DataLoader 26 | 27 | from wenet.dataset.dataset import Dataset 28 | from wenet.transformer.asr_model import init_asr_model 29 | from wenet.utils.checkpoint import load_checkpoint 30 | from wenet.utils.file_utils import read_symbol_table 31 | from wenet.utils.config import override_config 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser(description='recognize with your model') 35 | parser.add_argument('--config', required=True, help='config file') 36 | parser.add_argument('--test_data', required=True, help='test data file') 37 | parser.add_argument('--data_type', 38 | default='raw', 39 | choices=['raw', 'shard'], 40 | help='train and cv data type') 41 | parser.add_argument('--gpu', 42 | type=int, 43 | default=-1, 44 | help='gpu id for this rank, -1 for cpu') 45 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 46 | parser.add_argument('--dict', required=True, help='dict file') 47 | parser.add_argument('--beam_size', 48 | type=int, 49 | default=10, 50 | help='beam size for search') 51 | parser.add_argument('--penalty', 52 | type=float, 53 | default=0.0, 54 | help='length penalty') 55 | parser.add_argument('--result_file', required=True, help='asr result file') 56 | parser.add_argument('--batch_size', 57 | type=int, 58 | default=16, 59 | help='asr result file') 60 | parser.add_argument('--mode', 61 | choices=[ 62 | 'attention', 'ctc_greedy_search', 63 | 'ctc_prefix_beam_search', 'attention_rescoring' 64 | ], 65 | default='attention', 66 | help='decoding mode') 67 | parser.add_argument('--ctc_weight', 68 | type=float, 69 | default=0.0, 70 | help='ctc weight for attention rescoring decode mode') 71 | parser.add_argument('--decoding_chunk_size', 72 | type=int, 73 | default=-1, 74 | help='''decoding chunk size, 75 | <0: for decoding, use full chunk. 76 | >0: for decoding, use fixed chunk size as set. 77 | 0: used for training, it's prohibited here''') 78 | parser.add_argument('--num_decoding_left_chunks', 79 | type=int, 80 | default=-1, 81 | help='number of left chunks for decoding') 82 | parser.add_argument('--simulate_streaming', 83 | action='store_true', 84 | help='simulate streaming inference') 85 | parser.add_argument('--reverse_weight', 86 | type=float, 87 | default=0.0, 88 | help='''right to left weight for attention rescoring 89 | decode mode''') 90 | parser.add_argument('--bpe_model', 91 | default=None, 92 | type=str, 93 | help='bpe model for english part') 94 | parser.add_argument('--override_config', 95 | action='append', 96 | default=[], 97 | help="override yaml config") 98 | 99 | args = parser.parse_args() 100 | print(args) 101 | return args 102 | 103 | 104 | def main(): 105 | args = get_args() 106 | logging.basicConfig(level=logging.DEBUG, 107 | format='%(asctime)s %(levelname)s %(message)s') 108 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 109 | 110 | if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' 111 | ] and args.batch_size > 1: 112 | logging.fatal( 113 | 'decoding mode {} must be running with batch_size == 1'.format( 114 | args.mode)) 115 | sys.exit(1) 116 | 117 | with open(args.config, 'r') as fin: 118 | configs = yaml.load(fin, Loader=yaml.FullLoader) 119 | if len(args.override_config) > 0: 120 | configs = override_config(configs, args.override_config) 121 | 122 | symbol_table = read_symbol_table(args.dict) 123 | test_conf = copy.deepcopy(configs['dataset_conf']) 124 | 125 | test_conf['filter_conf']['max_length'] = 102400 126 | test_conf['filter_conf']['min_length'] = 0 127 | test_conf['filter_conf']['token_max_length'] = 102400 128 | test_conf['filter_conf']['token_min_length'] = 0 129 | test_conf['filter_conf']['max_output_input_ratio'] = 102400 130 | test_conf['filter_conf']['min_output_input_ratio'] = 0 131 | test_conf['speed_perturb'] = False 132 | test_conf['spec_aug'] = False 133 | test_conf['shuffle'] = False 134 | test_conf['sort'] = False 135 | test_conf['fbank_conf']['dither'] = 0.0 136 | test_conf['batch_conf']['batch_size'] = args.batch_size 137 | 138 | 139 | test_dataset = Dataset(args.data_type, 140 | args.test_data, 141 | symbol_table, 142 | test_conf, 143 | args.bpe_model, 144 | partition=False) 145 | 146 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) 147 | 148 | # Init asr model from configs 149 | model = init_asr_model(configs) 150 | 151 | # Load dict 152 | char_dict = {} 153 | with open(args.dict, 'r') as fin: 154 | for line in fin: 155 | arr = line.strip().split() 156 | assert len(arr) == 2 157 | char_dict[int(arr[1])] = arr[0] 158 | eos = len(char_dict) - 1 159 | 160 | load_checkpoint(model, args.checkpoint) 161 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 162 | device = torch.device('cuda' if use_cuda else 'cpu') 163 | model = model.to(device) 164 | 165 | model.eval() 166 | with torch.no_grad(), open(args.result_file, 'w') as fout: 167 | for batch_idx, batch in enumerate(test_data_loader): 168 | keys, feats, target, feats_lengths, target_lengths = batch 169 | feats = feats.to(device) 170 | target = target.to(device) 171 | feats_lengths = feats_lengths.to(device) 172 | target_lengths = target_lengths.to(device) 173 | if args.mode == 'attention': 174 | hyps, _ = model.recognize( 175 | feats, 176 | feats_lengths, 177 | beam_size=args.beam_size, 178 | decoding_chunk_size=args.decoding_chunk_size, 179 | num_decoding_left_chunks=args.num_decoding_left_chunks, 180 | simulate_streaming=args.simulate_streaming) 181 | hyps = [hyp.tolist() for hyp in hyps] 182 | elif args.mode == 'ctc_greedy_search': 183 | hyps, _ = model.ctc_greedy_search( 184 | feats, 185 | feats_lengths, 186 | decoding_chunk_size=args.decoding_chunk_size, 187 | num_decoding_left_chunks=args.num_decoding_left_chunks, 188 | simulate_streaming=args.simulate_streaming) 189 | # ctc_prefix_beam_search and attention_rescoring only return one 190 | # result in List[int], change it to List[List[int]] for compatible 191 | # with other batch decoding mode 192 | elif args.mode == 'ctc_prefix_beam_search': 193 | assert (feats.size(0) == 1) 194 | hyp, _ = model.ctc_prefix_beam_search( 195 | feats, 196 | feats_lengths, 197 | args.beam_size, 198 | decoding_chunk_size=args.decoding_chunk_size, 199 | num_decoding_left_chunks=args.num_decoding_left_chunks, 200 | simulate_streaming=args.simulate_streaming) 201 | hyps = [hyp] 202 | elif args.mode == 'attention_rescoring': 203 | assert (feats.size(0) == 1) 204 | hyp, _ = model.attention_rescoring( 205 | feats, 206 | feats_lengths, 207 | args.beam_size, 208 | decoding_chunk_size=args.decoding_chunk_size, 209 | num_decoding_left_chunks=args.num_decoding_left_chunks, 210 | ctc_weight=args.ctc_weight, 211 | simulate_streaming=args.simulate_streaming, 212 | reverse_weight=args.reverse_weight) 213 | hyps = [hyp] 214 | for i, key in enumerate(keys): 215 | content = '' 216 | for w in hyps[i]: 217 | if w == eos: 218 | break 219 | content += char_dict[w] 220 | logging.info('{} {}'.format(key, content)) 221 | fout.write('{} {}\n'.format(key, content)) 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /wenet/bin/recognize_deprecated.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import copy 19 | import logging 20 | import os 21 | import sys 22 | 23 | import torch 24 | import yaml 25 | from torch.utils.data import DataLoader 26 | 27 | from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc 28 | from wenet.transformer.asr_model import init_asr_model 29 | from wenet.utils.checkpoint import load_checkpoint 30 | 31 | if __name__ == '__main__': 32 | print(""" 33 | !!! This file is deprecated, and we are planning to remove it in 34 | the future, please move to the new IO !!! 35 | """) 36 | parser = argparse.ArgumentParser(description='recognize with your model') 37 | parser.add_argument('--config', required=True, help='config file') 38 | parser.add_argument('--test_data', required=True, help='test data file') 39 | parser.add_argument('--gpu', 40 | type=int, 41 | default=-1, 42 | help='gpu id for this rank, -1 for cpu') 43 | parser.add_argument('--checkpoint', required=True, help='checkpoint model') 44 | parser.add_argument('--dict', required=True, help='dict file') 45 | parser.add_argument('--beam_size', 46 | type=int, 47 | default=10, 48 | help='beam size for search') 49 | parser.add_argument('--penalty', 50 | type=float, 51 | default=0.0, 52 | help='length penalty') 53 | parser.add_argument('--result_file', required=True, help='asr result file') 54 | parser.add_argument('--batch_size', 55 | type=int, 56 | default=16, 57 | help='asr result file') 58 | parser.add_argument('--mode', 59 | choices=[ 60 | 'attention', 'ctc_greedy_search', 61 | 'ctc_prefix_beam_search', 'attention_rescoring' 62 | ], 63 | default='attention', 64 | help='decoding mode') 65 | parser.add_argument('--ctc_weight', 66 | type=float, 67 | default=0.0, 68 | help='ctc weight for attention rescoring decode mode') 69 | parser.add_argument('--decoding_chunk_size', 70 | type=int, 71 | default=-1, 72 | help='''decoding chunk size, 73 | <0: for decoding, use full chunk. 74 | >0: for decoding, use fixed chunk size as set. 75 | 0: used for training, it's prohibited here''') 76 | parser.add_argument('--num_decoding_left_chunks', 77 | type=int, 78 | default=-1, 79 | help='number of left chunks for decoding') 80 | parser.add_argument('--simulate_streaming', 81 | action='store_true', 82 | help='simulate streaming inference') 83 | parser.add_argument('--reverse_weight', 84 | type=float, 85 | default=0.0, 86 | help='''right to left weight for attention rescoring 87 | decode mode''') 88 | args = parser.parse_args() 89 | print(args) 90 | logging.basicConfig(level=logging.DEBUG, 91 | format='%(asctime)s %(levelname)s %(message)s') 92 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 93 | 94 | if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' 95 | ] and args.batch_size > 1: 96 | logging.fatal( 97 | 'decoding mode {} must be running with batch_size == 1'.format( 98 | args.mode)) 99 | sys.exit(1) 100 | 101 | with open(args.config, 'r') as fin: 102 | configs = yaml.load(fin, Loader=yaml.FullLoader) 103 | 104 | raw_wav = configs['raw_wav'] 105 | # Init dataset and data loader 106 | # Init dataset and data loader 107 | test_collate_conf = copy.deepcopy(configs['collate_conf']) 108 | test_collate_conf['spec_aug'] = False 109 | test_collate_conf['spec_sub'] = False 110 | test_collate_conf['feature_dither'] = False 111 | test_collate_conf['speed_perturb'] = False 112 | if raw_wav: 113 | test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 114 | test_collate_conf['wav_distortion_conf']['wav_dither'] = 0.0 115 | test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav) 116 | dataset_conf = configs.get('dataset_conf', {}) 117 | dataset_conf['batch_size'] = args.batch_size 118 | dataset_conf['batch_type'] = 'static' 119 | dataset_conf['sort'] = False 120 | test_dataset = AudioDataset(args.test_data, 121 | **dataset_conf, 122 | raw_wav=raw_wav) 123 | test_data_loader = DataLoader(test_dataset, 124 | collate_fn=test_collate_func, 125 | shuffle=False, 126 | batch_size=1, 127 | num_workers=0) 128 | 129 | # Init asr model from configs 130 | model = init_asr_model(configs) 131 | 132 | # Load dict 133 | char_dict = {} 134 | with open(args.dict, 'r') as fin: 135 | for line in fin: 136 | arr = line.strip().split() 137 | assert len(arr) == 2 138 | char_dict[int(arr[1])] = arr[0] 139 | eos = len(char_dict) - 1 140 | 141 | load_checkpoint(model, args.checkpoint) 142 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 143 | device = torch.device('cuda' if use_cuda else 'cpu') 144 | model = model.to(device) 145 | 146 | model.eval() 147 | with torch.no_grad(), open(args.result_file, 'w') as fout: 148 | for batch_idx, batch in enumerate(test_data_loader): 149 | keys, feats, target, feats_lengths, target_lengths = batch 150 | feats = feats.to(device) 151 | target = target.to(device) 152 | feats_lengths = feats_lengths.to(device) 153 | target_lengths = target_lengths.to(device) 154 | if args.mode == 'attention': 155 | hyps, _ = model.recognize( 156 | feats, 157 | feats_lengths, 158 | beam_size=args.beam_size, 159 | decoding_chunk_size=args.decoding_chunk_size, 160 | num_decoding_left_chunks=args.num_decoding_left_chunks, 161 | simulate_streaming=args.simulate_streaming) 162 | hyps = [hyp.tolist() for hyp in hyps] 163 | elif args.mode == 'ctc_greedy_search': 164 | hyps, _ = model.ctc_greedy_search( 165 | feats, 166 | feats_lengths, 167 | decoding_chunk_size=args.decoding_chunk_size, 168 | num_decoding_left_chunks=args.num_decoding_left_chunks, 169 | simulate_streaming=args.simulate_streaming) 170 | # ctc_prefix_beam_search and attention_rescoring only return one 171 | # result in List[int], change it to List[List[int]] for compatible 172 | # with other batch decoding mode 173 | elif args.mode == 'ctc_prefix_beam_search': 174 | assert (feats.size(0) == 1) 175 | hyp, _ = model.ctc_prefix_beam_search( 176 | feats, 177 | feats_lengths, 178 | args.beam_size, 179 | decoding_chunk_size=args.decoding_chunk_size, 180 | num_decoding_left_chunks=args.num_decoding_left_chunks, 181 | simulate_streaming=args.simulate_streaming) 182 | hyps = [hyp] 183 | elif args.mode == 'attention_rescoring': 184 | assert (feats.size(0) == 1) 185 | hyp, _ = model.attention_rescoring( 186 | feats, 187 | feats_lengths, 188 | args.beam_size, 189 | decoding_chunk_size=args.decoding_chunk_size, 190 | num_decoding_left_chunks=args.num_decoding_left_chunks, 191 | ctc_weight=args.ctc_weight, 192 | simulate_streaming=args.simulate_streaming, 193 | reverse_weight=args.reverse_weight) 194 | hyps = [hyp] 195 | for i, key in enumerate(keys): 196 | content = '' 197 | for w in hyps[i]: 198 | if w == eos: 199 | break 200 | content += char_dict[w] 201 | logging.info('{} {}'.format(key, content)) 202 | fout.write('{} {}\n'.format(key, content)) 203 | -------------------------------------------------------------------------------- /wenet/bin/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import copy 19 | import logging 20 | import os 21 | 22 | import torch 23 | import torch.distributed as dist 24 | import torch.optim as optim 25 | import yaml 26 | from tensorboardX import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | 29 | from wenet.dataset.dataset import Dataset 30 | from wenet.transformer.asr_model import init_asr_model 31 | from wenet.utils.checkpoint import load_checkpoint, save_checkpoint 32 | from wenet.utils.executor import Executor 33 | from wenet.utils.file_utils import read_symbol_table 34 | from wenet.utils.scheduler import WarmupLR 35 | from wenet.utils.config import override_config 36 | 37 | def get_args(): 38 | parser = argparse.ArgumentParser(description='training your network') 39 | parser.add_argument('--config', required=True, help='config file') 40 | parser.add_argument('--data_type', 41 | default='raw', 42 | choices=['raw', 'shard'], 43 | help='train and cv data type') 44 | parser.add_argument('--train_data', required=True, help='train data file') 45 | parser.add_argument('--cv_data', required=True, help='cv data file') 46 | parser.add_argument('--gpu', 47 | type=int, 48 | default=-1, 49 | help='gpu id for this local rank, -1 for cpu') 50 | parser.add_argument('--model_dir', required=True, help='save model dir') 51 | parser.add_argument('--checkpoint', help='checkpoint model') 52 | parser.add_argument('--tensorboard_dir', 53 | default='tensorboard', 54 | help='tensorboard log dir') 55 | parser.add_argument('--ddp.rank', 56 | dest='rank', 57 | default=0, 58 | type=int, 59 | help='global rank for distributed training') 60 | parser.add_argument('--ddp.world_size', 61 | dest='world_size', 62 | default=-1, 63 | type=int, 64 | help='''number of total processes/gpus for 65 | distributed training''') 66 | parser.add_argument('--ddp.dist_backend', 67 | dest='dist_backend', 68 | default='nccl', 69 | choices=['nccl', 'gloo'], 70 | help='distributed backend') 71 | parser.add_argument('--ddp.init_method', 72 | dest='init_method', 73 | default=None, 74 | help='ddp init method') 75 | parser.add_argument('--num_workers', 76 | default=0, 77 | type=int, 78 | help='num of subprocess workers for reading') 79 | parser.add_argument('--pin_memory', 80 | action='store_true', 81 | default=False, 82 | help='Use pinned memory buffers used for reading') 83 | parser.add_argument('--use_amp', 84 | action='store_true', 85 | default=False, 86 | help='Use automatic mixed precision training') 87 | parser.add_argument('--cmvn', default=None, help='global cmvn file') 88 | parser.add_argument('--symbol_table', 89 | required=True, 90 | help='model unit symbol table for training') 91 | parser.add_argument('--prefetch', 92 | default=100, 93 | type=int, 94 | help='prefetch number') 95 | parser.add_argument('--bpe_model', 96 | default=None, 97 | type=str, 98 | help='bpe model for english part') 99 | parser.add_argument('--override_config', 100 | action='append', 101 | default=[], 102 | help="override yaml config") 103 | 104 | args = parser.parse_args() 105 | return args 106 | 107 | 108 | def main(): 109 | args = get_args() 110 | logging.basicConfig(level=logging.DEBUG, 111 | format='%(asctime)s %(levelname)s %(message)s') 112 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 113 | 114 | # Set random seed 115 | torch.manual_seed(777) 116 | with open(args.config, 'r') as fin: 117 | configs = yaml.load(fin, Loader=yaml.FullLoader) 118 | if len(args.override_config) > 0: 119 | configs = override_config(configs, args.override_config) 120 | 121 | distributed = args.world_size > 1 122 | if distributed: 123 | logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) 124 | dist.init_process_group(args.dist_backend, 125 | init_method=args.init_method, 126 | world_size=args.world_size, 127 | rank=args.rank) 128 | 129 | symbol_table = read_symbol_table(args.symbol_table) 130 | 131 | train_conf = configs['dataset_conf'] 132 | cv_conf = copy.deepcopy(train_conf) 133 | cv_conf['speed_perturb'] = False 134 | cv_conf['spec_aug'] = False 135 | 136 | train_dataset = Dataset(args.data_type, args.train_data, symbol_table, 137 | train_conf, args.bpe_model, partition=True) 138 | cv_dataset = Dataset(args.data_type, 139 | args.cv_data, 140 | symbol_table, 141 | cv_conf, 142 | args.bpe_model, 143 | partition=False) 144 | 145 | train_data_loader = DataLoader(train_dataset, 146 | batch_size=None, 147 | pin_memory=args.pin_memory, 148 | num_workers=args.num_workers, 149 | prefetch_factor=args.prefetch) 150 | cv_data_loader = DataLoader(cv_dataset, 151 | batch_size=None, 152 | pin_memory=args.pin_memory, 153 | num_workers=args.num_workers, 154 | prefetch_factor=args.prefetch) 155 | 156 | input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] 157 | vocab_size = len(symbol_table) 158 | 159 | # Save configs to model_dir/train.yaml for inference and export 160 | configs['input_dim'] = input_dim 161 | configs['output_dim'] = vocab_size 162 | configs['cmvn_file'] = args.cmvn 163 | configs['is_json_cmvn'] = True 164 | if args.rank == 0: 165 | saved_config_path = os.path.join(args.model_dir, 'train.yaml') 166 | with open(saved_config_path, 'w') as fout: 167 | data = yaml.dump(configs) 168 | fout.write(data) 169 | 170 | # Init asr model from configs 171 | model = init_asr_model(configs) 172 | print(model) 173 | num_params = sum(p.numel() for p in model.parameters()) 174 | print('the number of model params: {}'.format(num_params)) 175 | 176 | # !!!IMPORTANT!!! 177 | # Try to export the model by script, if fails, we should refine 178 | # the code to satisfy the script export requirements 179 | if args.rank == 0: 180 | script_model = torch.jit.script(model) 181 | script_model.save(os.path.join(args.model_dir, 'init.zip')) 182 | executor = Executor() 183 | # If specify checkpoint, load some info from checkpoint 184 | if args.checkpoint is not None: 185 | infos = load_checkpoint(model, args.checkpoint) 186 | else: 187 | infos = {} 188 | start_epoch = infos.get('epoch', -1) + 1 189 | cv_loss = infos.get('cv_loss', 0.0) 190 | step = infos.get('step', -1) 191 | 192 | num_epochs = configs.get('max_epoch', 100) 193 | model_dir = args.model_dir 194 | writer = None 195 | if args.rank == 0: 196 | os.makedirs(model_dir, exist_ok=True) 197 | exp_id = os.path.basename(model_dir) 198 | writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) 199 | 200 | if distributed: 201 | assert (torch.cuda.is_available()) 202 | # cuda model is required for nn.parallel.DistributedDataParallel 203 | model.cuda() 204 | model = torch.nn.parallel.DistributedDataParallel( 205 | model, find_unused_parameters=True) 206 | device = torch.device("cuda") 207 | else: 208 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 209 | device = torch.device('cuda' if use_cuda else 'cpu') 210 | model = model.to(device) 211 | 212 | optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) 213 | scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) 214 | final_epoch = None 215 | configs['rank'] = args.rank 216 | configs['is_distributed'] = distributed 217 | configs['use_amp'] = args.use_amp 218 | if start_epoch == 0 and args.rank == 0: 219 | save_model_path = os.path.join(model_dir, 'init.pt') 220 | save_checkpoint(model, save_model_path) 221 | 222 | # Start training loop 223 | executor.step = step 224 | scheduler.set_step(step) 225 | # used for pytorch amp mixed precision training 226 | scaler = None 227 | if args.use_amp: 228 | scaler = torch.cuda.amp.GradScaler() 229 | 230 | for epoch in range(start_epoch, num_epochs): 231 | train_dataset.set_epoch(epoch) 232 | configs['epoch'] = epoch 233 | lr = optimizer.param_groups[0]['lr'] 234 | logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) 235 | executor.train(model, optimizer, scheduler, train_data_loader, device, 236 | writer, configs, scaler) 237 | total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, 238 | configs) 239 | cv_loss = total_loss / num_seen_utts 240 | 241 | logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) 242 | if args.rank == 0: 243 | save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) 244 | save_checkpoint( 245 | model, save_model_path, { 246 | 'epoch': epoch, 247 | 'lr': lr, 248 | 'cv_loss': cv_loss, 249 | 'step': executor.step 250 | }) 251 | writer.add_scalar('epoch/cv_loss', cv_loss, epoch) 252 | writer.add_scalar('epoch/lr', lr, epoch) 253 | final_epoch = epoch 254 | 255 | if final_epoch is not None and args.rank == 0: 256 | final_model_path = os.path.join(model_dir, 'final.pt') 257 | os.symlink('{}.pt'.format(final_epoch), final_model_path) 258 | writer.close() 259 | 260 | 261 | if __name__ == '__main__': 262 | main() 263 | -------------------------------------------------------------------------------- /wenet/bin/train_deprecated.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import copy 19 | import logging 20 | import os 21 | 22 | import torch 23 | import torch.distributed as dist 24 | import torch.optim as optim 25 | import yaml 26 | from tensorboardX import SummaryWriter 27 | from torch.utils.data import DataLoader 28 | 29 | from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc 30 | from wenet.transformer.asr_model import init_asr_model 31 | from wenet.utils.checkpoint import load_checkpoint, save_checkpoint 32 | from wenet.utils.executor import Executor 33 | from wenet.utils.scheduler import WarmupLR 34 | 35 | if __name__ == '__main__': 36 | print(""" 37 | !!! This file is deprecated, and we are planning to remove it in 38 | the future, please move to the new IO !!! 39 | """) 40 | parser = argparse.ArgumentParser(description='training your network') 41 | parser.add_argument('--config', required=True, help='config file') 42 | parser.add_argument('--train_data', required=True, help='train data file') 43 | parser.add_argument('--cv_data', required=True, help='cv data file') 44 | parser.add_argument('--gpu', 45 | type=int, 46 | default=-1, 47 | help='gpu id for this local rank, -1 for cpu') 48 | parser.add_argument('--model_dir', required=True, help='save model dir') 49 | parser.add_argument('--checkpoint', help='checkpoint model') 50 | parser.add_argument('--tensorboard_dir', 51 | default='tensorboard', 52 | help='tensorboard log dir') 53 | parser.add_argument('--ddp.rank', 54 | dest='rank', 55 | default=0, 56 | type=int, 57 | help='global rank for distributed training') 58 | parser.add_argument('--ddp.world_size', 59 | dest='world_size', 60 | default=-1, 61 | type=int, 62 | help='''number of total processes/gpus for 63 | distributed training''') 64 | parser.add_argument('--ddp.dist_backend', 65 | dest='dist_backend', 66 | default='nccl', 67 | choices=['nccl', 'gloo'], 68 | help='distributed backend') 69 | parser.add_argument('--ddp.init_method', 70 | dest='init_method', 71 | default=None, 72 | help='ddp init method') 73 | parser.add_argument('--num_workers', 74 | default=0, 75 | type=int, 76 | help='num of subprocess workers for reading') 77 | parser.add_argument('--pin_memory', 78 | action='store_true', 79 | default=False, 80 | help='Use pinned memory buffers used for reading') 81 | parser.add_argument('--use_amp', 82 | action='store_true', 83 | default=False, 84 | help='Use automatic mixed precision training') 85 | parser.add_argument('--cmvn', default=None, help='global cmvn file') 86 | 87 | args = parser.parse_args() 88 | 89 | logging.basicConfig(level=logging.DEBUG, 90 | format='%(asctime)s %(levelname)s %(message)s') 91 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 92 | # Set random seed 93 | torch.manual_seed(777) 94 | print(args) 95 | with open(args.config, 'r') as fin: 96 | configs = yaml.load(fin, Loader=yaml.FullLoader) 97 | 98 | distributed = args.world_size > 1 99 | 100 | raw_wav = configs['raw_wav'] 101 | 102 | train_collate_func = CollateFunc(**configs['collate_conf'], 103 | raw_wav=raw_wav) 104 | 105 | cv_collate_conf = copy.deepcopy(configs['collate_conf']) 106 | # no augmenation on cv set 107 | cv_collate_conf['spec_aug'] = False 108 | cv_collate_conf['spec_sub'] = False 109 | if raw_wav: 110 | cv_collate_conf['feature_dither'] = 0.0 111 | cv_collate_conf['speed_perturb'] = False 112 | cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 113 | cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav) 114 | 115 | dataset_conf = configs.get('dataset_conf', {}) 116 | train_dataset = AudioDataset(args.train_data, 117 | **dataset_conf, 118 | raw_wav=raw_wav) 119 | cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav) 120 | 121 | if distributed: 122 | logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) 123 | dist.init_process_group(args.dist_backend, 124 | init_method=args.init_method, 125 | world_size=args.world_size, 126 | rank=args.rank) 127 | train_sampler = torch.utils.data.distributed.DistributedSampler( 128 | train_dataset, shuffle=True) 129 | cv_sampler = torch.utils.data.distributed.DistributedSampler( 130 | cv_dataset, shuffle=False) 131 | else: 132 | train_sampler = None 133 | cv_sampler = None 134 | 135 | train_data_loader = DataLoader(train_dataset, 136 | collate_fn=train_collate_func, 137 | sampler=train_sampler, 138 | shuffle=(train_sampler is None), 139 | pin_memory=args.pin_memory, 140 | batch_size=1, 141 | num_workers=args.num_workers) 142 | cv_data_loader = DataLoader(cv_dataset, 143 | collate_fn=cv_collate_func, 144 | sampler=cv_sampler, 145 | shuffle=False, 146 | batch_size=1, 147 | pin_memory=args.pin_memory, 148 | num_workers=args.num_workers) 149 | 150 | if raw_wav: 151 | input_dim = configs['collate_conf']['feature_extraction_conf'][ 152 | 'mel_bins'] 153 | else: 154 | input_dim = train_dataset.input_dim 155 | vocab_size = train_dataset.output_dim 156 | 157 | # Save configs to model_dir/train.yaml for inference and export 158 | configs['input_dim'] = input_dim 159 | configs['output_dim'] = vocab_size 160 | configs['cmvn_file'] = args.cmvn 161 | configs['is_json_cmvn'] = raw_wav 162 | if args.rank == 0: 163 | saved_config_path = os.path.join(args.model_dir, 'train.yaml') 164 | with open(saved_config_path, 'w') as fout: 165 | data = yaml.dump(configs) 166 | fout.write(data) 167 | 168 | # Init asr model from configs 169 | model = init_asr_model(configs) 170 | print(model) 171 | num_params = sum(p.numel() for p in model.parameters()) 172 | print('the number of model params: {}'.format(num_params)) 173 | 174 | # !!!IMPORTANT!!! 175 | # Try to export the model by script, if fails, we should refine 176 | # the code to satisfy the script export requirements 177 | if args.rank == 0: 178 | script_model = torch.jit.script(model) 179 | script_model.save(os.path.join(args.model_dir, 'init.zip')) 180 | executor = Executor() 181 | # If specify checkpoint, load some info from checkpoint 182 | if args.checkpoint is not None: 183 | infos = load_checkpoint(model, args.checkpoint) 184 | else: 185 | infos = {} 186 | start_epoch = infos.get('epoch', -1) + 1 187 | cv_loss = infos.get('cv_loss', 0.0) 188 | step = infos.get('step', -1) 189 | 190 | num_epochs = configs.get('max_epoch', 100) 191 | model_dir = args.model_dir 192 | writer = None 193 | if args.rank == 0: 194 | os.makedirs(model_dir, exist_ok=True) 195 | exp_id = os.path.basename(model_dir) 196 | writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) 197 | 198 | if distributed: 199 | assert (torch.cuda.is_available()) 200 | # cuda model is required for nn.parallel.DistributedDataParallel 201 | model.cuda() 202 | model = torch.nn.parallel.DistributedDataParallel( 203 | model, find_unused_parameters=True) 204 | device = torch.device("cuda") 205 | else: 206 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 207 | device = torch.device('cuda' if use_cuda else 'cpu') 208 | model = model.to(device) 209 | 210 | optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) 211 | scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) 212 | final_epoch = None 213 | configs['rank'] = args.rank 214 | configs['is_distributed'] = distributed 215 | configs['use_amp'] = args.use_amp 216 | if start_epoch == 0 and args.rank == 0: 217 | save_model_path = os.path.join(model_dir, 'init.pt') 218 | save_checkpoint(model, save_model_path) 219 | 220 | # Start training loop 221 | executor.step = step 222 | scheduler.set_step(step) 223 | # used for pytorch amp mixed precision training 224 | scaler = None 225 | if args.use_amp: 226 | scaler = torch.cuda.amp.GradScaler() 227 | for epoch in range(start_epoch, num_epochs): 228 | if distributed: 229 | train_sampler.set_epoch(epoch) 230 | lr = optimizer.param_groups[0]['lr'] 231 | logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) 232 | executor.train(model, optimizer, scheduler, train_data_loader, device, 233 | writer, configs, scaler) 234 | total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, 235 | configs) 236 | if args.world_size > 1: 237 | # all_reduce expected a sequence parameter, so we use [num_seen_utts]. 238 | num_seen_utts = torch.Tensor([num_seen_utts]).to(device) 239 | # the default operator in all_reduce function is sum. 240 | dist.all_reduce(num_seen_utts) 241 | total_loss = torch.Tensor([total_loss]).to(device) 242 | dist.all_reduce(total_loss) 243 | cv_loss = total_loss[0] / num_seen_utts[0] 244 | cv_loss = cv_loss.item() 245 | else: 246 | cv_loss = total_loss / num_seen_utts 247 | 248 | logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) 249 | if args.rank == 0: 250 | save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) 251 | save_checkpoint( 252 | model, save_model_path, { 253 | 'epoch': epoch, 254 | 'lr': lr, 255 | 'cv_loss': cv_loss, 256 | 'step': executor.step 257 | }) 258 | writer.add_scalar('epoch/cv_loss', cv_loss, epoch) 259 | writer.add_scalar('epoch/lr', lr, epoch) 260 | final_epoch = epoch 261 | 262 | if final_epoch is not None and args.rank == 0: 263 | final_model_path = os.path.join(model_dir, 'final.pt') 264 | os.symlink('{}.pt'.format(final_epoch), final_model_path) 265 | writer.close() 266 | -------------------------------------------------------------------------------- /wenet/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/dataset/__pycache__/dataset_deprecated.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/dataset/__pycache__/dataset_deprecated.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/dataset/__pycache__/kaldi_io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/dataset/__pycache__/kaldi_io.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/dataset/__pycache__/processor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/dataset/__pycache__/processor.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/dataset/__pycache__/wav_distortion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/dataset/__pycache__/wav_distortion.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | import torch 18 | import torch.distributed as dist 19 | from torch.utils.data import IterableDataset 20 | 21 | import wenet.dataset.processor as processor 22 | from wenet.utils.file_utils import read_lists 23 | 24 | 25 | class Processor(IterableDataset): 26 | def __init__(self, source, f, *args, **kw): 27 | assert callable(f) 28 | self.source = source 29 | self.f = f 30 | self.args = args 31 | self.kw = kw 32 | 33 | def set_epoch(self, epoch): 34 | self.source.set_epoch(epoch) 35 | 36 | def __iter__(self): 37 | """ Return an iterator over the source dataset processed by the 38 | given processor. 39 | """ 40 | assert self.source is not None 41 | assert callable(self.f) 42 | return self.f(iter(self.source), *self.args, **self.kw) 43 | 44 | def apply(self, f): 45 | assert callable(f) 46 | return Processor(self, f, *self.args, **self.kw) 47 | 48 | 49 | class DistributedSampler: 50 | def __init__(self, shuffle=True, partition=True): 51 | self.epoch = -1 52 | self.update() 53 | self.shuffle = shuffle 54 | self.partition = partition 55 | 56 | def update(self): 57 | assert dist.is_available() 58 | if dist.is_initialized(): 59 | self.rank = dist.get_rank() 60 | self.world_size = dist.get_world_size() 61 | else: 62 | self.rank = 0 63 | self.world_size = 1 64 | worker_info = torch.utils.data.get_worker_info() 65 | if worker_info is None: 66 | self.worker_id = 0 67 | self.num_workers = 1 68 | else: 69 | self.worker_id = worker_info.id 70 | self.num_workers = worker_info.num_workers 71 | return dict(rank=self.rank, 72 | world_size=self.world_size, 73 | worker_id=self.worker_id, 74 | num_workers=self.num_workers) 75 | 76 | def set_epoch(self, epoch): 77 | self.epoch = epoch 78 | 79 | def sample(self, data): 80 | """ Sample data according to rank/world_size/num_workers 81 | 82 | Args: 83 | data(List): input data list 84 | 85 | Returns: 86 | List: data list after sample 87 | """ 88 | data = data.copy() 89 | # TODO(Binbin Zhang): fix this 90 | # We can not handle uneven data for CV on DDP, so we don't 91 | # sample data by rank, that means every GPU gets the same 92 | # and all the CV data 93 | if self.partition: 94 | if self.shuffle: 95 | random.Random(self.epoch).shuffle(data) 96 | data = data[self.rank::self.world_size] 97 | data = data[self.worker_id::self.num_workers] 98 | return data 99 | 100 | 101 | class DataList(IterableDataset): 102 | def __init__(self, lists, shuffle=True, partition=True): 103 | self.lists = lists 104 | self.sampler = DistributedSampler(shuffle, partition) 105 | 106 | def set_epoch(self, epoch): 107 | self.sampler.set_epoch(epoch) 108 | 109 | def __iter__(self): 110 | sampler_info = self.sampler.update() 111 | lists = self.sampler.sample(self.lists) 112 | for src in lists: 113 | # yield dict(src=src) 114 | data = dict(src=src) 115 | data.update(sampler_info) 116 | yield data 117 | 118 | 119 | def Dataset(data_type, data_list_file, symbol_table, conf, 120 | bpe_model=None, partition=True): 121 | """ Construct dataset from arguments 122 | 123 | We have two shuffle stage in the Dataset. The first is global 124 | shuffle at shards tar/raw file level. The second is global shuffle 125 | at training samples level. 126 | 127 | Args: 128 | data_type(str): raw/shard 129 | bpe_model(str): model for english bpe part 130 | partition(bool): whether to do data partition in terms of rank 131 | """ 132 | assert data_type in ['raw', 'shard'] 133 | lists = read_lists(data_list_file) 134 | #print('***'*20,lists) 135 | shuffle = conf.get('shuffle', True) 136 | dataset = DataList(lists, shuffle=shuffle, partition=partition) 137 | #print(dataset) 138 | #assert 0==1 139 | if data_type == 'shard': 140 | dataset = Processor(dataset, processor.url_opener) 141 | dataset = Processor(dataset, processor.tar_file_and_group) 142 | else: 143 | dataset = Processor(dataset, processor.parse_raw) 144 | 145 | dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model) 146 | filter_conf = conf.get('filter_conf', {}) 147 | dataset = Processor(dataset, processor.filter, **filter_conf) 148 | 149 | resample_conf = conf.get('resample_conf', {}) 150 | dataset = Processor(dataset, processor.resample, **resample_conf) 151 | 152 | speed_perturb = conf.get('speed_perturb', False) 153 | if speed_perturb: 154 | dataset = Processor(dataset, processor.speed_perturb) 155 | 156 | fbank_conf = conf.get('fbank_conf', {}) 157 | dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) 158 | 159 | spec_aug = conf.get('spec_aug', True) 160 | if spec_aug: 161 | spec_aug_conf = conf.get('spec_aug_conf', {}) 162 | dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) 163 | 164 | if shuffle: 165 | shuffle_conf = conf.get('shuffle_conf', {}) 166 | dataset = Processor(dataset, processor.shuffle, **shuffle_conf) 167 | 168 | sort = conf.get('sort', True) 169 | if sort: 170 | sort_conf = conf.get('sort_conf', {}) 171 | dataset = Processor(dataset, processor.sort, **sort_conf) 172 | 173 | batch_conf = conf.get('batch_conf', {}) 174 | dataset = Processor(dataset, processor.batch, **batch_conf) 175 | dataset = Processor(dataset, processor.padding) 176 | return dataset 177 | -------------------------------------------------------------------------------- /wenet/dataset/wav_distortion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import math 4 | 5 | import torchaudio 6 | import torch 7 | torchaudio.set_audio_backend("sox_io") 8 | 9 | 10 | def db2amp(db): 11 | return pow(10, db / 20) 12 | 13 | def amp2db(amp): 14 | return 20 * math.log10(amp) 15 | 16 | def make_poly_distortion(conf): 17 | """Generate a db-domain ploynomial distortion function 18 | 19 | f(x) = a * x^m * (1-x)^n + x 20 | 21 | Args: 22 | conf: a dict {'a': #int, 'm': #int, 'n': #int} 23 | 24 | Returns: 25 | The ploynomial function, which could be applied on 26 | a float amplitude value 27 | """ 28 | a = conf['a'] 29 | m = conf['m'] 30 | n = conf['n'] 31 | 32 | def poly_distortion(x): 33 | abs_x = abs(x) 34 | if abs_x < 0.000001: 35 | x = x 36 | else: 37 | db_norm = amp2db(abs_x) / 100 + 1 38 | if db_norm < 0: 39 | db_norm = 0 40 | db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm 41 | if db_norm > 1: 42 | db_norm = 1 43 | db = (db_norm - 1) * 100 44 | amp = db2amp(db) 45 | if amp >= 0.9997: 46 | amp = 0.9997 47 | if x > 0: 48 | x = amp 49 | else: 50 | x = -amp 51 | return x 52 | return poly_distortion 53 | 54 | def make_quad_distortion(): 55 | return make_poly_distortion({'a' : 1, 'm' : 1, 'n' : 1}) 56 | 57 | # the amplitude are set to max for all non-zero point 58 | def make_max_distortion(conf): 59 | """Generate a max distortion function 60 | 61 | Args: 62 | conf: a dict {'max_db': float } 63 | 'max_db': the maxium value. 64 | 65 | Returns: 66 | The max function, which could be applied on 67 | a float amplitude value 68 | """ 69 | max_db = conf['max_db'] 70 | if max_db: 71 | max_amp = db2amp(max_db) # < 0.997 72 | else: 73 | max_amp = 0.997 74 | 75 | def max_distortion(x): 76 | if x > 0: 77 | x = max_amp 78 | elif x < 0: 79 | x = -max_amp 80 | else: 81 | x = 0.0 82 | return x 83 | return max_distortion 84 | 85 | 86 | 87 | def make_amp_mask(db_mask=None): 88 | """Get a amplitude domain mask from db domain mask 89 | 90 | Args: 91 | db_mask: Optional. A list of tuple. if None, using default value. 92 | 93 | Returns: 94 | A list of tuple. The amplitude domain mask 95 | """ 96 | if db_mask is None: 97 | db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)] 98 | amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask] 99 | return amp_mask 100 | 101 | default_mask = make_amp_mask() 102 | 103 | 104 | def generate_amp_mask(mask_num): 105 | """Generate amplitude domain mask randomly in [-100db, 0db] 106 | 107 | Args: 108 | mask_num: the slot number of the mask 109 | 110 | Returns: 111 | A list of tuple. each tuple defines a slot. 112 | e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)] 113 | for #mask_num = 4 114 | """ 115 | a = [0] * 2 * mask_num 116 | a[0] = 0 117 | m = [] 118 | for i in range(1, 2 * mask_num): 119 | a[i] = a[i - 1] + random.uniform(0.5, 1) 120 | max_val = a[2 * mask_num - 1] 121 | for i in range(0, mask_num): 122 | l = ((a[2 * i] - max_val) / max_val) * 100 123 | r = ((a[2 * i + 1] - max_val) / max_val) * 100 124 | m.append((l, r)) 125 | return make_amp_mask(m) 126 | 127 | 128 | def make_fence_distortion(conf): 129 | """Generate a fence distortion function 130 | 131 | In this fence-like shape function, the values in mask slots are 132 | set to maxium, while the values not in mask slots are set to 0. 133 | Use seperated masks for Positive and negetive amplitude. 134 | 135 | Args: 136 | conf: a dict {'mask_number': int,'max_db': float } 137 | 'mask_number': the slot number in mask. 138 | 'max_db': the maxium value. 139 | 140 | Returns: 141 | The fence function, which could be applied on 142 | a float amplitude value 143 | """ 144 | mask_number = conf['mask_number'] 145 | max_db = conf['max_db'] 146 | max_amp = db2amp(max_db) # 0.997 147 | if mask_number <= 0 : 148 | positive_mask = default_mask 149 | negative_mask = make_amp_mask([(-50, 0)]) 150 | else: 151 | positive_mask = generate_amp_mask(mask_number) 152 | negative_mask = generate_amp_mask(mask_number) 153 | 154 | def fence_distortion(x): 155 | is_in_mask = False 156 | if x > 0: 157 | for mask in positive_mask: 158 | if x >= mask[0] and x <= mask[1]: 159 | is_in_mask = True 160 | return max_amp 161 | if not is_in_mask: 162 | return 0.0 163 | elif x < 0: 164 | abs_x = abs(x) 165 | for mask in negative_mask: 166 | if abs_x >= mask[0] and abs_x <= mask[1]: 167 | is_in_mask = True 168 | return max_amp 169 | if not is_in_mask: 170 | return 0.0 171 | return x 172 | 173 | return fence_distortion 174 | 175 | # 176 | def make_jag_distortion(conf): 177 | """Generate a jag distortion function 178 | 179 | In this jag-like shape function, the values in mask slots are 180 | not changed, while the values not in mask slots are set to 0. 181 | Use seperated masks for Positive and negetive amplitude. 182 | 183 | Args: 184 | conf: a dict {'mask_number': #int} 185 | 'mask_number': the slot number in mask. 186 | 187 | Returns: 188 | The jag function,which could be applied on 189 | a float amplitude value 190 | """ 191 | mask_number = conf['mask_number'] 192 | if mask_number <= 0 : 193 | positive_mask = default_mask 194 | negative_mask = make_amp_mask([(-50, 0)]) 195 | else: 196 | positive_mask = generate_amp_mask(mask_number) 197 | negative_mask = generate_amp_mask(mask_number) 198 | 199 | def jag_distortion(x): 200 | is_in_mask = False 201 | if x > 0: 202 | for mask in positive_mask: 203 | if x >= mask[0] and x <= mask[1]: 204 | is_in_mask = True 205 | return x 206 | if not is_in_mask: 207 | return 0.0 208 | elif x < 0: 209 | abs_x = abs(x) 210 | for mask in negative_mask: 211 | if abs_x >= mask[0] and abs_x <= mask[1]: 212 | is_in_mask = True 213 | return x 214 | if not is_in_mask: 215 | return 0.0 216 | return x 217 | 218 | return jag_distortion 219 | 220 | # gaining 20db means amp = amp * 10 221 | # gaining -20db means amp = amp / 10 222 | def make_gain_db(conf): 223 | """Generate a db domain gain function 224 | 225 | Args: 226 | conf: a dict {'db': #float} 227 | 'db': the gaining value 228 | 229 | Returns: 230 | The db gain function, which could be applied on 231 | a float amplitude value 232 | """ 233 | db = conf['db'] 234 | 235 | def gain_db(x): 236 | return min(0.997, x * pow(10, db / 20)) 237 | 238 | return gain_db 239 | 240 | 241 | def distort(x, func, rate=0.8): 242 | """Distort a waveform in sample point level 243 | 244 | Args: 245 | x: the origin wavefrom 246 | func: the distort function 247 | rate: sample point-level distort probability 248 | 249 | Returns: 250 | the distorted waveform 251 | """ 252 | for i in range(0, x.shape[1]): 253 | a = random.uniform(0, 1) 254 | if a < rate: 255 | x[0][i] = func(float(x[0][i])) 256 | return x 257 | 258 | def distort_chain(x, funcs, rate=0.8): 259 | for i in range(0, x.shape[1]): 260 | a = random.uniform(0, 1) 261 | if a < rate: 262 | for func in funcs: 263 | x[0][i] = func(float(x[0][i])) 264 | return x 265 | 266 | # x is numpy 267 | def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): 268 | if distort_type == 'gain_db': 269 | gain_db = make_gain_db(distort_conf) 270 | x = distort(x, gain_db) 271 | elif distort_type == 'max_distortion': 272 | max_distortion = make_max_distortion(distort_conf) 273 | x = distort(x, max_distortion, rate=rate) 274 | elif distort_type == 'fence_distortion': 275 | fence_distortion = make_fence_distortion(distort_conf) 276 | x = distort(x, fence_distortion, rate=rate) 277 | elif distort_type == 'jag_distortion': 278 | jag_distortion = make_jag_distortion(distort_conf) 279 | x = distort(x, jag_distortion, rate=rate) 280 | elif distort_type == 'poly_distortion': 281 | poly_distortion = make_poly_distortion(distort_conf) 282 | x = distort(x, poly_distortion, rate=rate) 283 | elif distort_type == 'quad_distortion': 284 | quad_distortion = make_quad_distortion() 285 | x = distort(x, quad_distortion, rate=rate) 286 | elif distort_type == 'none_distortion': 287 | pass 288 | else: 289 | print('unsupport type') 290 | return x 291 | 292 | def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out): 293 | x, sr = torchaudio.load(wav_in) 294 | x = x.detach().numpy() 295 | out = distort_wav_conf(x, distort_type, distort_conf, rate) 296 | torchaudio.save(wav_out, torch.from_numpy(out), sr) 297 | 298 | if __name__ == "__main__": 299 | distort_type = sys.argv[1] 300 | wav_in = sys.argv[2] 301 | wav_out = sys.argv[3] 302 | conf = None 303 | rate = 0.1 304 | if distort_type == 'new_jag_distortion': 305 | conf = {'mask_number' : 4} 306 | elif distort_type == 'new_fence_distortion': 307 | conf = {'mask_number' : 1, 'max_db' : -30} 308 | elif distort_type == 'poly_distortion': 309 | conf = {'a' : 4, 'm' : 2, "n" : 2} 310 | distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out) 311 | -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/asr_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/asr_model.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/asr_model_streaming.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/asr_model_streaming.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/cmvn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/cmvn.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/convolution.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/convolution.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/ctc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/ctc.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/decoder_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/decoder_layer.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/decoder_streaming.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/decoder_streaming.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/embedding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/embedding.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/encoder_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/encoder_layer.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/encoder_streaming.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/encoder_streaming.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/label_smoothing_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/label_smoothing_loss.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/subsampling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/subsampling.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/__pycache__/swish.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/transformer/__pycache__/swish.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/transformer/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Multi-Head Attention layer definition.""" 7 | 8 | import math 9 | from typing import Optional, Tuple 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class MultiHeadedAttention(nn.Module): 16 | """Multi-Head Attention layer. 17 | 18 | Args: 19 | n_head (int): The number of heads. 20 | n_feat (int): The number of features. 21 | dropout_rate (float): Dropout rate. 22 | 23 | """ 24 | def __init__(self, n_head: int, n_feat: int, dropout_rate: float): 25 | """Construct an MultiHeadedAttention object.""" 26 | super().__init__() 27 | assert n_feat % n_head == 0 28 | # We assume d_v always equals d_k 29 | self.d_k = n_feat // n_head 30 | self.h = n_head 31 | self.linear_q = nn.Linear(n_feat, n_feat) 32 | self.linear_k = nn.Linear(n_feat, n_feat) 33 | self.linear_v = nn.Linear(n_feat, n_feat) 34 | self.linear_out = nn.Linear(n_feat, n_feat) 35 | self.dropout = nn.Dropout(p=dropout_rate) 36 | 37 | def forward_qkv( 38 | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 39 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 40 | """Transform query, key and value. 41 | 42 | Args: 43 | query (torch.Tensor): Query tensor (#batch, time1, size). 44 | key (torch.Tensor): Key tensor (#batch, time2, size). 45 | value (torch.Tensor): Value tensor (#batch, time2, size). 46 | 47 | Returns: 48 | torch.Tensor: Transformed query tensor, size 49 | (#batch, n_head, time1, d_k). 50 | torch.Tensor: Transformed key tensor, size 51 | (#batch, n_head, time2, d_k). 52 | torch.Tensor: Transformed value tensor, size 53 | (#batch, n_head, time2, d_k). 54 | 55 | """ 56 | n_batch = query.size(0) 57 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 58 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 59 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 60 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 61 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 62 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 63 | 64 | return q, k, v 65 | 66 | def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, 67 | mask: Optional[torch.Tensor]) -> torch.Tensor: 68 | """Compute attention context vector. 69 | 70 | Args: 71 | value (torch.Tensor): Transformed value, size 72 | (#batch, n_head, time2, d_k). 73 | scores (torch.Tensor): Attention score, size 74 | (#batch, n_head, time1, time2). 75 | mask (torch.Tensor): Mask, size (#batch, 1, time2) or 76 | (#batch, time1, time2). 77 | 78 | Returns: 79 | torch.Tensor: Transformed value (#batch, time1, d_model) 80 | weighted by the attention score (#batch, time1, time2). 81 | 82 | """ 83 | n_batch = value.size(0) 84 | if mask is not None: 85 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 86 | scores = scores.masked_fill(mask, -float('inf')) 87 | attn = torch.softmax(scores, dim=-1).masked_fill( 88 | mask, 0.0) # (batch, head, time1, time2) 89 | else: 90 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 91 | 92 | p_attn = self.dropout(attn) 93 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 94 | x = (x.transpose(1, 2).contiguous().view(n_batch, -1, 95 | self.h * self.d_k) 96 | ) # (batch, time1, d_model) 97 | #最后将head与d_k相乘,拼接得到完整的特征 98 | 99 | return self.linear_out(x) # (batch, time1, d_model) 100 | 101 | def forward(self, query: torch.Tensor, key: torch.Tensor, 102 | value: torch.Tensor, 103 | mask: Optional[torch.Tensor], 104 | pos_emb: torch.Tensor = torch.empty(0),) -> torch.Tensor: 105 | """Compute scaled dot product attention. 106 | 107 | Args: 108 | query (torch.Tensor): Query tensor (#batch, time1, size). 109 | key (torch.Tensor): Key tensor (#batch, time2, size). 110 | value (torch.Tensor): Value tensor (#batch, time2, size). 111 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 112 | (#batch, time1, time2). 113 | 1.When applying cross attention between decoder and encoder, 114 | the batch padding mask for input is in (#batch, 1, T) shape. 115 | 2.When applying self attention of encoder, 116 | the mask is in (#batch, T, T) shape. 117 | 3.When applying self attention of decoder, 118 | the mask is in (#batch, L, L) shape. 119 | 4.If the different position in decoder see different block 120 | of the encoder, such as Mocha, the passed in mask could be 121 | in (#batch, L, T) shape. But there is no such case in current 122 | Wenet. 123 | 124 | 125 | Returns: 126 | torch.Tensor: Output tensor (#batch, time1, d_model). 127 | 128 | """ 129 | q, k, v = self.forward_qkv(query, key, value) 130 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 131 | # matmul:[batch, head, time1, d_k]*[batch, head, d_k, time2]--> [batch,head, time1, time2] 132 | # / : d_k = feat//head #将特征拆分为head个,吗,每个head学习不同层面的特征,最后将所有的head上的特征拼接起来即可。 133 | return self.forward_attention(v, scores, mask) #v的shape为[batch, head, time2, d_k],在score的基础上对v进行加权 134 | # 最后multiheadattention的数据shape为 (batch, time1, d_model) ,d_model由linear层的输出决定。 135 | 136 | class RelPositionMultiHeadedAttention(MultiHeadedAttention): 137 | """Multi-Head Attention layer with relative position encoding. 138 | Paper: https://arxiv.org/abs/1901.02860 139 | Args: 140 | n_head (int): The number of heads. 141 | n_feat (int): The number of features. 142 | dropout_rate (float): Dropout rate. 143 | """ 144 | def __init__(self, n_head, n_feat, dropout_rate): 145 | """Construct an RelPositionMultiHeadedAttention object.""" 146 | super().__init__(n_head, n_feat, dropout_rate) 147 | # linear transformation for positional encoding 148 | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) 149 | # these two learnable bias are used in matrix c and matrix d 150 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 151 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 152 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 153 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 154 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 155 | 156 | def rel_shift(self, x, zero_triu: bool = False): 157 | """Compute relative positinal encoding. 158 | Args: 159 | x (torch.Tensor): Input tensor (batch, time, size). 160 | zero_triu (bool): If true, return the lower triangular part of 161 | the matrix. 162 | Returns: 163 | torch.Tensor: Output tensor. 164 | """ 165 | 166 | zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), 167 | device=x.device, 168 | dtype=x.dtype) 169 | x_padded = torch.cat([zero_pad, x], dim=-1) 170 | 171 | x_padded = x_padded.view(x.size()[0], 172 | x.size()[1], 173 | x.size(3) + 1, x.size(2)) 174 | x = x_padded[:, :, 1:].view_as(x) 175 | 176 | if zero_triu: 177 | ones = torch.ones((x.size(2), x.size(3))) 178 | x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] 179 | 180 | return x 181 | 182 | def forward(self, query: torch.Tensor, key: torch.Tensor, 183 | value: torch.Tensor, mask: Optional[torch.Tensor], 184 | pos_emb: torch.Tensor): 185 | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. 186 | Args: 187 | query (torch.Tensor): Query tensor (#batch, time1, size). 188 | key (torch.Tensor): Key tensor (#batch, time2, size). 189 | value (torch.Tensor): Value tensor (#batch, time2, size). 190 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 191 | (#batch, time1, time2). 192 | pos_emb (torch.Tensor): Positional embedding tensor 193 | (#batch, time2, size). 194 | Returns: 195 | torch.Tensor: Output tensor (#batch, time1, d_model). 196 | """ 197 | q, k, v = self.forward_qkv(query, key, value) 198 | q = q.transpose(1, 2) # (batch, time1, head, d_k) 199 | 200 | n_batch_pos = pos_emb.size(0) 201 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 202 | p = p.transpose(1, 2) # (batch, head, time1, d_k) 203 | 204 | # (batch, head, time1, d_k) 205 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 206 | # (batch, head, time1, d_k) 207 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 208 | 209 | # compute attention score 210 | # first compute matrix a and matrix c 211 | # as described in https://arxiv.org/abs/1901.02860 Section 3.3 212 | # (batch, head, time1, time2) 213 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 214 | 215 | # compute matrix b and matrix d 216 | # (batch, head, time1, time2) 217 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 218 | # Remove rel_shift since it is useless in speech recognition, 219 | # and it requires special attention for streaming. 220 | # matrix_bd = self.rel_shift(matrix_bd) 221 | 222 | scores = (matrix_ac + matrix_bd) / math.sqrt( 223 | self.d_k) # (batch, head, time1, time2) 224 | 225 | return self.forward_attention(v, scores, mask) 226 | -------------------------------------------------------------------------------- /wenet/transformer/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | class GlobalCMVN(torch.nn.Module): 20 | def __init__(self, 21 | mean: torch.Tensor, 22 | istd: torch.Tensor, 23 | norm_var: bool = True): 24 | """ 25 | Args: 26 | mean (torch.Tensor): mean stats 27 | istd (torch.Tensor): inverse std, std which is 1.0 / std 28 | """ 29 | super().__init__() 30 | assert mean.shape == istd.shape 31 | self.norm_var = norm_var 32 | # The buffer can be accessed from this module using self.mean 33 | self.register_buffer("mean", mean) 34 | self.register_buffer("istd", istd) 35 | 36 | def forward(self, x: torch.Tensor): 37 | """ 38 | Args: 39 | x (torch.Tensor): (batch, max_len, feat_dim) 40 | 41 | Returns: 42 | (torch.Tensor): normalized feature 43 | """ 44 | x = x - self.mean 45 | if self.norm_var: 46 | x = x * self.istd 47 | return x 48 | -------------------------------------------------------------------------------- /wenet/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """ConvolutionModule definition.""" 7 | 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from torch import nn 12 | from typeguard import check_argument_types 13 | 14 | 15 | class ConvolutionModule(nn.Module): 16 | """ConvolutionModule in Conformer model.""" 17 | def __init__(self, 18 | channels: int, 19 | kernel_size: int = 15, 20 | activation: nn.Module = nn.ReLU(), 21 | norm: str = "batch_norm", 22 | causal: bool = False, 23 | bias: bool = True): 24 | """Construct an ConvolutionModule object. 25 | Args: 26 | channels (int): The number of channels of conv layers. 27 | kernel_size (int): Kernel size of conv layers. 28 | causal (int): Whether use causal convolution or not 29 | """ 30 | assert check_argument_types() 31 | super().__init__() 32 | 33 | self.pointwise_conv1 = nn.Conv1d( 34 | channels, 35 | 2 * channels, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, 39 | bias=bias, 40 | ) 41 | # self.lorder is used to distinguish if it's a causal convolution, 42 | # if self.lorder > 0: it's a causal convolution, the input will be 43 | # padded with self.lorder frames on the left in forward. 44 | # else: it's a symmetrical convolution 45 | if causal: 46 | padding = 0 47 | self.lorder = kernel_size - 1 48 | else: 49 | # kernel_size should be an odd number for none causal convolution 50 | assert (kernel_size - 1) % 2 == 0 51 | padding = (kernel_size - 1) // 2 52 | self.lorder = 0 53 | self.depthwise_conv = nn.Conv1d( 54 | channels, 55 | channels, 56 | kernel_size, 57 | stride=1, 58 | padding=padding, 59 | groups=channels, 60 | bias=bias, 61 | ) 62 | 63 | assert norm in ['batch_norm', 'layer_norm'] 64 | if norm == "batch_norm": 65 | self.use_layer_norm = False 66 | self.norm = nn.BatchNorm1d(channels) 67 | else: 68 | self.use_layer_norm = True 69 | self.norm = nn.LayerNorm(channels) 70 | 71 | self.pointwise_conv2 = nn.Conv1d( 72 | channels, 73 | channels, 74 | kernel_size=1, 75 | stride=1, 76 | padding=0, 77 | bias=bias, 78 | ) 79 | self.activation = activation 80 | 81 | def forward( 82 | self, 83 | x: torch.Tensor, 84 | mask_pad: Optional[torch.Tensor] = None, 85 | cache: Optional[torch.Tensor] = None, 86 | ) -> Tuple[torch.Tensor, torch.Tensor]: 87 | """Compute convolution module. 88 | Args: 89 | x (torch.Tensor): Input tensor (#batch, time, channels). 90 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time) 91 | cache (torch.Tensor): left context cache, it is only 92 | used in causal convolution 93 | Returns: 94 | torch.Tensor: Output tensor (#batch, time, channels). 95 | """ 96 | # exchange the temporal dimension and the feature dimension 97 | x = x.transpose(1, 2) # (#batch, channels, time) 98 | 99 | # mask batch padding 100 | if mask_pad is not None: 101 | x.masked_fill_(~mask_pad, 0.0) 102 | 103 | if self.lorder > 0: 104 | if cache is None: 105 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 106 | else: 107 | assert cache.size(0) == x.size(0) 108 | assert cache.size(1) == x.size(1) 109 | x = torch.cat((cache, x), dim=2) 110 | assert (x.size(2) > self.lorder) 111 | new_cache = x[:, :, -self.lorder:] 112 | else: 113 | # It's better we just return None if no cache is requried, 114 | # However, for JIT export, here we just fake one tensor instead of 115 | # None. 116 | new_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 117 | 118 | # GLU mechanism 119 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 120 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 121 | 122 | # 1D Depthwise Conv 123 | x = self.depthwise_conv(x) 124 | if self.use_layer_norm: 125 | x = x.transpose(1, 2) 126 | x = self.activation(self.norm(x)) 127 | if self.use_layer_norm: 128 | x = x.transpose(1, 2) 129 | x = self.pointwise_conv2(x) 130 | # mask batch padding 131 | if mask_pad is not None: 132 | x.masked_fill_(~mask_pad, 0.0) 133 | 134 | return x.transpose(1, 2), new_cache 135 | -------------------------------------------------------------------------------- /wenet/transformer/ctc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typeguard import check_argument_types 4 | 5 | 6 | class CTC(torch.nn.Module): 7 | """CTC module""" 8 | def __init__( 9 | self, 10 | odim: int, 11 | encoder_output_size: int, 12 | dropout_rate: float = 0.0, 13 | reduce: bool = True, 14 | ): 15 | """ Construct CTC module 16 | Args: 17 | odim: dimension of outputs 18 | encoder_output_size: number of encoder projection units 19 | dropout_rate: dropout rate (0.0 ~ 1.0) 20 | reduce: reduce the CTC loss into a scalar 21 | """ 22 | assert check_argument_types() 23 | super().__init__() 24 | eprojs = encoder_output_size 25 | self.dropout_rate = dropout_rate 26 | self.ctc_lo = torch.nn.Linear(eprojs, odim) 27 | 28 | reduction_type = "sum" if reduce else "none" 29 | self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) 30 | 31 | def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, 32 | ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: 33 | """Calculate CTC loss. 34 | 35 | Args: 36 | hs_pad: batch of padded hidden state sequences (B, Tmax, D) 37 | hlens: batch of lengths of hidden state sequences (B) 38 | ys_pad: batch of padded character id sequence tensor (B, Lmax) 39 | ys_lens: batch of lengths of character sequence (B) 40 | """ 41 | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) 42 | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) 43 | # ys_hat: (B, L, D) -> (L, B, D) 44 | ys_hat = ys_hat.transpose(0, 1) 45 | ys_hat = ys_hat.log_softmax(2) 46 | loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) 47 | # Batch-size average 48 | loss = loss / ys_hat.size(1) 49 | return loss 50 | 51 | def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 52 | """log_softmax of frame activations 53 | 54 | Args: 55 | Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 56 | Returns: 57 | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) 58 | """ 59 | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) 60 | 61 | def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: 62 | """argmax of frame activations 63 | 64 | Args: 65 | torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) 66 | Returns: 67 | torch.Tensor: argmax applied 2d tensor (B, Tmax) 68 | """ 69 | return torch.argmax(self.ctc_lo(hs_pad), dim=2) 70 | -------------------------------------------------------------------------------- /wenet/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Decoder self-attention layer definition.""" 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class DecoderLayer(nn.Module): 14 | """Single decoder layer module. 15 | 16 | Args: 17 | size (int): Input dimension. 18 | self_attn (torch.nn.Module): Self-attention module instance. 19 | `MultiHeadedAttention` instance can be used as the argument. 20 | src_attn (torch.nn.Module): Inter-attention module instance. 21 | `MultiHeadedAttention` instance can be used as the argument. 22 | feed_forward (torch.nn.Module): Feed-forward module instance. 23 | `PositionwiseFeedForward` instance can be used as the argument. 24 | dropout_rate (float): Dropout rate. 25 | normalize_before (bool): 26 | True: use layer_norm before each sub-block. 27 | False: to use layer_norm after each sub-block. 28 | concat_after (bool): Whether to concat attention layer's inpu 29 | and output. 30 | True: x -> x + linear(concat(x, att(x))) 31 | False: x -> x + att(x) 32 | """ 33 | def __init__( 34 | self, 35 | size: int, 36 | self_attn: nn.Module, 37 | src_attn: nn.Module, 38 | feed_forward: nn.Module, 39 | dropout_rate: float, 40 | normalize_before: bool = True, 41 | concat_after: bool = False, 42 | ): 43 | """Construct an DecoderLayer object.""" 44 | super().__init__() 45 | self.size = size 46 | self.self_attn = self_attn 47 | self.src_attn = src_attn 48 | self.feed_forward = feed_forward 49 | self.norm1 = nn.LayerNorm(size, eps=1e-12) 50 | self.norm2 = nn.LayerNorm(size, eps=1e-12) 51 | self.norm3 = nn.LayerNorm(size, eps=1e-12) 52 | self.dropout = nn.Dropout(dropout_rate) 53 | self.normalize_before = normalize_before 54 | self.concat_after = concat_after 55 | self.concat_linear1 = nn.Linear(size + size, size) 56 | self.concat_linear2 = nn.Linear(size + size, size) 57 | 58 | def forward( 59 | self, 60 | tgt: torch.Tensor, 61 | tgt_mask: torch.Tensor, 62 | memory: torch.Tensor, 63 | memory_mask: torch.Tensor, 64 | cache: Optional[torch.Tensor] = None 65 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 66 | """Compute decoded features. 67 | 68 | Args: 69 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 70 | tgt_mask (torch.Tensor): Mask for input tensor 71 | (#batch, maxlen_out). 72 | memory (torch.Tensor): Encoded memory 73 | (#batch, maxlen_in, size). 74 | memory_mask (torch.Tensor): Encoded memory mask 75 | (#batch, maxlen_in). 76 | cache (torch.Tensor): cached tensors. 77 | (#batch, maxlen_out - 1, size). 78 | 79 | Returns: 80 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 81 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 82 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 83 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 84 | 85 | """ 86 | residual = tgt 87 | if self.normalize_before: 88 | tgt = self.norm1(tgt) 89 | 90 | if cache is None: 91 | tgt_q = tgt 92 | tgt_q_mask = tgt_mask 93 | else: 94 | # compute only the last frame query keeping dim: max_time_out -> 1 95 | #print('cache shape***'*20, cache.shape) 96 | #print((tgt.shape[0], tgt.shape[1] - 1, self.size) ) 97 | assert cache.shape == ( 98 | tgt.shape[0], 99 | tgt.shape[1] - 1, 100 | self.size, 101 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 102 | tgt_q = tgt[:, -1:, :] 103 | residual = residual[:, -1:, :] 104 | tgt_q_mask = tgt_mask[:, -1:, :] 105 | 106 | if self.concat_after: 107 | tgt_concat = torch.cat( 108 | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1) 109 | x = residual + self.concat_linear1(tgt_concat) 110 | else: 111 | x = residual + self.dropout( 112 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) 113 | if not self.normalize_before: 114 | x = self.norm1(x) 115 | 116 | residual = x 117 | if self.normalize_before: 118 | x = self.norm2(x) 119 | if self.concat_after: 120 | x_concat = torch.cat( 121 | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1) 122 | x = residual + self.concat_linear2(x_concat) 123 | else: 124 | x = residual + self.dropout( 125 | self.src_attn(x, memory, memory, memory_mask)) 126 | if not self.normalize_before: 127 | x = self.norm2(x) 128 | 129 | residual = x 130 | if self.normalize_before: 131 | x = self.norm3(x) 132 | x = residual + self.dropout(self.feed_forward(x)) 133 | if not self.normalize_before: 134 | x = self.norm3(x) 135 | 136 | if cache is not None: 137 | x = torch.cat([cache, x], dim=1) 138 | 139 | return x, tgt_mask, memory, memory_mask 140 | -------------------------------------------------------------------------------- /wenet/transformer/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Positonal Encoding Module.""" 7 | 8 | import math 9 | from typing import Tuple 10 | 11 | import torch 12 | 13 | 14 | class PositionalEncoding(torch.nn.Module): 15 | """Positional encoding. 16 | 17 | :param int d_model: embedding dim 18 | :param float dropout_rate: dropout rate 19 | :param int max_len: maximum input length 20 | 21 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 22 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 23 | """ 24 | def __init__(self, 25 | d_model: int, 26 | dropout_rate: float, 27 | max_len: int = 5000, 28 | reverse: bool = False): 29 | """Construct an PositionalEncoding object.""" 30 | super().__init__() 31 | self.d_model = d_model 32 | self.xscale = math.sqrt(self.d_model) 33 | self.dropout = torch.nn.Dropout(p=dropout_rate) 34 | self.max_len = max_len 35 | 36 | self.pe = torch.zeros(self.max_len, self.d_model) 37 | position = torch.arange(0, self.max_len, 38 | dtype=torch.float32).unsqueeze(1) 39 | div_term = torch.exp( 40 | torch.arange(0, self.d_model, 2, dtype=torch.float32) * 41 | -(math.log(10000.0) / self.d_model)) 42 | self.pe[:, 0::2] = torch.sin(position * div_term) 43 | self.pe[:, 1::2] = torch.cos(position * div_term) 44 | self.pe = self.pe.unsqueeze(0) 45 | 46 | def forward(self, 47 | x: torch.Tensor, 48 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 49 | """Add positional encoding. 50 | 51 | Args: 52 | x (torch.Tensor): Input. Its shape is (batch, time, ...) 53 | offset (int): position offset 54 | 55 | Returns: 56 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) 57 | torch.Tensor: for compatibility to RelPositionalEncoding 58 | """ 59 | assert offset + x.size(1) < self.max_len 60 | self.pe = self.pe.to(x.device) 61 | pos_emb = self.pe[:, offset:offset + x.size(1)] 62 | x = x * self.xscale + pos_emb 63 | return self.dropout(x), self.dropout(pos_emb) 64 | 65 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 66 | """ For getting encoding in a streaming fashion 67 | 68 | Attention!!!!! 69 | we apply dropout only once at the whole utterance level in a none 70 | streaming way, but will call this function several times with 71 | increasing input size in a streaming scenario, so the dropout will 72 | be applied several times. 73 | 74 | Args: 75 | offset (int): start offset 76 | size (int): requried size of position encoding 77 | 78 | Returns: 79 | torch.Tensor: Corresponding encoding 80 | """ 81 | assert offset + size < self.max_len 82 | return self.dropout(self.pe[:, offset:offset + size]) 83 | 84 | 85 | class RelPositionalEncoding(PositionalEncoding): 86 | """Relative positional encoding module. 87 | See : Appendix B in https://arxiv.org/abs/1901.02860 88 | Args: 89 | d_model (int): Embedding dimension. 90 | dropout_rate (float): Dropout rate. 91 | max_len (int): Maximum input length. 92 | """ 93 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): 94 | """Initialize class.""" 95 | super().__init__(d_model, dropout_rate, max_len, reverse=True) 96 | 97 | def forward(self, 98 | x: torch.Tensor, 99 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 100 | """Compute positional encoding. 101 | Args: 102 | x (torch.Tensor): Input tensor (batch, time, `*`). 103 | Returns: 104 | torch.Tensor: Encoded tensor (batch, time, `*`). 105 | torch.Tensor: Positional embedding tensor (1, time, `*`). 106 | """ 107 | assert offset + x.size(1) < self.max_len 108 | self.pe = self.pe.to(x.device) 109 | x = x * self.xscale 110 | pos_emb = self.pe[:, offset:offset + x.size(1)] 111 | return self.dropout(x), self.dropout(pos_emb) 112 | 113 | 114 | class NoPositionalEncoding(torch.nn.Module): 115 | """ No position encoding 116 | """ 117 | def __init__(self, d_model: int, dropout_rate: float): 118 | super().__init__() 119 | self.d_model = d_model 120 | self.dropout = torch.nn.Dropout(p=dropout_rate) 121 | 122 | def forward(self, 123 | x: torch.Tensor, 124 | offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """ Just return zero vector for interface compatibility 126 | """ 127 | pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) 128 | return self.dropout(x), pos_emb 129 | 130 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 131 | return torch.zeros(1, size, self.d_model) 132 | -------------------------------------------------------------------------------- /wenet/transformer/encoder_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Encoder self-attention layer definition.""" 7 | 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class TransformerEncoderLayer(nn.Module): 15 | """Encoder layer module. 16 | 17 | Args: 18 | size (int): Input dimension. 19 | self_attn (torch.nn.Module): Self-attention module instance. 20 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 21 | instance can be used as the argument. 22 | feed_forward (torch.nn.Module): Feed-forward module instance. 23 | `PositionwiseFeedForward`, instance can be used as the argument. 24 | dropout_rate (float): Dropout rate. 25 | normalize_before (bool): 26 | True: use layer_norm before each sub-block. 27 | False: to use layer_norm after each sub-block. 28 | concat_after (bool): Whether to concat attention layer's input and 29 | output. 30 | True: x -> x + linear(concat(x, att(x))) 31 | False: x -> x + att(x) 32 | 33 | """ 34 | def __init__( 35 | self, 36 | size: int, 37 | self_attn: torch.nn.Module, 38 | feed_forward: torch.nn.Module, 39 | dropout_rate: float, 40 | normalize_before: bool = True, 41 | concat_after: bool = False, 42 | ): 43 | """Construct an EncoderLayer object.""" 44 | super().__init__() 45 | self.self_attn = self_attn 46 | self.feed_forward = feed_forward 47 | self.norm1 = nn.LayerNorm(size, eps=1e-12) 48 | self.norm2 = nn.LayerNorm(size, eps=1e-12) 49 | self.dropout = nn.Dropout(dropout_rate) 50 | self.size = size 51 | self.normalize_before = normalize_before 52 | self.concat_after = concat_after 53 | # concat_linear may be not used in forward fuction, 54 | # but will be saved in the *.pt 55 | self.concat_linear = nn.Linear(size + size, size) 56 | 57 | def forward( 58 | self, 59 | x: torch.Tensor, 60 | mask: torch.Tensor, 61 | pos_emb: torch.Tensor, 62 | mask_pad: Optional[torch.Tensor] = None, 63 | output_cache: Optional[torch.Tensor] = None, 64 | cnn_cache: Optional[torch.Tensor] = None, 65 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 66 | """Compute encoded features. 67 | 68 | Args: 69 | x (torch.Tensor): Input tensor (#batch, time, size). 70 | mask (torch.Tensor): Mask tensor for the input (#batch, time). 71 | pos_emb (torch.Tensor): just for interface compatibility 72 | to ConformerEncoderLayer 73 | mask_pad (torch.Tensor): does not used in transformer layer, 74 | just for unified api with conformer. 75 | output_cache (torch.Tensor): Cache tensor of the output 76 | (#batch, time2, size), time2 < time in x. 77 | cnn_cache (torch.Tensor): not used here, it's for interface 78 | compatibility to ConformerEncoderLayer 79 | Returns: 80 | torch.Tensor: Output tensor (#batch, time, size). 81 | torch.Tensor: Mask tensor (#batch, time). 82 | 83 | """ 84 | residual = x 85 | if self.normalize_before: 86 | x = self.norm1(x) 87 | 88 | if output_cache is None: 89 | x_q = x 90 | else: 91 | assert output_cache.size(0) == x.size(0) 92 | assert output_cache.size(2) == self.size 93 | assert output_cache.size(1) < x.size(1) 94 | chunk = x.size(1) - output_cache.size(1) 95 | x_q = x[:, -chunk:, :] 96 | residual = residual[:, -chunk:, :] 97 | mask = mask[:, -chunk:, :] 98 | 99 | if self.concat_after: 100 | x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) 101 | x = residual + self.concat_linear(x_concat) 102 | else: 103 | x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) 104 | if not self.normalize_before: 105 | x = self.norm1(x) 106 | 107 | residual = x 108 | if self.normalize_before: 109 | x = self.norm2(x) 110 | x = residual + self.dropout(self.feed_forward(x)) 111 | if not self.normalize_before: 112 | x = self.norm2(x) 113 | 114 | if output_cache is not None: 115 | x = torch.cat([output_cache, x], dim=1) 116 | 117 | fake_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 118 | return x, mask, fake_cnn_cache 119 | 120 | 121 | class ConformerEncoderLayer(nn.Module): 122 | """Encoder layer module. 123 | Args: 124 | size (int): Input dimension. 125 | self_attn (torch.nn.Module): Self-attention module instance. 126 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` 127 | instance can be used as the argument. 128 | feed_forward (torch.nn.Module): Feed-forward module instance. 129 | `PositionwiseFeedForward` instance can be used as the argument. 130 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module 131 | instance. 132 | `PositionwiseFeedForward` instance can be used as the argument. 133 | conv_module (torch.nn.Module): Convolution module instance. 134 | `ConvlutionModule` instance can be used as the argument. 135 | dropout_rate (float): Dropout rate. 136 | normalize_before (bool): 137 | True: use layer_norm before each sub-block. 138 | False: use layer_norm after each sub-block. 139 | concat_after (bool): Whether to concat attention layer's input and 140 | output. 141 | True: x -> x + linear(concat(x, att(x))) 142 | False: x -> x + att(x) 143 | """ 144 | def __init__( 145 | self, 146 | size: int, 147 | self_attn: torch.nn.Module, 148 | feed_forward: Optional[nn.Module] = None, 149 | feed_forward_macaron: Optional[nn.Module] = None, 150 | conv_module: Optional[nn.Module] = None, 151 | dropout_rate: float = 0.1, 152 | normalize_before: bool = True, 153 | concat_after: bool = False, 154 | ): 155 | """Construct an EncoderLayer object.""" 156 | super().__init__() 157 | self.self_attn = self_attn 158 | self.feed_forward = feed_forward 159 | self.feed_forward_macaron = feed_forward_macaron 160 | self.conv_module = conv_module 161 | self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module 162 | self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module 163 | if feed_forward_macaron is not None: 164 | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) 165 | self.ff_scale = 0.5 166 | else: 167 | self.ff_scale = 1.0 168 | if self.conv_module is not None: 169 | self.norm_conv = nn.LayerNorm(size, 170 | eps=1e-12) # for the CNN module 171 | self.norm_final = nn.LayerNorm( 172 | size, eps=1e-12) # for the final output of the block 173 | self.dropout = nn.Dropout(dropout_rate) 174 | self.size = size 175 | self.normalize_before = normalize_before 176 | self.concat_after = concat_after 177 | self.concat_linear = nn.Linear(size + size, size) 178 | 179 | def forward( 180 | self, 181 | x: torch.Tensor, 182 | mask: torch.Tensor, 183 | pos_emb: torch.Tensor, 184 | mask_pad: Optional[torch.Tensor] = None, 185 | output_cache: Optional[torch.Tensor] = None, 186 | cnn_cache: Optional[torch.Tensor] = None, 187 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 188 | """Compute encoded features. 189 | 190 | Args: 191 | x (torch.Tensor): (#batch, time, size) 192 | mask (torch.Tensor): Mask tensor for the input (#batch, time,time). 193 | pos_emb (torch.Tensor): positional encoding, must not be None 194 | for ConformerEncoderLayer. 195 | mask_pad (torch.Tensor): batch padding mask used for conv module. 196 | (#batch, 1,time) 197 | output_cache (torch.Tensor): Cache tensor of the output 198 | (#batch, time2, size), time2 < time in x. 199 | cnn_cache (torch.Tensor): Convolution cache in conformer layer 200 | Returns: 201 | torch.Tensor: Output tensor (#batch, time, size). 202 | torch.Tensor: Mask tensor (#batch, time). 203 | """ 204 | 205 | # whether to use macaron style 206 | if self.feed_forward_macaron is not None: 207 | residual = x 208 | if self.normalize_before: 209 | x = self.norm_ff_macaron(x) 210 | x = residual + self.ff_scale * self.dropout( 211 | self.feed_forward_macaron(x)) 212 | if not self.normalize_before: 213 | x = self.norm_ff_macaron(x) 214 | 215 | # multi-headed self-attention module 216 | residual = x 217 | if self.normalize_before: 218 | x = self.norm_mha(x) 219 | 220 | if output_cache is None: 221 | x_q = x 222 | else: 223 | assert output_cache.size(0) == x.size(0) 224 | assert output_cache.size(2) == self.size 225 | assert output_cache.size(1) < x.size(1) 226 | chunk = x.size(1) - output_cache.size(1) 227 | x_q = x[:, -chunk:, :] 228 | residual = residual[:, -chunk:, :] 229 | mask = mask[:, -chunk:, :] 230 | 231 | x_att = self.self_attn(x_q, x, x, mask, pos_emb) 232 | if self.concat_after: 233 | x_concat = torch.cat((x, x_att), dim=-1) 234 | x = residual + self.concat_linear(x_concat) 235 | else: 236 | x = residual + self.dropout(x_att) 237 | if not self.normalize_before: 238 | x = self.norm_mha(x) 239 | 240 | # convolution module 241 | # Fake new cnn cache here, and then change it in conv_module 242 | new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device) 243 | if self.conv_module is not None: 244 | residual = x 245 | if self.normalize_before: 246 | x = self.norm_conv(x) 247 | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) 248 | x = residual + self.dropout(x) 249 | 250 | if not self.normalize_before: 251 | x = self.norm_conv(x) 252 | 253 | # feed forward module 254 | residual = x 255 | if self.normalize_before: 256 | x = self.norm_ff(x) 257 | 258 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 259 | if not self.normalize_before: 260 | x = self.norm_ff(x) 261 | 262 | if self.conv_module is not None: 263 | x = self.norm_final(x) 264 | 265 | if output_cache is not None: 266 | x = torch.cat([output_cache, x], dim=1) 267 | 268 | return x, mask, new_cnn_cache 269 | -------------------------------------------------------------------------------- /wenet/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Label smoothing module.""" 7 | 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class LabelSmoothingLoss(nn.Module): 13 | """Label-smoothing loss. 14 | 15 | In a standard CE loss, the label's data distribution is: 16 | [0,1,2] -> 17 | [ 18 | [1.0, 0.0, 0.0], 19 | [0.0, 1.0, 0.0], 20 | [0.0, 0.0, 1.0], 21 | ] 22 | 23 | In the smoothing version CE Loss,some probabilities 24 | are taken from the true label prob (1.0) and are divided 25 | among other labels. 26 | 27 | e.g. 28 | smoothing=0.1 29 | [0,1,2] -> 30 | [ 31 | [0.9, 0.05, 0.05], 32 | [0.05, 0.9, 0.05], 33 | [0.05, 0.05, 0.9], 34 | ] 35 | 36 | Args: 37 | size (int): the number of class 38 | padding_idx (int): padding class id which will be ignored for loss 39 | smoothing (float): smoothing rate (0.0 means the conventional CE) 40 | normalize_length (bool): 41 | normalize loss by sequence length if True 42 | normalize loss by batch size if False 43 | """ 44 | def __init__(self, 45 | size: int, 46 | padding_idx: int, 47 | smoothing: float, 48 | normalize_length: bool = False): 49 | """Construct an LabelSmoothingLoss object.""" 50 | super(LabelSmoothingLoss, self).__init__() 51 | self.criterion = nn.KLDivLoss(reduction="none") 52 | self.padding_idx = padding_idx 53 | self.confidence = 1.0 - smoothing 54 | self.smoothing = smoothing 55 | self.size = size 56 | self.normalize_length = normalize_length 57 | 58 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 59 | """Compute loss between x and target. 60 | 61 | The model outputs and data labels tensors are flatten to 62 | (batch*seqlen, class) shape and a mask is applied to the 63 | padding part which should not be calculated for loss. 64 | 65 | Args: 66 | x (torch.Tensor): prediction (batch, seqlen, class) 67 | target (torch.Tensor): 68 | target signal masked with self.padding_id (batch, seqlen) 69 | Returns: 70 | loss (torch.Tensor) : The KL loss, scalar float value 71 | """ 72 | assert x.size(2) == self.size 73 | batch_size = x.size(0) 74 | x = x.view(-1, self.size) 75 | target = target.view(-1) 76 | # use zeros_like instead of torch.no_grad() for true_dist, 77 | # since no_grad() can not be exported by JIT 78 | true_dist = torch.zeros_like(x) 79 | true_dist.fill_(self.smoothing / (self.size - 1)) 80 | ignore = target == self.padding_idx # (B,) 81 | total = len(target) - ignore.sum().item() 82 | target = target.masked_fill(ignore, 0) # avoid -1 index 83 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 84 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 85 | denom = total if self.normalize_length else batch_size 86 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 87 | -------------------------------------------------------------------------------- /wenet/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | """Positionwise feed forward layer definition.""" 7 | 8 | import torch 9 | 10 | 11 | class PositionwiseFeedForward(torch.nn.Module): 12 | """Positionwise feed forward layer. 13 | 14 | FeedForward are appied on each position of the sequence. 15 | The output dim is same with the input dim. 16 | 17 | Args: 18 | idim (int): Input dimenstion. 19 | hidden_units (int): The number of hidden units. 20 | dropout_rate (float): Dropout rate. 21 | activation (torch.nn.Module): Activation function 22 | """ 23 | def __init__(self, 24 | idim: int, 25 | hidden_units: int, 26 | dropout_rate: float, 27 | activation: torch.nn.Module = torch.nn.ReLU()): 28 | """Construct a PositionwiseFeedForward object.""" 29 | super(PositionwiseFeedForward, self).__init__() 30 | self.w_1 = torch.nn.Linear(idim, hidden_units) 31 | self.activation = activation 32 | self.dropout = torch.nn.Dropout(dropout_rate) 33 | self.w_2 = torch.nn.Linear(hidden_units, idim) 34 | 35 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 36 | """Forward function. 37 | 38 | Args: 39 | xs: input tensor (B, L, D) 40 | Returns: 41 | output tensor, (B, L, D) 42 | """ 43 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 44 | -------------------------------------------------------------------------------- /wenet/transformer/subsampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 5 | # Author: di.wu@mobvoi.com (DI WU) 6 | """Subsampling layer definition.""" 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | 12 | 13 | class BaseSubsampling(torch.nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.right_context = 0 17 | self.subsampling_rate = 1 18 | 19 | def position_encoding(self, offset: int, size: int) -> torch.Tensor: 20 | return self.pos_enc.position_encoding(offset, size) 21 | 22 | 23 | class LinearNoSubsampling(BaseSubsampling): 24 | """Linear transform the input without subsampling 25 | 26 | Args: 27 | idim (int): Input dimension. 28 | odim (int): Output dimension. 29 | dropout_rate (float): Dropout rate. 30 | 31 | """ 32 | def __init__(self, idim: int, odim: int, dropout_rate: float, 33 | pos_enc_class: torch.nn.Module): 34 | """Construct an linear object.""" 35 | super().__init__() 36 | self.out = torch.nn.Sequential( 37 | torch.nn.Linear(idim, odim), 38 | torch.nn.LayerNorm(odim, eps=1e-12), 39 | torch.nn.Dropout(dropout_rate), 40 | ) 41 | self.pos_enc = pos_enc_class 42 | self.right_context = 0 43 | self.subsampling_rate = 1 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor, 48 | x_mask: torch.Tensor, 49 | offset: int = 0 50 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 51 | """Input x. 52 | 53 | Args: 54 | x (torch.Tensor): Input tensor (#batch, time, idim). 55 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 56 | 57 | Returns: 58 | torch.Tensor: linear input tensor (#batch, time', odim), 59 | where time' = time . 60 | torch.Tensor: linear input mask (#batch, 1, time'), 61 | where time' = time . 62 | 63 | """ 64 | x = self.out(x) 65 | x, pos_emb = self.pos_enc(x, offset) 66 | return x, pos_emb, x_mask 67 | 68 | 69 | class Conv2dSubsampling4(BaseSubsampling): 70 | """Convolutional 2D subsampling (to 1/4 length). 71 | 72 | Args: 73 | idim (int): Input dimension. 74 | odim (int): Output dimension. 75 | dropout_rate (float): Dropout rate. 76 | 77 | """ 78 | def __init__(self, idim: int, odim: int, dropout_rate: float, 79 | pos_enc_class: torch.nn.Module): 80 | """Construct an Conv2dSubsampling4 object.""" 81 | super().__init__() 82 | self.conv = torch.nn.Sequential( 83 | torch.nn.Conv2d(1, odim, 3, 2), 84 | torch.nn.ReLU(), 85 | torch.nn.Conv2d(odim, odim, 3, 2), 86 | torch.nn.ReLU(), 87 | ) 88 | self.out = torch.nn.Sequential( 89 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) 90 | self.pos_enc = pos_enc_class 91 | # The right context for every conv layer is computed by: 92 | # (kernel_size - 1) * frame_rate_of_this_layer 93 | self.subsampling_rate = 4 94 | # 6 = (3 - 1) * 1 + (3 - 1) * 2 95 | self.right_context = 6 96 | 97 | def forward( 98 | self, 99 | x: torch.Tensor, 100 | x_mask: torch.Tensor, 101 | offset: int = 0 102 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 103 | """Subsample x. 104 | 105 | Args: 106 | x (torch.Tensor): Input tensor (#batch, time, idim). 107 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 108 | 109 | Returns: 110 | torch.Tensor: Subsampled tensor (#batch, time', odim), 111 | where time' = time // 4. 112 | torch.Tensor: Subsampled mask (#batch, 1, time'), 113 | where time' = time // 4. 114 | torch.Tensor: positional encoding 115 | 116 | """ 117 | x = x.unsqueeze(1) # (b, c=1, t, f) 118 | x = self.conv(x) 119 | b, c, t, f = x.size() 120 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 121 | x, pos_emb = self.pos_enc(x, offset) 122 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] 123 | 124 | 125 | class Conv2dSubsampling6(BaseSubsampling): 126 | """Convolutional 2D subsampling (to 1/6 length). 127 | Args: 128 | idim (int): Input dimension. 129 | odim (int): Output dimension. 130 | dropout_rate (float): Dropout rate. 131 | pos_enc (torch.nn.Module): Custom position encoding layer. 132 | """ 133 | def __init__(self, idim: int, odim: int, dropout_rate: float, 134 | pos_enc_class: torch.nn.Module): 135 | """Construct an Conv2dSubsampling6 object.""" 136 | super().__init__() 137 | self.conv = torch.nn.Sequential( 138 | torch.nn.Conv2d(1, odim, 3, 2), 139 | torch.nn.ReLU(), 140 | torch.nn.Conv2d(odim, odim, 5, 3), 141 | torch.nn.ReLU(), 142 | ) 143 | self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), 144 | odim) 145 | self.pos_enc = pos_enc_class 146 | # 10 = (3 - 1) * 1 + (5 - 1) * 2 147 | self.subsampling_rate = 6 148 | self.right_context = 10 149 | 150 | def forward( 151 | self, 152 | x: torch.Tensor, 153 | x_mask: torch.Tensor, 154 | offset: int = 0 155 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 156 | """Subsample x. 157 | Args: 158 | x (torch.Tensor): Input tensor (#batch, time, idim). 159 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 160 | 161 | Returns: 162 | torch.Tensor: Subsampled tensor (#batch, time', odim), 163 | where time' = time // 6. 164 | torch.Tensor: Subsampled mask (#batch, 1, time'), 165 | where time' = time // 6. 166 | torch.Tensor: positional encoding 167 | """ 168 | x = x.unsqueeze(1) # (b, c, t, f) 169 | x = self.conv(x) 170 | b, c, t, f = x.size() 171 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 172 | x, pos_emb = self.pos_enc(x, offset) 173 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] 174 | 175 | 176 | class Conv2dSubsampling8(BaseSubsampling): 177 | """Convolutional 2D subsampling (to 1/8 length). 178 | 179 | Args: 180 | idim (int): Input dimension. 181 | odim (int): Output dimension. 182 | dropout_rate (float): Dropout rate. 183 | 184 | """ 185 | def __init__(self, idim: int, odim: int, dropout_rate: float, 186 | pos_enc_class: torch.nn.Module): 187 | """Construct an Conv2dSubsampling8 object.""" 188 | super().__init__() 189 | self.conv = torch.nn.Sequential( 190 | torch.nn.Conv2d(1, odim, 3, 2), 191 | torch.nn.ReLU(), 192 | torch.nn.Conv2d(odim, odim, 3, 2), 193 | torch.nn.ReLU(), 194 | torch.nn.Conv2d(odim, odim, 3, 2), 195 | torch.nn.ReLU(), 196 | ) 197 | self.linear = torch.nn.Linear( 198 | odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) 199 | self.pos_enc = pos_enc_class 200 | self.subsampling_rate = 8 201 | # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 202 | self.right_context = 14 203 | 204 | def forward( 205 | self, 206 | x: torch.Tensor, 207 | x_mask: torch.Tensor, 208 | offset: int = 0 209 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 210 | """Subsample x. 211 | 212 | Args: 213 | x (torch.Tensor): Input tensor (#batch, time, idim). 214 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 215 | 216 | Returns: 217 | torch.Tensor: Subsampled tensor (#batch, time', odim), 218 | where time' = time // 8. 219 | torch.Tensor: Subsampled mask (#batch, 1, time'), 220 | where time' = time // 8. 221 | torch.Tensor: positional encoding 222 | """ 223 | x = x.unsqueeze(1) # (b, c, t, f) 224 | x = self.conv(x) 225 | b, c, t, f = x.size() 226 | x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) 227 | x, pos_emb = self.pos_enc(x, offset) 228 | return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] 229 | -------------------------------------------------------------------------------- /wenet/transformer/swish.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2020 Johns Hopkins University (Shinji Watanabe) 5 | # Northwestern Polytechnical University (Pengcheng Guo) 6 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 7 | """Swish() activation function for Conformer.""" 8 | 9 | import torch 10 | 11 | 12 | class Swish(torch.nn.Module): 13 | """Construct an Swish object.""" 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | """Return Swish activation function.""" 16 | return x * torch.sigmoid(x) 17 | -------------------------------------------------------------------------------- /wenet/utils/__pycache__/checkpoint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/checkpoint.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/cmvn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/cmvn.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/executor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/executor.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/file_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/file_utils.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/mask.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/mask.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tzenthin/ASR_python_deploy/38b71f826a1e0867727dbe5a0521281178e32609/wenet/utils/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /wenet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 3 | 4 | import logging 5 | import os 6 | import re 7 | 8 | import yaml 9 | import torch 10 | 11 | 12 | def load_checkpoint(model: torch.nn.Module, path: str) -> dict: 13 | if torch.cuda.is_available(): 14 | logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) 15 | checkpoint = torch.load(path) 16 | else: 17 | logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) 18 | checkpoint = torch.load(path, map_location='cpu') 19 | model.load_state_dict(checkpoint) 20 | info_path = re.sub('.pt$', '.yaml', path) 21 | configs = {} 22 | if os.path.exists(info_path): 23 | with open(info_path, 'r') as fin: 24 | configs = yaml.load(fin, Loader=yaml.FullLoader) 25 | return configs 26 | 27 | 28 | def save_checkpoint(model: torch.nn.Module, path: str, infos=None): 29 | ''' 30 | Args: 31 | infos (dict or None): any info you want to save. 32 | ''' 33 | logging.info('Checkpoint: save to checkpoint %s' % path) 34 | if isinstance(model, torch.nn.DataParallel): 35 | state_dict = model.module.state_dict() 36 | elif isinstance(model, torch.nn.parallel.DistributedDataParallel): 37 | state_dict = model.module.state_dict() 38 | else: 39 | state_dict = model.state_dict() 40 | torch.save(state_dict, path) 41 | info_path = re.sub('.pt$', '.yaml', path) 42 | if infos is None: 43 | infos = {} 44 | with open(info_path, 'w') as fout: 45 | data = yaml.dump(infos) 46 | fout.write(data) 47 | -------------------------------------------------------------------------------- /wenet/utils/cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import math 18 | 19 | import numpy as np 20 | 21 | 22 | def _load_json_cmvn(json_cmvn_file): 23 | """ Load the json format cmvn stats file and calculate cmvn 24 | 25 | Args: 26 | json_cmvn_file: cmvn stats file in json format 27 | 28 | Returns: 29 | a numpy array of [means, vars] 30 | """ 31 | with open(json_cmvn_file) as f: 32 | cmvn_stats = json.load(f) 33 | 34 | means = cmvn_stats['mean_stat'] 35 | variance = cmvn_stats['var_stat'] 36 | count = cmvn_stats['frame_num'] 37 | for i in range(len(means)): 38 | means[i] /= count 39 | variance[i] = variance[i] / count - means[i] * means[i] 40 | if variance[i] < 1.0e-20: 41 | variance[i] = 1.0e-20 42 | variance[i] = 1.0 / math.sqrt(variance[i]) 43 | cmvn = np.array([means, variance]) 44 | return cmvn 45 | 46 | 47 | def _load_kaldi_cmvn(kaldi_cmvn_file): 48 | """ Load the kaldi format cmvn stats file and calculate cmvn 49 | 50 | Args: 51 | kaldi_cmvn_file: kaldi text style global cmvn file, which 52 | is generated by: 53 | compute-cmvn-stats --binary=false scp:feats.scp global_cmvn 54 | 55 | Returns: 56 | a numpy array of [means, vars] 57 | """ 58 | means = [] 59 | variance = [] 60 | with open(kaldi_cmvn_file, 'r') as fid: 61 | # kaldi binary file start with '\0B' 62 | if fid.read(2) == '\0B': 63 | logging.error('kaldi cmvn binary file is not supported, please ' 64 | 'recompute it by: compute-cmvn-stats --binary=false ' 65 | ' scp:feats.scp global_cmvn') 66 | sys.exit(1) 67 | fid.seek(0) 68 | arr = fid.read().split() 69 | assert (arr[0] == '[') 70 | assert (arr[-2] == '0') 71 | assert (arr[-1] == ']') 72 | feat_dim = int((len(arr) - 2 - 2) / 2) 73 | for i in range(1, feat_dim + 1): 74 | means.append(float(arr[i])) 75 | count = float(arr[feat_dim + 1]) 76 | for i in range(feat_dim + 2, 2 * feat_dim + 2): 77 | variance.append(float(arr[i])) 78 | 79 | for i in range(len(means)): 80 | means[i] /= count 81 | variance[i] = variance[i] / count - means[i] * means[i] 82 | if variance[i] < 1.0e-20: 83 | variance[i] = 1.0e-20 84 | variance[i] = 1.0 / math.sqrt(variance[i]) 85 | cmvn = np.array([means, variance]) 86 | return cmvn 87 | 88 | 89 | def load_cmvn(cmvn_file, is_json): 90 | if is_json: 91 | cmvn = _load_json_cmvn(cmvn_file) 92 | else: 93 | cmvn = _load_kaldi_cmvn(cmvn_file) 94 | return cmvn[0], cmvn[1] 95 | -------------------------------------------------------------------------------- /wenet/utils/common.py: -------------------------------------------------------------------------------- 1 | """Unility functions for Transformer.""" 2 | 3 | import math 4 | from typing import Tuple, List 5 | 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | IGNORE_ID = -1 10 | 11 | 12 | def pad_list(xs: List[torch.Tensor], pad_value: int): 13 | """Perform padding for the list of tensors. 14 | 15 | Args: 16 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 17 | pad_value (float): Value for padding. 18 | 19 | Returns: 20 | Tensor: Padded tensor (B, Tmax, `*`). 21 | 22 | Examples: 23 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 24 | >>> x 25 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 26 | >>> pad_list(x, 0) 27 | tensor([[1., 1., 1., 1.], 28 | [1., 1., 0., 0.], 29 | [1., 0., 0., 0.]]) 30 | 31 | """ 32 | n_batch = len(xs) 33 | max_len = max([x.size(0) for x in xs]) 34 | pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) 35 | pad = pad.fill_(pad_value) 36 | for i in range(n_batch): 37 | pad[i, :xs[i].size(0)] = xs[i] 38 | 39 | return pad 40 | 41 | 42 | def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, 43 | ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: 44 | """Add and labels. 45 | 46 | Args: 47 | ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) 48 | sos (int): index of 49 | eos (int): index of 50 | ignore_id (int): index of padding 51 | 52 | Returns: 53 | ys_in (torch.Tensor) : (B, Lmax + 1) 54 | ys_out (torch.Tensor) : (B, Lmax + 1) 55 | 56 | Examples: 57 | >>> sos_id = 10 58 | >>> eos_id = 11 59 | >>> ignore_id = -1 60 | >>> ys_pad 61 | tensor([[ 1, 2, 3, 4, 5], 62 | [ 4, 5, 6, -1, -1], 63 | [ 7, 8, 9, -1, -1]], dtype=torch.int32) 64 | >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) 65 | >>> ys_in 66 | tensor([[10, 1, 2, 3, 4, 5], 67 | [10, 4, 5, 6, 11, 11], 68 | [10, 7, 8, 9, 11, 11]]) 69 | >>> ys_out 70 | tensor([[ 1, 2, 3, 4, 5, 11], 71 | [ 4, 5, 6, 11, -1, -1], 72 | [ 7, 8, 9, 11, -1, -1]]) 73 | """ 74 | _sos = torch.tensor([sos], 75 | dtype=torch.long, 76 | requires_grad=False, 77 | device=ys_pad.device) 78 | _eos = torch.tensor([eos], 79 | dtype=torch.long, 80 | requires_grad=False, 81 | device=ys_pad.device) 82 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys 83 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] 84 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] 85 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) 86 | 87 | 88 | def reverse_pad_list(ys_pad: torch.Tensor, 89 | ys_lens: torch.Tensor, 90 | pad_value: float = -1.0) -> torch.Tensor: 91 | """Reverse padding for the list of tensors. 92 | 93 | Args: 94 | ys_pad (tensor): The padded tensor (B, Tokenmax). 95 | ys_lens (tensor): The lens of token seqs (B) 96 | pad_value (int): Value for padding. 97 | 98 | Returns: 99 | Tensor: Padded tensor (B, Tokenmax). 100 | 101 | Examples: 102 | >>> x 103 | tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) 104 | >>> pad_list(x, 0) 105 | tensor([[4, 3, 2, 1], 106 | [7, 6, 5, 0], 107 | [9, 8, 0, 0]]) 108 | 109 | """ 110 | r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) 111 | for y, i in zip(ys_pad, ys_lens)], True, 112 | pad_value) 113 | return r_ys_pad 114 | 115 | 116 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 117 | ignore_label: int) -> float: 118 | """Calculate accuracy. 119 | 120 | Args: 121 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 122 | pad_targets (LongTensor): Target label tensors (B, Lmax, D). 123 | ignore_label (int): Ignore label id. 124 | 125 | Returns: 126 | float: Accuracy value (0.0 - 1.0). 127 | 128 | """ 129 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 130 | pad_outputs.size(1)).argmax(2) 131 | mask = pad_targets != ignore_label 132 | numerator = torch.sum( 133 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 134 | denominator = torch.sum(mask) 135 | return float(numerator) / float(denominator) 136 | 137 | 138 | def get_activation(act): 139 | """Return activation function.""" 140 | # Lazy load to avoid unused import 141 | from wenet.transformer.swish import Swish 142 | 143 | activation_funcs = { 144 | "hardtanh": torch.nn.Hardtanh, 145 | "tanh": torch.nn.Tanh, 146 | "relu": torch.nn.ReLU, 147 | "selu": torch.nn.SELU, 148 | "swish": getattr(torch.nn, "SiLU", Swish), 149 | "gelu": torch.nn.GELU 150 | } 151 | 152 | return activation_funcs[act]() 153 | 154 | 155 | def get_subsample(config): 156 | input_layer = config["encoder_conf"]["input_layer"] 157 | assert input_layer in ["conv2d", "conv2d6", "conv2d8"] 158 | if input_layer == "conv2d": 159 | return 4 160 | elif input_layer == "conv2d6": 161 | return 6 162 | elif input_layer == "conv2d8": 163 | return 8 164 | 165 | 166 | def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: 167 | new_hyp: List[int] = [] 168 | cur = 0 169 | while cur < len(hyp): 170 | if hyp[cur] != 0: 171 | new_hyp.append(hyp[cur]) 172 | prev = cur 173 | while cur < len(hyp) and hyp[cur] == hyp[prev]: 174 | cur += 1 175 | return new_hyp 176 | 177 | 178 | def log_add(args: List[int]) -> float: 179 | """ 180 | Stable log add 181 | """ 182 | if all(a == -float('inf') for a in args): 183 | return -float('inf') 184 | a_max = max(args) 185 | lsp = math.log(sum(math.exp(a - a_max) for a in args)) 186 | return a_max + lsp 187 | -------------------------------------------------------------------------------- /wenet/utils/config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | def override_config(configs, override_list): 4 | new_configs = copy.deepcopy(configs) 5 | for item in override_list: 6 | arr = item.split() 7 | if len(arr) != 2: 8 | print(f"the overrive {item} format not correct, skip it") 9 | continue 10 | keys = arr[0].split('.') 11 | s_configs = new_configs 12 | for i, key in enumerate(keys): 13 | if key not in s_configs: 14 | print(f"the overrive {item} format not correct, skip it") 15 | if i == len(keys) - 1: 16 | param_type = type(s_configs[key]) 17 | s_configs[key] = param_type(arr[1]) 18 | print(f"override {arr[0]} with {arr[1]}") 19 | else: 20 | s_configs = s_configs[key] 21 | return new_configs 22 | -------------------------------------------------------------------------------- /wenet/utils/ctc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Di Wu) 3 | 4 | import numpy as np 5 | import torch 6 | 7 | def insert_blank(label, blank_id=0): 8 | """Insert blank token between every two label token.""" 9 | label = np.expand_dims(label, 1) 10 | blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id 11 | label = np.concatenate([blanks, label], axis=1) 12 | label = label.reshape(-1) 13 | label = np.append(label, label[0]) 14 | return label 15 | 16 | def forced_align(ctc_probs: torch.Tensor, 17 | y: torch.Tensor, 18 | blank_id=0) -> list: 19 | """ctc forced alignment. 20 | 21 | Args: 22 | torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) 23 | torch.Tensor y: id sequence tensor 1d tensor (L) 24 | int blank_id: blank symbol index 25 | Returns: 26 | torch.Tensor: alignment result 27 | """ 28 | y_insert_blank = insert_blank(y, blank_id) 29 | 30 | log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) 31 | log_alpha = log_alpha - float('inf') # log of zero 32 | state_path = (torch.zeros( 33 | (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 34 | ) # state path 35 | 36 | # init start state 37 | log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] 38 | log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] 39 | 40 | for t in range(1, ctc_probs.size(0)): 41 | for s in range(len(y_insert_blank)): 42 | if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ 43 | s] == y_insert_blank[s - 2]: 44 | candidates = torch.tensor( 45 | [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) 46 | prev_state = [s, s - 1] 47 | else: 48 | candidates = torch.tensor([ 49 | log_alpha[t - 1, s], 50 | log_alpha[t - 1, s - 1], 51 | log_alpha[t - 1, s - 2], 52 | ]) 53 | prev_state = [s, s - 1, s - 2] 54 | log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] 55 | state_path[t, s] = prev_state[torch.argmax(candidates)] 56 | 57 | state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) 58 | 59 | candidates = torch.tensor([ 60 | log_alpha[-1, len(y_insert_blank) - 1], 61 | log_alpha[-1, len(y_insert_blank) - 2] 62 | ]) 63 | prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] 64 | state_seq[-1] = prev_state[torch.argmax(candidates)] 65 | for t in range(ctc_probs.size(0) - 2, -1, -1): 66 | state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] 67 | 68 | output_alignment = [] 69 | for t in range(0, ctc_probs.size(0)): 70 | output_alignment.append(y_insert_blank[state_seq[t, 0]]) 71 | 72 | return output_alignment 73 | -------------------------------------------------------------------------------- /wenet/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Mobvoi Inc. All Rights Reserved. 2 | # Author: binbinzhang@mobvoi.com (Binbin Zhang) 3 | 4 | import logging 5 | from contextlib import nullcontext 6 | # if your python version < 3.7 use the below one 7 | # from contextlib import suppress as nullcontext 8 | import torch 9 | from torch.nn.utils import clip_grad_norm_ 10 | 11 | 12 | class Executor: 13 | def __init__(self): 14 | self.step = 0 15 | 16 | def train(self, model, optimizer, scheduler, data_loader, device, writer, 17 | args, scaler): 18 | ''' Train one epoch 19 | ''' 20 | model.train() 21 | clip = args.get('grad_clip', 50.0) 22 | log_interval = args.get('log_interval', 10) 23 | rank = args.get('rank', 0) 24 | epoch = args.get('epoch', 0) 25 | accum_grad = args.get('accum_grad', 1) 26 | is_distributed = args.get('is_distributed', True) 27 | use_amp = args.get('use_amp', False) 28 | logging.info('using accumulate grad, new batch size is {} times' 29 | 'larger than before'.format(accum_grad)) 30 | if use_amp: 31 | assert scaler is not None 32 | # A context manager to be used in conjunction with an instance of 33 | # torch.nn.parallel.DistributedDataParallel to be able to train 34 | # with uneven inputs across participating processes. 35 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 36 | model_context = model.join 37 | else: 38 | model_context = nullcontext 39 | num_seen_utts = 0 40 | with model_context(): 41 | for batch_idx, batch in enumerate(data_loader): 42 | key, feats, target, feats_lengths, target_lengths = batch 43 | feats = feats.to(device) 44 | target = target.to(device) 45 | feats_lengths = feats_lengths.to(device) 46 | target_lengths = target_lengths.to(device) 47 | num_utts = target_lengths.size(0) 48 | if num_utts == 0: 49 | continue 50 | context = None 51 | # Disable gradient synchronizations across DDP processes. 52 | # Within this context, gradients will be accumulated on module 53 | # variables, which will later be synchronized. 54 | if is_distributed and batch_idx % accum_grad != 0: 55 | context = model.no_sync 56 | # Used for single gpu training and DDP gradient synchronization 57 | # processes. 58 | else: 59 | context = nullcontext 60 | with context(): 61 | # autocast context 62 | # The more details about amp can be found in 63 | # https://pytorch.org/docs/stable/notes/amp_examples.html 64 | with torch.cuda.amp.autocast(scaler is not None): 65 | loss, loss_att, loss_ctc = model( 66 | feats, feats_lengths, target, target_lengths) 67 | loss = loss / accum_grad 68 | if use_amp: 69 | scaler.scale(loss).backward() 70 | else: 71 | loss.backward() 72 | 73 | num_seen_utts += num_utts 74 | if batch_idx % accum_grad == 0: 75 | if rank == 0 and writer is not None: 76 | writer.add_scalar('train_loss', loss, self.step) 77 | # Use mixed precision training 78 | if use_amp: 79 | scaler.unscale_(optimizer) 80 | grad_norm = clip_grad_norm_(model.parameters(), clip) 81 | # Must invoke scaler.update() if unscale_() is used in 82 | # the iteration to avoid the following error: 83 | # RuntimeError: unscale_() has already been called 84 | # on this optimizer since the last update(). 85 | # We don't check grad here since that if the gradient 86 | # has inf/nan values, scaler.step will skip 87 | # optimizer.step(). 88 | scaler.step(optimizer) 89 | scaler.update() 90 | else: 91 | grad_norm = clip_grad_norm_(model.parameters(), clip) 92 | if torch.isfinite(grad_norm): 93 | optimizer.step() 94 | optimizer.zero_grad() 95 | scheduler.step() 96 | self.step += 1 97 | if batch_idx % log_interval == 0: 98 | lr = optimizer.param_groups[0]['lr'] 99 | log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( 100 | epoch, batch_idx, 101 | loss.item() * accum_grad) 102 | if loss_att is not None: 103 | log_str += 'loss_att {:.6f} '.format(loss_att.item()) 104 | if loss_ctc is not None: 105 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) 106 | log_str += 'lr {:.8f} rank {}'.format(lr, rank) 107 | logging.debug(log_str) 108 | 109 | def cv(self, model, data_loader, device, args): 110 | ''' Cross validation on 111 | ''' 112 | model.eval() 113 | rank = args.get('rank', 0) 114 | epoch = args.get('epoch', 0) 115 | log_interval = args.get('log_interval', 10) 116 | # in order to avoid division by 0 117 | num_seen_utts = 1 118 | total_loss = 0.0 119 | with torch.no_grad(): 120 | for batch_idx, batch in enumerate(data_loader): 121 | key, feats, target, feats_lengths, target_lengths = batch 122 | feats = feats.to(device) 123 | target = target.to(device) 124 | feats_lengths = feats_lengths.to(device) 125 | target_lengths = target_lengths.to(device) 126 | num_utts = target_lengths.size(0) 127 | if num_utts == 0: 128 | continue 129 | loss, loss_att, loss_ctc = model(feats, feats_lengths, target, 130 | target_lengths) 131 | if torch.isfinite(loss): 132 | num_seen_utts += num_utts 133 | total_loss += loss.item() * num_utts 134 | if batch_idx % log_interval == 0: 135 | log_str = 'CV Batch {}/{} loss {:.6f} '.format( 136 | epoch, batch_idx, loss.item()) 137 | if loss_att is not None: 138 | log_str += 'loss_att {:.6f} '.format(loss_att.item()) 139 | if loss_ctc is not None: 140 | log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) 141 | log_str += 'history loss {:.6f}'.format(total_loss / 142 | num_seen_utts) 143 | log_str += ' rank {}'.format(rank) 144 | logging.debug(log_str) 145 | return total_loss, num_seen_utts 146 | -------------------------------------------------------------------------------- /wenet/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def read_lists(list_file): 17 | lists = [] 18 | with open(list_file, 'r', encoding='utf8') as fin: 19 | for line in fin: 20 | lists.append(line.strip()) 21 | return lists 22 | 23 | 24 | def read_symbol_table(symbol_table_file): 25 | symbol_table = {} 26 | with open(symbol_table_file, 'r', encoding='utf8') as fin: 27 | for line in fin: 28 | arr = line.strip().split() 29 | assert len(arr) == 2 30 | symbol_table[arr[0]] = int(arr[1]) 31 | return symbol_table 32 | -------------------------------------------------------------------------------- /wenet/utils/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Shigeki Karita 4 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | import torch 7 | 8 | 9 | def subsequent_mask( 10 | size: int, 11 | device: torch.device = torch.device("cpu"), 12 | ) -> torch.Tensor: 13 | """Create mask for subsequent steps (size, size). 14 | 15 | This mask is used only in decoder which works in an auto-regressive mode. 16 | This means the current step could only do attention with its left steps. 17 | 18 | In encoder, fully attention is used when streaming is not necessary and 19 | the sequence is not long. In this case, no attention mask is needed. 20 | 21 | When streaming is need, chunk-based attention is used in encoder. See 22 | subsequent_chunk_mask for the chunk-based attention mask. 23 | 24 | Args: 25 | size (int): size of mask 26 | str device (str): "cpu" or "cuda" or torch.Tensor.device 27 | dtype (torch.device): result dtype 28 | 29 | Returns: 30 | torch.Tensor: mask 31 | 32 | Examples: 33 | >>> subsequent_mask(3) 34 | [[1, 0, 0], 35 | [1, 1, 0], 36 | [1, 1, 1]] 37 | """ 38 | ret = torch.ones(size, size, device=device, dtype=torch.bool) 39 | return torch.tril(ret, out=ret) 40 | 41 | 42 | def subsequent_chunk_mask( 43 | size: int, 44 | chunk_size: int, 45 | num_left_chunks: int = -1, 46 | device: torch.device = torch.device("cpu"), 47 | ) -> torch.Tensor: 48 | """Create mask for subsequent steps (size, size) with chunk size, 49 | this is for streaming encoder 50 | 51 | Args: 52 | size (int): size of mask 53 | chunk_size (int): size of chunk 54 | num_left_chunks (int): number of left chunks 55 | <0: use full chunk 56 | >=0: use num_left_chunks 57 | device (torch.device): "cpu" or "cuda" or torch.Tensor.device 58 | 59 | Returns: 60 | torch.Tensor: mask 61 | 62 | Examples: 63 | >>> subsequent_chunk_mask(4, 2) 64 | [[1, 1, 0, 0], 65 | [1, 1, 0, 0], 66 | [1, 1, 1, 1], 67 | [1, 1, 1, 1]] 68 | """ 69 | ret = torch.zeros(size, size, device=device, dtype=torch.bool) 70 | for i in range(size): 71 | if num_left_chunks < 0: 72 | start = 0 73 | else: 74 | start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) 75 | ending = min((i // chunk_size + 1) * chunk_size, size) 76 | ret[i, start:ending] = True 77 | return ret 78 | 79 | 80 | def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, 81 | use_dynamic_chunk: bool, 82 | use_dynamic_left_chunk: bool, 83 | decoding_chunk_size: int, static_chunk_size: int, 84 | num_decoding_left_chunks: int): 85 | """ Apply optional mask for encoder. 86 | 87 | Args: 88 | xs (torch.Tensor): padded input, (B, L, D), L for max length 89 | mask (torch.Tensor): mask for xs, (B, 1, L) 90 | use_dynamic_chunk (bool): whether to use dynamic chunk or not 91 | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for 92 | training. 93 | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 94 | 0: default for training, use random dynamic chunk. 95 | <0: for decoding, use full chunk. 96 | >0: for decoding, use fixed chunk size as set. 97 | static_chunk_size (int): chunk size for static chunk training/decoding 98 | if it's greater than 0, if use_dynamic_chunk is true, 99 | this parameter will be ignored 100 | num_decoding_left_chunks: number of left chunks, this is for decoding, 101 | the chunk size is decoding_chunk_size. 102 | >=0: use num_decoding_left_chunks 103 | <0: use all left chunks 104 | 105 | Returns: 106 | torch.Tensor: chunk mask of the input xs. 107 | """ 108 | # Whether to use chunk mask or not 109 | if use_dynamic_chunk: 110 | max_len = xs.size(1) 111 | if decoding_chunk_size < 0: 112 | chunk_size = max_len 113 | num_left_chunks = -1 114 | elif decoding_chunk_size > 0: 115 | chunk_size = decoding_chunk_size 116 | num_left_chunks = num_decoding_left_chunks 117 | else: 118 | # chunk size is either [1, 25] or full context(max_len). 119 | # Since we use 4 times subsampling and allow up to 1s(100 frames) 120 | # delay, the maximum frame is 100 / 4 = 25. 121 | chunk_size = torch.randint(1, max_len, (1, )).item() 122 | num_left_chunks = -1 123 | if chunk_size > max_len // 2: 124 | chunk_size = max_len 125 | else: 126 | chunk_size = chunk_size % 25 + 1 127 | if use_dynamic_left_chunk: 128 | max_left_chunks = (max_len - 1) // chunk_size 129 | num_left_chunks = torch.randint(0, max_left_chunks, 130 | (1, )).item() 131 | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, 132 | num_left_chunks, 133 | xs.device) # (L, L) 134 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 135 | chunk_masks = masks & chunk_masks # (B, L, L) 136 | elif static_chunk_size > 0: 137 | num_left_chunks = num_decoding_left_chunks 138 | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, 139 | num_left_chunks, 140 | xs.device) # (L, L) 141 | chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) 142 | chunk_masks = masks & chunk_masks # (B, L, L) 143 | else: 144 | chunk_masks = masks 145 | return chunk_masks 146 | 147 | 148 | def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: 149 | """Make mask tensor containing indices of padded part. 150 | 151 | See description of make_non_pad_mask. 152 | 153 | Args: 154 | lengths (torch.Tensor): Batch of lengths (B,). 155 | Returns: 156 | torch.Tensor: Mask tensor containing indices of padded part. 157 | 158 | Examples: 159 | >>> lengths = [5, 3, 2] 160 | >>> make_pad_mask(lengths) 161 | masks = [[0, 0, 0, 0 ,0], 162 | [0, 0, 0, 1, 1], 163 | [0, 0, 1, 1, 1]] 164 | """ 165 | batch_size = int(lengths.size(0)) 166 | max_len = int(lengths.max().item()) 167 | seq_range = torch.arange(0, 168 | max_len, 169 | dtype=torch.int64, 170 | device=lengths.device) 171 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 172 | seq_length_expand = lengths.unsqueeze(-1) 173 | mask = seq_range_expand >= seq_length_expand 174 | return mask 175 | 176 | 177 | def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: 178 | """Make mask tensor containing indices of non-padded part. 179 | 180 | The sequences in a batch may have different lengths. To enable 181 | batch computing, padding is need to make all sequence in same 182 | size. To avoid the padding part pass value to context dependent 183 | block such as attention or convolution , this padding part is 184 | masked. 185 | 186 | This pad_mask is used in both encoder and decoder. 187 | 188 | 1 for non-padded part and 0 for padded part. 189 | 190 | Args: 191 | lengths (torch.Tensor): Batch of lengths (B,). 192 | Returns: 193 | torch.Tensor: mask tensor containing indices of padded part. 194 | 195 | Examples: 196 | >>> lengths = [5, 3, 2] 197 | >>> make_non_pad_mask(lengths) 198 | masks = [[1, 1, 1, 1 ,1], 199 | [1, 1, 1, 0, 0], 200 | [1, 1, 0, 0, 0]] 201 | """ 202 | return ~make_pad_mask(lengths) 203 | 204 | 205 | def mask_finished_scores(score: torch.Tensor, 206 | flag: torch.Tensor) -> torch.Tensor: 207 | """ 208 | If a sequence is finished, we only allow one alive branch. This function 209 | aims to give one branch a zero score and the rest -inf score. 210 | 211 | Args: 212 | score (torch.Tensor): A real value array with shape 213 | (batch_size * beam_size, beam_size). 214 | flag (torch.Tensor): A bool array with shape 215 | (batch_size * beam_size, 1). 216 | 217 | Returns: 218 | torch.Tensor: (batch_size * beam_size, beam_size). 219 | """ 220 | beam_size = score.size(-1) 221 | zero_mask = torch.zeros_like(flag, dtype=torch.bool) 222 | if beam_size > 1: 223 | unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), 224 | dim=1) 225 | finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), 226 | dim=1) 227 | else: 228 | unfinished = zero_mask 229 | finished = flag 230 | score.masked_fill_(unfinished, -float('inf')) 231 | score.masked_fill_(finished, 0) 232 | return score 233 | 234 | 235 | def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, 236 | eos: int) -> torch.Tensor: 237 | """ 238 | If a sequence is finished, all of its branch should be 239 | 240 | Args: 241 | pred (torch.Tensor): A int array with shape 242 | (batch_size * beam_size, beam_size). 243 | flag (torch.Tensor): A bool array with shape 244 | (batch_size * beam_size, 1). 245 | 246 | Returns: 247 | torch.Tensor: (batch_size * beam_size). 248 | """ 249 | beam_size = pred.size(-1) 250 | finished = flag.repeat([1, beam_size]) 251 | return pred.masked_fill_(finished, eos) 252 | -------------------------------------------------------------------------------- /wenet/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | from typeguard import check_argument_types 7 | 8 | 9 | class WarmupLR(_LRScheduler): 10 | """The WarmupLR scheduler 11 | 12 | This scheduler is almost same as NoamLR Scheduler except for following 13 | difference: 14 | 15 | NoamLR: 16 | lr = optimizer.lr * model_size ** -0.5 17 | * min(step ** -0.5, step * warmup_step ** -1.5) 18 | WarmupLR: 19 | lr = optimizer.lr * warmup_step ** 0.5 20 | * min(step ** -0.5, step * warmup_step ** -1.5) 21 | 22 | Note that the maximum lr equals to optimizer.lr in this scheduler. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | optimizer: torch.optim.Optimizer, 29 | warmup_steps: Union[int, float] = 25000, 30 | last_epoch: int = -1, 31 | ): 32 | assert check_argument_types() 33 | self.warmup_steps = warmup_steps 34 | 35 | # __init__() must be invoked before setting field 36 | # because step() is also invoked in __init__() 37 | super().__init__(optimizer, last_epoch) 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" 41 | 42 | def get_lr(self): 43 | step_num = self.last_epoch + 1 44 | return [ 45 | lr 46 | * self.warmup_steps ** 0.5 47 | * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) 48 | for lr in self.base_lrs 49 | ] 50 | 51 | def set_step(self, step: int): 52 | self.last_epoch = step 53 | --------------------------------------------------------------------------------