├── .DS_Store
├── README.md
├── finetune.py
├── finetune.sh
├── inference.py
├── inference.sh
├── requirements.txt
└── star.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YUCHEN005/STAR-Adapt/3003fd20047ce56888911e8f2bd46100455e3215/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Taught Recognizer: Toward Unsupervised Adaptation for Speech Foundation Models
2 |
3 | [[Paper]](https://arxiv.org/pdf/2405.14161)
4 |
5 |
6 |
7 | This work proposes a source-free unsupervised domain adaptation approach for speech foundation models.
8 |
9 | ## Conda Environment Configuration
10 |
11 | Our conda environment is provided via the file `requirements.txt`, please run the command below to install necessary packages:
12 | ```bash
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Data Preparation
17 |
18 | Our code requires two kaldi-format data files: `wav.scp` and `text`.
19 |
20 | - `wav.scp` contains a list of audio files, each line includes sample ID and absolute audio path:
21 |
22 | ```
23 | utt_1 /your-data-path/1.wav
24 | utt_2 /your-data-path/2.wav
25 | ```
26 |
27 | - `text` contains a list of ground-truth transcriptions, each line includes sample ID and transcription:
28 |
29 | ```
30 | utt_1 i feel good
31 | utt_2 he is coming back
32 | ```
33 |
34 | **NOTE:** each line in above two files should be paired.
35 |
36 |
37 | ## Training
38 | Please refer to our training script `finetune.sh` and specify some settings:
39 | - `dataset`: training data name;
40 | - `model_size`: whisper model size;
41 | - `train_data`: training data directory that contains files `wav.scp` and `text`;
42 | - `dev_data`: development data directory that contains files `wav.scp` and `text`;
43 |
44 | Then, please run command `bash finetune.sh` to start training. The model weights will be saved at `runs/{dataset}_{model_size}`.
45 |
46 |
47 | ## Inference
48 | Please refer to our inference script `inference.sh` and specify some settings:
49 | - `dataset`: training data name;
50 | - `model_size`: whisper model size;
51 | - `checkpoint`: path of the trained model checkpoint (`.pth` file);
52 | - `test_data`: test data directory that contains files `wav.scp` and `text`;
53 |
54 | Please run command `bash inference.sh` for inference. WER results would be printed in the log.
55 |
56 |
57 | ## References
58 |
59 | We kindly hope you can cite our paper in your publication when using our research or code:
60 | ```bib
61 | @article{hu2024self,
62 | title={Self-Taught Recognizer: Toward Unsupervised Adaptation for Speech Foundation Models},
63 | author={Hu, Yuchen and Chen, Chen and Yang, Chao-Han Huck and Qin, Chengwei and Chen, Pin-Yu and Chng, Eng Siong and Zhang, Chao},
64 | journal={arXiv preprint arXiv:2405.14161},
65 | year={2024}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
2 | from transformers import AutoFeatureExtractor, WhisperModel
3 | from transformers import LlamaTokenizer
4 | from datasets import load_dataset
5 | import torch, torchaudio
6 | from torch import nn
7 | import numpy as np
8 | from jiwer import wer as calculate_wer
9 | import pickle
10 | import fire
11 | from datasets import Dataset, Audio, Value
12 | import os, random
13 | from typing import Optional
14 | from whisper.normalizers import EnglishTextNormalizer
15 | import math
16 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
17 | from pathlib import Path
18 | import whisper
19 | import copy, heapq
20 | normalizer = EnglishTextNormalizer()
21 |
22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 |
24 | def sigmoid(x):
25 | return 1 / (1 + np.exp(-x))
26 |
27 | def train(
28 | MODEL = "openai/whisper-large-v3",
29 | DATASET = "chime4",
30 | TRAIN_DATA = "",
31 | DEV_DATA = "",
32 | SAVE_EVERY = 10,
33 | BATCH_SIZE = 32,
34 | GRADIENT_ACCUMULATION_STEPS = 4,
35 | LEARNING_RATE = 1e-3,
36 | EPOCHS = 100,
37 | THRESOLD=2.0,
38 | TOP_PERCENT=0.8,
39 | TAU=10,
40 | ):
41 | feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL)
42 | processor = WhisperProcessor.from_pretrained(MODEL, language="en", task="transcribe")
43 | tokenizer = WhisperTokenizer.from_pretrained(MODEL, language="en", task="transcribe")
44 | model = WhisperForConditionalGeneration.from_pretrained(MODEL).to(device)
45 | forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
46 | state_dict = copy.deepcopy(model.state_dict())
47 |
48 | prompt_and_eos = tokenizer('')['input_ids']
49 | prompt_ids, eos_id = prompt_and_eos[:-1], prompt_and_eos[-1]
50 |
51 | def data_preparation(data_path, feature_extractor, tokenizer):
52 | with open(data_path + "wav.scp", 'r') as f1:
53 | wave_data = f1.readlines()
54 | with open(data_path + "text", 'r') as f2:
55 | trans_data = f2.readlines()
56 |
57 | audio_data, txt_data = [], []
58 | for i in range(len(wave_data)):
59 | audio_data.append(wave_data[i])
60 | txt_data.append(trans_data[i])
61 |
62 | audio_dataset = []
63 | all_pred, all_gt = [], []
64 | for audio_line, text_line in zip(audio_data, txt_data):
65 | audio_path = audio_line.strip().split()[1]
66 | text = ' '.join(text_line.split()[1:]).lower().strip()
67 | audio, sr = torchaudio.load(audio_path)
68 | if sr != 16000:
69 | audio = torchaudio.functional.resample(audio, sr, 16000)
70 | item = {'audio': audio, 'text': text}
71 |
72 | item['mel'] = feature_extractor(audio.squeeze(0).numpy(), sampling_rate=16_000, return_tensors="pt")['input_features']
73 | item['decoder_input_ids'] = tokenizer(text, max_length=1024, truncation=True).input_ids
74 |
75 | model.load_state_dict(state_dict)
76 | hidden_feature = model.model.encoder(input_features=item['mel'].to(device)).last_hidden_state
77 |
78 | # prompt: '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>'
79 | pseudo_label_ids = torch.tensor([prompt_ids]).long().to(device)
80 |
81 | ### probs: confidence score
82 | probs, decoder_outputs = [], None
83 | for _ in range(150):
84 | decoder_outputs = model(encoder_outputs=(hidden_feature), decoder_input_ids=pseudo_label_ids, output_attentions=True)
85 | logits = torch.softmax(decoder_outputs.logits / 1.2, dim=-1)
86 | next_token = logits[0, -1, :].topk(1)[1]
87 | probs.append(float(logits[0, -1, next_token]))
88 | pseudo_label_ids = torch.cat((pseudo_label_ids, next_token.unsqueeze(0)), dim=-1)
89 | if next_token == eos_id: # EOS
90 | break
91 |
92 | # normlization
93 | mean_probs = sum(probs) / len(probs)
94 | for k in range(len(probs)):
95 | probs[k] = round(probs[k] / mean_probs, 3)
96 |
97 | ### weights: attentive score
98 | n_prompt_toks = 4
99 | layer_id, head_id = 30, 13 # suggest: layer_id \in [30,31], head_id \in [0,1,...,19]
100 | attn = decoder_outputs.decoder_attentions[layer_id][0, head_id, :, :]
101 | attn[:, :n_prompt_toks-1] = 0 # remove prompts
102 | weights = []
103 | for i in range(n_prompt_toks-1, attn.shape[-1]):
104 | weight = torch.sum(attn[i, :]) + torch.sum(attn[:, i]) - attn[i, i]
105 | weights.append(float(weight))
106 |
107 | # normalization
108 | mean_weights = sum(weights) / len(weights)
109 | for j in range(len(weights)):
110 | weights[j] = round(weights[j] / mean_weights, 3)
111 |
112 | ### final_weights: star score
113 | final_weights = []
114 | for ci, ai in zip(probs, weights):
115 | c_over_a, a_over_c = ci * ci / ai, ai * ai / ci
116 | conflict = (sigmoid((c_over_a - THRESOLD) * TAU) + sigmoid((a_over_c - THRESOLD) * TAU)) * ai
117 | no_conflict = (sigmoid((THRESOLD - c_over_a) * TAU) * sigmoid((THRESOLD - a_over_c) * TAU)) * ai * np.exp((ci - ai) / TAU)
118 | final_weights.append(conflict + no_conflict)
119 |
120 | item['pseudo_label_ids'] = pseudo_label_ids
121 | item['probs'] = torch.tensor(final_weights).unsqueeze(0)
122 | pseudo_text = processor.batch_decode(pseudo_label_ids, skip_special_tokens=True)[0]
123 |
124 | ### utt-level uncertainty
125 | if 'train' in data_path:
126 | avg_wer, generated_texts = 0, []
127 | for _ in range(5):
128 | new_state_dict = copy.deepcopy(state_dict)
129 | for k in new_state_dict.keys():
130 | std = torch.std(new_state_dict[k])
131 | noise = torch.randn_like(new_state_dict[k])
132 | new_state_dict[k] = new_state_dict[k] + noise * std * 0.1
133 |
134 | model.load_state_dict(new_state_dict)
135 | generated_ids = model.generate(inputs=item['mel'].to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150)
136 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
137 | generated_texts.append(generated_text)
138 | avg_wer += calculate_wer([pseudo_text], [generated_text]) / 5
139 |
140 | item['avg_wer'] = avg_wer
141 | item['diversity'] = len(list(set(generated_texts)))
142 |
143 | ## text normalization
144 | pseudo_text = normalizer(pseudo_text)
145 | pseudo_text = pseudo_text if len(pseudo_text) > 0 else ''
146 |
147 | gt = normalizer(text)
148 | gt = gt if len(gt) > 0 else ''
149 |
150 | audio_dataset.append(item)
151 | all_pred.append(pseudo_text)
152 | all_gt.append(gt)
153 |
154 | model.load_state_dict(state_dict)
155 | return audio_dataset, calculate_wer(all_gt, all_pred)
156 |
157 |
158 | def evaluate(model, dataset):
159 | with torch.no_grad():
160 | all_pred, all_gt = [], []
161 | for item in dataset:
162 | mel = item['mel']
163 | generated_ids = model.generate(inputs=mel.to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150)
164 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
165 |
166 | ## text normalization
167 | pred = normalizer(generated_text)
168 | pred = pred if len(pred) > 0 else ''
169 |
170 | gt = normalizer(item['text'])
171 | gt = gt if len(gt) > 0 else ''
172 |
173 | all_pred.append(pred)
174 | all_gt.append(gt)
175 |
176 | return calculate_wer(all_gt, all_pred)
177 |
178 |
179 | model.eval()
180 | train_dataset, train_wer = data_preparation(TRAIN_DATA, feature_extractor, tokenizer)
181 | dev_dataset, dev_wer = data_preparation(DEV_DATA, feature_extractor, tokenizer)
182 | os.system('mkdir -p data')
183 | torch.save(train_dataset, f'data/train_{DATASET}.pt')
184 | torch.save(dev_dataset, f'data/dev_{DATASET}.pt')
185 | model.train()
186 |
187 | ## load saved data
188 | # train_dataset = torch.load(f'data/train_{DATASET}.pt')
189 | # dev_dataset = torch.load(f'data/dev_{DATASET}.pt')
190 |
191 | ## utt-level filtering
192 | def product(item):
193 | return item['avg_wer'] * item['diversity']
194 | filtered_train_dataset = heapq.nsmallest(int(len(train_dataset) * TOP_PERCENT), train_dataset, key=product)
195 |
196 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
197 | loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
198 |
199 | model_size = MODEL.replace('openai/whisper-', '')
200 | exp_dir = f'runs/{DATASET}_{model_size}'
201 | os.system(f"mkdir -p {exp_dir}")
202 |
203 | steps, loss = 0, 0
204 | best_loss, best_wer = 10000, 10000
205 | for Epoch in range(EPOCHS):
206 | print("Epoch: ", Epoch + 1)
207 |
208 | # Train
209 | random.shuffle(filtered_train_dataset)
210 | print('Training...')
211 | for i in range(len(filtered_train_dataset) // BATCH_SIZE):
212 | batch_data = filtered_train_dataset[i * BATCH_SIZE: (i+1) * BATCH_SIZE]
213 |
214 | input_features = [{"input_features": item["mel"]} for item in batch_data]
215 | mel = processor.feature_extractor.pad(input_features, return_tensors="pt")["input_features"].squeeze(1).to(device)
216 |
217 | labels = batch_data[0]["pseudo_label_ids"].to(device)
218 | y_in = labels[:, :-1]
219 | y_out = labels[:, 1:]
220 |
221 | logits = model(input_features=mel, decoder_input_ids=y_in).logits
222 | loss_items = loss_fn(logits.permute(0, 2, 1), y_out)
223 |
224 | # uncertainty calibration
225 | ratios = batch_data[0]['probs'].to(device)
226 | ratios = ratios / torch.mean(ratios)
227 | loss = (torch.sum(loss_items[:, :n_prompt_toks-1]) + torch.sum(loss_items[:, n_prompt_toks-1:] * ratios)) / (n_prompt_toks-1 + ratios.shape[-1])
228 |
229 | loss.backward()
230 | steps += 1
231 |
232 | if steps % GRADIENT_ACCUMULATION_STEPS == 0:
233 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
234 | optimizer.step()
235 | optimizer.zero_grad()
236 |
237 | if steps % SAVE_EVERY == 0: # Evaluate
238 | torch.save(model, f"{exp_dir}/Iter_{steps}.pth")
239 |
240 | model.eval()
241 | dev_wer = evaluate(model, dev_dataset)
242 | model.train()
243 |
244 | if dev_wer < best_wer or (dev_wer == best_wer and loss < best_loss):
245 | torch.save(model, f"{exp_dir}/best_checkpoint.pth")
246 | best_loss, best_wer = loss, dev_wer
247 |
248 | torch.save(model, f"{exp_dir}/last_checkpoint.pth")
249 |
250 |
251 | if __name__ == "__main__":
252 | fire.Fire(train)
253 |
254 |
--------------------------------------------------------------------------------
/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source activate
4 |
5 | dataset=chime4
6 | model_size=large-v3
7 | train_data=
8 | dev_data=
9 |
10 | $cmd log/finetune_${dataset}_${model_size}.log \
11 | python finetune.py \
12 | --MODEL "openai/whisper-${model_size}" \
13 | --DATASET ${dataset} \
14 | --TRAIN_DATA ${train_data} \
15 | --DEV_DATA ${dev_data} \
16 | --BATCH_SIZE 1 \
17 | --GRADIENT_ACCUMULATION_STEPS 16 \
18 | --LEARNING_RATE 1e-5 \
19 | --EPOCHS 2 \
20 |
21 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
2 | from transformers import AutoFeatureExtractor, WhisperModel
3 | from transformers import LlamaTokenizer
4 | from datasets import load_dataset
5 | import torch, torchaudio
6 | from torch import nn
7 | import numpy as np
8 | from jiwer import wer as calculate_wer
9 | import pickle
10 | import fire
11 | from datasets import Dataset, Audio, Value
12 | import os, random
13 | from typing import Optional
14 | from whisper.normalizers import EnglishTextNormalizer
15 | import math
16 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
17 | from pathlib import Path
18 | import whisper
19 | import copy, heapq
20 | normalizer = EnglishTextNormalizer()
21 |
22 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 |
24 | def sigmoid(x):
25 | return 1 / (1 + np.exp(-x))
26 |
27 | def train(
28 | MODEL = "openai/whisper-large-v3",
29 | DATASET = "chime4",
30 | TEST_DATA = "",
31 | CKPT = "",
32 | ):
33 | feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL)
34 | processor = WhisperProcessor.from_pretrained(MODEL, language="en", task="transcribe")
35 | tokenizer = WhisperTokenizer.from_pretrained(MODEL, language="en", task="transcribe")
36 | forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
37 |
38 |
39 | def data_preparation(data_path, feature_extractor, tokenizer):
40 | with open(data_path + "wav.scp", 'r') as f1:
41 | wave_data = f1.readlines()
42 | with open(data_path + "text", 'r') as f2:
43 | trans_data = f2.readlines()
44 |
45 | audio_data, txt_data = [], []
46 | for i in range(len(wave_data)):
47 | audio_data.append(wave_data[i])
48 | txt_data.append(trans_data[i])
49 |
50 | audio_dataset = []
51 | for audio_line, text_line in zip(audio_data, txt_data):
52 | audio_path = audio_line.strip().split()[1]
53 | audio, sr = torchaudio.load(audio_path)
54 | if sr != 16000:
55 | audio = torchaudio.functional.resample(audio, sr, 16000)
56 | mel = feature_extractor(audio.squeeze(0).numpy(), sampling_rate=16_000, return_tensors="pt")['input_features']
57 | text = ' '.join(text_line.split()[1:]).lower().strip()
58 |
59 | item = {'mel': mel, 'text': text}
60 | audio_dataset.append(item)
61 |
62 | return audio_dataset
63 |
64 |
65 | def evaluate(model, dataset):
66 | with torch.no_grad():
67 | all_pred, all_gt = [], []
68 | for item in dataset:
69 | mel = item['mel']
70 | generated_ids = model.generate(inputs=mel.to(device), forced_decoder_ids=forced_decoder_ids, max_new_tokens=150)
71 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72 |
73 | ## text normalization
74 | pred = normalizer(generated_text)
75 | pred = pred if len(pred) > 0 else ''
76 |
77 | gt = normalizer(item['text'])
78 | gt = gt if len(gt) > 0 else ''
79 |
80 | all_pred.append(pred)
81 | all_gt.append(gt)
82 |
83 | return calculate_wer(all_gt, all_pred)
84 |
85 |
86 | ## prepare dataset
87 | test_dataset = data_preparation(TEST_DATA, feature_extractor, tokenizer)
88 | torch.save(test_dataset, f'data/test_{DATASET}.pt')
89 | print(f'{DATASET}:')
90 | # test_dataset = torch.load(f'data/test_{DATASET}.pt')
91 |
92 | ## evaluate official whisper (only need to run once)
93 | model = WhisperForConditionalGeneration.from_pretrained(MODEL).to(device)
94 | model.eval()
95 | print(f'zero-shot = {evaluate(model, test_dataset)}')
96 |
97 | ## evaluate star adapted whisper
98 | model = torch.load(CKPT).to(device)
99 | model.eval()
100 | print(f'star = {evaluate(model, test_dataset)}')
101 |
102 |
103 | if __name__ == "__main__":
104 | fire.Fire(train)
105 |
106 |
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | source activate
4 |
5 | dataset=chime4
6 | model_size=large-v3
7 | checkpoint=
8 | test_data=
9 |
10 | $cmd log/inference_${dataset}_${model_size}.log \
11 | python inference.py \
12 | --MODEL "openai/whisper-${model_size}" \
13 | --DATASET ${dataset} \
14 | --CKPT ${checkpoint} \
15 | --TEST_DATA ${test_data}
16 |
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | accelerate==0.25.0
3 | aiofiles==23.1.0
4 | aiohttp==3.8.4
5 | aiosignal==1.3.1
6 | altair==4.2.2
7 | antlr4-python3-runtime==4.8
8 | anyio==3.6.2
9 | appdirs==1.4.4
10 | arrow==1.2.3
11 | asteroid==0.6.0
12 | asteroid-filterbanks==0.4.0
13 | asttokens==2.2.1
14 | async-timeout==4.0.2
15 | attrs==22.2.0
16 | audioldm==0.0.21
17 | audioread==2.1.9
18 | backcall==0.2.0
19 | backoff==2.2.1
20 | backports.zoneinfo==0.2.1
21 | beautifulsoup4==4.12.2
22 | bitarray==2.8.2
23 | bitsandbytes==0.37.2
24 | black==23.3.0
25 | blessed==1.20.0
26 | blinker==1.6.2
27 | blis==0.7.9
28 | braceexpand==0.1.7
29 | Brotli==1.0.9
30 | brotlipy==0.7.0
31 | cached-property==1.5.2
32 | cachetools==5.3.0
33 | catalogue==2.0.8
34 | certifi @ file:///croot/certifi_1671487769961/work/certifi
35 | cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work
36 | cfgv==3.3.1
37 | chardet==5.1.0
38 | charset-normalizer==2.0.12
39 | click==8.1.3
40 | cmake==3.27.0
41 | colorama==0.4.6
42 | confection==0.0.4
43 | contexttimer==0.3.3
44 | contourpy==1.0.7
45 | croniter==1.4.1
46 | cryptography @ file:///croot/cryptography_1673298753778/work
47 | cycler==0.11.0
48 | cymem==2.0.7
49 | Cython==3.0.0
50 | dataclasses==0.6
51 | datasets==2.15.0
52 | dateutils==0.6.12
53 | decorator==5.1.1
54 | decord==0.6.0
55 | deepdiff==6.3.1
56 | deepspeed==0.9.0
57 | dill==0.3.6
58 | distlib==0.3.6
59 | docopt==0.6.2
60 | editdistance==0.6.2
61 | einops==0.6.0
62 | -e git+https://github.com/facebookresearch/encodec.git@0e2d0aed29362c8e8f52494baf3e6f99056b214f#egg=encodec
63 | entrypoints==0.4
64 | evaluate==0.4.0
65 | executing==1.2.0
66 | fairscale==0.4.4
67 | fairseq==0.12.2
68 | fast-bss-eval==0.1.4
69 | fastapi==0.94.1
70 | ffmpeg-python==0.2.0
71 | ffmpy==0.3.0
72 | filelock==3.12.0
73 | fire==0.5.0
74 | flit_core @ file:///opt/conda/conda-bld/flit-core_1644941570762/work/source/flit_core
75 | fonttools==4.39.0
76 | frozenlist==1.3.3
77 | fschat==0.2.33
78 | fsspec==2023.10.0
79 | ftfy==6.1.1
80 | future==0.18.3
81 | gitdb==4.0.10
82 | GitPython==3.1.31
83 | google-auth==2.16.0
84 | google-auth-oauthlib==0.4.6
85 | gradio==3.21.0
86 | grpcio==1.51.1
87 | h11==0.14.0
88 | hjson==3.1.0
89 | httpcore==0.16.3
90 | httpx==0.23.3
91 | huggingface-hub==0.19.4
92 | hydra-core==1.0.7
93 | identify==2.5.23
94 | idna @ file:///croot/idna_1666125576474/work
95 | imageio==2.28.0
96 | importlib-metadata==6.0.0
97 | importlib-resources==5.12.0
98 | inflate64==0.3.1
99 | inquirer==3.1.3
100 | iopath==0.1.10
101 | ipython==8.12.0
102 | itsdangerous==2.1.2
103 | jedi==0.18.2
104 | jieba==0.42.1
105 | Jinja2==3.1.2
106 | jiwer==3.0.2
107 | joblib==1.1.0
108 | jsonargparse==4.22.1
109 | jsonschema==4.17.3
110 | julius==0.2.7
111 | kaggle==1.5.13
112 | kiwisolver==1.4.4
113 | langcodes==3.3.0
114 | lazy_loader==0.2
115 | librosa==0.9.2
116 | lightning-cloud==0.5.37
117 | lightning-utilities==0.9.0
118 | linkify-it-py==1.0.3
119 | lit==16.0.6
120 | llvmlite==0.38.0
121 | loralib==0.1.1
122 | lxml==4.9.3
123 | Markdown==3.4.1
124 | markdown-it-py==2.1.0
125 | markdown2==2.4.11
126 | MarkupSafe==2.1.2
127 | matplotlib==3.7.1
128 | matplotlib-inline==0.1.6
129 | mdit-py-plugins==0.3.3
130 | mdurl==0.1.2
131 | mir-eval==0.7
132 | mkl-fft==1.3.1
133 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
134 | mkl-service==2.4.0
135 | mlp-mixer-pytorch==0.1.1
136 | more-itertools==9.1.0
137 | multidict==6.0.4
138 | multiprocess==0.70.14
139 | multivolumefile==0.2.3
140 | murmurhash==1.0.9
141 | mypy-extensions==1.0.0
142 | networkx==3.1
143 | nh3==0.2.14
144 | ninja==1.11.1
145 | nlp-tools @ git+https://github.com/yuyusica/NLP_tools.git@f1e34676d5a82b28a09c3194dd6357aa33c3d434
146 | nltk==3.8.1
147 | nodeenv==1.7.0
148 | num2words==0.5.12
149 | numba==0.55.1
150 | numpy==1.21.1
151 | oauthlib==3.2.2
152 | omegaconf==2.0.6
153 | openai==0.27.6
154 | openai-whisper @ git+https://github.com/openai/whisper.git@0a60fcaa9b86748389a656aa013c416030287d47
155 | opencv-python==4.7.0.68
156 | opencv-python-headless==4.5.5.64
157 | opendatasets==0.1.22
158 | opt-einsum==3.3.0
159 | ordered-set==4.1.0
160 | orjson==3.8.7
161 | packaging==23.1
162 | pandas==1.5.3
163 | parso==0.8.3
164 | pathlib==1.0.1
165 | pathspec==0.11.1
166 | pathy==0.10.1
167 | pb-bss-eval==0.0.2
168 | peft @ git+https://github.com/huggingface/peft.git@e536616888d51b453ed354a6f1e243fecb02ea08
169 | pesq==0.0.4
170 | pexpect==4.8.0
171 | pickleshare==0.7.5
172 | Pillow==9.0.1
173 | pkgutil_resolve_name==1.3.10
174 | platformdirs==3.2.0
175 | plotly==5.14.1
176 | pooch==1.6.0
177 | portalocker==2.7.0
178 | pre-commit==3.2.2
179 | preshed==3.0.8
180 | progressbar==2.5
181 | prompt-toolkit==3.0.38
182 | protobuf==3.20.3
183 | psutil==5.9.5
184 | ptyprocess==0.7.0
185 | pure-eval==0.2.2
186 | py-cpuinfo==9.0.0
187 | py7zr==0.20.5
188 | pyarrow==11.0.0
189 | pyarrow-hotfix==0.6
190 | pyasn1==0.4.8
191 | pyasn1-modules==0.2.8
192 | pybcj==1.0.1
193 | pycocoevalcap==1.2
194 | pycocotools==2.0.6
195 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
196 | pycryptodomex==3.18.0
197 | pydantic==1.10.6
198 | pydeck==0.8.1b0
199 | pyDeprecate==0.3.2
200 | pydub==0.25.1
201 | Pygments==2.14.0
202 | PyJWT==2.8.0
203 | Pympler==1.0.1
204 | pynini==2.1.5
205 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
206 | pyparsing==3.0.8
207 | pyppmd==1.0.0
208 | pyrsistent==0.19.3
209 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
210 | pystoi==0.3.3
211 | python-dateutil==2.8.2
212 | python-editor==1.0.4
213 | python-magic==0.4.27
214 | python-multipart==0.0.6
215 | python-slugify==8.0.1
216 | pytorch-crf==0.7.2
217 | pytorch-lightning==1.9.1
218 | pytorch-ranger==0.1.1
219 | pytz==2022.7.1
220 | pytz-deprecation-shim==0.1.0.post0
221 | PyWavelets==1.4.1
222 | PyYAML==6.0
223 | pyzstd==0.15.9
224 | rapidfuzz==2.13.7
225 | readchar==4.0.5
226 | regex==2022.10.31
227 | requests @ file:///opt/conda/conda-bld/requests_1657734628632/work
228 | requests-oauthlib==1.3.1
229 | resampy==0.2.2
230 | responses==0.18.0
231 | rfc3986==1.5.0
232 | rich==13.3.1
233 | rouge-score==0.1.2
234 | rsa==4.9
235 | sacrebleu==2.3.1
236 | safetensors==0.4.1
237 | scikit-image==0.20.0
238 | scikit-learn==1.0.2
239 | scipy==1.8.0
240 | seaborn==0.13.0
241 | sentencepiece==0.1.98
242 | seqeval==1.2.2
243 | shortuuid==1.0.11
244 | six @ file:///tmp/build/80754af9/six_1644875935023/work
245 | sklearn==0.0
246 | smart-open==6.3.0
247 | smmap==5.0.0
248 | sniffio==1.3.0
249 | soundfile==0.12.1
250 | soupsieve==2.4.1
251 | spacy==3.5.2
252 | spacy-legacy==3.0.12
253 | spacy-loggers==1.0.4
254 | -e git+https://github.com/ZhangXInFD/SpeechTokenizer.git@adc2c3fc65a2eb82efb4744b66905ed018d5895a#egg=speechtokenizer
255 | srsly==2.4.6
256 | stack-data==0.6.2
257 | starlette==0.26.1
258 | starsessions==1.3.0
259 | streamlit==1.21.0
260 | svgwrite==1.4.3
261 | tabulate==0.9.0
262 | tenacity==8.2.2
263 | tensorboard==2.12.0
264 | tensorboard-data-server==0.7.0
265 | tensorboard-plugin-wit==1.8.1
266 | termcolor==2.2.0
267 | text-unidecode==1.3
268 | texttable==1.6.7
269 | thinc==8.1.9
270 | threadpoolctl==3.1.0
271 | tifffile==2023.4.12
272 | tiktoken==0.3.3
273 | timm==0.4.12
274 | tn==0.0.4
275 | tokenize-rt==5.0.0
276 | tokenizers==0.15.0
277 | toml==0.10.2
278 | tomli==2.0.1
279 | toolz==0.12.0
280 | torch==1.13.1
281 | torch-mir-eval==0.4
282 | torch-optimizer==0.1.0
283 | torch-stoi==0.1.2
284 | torchaudio==0.13.1
285 | torchlibrosa==0.0.9
286 | torchmetrics==0.7.3
287 | torchvision==0.14.1
288 | tornado==6.3.1
289 | tqdm==4.64.1
290 | traitlets==5.9.0
291 | transformers @ git+https://github.com/huggingface/transformers.git@e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2
292 | triton==2.0.0
293 | typer==0.7.0
294 | typing_extensions @ file:///croot/typing_extensions_1669924550328/work
295 | tzdata==2023.3
296 | tzlocal==4.3
297 | uc-micro-py==1.0.1
298 | urllib3 @ file:///croot/urllib3_1673575502006/work
299 | uvicorn==0.21.0
300 | validators==0.20.0
301 | virtualenv==20.22.0
302 | wasabi==1.1.1
303 | watchdog==3.0.0
304 | wavedrom==2.0.3.post3
305 | wcwidth==0.2.6
306 | webdataset==0.2.48
307 | websocket-client==1.6.1
308 | websockets==10.4
309 | Werkzeug==2.2.3
310 | WeTextProcessing==0.1.2
311 | -e git+https://github.com/patrickvonplaten/whisper.git@d1690a4d0e0802dcc05ff562f89b4ca65324ceee#egg=whisper
312 | whisper-normalizer==0.0.2
313 | xxhash==3.2.0
314 | yarl==1.8.2
315 | zipp==3.13.0
316 |
--------------------------------------------------------------------------------
/star.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YUCHEN005/STAR-Adapt/3003fd20047ce56888911e8f2bd46100455e3215/star.png
--------------------------------------------------------------------------------