├── data └── .gitkeep ├── corpus ├── preprocess_ted.sh ├── commonvoice.py ├── preprocess_ted.py ├── l2arctic.py ├── valentini.py ├── CHiME.py ├── audiolib.py ├── librispeech.py ├── ted.py └── noisyspeech_synthesizer.py ├── README.md ├── conf ├── config_sgem_ctc.yaml └── config_litta_ctc.yaml ├── data.py ├── main.py ├── utils.py ├── main_lm.py └── forward.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /corpus/preprocess_ted.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # audio_paths="/home/server08/hdd0/changhun_workspace/TEDLIUM_release2/test/sph/*.sph" 4 | # output_dir="/home/server08/hdd0/changhun_workspace/TEDLIUM_release2/test/wav" 5 | audio_paths="/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/sph/*.sph" 6 | output_dir="/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/wav" 7 | [ ! -e "$output_dir" ] && mkdir "$output_dir" 8 | for f in ${audio_paths} 9 | do 10 | IFS="/" read -ra arr <<< ${f} 11 | IFS="." read -ra name <<< ${arr[-1]} 12 | echo "filename: ${name}" 13 | sox $f "${output_dir}/${name}.wav" 14 | done 15 | echo "done." -------------------------------------------------------------------------------- /corpus/commonvoice.py: -------------------------------------------------------------------------------- 1 | import re 2 | from builtins import str as unicode 3 | import pandas as pd 4 | import os 5 | from unicodedata import name 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | from joblib import Parallel, delayed 9 | from torch.utils.data import Dataset 10 | 11 | def preprocess_text(text): 12 | text = unicode(text) 13 | text = text.replace("i.e.", "that is") 14 | text = text.replace("e.g.", "for example") 15 | text = text.replace("Mr.", "Mister") 16 | text = text.replace("Mrs.", "Mistress") 17 | text = text.replace("Dr.", "Doctor") 18 | text = text.replace("-", " ") 19 | text = text.upper() 20 | text = re.sub("[^ A-Z']", "", text) 21 | text = ' '.join(text.split()) 22 | 23 | return text 24 | 25 | class CVDataset(Dataset): 26 | def __init__(self, bucket_size, path="/home/daniel094144/data/cv-corpus-5.1-2020-06-22/en", enhance=False, ascending=False): 27 | # Setup 28 | self.path = path 29 | self.bucket_size = bucket_size 30 | 31 | apath = path + "/clips" 32 | tpath = path + "/test.tsv" 33 | 34 | df = pd.read_csv(tpath, sep='\t') 35 | text = df['sentence'].apply(preprocess_text).values 36 | file_list = df['path'].values 37 | file_list = [os.path.join(apath, f) for f in file_list] 38 | 39 | print(len(file_list), len(text)) 40 | self.file_list, self.text = zip(*[(f_name, txt) 41 | for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))]) 42 | 43 | def __getitem__(self, index): 44 | if self.bucket_size > 1: 45 | # Return a bucket 46 | index = min(len(self.file_list)-self.bucket_size, index) 47 | return [(f_path, txt) for f_path, txt in 48 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 49 | else: 50 | return self.file_list[index], self.text[index] 51 | 52 | def __len__(self): 53 | return len(self.file_list) 54 | -------------------------------------------------------------------------------- /corpus/preprocess_ted.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | import os 3 | import re 4 | 5 | stm_path = '/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/stm' 6 | audio_path = '/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/wav' 7 | save_audio_dir = '/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/wav_segment' 8 | save_text_dir = '/mnt/hdd/hsyoon/workspace/ES/speech/datasets/TED-LIUM/TEDLIUM_release2/test/transcription' 9 | 10 | SAMPLE_RATE = 16000 11 | 12 | # preprocess text 13 | def preprocess_text(text): 14 | text = text.upper() 15 | text = text.replace(" '", "'") 16 | text = text.replace("-", " ") 17 | text = re.sub("[^ A-Z']", "", text) 18 | text = ' '.join(text.split()) 19 | return text 20 | 21 | skip = 'inter_segment_gap' 22 | 23 | import glob 24 | 25 | for stm_path in glob.glob(os.path.join(stm_path, '*.stm')): 26 | with open(stm_path, 'r') as f: 27 | curr_file = None 28 | for line in f: 29 | l = line.split() 30 | name = l[0] 31 | if l[2] == skip: 32 | continue 33 | s = float(l[3]) 34 | e = float(l[4]) 35 | txt = ' '.join(l[6:]) 36 | if curr_file != name: 37 | print('---new---') 38 | print(name) 39 | 40 | print(stm_path) 41 | print(os.path.join(audio_path, name+'.wav')) 42 | 43 | wav, sr = sf.read(os.path.join(audio_path, name+'.wav')) 44 | print(wav.shape) 45 | 46 | start_idx = int(s * sr) 47 | end_idx = int(e * sr) 48 | segment = wav[start_idx: end_idx] 49 | 50 | norm_txt = preprocess_text(txt) 51 | apath = os.path.join(save_audio_dir, '-'.join([name, l[3], l[4]])+'.wav') 52 | tpath = os.path.join(save_text_dir, '-'.join([name, l[3], l[4]])+'.txt') 53 | 54 | 55 | sf.write(apath, segment, SAMPLE_RATE) 56 | with open(tpath, 'w') as tf: 57 | tf.write(norm_txt) 58 | curr_file = name 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /corpus/l2arctic.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | import os 5 | from joblib import Parallel, delayed 6 | from torch.utils.data import Dataset 7 | from builtins import str as unicode 8 | from unicodedata import name 9 | 10 | 11 | def preprocess_text(text): 12 | text = unicode(text) 13 | text = text.replace("i.e.", "that is") 14 | text = text.replace("e.g.", "for example") 15 | text = text.replace("Mr.", "Mister") 16 | text = text.replace("Mrs.", "Mistress") 17 | text = text.replace("Dr.", "Doctor") 18 | text = text.replace("-", " ") 19 | text = text.upper() 20 | text = re.sub("[^ A-Z']", "", text) 21 | text = ' '.join(text.split()) 22 | 23 | return text 24 | 25 | 26 | class L2ArcticDataset(Dataset): 27 | def __init__(self, bucket_size, path, enhance=False, ascending=False): 28 | # Setup 29 | self.path = path 30 | self.bucket_size = bucket_size 31 | 32 | apath = os.path.join(path, "wav") 33 | tpath = os.path.join(path, "transcript") 34 | 35 | file_list, text_list = [], [] 36 | for wav in sorted(os.listdir(apath)): 37 | if not wav.endswith(".wav"): 38 | continue 39 | file_list.append(os.path.join(apath, wav)) 40 | for txt_file in sorted(os.listdir(tpath)): 41 | if not txt_file.endswith(".txt"): 42 | continue 43 | with open(os.path.join(tpath, txt_file), "r") as f: 44 | txt = f.read() 45 | txt = preprocess_text(txt) 46 | text_list.append(txt) 47 | 48 | assert len(file_list) == len(text_list) 49 | 50 | self.file_list, self.text = zip(*[(f_name, txt) 51 | for f_name, txt in sorted(zip(file_list, text_list), reverse=not ascending, key=lambda x:len(x[1]))]) 52 | 53 | def __getitem__(self, index): 54 | if self.bucket_size > 1: 55 | # Return a bucket 56 | index = min(len(self.file_list)-self.bucket_size, index) 57 | return [(f_path, txt) for f_path, txt in 58 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 59 | else: 60 | return self.file_list[index], self.text[index] 61 | 62 | def __len__(self): 63 | return len(self.file_list) -------------------------------------------------------------------------------- /corpus/valentini.py: -------------------------------------------------------------------------------- 1 | import re 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | import os 5 | from joblib import Parallel, delayed 6 | from torch.utils.data import Dataset 7 | from builtins import str as unicode 8 | from unicodedata import name 9 | 10 | 11 | def preprocess_text(text): 12 | text = unicode(text) 13 | text = text.replace("i.e.", "that is") 14 | text = text.replace("e.g.", "for example") 15 | text = text.replace("Mr.", "Mister") 16 | text = text.replace("Mrs.", "Mistress") 17 | text = text.replace("Dr.", "Doctor") 18 | text = text.replace("-", " ") 19 | text = text.upper() 20 | text = re.sub("[^ A-Z']", "", text) 21 | text = ' '.join(text.split()) 22 | 23 | return text 24 | 25 | 26 | class ValDataset(Dataset): 27 | def __init__(self, bucket_size, path, enhance=False, ascending=False): 28 | # Setup 29 | self.path = path 30 | self.bucket_size = bucket_size 31 | 32 | apath = os.path.join(path, "noisy_testset_wav") 33 | tpath = os.path.join(path, "testset_txt") 34 | 35 | file_list, text_list = [], [] 36 | for wav in sorted(os.listdir(apath)): 37 | if not wav.endswith(".wav"): 38 | continue 39 | file_list.append(os.path.join(apath, wav)) 40 | for txt_file in sorted(os.listdir(tpath)): 41 | if not txt_file.endswith(".txt"): 42 | continue 43 | with open(os.path.join(tpath, txt_file), "r") as f: 44 | txt = f.read() 45 | txt = preprocess_text(txt) 46 | text_list.append(txt) 47 | 48 | assert len(file_list) == len(text_list) 49 | 50 | self.file_list, self.text = zip(*[(f_name, txt) 51 | for f_name, txt in sorted(zip(file_list, text_list), reverse=not ascending, key=lambda x:len(x[1]))]) 52 | 53 | def __getitem__(self, index): 54 | if self.bucket_size > 1: 55 | # Return a bucket 56 | index = min(len(self.file_list)-self.bucket_size, index) 57 | return [(f_path, txt) for f_path, txt in 58 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 59 | else: 60 | return self.file_list[index], self.text[index] 61 | 62 | def __len__(self): 63 | return len(self.file_list) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LiTTA 2 | [INTERSPEECH'24] Official code for "LI-TTA: Language Informed Test-Time Adaptation for Automatic Speech Recognition" 3 | 4 | 5 | ## Environmental Setup 6 | ``` 7 | conda create -y -n LiTTA python=3.10 8 | conda activate LiTTA 9 | pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu118 10 | conda env update --file environment.yaml --prune 11 | ``` 12 | 13 | 14 | ## Pre-trained Models 15 | - [CTC-based Model](https://huggingface.co/facebook/wav2vec2-base-960h) 16 | - CTC-based model will be automatically downloaded if you set ``asr`` as ``facebook/wav2vec2-base-960h``. 17 | - [4-gram Language Model for CTC-based Model](https://huggingface.co/patrickvonplaten/wav2vec2-base-100h-with-lm) 18 | - You need to download language by your own using following command: 19 | ``` 20 | git lfs install 21 | git clone https://huggingface.co/patrickvonplaten/wav2vec2-base-100h-with-lm pretrained_models/wav2vec2-base-100h-with-lm 22 | ``` 23 | 24 | 25 | ## Run 26 | You can run main.py (baseline) or main_lm.py (litta) using the command below: 27 | ``` 28 | python main_lm.py \ 29 | --config-name [CONFIG.YAML] \ 30 | dataset_name=[DATASET_NAME] \ 31 | dataset_dir=[DATASET_DIR] \ 32 | ``` 33 | Currently available parameters are as follows: 34 | 35 | Parameter | Value 36 | --- | --- 37 | CONFIG.YAML | config.yaml, config_{sgem\|litta}_ctc.yaml 38 | DATASET_NAME | librispeech, chime, ted, commonvoice, valentini, l2arctic 39 | 40 | 41 | ## Acknowledgement 42 | 43 | This work was partially supported by Institute for Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. 2021-0-01381, Development of Causal AI through Video Understanding and Reinforcement Learning, and Its Applications to Real Environments) and SAMSUNG Research, Samsung Electronics Co.,Ltd. 44 | 45 | Also, we thank the authors of the [SGEM](https://github.com/drumpt/SGEM) for their open-source contributions and their assistance with the data preparation. 46 | 47 | ## Citation 48 | ``` 49 | @inproceedings{yoon2024li, 50 | title={LI-TTA: Language Informed Test-Time Adaptation for Automatic Speech Recognition}, 51 | author={Yoon, Eunseop and Yoon, Hee Suk and Harvill, John and Hasegawa-Johnson, Mark and Yoo, Chang D}, 52 | booktitle={Proc. Interspeech 2024}, 53 | pages={3490--3494}, 54 | year={2024} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /corpus/CHiME.py: -------------------------------------------------------------------------------- 1 | from unicodedata import name 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | import os 5 | from joblib import Parallel, delayed 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def read_text(tpath, file): 10 | '''Get transcription of target wave file, 11 | it's somewhat redundant for accessing each txt multiplt times, 12 | but it works fine with multi-thread''' 13 | txt_list = os.path.join(tpath, "".join("/".join(file.split('/')[-2:]).split(".")[:-1])+'.trn') 14 | 15 | with open(txt_list, 'r') as fp: 16 | for line in fp: 17 | return ' '.join(line.split(' ')[1:]).strip('\n') 18 | 19 | 20 | 21 | class CHiMEDataset(Dataset): 22 | def __init__(self, bucket_size, path="/home/daniel094144/data/CHiME3", enhance=False, ascending=False): 23 | # Setup 24 | self.path = path 25 | self.bucket_size = bucket_size 26 | 27 | split = ['et05_bus_real', 'et05_bus_simu', 'et05_caf_real', 'et05_caf_simu', 'et05_ped_simu', 'et05_str_real', 'et05_str_simu'] 28 | apath = path + "/data/audio/16kHz/enhanced" 29 | tpath = path + "/data/transcriptions" 30 | 31 | file_list = [] 32 | for s in split: 33 | split_list = list(Path(os.path.join(apath, s)).glob("*.wav")) 34 | file_list += split_list 35 | 36 | text = [] 37 | for f in tqdm(file_list, desc='Read text'): 38 | transcription = read_text(tpath, str(f)) 39 | text.append(transcription) 40 | 41 | if enhance: 42 | file_list = [] 43 | for s in split: 44 | split_list = list(Path(os.path.join(os.path.join(apath, s), 'se_wav')).glob("*.wav")) 45 | file_list += split_list 46 | 47 | self.file_list, self.text = zip(*[(f_name, txt) 48 | for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))]) 49 | 50 | def __getitem__(self, index): 51 | if self.bucket_size > 1: 52 | # Return a bucket 53 | index = min(len(self.file_list)-self.bucket_size, index) 54 | return [(f_path, txt) for f_path, txt in 55 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 56 | else: 57 | return self.file_list[index], self.text[index] 58 | 59 | def __len__(self): 60 | return len(self.file_list) 61 | -------------------------------------------------------------------------------- /corpus/audiolib.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 26 15:54:05 2019 4 | 5 | @author: chkarada 6 | """ 7 | import soundfile as sf 8 | import os 9 | import numpy as np 10 | 11 | # Function to read audio 12 | def audioread(path, norm = True, start=0, stop=None): 13 | path = os.path.abspath(path) 14 | if not os.path.exists(path): 15 | raise ValueError("[{}] does not exist!".format(path)) 16 | try: 17 | x, sr = sf.read(path, start=start, stop=stop) 18 | except RuntimeError: # fix for sph pcm-embedded shortened v2 19 | print('WARNING: Audio type not supported') 20 | 21 | if len(x.shape) == 1: # mono 22 | if norm: 23 | rms = (x ** 2).mean() ** 0.5 24 | scalar = 10 ** (-25 / 20) / (rms) 25 | x = x * scalar 26 | return x, sr 27 | else: # multi-channel 28 | x = x.T 29 | x = x.sum(axis=0)/x.shape[0] 30 | if norm: 31 | rms = (x ** 2).mean() ** 0.5 32 | scalar = 10 ** (-25 / 20) / (rms) 33 | x = x * scalar 34 | return x, sr 35 | 36 | # Funtion to write audio 37 | def audiowrite(data, fs, destpath, norm=False): 38 | if norm: 39 | rms = (data ** 2).mean() ** 0.5 40 | scalar = 10 ** (-25 / 10) / (rms+eps) 41 | data = data * scalar 42 | if max(abs(data))>=1: 43 | data = data/max(abs(data), eps) 44 | 45 | destpath = os.path.abspath(destpath) 46 | destdir = os.path.dirname(destpath) 47 | 48 | if not os.path.exists(destdir): 49 | os.makedirs(destdir) 50 | 51 | sf.write(destpath, data, fs) 52 | return 53 | 54 | # Function to mix clean speech and noise at various SNR levels 55 | def snr_mixer(clean, noise, snr): 56 | # Normalizing to -25 dB FS 57 | rmsclean = (clean**2).mean()**0.5 58 | scalarclean = 10 ** (-25 / 20) / rmsclean 59 | clean = clean * scalarclean 60 | rmsclean = (clean**2).mean()**0.5 61 | 62 | rmsnoise = (noise**2).mean()**0.5 63 | scalarnoise = 10 ** (-25 / 20) /rmsnoise 64 | noise = noise * scalarnoise 65 | rmsnoise = (noise**2).mean()**0.5 66 | 67 | # Set the noise level for a given SNR 68 | noisescalar = np.sqrt(rmsclean / (10**(snr/20)) / rmsnoise) 69 | noisenewlevel = noise * noisescalar 70 | noisyspeech = clean + noisenewlevel 71 | return clean, noisenewlevel, noisyspeech 72 | 73 | 74 | -------------------------------------------------------------------------------- /corpus/librispeech.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | import os 4 | from joblib import Parallel, delayed 5 | from torch.utils.data import Dataset 6 | import ipdb 7 | 8 | def read_text(file): 9 | '''Get transcription of target wave file, 10 | it's somewhat redundant for accessing each txt multiplt times, 11 | but it works fine with multi-thread''' 12 | src_file = '-'.join(file.split('-')[:-1])+'.trans.txt' 13 | idx = file.split('/')[-1].split('.')[0] 14 | 15 | with open(src_file, 'r') as fp: 16 | for line in fp: 17 | if idx == line.split(' ')[0]: 18 | return line[:-1].split(' ', 1)[1] 19 | 20 | 21 | class LibriDataset(Dataset): 22 | def __init__(self, bucket_size, path, noise_type=None, noise_snr=None, ascending=False): 23 | # Setup 24 | self.path = path 25 | self.bucket_size = bucket_size 26 | split = ['test-other'] 27 | # added by esyoon 2023-08-30-02:57:27 28 | # split = ['test-clean'] 29 | 30 | # List all wave files 31 | file_list = [] 32 | for s in split: 33 | split_list = list(Path(os.path.join(path, s)).rglob("*.flac")) 34 | file_list += split_list 35 | file_list.sort() 36 | text = [] 37 | 38 | for f in tqdm(file_list, desc='Read text'): 39 | transcription = read_text(str(f)) 40 | text.append(transcription) 41 | if noise_type: 42 | snr_string = f"_{noise_snr}.0" if noise_snr in [0, -10, 10] else "" 43 | # file_list = sorted(list(Path(os.path.join(path, f"../MS-SNSD/libri_test_noise{snr_string}/{noise_type}")).rglob("*.wav"))) 44 | 45 | # added by esyoon 2023-08-30-00:34:08 46 | file_list = sorted(list(Path(os.path.join(f"./data/libri_{split[0]}_noise{snr_string}/{noise_type}")).rglob("*.wav"))) 47 | self.file_list, self.text = zip(*[(f_name, txt) 48 | for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))]) 49 | # ipdb.set_trace() 50 | def __getitem__(self, index): 51 | if self.bucket_size > 1: 52 | # Return a bucket 53 | index = min(len(self.file_list)-self.bucket_size, index) 54 | return [(f_path, txt) for f_path, txt in 55 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 56 | else: 57 | return self.file_list[index], self.text[index] 58 | 59 | def __len__(self): 60 | return len(self.file_list) -------------------------------------------------------------------------------- /corpus/ted.py: -------------------------------------------------------------------------------- 1 | from unicodedata import name 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | import os 5 | from joblib import Parallel, delayed 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def read_text(tpath, file): 10 | '''Get transcription of target wave file, 11 | it's somewhat redundant for accessing each txt multiplt times, 12 | but it works fine with multi-thread''' 13 | file = file.split('/')[-1].replace('wav', 'txt') 14 | txt_list = os.path.join(tpath, file) 15 | 16 | with open(txt_list, 'r') as fp: 17 | for line in fp: 18 | return line.strip('\n') 19 | 20 | 21 | 22 | class TedDataset(Dataset): 23 | def __init__(self, bucket_size, path="/home/daniel094144/data/TEDLIUM_release2/test", enhance=False, ascending=True): 24 | # Setup 25 | self.path = path 26 | self.bucket_size = bucket_size 27 | 28 | split = [''] 29 | apath = path + "/wav_segment" 30 | tpath = path + "/transcription" 31 | 32 | file_list = [] 33 | for s in split: 34 | if enhance: 35 | split_list = list(Path(os.path.join(os.path.join(apath, 'se_wav'), s)).glob("*.wav")) 36 | else: 37 | split_list = list(Path(os.path.join(apath, s)).glob("*.wav")) 38 | file_list += split_list 39 | 40 | text = [] 41 | filtered_file_list = [] 42 | for f in tqdm(file_list, desc='Read text'): 43 | transcription = read_text(tpath, str(f)) 44 | # print(transcription) 45 | if transcription == None: 46 | pass 47 | # elif len(transcription.split()) <= 3: 48 | # pass 49 | else: 50 | filtered_file_list.append(f) 51 | text.append(transcription) 52 | 53 | print(len(filtered_file_list), len(text)) 54 | file_list = filtered_file_list 55 | self.file_list, self.text = zip(*[(f_name, txt) 56 | for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))]) 57 | 58 | def __getitem__(self, index): 59 | if self.bucket_size > 1: 60 | # Return a bucket 61 | index = min(len(self.file_list)-self.bucket_size, index) 62 | return [(f_path, txt) for f_path, txt in 63 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 64 | else: 65 | return self.file_list[index], self.text[index] 66 | 67 | def __len__(self): 68 | return len(self.file_list) 69 | -------------------------------------------------------------------------------- /conf/config_sgem_ctc.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | asr: facebook/wav2vec2-base-960h 3 | processor: patrickvonplaten/wav2vec2-base-100h-with-lm 4 | 5 | ### dataset 6 | # dataset_name: librispeech 7 | # dataset_dir: /home/server17/hdd/changhun_workspace/LibriSpeech 8 | dataset_name: chime 9 | dataset_dir: /home/server17/hdd/changhun_workspace/CHiME3 10 | # dataset_name: ted 11 | # dataset_dir: /home/server17/hdd/changhun_workspace/TEDLIUM_release2/test 12 | # dataset_name: valentini 13 | # dataset_dir: /home/server17/hdd/changhun_workspace/Valentini 14 | # dataset_name: commonvoice 15 | # dataset_dir: /home/server17/hdd/changhun_workspace/cv-corpus-5.1-2020-06-22/en 16 | 17 | extra_noise: 0.00 18 | noise_type: null # currently supported: null, AirConditioner_6, AirportAnnouncements_2, Babble_4, CopyMachine_2, Munching_3, Neighbor_6, ShuttingDoor_6, Typing_2, VacuumCleaner_1 19 | noise_snr: 10 20 | sample_rate: 16000 21 | batch_size: 1 22 | 23 | ### device and amp 24 | device: cuda 25 | # use_amp: true 26 | 27 | ### logging 28 | log_dir: exps/sgem_ctc 29 | 30 | ### seed for reproductivity 31 | seed: null 32 | 33 | ### optimizer & train hyparameters & learning rate scheduler 34 | optimizer: AdamW 35 | train_params: [feature] # currently supported: all, feature, LN, linear(linear probing), enc, dec 36 | steps: 10 37 | episodic: true # load pretrained model again for every batch 38 | 39 | lr: 4e-5 40 | scheduler: CosineAnnealingLR # null or CosineAnnealingLR 41 | t_max: 10 42 | lr_min: 2e-5 43 | 44 | ################################################################## 45 | ### methods & other hyperparameters 46 | ## currently supported methods: original, em_uncertainty, em_sparse, greedy_pseudo_labeling, ctc, beam_search_max, beam_search_all, beam_search_negative_sampling, diversity_maximization, renyi_em 47 | method: [renyi_em, beam_search_negative_sampling] 48 | decoding_method: beam_search # greedy_search or beam_search 49 | lm_coef: 0.3 50 | kld_weight: 0.0625 # 0, 0.0625(1/16), 0.125(1/8), 0.25(1/4), 9.5(1/2), ... 51 | 52 | ### beam search 53 | beam_width: 5 54 | num_positives: 1 # for beam_search_all 55 | num_negatives: 4 # for beam_candidate for negtive_sampling 56 | 57 | temp: 2.5 # temperature scaling 58 | em_coef: 0.3 # for balancing entropy minimization and minimum class confusion for baseline 59 | not_blank: true 60 | certain_only: true 61 | 62 | ### renyi entropy minimization 63 | renyi_entropy_alpha: 1.5 # 1, 2, ..., inf 64 | 65 | # negative sampling 66 | negative_sampling_method: ns3l # random, beam_candidate, ns3l 67 | ns_coef: 1 68 | ns_threshold: 0.04 69 | 70 | # thresholding 71 | prob_threshold: 0.9 72 | entropy_threshold: 0.05 73 | ################################################################## -------------------------------------------------------------------------------- /conf/config_litta_ctc.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | asr: facebook/wav2vec2-base-960h 3 | processor: patrickvonplaten/wav2vec2-base-100h-with-lm 4 | 5 | ### dataset 6 | # dataset_name: librispeech 7 | # dataset_dir: /home/server17/hdd/changhun_workspace/LibriSpeech 8 | dataset_name: chime 9 | dataset_dir: /home/server17/hdd/changhun_workspace/CHiME3 10 | # dataset_name: ted 11 | # dataset_dir: /home/server17/hdd/changhun_workspace/TEDLIUM_release2/test 12 | # dataset_name: valentini 13 | # dataset_dir: /home/server17/hdd/changhun_workspace/Valentini 14 | # dataset_name: commonvoice 15 | # dataset_dir: /home/server17/hdd/changhun_workspace/cv-corpus-5.1-2020-06-22/en 16 | 17 | extra_noise: 0.00 18 | noise_type: null # currently supported: null, AirConditioner_6, AirportAnnouncements_2, Babble_4, CopyMachine_2, Munching_3, Neighbor_6, ShuttingDoor_6, Typing_2, VacuumCleaner_1 19 | noise_snr: 10 20 | sample_rate: 16000 21 | batch_size: 1 22 | 23 | ### device and amp 24 | device: cuda 25 | # use_amp: true 26 | 27 | ### logging 28 | log_dir: exps/litta_ctc 29 | 30 | ### seed for reproductivity 31 | seed: null 32 | 33 | ### optimizer & train hyparameters & learning rate scheduler 34 | optimizer: AdamW 35 | train_params: [feature] # currently supported: all, feature, LN, linear(linear probing), enc, dec 36 | steps: 10 37 | episodic: true # load pretrained model again for every batch 38 | 39 | lr: 4e-5 40 | scheduler: CosineAnnealingLR # null or CosineAnnealingLR 41 | t_max: 10 42 | lr_min: 2e-5 43 | 44 | ################################################################## 45 | ### methods & other hyperparameters 46 | ## currently supported methods: original, em_uncertainty, em_sparse, greedy_pseudo_labeling, ctc, beam_search_max, beam_search_all, beam_search_negative_sampling, diversity_maximization, renyi_em 47 | method: [renyi_em, beam_search_negative_sampling] 48 | decoding_method: beam_search # greedy_search or beam_search 49 | lm_coef: 0.3 50 | kld_weight: 0.0625 # 0, 0.0625(1/16), 0.125(1/8), 0.25(1/4), 9.5(1/2), ... 51 | 52 | ### beam search 53 | beam_width: 5 54 | num_positives: 1 # for beam_search_all 55 | num_negatives: 4 # for beam_candidate for negtive_sampling 56 | 57 | temp: 2.5 # temperature scaling 58 | em_coef: 0.3 # for balancing entropy minimization and minimum class confusion for baseline 59 | not_blank: true 60 | certain_only: true 61 | 62 | ### renyi entropy minimization 63 | renyi_entropy_alpha: 1.5 # 1, 2, ..., inf 64 | 65 | # negative sampling 66 | negative_sampling_method: ns3l # random, beam_candidate, ns3l 67 | ns_coef: 1 68 | ns_threshold: 0.04 69 | 70 | # thresholding 71 | prob_threshold: 0.9 72 | entropy_threshold: 0.05 73 | 74 | # llm parameters 75 | lm_model: openchat/openchat_3.5 76 | llm_model: openchat3.5 77 | tta_coef: 1.0 # test time adaptation loss coefficient 78 | llm_coef: 1.0 # llm loss coefficient 79 | 80 | ################################################################## -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from functools import partial 4 | from torch.utils.data import DataLoader 5 | 6 | SAMPLE_RATE = 16000 7 | 8 | 9 | def collect_audio_batch(batch, extra_noise=0., maxLen=600000): 10 | '''Collects a batch, should be list of tuples (audio_path , list of int token ) 11 | e.g. [(file1,txt1),(file2,txt2),...] 12 | ''' 13 | def audio_reader(filepath): 14 | wav, sample_rate = torchaudio.load(filepath) 15 | if sample_rate != SAMPLE_RATE: 16 | wav = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE)(wav) 17 | wav = wav.reshape(-1) 18 | if wav.shape[-1] >= maxLen: 19 | print(f'{filepath} has len {wav.shape}, truncate to {maxLen}') 20 | wav = wav[:maxLen] 21 | wav += extra_noise * torch.randn_like(wav) 22 | return wav 23 | 24 | # Bucketed batch should be [[(file1,txt1),(file2,txt2),...]] 25 | if type(batch[0]) is not tuple: 26 | batch = batch[0] 27 | 28 | # Read batch 29 | file, audio_feat, audio_len, text = [], [], [], [] 30 | with torch.no_grad(): 31 | for b in batch: 32 | feat = audio_reader(str(b[0])).numpy() 33 | # feat = audio_reader(str(b[0])) 34 | file.append(str(b[0]).split('/')[-1].split('.')[0]) 35 | audio_feat.append(feat) 36 | audio_len.append(len(feat)) 37 | text.append(b[1]) 38 | 39 | return torch.tensor(audio_len), audio_feat, text, file 40 | 41 | 42 | def create_dataset(name, path, batch_size=1, noise_type=None, noise_snr=None): 43 | ''' Interface for creating all kinds of dataset''' 44 | 45 | # Recognize corpus 46 | if name.lower() == "librispeech": 47 | from corpus.librispeech import LibriDataset as Dataset 48 | elif name.lower() == "chime": 49 | from corpus.CHiME import CHiMEDataset as Dataset 50 | elif name.lower() == "ted": 51 | from corpus.ted import TedDataset as Dataset 52 | elif name.lower() == "commonvoice": 53 | from corpus.commonvoice import CVDataset as Dataset 54 | elif name.lower() == "valentini": 55 | from corpus.valentini import ValDataset as Dataset 56 | elif name.lower() =="l2arctic": 57 | from corpus.l2arctic import L2ArcticDataset as Dataset 58 | else: 59 | raise NotImplementedError 60 | 61 | loader_bs = batch_size 62 | if name.lower() == "librispeech": 63 | dataset = Dataset(batch_size, path, noise_type=noise_type, noise_snr=noise_snr) 64 | else: 65 | dataset = Dataset(batch_size, path) 66 | 67 | print(f'[INFO] There are {len(dataset)} samples.') 68 | 69 | return dataset, loader_bs 70 | 71 | 72 | def load_dataset(name='librispeech', path=None, batch_size=1, extra_noise=0., noise_type=None, noise_snr=None, num_workers=4): 73 | ''' Prepare dataloader for training/validation''' 74 | dataset, loader_bs = create_dataset(name, path, batch_size, noise_type=noise_type, noise_snr=noise_snr) 75 | if name == "librispeech" and noise_type == None: 76 | collate_fn = partial(collect_audio_batch, extra_noise=extra_noise) 77 | else: 78 | collate_fn = partial(collect_audio_batch, extra_noise=0) 79 | 80 | dataloader = DataLoader(dataset, batch_size=loader_bs, shuffle=False, 81 | collate_fn=collate_fn, num_workers=num_workers) 82 | return dataloader -------------------------------------------------------------------------------- /corpus/noisyspeech_synthesizer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import soundfile as sf 4 | import os 5 | import argparse 6 | import configparser as CP 7 | from audiolib import audioread, audiowrite, snr_mixer 8 | 9 | 10 | def main(cfg): 11 | snr_lower = float(cfg["snr_lower"]) 12 | snr_upper = float(cfg["snr_upper"]) 13 | total_snrlevels = int(cfg["total_snrlevels"]) 14 | 15 | clean_dir = os.path.join(os.path.dirname(__file__), 'clean_train') 16 | if cfg["speech_dir"] != 'None': 17 | clean_dir = cfg["speech_dir"] 18 | if not os.path.exists(clean_dir): 19 | assert False, ("Clean speech data is required") 20 | 21 | noise_dir = os.path.join(os.path.dirname(__file__), 'noise_train') 22 | if cfg["noise_dir"] != 'None': 23 | noise_dir = cfg["noise_dir"] 24 | if not os.path.exists(noise_dir): 25 | assert False, ("Noise data is required") 26 | 27 | fs = float(cfg["sampling_rate"]) 28 | audioformat = cfg["audioformat"] 29 | noise_audioformat = cfg["noise_audioformat"] 30 | total_hours = float(cfg["total_hours"]) 31 | audio_length = float(cfg["audio_length"]) 32 | silence_length = float(cfg["silence_length"]) 33 | # noisyspeech_dir = os.path.join(os.path.dirname(__file__), 'NoisySpeech_training') 34 | # if not os.path.exists(noisyspeech_dir): 35 | # os.makedirs(noisyspeech_dir) 36 | # clean_proc_dir = os.path.join(os.path.dirname(__file__), 'CleanSpeech_training') 37 | # if not os.path.exists(clean_proc_dir): 38 | # os.makedirs(clean_proc_dir) 39 | # noise_proc_dir = os.path.join(os.path.dirname(__file__), 'Noise_training') 40 | # if not os.path.exists(noise_proc_dir): 41 | # os.makedirs(noise_proc_dir) 42 | 43 | total_secs = total_hours*60*60 44 | total_samples = int(total_secs * fs) 45 | audio_length = int(audio_length * fs) 46 | SNR = np.linspace(snr_lower, snr_upper, total_snrlevels) 47 | print(f"SNR : {SNR}") 48 | cleanfilenames = glob.glob(os.path.join(clean_dir, f"**/{audioformat}"), recursive=True) 49 | if cfg["noise_types_excluded"]=='None': 50 | noisefilenames = glob.glob(os.path.join(noise_dir, f"**/{noise_audioformat}"), recursive=True) 51 | else: 52 | filestoexclude = cfg["noise_types_excluded"].split(',') 53 | noisefilenames = glob.glob(os.path.join(noise_dir, f"**/{noise_audioformat}"), recursive=True) 54 | for i in range(len(filestoexclude)): 55 | noisefilenames = [fn for fn in noisefilenames if not os.path.basename(fn).startswith(filestoexclude[i])] 56 | 57 | print(f"noisefilenames: {noisefilenames}") 58 | for noisefile in noisefilenames: 59 | noisyspeech_dir = os.path.join(os.path.dirname(__file__), "../data/" f"libri_test-other_noise_{SNR[0]}", noisefile.split("/")[-1].split(".")[0]) 60 | for i in range(np.size(SNR)): 61 | for cleanfile in cleanfilenames: 62 | clean, fs = audioread(cleanfile) 63 | noise, fs = audioread(noisefile) 64 | 65 | noiseconcat = noise 66 | while len(noiseconcat) <= len(clean): 67 | noiseconcat = np.append(noiseconcat, noise) 68 | noise = noiseconcat 69 | if len(noise)>=len(clean): 70 | noise = noise[0:len(clean)] 71 | 72 | clean_snr, noise_snr, noisy_snr = snr_mixer(clean, noise, SNR[i]) 73 | noisyfilename = f"{cleanfile.split('/')[-1]}_{noisefile.split('/')[-1].split('.')[0]}_SNR_{str(SNR[i])}.wav" 74 | noisypath = os.path.join(noisyspeech_dir, noisyfilename) 75 | audiowrite(noisy_snr, fs, noisypath, norm=False) 76 | 77 | 78 | 79 | if __name__=="__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--cfg", default="../conf/noisyspeech_synthesizer.cfg", help="Read noisyspeech_synthesizer.cfg for all the details") 82 | parser.add_argument("--cfg_str", type=str, default="noisy_speech") 83 | # parser.add_argument("--speech_dir", type=str, default="/home/server17/hdd/changhun_workspace/LibriSpeech") 84 | # parser.add_argument("--snr_lower", type=int, default=10) 85 | args = parser.parse_args() 86 | 87 | cfgpath = os.path.join(os.path.dirname(__file__), args.cfg) 88 | assert os.path.exists(cfgpath), f"No configuration file as [{cfgpath}]" 89 | cfg = CP.ConfigParser() 90 | cfg._interpolation = CP.ExtendedInterpolation() 91 | cfg.read(cfgpath) 92 | 93 | main(cfg._sections[args.cfg_str]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import hydra 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.utils.rnn import pad_sequence 9 | torch.backends.cudnn.enabled = True 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = True 12 | 13 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM 14 | from speechbrain.pretrained import EncoderDecoderASR 15 | import nemo.collections.asr as nemo_asr 16 | from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import BeamRNNTInfer 17 | from pyctcdecode import BeamSearchDecoderCTC 18 | from pyctcdecode.alphabet import Alphabet 19 | from pyctcdecode.language_model import LanguageModel 20 | from jiwer import wer 21 | 22 | from data import * 23 | from forward import * 24 | from utils import * 25 | 26 | 27 | def forward_and_adapt(args, model, processor, optimizer, scheduler, wavs, lens): 28 | global original_model 29 | 30 | optimizer.zero_grad() 31 | blank_index = get_blank_index(args, model, processor) 32 | 33 | for i, wav in enumerate(wavs): 34 | wav = wav[:lens[i]].unsqueeze(0) 35 | outputs, pseudo_labels = get_logits_and_pseudo_labels(args, model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 36 | if "original" in args.method or "em_uncertainty" in args.method or "em_sparse" in args.method: 37 | predicted_ids = torch.argmax(outputs, dim=-1) 38 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 39 | 40 | if args.em_coef > 0: 41 | if "original" in args.method: 42 | if args.not_blank: 43 | e_loss = softmax_entropy(outputs / args.temp)[non_blank].mean(0).mean() 44 | else: 45 | e_loss = softmax_entropy(outputs / args.temp).mean(0).mean() 46 | (args.em_coef * e_loss / (len(wavs))).backward(retain_graph=True) 47 | if 1 - args.em_coef > 0: 48 | c_loss = mcc_loss(outputs / args.temp, class_num=outputs.shape[-1], reweight=True) 49 | ((1 - args.em_coef) * c_loss / (len(wavs))).backward(retain_graph=True) 50 | if 'beam_search_max' in args.method or 'beam_search_all' in args.method or 'beam_search_negative_sampling' in args.method: 51 | criterion = nn.CrossEntropyLoss(ignore_index=blank_index) if args.not_blank else nn.CrossEntropyLoss() 52 | if 'beam_search_max' in args.method: 53 | char_history = pseudo_labels[0].to(args.device) 54 | if args.certain_only: 55 | selected_frame = set() 56 | top_idx, top_prob = -1, 0 57 | for frame_idx, (output, char_idx) in enumerate(zip(outputs.squeeze(0), char_history)): 58 | probs = torch.softmax(output, dim=-1) 59 | if probs[char_idx] > args.prob_threshold: 60 | selected_frame.add(frame_idx) 61 | if char_idx != blank_index and probs[char_idx].item() > top_prob: 62 | top_idx = frame_idx 63 | top_prob = probs[char_idx].item() 64 | selected_frame.add(top_idx) 65 | selected_frame = sorted(selected_frame) 66 | selected_outputs, selected_char_history = outputs.squeeze(0)[selected_frame], char_history[selected_frame] 67 | loss = criterion(selected_outputs / args.temp, selected_char_history) 68 | else: 69 | loss = criterion(outputs / args.temp, char_history) 70 | (loss / len(wavs)).backward(retain_graph=True) 71 | elif 'beam_search_all' in args.method: 72 | loss = 0 73 | for char_history in pseudo_labels[:args.num_positives]: 74 | char_history = char_history.to(args.device) 75 | if args.certain_only: 76 | selected_frame = set() 77 | top_idx, top_prob = -1, 0 78 | for frame_idx, (output, char_idx) in enumerate(zip(outputs.squeeze(0), char_history)): 79 | probs = torch.softmax(output, dim=-1) 80 | if probs[char_idx] > args.prob_threshold: 81 | selected_frame.add(frame_idx) 82 | if char_idx != blank_index and probs[char_idx].item() > top_prob: 83 | top_idx = frame_idx 84 | top_prob = probs[char_idx].item() 85 | selected_frame.add(top_idx) 86 | selected_frame = sorted(selected_frame) 87 | selected_outputs, selected_char_history = outputs.squeeze(0)[selected_frame], char_history[selected_frame] 88 | loss += criterion(selected_outputs / args.temp, selected_char_history) / len(pseudo_labels) 89 | else: 90 | loss += criterion(outputs / args.temp, char_history) / len(pseudo_labels) 91 | (loss / len(wavs)).backward(retain_graph=True) 92 | if 'beam_search_negative_sampling' in args.method: 93 | negative_outputs = outputs.clone() 94 | negative_loss = 0 95 | char_history = pseudo_labels[0].to(args.device) 96 | if args.negative_sampling_method == "random": 97 | for _ in range(args.num_negatives): 98 | negative_char_history = torch.randint_like(input=char_history, high=negative_outputs.shape[-1]).to(args.device) 99 | negative_mask = (negative_char_history != char_history) & (char_history != 0) 100 | 101 | selected_frame = [] 102 | for frame_idx, mask in enumerate(negative_mask): 103 | if mask: 104 | selected_frame.append(frame_idx) 105 | selected_negative_outputs = negative_outputs.squeeze(0)[selected_frame] 106 | selected_negative_char_history = negative_char_history[selected_frame] 107 | if len(selected_negative_outputs) > 0: 108 | negative_loss += -criterion(selected_negative_outputs / args.temp, selected_negative_char_history) / args.num_negatives 109 | elif args.negative_sampling_method == "beam_candidate": 110 | for out_idx in range(len(pseudo_labels))[-args.num_negatives:]: 111 | negative_char_history = pseudo_labels[out_idx].to(args.device) 112 | negative_mask = (negative_char_history != char_history) & (char_history != 0) 113 | 114 | selected_frame = [] 115 | for frame_idx, mask in enumerate(negative_mask): 116 | if mask: 117 | selected_frame.append(frame_idx) 118 | selected_negative_outputs = negative_outputs.squeeze(0)[selected_frame] 119 | selected_negative_char_history = negative_char_history[selected_frame] 120 | if len(selected_negative_outputs) > 0: 121 | negative_loss += -criterion(selected_negative_outputs / args.temp, selected_negative_char_history) / args.num_negatives 122 | elif args.negative_sampling_method == 'ns3l': 123 | negative_mask = torch.where(torch.softmax(negative_outputs, dim=-1) < args.ns_threshold * (10 / negative_outputs.shape[-1]), 1, 0) 124 | negative_loss += torch.mean(-torch.log(1 - torch.sum(negative_mask * torch.softmax(negative_outputs / args.temp, dim=-1), dim=-1))) 125 | if torch.is_tensor(negative_loss): 126 | (args.ns_coef * negative_loss / len(wavs)).backward(retain_graph=True) 127 | if 'renyi_em' in args.method: 128 | predicted_ids = torch.argmax(outputs, dim=-1) 129 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 130 | 131 | if args.not_blank: 132 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha=args.renyi_entropy_alpha) 133 | else: 134 | e_loss = renyi_entropy(outputs / args.temp, alpha=args.renyi_entropy_alpha) 135 | (e_loss / (len(wavs))).backward(retain_graph=True) 136 | if 'kld_ori' in args.method: 137 | assert 0 <= args.kld_weight <= 1 138 | # TODO: implement bias parameter 139 | 140 | # naive pseudo-labeling 141 | predicted_ids = torch.argmax(outputs, dim=-1) 142 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 143 | # e_loss = renyi_entropy(outputs, alpha='inf') 144 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha='inf') 145 | ((1 - args.kld_weight) * e_loss / (len(wavs))).backward(retain_graph=True) 146 | 147 | # kld loss 148 | original_outputs, _ = get_logits_and_pseudo_labels(args, original_model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 149 | probs = torch.softmax(outputs, dim=-1) 150 | original_probs = torch.softmax(original_outputs, dim=-1) 151 | kl_div_loss = F.kl_div(torch.log(probs), original_probs.detach(), reduction="batchmean") 152 | (args.kld_weight * kl_div_loss / (len(wavs))).backward(retain_graph=True) 153 | if 'kld_comb' in args.method: 154 | # Renyi em 155 | predicted_ids = torch.argmax(outputs, dim=-1) 156 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 157 | if args.not_blank: 158 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha=args.renyi_entropy_alpha) 159 | else: 160 | e_loss = renyi_entropy(outputs / args.temp, alpha=args.renyi_entropy_alpha) 161 | ((1 - args.kld_weight) * e_loss / (len(wavs))).backward(retain_graph=True) 162 | 163 | # negative sampling 164 | criterion = nn.CrossEntropyLoss(ignore_index=blank_index) if args.not_blank else nn.CrossEntropyLoss() 165 | negative_outputs = outputs.clone() 166 | negative_loss = 0 167 | char_history = pseudo_labels[0].to(args.device) 168 | negative_mask = torch.where(torch.softmax(negative_outputs, dim=-1) < args.ns_threshold * (10 / negative_outputs.shape[-1]), 1, 0) 169 | negative_loss += torch.mean(-torch.log(1 - torch.sum(negative_mask * torch.softmax(negative_outputs / args.temp, dim=-1), dim=-1))) 170 | if torch.is_tensor(negative_loss): 171 | ((1 - args.kld_weight) * args.ns_coef * negative_loss / len(wavs)).backward(retain_graph=True) 172 | 173 | # kld loss 174 | original_outputs, _ = get_logits_and_pseudo_labels(args, original_model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 175 | probs = torch.softmax(outputs, dim=-1) 176 | original_probs = torch.softmax(original_outputs, dim=-1) 177 | kl_div_loss = F.kl_div(torch.log(probs), original_probs.detach(), reduction="batchmean") 178 | (args.kld_weight * kl_div_loss / (len(wavs))).backward(retain_graph=True) 179 | 180 | optimizer.step() 181 | if scheduler is not None: 182 | scheduler.step() 183 | 184 | 185 | @hydra.main(version_base=None, config_path="conf", config_name="config") 186 | def main(args): 187 | if args.seed: 188 | set_seed(args.seed) 189 | 190 | if not os.path.exists(args.log_dir): 191 | os.makedirs(args.log_dir) 192 | global logger 193 | logger = get_logger(args) 194 | logger.info(OmegaConf.to_yaml(args)) 195 | 196 | dataset = load_dataset(args.dataset_name, args.dataset_dir, args.batch_size, args.extra_noise, args.noise_type, args.noise_snr) 197 | gt_texts, ori_transcriptions, transcriptions_1, transcriptions_3, transcriptions_5, transcriptions_10, transcriptions_20, transcriptions_40 = [], [], [], [], [], [], [], [] 198 | 199 | global original_model 200 | 201 | model = get_model(args) 202 | original_model = get_model(args) 203 | params, _ = collect_params(model, args.train_params) 204 | optimizer, scheduler = get_optimizer(args, params, opt_name=args.optimizer, lr=args.lr, scheduler=args.scheduler) 205 | processor = Wav2Vec2Processor.from_pretrained(args.asr, sampling_rate=args.sample_rate, return_attention_mask=True) if isinstance(model, Wav2Vec2ForCTC) else None 206 | 207 | if isinstance(model, Wav2Vec2ForCTC): 208 | decoder_processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.processor) 209 | elif isinstance(model, EncoderDecoderASR): 210 | decoder_processor = None 211 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 212 | decoder_processor = BeamSearchDecoderCTC( 213 | alphabet=Alphabet(labels=model.decoder.vocabulary+[""], is_bpe=True), 214 | language_model=LanguageModel.load_from_dir(args.processor), 215 | ) 216 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 217 | decoder_processor = BeamRNNTInfer( 218 | model.decoding.decoding.decoder.to(args.device), 219 | model.decoding.decoding.joint.to(args.device), 220 | beam_size=args.beam_width, 221 | return_best_hypothesis=False, 222 | ) 223 | 224 | episodic = args.episodic 225 | steps = args.steps 226 | 227 | if episodic: 228 | original_model_state, original_optimizer_state, original_scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler) 229 | 230 | for batch_idx, batch in enumerate(dataset): 231 | if args.dataset_name == "commonvoice" and batch_idx >= 1000: 232 | break 233 | 234 | lens, wavs, texts, _ = batch 235 | if isinstance(model, Wav2Vec2ForCTC): 236 | wavs = processor(wavs, sampling_rate=args.sample_rate, return_tensors="pt", padding="longest").input_values.to(args.device) 237 | else: 238 | wavs = pad_sequence([torch.from_numpy(wav) for wav in wavs], batch_first=True).to(args.device) 239 | lens = lens.to(args.device) 240 | 241 | gt_texts.extend(texts) 242 | ori_transcription = transcribe_batch(args, original_model, processor, wavs, lens) 243 | ori_transcriptions.extend(ori_transcription) 244 | ori_wer = wer(list(texts), list(ori_transcription)) 245 | 246 | logger.info(f"{batch_idx}/{len(dataset)}") 247 | logger.info(f"gt text: {' '.join(list(texts))}") 248 | logger.info(f"original WER: {ori_wer}") 249 | logger.info(f"original text: {' '.join(list(ori_transcription))}") 250 | 251 | if episodic: 252 | model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, scheduler, original_model_state, original_optimizer_state, original_scheduler_state) 253 | 254 | for step_idx in range(1, steps + 1): 255 | model = set_rnn_to_train(model) 256 | forward_and_adapt(args, model, decoder_processor, optimizer, scheduler, wavs, lens) 257 | transcription = transcribe_batch(args, model, processor, wavs, lens) 258 | 259 | if step_idx in [1, 3, 5, 10, 20, 40]: 260 | transcription_list = eval(f"transcriptions_{step_idx}") 261 | transcription_list.extend(transcription) 262 | ada_wer = wer(list(texts), list(transcription)) 263 | logger.info(f"adapt-{step_idx} WER: {ada_wer}") 264 | logger.info(f"adapt-{step_idx} text: {' '.join(list(transcription))}") 265 | 266 | gc.collect() 267 | torch.cuda.empty_cache() 268 | logger.info("\n") 269 | 270 | logger.info(OmegaConf.to_yaml(args)) 271 | logger.info(f"number of data : {len(dataset)}") 272 | logger.info(f"original WER: {wer(gt_texts, ori_transcriptions)}") 273 | for step_idx in [1, 3, 5, 10, 20, 40]: 274 | if step_idx > steps: 275 | break 276 | transcription_list = eval(f"transcriptions_{step_idx}") 277 | logger.info(f"TTA-{step_idx}: {wer(gt_texts, transcription_list)}") 278 | 279 | 280 | 281 | if __name__ == '__main__': 282 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import logging 4 | from copy import deepcopy 5 | from datetime import datetime 6 | import math 7 | import pickle 8 | import string 9 | import json 10 | import csv 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.optim.lr_scheduler import CosineAnnealingLR 17 | 18 | from transformers import Wav2Vec2ForCTC 19 | from speechbrain.pretrained import EncoderDecoderASR 20 | import nemo.collections.asr as nemo_asr 21 | 22 | from typing import Optional, Tuple, Union 23 | import ipdb 24 | 25 | def set_seed(seed): 26 | os.environ["PYTHONHASHSEED"] = str(seed) 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) 33 | 34 | 35 | def get_logger(args): 36 | logger = logging.getLogger("main") 37 | logger.setLevel(logging.INFO) 38 | formatter = logging.Formatter('%(message)s') 39 | 40 | time_string = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 41 | file_handler = logging.FileHandler(os.path.join(args.log_dir, f"log_{time_string}.txt")) 42 | file_handler.setFormatter(formatter) 43 | logger.addHandler(file_handler) 44 | return logger 45 | 46 | 47 | def get_model(args): 48 | if args.asr in ["facebook/wav2vec2-base-960h"]: # CTC-based models 49 | model = Wav2Vec2ForCTC.from_pretrained(args.asr).requires_grad_(True).eval() 50 | if 'cuda' in args.device: 51 | model = model.cuda() 52 | elif args.asr in ["speechbrain/asr-crdnn-rnnlm-librispeech", "speechbrain/asr-crdnn-transformerlm-librispeech", "speechbrain/asr-transformer-transformerlm-librispeech", "speechbrain/asr-conformersmall-transformerlm-librispeech"]: # attention-based models 53 | model = EncoderDecoderASR.from_hparams(args.asr, run_opts={"device": args.device}).requires_grad_(True).eval() 54 | elif args.asr in ["pretrained_models/stt_en_conformer_ctc_small.nemo", "pretrained_models/stt_en_conformer_ctc_small_ls.nemo"]: # conformers 55 | model = nemo_asr.models.EncDecCTCModelBPE.restore_from(args.asr).to(args.device).requires_grad_(True).eval() 56 | elif args.asr in ["pretrained_models/stt_en_conformer_transducer_small.nemo", "pretrained_models/stt_en_conformer_transducer_large.nemo"]: # transducers 57 | model = nemo_asr.models.EncDecRNNTBPEModel.restore_from(args.asr).to(args.device).requires_grad_(True).eval() 58 | return model 59 | 60 | 61 | def collect_params(model, train_params): 62 | if isinstance(model, Wav2Vec2ForCTC): 63 | return collect_params_ctc(model, train_params) 64 | elif isinstance(model, EncoderDecoderASR): 65 | return collect_params_attn(model, train_params) 66 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 67 | return collect_params_conf(model, train_params) 68 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 69 | return collect_params_trans(model, train_params) 70 | 71 | 72 | def collect_params_ctc(model, train_params): 73 | params, names = [], [] 74 | for nm, m in model.named_modules(): 75 | if "all" in train_params: 76 | for np, p in m.named_parameters(): 77 | p.requires_grad = True 78 | if not f"{nm}.{np}" in names: 79 | params.append(p) 80 | names.append(f"{nm}.{np}") 81 | if "feature" in train_params: 82 | if len(str(nm).split('.')) > 1: 83 | if str(nm).split('.')[1] == 'feature_extractor' or str(nm).split('.')[1] == 'feature_projection': 84 | for np, p in m.named_parameters(): 85 | p.requires_grad = True 86 | if not f"{nm}.{np}" in names: 87 | params.append(p) 88 | names.append(f"{nm}.{np}") 89 | if "LN" in train_params: 90 | if isinstance(m, nn.LayerNorm): 91 | for np, p in m.named_parameters(): 92 | if np in ['weight', 'bias']: 93 | p.requires_grad = True 94 | if not f"{nm}.{np}" in names: 95 | params.append(p) 96 | names.append(f"{nm}.{np}") 97 | return params, names 98 | 99 | 100 | def collect_params_attn(model, train_params): 101 | params, names = [], [] 102 | for nm, m in model.named_modules(): 103 | for np, p in m.named_parameters(): 104 | collect = False 105 | if "all" in train_params: 106 | collect = True 107 | if 'enc' in train_params and 'encoder' in str(nm): 108 | collect = True 109 | if 'dec' in train_params and 'decoder' in str(nm): 110 | collect = True 111 | if 'LN' in train_params and isinstance(m, nn.LayerNorm): 112 | collect = True 113 | 114 | if collect: 115 | p.requires_grad = True 116 | params.append(p) 117 | names.append(f"{nm}.{np}") 118 | return params, names 119 | 120 | 121 | def collect_params_conf(model, train_params): 122 | params, names = [], [] 123 | for nm, m in model.named_modules(): 124 | for np, p in m.named_parameters(): 125 | collect = False 126 | if "all" in train_params: 127 | collect = True 128 | if 'LN' in train_params and isinstance(m, nn.LayerNorm): 129 | collect = True 130 | if 'BN' in train_params and isinstance(m, nn.BatchNorm1d): 131 | collect = True 132 | if 'enc' in train_params and 'encoder' in nm: 133 | collect = True 134 | if 'dec' in train_params and 'decoder' in nm: 135 | collect = True 136 | 137 | if collect: 138 | p.requires_grad = True 139 | params.append(p) 140 | names.append(f"{nm}.{np}") 141 | return params, names 142 | 143 | 144 | def collect_params_trans(model, train_params): 145 | params, names = [], [] 146 | for nm, m in model.named_modules(): 147 | for np, p in m.named_parameters(): 148 | collect = False 149 | if "all" in train_params: 150 | collect = True 151 | if 'LN' in train_params and isinstance(m, nn.LayerNorm): 152 | collect = True 153 | if 'BN' in train_params and isinstance(m, nn.BatchNorm1d): 154 | collect = True 155 | if 'enc' in train_params and 'encoder' in nm: 156 | collect = True 157 | if 'dec' in train_params and 'decoder' in nm: 158 | collect = True 159 | if 'joint' in train_params and 'joint' in nm: 160 | collect = True 161 | 162 | if collect: 163 | p.requires_grad = True 164 | params.append(p) 165 | names.append(f"{nm}.{np}") 166 | return params, names 167 | 168 | 169 | def freeze_norm_stats(model): 170 | for m in model.modules(): 171 | if isinstance(m, nn.BatchNorm1d): 172 | m.track_running_stats = False 173 | 174 | 175 | def set_rnn_to_train(model): 176 | for m in model.modules(): 177 | if isinstance(m, torch.nn.modules.rnn.RNNBase): 178 | m.train() 179 | m.dropout = 0 180 | return model 181 | 182 | 183 | def get_optimizer(args, params, opt_name='AdamW', lr=1e-4, beta=0.9, weight_decay=0., scheduler=None): 184 | opt = getattr(torch.optim, opt_name) 185 | if opt_name == 'Adam': 186 | optimizer = opt(params, lr=lr, betas=(beta, 0.999), weight_decay=weight_decay) 187 | else: 188 | optimizer = opt(params, lr=lr, weight_decay=weight_decay) 189 | 190 | if scheduler is not None: 191 | return optimizer, eval(scheduler)(optimizer, T_max=args.t_max, eta_min=args.lr_min) 192 | return optimizer, None 193 | 194 | 195 | def copy_model_and_optimizer(model, optimizer, scheduler): 196 | """Copy the model and optimizer states for resetting after adaptation.""" 197 | model_state = deepcopy(model.state_dict()) 198 | optimizer_state = deepcopy(optimizer.state_dict()) 199 | if scheduler is not None: 200 | scheduler_state = deepcopy(scheduler.state_dict()) 201 | return model_state, optimizer_state, scheduler_state 202 | else: 203 | return model_state, optimizer_state, None 204 | 205 | 206 | def load_model_and_optimizer(model, optimizer, scheduler, model_state, optimizer_state, scheduler_state): 207 | """Restore the model and optimizer states from copies.""" 208 | model.load_state_dict(model_state, strict=True) 209 | optimizer.load_state_dict(optimizer_state) 210 | if scheduler is not None: 211 | scheduler.load_state_dict(scheduler_state) 212 | return model, optimizer, scheduler 213 | else: 214 | return model, optimizer, None 215 | 216 | 217 | def get_blank_index(args, model, processor): 218 | if isinstance(model, Wav2Vec2ForCTC): 219 | blank_index = 0 220 | elif isinstance(model, EncoderDecoderASR): 221 | blank_index = model.mods.decoder.blank_index 222 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 223 | blank_index = 128 224 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 225 | blank_index = processor.blank 226 | return blank_index 227 | 228 | 229 | def softmax_entropy(x, dim=-1): 230 | return -(x.softmax(dim) * x.log_softmax(dim)).sum(dim) 231 | 232 | 233 | def renyi_entropy(x, alpha, dim=-1): 234 | # x: (B, L, D) 235 | if alpha == 1: 236 | return torch.mean(softmax_entropy(x, dim)) 237 | if alpha == 'inf' or alpha == float('inf'): 238 | entropy, _ = torch.max(x, dim) 239 | return -torch.mean(torch.log(entropy)) 240 | entropy = torch.log(torch.pow(x.softmax(dim), alpha).sum(dim)) # entropy: B, L 241 | entropy = entropy / (1 - alpha) 242 | return torch.mean(entropy) 243 | 244 | def renyi_entropy_raw(x, alpha, dim=-1): 245 | # x: (B, L, D) 246 | if alpha == 1: 247 | return torch.mean(softmax_entropy(x, dim)) 248 | if alpha == 'inf' or alpha == float('inf'): 249 | entropy, _ = torch.max(x, dim) 250 | return -torch.mean(torch.log(entropy)) 251 | entropy = torch.log(torch.pow(x.softmax(dim), alpha).sum(dim)) # entropy: B, L 252 | entropy = entropy / (1 - alpha) 253 | return entropy 254 | 255 | def marginal_renyi_entropy(x, alpha, dim=-1): 256 | # x: (B, L, D) 257 | if alpha == 1: 258 | x = x.mean(0) 259 | return torch.mean(softmax_entropy(x, dim)) 260 | if alpha == 'inf' or alpha == float('inf'): 261 | x = x.mean(0) 262 | entropy, _ = torch.max(x, dim) 263 | return -torch.mean(torch.log(entropy)) 264 | x_softmax = x.softmax(dim) 265 | x_marginal = x_softmax.mean(0) 266 | entropy = torch.log(torch.pow(x_marginal, alpha).sum(dim)) # entropy: B, L 267 | entropy = entropy / (1 - alpha) 268 | return torch.mean(entropy) 269 | 270 | def non_saturating_loss(x, dim=-1): 271 | max_idx = torch.argmax(x, dim=dim, keepdim=True) 272 | one_hots = torch.zeros_like(x).scatter(dim, max_idx, 1).to(x.device) 273 | return - torch.mean(one_hots * x) + torch.log(((1 - one_hots) * torch.exp(x)).sum(dim=dim)).mean() 274 | 275 | 276 | def mcc_loss(x, class_num, reweight=False, dim=-1): 277 | mcc_loss = 0 278 | for x_split in x: # (B, L, D) -> (L, D) 279 | x_split = x_split.unsqueeze(0) 280 | p = x_split.softmax(dim) # (1, L, D) 281 | p = p.squeeze(0) # (L, D) 282 | 283 | if reweight: # (1, L, D) * (L, 1) 284 | target_entropy_weight = softmax_entropy(x_split, dim=-1).detach().squeeze(0) # instance-wise entropy (1, L, D) 285 | target_entropy_weight = 1 + torch.exp(-target_entropy_weight) # (1, L) 286 | target_entropy_weight = x_split.shape[1] * target_entropy_weight / torch.sum(target_entropy_weight) 287 | cov_matrix_t = p.mul(target_entropy_weight.view(-1, 1)).transpose(1, 0).mm(p) 288 | else: 289 | cov_matrix_t = p.transpose(1, 0).mm(p) # (D, L) * (L, D) -> (D, D) 290 | 291 | cov_matrix_t = cov_matrix_t / torch.sum(cov_matrix_t, dim=1) 292 | mcc_loss += (torch.sum(cov_matrix_t) - torch.trace(cov_matrix_t)) / class_num 293 | mcc_loss /= len(x) 294 | return mcc_loss 295 | 296 | 297 | def js_divergence(p1, p2): 298 | total_m = 0.5 * (p1 + p2) 299 | loss = 0.5 * F.kl_div(torch.log(p1), total_m, reduction="batchmean") + 0.5 * F.kl_div(torch.log(p2), total_m, reduction="batchmean") 300 | return loss 301 | 302 | 303 | def log_softmax(x, axis): 304 | x_max = np.amax(x, axis=axis, keepdims=True) 305 | if x_max.ndim > 0: 306 | x_max[~np.isfinite(x_max)] = 0 307 | elif not np.isfinite(x_max): 308 | x_max = 0 309 | tmp = x - x_max 310 | exp_tmp = np.exp(tmp) 311 | with np.errstate(divide="ignore"): 312 | s = np.sum(exp_tmp, axis=axis, keepdims=True) 313 | out: np.ndarray = np.log(s) 314 | out = tmp - out 315 | return out 316 | 317 | # ============================================================================================================= 318 | def compute_mask_indices( 319 | shape: Tuple[int, int], 320 | mask_prob: float, 321 | mask_length: int, 322 | attention_mask: Optional[torch.LongTensor] = None, 323 | min_masks: int = 0, 324 | ) -> np.ndarray: 325 | """ 326 | Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for 327 | ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on 328 | CPU as part of the preprocessing during training. 329 | 330 | Args: 331 | shape: The shape for which to compute masks. This should be of a tuple of size 2 where 332 | the first element is the batch size and the second element is the length of the axis to span. 333 | mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of 334 | independently generated mask spans of length `mask_length` is computed by 335 | `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the 336 | actual percentage will be smaller. 337 | mask_length: size of the mask 338 | min_masks: minimum number of masked spans 339 | attention_mask: A (right-padded) attention mask which independently shortens the feature axis of 340 | each batch dimension. 341 | """ 342 | batch_size, sequence_length = shape 343 | 344 | if mask_length < 1: 345 | raise ValueError("`mask_length` has to be bigger than 0.") 346 | 347 | if mask_length > sequence_length: 348 | raise ValueError( 349 | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" 350 | f" and `sequence_length`: {sequence_length}`" 351 | ) 352 | 353 | # epsilon is used for probabilistic rounding 354 | epsilon = np.random.rand(1).item() 355 | 356 | def compute_num_masked_span(input_length): 357 | """Given input length, compute how many spans should be masked""" 358 | num_masked_span = int(mask_prob * input_length / mask_length + epsilon) 359 | num_masked_span = max(num_masked_span, min_masks) 360 | 361 | # make sure num masked span <= sequence_length 362 | if num_masked_span * mask_length > sequence_length: 363 | num_masked_span = sequence_length // mask_length 364 | 365 | # make sure num_masked span is also <= input_length - (mask_length - 1) 366 | if input_length - (mask_length - 1) < num_masked_span: 367 | num_masked_span = max(input_length - (mask_length - 1), 0) 368 | 369 | return num_masked_span 370 | 371 | # compute number of masked spans in batch 372 | input_lengths = ( 373 | attention_mask.sum(-1).detach().tolist() 374 | if attention_mask is not None 375 | else [sequence_length for _ in range(batch_size)] 376 | ) 377 | 378 | # SpecAugment mask to fill 379 | spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) 380 | spec_aug_mask_idxs = [] 381 | 382 | max_num_masked_span = compute_num_masked_span(sequence_length) 383 | 384 | if max_num_masked_span == 0: 385 | return spec_aug_mask 386 | 387 | for input_length in input_lengths: 388 | # compute num of masked spans for this input 389 | num_masked_span = compute_num_masked_span(input_length) 390 | 391 | # get random indices to mask 392 | spec_aug_mask_idx = np.random.choice( 393 | np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False 394 | ) 395 | 396 | # pick first sampled index that will serve as a dummy index to pad vector 397 | # to ensure same dimension for all batches due to probabilistic rounding 398 | # Picking first sample just pads those vectors twice. 399 | if len(spec_aug_mask_idx) == 0: 400 | # this case can only happen if `input_length` is strictly smaller then 401 | # `sequence_length` in which case the last token has to be a padding 402 | # token which we can use as a dummy mask id 403 | dummy_mask_idx = sequence_length - 1 404 | else: 405 | dummy_mask_idx = spec_aug_mask_idx[0] 406 | 407 | spec_aug_mask_idx = np.concatenate( 408 | [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] 409 | ) 410 | spec_aug_mask_idxs.append(spec_aug_mask_idx) 411 | 412 | spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) 413 | 414 | # expand masked indices to masked spans 415 | spec_aug_mask_idxs = np.broadcast_to( 416 | spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) 417 | ) 418 | spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) 419 | 420 | # add offset to the starting indexes so that indexes now create a span 421 | offsets = np.arange(mask_length)[None, None, :] 422 | offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( 423 | batch_size, max_num_masked_span * mask_length 424 | ) 425 | spec_aug_mask_idxs = spec_aug_mask_idxs + offsets 426 | 427 | # ensure that we cannot have indices larger than sequence_length 428 | if spec_aug_mask_idxs.max() > sequence_length - 1: 429 | spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 430 | 431 | # scatter indices to mask 432 | np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) 433 | 434 | return spec_aug_mask 435 | 436 | def save_pkl(data, file): 437 | with open(file, 'wb') as f: 438 | pickle.dump(data, f) 439 | return 440 | 441 | def load_pkl(file): 442 | with open(file, 'rb') as f: 443 | data = pickle.load(f) 444 | return data 445 | 446 | def append_pkl(data, file): 447 | with open(file, 'ab') as f: 448 | pickle.dump(data, f) 449 | return 450 | 451 | def read_txt(file): 452 | file = open(file, 'r') 453 | data = [] 454 | while True: 455 | line = file.readline() 456 | data.append(line.strip().split('\t')) 457 | if not line: 458 | break 459 | file.close() 460 | 461 | return data 462 | 463 | def csv_write(file, column_name, data=None): 464 | with open(file, 'w', newline='\n') as f: 465 | write = csv.writer(f) 466 | write.writerow(column_name) 467 | if data == None: 468 | return 469 | else: 470 | assert len(data) == len(column_name) 471 | write.writerow(column_name) 472 | for i in range(len(data[0])): 473 | line = [data[column_idx][i] for column_idx, _ in enumerate(column_name)] 474 | write.writerow(line) 475 | return 476 | 477 | def csv_append(file, data): 478 | with open(file, 'a', newline='\n') as f: 479 | write = csv.writer(f) 480 | for i in range(len(data[0])): 481 | line = [data[j][i] for j in range(len(data))] 482 | write.writerow(line) 483 | return 484 | 485 | def csv_append_ver2(file, data): 486 | with open(file, 'a', newline='\n') as f: 487 | write = csv.writer(f) 488 | for idx, line in enumerate(data): 489 | # line = [data[j][i] for j in range(len(data))] 490 | write.writerow(line) 491 | return 492 | 493 | def csv_read(file): 494 | data = [] 495 | alphabet = list(string.ascii_letters) 496 | with open(file, 'r', newline='') as f: 497 | reader = csv.reader(f, delimiter=' ', quotechar='|') 498 | for row in reader: 499 | # ipdb.set_trace() 500 | # if len([i for i in row[0] if i not in alphabet]) == 0: 501 | data.append(row[0]) 502 | return data 503 | 504 | def save_jsonl(file, data): 505 | with open(file, encoding='utf-8', mode='w') as f: 506 | for i in data: 507 | f.write(json.dump(i)+'\n') 508 | return 509 | 510 | def append_jsonl(file, data): 511 | with open(file, encoding='utf-8', mode='a') as f: 512 | f.write(json.dumps(data)+'\n') 513 | return 514 | 515 | def read_jsonl(file): 516 | with open(file, 'r', encoding="utf-8") as f: 517 | data = [json.loads(line) for line in f] 518 | return data 519 | 520 | def print_now(return_flag=0): 521 | # commented out by esyoon 2023-09-24-18:44:05 522 | # t_delta = datetime.timedelta(hours=9) 523 | # JST = datetime.timezone(t_delta, 'JST') 524 | 525 | from datetime import date, datetime, timezone, timedelta 526 | KST = timezone(timedelta(hours=9)) 527 | date = str(date.today()) 528 | time_record = str(datetime.now(KST).time())[:8] 529 | now = date+'_'+time_record 530 | if return_flag == 0: 531 | print(now) 532 | elif return_flag == 1: 533 | return now 534 | else: 535 | pass 536 | 537 | def normalize_dict(input_dict): 538 | total = sum(input_dict.values()) 539 | return {key: value / total for key, value in input_dict.items()} 540 | 541 | def softmax_dict(input_dict): 542 | # total = math.exp(sum(input_dict.values())) 543 | total = sum([math.exp(i) for i in input_dict.values()]) 544 | return {key: math.exp(value) / total for key, value in input_dict.items()} 545 | 546 | def reverse_softmax_dict(input_dict): 547 | # total = math.exp(-1 * sum(input_dict.values())) 548 | total = sum([math.exp(-1 *i) for i in input_dict.values()]) 549 | return {key: math.exp(-1 * value) / total for key, value in input_dict.items()} 550 | 551 | def to_device(batch, device): 552 | output = {} 553 | for k, v in batch.items(): 554 | try: 555 | output[k] = v.to(device) 556 | except Exception as e: 557 | print("to device error: {}".format(e)) 558 | assert 0 559 | output[k] = v 560 | return output -------------------------------------------------------------------------------- /main_lm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import hydra 4 | from omegaconf import OmegaConf 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.utils.rnn import pad_sequence 9 | torch.backends.cudnn.enabled = True 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = True 12 | 13 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM 14 | from speechbrain.pretrained import EncoderDecoderASR 15 | import nemo.collections.asr as nemo_asr 16 | from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import BeamRNNTInfer 17 | from pyctcdecode import BeamSearchDecoderCTC 18 | from pyctcdecode.alphabet import Alphabet 19 | from pyctcdecode.language_model import LanguageModel 20 | from jiwer import wer 21 | 22 | from data import * 23 | from forward import * 24 | from utils import * 25 | 26 | from transformers import AutoTokenizer, AutoModelForCausalLM 27 | import ipdb 28 | 29 | llm_model = 'openchat3.5' 30 | if llm_model == 'llama': 31 | SYSTEM = 'You are a sentence correction assistant' 32 | INSTRUCTION = "Please generate correction of the given <> considering its pronunciation and meaning based on the suitable explanation. Generate the output in form of '\n<>: {explanation of correction}\n\n<>: {corrected sentence}\n'" 33 | TEMPLATE = "[INST] «SYS»\n{}\n«/SYS»\n\n{}\n\n<>: {}[/INST]" 34 | elif llm_model == 'openchat3.5': 35 | TEMPLATE = {'role': 'user', 'content': "Given <> is the transcription of utterance. Please generate correction of the <> considering its pronunciation and meaning based on the generated explanation. Generate the output in form of '\n<>: {explanation of correction}\n\n<>: {corrected sentence}'\n\n<>: "} 36 | 37 | def forward_and_adapt(args, model, processor, optimizer, scheduler, wavs, lens, step): 38 | global original_model, lm_tokenizer, correction_lm 39 | 40 | optimizer.zero_grad() 41 | blank_index = get_blank_index(args, model, processor) 42 | 43 | for i, wav in enumerate(wavs): 44 | wav = wav[:lens[i]].unsqueeze(0) 45 | outputs, pseudo_labels = get_logits_and_pseudo_labels(args, model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 46 | if "original" in args.method or "em_uncertainty" in args.method or "em_sparse" in args.method: 47 | predicted_ids = torch.argmax(outputs, dim=-1) 48 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 49 | 50 | if args.em_coef > 0: 51 | if "original" in args.method: 52 | if args.not_blank: 53 | e_loss = softmax_entropy(outputs / args.temp)[non_blank].mean(0).mean() 54 | else: 55 | e_loss = softmax_entropy(outputs / args.temp).mean(0).mean() 56 | (args.em_coef * e_loss / (len(wavs))).backward(retain_graph=True) 57 | if 1 - args.em_coef > 0: 58 | c_loss = mcc_loss(outputs / args.temp, class_num=outputs.shape[-1], reweight=True) 59 | ((1 - args.em_coef) * c_loss / (len(wavs))).backward(retain_graph=True) 60 | if 'beam_search_max' in args.method or 'beam_search_all' in args.method or 'beam_search_negative_sampling' in args.method: 61 | criterion = nn.CrossEntropyLoss(ignore_index=blank_index) if args.not_blank else nn.CrossEntropyLoss() 62 | if 'beam_search_max' in args.method: 63 | char_history = pseudo_labels[0].to(args.device) 64 | if args.certain_only: 65 | selected_frame = set() 66 | top_idx, top_prob = -1, 0 67 | for frame_idx, (output, char_idx) in enumerate(zip(outputs.squeeze(0), char_history)): 68 | probs = torch.softmax(output, dim=-1) 69 | if probs[char_idx] > args.prob_threshold: 70 | selected_frame.add(frame_idx) 71 | if char_idx != blank_index and probs[char_idx].item() > top_prob: 72 | top_idx = frame_idx 73 | top_prob = probs[char_idx].item() 74 | selected_frame.add(top_idx) 75 | selected_frame = sorted(selected_frame) 76 | selected_outputs, selected_char_history = outputs.squeeze(0)[selected_frame], char_history[selected_frame] 77 | loss = criterion(selected_outputs / args.temp, selected_char_history) 78 | else: 79 | loss = criterion(outputs / args.temp, char_history) 80 | (loss / len(wavs)).backward(retain_graph=True) 81 | elif 'beam_search_all' in args.method: 82 | loss = 0 83 | for char_history in pseudo_labels[:args.num_positives]: 84 | char_history = char_history.to(args.device) 85 | if args.certain_only: 86 | selected_frame = set() 87 | top_idx, top_prob = -1, 0 88 | for frame_idx, (output, char_idx) in enumerate(zip(outputs.squeeze(0), char_history)): 89 | probs = torch.softmax(output, dim=-1) 90 | if probs[char_idx] > args.prob_threshold: 91 | selected_frame.add(frame_idx) 92 | if char_idx != blank_index and probs[char_idx].item() > top_prob: 93 | top_idx = frame_idx 94 | top_prob = probs[char_idx].item() 95 | selected_frame.add(top_idx) 96 | selected_frame = sorted(selected_frame) 97 | selected_outputs, selected_char_history = outputs.squeeze(0)[selected_frame], char_history[selected_frame] 98 | loss += criterion(selected_outputs / args.temp, selected_char_history) / len(pseudo_labels) 99 | else: 100 | loss += criterion(outputs / args.temp, char_history) / len(pseudo_labels) 101 | (loss / len(wavs)).backward(retain_graph=True) 102 | if 'beam_search_negative_sampling' in args.method: 103 | negative_outputs = outputs.clone() 104 | negative_loss = 0 105 | char_history = pseudo_labels[0].to(args.device) 106 | if args.negative_sampling_method == "random": 107 | for _ in range(args.num_negatives): 108 | negative_char_history = torch.randint_like(input=char_history, high=negative_outputs.shape[-1]).to(args.device) 109 | negative_mask = (negative_char_history != char_history) & (char_history != 0) 110 | 111 | selected_frame = [] 112 | for frame_idx, mask in enumerate(negative_mask): 113 | if mask: 114 | selected_frame.append(frame_idx) 115 | selected_negative_outputs = negative_outputs.squeeze(0)[selected_frame] 116 | selected_negative_char_history = negative_char_history[selected_frame] 117 | if len(selected_negative_outputs) > 0: 118 | negative_loss += -criterion(selected_negative_outputs / args.temp, selected_negative_char_history) / args.num_negatives 119 | elif args.negative_sampling_method == "beam_candidate": 120 | for out_idx in range(len(pseudo_labels))[-args.num_negatives:]: 121 | negative_char_history = pseudo_labels[out_idx].to(args.device) 122 | negative_mask = (negative_char_history != char_history) & (char_history != 0) 123 | 124 | selected_frame = [] 125 | for frame_idx, mask in enumerate(negative_mask): 126 | if mask: 127 | selected_frame.append(frame_idx) 128 | selected_negative_outputs = negative_outputs.squeeze(0)[selected_frame] 129 | selected_negative_char_history = negative_char_history[selected_frame] 130 | if len(selected_negative_outputs) > 0: 131 | negative_loss += -criterion(selected_negative_outputs / args.temp, selected_negative_char_history) / args.num_negatives 132 | elif args.negative_sampling_method == 'ns3l': 133 | negative_mask = torch.where(torch.softmax(negative_outputs, dim=-1) < args.ns_threshold * (10 / negative_outputs.shape[-1]), 1, 0) 134 | negative_loss += torch.mean(-torch.log(1 - torch.sum(negative_mask * torch.softmax(negative_outputs / args.temp, dim=-1), dim=-1))) 135 | if torch.is_tensor(negative_loss): 136 | (args.ns_coef * negative_loss / len(wavs)).backward(retain_graph=True) 137 | if 'renyi_em' in args.method: 138 | predicted_ids = torch.argmax(outputs, dim=-1) 139 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 140 | # added by esyoon 2024-02-25-22:31:53 141 | if isinstance(model, Wav2Vec2ForCTC): 142 | pred_text = processor.decode( 143 | outputs.squeeze(0).detach().cpu().numpy(), 144 | beam_width=args.beam_width, 145 | ).text 146 | elif isinstance(model, EncoderDecoderASR): 147 | pred_text = model.transcribe_batch(wav, wav_lens=torch.ones(1).to(args.device))[0] 148 | 149 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 150 | greedy_predictions = outputs.argmax(dim=-1, keepdim=False) 151 | pred_text = model.wer.decode(greedy_predictions, predictions_lengths=torch.tensor([outputs.shape[1]]))[0][0].upper() 152 | 153 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 154 | # vocabulary = model.tokenizer.tokenizer.get_vocab() 155 | # labels_map = {v: k for k, v in vocabulary.items()} 156 | # token_list = [labels_map[c] for c in pseudo_labels[0].tolist() if c < model.decoding.blank_id - model.decoding.num_extra_outputs] 157 | # pred_text = model.tokenizer.tokenizer.detokenize(token_list).upper() 158 | pred_text = transcribe_batch(args, model, processor, wavs, lens)[0] 159 | 160 | # pred_text = model.wer.decode(greedy_predictions, predictions_lengths=encoder_length)[0][0].upper() 161 | if args.llm_model == 'llama': 162 | lm_input = TEMPLATE.format(SYSTEM, INSTRUCTION, pred_text) 163 | lm_input_tensor = lm_tokenizer(lm_input, return_tensors='pt') 164 | lm_input_tensor = to_device(lm_input_tensor, args.device) 165 | do_generation = True 166 | max_length = 512 167 | lm_generated_text = None 168 | num_loops = 0 169 | ctc_loss = 0 170 | while do_generation: 171 | num_loops += 1 172 | with torch.no_grad(): 173 | lm_output = correction_lm.generate(**lm_input_tensor, max_length = max_length) 174 | lm_output = lm_tokenizer.decode(lm_output[0], skip_special_tokens=True) 175 | if "<>" in lm_output.split('[/INST]')[1]: 176 | try: 177 | lm_generated_text = lm_output.split('[/INST]')[1].split('<>: ')[1].split('\n')[0].replace('.', '') 178 | except: 179 | do_generation=True 180 | max_length += 128 181 | elif "Corrected sentence" in lm_output.split('[/INST]')[1]: 182 | try: 183 | lm_generated_text = lm_output.split('[/INST]')[1].split("Corrected sentence: ")[1].split('\n')[0].replace('.', '') 184 | except: 185 | do_generation=True 186 | max_length += 128 187 | if lm_generated_text: 188 | if '"' in lm_generated_text: 189 | lm_generated_text = lm_generated_text.replace('"', '') 190 | do_generation=False 191 | break 192 | else: 193 | do_generation = True 194 | if num_loops > 5: 195 | print("skip this one") 196 | 197 | do_generation = False 198 | break 199 | if lm_generated_text: 200 | ctc_loss_reduction = "mean" 201 | ctc_zero_infinity = True 202 | 203 | log_probs = nn.functional.log_softmax(outputs, dim=-1, dtype=torch.float32).transpose(0,1) 204 | 205 | if isinstance(processor, Wav2Vec2Processor): 206 | processed_lm_output = processor(text=lm_generated_text, return_tensors='pt') 207 | processed_lm_output = to_device(processed_lm_output, args.device) 208 | ctc_loss = nn.functional.ctc_loss( 209 | log_probs, 210 | processed_lm_output['input_ids'], 211 | torch.LongTensor([outputs.shape[1]]), 212 | torch.LongTensor([processed_lm_output['input_ids'].shape[1]]), 213 | blank=blank_index, 214 | reduction=ctc_loss_reduction, 215 | zero_infinity=ctc_zero_infinity, 216 | ) 217 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 218 | processed_lm_output = torch.tensor([model.tokenizer.text_to_ids(pred_text)]).to(args.device) 219 | ctc_loss = nn.functional.ctc_loss( 220 | log_probs, 221 | processed_lm_output, 222 | torch.LongTensor([outputs.shape[1]]), 223 | torch.LongTensor([processed_lm_output.shape[1]]), 224 | blank=blank_index, 225 | reduction=ctc_loss_reduction, 226 | zero_infinity=ctc_zero_infinity, 227 | ) 228 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 229 | processed_lm_output = torch.tensor([model.tokenizer.text_to_ids(pred_text)]).to(args.device) 230 | ctc_loss = nn.functional.ctc_loss( 231 | log_probs, 232 | processed_lm_output, 233 | torch.LongTensor([outputs.shape[1]]), 234 | torch.LongTensor([processed_lm_output.shape[1]]), 235 | blank=blank_index, 236 | reduction=ctc_loss_reduction, 237 | zero_infinity=ctc_zero_infinity, 238 | ) 239 | 240 | # ctc_loss_reduction = "mean" 241 | # ctc_zero_infinity = True 242 | # ctc_loss = nn.functional.ctc_loss( 243 | # log_probs, 244 | # processed_lm_output['input_ids'], 245 | # torch.LongTensor([outputs.shape[1]]), 246 | # torch.LongTensor([processed_lm_output['input_ids'].shape[1]]), 247 | # blank=blank_index, 248 | # reduction=ctc_loss_reduction, 249 | # zero_infinity=ctc_zero_infinity, 250 | # ) 251 | elif args.llm_model == 'openchat3.5': 252 | NEW_TEMPLATE = {} 253 | NEW_TEMPLATE['role'] = 'user' 254 | NEW_TEMPLATE['content'] = TEMPLATE['content'] + pred_text 255 | 256 | lm_input_with_template= lm_tokenizer.apply_chat_template([NEW_TEMPLATE], add_generation_prompt=True, tokenize=True, return_tensors='pt') 257 | lm_input_tensor = lm_input_with_template.to(args.device) 258 | max_length = 512 259 | lm_generated_text = None 260 | ctc_loss = 0 261 | with torch.no_grad(): 262 | with torch.cuda.amp.autocast(): 263 | lm_output = correction_lm.generate(lm_input_tensor, max_new_tokens = max_length, pad_token_id=lm_tokenizer.eos_token_id) 264 | lm_output = lm_tokenizer.decode(lm_output[0], skip_special_tokens=True) 265 | try: 266 | lm_generated_text = lm_output.split('<>: ')[-1].replace('.', '').upper() 267 | except: 268 | lm_generated_text = None 269 | if lm_generated_text: 270 | ctc_loss_reduction = "mean" 271 | ctc_zero_infinity = True 272 | 273 | log_probs = nn.functional.log_softmax(outputs, dim=-1, dtype=torch.float32).transpose(0,1) 274 | 275 | if isinstance(processor, Wav2Vec2Processor): 276 | processed_lm_output = processor(text=lm_generated_text, return_tensors='pt') 277 | processed_lm_output = to_device(processed_lm_output, args.device) 278 | 279 | ctc_loss = nn.functional.ctc_loss( 280 | log_probs, 281 | processed_lm_output['input_ids'], 282 | torch.LongTensor([outputs.shape[1]]), 283 | torch.LongTensor([processed_lm_output['input_ids'].shape[1]]), 284 | blank=blank_index, 285 | reduction=ctc_loss_reduction, 286 | zero_infinity=ctc_zero_infinity, 287 | ) 288 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 289 | processed_lm_output = torch.tensor([model.tokenizer.text_to_ids(pred_text)]).to(args.device) 290 | ctc_loss = nn.functional.ctc_loss( 291 | log_probs, 292 | processed_lm_output, 293 | torch.LongTensor([outputs.shape[1]]), 294 | torch.LongTensor([processed_lm_output.shape[1]]), 295 | blank=blank_index, 296 | reduction=ctc_loss_reduction, 297 | zero_infinity=ctc_zero_infinity, 298 | ) 299 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 300 | processed_lm_output = torch.tensor([model.tokenizer.text_to_ids(pred_text)]).to(args.device) 301 | ctc_loss = nn.functional.ctc_loss( 302 | log_probs, 303 | processed_lm_output, 304 | torch.LongTensor([outputs.shape[1]]), 305 | torch.LongTensor([processed_lm_output.shape[1]]), 306 | blank=blank_index, 307 | reduction=ctc_loss_reduction, 308 | zero_infinity=ctc_zero_infinity, 309 | ) 310 | # processed_lm_output = processor(text=lm_generated_text, return_tensors='pt') 311 | # processed_lm_output = to_device(processed_lm_output, args.device) 312 | # log_probs = nn.functional.log_softmax(outputs, dim=-1, dtype=torch.float32).transpose(0,1) 313 | 314 | # ctc_loss_reduction = "mean" 315 | # ctc_zero_infinity = True 316 | # ctc_loss = nn.functional.ctc_loss( 317 | # log_probs, 318 | # processed_lm_output['input_ids'], 319 | # torch.LongTensor([outputs.shape[1]]), 320 | # torch.LongTensor([processed_lm_output['input_ids'].shape[1]]), 321 | # blank=blank_index, 322 | # reduction=ctc_loss_reduction, 323 | # zero_infinity=ctc_zero_infinity, 324 | # ) 325 | # ============================================================================================================= 326 | if args.not_blank: 327 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha=args.renyi_entropy_alpha) 328 | if non_blank.sum().item() == 0: 329 | e_loss = renyi_entropy(outputs / args.temp, alpha=args.renyi_entropy_alpha) 330 | else: 331 | e_loss = renyi_entropy(outputs / args.temp, alpha=args.renyi_entropy_alpha) 332 | 333 | if args.coef_adapting: 334 | ((args.tta_coef*e_loss + (e_loss/(e_loss + ctc_loss))*args.llm_coef*ctc_loss) / (len(wavs))).backward(retain_graph=True) 335 | # ((args.tta_coef*e_loss + ((args.steps-(step-1))/args.steps)*args.llm_coef*ctc_loss) / (len(wavs))).backward(retain_graph=True) 336 | # ((args.tta_coef*e_loss + ((args.steps-(step-1))/args.steps)*args.llm_coef*ctc_loss) / (len(wavs))).backward(retain_graph=True) 337 | else: 338 | ((args.tta_coef*e_loss + args.llm_coef*ctc_loss) / (len(wavs))).backward(retain_graph=True) 339 | if 'kld_ori' in args.method: 340 | assert 0 <= args.kld_weight <= 1 341 | # TODO: implement bias parameter 342 | 343 | # naive pseudo-labeling 344 | predicted_ids = torch.argmax(outputs, dim=-1) 345 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 346 | # e_loss = renyi_entropy(outputs, alpha='inf') 347 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha='inf') 348 | ((1 - args.kld_weight) * e_loss / (len(wavs))).backward(retain_graph=True) 349 | 350 | # kld loss 351 | original_outputs, _ = get_logits_and_pseudo_labels(args, original_model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 352 | probs = torch.softmax(outputs, dim=-1) 353 | original_probs = torch.softmax(original_outputs, dim=-1) 354 | kl_div_loss = F.kl_div(torch.log(probs), original_probs.detach(), reduction="batchmean") 355 | (args.kld_weight * kl_div_loss / (len(wavs))).backward(retain_graph=True) 356 | if 'kld_comb' in args.method: 357 | # Renyi em 358 | predicted_ids = torch.argmax(outputs, dim=-1) 359 | non_blank = torch.where(predicted_ids != blank_index, 1, 0).bool() 360 | if args.not_blank: 361 | e_loss = renyi_entropy((outputs / args.temp)[non_blank], alpha=args.renyi_entropy_alpha) 362 | else: 363 | e_loss = renyi_entropy(outputs / args.temp, alpha=args.renyi_entropy_alpha) 364 | ((1 - args.kld_weight) * e_loss / (len(wavs))).backward(retain_graph=True) 365 | 366 | # negative sampling 367 | criterion = nn.CrossEntropyLoss(ignore_index=blank_index) if args.not_blank else nn.CrossEntropyLoss() 368 | negative_outputs = outputs.clone() 369 | negative_loss = 0 370 | char_history = pseudo_labels[0].to(args.device) 371 | negative_mask = torch.where(torch.softmax(negative_outputs, dim=-1) < args.ns_threshold * (10 / negative_outputs.shape[-1]), 1, 0) 372 | negative_loss += torch.mean(-torch.log(1 - torch.sum(negative_mask * torch.softmax(negative_outputs / args.temp, dim=-1), dim=-1))) 373 | if torch.is_tensor(negative_loss): 374 | ((1 - args.kld_weight) * args.ns_coef * negative_loss / len(wavs)).backward(retain_graph=True) 375 | 376 | # kld loss 377 | original_outputs, _ = get_logits_and_pseudo_labels(args, original_model, processor, wav, torch.FloatTensor([lens[i]]).to(wav.device)) 378 | probs = torch.softmax(outputs, dim=-1) 379 | original_probs = torch.softmax(original_outputs, dim=-1) 380 | kl_div_loss = F.kl_div(torch.log(probs), original_probs.detach(), reduction="batchmean") 381 | (args.kld_weight * kl_div_loss / (len(wavs))).backward(retain_graph=True) 382 | 383 | optimizer.step() 384 | if scheduler is not None: 385 | scheduler.step() 386 | 387 | 388 | @hydra.main(version_base=None, config_path="conf", config_name="config") 389 | def main(args): 390 | if args.seed: 391 | set_seed(args.seed) 392 | 393 | if not os.path.exists(args.log_dir): 394 | os.makedirs(args.log_dir) 395 | global logger 396 | logger = get_logger(args) 397 | logger.info(OmegaConf.to_yaml(args)) 398 | 399 | dataset = load_dataset(args.dataset_name, args.dataset_dir, args.batch_size, args.extra_noise, args.noise_type, args.noise_snr) 400 | gt_texts, ori_transcriptions, transcriptions_1, transcriptions_3, transcriptions_5, transcriptions_10, transcriptions_20, transcriptions_40 = [], [], [], [], [], [], [], [] 401 | 402 | global original_model, lm_tokenizer, correction_lm 403 | 404 | model = get_model(args) 405 | original_model = get_model(args) 406 | params, _ = collect_params(model, args.train_params) 407 | optimizer, scheduler = get_optimizer(args, params, opt_name=args.optimizer, lr=args.lr, scheduler=args.scheduler) 408 | processor = Wav2Vec2Processor.from_pretrained(args.asr, sampling_rate=args.sample_rate, return_attention_mask=True) if isinstance(model, Wav2Vec2ForCTC) else None 409 | 410 | lm_tokenizer = AutoTokenizer.from_pretrained(args.lm_model, padding_side="right", use_fast=False,) 411 | correction_lm = AutoModelForCausalLM.from_pretrained(args.lm_model, torch_dtype=torch.float16, use_flash_attention_2=True).to(args.device) 412 | correction_lm.eval() 413 | 414 | 415 | 416 | if isinstance(model, Wav2Vec2ForCTC): 417 | decoder_processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.processor) 418 | elif isinstance(model, EncoderDecoderASR): 419 | decoder_processor = None 420 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 421 | decoder_processor = BeamSearchDecoderCTC( 422 | alphabet=Alphabet(labels=model.decoder.vocabulary+[""], is_bpe=True), 423 | language_model=LanguageModel.load_from_dir(args.processor), 424 | ) 425 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 426 | decoder_processor = BeamRNNTInfer( 427 | model.decoding.decoding.decoder.to(args.device), 428 | model.decoding.decoding.joint.to(args.device), 429 | beam_size=args.beam_width, 430 | return_best_hypothesis=False, 431 | ) 432 | 433 | episodic = args.episodic 434 | steps = args.steps 435 | 436 | if episodic: 437 | original_model_state, original_optimizer_state, original_scheduler_state = copy_model_and_optimizer(model, optimizer, scheduler) 438 | 439 | for batch_idx, batch in enumerate(dataset): 440 | if args.dataset_name == "commonvoice" and batch_idx >= 1000: 441 | break 442 | 443 | lens, wavs, texts, _ = batch 444 | if isinstance(model, Wav2Vec2ForCTC): 445 | wavs = processor(wavs, sampling_rate=args.sample_rate, return_tensors="pt", padding="longest").input_values.to(args.device) 446 | else: 447 | wavs = pad_sequence([torch.from_numpy(wav) for wav in wavs], batch_first=True).to(args.device) 448 | lens = lens.to(args.device) 449 | 450 | gt_texts.extend(texts) 451 | ori_transcription = transcribe_batch(args, original_model, processor, wavs, lens) 452 | ori_transcriptions.extend(ori_transcription) 453 | ori_wer = wer(list(texts), list(ori_transcription)) 454 | 455 | logger.info(f"{batch_idx}/{len(dataset)}") 456 | logger.info(f"gt text: {' '.join(list(texts))}") 457 | logger.info(f"original WER: {ori_wer}") 458 | logger.info(f"original text: {' '.join(list(ori_transcription))}") 459 | 460 | if episodic: 461 | model, optimizer, scheduler = load_model_and_optimizer(model, optimizer, scheduler, original_model_state, original_optimizer_state, original_scheduler_state) 462 | 463 | for step_idx in range(1, steps + 1): 464 | model = set_rnn_to_train(model) 465 | forward_and_adapt(args, model, decoder_processor, optimizer, scheduler, wavs, lens, step_idx) 466 | transcription = transcribe_batch(args, model, processor, wavs, lens) 467 | 468 | if step_idx in [1, 3, 5, 10, 20, 40]: 469 | transcription_list = eval(f"transcriptions_{step_idx}") 470 | transcription_list.extend(transcription) 471 | ada_wer = wer(list(texts), list(transcription)) 472 | logger.info(f"adapt-{step_idx} WER: {ada_wer}") 473 | logger.info(f"adapt-{step_idx} text: {' '.join(list(transcription))}") 474 | 475 | gc.collect() 476 | torch.cuda.empty_cache() 477 | logger.info("\n") 478 | 479 | logger.info(OmegaConf.to_yaml(args)) 480 | logger.info(f"number of data : {len(dataset)}") 481 | logger.info(f"original WER: {wer(gt_texts, ori_transcriptions)}") 482 | for step_idx in [1, 3, 5, 10, 20, 40]: 483 | if step_idx > steps: 484 | break 485 | transcription_list = eval(f"transcriptions_{step_idx}") 486 | logger.info(f"TTA-{step_idx}: {wer(gt_texts, transcription_list)}") 487 | 488 | 489 | 490 | if __name__ == '__main__': 491 | main() -------------------------------------------------------------------------------- /forward.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | from dataclasses import dataclass, field 4 | import heapq 5 | 6 | import numpy as np 7 | import torch 8 | from torch.nn.utils.rnn import pad_sequence 9 | 10 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM 11 | from speechbrain.pretrained import EncoderDecoderASR 12 | from speechbrain.decoders.ctc import CTCPrefixScorer 13 | import nemo.collections.asr as nemo_asr 14 | from nemo.collections.asr.parts.utils import rnnt_utils 15 | from nemo.collections.common.parts.rnn import label_collate 16 | from pyctcdecode.alphabet import BPE_TOKEN 17 | from pyctcdecode.constants import DEFAULT_HOTWORD_WEIGHT, DEFAULT_MIN_TOKEN_LOGP, DEFAULT_PRUNE_BEAMS, DEFAULT_PRUNE_LOGP, MIN_TOKEN_CLIP_P 18 | from pyctcdecode.language_model import HotwordScorer 19 | import kenlm 20 | 21 | from utils import log_softmax 22 | 23 | 24 | # for ctc-based models and conformers 25 | Frames = Tuple[int, int] 26 | WordFrames = Tuple[str, Frames] 27 | LMBeam = Tuple[str, str, str, Optional[str], List[Frames], Frames, float, float] 28 | LMState = Optional[Union["kenlm.State", List["kenlm.State"]]] 29 | OutputBeam = Tuple[str, LMState, List[WordFrames], float, float] 30 | OutputBeamMPSafe = Tuple[str, List[WordFrames], float, float] 31 | NULL_FRAMES: Frames = (-1, -1) # placeholder that gets replaced with positive integer frame indices 32 | 33 | 34 | 35 | @dataclass 36 | class Hypothesis: # for transducers 37 | score: float 38 | y_sequence: Union[List[int], torch.Tensor] 39 | text: Optional[str] = None 40 | dec_out: Optional[List[torch.Tensor]] = None 41 | dec_state: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor]]] = None 42 | timestep: Union[List[int], torch.Tensor] = field(default_factory=list) 43 | alignments: Optional[Union[List[int], List[List[int]]]] = None 44 | length: Union[int, torch.Tensor] = 0 45 | y: List[torch.tensor] = None 46 | lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None 47 | lm_scores: Optional[torch.Tensor] = None 48 | tokens: Optional[Union[List[int], torch.Tensor]] = None 49 | last_token: Optional[torch.Tensor] = None 50 | token_list: List = field(default_factory=list) 51 | 52 | 53 | @torch.no_grad() 54 | def transcribe_batch(args, model, processor, wavs, lens): 55 | transcription = [] 56 | if isinstance(model, Wav2Vec2ForCTC): 57 | for wav, len in zip(wavs, lens): 58 | wav = wav[:len].unsqueeze(0) 59 | outputs = model(wav).logits 60 | predicted_ids = torch.argmax(outputs, dim=-1) 61 | if isinstance(processor, Wav2Vec2Processor): # greedy decoding 62 | text = processor.batch_decode(predicted_ids)[0] 63 | elif isinstance(processor, Wav2Vec2ProcessorWithLM): # beam search decoding with external language model 64 | text = processor.decode( 65 | outputs.squeeze(0).detach().cpu().numpy(), 66 | beam_width=args.beam_width, 67 | ).text 68 | transcription.append(text) 69 | elif isinstance(model, EncoderDecoderASR): 70 | for wav, len in zip(wavs, lens): 71 | wav = wav[:len].unsqueeze(0) 72 | text = model.transcribe_batch(wav, wav_lens=torch.ones(1).to(args.device))[0] 73 | transcription.append(text[0]) 74 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 75 | for wav, len in zip(wavs, lens): 76 | wav = wav[:len].unsqueeze(0) 77 | len = len.unsqueeze(0) 78 | processed_signal, processed_signal_length = model.preprocessor( 79 | input_signal=wav.to(args.device), length=len.to(args.device), 80 | ) 81 | encoder_output, encoder_length = model.encoder(audio_signal=processed_signal, length=processed_signal_length) 82 | log_probs = model.decoder(encoder_output=encoder_output) 83 | greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) 84 | 85 | # text = model._wer.ctc_decoder_predictions_tensor(greedy_predictions, predictions_len=encoder_length, return_hypotheses=False)[0].upper() # commented out by esyoon 2024-03-08-20:19:21 86 | text = model.wer.decode(greedy_predictions, predictions_lengths=encoder_length)[0][0].upper() # added by esyoon 2024-03-08-20:19:19 87 | transcription.append(text) 88 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 89 | for wav, len in zip(wavs, lens): 90 | wav = wav[:len].unsqueeze(0) 91 | len = len.unsqueeze(0) 92 | encoded_feature, encoded_len = model(input_signal=wav, input_signal_length=len) 93 | best_hyp_texts, _ = model.decoding.rnnt_decoder_predictions_tensor( 94 | encoder_output=encoded_feature, encoded_lengths=encoded_len, return_hypotheses=False 95 | ) 96 | text = [best_hyp_text.upper() for best_hyp_text in best_hyp_texts][0] 97 | transcription.append(text) 98 | return transcription 99 | 100 | 101 | def forward_batch(args, model, processor, wavs, lens, labels=None): 102 | if isinstance(model, Wav2Vec2ForCTC): 103 | outputs = forward_ctc_or_conformer(args, model, processor, wavs, lens, labels) 104 | elif isinstance(model, EncoderDecoderASR): 105 | model.mods.decoder.dec.train() 106 | # model.mods.decoder.lm_weight = args.lm_coef 107 | outputs = forward_attn(args, model, wavs, lens, labels) 108 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 109 | outputs = forward_ctc_or_conformer(args, model, processor, wavs, lens, labels) 110 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 111 | outputs = forward_trans(args, model, wavs, lens, labels) 112 | return outputs 113 | 114 | 115 | def forward_ctc_or_conformer(args, model, processor, wavs, lens, labels): 116 | if isinstance(model, Wav2Vec2ForCTC): 117 | logits = model(wavs).logits 118 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 119 | processed_signal, processed_signal_length = model.preprocessor( 120 | input_signal=wavs.to(args.device), length=lens.to(args.device), 121 | ) 122 | encoder_output, _ = model.encoder(audio_signal=processed_signal, length=processed_signal_length) 123 | logits = model.decoder(encoder_output=encoder_output) 124 | if labels == None or not args.lm_coef: 125 | return logits 126 | else: 127 | lm_logits = forward_ctc_or_conformer_with_labels( 128 | args, 129 | processor.decoder if isinstance(model, Wav2Vec2ForCTC) else processor, 130 | np.clip(log_softmax(logits.squeeze(0).detach().cpu().numpy(), axis=1), 131 | np.log(MIN_TOKEN_CLIP_P), 0), 132 | labels, 133 | hotword_scorer=HotwordScorer.build_scorer(None, weight=DEFAULT_HOTWORD_WEIGHT), 134 | lm_start_state=None, 135 | ).unsqueeze(0) 136 | return logits + args.lm_coef * lm_logits 137 | 138 | 139 | def forward_ctc_or_conformer_with_labels( 140 | args, 141 | model, 142 | logits, 143 | labels, 144 | hotword_scorer, 145 | lm_start_state, 146 | ): 147 | def _merge_tokens(token_1: str, token_2: str) -> str: 148 | """Fast, whitespace safe merging of tokens.""" 149 | if len(token_2) == 0: 150 | text = token_1 151 | elif len(token_1) == 0: 152 | text = token_2 153 | else: 154 | text = token_1 + " " + token_2 155 | return text 156 | 157 | def get_new_beams( 158 | model, 159 | beams, 160 | idx_list, 161 | frame_idx, 162 | logit_col, 163 | ): 164 | new_beams = [] 165 | # bpe we can also have trailing word boundaries ▁⁇▁ so we may need to remember breaks 166 | force_next_break = False 167 | for idx_char in idx_list: 168 | p_char = logit_col[idx_char] 169 | char = model._idx2vocab[idx_char] 170 | for ( 171 | text, 172 | next_word, 173 | word_part, 174 | last_char, 175 | text_frames, 176 | part_frames, 177 | logit_score, 178 | idx_history, 179 | lm_logits, 180 | ) in beams: 181 | if char == "" or last_char == char: 182 | if char == "": 183 | new_end_frame = part_frames[0] 184 | else: 185 | new_end_frame = frame_idx + 1 186 | new_part_frames = ( 187 | part_frames if char == "" else (part_frames[0], new_end_frame) 188 | ) 189 | new_beams.append( 190 | ( 191 | text, 192 | next_word, 193 | word_part, 194 | char, 195 | text_frames, 196 | new_part_frames, 197 | logit_score + p_char, 198 | idx_history + [idx_char], 199 | lm_logits, 200 | ) 201 | ) 202 | # if bpe and leading space char 203 | elif model._is_bpe and (char[:1] == BPE_TOKEN or force_next_break): 204 | force_next_break = False 205 | # some tokens are bounded on both sides like ▁⁇▁ 206 | clean_char = char 207 | if char[:1] == BPE_TOKEN: 208 | clean_char = clean_char[1:] 209 | if char[-1:] == BPE_TOKEN: 210 | clean_char = clean_char[:-1] 211 | force_next_break = True 212 | new_frame_list = ( 213 | text_frames if word_part == "" else text_frames + [part_frames] 214 | ) 215 | new_beams.append( 216 | ( 217 | text, 218 | word_part, 219 | clean_char, 220 | char, 221 | new_frame_list, 222 | (frame_idx, frame_idx + 1), 223 | logit_score + p_char, 224 | idx_history + [idx_char], 225 | lm_logits, 226 | ) 227 | ) 228 | # if not bpe and space char 229 | elif not model._is_bpe and char == " ": 230 | new_frame_list = ( 231 | text_frames if word_part == "" else text_frames + [part_frames] 232 | ) 233 | new_beams.append( 234 | ( 235 | text, 236 | word_part, 237 | "", 238 | char, 239 | new_frame_list, 240 | NULL_FRAMES, 241 | logit_score + p_char, 242 | idx_history + [idx_char], 243 | lm_logits, 244 | ) 245 | ) 246 | # general update of continuing token without space 247 | else: 248 | new_part_frames = ( 249 | (frame_idx, frame_idx + 1) 250 | if part_frames[0] < 0 251 | else (part_frames[0], frame_idx + 1) 252 | ) 253 | new_beams.append( 254 | ( 255 | text, 256 | next_word, 257 | word_part + char, 258 | char, 259 | text_frames, 260 | new_part_frames, 261 | logit_score + p_char, 262 | idx_history + [idx_char], 263 | lm_logits, 264 | ) 265 | ) 266 | return new_beams 267 | 268 | def get_lm_beams( 269 | model, 270 | beams, 271 | hotword_scorer, 272 | cached_lm_scores, 273 | cached_partial_token_scores, 274 | is_eos, 275 | ): 276 | lm_score_list = np.zeros(len(beams)) 277 | language_model = model._language_model 278 | new_beams = [] 279 | for text, next_word, word_part, last_char, frame_list, frames, logit_score, idx_history, lm_logits in beams: 280 | new_text = _merge_tokens(text, next_word) 281 | if new_text not in cached_lm_scores: 282 | _, prev_raw_lm_score, start_state = cached_lm_scores[text] 283 | score, end_state = language_model.score(start_state, next_word, is_last_word=is_eos) 284 | raw_lm_score = prev_raw_lm_score + score 285 | lm_hw_score = raw_lm_score + hotword_scorer.score(new_text) 286 | cached_lm_scores[new_text] = (lm_hw_score, raw_lm_score, end_state) 287 | lm_score, _, _ = cached_lm_scores[new_text] 288 | 289 | if len(word_part) > 0: 290 | if word_part not in cached_partial_token_scores: 291 | # if prefix available in hotword trie use that, otherwise default to char trie 292 | if word_part in hotword_scorer: 293 | cached_partial_token_scores[word_part] = hotword_scorer.score_partial_token( 294 | word_part 295 | ) 296 | else: 297 | cached_partial_token_scores[word_part] = language_model.score_partial_token( 298 | word_part 299 | ) 300 | lm_score += cached_partial_token_scores[word_part] 301 | 302 | new_beams.append( 303 | ( 304 | new_text, 305 | "", 306 | word_part, 307 | last_char, 308 | frame_list, 309 | frames, 310 | logit_score, 311 | logit_score + lm_score, 312 | idx_history, 313 | lm_logits, 314 | ) 315 | ) 316 | lm_score_list[model._vocab2idx[last_char]] = lm_score 317 | 318 | new_beams_with_lm_logits = [] 319 | for text, next_word, word_part, last_char, frame_list, frames, logit_score, combined_score, idx_history, lm_logits in new_beams: 320 | new_beams_with_lm_logits.append( 321 | ( 322 | text, 323 | next_word, 324 | word_part, 325 | last_char, 326 | frame_list, 327 | frames, 328 | logit_score, 329 | combined_score, 330 | idx_history, 331 | lm_logits + [lm_score_list], 332 | ) 333 | ) 334 | return new_beams_with_lm_logits 335 | 336 | language_model = model._language_model 337 | if lm_start_state is None and language_model is not None: 338 | cached_lm_scores: Dict[str, Tuple[float, float, LMState]] = { 339 | "": (0.0, 0.0, language_model.get_start_state()) 340 | } 341 | else: 342 | cached_lm_scores = {"": (0.0, 0.0, lm_start_state)} 343 | cached_p_lm_scores: Dict[str, float] = {} 344 | if not hasattr(model, '_vocab2idx'): 345 | model._vocab2idx = {vocab: idx for idx, vocab in model._idx2vocab.items()} 346 | beams = [("", "", "", None, [], NULL_FRAMES, 0.0, [], [])] # start with single beam to expand on 347 | 348 | for frame_idx, logit_col in enumerate(logits): 349 | idx_list = list(range(0, logit_col.shape[-1])) 350 | new_beams = get_new_beams( 351 | model, 352 | beams, 353 | idx_list, 354 | frame_idx, 355 | logit_col, 356 | ) 357 | scored_beams = get_lm_beams( 358 | model, 359 | new_beams, 360 | hotword_scorer, 361 | cached_lm_scores, 362 | cached_p_lm_scores, 363 | is_eos=False, 364 | ) 365 | beams = [scored_beams[labels[frame_idx]][:-3] + scored_beams[labels[frame_idx]][-2:]] 366 | return torch.tensor(np.array(beams[0][-1])).to(args.device) 367 | 368 | 369 | def forward_attn(args, model, wavs, lens, labels): 370 | def decoder_forward_step(model, inp_tokens, memory, enc_states, enc_lens): 371 | """Performs a step in the implemented beamsearcher.""" 372 | hs, c = memory 373 | e = model.emb(inp_tokens) 374 | dec_out, hs, c, w = model.dec.forward_step( 375 | e, hs, c, enc_states, enc_lens 376 | ) 377 | log_probs = model.softmax(model.fc(dec_out) / model.temperature) 378 | 379 | if model.dec.attn_type == "multiheadlocation": 380 | w = torch.mean(w, dim=1) 381 | return log_probs, (hs, c), w 382 | 383 | logits = [] 384 | enc_states = model.encode_batch(wavs, lens) 385 | enc_lens = torch.tensor([enc_states.shape[1]]).to(args.device) 386 | 387 | device = enc_states.device 388 | batch_size = enc_states.shape[0] 389 | memory = model.mods.decoder.reset_mem(batch_size, device=device) 390 | 391 | inp_tokens = (enc_states.new_zeros(batch_size).fill_(model.mods.decoder.bos_index).long()) 392 | max_decode_steps = int(enc_states.shape[1] * model.mods.decoder.max_decode_ratio) 393 | 394 | for decode_step in range(max_decode_steps): 395 | log_probs, memory, _ = decoder_forward_step( 396 | model.mods.decoder, inp_tokens, memory, enc_states, enc_lens 397 | ) 398 | logits.append(log_probs) 399 | # teacher-forcing using beam search 400 | if labels != None: 401 | inp_tokens = torch.tensor([labels[decode_step]]).to(log_probs.device) 402 | else: 403 | inp_tokens = log_probs.argmax(dim=-1) 404 | logits = torch.stack(logits, dim=1).to(args.device) 405 | return logits 406 | 407 | 408 | def forward_trans(args, model, wavs, lens, labels): 409 | logits = [] 410 | encoder_output, encoded_lengths = model(input_signal=wavs, input_signal_length=lens) 411 | encoder_output = encoder_output.transpose(1, 2) 412 | logitlen = encoded_lengths 413 | 414 | inseq = encoder_output # [B, T, D] 415 | x, out_len, device = inseq, logitlen, inseq.device 416 | batchsize = x.shape[0] 417 | hypotheses = [rnnt_utils.Hypothesis(score=0.0, y_sequence=[], timestep=[], dec_state=None) for _ in range(batchsize)] 418 | hidden = None 419 | 420 | if model.decoding.decoding.preserve_alignments: 421 | for hyp in hypotheses: 422 | hyp.alignments = [[]] 423 | 424 | last_label = torch.full([batchsize, 1], fill_value=model.decoding.decoding._blank_index, dtype=torch.long, device=device) 425 | blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device) 426 | 427 | max_out_len = out_len.max() 428 | for time_idx in range(max_out_len): 429 | f = x.narrow(dim=1, start=time_idx, length=1) # [B, 1, D] 430 | 431 | not_blank = True 432 | symbols_added = 0 433 | 434 | blank_mask.mul_(False) 435 | blank_mask = time_idx >= out_len 436 | 437 | while not_blank and (model.decoding.decoding.max_symbols is None or symbols_added < model.decoding.decoding.max_symbols): 438 | if time_idx == 0 and symbols_added == 0 and hidden is None: 439 | in_label = model.decoding.decoding._SOS 440 | else: 441 | in_label = last_label 442 | if isinstance(in_label, torch.Tensor) and in_label.dtype != torch.long: 443 | in_label = in_label.long() 444 | g, hidden_prime = model.decoding.decoding.decoder.predict(None, hidden, False, batchsize) 445 | else: 446 | if in_label == model.decoding.decoding._SOS: 447 | g, hidden_prime = model.decoding.decoding.decoder.predict(None, hidden, False, batchsize) 448 | else: 449 | in_label = label_collate([[in_label.cpu()]]) 450 | g, hidden_prime = model.decoding.decoding.decoder.predict(in_label, hidden, False, batchsize) 451 | 452 | logp = model.decoding.decoding.joint.joint(f, g) 453 | if not logp.is_cuda: 454 | logp = logp.log_softmax(dim=len(logp.shape) - 1) 455 | logp = logp[:, 0, 0, :] 456 | 457 | if logp.dtype != torch.float32: 458 | logp = logp.float() 459 | 460 | # teacher-forcing using beam search 461 | if labels != None: 462 | label_idx = len(logits) 463 | label = labels[label_idx] if label_idx < len(labels) else model.decoding.decoding._blank_index 464 | v, k = logp[:, label], torch.tensor([label for _ in range(logp.shape[0])]).to(logp.device) 465 | else: 466 | v, k = logp.max(1) 467 | del g 468 | 469 | logits.append(logp) 470 | 471 | k_is_blank = k == model.decoding.decoding._blank_index 472 | blank_mask.bitwise_or_(k_is_blank) 473 | del k_is_blank 474 | 475 | if model.decoding.decoding.preserve_alignments: 476 | logp_vals = logp.to('cpu') 477 | logp_ids = logp_vals.max(1)[1] 478 | for batch_idx in range(batchsize): 479 | if time_idx < out_len[batch_idx]: 480 | hypotheses[batch_idx].alignments[-2].append( 481 | (logp_vals[batch_idx], logp_ids[batch_idx]) 482 | ) 483 | del logp_vals 484 | 485 | if blank_mask.all(): 486 | not_blank = False 487 | if model.decoding.decoding.preserve_alignments: 488 | for batch_idx in range(batchsize): 489 | if len(hypotheses[batch_idx].alignments[-2]) > 0: 490 | hypotheses[batch_idx].alignments.append([]) # blank buffer for next timestep 491 | else: 492 | blank_indices = (blank_mask == 1).nonzero(as_tuple=False) 493 | if hidden is not None: 494 | hidden_prime = model.decoding.decoding.decoder.batch_copy_states(hidden_prime, hidden, blank_indices) 495 | elif len(blank_indices) > 0 and hidden is None: 496 | hidden_prime = model.decoding.decoding.decoder.batch_copy_states(hidden_prime, None, blank_indices, value=0.0) 497 | k[blank_indices] = last_label[blank_indices, 0] 498 | last_label = k.clone().view(-1, 1) 499 | hidden = hidden_prime 500 | for kidx, ki in enumerate(k): 501 | if blank_mask[kidx] == 0: 502 | hypotheses[kidx].y_sequence.append(ki) 503 | hypotheses[kidx].timestep.append(time_idx) 504 | hypotheses[kidx].score += float(v[kidx]) 505 | 506 | symbols_added += 1 507 | logits = torch.stack(logits, dim=1)[:, :max_out_len, :] 508 | return logits 509 | 510 | 511 | def get_logits_and_pseudo_labels(args, model, processor, wavs, lens): 512 | if args.decoding_method == "greedy_search" or args.beam_width == 1: # greedy search 513 | logits = forward_batch(args, model, processor, wavs, lens) 514 | pseudo_labels = [torch.argmax(logits, dim=-1).squeeze(0)] 515 | else: # beam search 516 | encoder_output, encoder_length = encode_batch(args, model, wavs, lens) 517 | if True in torch.isnan(encoder_output): 518 | import ipdb; ipdb.set_trace() 519 | pseudo_labels = decode_batch(args, model, processor, encoder_output, encoder_length) 520 | logits = forward_batch(args, model, processor, wavs, lens, labels=pseudo_labels[0]) 521 | 522 | return logits, pseudo_labels 523 | 524 | 525 | @torch.no_grad() 526 | def encode_batch(args, model, wavs, lens): 527 | if isinstance(model, Wav2Vec2ForCTC): 528 | logits = model(wavs).logits 529 | logitlen = torch.tensor([logits.shape[1]]).to(logits.device) 530 | outputs = logits, logitlen 531 | elif isinstance(model, EncoderDecoderASR): 532 | enc_states = model.encode_batch(wavs, lens) 533 | enc_lens = torch.tensor([enc_states.shape[1]]).to(args.device) 534 | outputs = enc_states, enc_lens 535 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 536 | processed_signal, processed_signal_length = model.preprocessor( 537 | input_signal=wavs.to(args.device), length=lens.to(args.device), 538 | ) 539 | encoder_output, encoder_length = model.encoder( 540 | audio_signal=processed_signal, length=processed_signal_length 541 | ) 542 | outputs = encoder_output, encoder_length 543 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 544 | enc_states, enc_lens = model(input_signal=wavs, input_signal_length=lens) 545 | enc_states = enc_states.transpose(1, 2) 546 | outputs = enc_states, enc_lens 547 | return outputs 548 | 549 | 550 | @torch.no_grad() 551 | def decode_batch(args, model, processor, encoder_output, encoder_length): 552 | beam_width = args.beam_width if args.decoding_method == "beam_search" else 1 553 | if isinstance(model, Wav2Vec2ForCTC): 554 | if True in np.isnan(np.clip(log_softmax(encoder_output.squeeze(0).detach().cpu().numpy(), axis=1), np.log(MIN_TOKEN_CLIP_P), 0)): 555 | import ipdb; ipdb.set_trace() 556 | pseudo_labels = decode_ctc_or_conformer( 557 | processor.decoder, 558 | logits=np.clip(log_softmax(encoder_output.squeeze(0).detach().cpu().numpy(), axis=1), np.log(MIN_TOKEN_CLIP_P), 0), 559 | beam_width=beam_width, 560 | beam_prune_logp=DEFAULT_PRUNE_LOGP, 561 | token_min_logp=DEFAULT_MIN_TOKEN_LOGP, 562 | prune_history=DEFAULT_PRUNE_BEAMS, 563 | hotword_scorer=HotwordScorer.build_scorer(None, weight=DEFAULT_HOTWORD_WEIGHT), 564 | lm_start_state=None, 565 | ) 566 | elif isinstance(model, EncoderDecoderASR): 567 | model.mods.decoder.topk = beam_width 568 | pseudo_labels = decode_attn( 569 | model.mods.decoder, 570 | encoder_output, 571 | torch.ones(1).to(encoder_output.device), 572 | beam_width, 573 | ) 574 | elif isinstance(model, nemo_asr.models.EncDecCTCModelBPE): 575 | logits = model.decoder(encoder_output=encoder_output) 576 | 577 | pseudo_labels = decode_ctc_or_conformer( 578 | processor, 579 | logits=np.clip(log_softmax(logits.squeeze(0).detach().cpu().numpy(), axis=1), np.log(MIN_TOKEN_CLIP_P), 0), 580 | beam_width=beam_width, 581 | beam_prune_logp=DEFAULT_PRUNE_LOGP, 582 | token_min_logp=DEFAULT_MIN_TOKEN_LOGP, 583 | prune_history=DEFAULT_PRUNE_BEAMS, 584 | hotword_scorer=HotwordScorer.build_scorer(None, weight=DEFAULT_HOTWORD_WEIGHT), 585 | lm_start_state=None, 586 | ) 587 | elif isinstance(model, nemo_asr.models.EncDecRNNTBPEModel): 588 | processor.beam_size = beam_width 589 | pseudo_labels = decode_trans(processor, encoder_output, encoder_length) 590 | return pseudo_labels 591 | 592 | 593 | def decode_ctc_or_conformer( 594 | model, 595 | logits, 596 | beam_width, 597 | beam_prune_logp, 598 | token_min_logp, 599 | prune_history, 600 | hotword_scorer, 601 | lm_start_state, 602 | ): 603 | def _merge_beams(beams): 604 | """Merge beams with same prefix together.""" 605 | beam_dict = {} 606 | for text, next_word, word_part, last_char, text_frames, part_frames, logit_score, idx_history in beams: 607 | new_text = _merge_tokens(text, next_word) 608 | hash_idx = (new_text, word_part, last_char) 609 | if hash_idx not in beam_dict: 610 | beam_dict[hash_idx] = ( 611 | text, 612 | next_word, 613 | word_part, 614 | last_char, 615 | text_frames, 616 | part_frames, 617 | logit_score, 618 | idx_history 619 | ) 620 | else: 621 | beam_dict[hash_idx] = ( 622 | text, 623 | next_word, 624 | word_part, 625 | last_char, 626 | text_frames, 627 | part_frames, 628 | _sum_log_scores(beam_dict[hash_idx][-2], logit_score), 629 | idx_history 630 | ) 631 | return list(beam_dict.values()) 632 | 633 | def _sort_and_trim_beams(beams, beam_width: int): 634 | """Take top N beams by score.""" 635 | return heapq.nlargest(beam_width, beams, key=lambda x: x[-2]) 636 | 637 | def _merge_tokens(token_1: str, token_2: str) -> str: 638 | """Fast, whitespace safe merging of tokens.""" 639 | if len(token_2) == 0: 640 | text = token_1 641 | elif len(token_1) == 0: 642 | text = token_2 643 | else: 644 | text = token_1 + " " + token_2 645 | return text 646 | 647 | def _sum_log_scores(s1: float, s2: float) -> float: 648 | """Sum log odds in a numerically stable way.""" 649 | # this is slightly faster than using max 650 | if s1 >= s2: 651 | log_sum = s1 + math.log(1 + math.exp(s2 - s1)) 652 | else: 653 | log_sum = s2 + math.log(1 + math.exp(s1 - s2)) 654 | return log_sum 655 | 656 | def get_new_beams( 657 | model, 658 | beams, 659 | idx_list, 660 | frame_idx, 661 | logit_col, 662 | ): 663 | new_beams = [] 664 | # bpe we can also have trailing word boundaries ▁⁇▁ so we may need to remember breaks 665 | force_next_break = False 666 | 667 | for idx_char in idx_list: 668 | p_char = logit_col[idx_char] 669 | char = model._idx2vocab[idx_char] 670 | for ( 671 | text, 672 | next_word, 673 | word_part, 674 | last_char, 675 | text_frames, 676 | part_frames, 677 | logit_score, 678 | idx_history, 679 | ) in beams: 680 | if char == "" or last_char == char: 681 | if char == "": 682 | new_end_frame = part_frames[0] 683 | else: 684 | new_end_frame = frame_idx + 1 685 | new_part_frames = ( 686 | part_frames if char == "" else (part_frames[0], new_end_frame) 687 | ) 688 | new_beams.append( 689 | ( 690 | text, 691 | next_word, 692 | word_part, 693 | char, 694 | text_frames, 695 | new_part_frames, 696 | logit_score + p_char, 697 | idx_history + [idx_char], 698 | ) 699 | ) 700 | # if bpe and leading space char 701 | elif model._is_bpe and (char[:1] == BPE_TOKEN or force_next_break): 702 | force_next_break = False 703 | # some tokens are bounded on both sides like ▁⁇▁ 704 | clean_char = char 705 | if char[:1] == BPE_TOKEN: 706 | clean_char = clean_char[1:] 707 | if char[-1:] == BPE_TOKEN: 708 | clean_char = clean_char[:-1] 709 | force_next_break = True 710 | new_frame_list = ( 711 | text_frames if word_part == "" else text_frames + [part_frames] 712 | ) 713 | new_beams.append( 714 | ( 715 | text, 716 | word_part, 717 | clean_char, 718 | char, 719 | new_frame_list, 720 | (frame_idx, frame_idx + 1), 721 | logit_score + p_char, 722 | idx_history + [idx_char], 723 | ) 724 | ) 725 | # if not bpe and space char 726 | elif not model._is_bpe and char == " ": 727 | new_frame_list = ( 728 | text_frames if word_part == "" else text_frames + [part_frames] 729 | ) 730 | new_beams.append( 731 | ( 732 | text, 733 | word_part, 734 | "", 735 | char, 736 | new_frame_list, 737 | NULL_FRAMES, 738 | logit_score + p_char, 739 | idx_history + [idx_char], 740 | ) 741 | ) 742 | # general update of continuing token without space 743 | else: 744 | new_part_frames = ( 745 | (frame_idx, frame_idx + 1) 746 | if part_frames[0] < 0 747 | else (part_frames[0], frame_idx + 1) 748 | ) 749 | new_beams.append( 750 | ( 751 | text, 752 | next_word, 753 | word_part + char, 754 | char, 755 | text_frames, 756 | new_part_frames, 757 | logit_score + p_char, 758 | idx_history + [idx_char], 759 | ) 760 | ) 761 | new_beams = _merge_beams(new_beams) 762 | return new_beams 763 | 764 | def get_lm_beams( 765 | model, 766 | beams, 767 | hotword_scorer: HotwordScorer, 768 | cached_lm_scores: Dict[str, Tuple[float, float, LMState]], 769 | cached_partial_token_scores: Dict[str, float], 770 | is_eos: bool = False, 771 | ) -> List[LMBeam]: 772 | """Update score by averaging logit_score and lm_score.""" 773 | # get language model and see if exists 774 | language_model = model._language_model 775 | 776 | # if no language model available then return raw score + hotwords as lm score 777 | if language_model is None: 778 | new_beams = [] 779 | for text, next_word, word_part, last_char, frame_list, frames, logit_score, idx_history in beams: 780 | new_text = _merge_tokens(text, next_word) 781 | # note that usually this gets scaled with alpha 782 | lm_hw_score = ( 783 | logit_score 784 | + hotword_scorer.score(new_text) 785 | + hotword_scorer.score_partial_token(word_part) 786 | ) 787 | new_beams.append( 788 | ( 789 | new_text, 790 | "", 791 | word_part, 792 | last_char, 793 | frame_list, 794 | frames, 795 | logit_score, 796 | lm_hw_score, 797 | idx_history 798 | ) 799 | ) 800 | return new_beams 801 | 802 | new_beams = [] 803 | for text, next_word, word_part, last_char, frame_list, frames, logit_score, idx_history in beams: 804 | new_text = _merge_tokens(text, next_word) 805 | if new_text not in cached_lm_scores: 806 | _, prev_raw_lm_score, start_state = cached_lm_scores[text] 807 | score, end_state = language_model.score(start_state, next_word, is_last_word=is_eos) 808 | raw_lm_score = prev_raw_lm_score + score 809 | lm_hw_score = raw_lm_score + hotword_scorer.score(new_text) 810 | cached_lm_scores[new_text] = (lm_hw_score, raw_lm_score, end_state) 811 | lm_score, _, _ = cached_lm_scores[new_text] 812 | 813 | if len(word_part) > 0: 814 | if word_part not in cached_partial_token_scores: 815 | # if prefix available in hotword trie use that, otherwise default to char trie 816 | if word_part in hotword_scorer: 817 | cached_partial_token_scores[word_part] = hotword_scorer.score_partial_token( 818 | word_part 819 | ) 820 | else: 821 | cached_partial_token_scores[word_part] = language_model.score_partial_token( 822 | word_part 823 | ) 824 | lm_score += cached_partial_token_scores[word_part] 825 | 826 | new_beams.append( 827 | ( 828 | new_text, 829 | "", 830 | word_part, 831 | last_char, 832 | frame_list, 833 | frames, 834 | logit_score, 835 | logit_score + lm_score, 836 | idx_history, 837 | ) 838 | ) 839 | return new_beams 840 | 841 | language_model = model._language_model 842 | if lm_start_state is None and language_model is not None: 843 | cached_lm_scores: Dict[str, Tuple[float, float, LMState]] = { 844 | "": (0.0, 0.0, language_model.get_start_state()) 845 | } 846 | else: 847 | cached_lm_scores = {"": (0.0, 0.0, lm_start_state)} 848 | cached_p_lm_scores: Dict[str, float] = {} 849 | # start with single beam to expand on 850 | beams = [("", "", "", None, [], NULL_FRAMES, 0.0, [])] 851 | 852 | for frame_idx, logit_col in enumerate(logits): 853 | max_idx = logit_col.argmax() 854 | idx_list = set(np.where(logit_col >= token_min_logp)[0]) | {max_idx} 855 | new_beams = get_new_beams( 856 | model, 857 | beams, 858 | idx_list, 859 | frame_idx, 860 | logit_col, 861 | ) 862 | # lm scoring and beam pruning 863 | scored_beams = get_lm_beams( 864 | model, 865 | new_beams, 866 | hotword_scorer, 867 | cached_lm_scores, 868 | cached_p_lm_scores, 869 | ) 870 | 871 | # remove beam outliers 872 | try: 873 | max_score = max([b[-2] for b in scored_beams]) 874 | except: 875 | import ipdb; ipdb.set_trace() 876 | scored_beams = [b for b in scored_beams if b[-2] >= max_score + beam_prune_logp] 877 | trimmed_beams = _sort_and_trim_beams(scored_beams, beam_width) 878 | beams = [b[:-2] + (b[-1], ) for b in trimmed_beams] 879 | 880 | new_beams = [] 881 | for text, _, word_part, _, frame_list, frames, logit_score, idx_history in beams: 882 | new_token_times = frame_list if word_part == "" else frame_list + [frames] 883 | new_beams.append((text, word_part, "", None, new_token_times, (-1, -1), logit_score, idx_history)) 884 | new_beams = _merge_beams(new_beams) 885 | scored_beams = get_lm_beams( 886 | model, 887 | new_beams, 888 | hotword_scorer, 889 | cached_lm_scores, 890 | cached_p_lm_scores, 891 | is_eos=True, 892 | ) 893 | scored_beams = [b[:-2] + (b[-1], ) for b in scored_beams] 894 | scored_beams = _merge_beams(scored_beams) 895 | 896 | # remove beam outliers 897 | max_score = max([b[-2] for b in beams]) 898 | scored_beams = [b for b in beams if b[-2] >= max_score + beam_prune_logp] 899 | trimmed_beams = _sort_and_trim_beams(scored_beams, beam_width) 900 | 901 | # remove unnecessary information from beams 902 | output_beams = [ 903 | torch.tensor(idx_history) 904 | for _, _, _, _, _, _, _, idx_history in trimmed_beams 905 | ] 906 | return output_beams 907 | 908 | 909 | def decode_attn(model, enc_states, wav_len, beam_width): 910 | def inflate_tensor(tensor, times, dim): 911 | return torch.repeat_interleave(tensor, times, dim=dim) 912 | 913 | def mask_by_condition(tensor, cond, fill_value): 914 | tensor = torch.where( 915 | cond, tensor, torch.Tensor([fill_value]).to(tensor.device) 916 | ) 917 | return tensor 918 | 919 | def forward_step(model, inp_tokens, memory, enc_states, enc_lens): 920 | """Performs a step in the implemented beamsearcher.""" 921 | hs, c = memory 922 | e = model.emb(inp_tokens) 923 | dec_out, hs, c, w = model.dec.forward_step( 924 | e, hs, c, enc_states, enc_lens 925 | ) 926 | log_probs = model.softmax(model.fc(dec_out) / model.temperature) 927 | if model.dec.attn_type == "multiheadlocation": 928 | w = torch.mean(w, dim=1) 929 | return log_probs, (hs, c), w 930 | 931 | def lm_forward_step(model, inp_tokens, memory): 932 | """Applies a step to the LM during beamsearch.""" 933 | with torch.no_grad(): 934 | logits, hs = model.lm(inp_tokens, hx=memory) 935 | log_probs = model.log_softmax(logits / model.temperature_lm) 936 | return log_probs, hs 937 | 938 | def ctc_forward_step(model, x): 939 | """Applies a ctc step during bramsearch.""" 940 | logits = model.ctc_fc(x) 941 | log_probs = model.softmax(logits) 942 | return log_probs 943 | 944 | enc_lens = torch.round(enc_states.shape[1] * wav_len).int() 945 | device = enc_states.device 946 | batch_size = enc_states.shape[0] 947 | 948 | memory = model.reset_mem(batch_size * beam_width, device=device) 949 | 950 | if model.lm_weight > 0: 951 | lm_memory = model.reset_lm_mem(batch_size * beam_width, device) 952 | 953 | if model.ctc_weight > 0: 954 | # (batch_size * beam_size, L, vocab_size) 955 | ctc_outputs = ctc_forward_step(model, enc_states) 956 | ctc_scorer = CTCPrefixScorer( 957 | ctc_outputs, 958 | enc_lens, 959 | batch_size, 960 | beam_width, 961 | model.blank_index, 962 | model.eos_index, 963 | model.ctc_window_size, 964 | ) 965 | ctc_memory = None 966 | 967 | # Inflate the enc_states and enc_len by beam_size times 968 | enc_states = inflate_tensor(enc_states, times=beam_width, dim=0) 969 | enc_lens = inflate_tensor(enc_lens, times=beam_width, dim=0) 970 | 971 | # Using bos as the first input 972 | inp_tokens = ( 973 | torch.zeros(batch_size * beam_width, device=device) 974 | .fill_(model.bos_index) 975 | .long() 976 | ) 977 | 978 | # The first index of each sentence. 979 | model.beam_offset = ( 980 | torch.arange(batch_size, device=device) * beam_width 981 | ) 982 | 983 | # initialize sequence scores variables. 984 | sequence_scores = torch.empty( 985 | batch_size * beam_width, device=device 986 | ) 987 | sequence_scores.fill_(float("-inf")) 988 | 989 | # keep only the first to make sure no redundancy. 990 | sequence_scores.index_fill_(0, model.beam_offset, 0.0) 991 | 992 | # keep the hypothesis that reaches eos and their corresponding score and log_probs. 993 | hyps_and_scores = [[] for _ in range(batch_size)] 994 | 995 | # keep the sequences that still not reaches eos. 996 | alived_seq = torch.empty( 997 | batch_size * beam_width, 0, device=device 998 | ).long() 999 | 1000 | # Keep the log-probabilities of alived sequences. 1001 | alived_log_probs = torch.empty( 1002 | batch_size * beam_width, 0, device=device 1003 | ) 1004 | 1005 | min_decode_steps = int(enc_states.shape[1] * model.min_decode_ratio) 1006 | max_decode_steps = int(enc_states.shape[1] * model.max_decode_ratio) 1007 | 1008 | # Initialize the previous attention peak to zero 1009 | # This variable will be used when using_max_attn_shift=True 1010 | prev_attn_peak = torch.zeros(batch_size * beam_width, device=device) 1011 | 1012 | for t in range(max_decode_steps): 1013 | # terminate condition 1014 | if model._check_full_beams(hyps_and_scores, beam_width): 1015 | break 1016 | 1017 | log_probs, memory, attn = forward_step( 1018 | model, inp_tokens, memory, enc_states, enc_lens 1019 | ) 1020 | log_probs = model.att_weight * log_probs 1021 | 1022 | # Keep the original value 1023 | log_probs_clone = log_probs.clone().reshape(batch_size, -1) 1024 | vocab_size = log_probs.shape[-1] 1025 | 1026 | if model.using_max_attn_shift: 1027 | # Block the candidates that exceed the max shift 1028 | cond, attn_peak = model._check_attn_shift(attn, prev_attn_peak) 1029 | log_probs = mask_by_condition( 1030 | log_probs, cond, fill_value=model.minus_inf 1031 | ) 1032 | prev_attn_peak = attn_peak 1033 | 1034 | # Set eos to minus_inf when less than minimum steps. 1035 | if t < min_decode_steps: 1036 | log_probs[:, model.eos_index] = model.minus_inf 1037 | 1038 | # Set the eos prob to minus_inf when it doesn't exceed threshold. 1039 | if model.using_eos_threshold: 1040 | cond = model._check_eos_threshold(log_probs) 1041 | log_probs[:, model.eos_index] = mask_by_condition( 1042 | log_probs[:, model.eos_index], 1043 | cond, 1044 | fill_value=model.minus_inf, 1045 | ) 1046 | 1047 | # adding LM scores to log_prob if lm_weight > 0 1048 | if model.lm_weight > 0: 1049 | lm_log_probs, lm_memory = lm_forward_step( 1050 | model, inp_tokens, lm_memory, 1051 | ) 1052 | log_probs = log_probs + model.lm_weight * lm_log_probs 1053 | 1054 | # adding CTC scores to log_prob if ctc_weight > 0 1055 | if model.ctc_weight > 0: 1056 | g = alived_seq 1057 | # block blank token 1058 | log_probs[:, model.blank_index] = model.minus_inf 1059 | if model.ctc_weight != 1.0 and model.ctc_score_mode == "partial": 1060 | # pruning vocab for ctc_scorer 1061 | _, ctc_candidates = log_probs.topk( 1062 | beam_width * 2, dim=-1 1063 | ) 1064 | else: 1065 | ctc_candidates = None 1066 | 1067 | ctc_log_probs, ctc_memory = ctc_scorer.forward_step( 1068 | g, ctc_memory, ctc_candidates, attn 1069 | ) 1070 | log_probs = log_probs + model.ctc_weight * ctc_log_probs 1071 | 1072 | scores = sequence_scores.unsqueeze(1).expand(-1, vocab_size) 1073 | scores = scores + log_probs 1074 | 1075 | # length normalization 1076 | if model.length_normalization: 1077 | scores = scores / (t + 1) 1078 | 1079 | scores_timestep = scores.clone() 1080 | 1081 | # keep topk beams 1082 | scores, candidates = scores.view(batch_size, -1).topk( 1083 | beam_width, dim=-1 1084 | ) 1085 | 1086 | # The input for the next step, also the output of current step. 1087 | inp_tokens = (candidates % vocab_size).view( 1088 | batch_size * beam_width 1089 | ) 1090 | 1091 | scores = scores.view(batch_size * beam_width) 1092 | sequence_scores = scores 1093 | 1094 | # recover the length normalization 1095 | if model.length_normalization: 1096 | sequence_scores = sequence_scores * (t + 1) 1097 | 1098 | # The index of which beam the current top-K output came from in (t-1) timesteps. 1099 | predecessors = ( 1100 | torch.div(candidates, vocab_size, rounding_mode="floor") 1101 | + model.beam_offset.unsqueeze(1).expand_as(candidates) 1102 | ).view(batch_size * beam_width) 1103 | 1104 | # Permute the memory to synchoronize with the output. 1105 | memory = model.permute_mem(memory, index=predecessors) 1106 | if model.lm_weight > 0: 1107 | lm_memory = model.permute_lm_mem(lm_memory, index=predecessors) 1108 | 1109 | if model.ctc_weight > 0: 1110 | ctc_memory = ctc_scorer.permute_mem(ctc_memory, candidates) 1111 | 1112 | # If using_max_attn_shift, then the previous attn peak has to be permuted too. 1113 | if model.using_max_attn_shift: 1114 | prev_attn_peak = torch.index_select( 1115 | prev_attn_peak, dim=0, index=predecessors 1116 | ) 1117 | 1118 | # Add coverage penalty 1119 | if model.coverage_penalty > 0: 1120 | cur_attn = torch.index_select(attn, dim=0, index=predecessors) 1121 | 1122 | # coverage: cumulative attention probability vector 1123 | if t == 0: 1124 | # Init coverage 1125 | model.coverage = cur_attn 1126 | 1127 | # the attn of transformer is [batch_size*beam_size, current_step, source_len] 1128 | if len(cur_attn.size()) > 2: 1129 | model.converage = torch.sum(cur_attn, dim=1) 1130 | else: 1131 | # Update coverage 1132 | model.coverage = torch.index_select( 1133 | model.coverage, dim=0, index=predecessors 1134 | ) 1135 | model.coverage = model.coverage + cur_attn 1136 | 1137 | # Compute coverage penalty and add it to scores 1138 | penalty = torch.max( 1139 | model.coverage, model.coverage.clone().fill_(0.5) 1140 | ).sum(-1) 1141 | penalty = penalty - model.coverage.size(-1) * 0.5 1142 | penalty = penalty.view(batch_size * beam_width) 1143 | penalty = ( 1144 | penalty / (t + 1) if model.length_normalization else penalty 1145 | ) 1146 | scores = scores - penalty * model.coverage_penalty 1147 | 1148 | # Update alived_seq 1149 | alived_seq = torch.cat( 1150 | [ 1151 | torch.index_select(alived_seq, dim=0, index=predecessors), 1152 | inp_tokens.unsqueeze(1), 1153 | ], 1154 | dim=-1, 1155 | ) 1156 | 1157 | # Takes the log-probabilities 1158 | beam_log_probs = log_probs_clone[ 1159 | torch.arange(batch_size).unsqueeze(1), candidates 1160 | ].reshape(batch_size * beam_width) 1161 | alived_log_probs = torch.cat( 1162 | [ 1163 | torch.index_select( 1164 | alived_log_probs, dim=0, index=predecessors 1165 | ), 1166 | beam_log_probs.unsqueeze(1), 1167 | ], 1168 | dim=-1, 1169 | ) 1170 | 1171 | is_eos = model._update_hyp_and_scores( 1172 | inp_tokens, 1173 | alived_seq, 1174 | alived_log_probs, 1175 | hyps_and_scores, 1176 | scores, 1177 | timesteps=t, 1178 | ) 1179 | 1180 | # Block the paths that have reached eos. 1181 | sequence_scores.masked_fill_(is_eos, float("-inf")) 1182 | 1183 | if not model._check_full_beams(hyps_and_scores, beam_width): 1184 | # Using all eos to fill-up the hyps. 1185 | eos = ( 1186 | torch.zeros(batch_size * beam_width, device=device) 1187 | .fill_(model.eos_index) 1188 | .long() 1189 | ) 1190 | _ = model._update_hyp_and_scores( 1191 | eos, 1192 | alived_seq, 1193 | alived_log_probs, 1194 | hyps_and_scores, 1195 | scores, 1196 | timesteps=max_decode_steps, 1197 | ) 1198 | 1199 | topk_hyps, _, _, _, = model._get_top_score_prediction(hyps_and_scores, topk=beam_width) 1200 | pseudo_labels = list(torch.unbind(topk_hyps.squeeze(0), dim=0)) 1201 | aux_label = [torch.tensor([model.blank_index for _ in range(max_decode_steps)])] 1202 | pseudo_labels = pad_sequence(pseudo_labels + aux_label, batch_first=True, padding_value=model.blank_index)[:beam_width, :max_decode_steps] 1203 | return pseudo_labels 1204 | 1205 | 1206 | def decode_trans(model, h, encoded_lengths): 1207 | # Initialize states 1208 | beam = min(model.beam_size, model.vocab_size) 1209 | beam_k = min(beam, (model.vocab_size - 1)) 1210 | 1211 | blank_tensor = torch.tensor([model.blank], device=h.device, dtype=torch.long) 1212 | 1213 | # Precompute some constants for blank position 1214 | ids = list(range(model.vocab_size + 1)) 1215 | ids.remove(model.blank) 1216 | 1217 | # Used when blank token is first vs last token 1218 | if model.blank == 0: 1219 | index_incr = 1 1220 | else: 1221 | index_incr = 0 1222 | 1223 | # Initialize zero vector states 1224 | dec_state = model.decoder.initialize_state(h) 1225 | 1226 | # Initialize first hypothesis for the beam (blank) 1227 | kept_hyps = [Hypothesis(score=0.0, y_sequence=[model.blank], dec_state=dec_state, timestep=[-1], length=0, token_list=[])] 1228 | cache = {} 1229 | 1230 | for i in range(int(encoded_lengths)): 1231 | hi = h[:, i : i + 1, :] # [1, 1, D] 1232 | hyps = kept_hyps 1233 | kept_hyps = [] 1234 | 1235 | while True: 1236 | max_hyp = max(hyps, key=lambda x: x.score) 1237 | hyps.remove(max_hyp) 1238 | 1239 | # update decoder state and get next score 1240 | y, state, lm_state = model.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] 1241 | 1242 | # get next token 1243 | logit = model.joint.joint(hi, y) / model.softmax_temperature 1244 | ytu = torch.log_softmax(logit, dim=-1) # [1, 1, 1, V + 1] 1245 | ytu = ytu[0, 0, 0, :] # [V + 1] 1246 | 1247 | # remove blank token before top k 1248 | top_k = ytu[ids].topk(beam_k, dim=-1) 1249 | 1250 | # Two possible steps - blank token or non-blank token predicted 1251 | ytu = ( 1252 | torch.cat((top_k[0], ytu[model.blank].unsqueeze(0))), 1253 | torch.cat((top_k[1] + index_incr, blank_tensor)), 1254 | ) 1255 | 1256 | # for each possible step 1257 | for logp, k in zip(*ytu): 1258 | # construct hypothesis for step 1259 | new_hyp = Hypothesis( 1260 | score=(max_hyp.score + float(logp)), 1261 | y_sequence=max_hyp.y_sequence[:], 1262 | dec_state=max_hyp.dec_state, 1263 | lm_state=max_hyp.lm_state, 1264 | timestep=max_hyp.timestep[:], 1265 | length=encoded_lengths, 1266 | token_list=max_hyp.token_list+[k], 1267 | ) 1268 | 1269 | # if current token is blank, dont update sequence, just store the current hypothesis 1270 | if k == model.blank: 1271 | kept_hyps.append(new_hyp) 1272 | else: 1273 | # if non-blank token was predicted, update state and sequence and then search more hypothesis 1274 | new_hyp.dec_state = state 1275 | new_hyp.y_sequence.append(int(k)) 1276 | new_hyp.timestep.append(i) 1277 | hyps.append(new_hyp) 1278 | 1279 | # keep those hypothesis that have scores greater than next search generation 1280 | hyps_max = float(max(hyps, key=lambda x: x.score).score) 1281 | kept_most_prob = sorted([hyp for hyp in kept_hyps if hyp.score > hyps_max], key=lambda x: x.score,) 1282 | 1283 | # If enough hypothesis have scores greater than next search generation, stop beam search. 1284 | if len(kept_most_prob) >= beam: 1285 | kept_hyps = kept_most_prob 1286 | break 1287 | 1288 | pseudo_labels = [torch.tensor(hyp.token_list) for hyp in model.sort_nbest(kept_hyps)] 1289 | aux_label = [torch.tensor([model.blank for _ in range(int(encoded_lengths))])] 1290 | pseudo_labels = pad_sequence(pseudo_labels + aux_label, batch_first=True, padding_value=model.blank)[:model.beam_size] 1291 | return [pseudo_label for pseudo_label in pseudo_labels] 1292 | 1293 | --------------------------------------------------------------------------------