├── examples └── wav │ ├── IT0011W0001.wav │ ├── BAC009S0764W0121.wav │ ├── TEST_MEETING_T0000000001_S00000.wav │ ├── TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav │ └── text ├── requirements.txt ├── pretrained_models └── README.md ├── LICENSE ├── test.py ├── fireredasr ├── models │ ├── fireredasr_aed.py │ ├── fireredasr_streaming.py │ ├── transformer_decoder.py │ └── conformer_encoder.py ├── data │ ├── token_dict.py │ └── asr_feat.py └── tokenizer │ └── aed_tokenizer.py ├── vad.py ├── README.md ├── speech_detector.py └── realtime_fireredasr.py /examples/wav/IT0011W0001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xphh/fireredasr-streaming/HEAD/examples/wav/IT0011W0001.wav -------------------------------------------------------------------------------- /examples/wav/BAC009S0764W0121.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xphh/fireredasr-streaming/HEAD/examples/wav/BAC009S0764W0121.wav -------------------------------------------------------------------------------- /examples/wav/TEST_MEETING_T0000000001_S00000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xphh/fireredasr-streaming/HEAD/examples/wav/TEST_MEETING_T0000000001_S00000.wav -------------------------------------------------------------------------------- /examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xphh/fireredasr-streaming/HEAD/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cn2an>=0.5.23 2 | kaldiio>=2.18.0 3 | kaldi_native_fbank>=1.15 4 | numpy>=1.26.1 5 | peft>=0.13.2 6 | torch>=2.0.0 7 | transformers>=4.46.3 8 | onnxruntime>=1.21.0 9 | pyaudio>=0.2.14 10 | sentencepiece>=0.2.0 -------------------------------------------------------------------------------- /examples/wav/text: -------------------------------------------------------------------------------- 1 | BAC009S0764W0121 甚至 出现 交易 几乎 停滞 的 情况 2 | IT0011W0001 换一首歌 3 | TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 我有的时候说不清楚你们知道吗 4 | TEST_MEETING_T0000000001_S00000 好首先说一下刚才这个经理说完的这个销售问题咱再说一下咱们的商场问题首先咱们商场上半年业这个先各部门儿汇报一下就是业绩 5 | -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | # Put Models Here 2 | 3 | # VAD 4 | 5 | Download `silero_vad.onnx` from [silero-vad](https://github.com/snakers4/silero-vad) 6 | 7 | # FireRedAsr-AED 8 | 9 | You can derive models from modelscope: [FireRedASR-AED-L](https://modelscope.cn/models/pengzhendong/FireRedASR-AED-L) 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Ping.X 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import pyaudio 2 | from realtime_fireredasr import RealtimeSpeechRecognizer 3 | 4 | 5 | def open_stream(filename=None): 6 | stream = None 7 | if filename is None: 8 | p = pyaudio.PyAudio() 9 | stream = p.open(format=pyaudio.paInt16, 10 | channels=1, 11 | rate=16000, 12 | input=True, 13 | frames_per_buffer=1600) 14 | else: 15 | stream = open(filename, "rb") 16 | return stream 17 | 18 | 19 | # read wav from file 20 | filename = "examples/wav/BAC009S0764W0121.wav" 21 | 22 | # capture wav from microphone 23 | # filename = None 24 | 25 | # create realtime speech recognizer instance 26 | asr = RealtimeSpeechRecognizer( 27 | model_dir="pretrained_models", # models' dir 28 | use_gpu=False, # use gpu or not 29 | sample_rate=16000, # audio sample rate 30 | silence_duration_s=0.4, # silence duration for VAD cutting 31 | transcribe_interval=1.0, # how many seconds to transcribe once 32 | ) 33 | 34 | # recognization loop 35 | stream = open_stream(filename) 36 | while True: 37 | data = stream.read(1600) 38 | if data == b"": 39 | print(">>>wave EOF") 40 | break 41 | results = asr.recognize(data) 42 | if len(results) > 0: 43 | print(results) 44 | -------------------------------------------------------------------------------- /fireredasr/models/fireredasr_aed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fireredasr.models.conformer_encoder import ConformerEncoder 4 | from fireredasr.models.transformer_decoder import TransformerDecoder 5 | 6 | 7 | class FireRedAsrAed(torch.nn.Module): 8 | @classmethod 9 | def from_args(cls, args): 10 | return cls(args) 11 | 12 | def __init__(self, args): 13 | super().__init__() 14 | self.sos_id = args.sos_id 15 | self.eos_id = args.eos_id 16 | 17 | self.encoder = ConformerEncoder( 18 | args.idim, args.n_layers_enc, args.n_head, args.d_model, 19 | args.residual_dropout, args.dropout_rate, 20 | args.kernel_size, args.pe_maxlen) 21 | 22 | self.decoder = TransformerDecoder( 23 | args.sos_id, args.eos_id, args.pad_id, args.odim, 24 | args.n_layers_dec, args.n_head, args.d_model, 25 | args.residual_dropout, args.pe_maxlen) 26 | 27 | def transcribe(self, padded_input, input_lengths, 28 | beam_size=1, nbest=1, decode_max_len=0, 29 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0, 30 | ys_state=None): 31 | enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths) 32 | nbest_hyps = self.decoder.batch_beam_search( 33 | enc_outputs, enc_mask, 34 | beam_size, nbest, decode_max_len, 35 | softmax_smoothing, length_penalty, eos_penalty, ys_state) 36 | return nbest_hyps 37 | -------------------------------------------------------------------------------- /vad.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | 4 | 5 | @functools.lru_cache 6 | def get_vad_model(model_dir="pretrained_models"): 7 | # now is silero_vad v5 model 8 | return SileroVADModel(f"{model_dir}/silero_vad.onnx") 9 | 10 | 11 | class SileroVADModel: 12 | def __init__(self, path): 13 | try: 14 | import onnxruntime 15 | except ImportError as e: 16 | raise RuntimeError( 17 | "Applying the VAD filter requires the onnxruntime package" 18 | ) from e 19 | 20 | opts = onnxruntime.SessionOptions() 21 | opts.inter_op_num_threads = 1 22 | opts.intra_op_num_threads = 1 23 | opts.log_severity_level = 4 24 | 25 | self.session = onnxruntime.InferenceSession( 26 | path, 27 | providers=["CPUExecutionProvider"], 28 | sess_options=opts, 29 | ) 30 | 31 | def get_initial_state(self, batch_size: int): 32 | return np.zeros((2, batch_size, 128), dtype=np.float32) 33 | 34 | 35 | def __call__(self, x, state, sr: int): 36 | if len(x.shape) == 1: 37 | x = np.expand_dims(x, 0) 38 | if len(x.shape) > 2: 39 | raise ValueError( 40 | f"Too many dimensions for input audio chunk {len(x.shape)}" 41 | ) 42 | if sr/x.shape[1] > 31.25: 43 | raise ValueError("Input audio chunk is too short") 44 | 45 | ort_inputs = { 46 | "input": x, 47 | "state": state, 48 | "sr": np.array(sr, dtype="int64"), 49 | } 50 | 51 | out, state = self.session.run(None, ort_inputs) 52 | 53 | return out, state 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FireRedAsr-Streaming 2 | 3 | A low-latency realtime ASR based on [FireRedASR](https://github.com/FireRedTeam/FireRedASR) (a SOTA ASR model for Chinese and English speech recognizing) 4 | 5 | ## How it works 6 | 7 | I've read the code of FireRedASR and found by utilizing the mechanism of autoregressive prediction in transformer decoder, we can input audio data into model in advance of the speeking sentence end. 8 | 9 | ## How to use 10 | 11 | See `test.py` 12 | 13 | ``` python 14 | from realtime_fireredasr import RealtimeSpeechRecognizer 15 | 16 | # read wav from file 17 | filename = "examples/wav/BAC009S0764W0121.wav" 18 | 19 | # create realtime speech recognizer instance 20 | asr = RealtimeSpeechRecognizer( 21 | model_dir="pretrained_models", # models' dir 22 | use_gpu=False, # use gpu or not 23 | sample_rate=16000, # audio sample rate 24 | silence_duration_s=0.4, # silence duration for VAD cutting 25 | transcribe_interval=1.0, # how many seconds to transcribe once 26 | ) 27 | 28 | # recognization loop 29 | stream = open_stream(filename) 30 | while True: 31 | data = stream.read(1600) 32 | if data == b"": 33 | print(">>>wave EOF") 34 | break 35 | results = asr.recognize(data) 36 | if len(results) > 0: 37 | print(results) 38 | ``` 39 | 40 | Results (Tested on Tesla T4): 41 | ``` 42 | [{'type': 'begin', 'id': 0, 'text': None, 'ts': 0.672, 'latency': 0.0}] 43 | [{'type': 'changed', 'id': 0, 'text': '甚至出现交', 'ts': 1.664, 'latency': 4.238659143447876}] 44 | [{'type': 'changed', 'id': 0, 'text': '甚至出现交易几乎停', 'ts': 2.656, 'latency': 0.2656550407409668}] 45 | [{'type': 'changed', 'id': 0, 'text': '甚至出现交易几乎停滞的情况', 'ts': 3.648, 'latency': 0.22229647636413574}] 46 | [{'type': 'end', 'id': 0, 'text': '甚至出现交易几乎停滞的情况', 'ts': 4.128, 'latency': 0.11719179153442383}] 47 | >>> wave EOF 48 | ``` 49 | 50 | Note that the recognition cost only 117ms at the sentence end, while it need nearly 400ms for whole sentence recognition without this project. 51 | -------------------------------------------------------------------------------- /fireredasr/data/token_dict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class TokenDict: 5 | def __init__(self, dict_path, unk=""): 6 | assert dict_path != "" 7 | self.id2word, self.word2id = self.read_dict(dict_path) 8 | self.unk = unk 9 | assert unk == "" or unk in self.word2id 10 | self.unkid = self.word2id[unk] if unk else -1 11 | 12 | def get(self, key, default): 13 | if type(default) == str: 14 | default = self.word2id[default] 15 | return self.word2id.get(key, default) 16 | 17 | def __getitem__(self, key): 18 | if type(key) == str: 19 | if self.unk: 20 | return self.word2id.get(key, self.word2id[self.unk]) 21 | else: 22 | return self.word2id[key] 23 | elif type(key) == int: 24 | return self.id2word[key] 25 | else: 26 | raise TypeError("Key should be str or int") 27 | 28 | def __len__(self): 29 | return len(self.id2word) 30 | 31 | def __contains__(self, query): 32 | if type(query) == str: 33 | return query in self.word2id 34 | elif type(query) == int: 35 | return query in self.id2word 36 | else: 37 | raise TypeError("query should be str or int") 38 | 39 | def read_dict(self, dict_path): 40 | id2word, word2id = [], {} 41 | with open(dict_path, encoding='utf8') as f: 42 | for i, line in enumerate(f): 43 | tokens = line.strip().split() 44 | if len(tokens) >= 2: 45 | word, index = tokens[0], int(tokens[1]) 46 | elif len(tokens) == 1: 47 | word, index = tokens[0], i 48 | else: # empty line or space 49 | logging.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '") 50 | word, index = " ", i 51 | assert len(id2word) == index 52 | assert len(word2id) == index 53 | if word == "": 54 | logging.info(f"NOTE: Find in {dict_path}:L{i} and convert it to ' '") 55 | word = " " 56 | word2id[word] = index 57 | id2word.append(word) 58 | assert len(id2word) == len(word2id) 59 | return id2word, word2id 60 | -------------------------------------------------------------------------------- /speech_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import vad as vad 3 | 4 | 5 | class SpeechDetector: 6 | 7 | def __init__(self, 8 | model_dir="pretrained_models", 9 | framerate=16000, 10 | threshold=0.5, 11 | silence_duration_s=0.8, 12 | max_speech_duration_s=30, 13 | ): 14 | self.framerate = framerate 15 | self.threshold = threshold 16 | self.silence_duration_s = silence_duration_s 17 | self.max_speech_duration_s = max_speech_duration_s 18 | 19 | self.model = vad.get_vad_model(model_dir) 20 | self.state = self.model.get_initial_state(batch_size=1) 21 | self.audio_buffer = None 22 | self.silence_last_s = 0 23 | self.is_speech = False 24 | self.samples_count = 0 25 | self.last_speech_pos = 0 26 | 27 | def detect(self, audio): 28 | samples = 512 if self.framerate == 16000 else 256 29 | det_interval_s = float(samples) / self.framerate 30 | if self.audio_buffer is None: 31 | self.audio_buffer = audio 32 | else: 33 | self.audio_buffer = np.concatenate((self.audio_buffer, audio)) 34 | 35 | neg_threshold = max(self.threshold - 0.15, 0.01) 36 | while len(self.audio_buffer) >= samples: 37 | audio = self.audio_buffer[:samples] 38 | speech_prob, self.state = self.model( 39 | audio / 32768.0, self.state, self.framerate) 40 | speech_threshold = neg_threshold if self.is_speech else self.threshold 41 | if speech_prob > speech_threshold: 42 | self.silence_last_s = 0 43 | if not self.is_speech: 44 | self.is_speech = True 45 | self.last_speech_pos = self.samples_count 46 | else: 47 | speech_frames = self.samples_count - self.last_speech_pos 48 | if speech_frames / self.framerate > self.max_speech_duration_s: 49 | self.is_speech = False 50 | else: 51 | if self.is_speech: 52 | self.silence_last_s += det_interval_s 53 | if self.silence_last_s >= self.silence_duration_s: 54 | self.is_speech = False 55 | 56 | yield audio, self.is_speech 57 | 58 | self.samples_count += samples 59 | self.audio_buffer = self.audio_buffer[samples:] 60 | -------------------------------------------------------------------------------- /fireredasr/tokenizer/aed_tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | import sentencepiece as spm 5 | 6 | from fireredasr.data.token_dict import TokenDict 7 | 8 | 9 | class ChineseCharEnglishSpmTokenizer: 10 | """ 11 | - One Chinese char is a token. 12 | - Split English word into SPM and one piece is a token. 13 | - Ignore ' ' between Chinese char 14 | - Replace ' ' between English word with "▁" by spm_model 15 | - Need to put SPM piece into dict file 16 | - If not set spm_model, will use English char and 17 | """ 18 | SPM_SPACE = "▁" 19 | 20 | def __init__(self, dict_path, spm_model, unk="", space=""): 21 | self.dict = TokenDict(dict_path, unk=unk) 22 | self.space = space 23 | if spm_model: 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(spm_model) 26 | else: 27 | self.sp = None 28 | print("[WRAN] Not set spm_model, will use English char") 29 | print("[WARN] Please check how to deal with ' '(space)") 30 | if self.space not in self.dict: 31 | print("Please add to your dict, or it will be ") 32 | 33 | def tokenize(self, text, replace_punc=True): 34 | #if text == "": 35 | # logging.info(f"empty text") 36 | text = text.upper() 37 | tokens = [] 38 | if replace_punc: 39 | text = re.sub("[,。?!,\.?!]", " ", text) 40 | pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') 41 | parts = pattern.split(text.strip()) 42 | parts = [p for p in parts if len(p.strip()) > 0] 43 | for part in parts: 44 | if pattern.fullmatch(part) is not None: 45 | tokens.append(part) 46 | else: 47 | if self.sp: 48 | for piece in self.sp.EncodeAsPieces(part.strip()): 49 | tokens.append(piece) 50 | else: 51 | for char in part.strip(): 52 | tokens.append(char if char != " " else self.space) 53 | tokens_id = [] 54 | for token in tokens: 55 | tokens_id.append(self.dict.get(token, self.dict.unk)) 56 | return tokens, tokens_id 57 | 58 | def detokenize(self, inputs, join_symbol="", replace_spm_space=True): 59 | """inputs is ids or tokens, do not need self.sp""" 60 | if len(inputs) > 0 and type(inputs[0]) == int: 61 | tokens = [self.dict[id] for id in inputs] 62 | else: 63 | tokens = inputs 64 | s = f"{join_symbol}".join(tokens) 65 | if replace_spm_space: 66 | s = s.replace(self.SPM_SPACE, ' ').strip() 67 | return s 68 | -------------------------------------------------------------------------------- /realtime_fireredasr.py: -------------------------------------------------------------------------------- 1 | from speech_detector import SpeechDetector 2 | from fireredasr.models.fireredasr_streaming import FireRedAsrStreaming 3 | import numpy as np 4 | import time 5 | 6 | 7 | class RealtimeSpeechRecognizer: 8 | def __init__( 9 | self, 10 | model_dir="pretrained_models", 11 | use_gpu=False, 12 | sample_rate=16000, 13 | silence_duration_s=0.4, 14 | transcribe_interval=1.0, 15 | ): 16 | self.model = FireRedAsrStreaming(model_dir, 17 | use_gpu=use_gpu, 18 | sample_rate=sample_rate) 19 | self.detector = SpeechDetector(model_dir, 20 | framerate=sample_rate, 21 | silence_duration_s=silence_duration_s) 22 | self.sample_rate = sample_rate 23 | self.transcribe_interval = transcribe_interval 24 | self.sentence_id = 0 25 | self.speech_state = False 26 | self.sample_count = 0 27 | self.next_transcribe_time = 0.0 28 | 29 | def gen_result(self, t, text=None, latency=0.0): 30 | return { 31 | "type": t, 32 | "id": self.sentence_id, 33 | 34 | "text": text, 35 | "ts": self.sample_count/self.sample_rate, 36 | "latency": latency, 37 | } 38 | 39 | def transcribe(self): 40 | start_time = time.time() 41 | text = self.model.transcribe() 42 | return text, time.time() - start_time 43 | 44 | def recognize(self, audio_bytes): 45 | results = [] 46 | wav_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) 47 | for frame_np, is_speech in self.detector.detect(wav_np): 48 | if is_speech: 49 | self.model.input(frame_np) 50 | if is_speech and not self.speech_state: 51 | results.append(self.gen_result("begin")) 52 | self.next_transcribe_time = self.transcribe_interval 53 | elif self.speech_state and not is_speech: 54 | text, cost = self.transcribe() 55 | results.append(self.gen_result("end", text, cost)) 56 | self.model.clear_state() 57 | self.sentence_id += 1 58 | elif self.speech_state: 59 | cur_ts = self.model.get_input_length() / self.sample_rate 60 | if cur_ts >= self.next_transcribe_time: 61 | text, cost = self.transcribe() 62 | results.append(self.gen_result("changed", text, cost)) 63 | self.next_transcribe_time = cur_ts + self.transcribe_interval 64 | self.speech_state = is_speech 65 | self.sample_count += len(frame_np) 66 | return results 67 | -------------------------------------------------------------------------------- /fireredasr/data/asr_feat.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import kaldiio 5 | import kaldi_native_fbank as knf 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class ASRFeatExtractor: 11 | def __init__(self, kaldi_cmvn_file): 12 | self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None 13 | self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25, 14 | frame_shift=10, dither=0.0) 15 | 16 | def __call__(self, sample_rate, wav_np): 17 | fbank = self.fbank((sample_rate, wav_np)) 18 | if self.cmvn is not None: 19 | fbank = self.cmvn(fbank) 20 | fbank = torch.from_numpy(fbank).float() 21 | return fbank 22 | 23 | 24 | class CMVN: 25 | def __init__(self, kaldi_cmvn_file): 26 | self.dim, self.means, self.inverse_std_variences = \ 27 | self.read_kaldi_cmvn(kaldi_cmvn_file) 28 | 29 | def __call__(self, x, is_train=False): 30 | assert x.shape[-1] == self.dim, "CMVN dim mismatch" 31 | out = x - self.means 32 | out = out * self.inverse_std_variences 33 | return out 34 | 35 | def read_kaldi_cmvn(self, kaldi_cmvn_file): 36 | assert os.path.exists(kaldi_cmvn_file) 37 | stats = kaldiio.load_mat(kaldi_cmvn_file) 38 | assert stats.shape[0] == 2 39 | dim = stats.shape[-1] - 1 40 | count = stats[0, dim] 41 | assert count >= 1 42 | floor = 1e-20 43 | means = [] 44 | inverse_std_variences = [] 45 | for d in range(dim): 46 | mean = stats[0, d] / count 47 | means.append(mean.item()) 48 | varience = (stats[1, d] / count) - mean*mean 49 | if varience < floor: 50 | varience = floor 51 | istd = 1.0 / math.sqrt(varience) 52 | inverse_std_variences.append(istd) 53 | return dim, np.array(means), np.array(inverse_std_variences) 54 | 55 | 56 | class KaldifeatFbank: 57 | def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10, 58 | dither=1.0): 59 | self.dither = dither 60 | opts = knf.FbankOptions() 61 | opts.frame_opts.dither = dither 62 | opts.mel_opts.num_bins = num_mel_bins 63 | opts.frame_opts.snip_edges = True 64 | opts.mel_opts.debug_mel = False 65 | self.opts = opts 66 | 67 | def __call__(self, wav, is_train=False): 68 | if type(wav) is str: 69 | sample_rate, wav_np = kaldiio.load_mat(wav) 70 | elif type(wav) in [tuple, list] and len(wav) == 2: 71 | sample_rate, wav_np = wav 72 | assert len(wav_np.shape) == 1 73 | 74 | dither = self.dither if is_train else 0.0 75 | self.opts.frame_opts.dither = dither 76 | fbank = knf.OnlineFbank(self.opts) 77 | 78 | fbank.accept_waveform(sample_rate, wav_np.tolist()) 79 | feat = [] 80 | for i in range(fbank.num_frames_ready): 81 | feat.append(fbank.get_frame(i)) 82 | if len(feat) == 0: 83 | print("Check data, len(feat) == 0", wav, flush=True) 84 | return np.zeros((0, self.opts.mel_opts.num_bins)) 85 | feat = np.vstack(feat) 86 | return feat 87 | -------------------------------------------------------------------------------- /fireredasr/models/fireredasr_streaming.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | from fireredasr.data.asr_feat import ASRFeatExtractor 6 | from fireredasr.models.fireredasr_aed import FireRedAsrAed 7 | from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer 8 | import functools 9 | 10 | 11 | @functools.lru_cache 12 | def load_model(model_dir): 13 | cmvn_path = os.path.join(model_dir, "cmvn.ark") 14 | feat_extractor = ASRFeatExtractor(cmvn_path) 15 | 16 | model_path = os.path.join(model_dir, "model.pth.tar") 17 | package = torch.load(model_path, map_location=lambda storage, 18 | loc: storage, weights_only=False) 19 | print("model args:", package["args"]) 20 | model = FireRedAsrAed.from_args(package["args"]) 21 | model.load_state_dict(package["model_state_dict"], strict=True) 22 | model.eval() 23 | 24 | dict_path = os.path.join(model_dir, "dict.txt") 25 | spm_model = os.path.join(model_dir, "train_bpe1000.model") 26 | tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model) 27 | 28 | return feat_extractor, model, tokenizer 29 | 30 | 31 | class FireRedAsrStreaming: 32 | def __init__(self, 33 | model_dir="pretrained_models", 34 | use_gpu=False, 35 | sample_rate=16000, 36 | least_ys_state_len=4): 37 | self.use_gpu = use_gpu 38 | self.sample_rate = sample_rate 39 | self.least_ys_state_len = least_ys_state_len 40 | feat_extractor, model, tokenizer = load_model(model_dir) 41 | self.feat_extractor = feat_extractor 42 | self.model = model 43 | self.tokenizer = tokenizer 44 | self.ys_state = None 45 | self.wav_buffer = np.empty(0) 46 | 47 | def input(self, streaming_wav_np): 48 | self.wav_buffer = np.concatenate((self.wav_buffer, streaming_wav_np)) 49 | 50 | def get_input_length(self): 51 | return len(self. wav_buffer) 52 | 53 | def clear_state(self): 54 | self.ys_state = None 55 | self.wav_buffer = np.empty(0) 56 | 57 | @torch.no_grad() 58 | def transcribe(self, full_update=False, args={}): 59 | feat = self.feat_extractor(self.sample_rate, self.wav_buffer) 60 | feats = feat.unsqueeze(0) 61 | lengths = torch.tensor([feat.size(0)]).long() 62 | if self.use_gpu: 63 | feats, lengths = feats.cuda(), lengths.cuda() 64 | self.model.cuda() 65 | else: 66 | self.model.cpu() 67 | 68 | hyps = self.model.transcribe( 69 | feats, lengths, 70 | args.get("beam_size", 1), 71 | args.get("nbest", 1), 72 | args.get("decode_max_len", 0), 73 | args.get("softmax_smoothing", 1.0), 74 | args.get("aed_length_penalty", 0.0), 75 | args.get("eos_penalty", 1.0), 76 | None if full_update else self.ys_state, 77 | ) 78 | 79 | hyp = hyps[0][0] 80 | ys = hyp["yseq"] 81 | hyp_ids = [int(id) for id in ys.cpu()] 82 | text = self.tokenizer.detokenize(hyp_ids) 83 | 84 | if len(ys) > self.least_ys_state_len: 85 | self.ys_state = ys[:-1] if len(ys) > 0 else ys 86 | 87 | return text 88 | -------------------------------------------------------------------------------- /fireredasr/models/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | class TransformerDecoder(nn.Module): 10 | def __init__( 11 | self, sos_id, eos_id, pad_id, odim, 12 | n_layers, n_head, d_model, 13 | residual_dropout=0.1, pe_maxlen=5000): 14 | super().__init__() 15 | self.INF = 1e10 16 | # parameters 17 | self.pad_id = pad_id 18 | self.sos_id = sos_id 19 | self.eos_id = eos_id 20 | self.n_layers = n_layers 21 | 22 | # Components 23 | self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id) 24 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) 25 | self.dropout = nn.Dropout(residual_dropout) 26 | 27 | self.layer_stack = nn.ModuleList() 28 | for l in range(n_layers): 29 | block = DecoderLayer(d_model, n_head, residual_dropout) 30 | self.layer_stack.append(block) 31 | 32 | self.tgt_word_prj = nn.Linear(d_model, odim, bias=False) 33 | self.layer_norm_out = nn.LayerNorm(d_model) 34 | 35 | self.tgt_word_prj.weight = self.tgt_word_emb.weight 36 | self.scale = (d_model ** 0.5) 37 | 38 | def batch_beam_search(self, encoder_outputs, src_masks, 39 | beam_size=1, nbest=1, decode_max_len=0, 40 | softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0, 41 | ys_state=None): 42 | B = beam_size 43 | N, Ti, H = encoder_outputs.size() 44 | device = encoder_outputs.device 45 | maxlen = decode_max_len if decode_max_len > 0 else Ti 46 | assert eos_penalty > 0.0 and eos_penalty <= 1.0 47 | 48 | # Init 49 | encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H) 50 | src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti) 51 | if ys_state is None: 52 | ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device) 53 | else: 54 | ys = torch.cat([torch.Tensor([self.sos_id]).to(device), ys_state.to(device)]).unsqueeze(0).long() 55 | caches: List[Optional[Tensor]] = [] 56 | for _ in range(self.n_layers): 57 | caches.append(None) 58 | scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device) 59 | scores = scores.repeat(N).view(N*B, 1) 60 | is_finished = torch.zeros_like(scores) 61 | 62 | # Autoregressive Prediction 63 | for t in range(maxlen): 64 | tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id) 65 | 66 | dec_output = self.dropout( 67 | self.tgt_word_emb(ys) * self.scale + 68 | self.positional_encoding(ys)) 69 | 70 | i = 0 71 | for dec_layer in self.layer_stack: 72 | dec_output = dec_layer.forward( 73 | dec_output, encoder_outputs, 74 | tgt_mask, src_mask, 75 | cache=caches[i]) 76 | caches[i] = dec_output 77 | i += 1 78 | 79 | dec_output = self.layer_norm_out(dec_output) 80 | 81 | t_logit = self.tgt_word_prj(dec_output[:, -1]) 82 | t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1) 83 | 84 | if eos_penalty != 1.0: 85 | t_scores[:, self.eos_id] *= eos_penalty 86 | 87 | t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1) 88 | t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished) 89 | t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished) 90 | 91 | # Accumulated 92 | scores = scores + t_topB_scores 93 | 94 | # Pruning 95 | scores = scores.view(N, B*B) 96 | scores, topB_score_ids = torch.topk(scores, k=B, dim=1) 97 | scores = scores.view(-1, 1) 98 | 99 | topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B) 100 | stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device) 101 | topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long() 102 | 103 | # Update ys 104 | ys = ys[topB_row_number_in_ys] 105 | t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1) 106 | ys = torch.cat((ys, t_ys), dim=1) 107 | 108 | # Update caches 109 | new_caches: List[Optional[Tensor]] = [] 110 | for cache in caches: 111 | if cache is not None: 112 | new_caches.append(cache[topB_row_number_in_ys]) 113 | caches = new_caches 114 | 115 | # Update finished state 116 | is_finished = t_ys.eq(self.eos_id) 117 | if is_finished.sum().item() == N*B: 118 | break 119 | 120 | # Length penalty (follow GNMT) 121 | scores = scores.view(N, B) 122 | ys = ys.view(N, B, -1) 123 | ys_lengths = self.get_ys_lengths(ys) 124 | if length_penalty > 0.0: 125 | penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty) 126 | scores /= penalty 127 | nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1) 128 | nbest_scores = -1.0 * nbest_scores 129 | index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long() 130 | nbest_ys = ys.view(N*B, -1)[index.view(-1)] 131 | nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1) 132 | nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1) 133 | 134 | # result 135 | nbest_hyps: List[List[Dict[str, Tensor]]] = [] 136 | for n in range(N): 137 | n_nbest_hyps: List[Dict[str, Tensor]] = [] 138 | for i, score in enumerate(nbest_scores[n]): 139 | new_hyp = { 140 | "yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]] 141 | } 142 | n_nbest_hyps.append(new_hyp) 143 | nbest_hyps.append(n_nbest_hyps) 144 | return nbest_hyps 145 | 146 | def ignored_target_position_is_0(self, padded_targets, ignore_id): 147 | mask = torch.ne(padded_targets, ignore_id) 148 | mask = mask.unsqueeze(dim=1) 149 | T = padded_targets.size(-1) 150 | upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype) 151 | upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device) 152 | return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8) 153 | 154 | def upper_triangular_is_0(self, size): 155 | ones = torch.ones(size, size) 156 | tri_left_ones = torch.tril(ones) 157 | return tri_left_ones.to(torch.uint8) 158 | 159 | def set_finished_beam_score_to_zero(self, scores, is_finished): 160 | NB, B = scores.size() 161 | is_finished = is_finished.float() 162 | mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device) 163 | mask_score = mask_score.view(1, B).repeat(NB, 1) 164 | return scores * (1 - is_finished) + mask_score * is_finished 165 | 166 | def set_finished_beam_y_to_eos(self, ys, is_finished): 167 | is_finished = is_finished.long() 168 | return ys * (1 - is_finished) + self.eos_id * is_finished 169 | 170 | def get_ys_lengths(self, ys): 171 | N, B, Tmax = ys.size() 172 | ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1) 173 | return ys_lengths.int() 174 | 175 | 176 | 177 | class DecoderLayer(nn.Module): 178 | def __init__(self, d_model, n_head, dropout): 179 | super().__init__() 180 | self.self_attn_norm = nn.LayerNorm(d_model) 181 | self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) 182 | 183 | self.cross_attn_norm = nn.LayerNorm(d_model) 184 | self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout) 185 | 186 | self.mlp_norm = nn.LayerNorm(d_model) 187 | self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout) 188 | 189 | def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask, 190 | cache=None): 191 | x = dec_input 192 | residual = x 193 | x = self.self_attn_norm(x) 194 | if cache is not None: 195 | xq = x[:, -1:, :] 196 | residual = residual[:, -1:, :] 197 | self_attn_mask = self_attn_mask[:, -1:, :] 198 | else: 199 | xq = x 200 | x = self.self_attn(xq, x, x, mask=self_attn_mask) 201 | x = residual + x 202 | 203 | residual = x 204 | x = self.cross_attn_norm(x) 205 | x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask) 206 | x = residual + x 207 | 208 | residual = x 209 | x = self.mlp_norm(x) 210 | x = residual + self.mlp(x) 211 | 212 | if cache is not None: 213 | x = torch.cat([cache, x], dim=1) 214 | 215 | return x 216 | 217 | 218 | class DecoderMultiHeadAttention(nn.Module): 219 | def __init__(self, d_model, n_head, dropout=0.1): 220 | super().__init__() 221 | self.d_model = d_model 222 | self.n_head = n_head 223 | self.d_k = d_model // n_head 224 | 225 | self.w_qs = nn.Linear(d_model, n_head * self.d_k) 226 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) 227 | self.w_vs = nn.Linear(d_model, n_head * self.d_k) 228 | 229 | self.attention = DecoderScaledDotProductAttention( 230 | temperature=self.d_k ** 0.5) 231 | self.fc = nn.Linear(n_head * self.d_k, d_model) 232 | self.dropout = nn.Dropout(dropout) 233 | 234 | def forward(self, q, k, v, mask=None): 235 | bs = q.size(0) 236 | 237 | q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k) 238 | k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k) 239 | v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k) 240 | q = q.transpose(1, 2) 241 | k = k.transpose(1, 2) 242 | v = v.transpose(1, 2) 243 | 244 | if mask is not None: 245 | mask = mask.unsqueeze(1) 246 | 247 | output = self.attention(q, k, v, mask=mask) 248 | 249 | output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model) 250 | output = self.fc(output) 251 | output = self.dropout(output) 252 | 253 | return output 254 | 255 | 256 | class DecoderScaledDotProductAttention(nn.Module): 257 | def __init__(self, temperature): 258 | super().__init__() 259 | self.temperature = temperature 260 | self.INF = float("inf") 261 | 262 | def forward(self, q, k, v, mask=None): 263 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature 264 | if mask is not None: 265 | mask = mask.eq(0) 266 | attn = attn.masked_fill(mask, -self.INF) 267 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) 268 | else: 269 | attn = torch.softmax(attn, dim=-1) 270 | output = torch.matmul(attn, v) 271 | return output 272 | 273 | 274 | class PositionwiseFeedForward(nn.Module): 275 | def __init__(self, d_model, d_ff, dropout=0.1): 276 | super().__init__() 277 | self.w_1 = nn.Linear(d_model, d_ff) 278 | self.act = nn.GELU() 279 | self.w_2 = nn.Linear(d_ff, d_model) 280 | self.dropout = nn.Dropout(dropout) 281 | 282 | def forward(self, x): 283 | output = self.w_2(self.act(self.w_1(x))) 284 | output = self.dropout(output) 285 | return output 286 | 287 | 288 | class PositionalEncoding(nn.Module): 289 | def __init__(self, d_model, max_len=5000): 290 | super().__init__() 291 | assert d_model % 2 == 0 292 | pe = torch.zeros(max_len, d_model, requires_grad=False) 293 | position = torch.arange(0, max_len).unsqueeze(1).float() 294 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 295 | -(torch.log(torch.tensor(10000.0)).item()/d_model)) 296 | pe[:, 0::2] = torch.sin(position * div_term) 297 | pe[:, 1::2] = torch.cos(position * div_term) 298 | pe = pe.unsqueeze(0) 299 | self.register_buffer('pe', pe) 300 | 301 | def forward(self, x): 302 | length = x.size(1) 303 | return self.pe[:, :length].clone().detach() 304 | -------------------------------------------------------------------------------- /fireredasr/models/conformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConformerEncoder(nn.Module): 7 | def __init__(self, idim, n_layers, n_head, d_model, 8 | residual_dropout=0.1, dropout_rate=0.1, kernel_size=33, 9 | pe_maxlen=5000): 10 | super().__init__() 11 | self.odim = d_model 12 | 13 | self.input_preprocessor = Conv2dSubsampling(idim, d_model) 14 | self.positional_encoding = RelPositionalEncoding(d_model) 15 | self.dropout = nn.Dropout(residual_dropout) 16 | 17 | self.layer_stack = nn.ModuleList() 18 | for l in range(n_layers): 19 | block = RelPosEmbConformerBlock(d_model, n_head, 20 | residual_dropout, 21 | dropout_rate, kernel_size) 22 | self.layer_stack.append(block) 23 | 24 | def forward(self, padded_input, input_lengths, pad=True): 25 | if pad: 26 | padded_input = F.pad(padded_input, 27 | (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0) 28 | src_mask = self.padding_position_is_0(padded_input, input_lengths) 29 | 30 | embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask) 31 | enc_output = self.dropout(embed_output) 32 | 33 | pos_emb = self.dropout(self.positional_encoding(embed_output)) 34 | 35 | enc_outputs = [] 36 | for enc_layer in self.layer_stack: 37 | enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask, 38 | pad_mask=src_mask) 39 | enc_outputs.append(enc_output) 40 | 41 | return enc_output, input_lengths, src_mask 42 | 43 | def padding_position_is_0(self, padded_input, input_lengths): 44 | N, T = padded_input.size()[:2] 45 | mask = torch.ones((N, T)).to(padded_input.device) 46 | for i in range(N): 47 | mask[i, input_lengths[i]:] = 0 48 | mask = mask.unsqueeze(dim=1) 49 | return mask.to(torch.uint8) 50 | 51 | 52 | class RelPosEmbConformerBlock(nn.Module): 53 | def __init__(self, d_model, n_head, 54 | residual_dropout=0.1, 55 | dropout_rate=0.1, kernel_size=33): 56 | super().__init__() 57 | self.ffn1 = ConformerFeedForward(d_model, dropout_rate) 58 | self.mhsa = RelPosMultiHeadAttention(n_head, d_model, 59 | residual_dropout) 60 | self.conv = ConformerConvolution(d_model, kernel_size, 61 | dropout_rate) 62 | self.ffn2 = ConformerFeedForward(d_model, dropout_rate) 63 | self.layer_norm = nn.LayerNorm(d_model) 64 | 65 | def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None): 66 | out = 0.5 * x + 0.5 * self.ffn1(x) 67 | out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0] 68 | out = self.conv(out, pad_mask) 69 | out = 0.5 * out + 0.5 * self.ffn2(out) 70 | out = self.layer_norm(out) 71 | return out 72 | 73 | 74 | class Swish(nn.Module): 75 | def forward(self, x): 76 | return x * torch.sigmoid(x) 77 | 78 | 79 | class Conv2dSubsampling(nn.Module): 80 | def __init__(self, idim, d_model, out_channels=32): 81 | super().__init__() 82 | self.conv = nn.Sequential( 83 | nn.Conv2d(1, out_channels, 3, 2), 84 | nn.ReLU(), 85 | nn.Conv2d(out_channels, out_channels, 3, 2), 86 | nn.ReLU(), 87 | ) 88 | subsample_idim = ((idim - 1) // 2 - 1) // 2 89 | self.out = nn.Linear(out_channels * subsample_idim, d_model) 90 | 91 | self.subsampling = 4 92 | left_context = right_context = 3 # both exclude currect frame 93 | self.context = left_context + 1 + right_context # 7 94 | 95 | def forward(self, x, x_mask): 96 | x = x.unsqueeze(1) 97 | x = self.conv(x) 98 | N, C, T, D = x.size() 99 | x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D)) 100 | mask = x_mask[:, :, :-2:2][:, :, :-2:2] 101 | input_lengths = mask[:, -1, :].sum(dim=-1) 102 | return x, input_lengths, mask 103 | 104 | 105 | class RelPositionalEncoding(torch.nn.Module): 106 | def __init__(self, d_model, max_len=5000): 107 | super().__init__() 108 | pe_positive = torch.zeros(max_len, d_model, requires_grad=False) 109 | pe_negative = torch.zeros(max_len, d_model, requires_grad=False) 110 | position = torch.arange(0, max_len).unsqueeze(1).float() 111 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 112 | -(torch.log(torch.tensor(10000.0)).item()/d_model)) 113 | pe_positive[:, 0::2] = torch.sin(position * div_term) 114 | pe_positive[:, 1::2] = torch.cos(position * div_term) 115 | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 116 | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 117 | 118 | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 119 | pe_negative = pe_negative[1:].unsqueeze(0) 120 | pe = torch.cat([pe_positive, pe_negative], dim=1) 121 | self.register_buffer('pe', pe) 122 | 123 | def forward(self, x): 124 | # Tmax = 2 * max_len - 1 125 | Tmax, T = self.pe.size(1), x.size(1) 126 | pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach() 127 | return pos_emb 128 | 129 | 130 | class ConformerFeedForward(nn.Module): 131 | def __init__(self, d_model, dropout_rate=0.1): 132 | super().__init__() 133 | pre_layer_norm = nn.LayerNorm(d_model) 134 | linear_expand = nn.Linear(d_model, d_model*4) 135 | nonlinear = Swish() 136 | dropout_pre = nn.Dropout(dropout_rate) 137 | linear_project = nn.Linear(d_model*4, d_model) 138 | dropout_post = nn.Dropout(dropout_rate) 139 | self.net = nn.Sequential(pre_layer_norm, 140 | linear_expand, 141 | nonlinear, 142 | dropout_pre, 143 | linear_project, 144 | dropout_post) 145 | 146 | def forward(self, x): 147 | residual = x 148 | output = self.net(x) 149 | output = output + residual 150 | return output 151 | 152 | 153 | class ConformerConvolution(nn.Module): 154 | def __init__(self, d_model, kernel_size=33, dropout_rate=0.1): 155 | super().__init__() 156 | assert kernel_size % 2 == 1 157 | self.pre_layer_norm = nn.LayerNorm(d_model) 158 | self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False) 159 | self.glu = F.glu 160 | self.padding = (kernel_size - 1) // 2 161 | self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2, 162 | kernel_size, stride=1, 163 | padding=self.padding, 164 | groups=d_model*2, bias=False) 165 | self.batch_norm = nn.LayerNorm(d_model*2) 166 | self.swish = Swish() 167 | self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False) 168 | self.dropout = nn.Dropout(dropout_rate) 169 | 170 | def forward(self, x, mask=None): 171 | residual = x 172 | out = self.pre_layer_norm(x) 173 | out = out.transpose(1, 2) 174 | if mask is not None: 175 | out.masked_fill_(mask.ne(1), 0.0) 176 | out = self.pointwise_conv1(out) 177 | out = F.glu(out, dim=1) 178 | out = self.depthwise_conv(out) 179 | 180 | out = out.transpose(1, 2) 181 | out = self.swish(self.batch_norm(out)) 182 | out = out.transpose(1, 2) 183 | 184 | out = self.dropout(self.pointwise_conv2(out)) 185 | if mask is not None: 186 | out.masked_fill_(mask.ne(1), 0.0) 187 | out = out.transpose(1, 2) 188 | return out + residual 189 | 190 | 191 | class EncoderMultiHeadAttention(nn.Module): 192 | def __init__(self, n_head, d_model, 193 | residual_dropout=0.1): 194 | super().__init__() 195 | assert d_model % n_head == 0 196 | self.n_head = n_head 197 | self.d_k = d_model // n_head 198 | self.d_v = self.d_k 199 | 200 | self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False) 201 | self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False) 202 | self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False) 203 | 204 | self.layer_norm_q = nn.LayerNorm(d_model) 205 | self.layer_norm_k = nn.LayerNorm(d_model) 206 | self.layer_norm_v = nn.LayerNorm(d_model) 207 | 208 | self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5) 209 | self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False) 210 | self.dropout = nn.Dropout(residual_dropout) 211 | 212 | def forward(self, q, k, v, mask=None): 213 | sz_b, len_q = q.size(0), q.size(1) 214 | 215 | residual = q 216 | q, k, v = self.forward_qkv(q, k, v) 217 | 218 | output, attn = self.attention(q, k, v, mask=mask) 219 | 220 | output = self.forward_output(output, residual, sz_b, len_q) 221 | return output, attn 222 | 223 | def forward_qkv(self, q, k, v): 224 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 225 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 226 | 227 | q = self.layer_norm_q(q) 228 | k = self.layer_norm_k(k) 229 | v = self.layer_norm_v(v) 230 | 231 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 232 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 233 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 234 | q = q.transpose(1, 2) 235 | k = k.transpose(1, 2) 236 | v = v.transpose(1, 2) 237 | return q, k, v 238 | 239 | def forward_output(self, output, residual, sz_b, len_q): 240 | output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 241 | fc_out = self.fc(output) 242 | output = self.dropout(fc_out) 243 | output = output + residual 244 | return output 245 | 246 | 247 | class ScaledDotProductAttention(nn.Module): 248 | def __init__(self, temperature): 249 | super().__init__() 250 | self.temperature = temperature 251 | self.dropout = nn.Dropout(0.0) 252 | self.INF = float('inf') 253 | 254 | def forward(self, q, k, v, mask=None): 255 | attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature 256 | output, attn = self.forward_attention(attn, v, mask) 257 | return output, attn 258 | 259 | def forward_attention(self, attn, v, mask=None): 260 | if mask is not None: 261 | mask = mask.unsqueeze(1) 262 | mask = mask.eq(0) 263 | attn = attn.masked_fill(mask, -self.INF) 264 | attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) 265 | else: 266 | attn = torch.softmax(attn, dim=-1) 267 | 268 | d_attn = self.dropout(attn) 269 | output = torch.matmul(d_attn, v) 270 | 271 | return output, attn 272 | 273 | 274 | class RelPosMultiHeadAttention(EncoderMultiHeadAttention): 275 | def __init__(self, n_head, d_model, 276 | residual_dropout=0.1): 277 | super().__init__(n_head, d_model, 278 | residual_dropout) 279 | d_k = d_model // n_head 280 | self.scale = 1.0 / (d_k ** 0.5) 281 | self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False) 282 | self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k)) 283 | self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k)) 284 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 285 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 286 | 287 | def _rel_shift(self, x): 288 | N, H, T1, T2 = x.size() 289 | zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype) 290 | x_padded = torch.cat([zero_pad, x], dim=-1) 291 | 292 | x_padded = x_padded.view(N, H, T2 + 1, T1) 293 | x = x_padded[:, :, 1:].view_as(x) 294 | x = x[:, :, :, : x.size(-1) // 2 + 1] 295 | return x 296 | 297 | def forward(self, q, k, v, pos_emb, mask=None): 298 | sz_b, len_q = q.size(0), q.size(1) 299 | 300 | residual = q 301 | q, k, v = self.forward_qkv(q, k, v) 302 | 303 | q = q.transpose(1, 2) 304 | n_batch_pos = pos_emb.size(0) 305 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k) 306 | p = p.transpose(1, 2) 307 | 308 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 309 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 310 | 311 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 312 | 313 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 314 | matrix_bd = self._rel_shift(matrix_bd) 315 | 316 | attn_scores = matrix_ac + matrix_bd 317 | attn_scores.mul_(self.scale) 318 | 319 | output, attn = self.attention.forward_attention(attn_scores, v, mask=mask) 320 | 321 | output = self.forward_output(output, residual, sz_b, len_q) 322 | return output, attn 323 | --------------------------------------------------------------------------------