├── .DS_Store ├── README.md ├── finetune.py ├── finetune.sh ├── inference.py ├── inference.sh ├── requirements.txt └── star.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YUCHEN005/STAR-Adapt/3003fd20047ce56888911e8f2bd46100455e3215/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Taught Recognizer: Toward Unsupervised Adaptation for Speech Foundation Models 2 | 3 | [[Paper]](https://arxiv.org/pdf/2405.14161) 4 | 5 |

6 | 7 | This work proposes a source-free unsupervised domain adaptation approach for speech foundation models. 8 | 9 | ## Conda Environment Configuration 10 | 11 | Our conda environment is provided via the file `requirements.txt`, please run the command below to install necessary packages: 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Data Preparation 17 | 18 | Our code requires two kaldi-format data files: `wav.scp` and `text`. 19 | 20 | - `wav.scp` contains a list of audio files, each line includes sample ID and absolute audio path: 21 | 22 | ``` 23 | utt_1 /your-data-path/1.wav 24 | utt_2 /your-data-path/2.wav 25 | ``` 26 | 27 | - `text` contains a list of ground-truth transcriptions, each line includes sample ID and transcription: 28 | 29 | ``` 30 | utt_1 i feel good 31 | utt_2 he is coming back 32 | ``` 33 | 34 | **NOTE:** each line in above two files should be paired. 35 | 36 | 37 | ## Training 38 | Please refer to our training script `finetune.sh` and specify some settings: 39 | - `dataset`: training data name; 40 | - `model_size`: whisper model size; 41 | - `train_data`: training data directory that contains files `wav.scp` and `text`; 42 | - `dev_data`: development data directory that contains files `wav.scp` and `text`; 43 | 44 | Then, please run command `bash finetune.sh` to start training. The model weights will be saved at `runs/{dataset}_{model_size}`. 45 | 46 | 47 | ## Inference 48 | Please refer to our inference script `inference.sh` and specify some settings: 49 | - `dataset`: training data name; 50 | - `model_size`: whisper model size; 51 | - `checkpoint`: path of the trained model checkpoint (`.pth` file); 52 | - `test_data`: test data directory that contains files `wav.scp` and `text`; 53 | 54 | Please run command `bash inference.sh` for inference. WER results would be printed in the log. 55 | 56 | 57 | ## References 58 | 59 | We kindly hope you can cite our paper in your publication when using our research or code: 60 | ```bib 61 | @article{hu2024self, 62 | title={Self-Taught Recognizer: Toward Unsupervised Adaptation for Speech Foundation Models}, 63 | author={Hu, Yuchen and Chen, Chen and Yang, Chao-Han Huck and Qin, Chengwei and Chen, Pin-Yu and Chng, Eng Siong and Zhang, Chao}, 64 | journal={arXiv preprint arXiv:2405.14161}, 65 | year={2024} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor 2 | from transformers import AutoFeatureExtractor, WhisperModel 3 | from transformers import LlamaTokenizer 4 | from datasets import load_dataset 5 | import torch, torchaudio 6 | from torch import nn 7 | import numpy as np 8 | from jiwer import wer as calculate_wer 9 | import pickle 10 | import fire 11 | from datasets import Dataset, Audio, Value 12 | import os, random 13 | from typing import Optional 14 | from whisper.normalizers import EnglishTextNormalizer 15 | import math 16 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer 17 | from pathlib import Path 18 | import whisper 19 | import copy, heapq 20 | normalizer = EnglishTextNormalizer() 21 | 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | def sigmoid(x): 25 | return 1 / (1 + np.exp(-x)) 26 | 27 | def train( 28 | MODEL = "openai/whisper-large-v3", 29 | DATASET = "chime4", 30 | TRAIN_DATA = "", 31 | DEV_DATA = "", 32 | SAVE_EVERY = 10, 33 | BATCH_SIZE = 32, 34 | GRADIENT_ACCUMULATION_STEPS = 4, 35 | LEARNING_RATE = 1e-3, 36 | EPOCHS = 100, 37 | THRESOLD=2.0, 38 | TOP_PERCENT=0.8, 39 | TAU=10, 40 | ): 41 | feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL) 42 | processor = WhisperProcessor.from_pretrained(MODEL, language="en", task="transcribe") 43 | tokenizer = WhisperTokenizer.from_pretrained(MODEL, language="en", task="transcribe") 44 | model = WhisperForConditionalGeneration.from_pretrained(MODEL).to(device) 45 | forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") 46 | state_dict = copy.deepcopy(model.state_dict()) 47 | 48 | prompt_and_eos = tokenizer('')['input_ids'] 49 | prompt_ids, eos_id = prompt_and_eos[:-1], prompt_and_eos[-1] 50 | 51 | def data_preparation(data_path, feature_extractor, tokenizer): 52 | with open(data_path + "wav.scp", 'r') as f1: 53 | wave_data = f1.readlines() 54 | with open(data_path + "text", 'r') as f2: 55 | trans_data = f2.readlines() 56 | 57 | audio_data, txt_data = [], [] 58 | for i in range(len(wave_data)): 59 | audio_data.append(wave_data[i]) 60 | txt_data.append(trans_data[i]) 61 | 62 | audio_dataset = [] 63 | all_pred, all_gt = [], [] 64 | for audio_line, text_line in zip(audio_data, txt_data): 65 | audio_path = audio_line.strip().split()[1] 66 | text = ' '.join(text_line.split()[1:]).lower().strip() 67 | audio, sr = torchaudio.load(audio_path) 68 | if sr != 16000: 69 | audio = torchaudio.functional.resample(audio, sr, 16000) 70 | item = {'audio': audio, 'text': text} 71 | 72 | item['mel'] = feature_extractor(audio.squeeze(0).numpy(), sampling_rate=16_000, return_tensors="pt")['input_features'] 73 | item['decoder_input_ids'] = tokenizer(text, max_length=1024, truncation=True).input_ids 74 | 75 | model.load_state_dict(state_dict) 76 | hidden_feature = model.model.encoder(input_features=item['mel'].to(device)).last_hidden_state 77 | 78 | # prompt: '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>' 79 | pseudo_label_ids = torch.tensor([prompt_ids]).long().to(device) 80 | 81 | ### probs: confidence score 82 | probs, decoder_outputs = [], None 83 | for _ in range(150): 84 | decoder_outputs = model(encoder_outputs=(hidden_feature), decoder_input_ids=pseudo_label_ids, output_attentions=True) 85 | logits = torch.softmax(decoder_outputs.logits / 1.2, dim=-1) 86 | next_token = logits[0, -1, :].topk(1)[1] 87 | probs.append(float(logits[0, -1, next_token])) 88 | pseudo_label_ids = torch.cat((pseudo_label_ids, next_token.unsqueeze(0)), dim=-1) 89 | if next_token == eos_id: # EOS 90 | break 91 | 92 | # normlization 93 | mean_probs = sum(probs) / len(probs) 94 | for k in range(len(probs)): 95 | probs[k] = round(probs[k] / mean_probs, 3) 96 | 97 | ### weights: attentive score 98 | n_prompt_toks = 4 99 | layer_id, head_id = 30, 13 # suggest: layer_id \in [30,31], head_id \in [0,1,...,19] 100 | attn = decoder_outputs.decoder_attentions[layer_id][0, head_id, :, :] 101 | attn[:, :n_prompt_toks-1] = 0 # remove prompts 102 | weights = [] 103 | for i in range(n_prompt_toks-1, attn.shape[-1]): 104 | weight = torch.sum(attn[i, :]) + torch.sum(attn[:, i]) - attn[i, i] 105 | weights.append(float(weight)) 106 | 107 | # normalization 108 | mean_weights = sum(weights) / len(weights) 109 | for j in range(len(weights)): 110 | weights[j] = round(weights[j] / mean_weights, 3) 111 | 112 | ### final_weights: star score 113 | final_weights = [] 114 | for ci, ai in zip(probs, weights): 115 | c_over_a, a_over_c = ci * ci / ai, ai * ai / ci 116 | conflict = (sigmoid((c_over_a - THRESOLD) * TAU) + sigmoid((a_over_c - THRESOLD) * TAU)) * ai 117 | no_conflict = (sigmoid((THRESOLD - c_over_a) * TAU) * sigmoid((THRESOLD - a_over_c) * TAU)) * ai * np.exp((ci - ai) / TAU) 118 | final_weights.append(conflict + no_conflict) 119 | 120 | item['pseudo_label_ids'] = pseudo_label_ids 121 | item['probs'] = torch.tensor(final_weights).unsqueeze(0) 122 | pseudo_text = processor.batch_decode(pseudo_label_ids, skip_special_tokens=True)[0] 123 | 124 | ### utt-level uncertainty 125 | if 'train' in data_path: 126 | avg_wer, generated_texts = 0, [] 127 | for _ in range(5): 128 | new_state_dict = copy.deepcopy(state_dict) 129 | for k in new_state_dict.keys(): 130 | std = torch.std(new_state_dict[k]) 131 | noise = torch.randn_like(new_state_dict[k]) 132 | new_state_dict[k] = new_state_dict[k] + noise * std * 0.1 133 | 134 | model.load_state_dict(new_state_dict) 135 | generated_ids = model.generate(inputs=item['mel'].to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150) 136 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 137 | generated_texts.append(generated_text) 138 | avg_wer += calculate_wer([pseudo_text], [generated_text]) / 5 139 | 140 | item['avg_wer'] = avg_wer 141 | item['diversity'] = len(list(set(generated_texts))) 142 | 143 | ## text normalization 144 | pseudo_text = normalizer(pseudo_text) 145 | pseudo_text = pseudo_text if len(pseudo_text) > 0 else '' 146 | 147 | gt = normalizer(text) 148 | gt = gt if len(gt) > 0 else '' 149 | 150 | audio_dataset.append(item) 151 | all_pred.append(pseudo_text) 152 | all_gt.append(gt) 153 | 154 | model.load_state_dict(state_dict) 155 | return audio_dataset, calculate_wer(all_gt, all_pred) 156 | 157 | 158 | def evaluate(model, dataset): 159 | with torch.no_grad(): 160 | all_pred, all_gt = [], [] 161 | for item in dataset: 162 | mel = item['mel'] 163 | generated_ids = model.generate(inputs=mel.to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150) 164 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 165 | 166 | ## text normalization 167 | pred = normalizer(generated_text) 168 | pred = pred if len(pred) > 0 else '' 169 | 170 | gt = normalizer(item['text']) 171 | gt = gt if len(gt) > 0 else '' 172 | 173 | all_pred.append(pred) 174 | all_gt.append(gt) 175 | 176 | return calculate_wer(all_gt, all_pred) 177 | 178 | 179 | model.eval() 180 | train_dataset, train_wer = data_preparation(TRAIN_DATA, feature_extractor, tokenizer) 181 | dev_dataset, dev_wer = data_preparation(DEV_DATA, feature_extractor, tokenizer) 182 | os.system('mkdir -p data') 183 | torch.save(train_dataset, f'data/train_{DATASET}.pt') 184 | torch.save(dev_dataset, f'data/dev_{DATASET}.pt') 185 | model.train() 186 | 187 | ## load saved data 188 | # train_dataset = torch.load(f'data/train_{DATASET}.pt') 189 | # dev_dataset = torch.load(f'data/dev_{DATASET}.pt') 190 | 191 | ## utt-level filtering 192 | def product(item): 193 | return item['avg_wer'] * item['diversity'] 194 | filtered_train_dataset = heapq.nsmallest(int(len(train_dataset) * TOP_PERCENT), train_dataset, key=product) 195 | 196 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 197 | loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none') 198 | 199 | model_size = MODEL.replace('openai/whisper-', '') 200 | exp_dir = f'runs/{DATASET}_{model_size}' 201 | os.system(f"mkdir -p {exp_dir}") 202 | 203 | steps, loss = 0, 0 204 | best_loss, best_wer = 10000, 10000 205 | for Epoch in range(EPOCHS): 206 | print("Epoch: ", Epoch + 1) 207 | 208 | # Train 209 | random.shuffle(filtered_train_dataset) 210 | print('Training...') 211 | for i in range(len(filtered_train_dataset) // BATCH_SIZE): 212 | batch_data = filtered_train_dataset[i * BATCH_SIZE: (i+1) * BATCH_SIZE] 213 | 214 | input_features = [{"input_features": item["mel"]} for item in batch_data] 215 | mel = processor.feature_extractor.pad(input_features, return_tensors="pt")["input_features"].squeeze(1).to(device) 216 | 217 | labels = batch_data[0]["pseudo_label_ids"].to(device) 218 | y_in = labels[:, :-1] 219 | y_out = labels[:, 1:] 220 | 221 | logits = model(input_features=mel, decoder_input_ids=y_in).logits 222 | loss_items = loss_fn(logits.permute(0, 2, 1), y_out) 223 | 224 | # uncertainty calibration 225 | ratios = batch_data[0]['probs'].to(device) 226 | ratios = ratios / torch.mean(ratios) 227 | loss = (torch.sum(loss_items[:, :n_prompt_toks-1]) + torch.sum(loss_items[:, n_prompt_toks-1:] * ratios)) / (n_prompt_toks-1 + ratios.shape[-1]) 228 | 229 | loss.backward() 230 | steps += 1 231 | 232 | if steps % GRADIENT_ACCUMULATION_STEPS == 0: 233 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 234 | optimizer.step() 235 | optimizer.zero_grad() 236 | 237 | if steps % SAVE_EVERY == 0: # Evaluate 238 | torch.save(model, f"{exp_dir}/Iter_{steps}.pth") 239 | 240 | model.eval() 241 | dev_wer = evaluate(model, dev_dataset) 242 | model.train() 243 | 244 | if dev_wer < best_wer or (dev_wer == best_wer and loss < best_loss): 245 | torch.save(model, f"{exp_dir}/best_checkpoint.pth") 246 | best_loss, best_wer = loss, dev_wer 247 | 248 | torch.save(model, f"{exp_dir}/last_checkpoint.pth") 249 | 250 | 251 | if __name__ == "__main__": 252 | fire.Fire(train) 253 | 254 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source activate 4 | 5 | dataset=chime4 6 | model_size=large-v3 7 | train_data= 8 | dev_data= 9 | 10 | $cmd log/finetune_${dataset}_${model_size}.log \ 11 | python finetune.py \ 12 | --MODEL "openai/whisper-${model_size}" \ 13 | --DATASET ${dataset} \ 14 | --TRAIN_DATA ${train_data} \ 15 | --DEV_DATA ${dev_data} \ 16 | --BATCH_SIZE 1 \ 17 | --GRADIENT_ACCUMULATION_STEPS 16 \ 18 | --LEARNING_RATE 1e-5 \ 19 | --EPOCHS 2 \ 20 | 21 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor 2 | from transformers import AutoFeatureExtractor, WhisperModel 3 | from transformers import LlamaTokenizer 4 | from datasets import load_dataset 5 | import torch, torchaudio 6 | from torch import nn 7 | import numpy as np 8 | from jiwer import wer as calculate_wer 9 | import pickle 10 | import fire 11 | from datasets import Dataset, Audio, Value 12 | import os, random 13 | from typing import Optional 14 | from whisper.normalizers import EnglishTextNormalizer 15 | import math 16 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer 17 | from pathlib import Path 18 | import whisper 19 | import copy, heapq 20 | normalizer = EnglishTextNormalizer() 21 | 22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | def sigmoid(x): 25 | return 1 / (1 + np.exp(-x)) 26 | 27 | def train( 28 | MODEL = "openai/whisper-large-v3", 29 | DATASET = "chime4", 30 | TEST_DATA = "", 31 | CKPT = "", 32 | ): 33 | feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL) 34 | processor = WhisperProcessor.from_pretrained(MODEL, language="en", task="transcribe") 35 | tokenizer = WhisperTokenizer.from_pretrained(MODEL, language="en", task="transcribe") 36 | forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") 37 | 38 | 39 | def data_preparation(data_path, feature_extractor, tokenizer): 40 | with open(data_path + "wav.scp", 'r') as f1: 41 | wave_data = f1.readlines() 42 | with open(data_path + "text", 'r') as f2: 43 | trans_data = f2.readlines() 44 | 45 | audio_data, txt_data = [], [] 46 | for i in range(len(wave_data)): 47 | audio_data.append(wave_data[i]) 48 | txt_data.append(trans_data[i]) 49 | 50 | audio_dataset = [] 51 | for audio_line, text_line in zip(audio_data, txt_data): 52 | audio_path = audio_line.strip().split()[1] 53 | audio, sr = torchaudio.load(audio_path) 54 | if sr != 16000: 55 | audio = torchaudio.functional.resample(audio, sr, 16000) 56 | mel = feature_extractor(audio.squeeze(0).numpy(), sampling_rate=16_000, return_tensors="pt")['input_features'] 57 | text = ' '.join(text_line.split()[1:]).lower().strip() 58 | 59 | item = {'mel': mel, 'text': text} 60 | audio_dataset.append(item) 61 | 62 | return audio_dataset 63 | 64 | 65 | def evaluate(model, dataset): 66 | with torch.no_grad(): 67 | all_pred, all_gt = [], [] 68 | for item in dataset: 69 | mel = item['mel'] 70 | generated_ids = model.generate(inputs=mel.to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150) 71 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] 72 | 73 | ## text normalization 74 | pred = normalizer(generated_text) 75 | pred = pred if len(pred) > 0 else '' 76 | 77 | gt = normalizer(item['text']) 78 | gt = gt if len(gt) > 0 else '' 79 | 80 | all_pred.append(pred) 81 | all_gt.append(gt) 82 | 83 | return calculate_wer(all_gt, all_pred) 84 | 85 | 86 | ## prepare dataset 87 | test_dataset = data_preparation(TEST_DATA, feature_extractor, tokenizer) 88 | torch.save(test_dataset, f'data/test_{DATASET}.pt') 89 | print(f'{DATASET}:') 90 | # test_dataset = torch.load(f'data/test_{DATASET}.pt') 91 | 92 | ## evaluate official whisper (only need to run once) 93 | model = WhisperForConditionalGeneration.from_pretrained(MODEL).to(device) 94 | model.eval() 95 | print(f'zero-shot = {evaluate(model, test_dataset)}') 96 | 97 | ## evaluate star adapted whisper 98 | model = torch.load(CKPT).to(device) 99 | model.eval() 100 | print(f'star = {evaluate(model, test_dataset)}') 101 | 102 | 103 | if __name__ == "__main__": 104 | fire.Fire(train) 105 | 106 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | source activate 4 | 5 | dataset=chime4 6 | model_size=large-v3 7 | checkpoint= 8 | test_data= 9 | 10 | $cmd log/inference_${dataset}_${model_size}.log \ 11 | python inference.py \ 12 | --MODEL "openai/whisper-${model_size}" \ 13 | --DATASET ${dataset} \ 14 | --CKPT ${checkpoint} \ 15 | --TEST_DATA ${test_data} 16 | 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.25.0 3 | aiofiles==23.1.0 4 | aiohttp==3.8.4 5 | aiosignal==1.3.1 6 | altair==4.2.2 7 | antlr4-python3-runtime==4.8 8 | anyio==3.6.2 9 | appdirs==1.4.4 10 | arrow==1.2.3 11 | asteroid==0.6.0 12 | asteroid-filterbanks==0.4.0 13 | asttokens==2.2.1 14 | async-timeout==4.0.2 15 | attrs==22.2.0 16 | audioldm==0.0.21 17 | audioread==2.1.9 18 | backcall==0.2.0 19 | backoff==2.2.1 20 | backports.zoneinfo==0.2.1 21 | beautifulsoup4==4.12.2 22 | bitarray==2.8.2 23 | bitsandbytes==0.37.2 24 | black==23.3.0 25 | blessed==1.20.0 26 | blinker==1.6.2 27 | blis==0.7.9 28 | braceexpand==0.1.7 29 | Brotli==1.0.9 30 | brotlipy==0.7.0 31 | cached-property==1.5.2 32 | cachetools==5.3.0 33 | catalogue==2.0.8 34 | certifi @ file:///croot/certifi_1671487769961/work/certifi 35 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work 36 | cfgv==3.3.1 37 | chardet==5.1.0 38 | charset-normalizer==2.0.12 39 | click==8.1.3 40 | cmake==3.27.0 41 | colorama==0.4.6 42 | confection==0.0.4 43 | contexttimer==0.3.3 44 | contourpy==1.0.7 45 | croniter==1.4.1 46 | cryptography @ file:///croot/cryptography_1673298753778/work 47 | cycler==0.11.0 48 | cymem==2.0.7 49 | Cython==3.0.0 50 | dataclasses==0.6 51 | datasets==2.15.0 52 | dateutils==0.6.12 53 | decorator==5.1.1 54 | decord==0.6.0 55 | deepdiff==6.3.1 56 | deepspeed==0.9.0 57 | dill==0.3.6 58 | distlib==0.3.6 59 | docopt==0.6.2 60 | editdistance==0.6.2 61 | einops==0.6.0 62 | -e git+https://github.com/facebookresearch/encodec.git@0e2d0aed29362c8e8f52494baf3e6f99056b214f#egg=encodec 63 | entrypoints==0.4 64 | evaluate==0.4.0 65 | executing==1.2.0 66 | fairscale==0.4.4 67 | fairseq==0.12.2 68 | fast-bss-eval==0.1.4 69 | fastapi==0.94.1 70 | ffmpeg-python==0.2.0 71 | ffmpy==0.3.0 72 | filelock==3.12.0 73 | fire==0.5.0 74 | flit_core @ file:///opt/conda/conda-bld/flit-core_1644941570762/work/source/flit_core 75 | fonttools==4.39.0 76 | frozenlist==1.3.3 77 | fschat==0.2.33 78 | fsspec==2023.10.0 79 | ftfy==6.1.1 80 | future==0.18.3 81 | gitdb==4.0.10 82 | GitPython==3.1.31 83 | google-auth==2.16.0 84 | google-auth-oauthlib==0.4.6 85 | gradio==3.21.0 86 | grpcio==1.51.1 87 | h11==0.14.0 88 | hjson==3.1.0 89 | httpcore==0.16.3 90 | httpx==0.23.3 91 | huggingface-hub==0.19.4 92 | hydra-core==1.0.7 93 | identify==2.5.23 94 | idna @ file:///croot/idna_1666125576474/work 95 | imageio==2.28.0 96 | importlib-metadata==6.0.0 97 | importlib-resources==5.12.0 98 | inflate64==0.3.1 99 | inquirer==3.1.3 100 | iopath==0.1.10 101 | ipython==8.12.0 102 | itsdangerous==2.1.2 103 | jedi==0.18.2 104 | jieba==0.42.1 105 | Jinja2==3.1.2 106 | jiwer==3.0.2 107 | joblib==1.1.0 108 | jsonargparse==4.22.1 109 | jsonschema==4.17.3 110 | julius==0.2.7 111 | kaggle==1.5.13 112 | kiwisolver==1.4.4 113 | langcodes==3.3.0 114 | lazy_loader==0.2 115 | librosa==0.9.2 116 | lightning-cloud==0.5.37 117 | lightning-utilities==0.9.0 118 | linkify-it-py==1.0.3 119 | lit==16.0.6 120 | llvmlite==0.38.0 121 | loralib==0.1.1 122 | lxml==4.9.3 123 | Markdown==3.4.1 124 | markdown-it-py==2.1.0 125 | markdown2==2.4.11 126 | MarkupSafe==2.1.2 127 | matplotlib==3.7.1 128 | matplotlib-inline==0.1.6 129 | mdit-py-plugins==0.3.3 130 | mdurl==0.1.2 131 | mir-eval==0.7 132 | mkl-fft==1.3.1 133 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work 134 | mkl-service==2.4.0 135 | mlp-mixer-pytorch==0.1.1 136 | more-itertools==9.1.0 137 | multidict==6.0.4 138 | multiprocess==0.70.14 139 | multivolumefile==0.2.3 140 | murmurhash==1.0.9 141 | mypy-extensions==1.0.0 142 | networkx==3.1 143 | nh3==0.2.14 144 | ninja==1.11.1 145 | nlp-tools @ git+https://github.com/yuyusica/NLP_tools.git@f1e34676d5a82b28a09c3194dd6357aa33c3d434 146 | nltk==3.8.1 147 | nodeenv==1.7.0 148 | num2words==0.5.12 149 | numba==0.55.1 150 | numpy==1.21.1 151 | oauthlib==3.2.2 152 | omegaconf==2.0.6 153 | openai==0.27.6 154 | openai-whisper @ git+https://github.com/openai/whisper.git@0a60fcaa9b86748389a656aa013c416030287d47 155 | opencv-python==4.7.0.68 156 | opencv-python-headless==4.5.5.64 157 | opendatasets==0.1.22 158 | opt-einsum==3.3.0 159 | ordered-set==4.1.0 160 | orjson==3.8.7 161 | packaging==23.1 162 | pandas==1.5.3 163 | parso==0.8.3 164 | pathlib==1.0.1 165 | pathspec==0.11.1 166 | pathy==0.10.1 167 | pb-bss-eval==0.0.2 168 | peft @ git+https://github.com/huggingface/peft.git@e536616888d51b453ed354a6f1e243fecb02ea08 169 | pesq==0.0.4 170 | pexpect==4.8.0 171 | pickleshare==0.7.5 172 | Pillow==9.0.1 173 | pkgutil_resolve_name==1.3.10 174 | platformdirs==3.2.0 175 | plotly==5.14.1 176 | pooch==1.6.0 177 | portalocker==2.7.0 178 | pre-commit==3.2.2 179 | preshed==3.0.8 180 | progressbar==2.5 181 | prompt-toolkit==3.0.38 182 | protobuf==3.20.3 183 | psutil==5.9.5 184 | ptyprocess==0.7.0 185 | pure-eval==0.2.2 186 | py-cpuinfo==9.0.0 187 | py7zr==0.20.5 188 | pyarrow==11.0.0 189 | pyarrow-hotfix==0.6 190 | pyasn1==0.4.8 191 | pyasn1-modules==0.2.8 192 | pybcj==1.0.1 193 | pycocoevalcap==1.2 194 | pycocotools==2.0.6 195 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 196 | pycryptodomex==3.18.0 197 | pydantic==1.10.6 198 | pydeck==0.8.1b0 199 | pyDeprecate==0.3.2 200 | pydub==0.25.1 201 | Pygments==2.14.0 202 | PyJWT==2.8.0 203 | Pympler==1.0.1 204 | pynini==2.1.5 205 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work 206 | pyparsing==3.0.8 207 | pyppmd==1.0.0 208 | pyrsistent==0.19.3 209 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 210 | pystoi==0.3.3 211 | python-dateutil==2.8.2 212 | python-editor==1.0.4 213 | python-magic==0.4.27 214 | python-multipart==0.0.6 215 | python-slugify==8.0.1 216 | pytorch-crf==0.7.2 217 | pytorch-lightning==1.9.1 218 | pytorch-ranger==0.1.1 219 | pytz==2022.7.1 220 | pytz-deprecation-shim==0.1.0.post0 221 | PyWavelets==1.4.1 222 | PyYAML==6.0 223 | pyzstd==0.15.9 224 | rapidfuzz==2.13.7 225 | readchar==4.0.5 226 | regex==2022.10.31 227 | requests @ file:///opt/conda/conda-bld/requests_1657734628632/work 228 | requests-oauthlib==1.3.1 229 | resampy==0.2.2 230 | responses==0.18.0 231 | rfc3986==1.5.0 232 | rich==13.3.1 233 | rouge-score==0.1.2 234 | rsa==4.9 235 | sacrebleu==2.3.1 236 | safetensors==0.4.1 237 | scikit-image==0.20.0 238 | scikit-learn==1.0.2 239 | scipy==1.8.0 240 | seaborn==0.13.0 241 | sentencepiece==0.1.98 242 | seqeval==1.2.2 243 | shortuuid==1.0.11 244 | six @ file:///tmp/build/80754af9/six_1644875935023/work 245 | sklearn==0.0 246 | smart-open==6.3.0 247 | smmap==5.0.0 248 | sniffio==1.3.0 249 | soundfile==0.12.1 250 | soupsieve==2.4.1 251 | spacy==3.5.2 252 | spacy-legacy==3.0.12 253 | spacy-loggers==1.0.4 254 | -e git+https://github.com/ZhangXInFD/SpeechTokenizer.git@adc2c3fc65a2eb82efb4744b66905ed018d5895a#egg=speechtokenizer 255 | srsly==2.4.6 256 | stack-data==0.6.2 257 | starlette==0.26.1 258 | starsessions==1.3.0 259 | streamlit==1.21.0 260 | svgwrite==1.4.3 261 | tabulate==0.9.0 262 | tenacity==8.2.2 263 | tensorboard==2.12.0 264 | tensorboard-data-server==0.7.0 265 | tensorboard-plugin-wit==1.8.1 266 | termcolor==2.2.0 267 | text-unidecode==1.3 268 | texttable==1.6.7 269 | thinc==8.1.9 270 | threadpoolctl==3.1.0 271 | tifffile==2023.4.12 272 | tiktoken==0.3.3 273 | timm==0.4.12 274 | tn==0.0.4 275 | tokenize-rt==5.0.0 276 | tokenizers==0.15.0 277 | toml==0.10.2 278 | tomli==2.0.1 279 | toolz==0.12.0 280 | torch==1.13.1 281 | torch-mir-eval==0.4 282 | torch-optimizer==0.1.0 283 | torch-stoi==0.1.2 284 | torchaudio==0.13.1 285 | torchlibrosa==0.0.9 286 | torchmetrics==0.7.3 287 | torchvision==0.14.1 288 | tornado==6.3.1 289 | tqdm==4.64.1 290 | traitlets==5.9.0 291 | transformers @ git+https://github.com/huggingface/transformers.git@e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2 292 | triton==2.0.0 293 | typer==0.7.0 294 | typing_extensions @ file:///croot/typing_extensions_1669924550328/work 295 | tzdata==2023.3 296 | tzlocal==4.3 297 | uc-micro-py==1.0.1 298 | urllib3 @ file:///croot/urllib3_1673575502006/work 299 | uvicorn==0.21.0 300 | validators==0.20.0 301 | virtualenv==20.22.0 302 | wasabi==1.1.1 303 | watchdog==3.0.0 304 | wavedrom==2.0.3.post3 305 | wcwidth==0.2.6 306 | webdataset==0.2.48 307 | websocket-client==1.6.1 308 | websockets==10.4 309 | Werkzeug==2.2.3 310 | WeTextProcessing==0.1.2 311 | -e git+https://github.com/patrickvonplaten/whisper.git@d1690a4d0e0802dcc05ff562f89b4ca65324ceee#egg=whisper 312 | whisper-normalizer==0.0.2 313 | xxhash==3.2.0 314 | yarl==1.8.2 315 | zipp==3.13.0 316 | -------------------------------------------------------------------------------- /star.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YUCHEN005/STAR-Adapt/3003fd20047ce56888911e8f2bd46100455e3215/star.png --------------------------------------------------------------------------------