├── .gitignore ├── Whisper ├── __main__.py ├── assets │ ├── multilingual │ │ ├── added_tokens.json │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json │ ├── mel_filters.npz │ └── gpt2 │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json ├── normalizers │ ├── __init__.py │ ├── basic.py │ └── english.py ├── utils.py ├── audio.py ├── __init__.py ├── tokenizer.py ├── model.py └── transcribe.py ├── requirements.txt ├── criterion ├── metric_util.py └── mix_criterions.py ├── modules ├── __init__.py ├── cascade.py ├── whisper.py ├── mbart.py ├── whisper_asr.py └── comsl.py ├── config ├── exp_spec │ ├── cascade.yaml │ ├── mbart.yaml │ ├── whisper_asr.yaml │ ├── whisper.yaml │ └── comsl.yaml └── parse_yaml_args.py ├── data ├── lang_dict.csv ├── data_util.py └── dataset.py ├── model ├── optimizer.py ├── mBART_model.py ├── model_util.py └── ComSL_model.py ├── README.md ├── run.py └── create_pseudo_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .history 3 | .vscode 4 | .DS_Store -------------------------------------------------------------------------------- /Whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | 4 | cli() 5 | -------------------------------------------------------------------------------- /Whisper/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /Whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nethermanpro/ComSL/HEAD/Whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /Whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer 2 | from .english import EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /Whisper/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /Whisper/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /Whisper/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed==0.9.2 2 | ffmpeg==1.4 3 | more_itertools==9.1.0 4 | numpy==1.23.0 5 | pandas==1.4.2 6 | pytorch_lightning==2.0.2 7 | PyYAML==6.0 8 | PyYAML==6.0 9 | regex==2022.10.31 10 | sacrebleu==2.3.1 11 | torch==1.13.1 12 | torchaudio==0.13.1 13 | torchmetrics==0.11.4 14 | tqdm==4.64.0 15 | transformers==4.29.1 16 | -------------------------------------------------------------------------------- /criterion/metric_util.py: -------------------------------------------------------------------------------- 1 | from sacrebleu.metrics.bleu import _get_tokenizer 2 | 3 | 4 | def get_segment_tokenizers(): 5 | return { 6 | "zh": _get_tokenizer("zh")(), 7 | "ja": _get_tokenizer("ja-mecab")(), 8 | "default": _get_tokenizer("13a")() 9 | } 10 | 11 | 12 | def preprocess_sentence(s_list, lang, tokenizers): 13 | for i in range(len(s_list)): 14 | tokenizer = tokenizers.get(lang[i], tokenizers["default"]) 15 | s_list[i] = "".join(tokenizer(s_list[i].rstrip())) 16 | -------------------------------------------------------------------------------- /Whisper/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | def get_module(name): 2 | if name == "mbart": 3 | from modules.mbart import MbartModelModule 4 | return MbartModelModule 5 | elif name == "whisper": 6 | from modules.whisper import WhisperModelModule 7 | return WhisperModelModule 8 | elif name == "whisper_asr": 9 | from modules.whisper_asr import WhisperAsrModelModule 10 | return WhisperAsrModelModule 11 | elif name == "cascade": 12 | from modules.cascade import CascadeModelModule 13 | return CascadeModelModule 14 | elif name == "comst": 15 | from modules.comsl import ComSTModule 16 | return ComSTModule 17 | else: 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /config/exp_spec/cascade.yaml: -------------------------------------------------------------------------------- 1 | module_name: cascade 2 | train_name: cascade 3 | train_id: default 4 | whisper_name: large 5 | monitor: None 6 | num_nodes: 2 7 | test_batch_size: 5 8 | num_worker: 4 9 | num_train_epochs: 10 10 | gradient_accumulation_steps: 2 11 | chunk_size: 30 12 | 13 | data_root: null 14 | cv_data_root: null 15 | output_dir: null 16 | 17 | asr_model_path: whisper_large_asr.pt 18 | mbart_model_path: mbart_x_en.pt 19 | 20 | avail_lang: 21 | - french 22 | - german 23 | - spanish 24 | - italian 25 | - russian 26 | - chinese 27 | - portuguese 28 | - persian 29 | - estonian 30 | - mongolian 31 | - dutch 32 | - turkish 33 | - arabic 34 | - swedish 35 | - latvian 36 | - tamil 37 | - japanese 38 | - indonesian 39 | - slovenian 40 | - welsh 41 | - catalan 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /config/exp_spec/mbart.yaml: -------------------------------------------------------------------------------- 1 | module_name: mbart 2 | train_name: mbart 3 | train_id: default 4 | monitor: val/bleu 5 | num_nodes: 1 6 | data_root: null 7 | cv_data_root: null 8 | output_dir: null 9 | ckpt_name: "checkpoint-{epoch:02d}-{val/bleu:.2f}" 10 | use_deepspeed: false 11 | learning_rate: 2.0e-05 12 | weight_decay: 0.0 13 | adam_epsilon: 1.0e-6 14 | adam_betas: [0.9, 0.98] 15 | warmup_steps: 2500 16 | batch_size: 20 17 | test_batch_size: 20 18 | num_worker: 4 19 | num_train_epochs: 5 20 | gradient_accumulation_steps: 2 21 | attention_dropout: 0.1 22 | dropout: 0.3 23 | label_smoothing: 0.0 24 | language_init_model_path: null 25 | 26 | avail_lang: 27 | - french 28 | - german 29 | - spanish 30 | - italian 31 | - russian 32 | - chinese 33 | - portuguese 34 | - persian 35 | - estonian 36 | - mongolian 37 | - dutch 38 | - turkish 39 | - arabic 40 | - swedish 41 | - latvian 42 | - tamil 43 | - japanese 44 | - indonesian 45 | - slovenian 46 | - welsh 47 | - catalan 48 | 49 | 50 | -------------------------------------------------------------------------------- /config/exp_spec/whisper_asr.yaml: -------------------------------------------------------------------------------- 1 | module_name: whisper_asr 2 | train_name: whisper_large_asr 3 | train_id: second 4 | whisper_name: large 5 | monitor: valid_wer_epoch 6 | data_root: null 7 | cv_data_root: null 8 | output_dir: null 9 | ckpt_name: "checkpoint-{epoch:02d}-{valid_wer_epoch:.2f}" 10 | use_deepspeed: true 11 | use_acti_ckpt: true 12 | chunk_size: 30 13 | num_nodes: 4 14 | 15 | learning_rate: 5.0e-8 16 | lr_end: 1.0e-8 17 | lr_pow: 2.0 18 | weight_decay: 0.1 19 | adam_epsilon: 1.0e-6 20 | adam_betas: [0.9, 0.98] 21 | warmup_steps: 5000 22 | batch_size: 3 23 | test_batch_size: 5 24 | num_worker: 4 25 | num_train_epochs: 10 26 | gradient_accumulation_steps: 2 27 | 28 | avail_lang: 29 | - french 30 | - german 31 | - spanish 32 | - italian 33 | - russian 34 | - chinese 35 | - portuguese 36 | - persian 37 | - estonian 38 | - mongolian 39 | - dutch 40 | - turkish 41 | - arabic 42 | - swedish 43 | - latvian 44 | - tamil 45 | - japanese 46 | - indonesian 47 | - slovenian 48 | - welsh 49 | - catalan -------------------------------------------------------------------------------- /config/exp_spec/whisper.yaml: -------------------------------------------------------------------------------- 1 | module_name: whisper 2 | train_name: whisper_large 3 | train_id: initial 4 | whisper_name: large 5 | monitor: valid_bleu_epoch 6 | data_root: null 7 | cv_data_root: null 8 | output_dir: null 9 | ckpt_name: "checkpoint-{epoch:02d}-{valid_bleu_epoch:.2f}" 10 | use_deepspeed: true 11 | use_acti_ckpt: true 12 | chunk_size: 30 13 | num_nodes: 4 14 | loss_scale: 0 15 | 16 | learning_rate: 1.0e-7 17 | lr_end: 1.0e-8 18 | lr_pow: 2.0 19 | weight_decay: 0.1 20 | adam_epsilon: 1.0e-6 21 | adam_betas: [0.9, 0.98] 22 | warmup_steps: 5000 23 | batch_size: 3 24 | test_batch_size: 10 25 | num_worker: 4 26 | num_train_epochs: 10 27 | gradient_accumulation_steps: 2 28 | 29 | 30 | avail_lang: 31 | - french 32 | - german 33 | - spanish 34 | - italian 35 | - russian 36 | - chinese 37 | - portuguese 38 | - persian 39 | - estonian 40 | - mongolian 41 | - dutch 42 | - turkish 43 | - arabic 44 | - swedish 45 | - latvian 46 | - tamil 47 | - japanese 48 | - indonesian 49 | - slovenian 50 | - welsh 51 | - catalan 52 | 53 | -------------------------------------------------------------------------------- /data/lang_dict.csv: -------------------------------------------------------------------------------- 1 | full name,covost,"whisper","mbart" 2 | english,en,"en","en_XX" 3 | spanish,es,"es","es_XX" 4 | german,de,"de","de_DE" 5 | turkish,tr,"tr","tr_TR" 6 | persian,fa,"fa","fa_IR" 7 | swedish,sv-SE,"sv","sv_SE" 8 | mongolian,mn,"mn","mn_MN" 9 | chinese,zh-CN,"zh","zh_CN" 10 | welsh,cy,"cy","cy_GB" 11 | catalan,ca,"ca","ca_ES" 12 | slovenian,sl,"sl","sl_SI" 13 | estonian,et,"et","et_EE" 14 | indonesian,id,"id","id_ID" 15 | arabic,ar,"ar","ar_AR" 16 | tamil,ta,"ta","ta_IN" 17 | latvian,lv,"lv","lv_LV" 18 | japanese,ja,"ja","ja_XX" 19 | italian,it,"it","it_IT" 20 | russian,ru,"ru","ru_RU" 21 | portuguese,pt,"pt","pt_XX" 22 | dutch,nl,"nl","nl_XX" 23 | french,fr,"fr","fr_XX" 24 | korean,,"ko","ko_KR" 25 | polish,,"pl","pl_PL" 26 | hindi,,"hi","hi_IN" 27 | finnish,,"fi","fi_FI" 28 | gujarati,,"gu","gu_IN" 29 | kazakh,,"kk","kk_KZ" 30 | lithuanian,,"lt","lt_LT" 31 | myanmar,,"my","my_MM" 32 | nepali,,"ne","ne_NP" 33 | romanian,,"ro","ro_RO" 34 | sinhala,,"si","si_LK" 35 | vietnamese,,"vi","vi_VN" 36 | galician,,"gl","gl_ES" 37 | xhosa,,,"xh_ZA" 38 | ukrainian,,"uk","uk_UA" 39 | urdu,,"ur","ur_PK" 40 | tagalog,,"tl","tl_XX" 41 | thai,,"th","th_TH" 42 | telugu,,"te","te_IN" 43 | swahili,,"sw","sw_KE" 44 | pashto,,"ps","ps_AF" 45 | marathi,,"mr","mr_IN" 46 | malayalam,,"ml","ml_IN" 47 | macedonian,,"mk","mk_MK" 48 | khmer,,"km","km_KH" 49 | georgian,,"ka","ka_GE" 50 | croatian,,"hr","hr_HR" 51 | hebrew,,"iw","he_IL" 52 | bengali,,"bn","bn_IN" 53 | azerbaijani,,"az","az_AZ" 54 | afrikaans,,"af","af_ZA" 55 | burmese,,"my","my_MM" 56 | czech,,"cs","cs_CZ" -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | from transformers import get_polynomial_decay_schedule_with_warmup 2 | import torch 3 | try: 4 | from deepspeed.ops.adam import DeepSpeedCPUAdam 5 | deepspeed_available = True 6 | except ImportError: 7 | deepspeed_available = False 8 | 9 | 10 | def configure_optimizer_schedular(cfg, params_generator, num_training_steps, warmup_steps=None): 11 | no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "attn_ln.weight", "mlp_ln"] 12 | warmup_steps = cfg.warmup_steps if warmup_steps is None else warmup_steps 13 | optimizer_grouped_parameters = [ 14 | { 15 | "params": [p for n, p in params_generator() 16 | if not any(nd in n for nd in no_decay)], 17 | "weight_decay": cfg.weight_decay, 18 | }, 19 | { 20 | "params": [p for n, p in params_generator() 21 | if any(nd in n for nd in no_decay)], 22 | "weight_decay": 0.0, 23 | }, 24 | ] 25 | if cfg.use_deepspeed and deepspeed_available: 26 | optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, 27 | lr=cfg.learning_rate, 28 | eps=cfg.adam_epsilon, 29 | betas=cfg.adam_betas) 30 | else: 31 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, 32 | lr=cfg.learning_rate, 33 | eps=cfg.adam_epsilon, 34 | betas=cfg.adam_betas) 35 | 36 | scheduler = get_polynomial_decay_schedule_with_warmup( 37 | optimizer, 38 | num_warmup_steps=warmup_steps, 39 | num_training_steps=num_training_steps, 40 | power=cfg.lr_pow, 41 | lr_end=cfg.lr_end, 42 | ) 43 | 44 | return optimizer, scheduler 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the official repository of [ComSL: A Composite Speech-Language Model for End-to-End Speech-to-Text Translation](https://arxiv.org/abs/2305.14838), which includes the code for finetuning whisper and mbart, creating pseudo dataset and finetuning the ComSL model. 4 | 5 | ## Preparation 6 | 7 | To run the code, first install the requirements: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Then, download CoVoST2 dataset following the instructions in [CoVoST2 Github Page](https://github.com/facebookresearch/covostz). 14 | 15 | ## Training 16 | 17 | To launch the training, you should change the `data_root` in the config files in config/exp_spec to the root of CoVoST2 dataset. After that use command to start training: 18 | 19 | ```bash 20 | python3 main.py -c XXX.yaml 21 | ``` 22 | 23 | where XXX.yaml is the configuration file in config/exp_spec. 24 | 25 | ## Training with Pseudo Data 26 | 27 | In ouder to train with pesudo data, you should first download and extract Common Voice dataset from [Common Voice Website](https://commonvoice.mozilla.org/en/datasets). Then, modified the data path and pretrained model path in create_pseudo_data.py and run this script. 28 | 29 | After that, set `cv_data_root` in config/exp_spec/comsl.yaml to the path of Common Voice dataset and uncomment the language in `avail_lang_extra`. Finally, run the training script as above. 30 | 31 | ```bash 32 | python3 main.py -c comsl.yaml 33 | ``` 34 | 35 | ## Citation 36 | 37 | ```bibtex 38 | @misc{le2023comsl, 39 | title={ComSL: A Composite Speech-Language Model for End-to-End Speech-to-Text Translation}, 40 | author={Chenyang Le and Yao Qian and Long Zhou and Shujie Liu and Yanmin Qian and Michael Zeng and Xuedong Huang}, 41 | year={2023}, 42 | eprint={2305.14838}, 43 | archivePrefix={arXiv}, 44 | primaryClass={cs.CL} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /config/exp_spec/comsl.yaml: -------------------------------------------------------------------------------- 1 | module_name: comst 2 | train_name: ComST 3 | train_id: default 4 | 5 | use_deepspeed: true 6 | use_acti_ckpt: false 7 | num_nodes: 4 8 | monitor: val_bleu_spch_epoch 9 | ckpt_name: "checkpoint-{epoch:02d}-{val_bleu_spch_epoch:.2f}" 10 | data_root: null 11 | cv_data_root: null 12 | output_dir: null 13 | 14 | warmup_steps: 5000 15 | batch_size: 1 16 | test_batch_size: 1 17 | num_worker: 3 18 | num_train_epochs: 15 19 | gradient_accumulation_steps: 1 20 | chunk_size: 11 21 | 22 | 23 | # whisper model 24 | whisper_name: medium 25 | spch_n_layers: 24 26 | disable_spch_grad_epoch: 5 27 | 28 | learning_rate: 2.0e-5 29 | lr_end: 1.0e-7 30 | lr_pow: 2.0 31 | weight_decay: 0.1 32 | adam_epsilon: 1.0e-6 33 | dropout: 0.1 34 | attention_dropout: 0.0 35 | adam_betas: [0.9, 0.98] 36 | enc_grad_mult: 2.0 37 | guide_alpha: 0.8 38 | text_alpha: 0.2 39 | 40 | spch_loss_weight: 0.35 41 | asr_loss_weight: 0.35 42 | text_loss_weight: 0.2 43 | use_cml: true 44 | cml_loss_weight: 0.1 45 | use_erm: true 46 | erm_loss_weight: 0.2 47 | 48 | # model path relative to cache_dir 49 | language_regularization_model_path: null 50 | language_init_model_path: null 51 | spch_init_model_path: null 52 | 53 | avail_lang: 54 | - french 55 | - german 56 | - spanish 57 | - italian 58 | - russian 59 | - chinese 60 | - portuguese 61 | - persian 62 | - estonian 63 | - mongolian 64 | - dutch 65 | - turkish 66 | - arabic 67 | - swedish 68 | - latvian 69 | - tamil 70 | - japanese 71 | - indonesian 72 | - slovenian 73 | - welsh 74 | - catalan 75 | 76 | # avail_lang_extra: 77 | # - french 78 | # - german 79 | # - spanish 80 | # - italian 81 | # - russian 82 | # - chinese 83 | # - portuguese 84 | # - persian 85 | # - estonian 86 | # - mongolian 87 | # - dutch 88 | # - turkish 89 | # - arabic 90 | # - swedish 91 | # - latvian 92 | # - tamil 93 | # - japanese 94 | # - indonesian 95 | # - slovenian 96 | # - welsh 97 | # - catalan 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /model/mBART_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from model.model_util import load_mbart_model 3 | 4 | 5 | class MbartDecoder(nn.Module): 6 | def __init__(self, cfg, mbart_model) -> None: 7 | super().__init__() 8 | self.cfg = cfg 9 | self.decoder = mbart_model.base_model.decoder 10 | self.lm_head = mbart_model.lm_head 11 | 12 | def forward(self, dec_input_ids, encoder_hidden_states, past_key_values=None, use_cache=False): 13 | dec_output = self.decoder(dec_input_ids, 14 | encoder_hidden_states=encoder_hidden_states, 15 | past_key_values=past_key_values, 16 | use_cache=use_cache) 17 | last_hidden_state = dec_output.last_hidden_state 18 | if use_cache: 19 | attn_key_values = dec_output.past_key_values 20 | lm_logits = self.lm_head(last_hidden_state) 21 | if use_cache: 22 | return [lm_logits, {'attn_key_values': attn_key_values}] 23 | else: 24 | return [lm_logits, {}] 25 | 26 | 27 | class MbartEncoder(nn.Module): 28 | def __init__(self, cfg, mbart_model) -> None: 29 | super().__init__() 30 | self.cfg = cfg 31 | 32 | self.encoder = mbart_model.base_model.encoder 33 | 34 | def forward(self, input_ids, attention_mask=None, output_attentions=False): 35 | return self.encoder(input_ids, attention_mask=attention_mask, 36 | output_attentions=output_attentions).last_hidden_state 37 | 38 | 39 | class MbartModel(nn.Module): 40 | def __init__(self, cfg) -> None: 41 | super().__init__() 42 | mbart_model = load_mbart_model(cfg) 43 | self.encoder = MbartEncoder(cfg, mbart_model) 44 | self.decoder = MbartDecoder(cfg, mbart_model) 45 | 46 | def forward(self, enc_input_ids, dec_input_ids, attention_mask=None): 47 | encoder_hidden_states = self.encoder(enc_input_ids, attention_mask=attention_mask) 48 | lm_logits = self.decoder(dec_input_ids, encoder_hidden_states)[0] 49 | return lm_logits -------------------------------------------------------------------------------- /Whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) 52 | ) 53 | 54 | 55 | class BasicTextNormalizer: 56 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 57 | self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols 58 | self.split_letters = split_letters 59 | 60 | def __call__(self, s: str): 61 | s = s.lower() 62 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 63 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 64 | s = self.clean(s).lower() 65 | 66 | if self.split_letters: 67 | s = " ".join(regex.findall(r"\X", s, regex.U)) 68 | 69 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 70 | 71 | return s 72 | -------------------------------------------------------------------------------- /Whisper/utils.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Iterator, TextIO 3 | 4 | 5 | def exact_div(x, y): 6 | assert x % y == 0 7 | return x // y 8 | 9 | 10 | def str2bool(string): 11 | str2val = {"True": True, "False": False} 12 | if string in str2val: 13 | return str2val[string] 14 | else: 15 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 16 | 17 | 18 | def optional_int(string): 19 | return None if string == "None" else int(string) 20 | 21 | 22 | def optional_float(string): 23 | return None if string == "None" else float(string) 24 | 25 | 26 | def compression_ratio(text) -> float: 27 | return len(text) / len(zlib.compress(text.encode("utf-8"))) 28 | 29 | 30 | def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): 31 | assert seconds >= 0, "non-negative timestamp expected" 32 | milliseconds = round(seconds * 1000.0) 33 | 34 | hours = milliseconds // 3_600_000 35 | milliseconds -= hours * 3_600_000 36 | 37 | minutes = milliseconds // 60_000 38 | milliseconds -= minutes * 60_000 39 | 40 | seconds = milliseconds // 1_000 41 | milliseconds -= seconds * 1_000 42 | 43 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 44 | return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 45 | 46 | 47 | def write_txt(transcript: Iterator[dict], file: TextIO): 48 | for segment in transcript: 49 | print(segment['text'].strip(), file=file, flush=True) 50 | 51 | 52 | def write_vtt(transcript: Iterator[dict], file: TextIO): 53 | print("WEBVTT\n", file=file) 54 | for segment in transcript: 55 | print( 56 | f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" 57 | f"{segment['text'].strip().replace('-->', '->')}\n", 58 | file=file, 59 | flush=True, 60 | ) 61 | 62 | 63 | def write_srt(transcript: Iterator[dict], file: TextIO): 64 | """ 65 | Write a transcript to a file in SRT format. 66 | 67 | Example usage: 68 | from pathlib import Path 69 | from whisper.misc import write_srt 70 | 71 | result = transcribe(model, audio_path, temperature=temperature, **args) 72 | 73 | # save SRT 74 | audio_basename = Path(audio_path).stem 75 | with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt: 76 | write_srt(result["segments"], file=srt) 77 | """ 78 | for i, segment in enumerate(transcript, start=1): 79 | # write srt lines 80 | print( 81 | f"{i}\n" 82 | f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " 83 | f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" 84 | f"{segment['text'].strip().replace('-->', '->')}\n", 85 | file=file, 86 | flush=True, 87 | ) 88 | -------------------------------------------------------------------------------- /data/data_util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torchaudio 3 | import torchaudio.transforms as at 4 | import os 5 | import torch 6 | import Whisper 7 | 8 | LANG_DICT = pd.read_csv('data/lang_dict.csv').set_index("full name").to_dict("index") 9 | 10 | 11 | def read_table(path): 12 | return pd.read_table(path, on_bad_lines='error', quoting=3, doublequote=False, encoding='utf-8', engine="python") 13 | 14 | 15 | def load_data_record(data_root, split, language_list, subsample_rate=1, 16 | expanded_data_root=None, expanded_language_list=None): 17 | data_pair_lists = [] 18 | for lang in language_list: 19 | data_lang_code = LANG_DICT[lang]['covost'] 20 | data_pair = read_table(os.path.join(data_root, 'covost', 21 | f"{data_lang_code}_en", f"covost_v2.{data_lang_code}_en.{split}.tsv")) 22 | data_pair['src_lang'] = lang 23 | data_pair['tgt_lang'] = 'english' 24 | data_pair['audio_root'] = os.path.join(data_root, 'extracted', data_lang_code, 'clips') 25 | data_pair = data_pair.dropna() 26 | data_pair_lists.append(data_pair) 27 | print(f"Loaded {len(data_pair)} {lang} to english data pairs.") 28 | if split == 'train' and expanded_language_list is not None and lang in expanded_language_list: 29 | expanded_data_pair = read_table(os.path.join(expanded_data_root, 'psudo', f"{data_lang_code}_en.train.tsv")) 30 | test_data_pair = read_table(os.path.join(data_root, 'covost', f"{data_lang_code}_en", 31 | f"covost_v2.{data_lang_code}_en.test.tsv")) 32 | dev_data_pair = read_table(os.path.join(data_root, 'covost', f"{data_lang_code}_en", 33 | f"covost_v2.{data_lang_code}_en.dev.tsv")) 34 | expanded_data_pair = pd.concat( 35 | [expanded_data_pair, data_pair, data_pair, test_data_pair, test_data_pair, dev_data_pair, 36 | dev_data_pair], ignore_index=True).drop_duplicates(subset=['path'], keep=False) 37 | expanded_data_pair['src_lang'] = lang 38 | expanded_data_pair['tgt_lang'] = 'english' 39 | expanded_data_pair['audio_root'] = os.path.join(expanded_data_root, data_lang_code, 'clips') 40 | data_pair_lists.append(expanded_data_pair) 41 | print(f"Loaded {len(expanded_data_pair)} {lang} to english extra data pairs.") 42 | joined_data_pair_lists = pd.concat(data_pair_lists, ignore_index=True).to_dict("records")[::subsample_rate] 43 | data_pair_lists = [data_pair_list.to_dict("records")[::subsample_rate] for data_pair_list in data_pair_lists] 44 | return joined_data_pair_lists, data_pair_lists 45 | 46 | 47 | def load_wave(wave_path, sample_rate: int = 16000): 48 | waveform, sr = torchaudio.load(wave_path, normalize=True) 49 | 50 | duration = waveform.shape[1] / sr 51 | if sample_rate != sr: 52 | waveform = at.Resample(sr, sample_rate)(waveform) 53 | return waveform, duration 54 | 55 | 56 | def pad_trim_audio(audio, cfg): 57 | max_lens = cfg.chunk_size * cfg.sample_rate 58 | audio_input_feature = [Whisper.log_mel_spectrogram(Whisper.pad_or_trim(a, max_lens)) for a in audio] 59 | audio_input_feature = torch.concat([a[None, :] for a in audio_input_feature]) 60 | return audio_input_feature 61 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | from pytorch_lightning import Trainer, seed_everything 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy 9 | 10 | from config.parse_yaml_args import parse_args_and_yaml 11 | from data.data_util import load_data_record 12 | from modules import get_module 13 | 14 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" 15 | 16 | cfg = parse_args_and_yaml() 17 | module_name = cfg.module_name 18 | 19 | Module = get_module(module_name) 20 | seed_everything(42) 21 | 22 | if __name__ == "__main__": 23 | 24 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 25 | 26 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 27 | for split in ["train", "dev", "test"]: 28 | subsample_rate = cfg.valid_sample_rate if split == "dev" else 1 29 | language_list = cfg.language_list 30 | expanded_language_list = cfg.extra_language_list 31 | 32 | joined_data_pair_lists[split], \ 33 | sep_data_pair_lists[split] = load_data_record(cfg.data_root, 34 | split, 35 | subsample_rate=subsample_rate, 36 | language_list=language_list, 37 | expanded_data_root=cfg.cv_data_root, 38 | expanded_language_list=expanded_language_list, ) 39 | 40 | Path(cfg.log_output_dir).mkdir(exist_ok=True) 41 | Path(cfg.check_output_dir).mkdir(exist_ok=True) 42 | Path(cfg.cache_dir).mkdir(exist_ok=True) 43 | 44 | tflogger = TensorBoardLogger( 45 | save_dir=cfg.log_output_dir, 46 | name=cfg.train_name, 47 | version=cfg.train_id 48 | ) 49 | ckpt_dir = f"{cfg.check_output_dir}/checkpoint_{cfg.train_name}_{cfg.train_id}" 50 | 51 | monitor = cfg.monitor 52 | 53 | if "bleu" in monitor: 54 | mode = "max" 55 | elif "wer" in monitor or "loss" in monitor: 56 | mode = "min" 57 | else: 58 | raise NotImplementedError 59 | 60 | checkpoint_callback = ModelCheckpoint( 61 | monitor=monitor, 62 | mode=mode, 63 | dirpath=ckpt_dir, 64 | filename=cfg.ckpt_name, 65 | auto_insert_metric_name=False, 66 | save_top_k=5, # all model save 67 | save_last=True, 68 | ) 69 | 70 | callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="step")] 71 | 72 | model = Module(cfg, joined_data_pair_lists, sep_data_pair_lists) 73 | 74 | if cfg.use_deepspeed: 75 | strategy = DeepSpeedStrategy( 76 | stage=2, 77 | logging_level=logging.WARN, 78 | offload_optimizer=True, 79 | loss_scale=cfg.ds_loss_scale, 80 | ) 81 | else: 82 | strategy = DDPStrategy(find_unused_parameters=False) 83 | 84 | trainer = Trainer( 85 | precision=16, 86 | num_nodes=cfg.num_nodes, 87 | accelerator="gpu", 88 | max_epochs=cfg.num_train_epochs, 89 | accumulate_grad_batches=cfg.gradient_accumulation_steps, 90 | logger=tflogger, 91 | callbacks=callback_list, 92 | strategy=strategy, 93 | ) 94 | 95 | if cfg.test_ckpt_name is not None: 96 | trainer.test(model, ckpt_path=f"{ckpt_dir}/{cfg.test_ckpt_name}") 97 | exit() 98 | else: 99 | if cfg.mode == "test": 100 | trainer.test(model, ckpt_path="last") 101 | else: 102 | if cfg.mode == "resume": 103 | trainer.fit(model, ckpt_path="last") 104 | else: 105 | trainer.fit(model, ckpt_path=None) 106 | trainer.test(model, ckpt_path="best") 107 | -------------------------------------------------------------------------------- /create_pseudo_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pandas as pd 4 | 5 | from tqdm import tqdm 6 | from torch.utils.data import DataLoader 7 | 8 | from data.data_util import LANG_DICT 9 | from config.parse_yaml_args import parse_args_and_yaml 10 | from model.model_util import load_mbart_model, load_mbart_tokenizer 11 | 12 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 13 | 14 | 15 | class MbartDataset(torch.utils.data.Dataset): 16 | def __init__(self, audio_info_list) -> None: 17 | super().__init__() 18 | 19 | self.audio_info_list = audio_info_list 20 | 21 | def __len__(self): 22 | return len(self.audio_info_list) 23 | 24 | def __getitem__(self, id): 25 | record = self.audio_info_list[id] 26 | 27 | sentence = record["sentence"] 28 | 29 | return { 30 | "sentence": sentence, 31 | "index": id, 32 | } 33 | 34 | 35 | class MbartCollatorWhithPadding: 36 | def __init__( 37 | self, 38 | tokenizer, 39 | src_lang, 40 | ) -> None: 41 | super().__init__() 42 | self.tokenizer = tokenizer 43 | 44 | self.tokenizer.src_lang = LANG_DICT[src_lang]["mbart"] 45 | 46 | def __call__(self, features): 47 | ( 48 | sentences, 49 | indexs, 50 | ) = ( 51 | [], 52 | [], 53 | ) 54 | for f in features: 55 | sentences.append(f["sentence"]) 56 | indexs.append(f["index"]) 57 | 58 | inputs = self.tokenizer( 59 | sentences, padding=True, truncation=True, return_tensors="pt" 60 | ) 61 | 62 | batch = { 63 | "inputs": inputs, 64 | "indexs": indexs, 65 | } 66 | 67 | return batch 68 | 69 | 70 | cfg = parse_args_and_yaml(config_path="config/exp_spec/mbart.yaml") 71 | 72 | CV_root = "" # TODO: set your Common Voice root path 73 | data_language = "french" # TODO: set your data language, e.g. 'french', 'chinese' 74 | cfg.mbart_model_path = "" # TODO: set your mBART model path to your pretrained model 75 | output_dir = f"{CV_root}/pseudo" 76 | 77 | 78 | mbart_tokenizer = load_mbart_tokenizer(cfg) 79 | mbart_model = load_mbart_model(cfg) 80 | mbart_model = mbart_model.to("cuda") 81 | 82 | if __name__ == "__main__": 83 | data_code = LANG_DICT[data_language]["covost"] 84 | for split in ["train"]: 85 | if os.path.exists(os.path.join(output_dir, f"{data_code}_en.{split}.tsv")): 86 | continue 87 | table_path = os.path.join(CV_root, data_code, f"{split}.tsv") 88 | data_pair = pd.read_table( 89 | table_path, 90 | on_bad_lines="error", 91 | quoting=3, 92 | doublequote=False, 93 | encoding="utf-8", 94 | engine="python", 95 | ) 96 | output_data_pair = data_pair.copy()[["path", "sentence"]] 97 | output_data_pair["translation"] = None 98 | data_pair_list = data_pair.to_dict("records") 99 | print(f"Loaded {len(data_pair)} {data_language} to english {split} data pairs.") 100 | 101 | dataset = MbartDataset(data_pair_list) 102 | collecter = MbartCollatorWhithPadding(mbart_tokenizer, src_lang=data_language) 103 | dataloader = DataLoader( 104 | dataset, 105 | batch_size=40, 106 | shuffle=False, 107 | collate_fn=collecter, 108 | num_workers=2, 109 | pin_memory=True, 110 | drop_last=False, 111 | ) 112 | for batch in tqdm(dataloader): 113 | inputs = batch["inputs"].to("cuda") 114 | indexs = batch["indexs"] 115 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 116 | with torch.no_grad(): 117 | outputs = mbart_model.generate( 118 | **inputs, 119 | decoder_start_token_id=mbart_tokenizer.lang_code_to_id["en_XX"], 120 | ) 121 | translations = mbart_tokenizer.batch_decode( 122 | outputs, skip_special_tokens=True 123 | ) 124 | for i, index in enumerate(indexs): 125 | output_data_pair.loc[index, "translation"] = translations[i] 126 | output_data_pair.to_csv( 127 | os.path.join(output_dir, f"{data_code}_en.{split}.tsv"), 128 | sep="\t", 129 | index=False, 130 | quoting=3, 131 | doublequote=False, 132 | encoding="utf-8", 133 | ) 134 | -------------------------------------------------------------------------------- /Whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Union 4 | 5 | import ffmpeg 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | N_MELS = 80 16 | HOP_LENGTH = 160 17 | CHUNK_LENGTH = 10 18 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk 19 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input 20 | 21 | 22 | def load_audio(file: str, sr: int = SAMPLE_RATE): 23 | """ 24 | Open an audio file and read as mono waveform, resampling as necessary 25 | 26 | Parameters 27 | ---------- 28 | file: str 29 | The audio file to open 30 | 31 | sr: int 32 | The sample rate to resample the audio if necessary 33 | 34 | Returns 35 | ------- 36 | A NumPy array containing the audio waveform, in float32 dtype. 37 | """ 38 | try: 39 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 40 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 41 | out, _ = ( 42 | ffmpeg.input(file, threads=0) 43 | .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) 44 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 45 | ) 46 | except ffmpeg.Error as e: 47 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 48 | 49 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 50 | 51 | 52 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 53 | """ 54 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 55 | """ 56 | if torch.is_tensor(array): 57 | if array.shape[axis] > length: 58 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) 59 | 60 | if array.shape[axis] < length: 61 | pad_widths = [(0, 0)] * array.ndim 62 | pad_widths[axis] = (0, length - array.shape[axis]) 63 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 64 | else: 65 | if array.shape[axis] > length: 66 | array = array.take(indices=range(length), axis=axis) 67 | 68 | if array.shape[axis] < length: 69 | pad_widths = [(0, 0)] * array.ndim 70 | pad_widths[axis] = (0, length - array.shape[axis]) 71 | array = np.pad(array, pad_widths) 72 | 73 | return array 74 | 75 | 76 | @lru_cache(maxsize=None) 77 | def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: 78 | """ 79 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 80 | Allows decoupling librosa dependency; saved using: 81 | 82 | np.savez_compressed( 83 | "mel_filters.npz", 84 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 85 | ) 86 | """ 87 | assert n_mels == 80, f"Unsupported n_mels: {n_mels}" 88 | with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: 89 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 90 | 91 | 92 | def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): 93 | """ 94 | Compute the log-Mel spectrogram of 95 | 96 | Parameters 97 | ---------- 98 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 99 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 100 | 101 | n_mels: int 102 | The number of Mel-frequency filters, only 80 is supported 103 | 104 | Returns 105 | ------- 106 | torch.Tensor, shape = (80, n_frames) 107 | A Tensor that contains the Mel spectrogram 108 | """ 109 | if not torch.is_tensor(audio): 110 | if isinstance(audio, str): 111 | audio = load_audio(audio) 112 | audio = torch.from_numpy(audio) 113 | 114 | window = torch.hann_window(N_FFT).to(audio.device) 115 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 116 | magnitudes = stft[:, :-1].abs() ** 2 117 | 118 | filters = mel_filters(audio.device, n_mels) 119 | mel_spec = filters @ magnitudes 120 | 121 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 122 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 123 | log_spec = (log_spec + 4.0) / 4.0 124 | return log_spec 125 | -------------------------------------------------------------------------------- /config/parse_yaml_args.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import os 4 | 5 | default_config_parser = parser = argparse.ArgumentParser( 6 | description='Training Config', add_help=False) 7 | 8 | parser.add_argument( 9 | '-c', 10 | '--config', 11 | default='comsl.yaml', 12 | type=str, 13 | metavar='FILE', 14 | help='YAML config file specifying default arguments') 15 | 16 | parser.add_argument( 17 | '--data_root', 18 | type=str, 19 | default=None, 20 | help="The root directory of CoVoST-2 audio clips") 21 | 22 | parser.add_argument( 23 | '--cv_data_root', 24 | type=str, 25 | default=None, 26 | help="The root directory of Common Voice audio clips") 27 | 28 | parser.add_argument( 29 | '--output_dir', 30 | type=str, 31 | default=None, ) 32 | 33 | parser.add_argument( 34 | '--ckpt_name', 35 | type=str, 36 | default="checkpoint-{epoch:02d}-{step}", ) 37 | 38 | parser.add_argument( 39 | '--num_nodes', 40 | type=int, 41 | default=1, ) 42 | 43 | parser.add_argument( 44 | '--language_list', 45 | type=str, 46 | nargs='+', 47 | default=None) 48 | 49 | parser.add_argument( 50 | '--sample_rate', 51 | type=int, 52 | default=16000, ) 53 | 54 | parser.add_argument( 55 | '--valid_sample_rate', 56 | type=int, 57 | default=4, ) 58 | 59 | 60 | parser.add_argument( 61 | '--mode', 62 | type=str, 63 | default="resume", ) 64 | 65 | parser.add_argument( 66 | '--use_acti_ckpt', 67 | action='store_false', ) 68 | 69 | parser.add_argument( 70 | '--use_deepspeed', 71 | action='store_true', 72 | ) 73 | 74 | parser.add_argument( 75 | '--ds_loss_scale', 76 | type=float, 77 | default=1.0, ) 78 | 79 | parser.add_argument( 80 | '--test_ckpt_name', 81 | type=str, 82 | default=None, ) 83 | 84 | parser.add_argument( 85 | '--chunk_size', 86 | type=int, 87 | default=11, ) 88 | 89 | # Optimizer and Scheduler 90 | parser.add_argument( 91 | '--warmup_steps', 92 | type=int, 93 | default=5000, ) 94 | 95 | parser.add_argument( 96 | '--learning_rate', 97 | type=float, 98 | default=1e-5, ) 99 | 100 | parser.add_argument( 101 | '--adam_epsilon', 102 | type=float, 103 | default=1e-6, ) 104 | 105 | parser.add_argument( 106 | '--adam_betas', 107 | type=float, 108 | nargs='+', 109 | default=(0.9, 0.98), ) 110 | 111 | parser.add_argument( 112 | '--weight_decay', 113 | type=float, 114 | default=0.1, ) 115 | 116 | parser.add_argument( 117 | '--lr_pow', 118 | type=float, 119 | default=2.0, ) 120 | 121 | parser.add_argument( 122 | '--lr_end', 123 | type=float, 124 | default=1e-7, ) 125 | 126 | # For ComST 127 | parser.add_argument( 128 | '--extra_language_list', 129 | type=str, 130 | nargs='+', 131 | default=None) 132 | 133 | parser.add_argument( 134 | '--language_regularization_model_path', 135 | type=str, 136 | default=None, 137 | ) 138 | 139 | parser.add_argument( 140 | '--language_init_model_path', 141 | type=str, 142 | default=None, 143 | ) 144 | 145 | parser.add_argument( 146 | '--spch_init_model_path', 147 | type=str, 148 | default=None, 149 | ) 150 | 151 | parser.add_argument( 152 | '--spch_n_layers', 153 | type=int, 154 | default=-1, ) 155 | 156 | parser.add_argument( 157 | '--erm_layer', 158 | type=int, 159 | default=4, ) 160 | 161 | parser.add_argument( 162 | '--p_mask', 163 | type=float, 164 | default=0.15, ) 165 | 166 | parser.add_argument( 167 | '--disable_spch_grad_epoch', 168 | type=int, 169 | default=0, ) 170 | 171 | 172 | def _parse_args_and_yaml(given_parser=None, config_path=None): 173 | if given_parser is None: 174 | given_parser = default_config_parser 175 | given_configs, remaining = given_parser.parse_known_args() 176 | file_name = given_configs.config if "yaml" in given_configs.config else given_configs.config + ".yaml" 177 | config_path = "config/exp_spec/" + file_name if config_path is None else config_path 178 | with open(config_path, 'r', encoding='utf-8') as f: 179 | cfg = yaml.safe_load(f) 180 | given_parser.set_defaults(**cfg) 181 | 182 | args = given_parser.parse_args(remaining) 183 | 184 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 185 | return args, args_text 186 | 187 | 188 | def parse_args_and_yaml(arg_parser=None, config_path=None): 189 | cfg = _parse_args_and_yaml(arg_parser, config_path)[0] 190 | 191 | if "OUTPUT_DIR" in os.environ: 192 | cfg.output_dir = os.environ["OUTPUT_DIR"] 193 | setattr(cfg, "log_output_dir", f"{cfg.output_dir}/logs") 194 | setattr(cfg, "check_output_dir", f"{cfg.output_dir}/ckpt") 195 | setattr(cfg, "cache_dir", f"{cfg.output_dir}/cache") 196 | if "DATA_ROOT" in os.environ: 197 | cfg.data_root = os.environ["DATA_ROOT"] 198 | if hasattr(cfg, "cv_data_root") and "CV_DATA_ROOT" in os.environ: 199 | cfg.cv_data_root = os.environ["CV_DATA_ROOT"] 200 | return cfg 201 | 202 | 203 | if __name__ == "__main__": 204 | args, args_text = _parse_args_and_yaml() 205 | print(args_text) 206 | print(args.cache_dir) 207 | -------------------------------------------------------------------------------- /Whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import Whisper, ModelDimensions 14 | from .transcribe import transcribe 15 | 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt", 27 | } 28 | 29 | 30 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 31 | os.makedirs(root, exist_ok=True) 32 | 33 | expected_sha256 = url.split("/")[-2] 34 | download_target = os.path.join(root, os.path.basename(url)) 35 | 36 | if os.path.exists(download_target) and not os.path.isfile(download_target): 37 | raise RuntimeError(f"{download_target} exists and is not a regular file") 38 | 39 | if os.path.isfile(download_target): 40 | model_bytes = open(download_target, "rb").read() 41 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 42 | return model_bytes if in_memory else download_target 43 | else: 44 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 45 | 46 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 47 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 48 | while True: 49 | buffer = source.read(8192) 50 | if not buffer: 51 | break 52 | 53 | output.write(buffer) 54 | loop.update(len(buffer)) 55 | 56 | model_bytes = open(download_target, "rb").read() 57 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 58 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") 59 | 60 | return model_bytes if in_memory else download_target 61 | 62 | 63 | def available_models() -> List[str]: 64 | """Returns the names of available model""" 65 | return list(_MODELS.keys()) 66 | 67 | 68 | def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False, **kwargs) -> Whisper: 69 | """ 70 | Load a Whisper ASR model 71 | 72 | Parameters 73 | ---------- 74 | name : str 75 | one of the official model names listed by `whisper.available_models()`, or 76 | path to a model checkpoint containing the model dimensions and the model state_dict. 77 | device : Union[str, torch.device] 78 | the PyTorch device to put the model into 79 | download_root: str 80 | path to download the model files; by default, it uses "~/.cache/whisper" 81 | in_memory: bool 82 | whether to preload the model weights into host memory 83 | 84 | Returns 85 | ------- 86 | model : Whisper 87 | The Whisper ASR model instance 88 | """ 89 | 90 | if device is None: 91 | device = "cuda" if torch.cuda.is_available() else "cpu" 92 | if download_root is None: 93 | download_root = os.getenv( 94 | "XDG_CACHE_HOME", 95 | os.path.join(os.path.expanduser("~"), ".cache", "whisper") 96 | ) 97 | 98 | if name in _MODELS: 99 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 100 | elif os.path.isfile(name): 101 | checkpoint_file = open(name, "rb").read() if in_memory else name 102 | else: 103 | raise RuntimeError(f"Model {name} not found; available model = {available_models()}") 104 | 105 | with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: 106 | checkpoint = torch.load(fp, map_location=device) 107 | del checkpoint_file 108 | 109 | dims = ModelDimensions(**checkpoint["dims"]) 110 | model = Whisper(dims, **kwargs) 111 | model.load_state_dict(checkpoint["model_state_dict"], strict=True) 112 | 113 | # return model.to(device) 114 | return model 115 | -------------------------------------------------------------------------------- /model/model_util.py: -------------------------------------------------------------------------------- 1 | from torch.nn import LayerNorm 2 | from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartConfig 3 | from Whisper.model import * 4 | 5 | MBART_PRETRAINED_MODEL = "facebook/mbart-large-50-many-to-many-mmt" 6 | 7 | 8 | def load_mbart_tokenizer(cfg, extra_special_tokens=None): 9 | if extra_special_tokens is None: 10 | extra_special_tokens = [] 11 | extra_special_tokens = ['cy_GB', 'ca_ES'] + extra_special_tokens 12 | 13 | tokenizer = MBart50TokenizerFast.from_pretrained( 14 | MBART_PRETRAINED_MODEL, 15 | cache_dir=cfg.cache_dir, 16 | additional_special_tokens=extra_special_tokens 17 | ) 18 | tokenizer.lang_code_to_id["cy_GB"] = tokenizer.convert_tokens_to_ids("cy_GB") 19 | tokenizer.lang_code_to_id["ca_ES"] = tokenizer.convert_tokens_to_ids("ca_ES") 20 | return tokenizer 21 | 22 | 23 | def load_mbart_model(cfg, extra_special_tokens=None, load_from_local=True, path=None): 24 | if extra_special_tokens is None: 25 | extra_special_tokens = [] 26 | 27 | configuration = MBartConfig.from_pretrained( 28 | MBART_PRETRAINED_MODEL, 29 | cache_dir=cfg.cache_dir 30 | ) 31 | if hasattr(cfg, "attention_dropout"): 32 | configuration.attention_dropout = cfg.attention_dropout 33 | if hasattr(cfg, "dropout"): 34 | configuration.dropout = cfg.dropout 35 | mbart_model = MBartForConditionalGeneration.from_pretrained( 36 | MBART_PRETRAINED_MODEL, 37 | cache_dir=cfg.cache_dir, 38 | config=configuration 39 | ) 40 | mbart_model.resize_token_embeddings(configuration.vocab_size + 2 + len(extra_special_tokens)) 41 | if path is None: 42 | path = cfg.language_init_model_path 43 | if load_from_local and path is not None: 44 | mbart_model.load_state_dict(torch.load(path)) 45 | print("load mbart model from {}".format(path)) 46 | 47 | return mbart_model 48 | 49 | 50 | def lengths_to_padding_mask(lens): 51 | bsz, max_lens = lens.size(0), torch.max(lens).item() 52 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 53 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 54 | return mask 55 | 56 | 57 | class Conv1dAdaptor(nn.Module): 58 | def __init__( 59 | self, 60 | in_dim, 61 | out_dim, 62 | n_layers=3, 63 | kernel_size=3, 64 | stride=2, 65 | layerdrop=0.0, 66 | layernorm=False, 67 | proj=False, 68 | ): 69 | super().__init__() 70 | self.proj, self.proj_ln = None, None 71 | self.post_proj, self.post_proj_ln = None, None 72 | if proj: 73 | self.proj = nn.Sequential( 74 | nn.Linear(in_dim, in_dim * 4), 75 | nn.ReLU(), 76 | nn.Linear(in_dim * 4, in_dim) 77 | ) 78 | self.proj_ln = LayerNorm(in_dim) 79 | self.post_proj = nn.Sequential( 80 | nn.Linear(out_dim, out_dim * 4), 81 | nn.ReLU(), 82 | nn.Linear(out_dim * 4, out_dim), 83 | ) 84 | self.post_proj_ln = LayerNorm(out_dim) 85 | 86 | self.layers = nn.ModuleList( 87 | nn.Conv1d( 88 | in_dim if i == 0 else out_dim, 89 | out_dim * 2, 90 | kernel_size, 91 | stride=stride, 92 | padding=kernel_size // 2, 93 | ) 94 | for i in range(n_layers) 95 | ) 96 | self.stride = stride 97 | self.layerdrop = layerdrop 98 | self.layernorm = LayerNorm(in_dim) if layernorm else None 99 | 100 | def forward(self, x, padding_mask: Optional[torch.Tensor] = None): 101 | if self.layernorm is not None: 102 | x = self.layernorm(x) 103 | 104 | if self.proj is not None: 105 | x = x + 0.5 * self.proj(x) 106 | x = self.proj_ln(x) 107 | 108 | # B x T x C -> B x C x T 109 | x = x.transpose(1, 2) 110 | out_lens = None 111 | if padding_mask is not None: 112 | out_lens = (~padding_mask).sum(1).float() 113 | 114 | for layer in self.layers: 115 | layerdrop_prob = np.random.random() 116 | if not self.training or (layerdrop_prob > self.layerdrop): 117 | x = nn.functional.glu(layer(x), dim=1) 118 | if padding_mask is not None: 119 | out_lens = ((out_lens - 1) / self.stride + 1).floor() 120 | # B x C x T -> B x T x C 121 | x = x.transpose(1, 2) 122 | 123 | if self.post_proj is not None: 124 | x = x + 0.5 * self.post_proj(x) 125 | x = self.post_proj_ln(x) 126 | 127 | out_padding_mask = None 128 | if padding_mask is not None: 129 | out_padding_mask = lengths_to_padding_mask(out_lens.long()) 130 | return x, out_padding_mask 131 | 132 | 133 | class GradMultiply(torch.autograd.Function): 134 | @staticmethod 135 | def forward(ctx, x, scale): 136 | ctx.scale = scale 137 | res = x.new(x) 138 | return res 139 | 140 | @staticmethod 141 | def backward(ctx, grad): 142 | return grad * ctx.scale, None 143 | 144 | 145 | def deep_to_device(obj, device): 146 | if isinstance(obj, torch.Tensor): 147 | return obj.to(device) 148 | elif isinstance(obj, dict): 149 | return {k: deep_to_device(v, device) for k, v in obj.items()} 150 | elif isinstance(obj, list): 151 | return [deep_to_device(v, device) for v in obj] 152 | else: 153 | return obj 154 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data.data_util import load_wave, LANG_DICT 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class CoVoSTDataset(Dataset): 8 | def __init__(self, audio_info_list, tokenizer, sample_rate=16000) -> None: 9 | super().__init__() 10 | 11 | self.audio_info_list = audio_info_list 12 | self.tokenizer = tokenizer 13 | self.sample_rate = sample_rate 14 | self.max_text_length = 128 15 | 16 | def __len__(self): 17 | return len(self.audio_info_list) 18 | 19 | def __getitem__(self, id): 20 | record = self.audio_info_list[id] 21 | 22 | audio_path = os.path.join(record['audio_root'], record["path"]) 23 | 24 | src_lang = record["src_lang"] 25 | tgt_lang = record["tgt_lang"] 26 | 27 | transcription = record["sentence"] 28 | translation = record["translation"] 29 | 30 | # # text 31 | tokenize_res = self.tokenize(src_lang, tgt_lang, transcription, translation) 32 | 33 | # audio 34 | audio, duration = load_wave(audio_path, sample_rate=self.sample_rate) 35 | audio = audio.flatten() 36 | 37 | return { 38 | 'index': id, 39 | "transcription_ids": tokenize_res[0], 40 | "translation_ids": tokenize_res[1], 41 | "transcription_labels": tokenize_res[2], 42 | "translation_labels": tokenize_res[3], 43 | "audio": audio, 44 | "duration": duration, 45 | "src_lang": src_lang, 46 | "tgt_lang": tgt_lang, 47 | "audio_path": audio_path, 48 | } 49 | 50 | def tokenize(self, src_lang, tgt_lang, transcription, translation): 51 | raise NotImplementedError 52 | 53 | 54 | class ComSTDataset(torch.utils.data.Dataset): 55 | def __init__(self, audio_info_list, tokenizer, sample_rate=16000, cfg=None) -> None: 56 | super().__init__() 57 | 58 | self.audio_info_list = audio_info_list 59 | self.tokenizer = tokenizer 60 | self.sample_rate = sample_rate 61 | self.max_text_length = 128 62 | self.cfg = cfg 63 | 64 | def __len__(self): 65 | return len(self.audio_info_list) 66 | 67 | def __getitem__(self, id): 68 | record = self.audio_info_list[id] 69 | 70 | audio_path = os.path.join(record['audio_root'], record["path"]) 71 | 72 | src_lang = record["src_lang"] 73 | tgt_lang = record["tgt_lang"] 74 | 75 | transcription = record["sentence"] 76 | translation = record["translation"] 77 | 78 | # text 79 | tokenize_res = self.tokenize(src_lang, tgt_lang, transcription, translation) 80 | 81 | # audio 82 | audio, duration = load_wave(audio_path, sample_rate=self.sample_rate) 83 | audio = audio.flatten() 84 | 85 | return { 86 | 'index': id, 87 | "transcription_ids": tokenize_res[0], 88 | "translation_ids": tokenize_res[1], 89 | "transcription_labels": tokenize_res[2], 90 | "translation_labels": tokenize_res[3], 91 | "audio": audio, 92 | "duration": duration, 93 | "src_lang": src_lang, 94 | "tgt_lang": tgt_lang, 95 | "audio_path": audio_path, 96 | } 97 | 98 | def tokenize(self, src_lang, tgt_lang, transcription, translation): 99 | # text 100 | self.tokenizer.src_lang = LANG_DICT[src_lang]['mbart'] 101 | self.tokenizer.tgt_lang = LANG_DICT[tgt_lang]['mbart'] 102 | 103 | tokenize_res = self.tokenizer(transcription, text_target=translation, max_length=self.max_text_length, 104 | truncation=True, return_tensors="np") 105 | transcription_ids = tokenize_res['input_ids'][0].tolist() 106 | translation_ids = tokenize_res['labels'][0].tolist() 107 | 108 | transcription_labels = transcription_ids[1:] + [self.tokenizer.pad_token_id] 109 | translation_labels = translation_ids[1:] + [self.tokenizer.pad_token_id] 110 | return transcription_ids, translation_ids, transcription_labels, translation_labels 111 | 112 | 113 | class CascadeDataset(Dataset): 114 | def __init__(self, audio_info_list, tokenizer, sample_rate, device=None) -> None: 115 | super().__init__() 116 | 117 | self.audio_info_list = audio_info_list 118 | self.sample_rate = sample_rate 119 | self.tokenizer = tokenizer 120 | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") 121 | 122 | def __len__(self): 123 | return len(self.audio_info_list) 124 | 125 | def __getitem__(self, id): 126 | record = self.audio_info_list[id] 127 | audio_path = os.path.join(record['audio_root'], record["path"]) 128 | translation = record["translation"] 129 | transcription = record["sentence"] 130 | 131 | # audio 132 | audio, duration = load_wave(audio_path, sample_rate=self.sample_rate) 133 | audio = audio.flatten() 134 | 135 | return { 136 | "audio": audio, 137 | "audio_path": audio_path, 138 | "transcription": transcription, 139 | "translation": translation, 140 | "src_lang": LANG_DICT[record['src_lang']]['whisper'], 141 | "tgt_lang": LANG_DICT[record['tgt_lang']]['whisper'], 142 | "m_src_lang": LANG_DICT[record['src_lang']]['mbart'], 143 | } 144 | 145 | 146 | class MbartDataset(Dataset): 147 | def __init__(self, audio_info_list, tokenizer) -> None: 148 | super().__init__() 149 | 150 | self.audio_info_list = audio_info_list 151 | self.tokenizer = tokenizer 152 | self.max_length = 128 153 | 154 | def __len__(self): 155 | return len(self.audio_info_list) 156 | 157 | def __getitem__(self, id): 158 | record = self.audio_info_list[id] 159 | 160 | translation = record["translation"] 161 | transcription = record["sentence"] 162 | src_lang = record["src_lang"] 163 | tgt_lang = record["tgt_lang"] 164 | 165 | self.tokenizer.src_lang = LANG_DICT[src_lang]['mbart'] 166 | self.tokenizer.tgt_lang = LANG_DICT[tgt_lang]['mbart'] 167 | 168 | encoded_ids = self.tokenizer(text=transcription, text_target=translation, max_length=self.max_length, 169 | truncation=True, return_tensors="np") 170 | encoded_src = encoded_ids['input_ids'][0].tolist() 171 | encoded_tgt = encoded_ids['labels'][0].tolist() 172 | 173 | dec_input_ids = [2] + encoded_tgt[:-1] 174 | 175 | return { 176 | "labels": encoded_tgt, 177 | "dec_input_ids": dec_input_ids, 178 | "enc_input_ids": encoded_src, 179 | "tgt_lang": LANG_DICT[tgt_lang]['whisper'] 180 | } 181 | -------------------------------------------------------------------------------- /modules/cascade.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torchmetrics 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning import LightningModule 8 | from transformers import WhisperTokenizer 9 | 10 | if __name__ == "__main__": 11 | import sys 12 | 13 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(BASE_DIR) 15 | 16 | import Whisper 17 | from data.data_util import load_data_record, pad_trim_audio 18 | from data.dataset import CascadeDataset 19 | from model.model_util import load_mbart_model, load_mbart_tokenizer 20 | from criterion.metric_util import get_segment_tokenizers, preprocess_sentence 21 | from decode.whisper_decode import decode, DecodingOptions 22 | from decode.mbart_decode import decode, DecodingOptions 23 | 24 | pd.options.display.max_rows = 100 25 | pd.options.display.max_colwidth = 1000 26 | 27 | 28 | class CascadeDataCollatorWhithPadding: 29 | def __init__(self, cfg): 30 | self.cfg = cfg 31 | 32 | def __call__(self, features): 33 | audio, translations, transcription, audio_paths, src_langs, tgt_lang, m_src_langs = [], [], [], [], [], [], [] 34 | for f in features: 35 | audio.append(f["audio"]) 36 | translations.append(f["translation"]) 37 | transcription.append(f["transcription"]) 38 | audio_paths.append(f["audio_path"]) 39 | src_langs.append(f["src_lang"]) 40 | tgt_lang.append(f["tgt_lang"]) 41 | m_src_langs.append(f["m_src_lang"]) 42 | 43 | audio_input_feature = pad_trim_audio(audio, self.cfg) 44 | 45 | batch = {} 46 | 47 | batch["input_ids"] = audio_input_feature 48 | batch["audio_paths"] = audio_paths 49 | batch["src_langs"] = src_langs 50 | batch["translations"] = translations 51 | batch["transcription"] = transcription 52 | batch["m_src_langs"] = m_src_langs 53 | batch["tgt_lang"] = tgt_lang 54 | 55 | return batch 56 | 57 | 58 | class CascadeModelModule(LightningModule): 59 | def __init__(self, cfg, joined_dataset: dict, sep_dataset: dict) -> None: 60 | super().__init__() 61 | model_name = cfg.model_name 62 | self.asr_model = Whisper.load_model(model_name, download_root=cfg.cache_dir, device='cpu') 63 | 64 | path = os.path.join(cfg.cache_dir, 'whisper_tokenizer') 65 | self.asr_tokenizer = WhisperTokenizer.from_pretrained(path, language="spanish", cache_dir=cfg.cache_dir, 66 | task='transcribe', predict_timestamps=False) 67 | 68 | self.mt_tokenizer = load_mbart_tokenizer(cfg) 69 | self.mt_model = load_mbart_model(cfg, load_from_local=True, path=cfg.mbart_model_path).eval() 70 | 71 | self.cfg = cfg 72 | self.__train_dataset = joined_dataset.get("train", []) 73 | self.__eval_dataset = joined_dataset.get("dev", []) 74 | self.__test_dataset = sep_dataset.get("test", []) 75 | self.decode_options = DecodingOptions(task='transcribe', beam_size=5, without_timestamps=True) 76 | 77 | asr_state_dict = torch.load(f'{cfg.cache_dir}/{cfg.asr_model_path}', map_location=self.device) 78 | 79 | self.asr_model.load_state_dict(asr_state_dict) 80 | 81 | self.test_metrics = nn.ModuleDict( 82 | {"bleu": nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__test_dataset])}) 83 | self.segment_tokenizers = get_segment_tokenizers() 84 | 85 | def forward(self, x): 86 | return self.asr_model(x) 87 | 88 | def training_step(self, batch, batch_id): 89 | pass 90 | 91 | def validation_step(self, batch, batch_id): 92 | pass 93 | 94 | def on_test_epoch_start(self) -> None: 95 | for k, v in self.test_metrics.items(): 96 | for metric in v: 97 | metric.set_dtype(torch.float32) 98 | 99 | def test_step(self, batch, batch_id, dataloader_idx): 100 | input_ids = batch["input_ids"] 101 | src_langs = batch["src_langs"] 102 | tgt_lang = batch["tgt_lang"] 103 | m_src_lang = batch["m_src_langs"][0] 104 | audio_features = self.asr_model.encoder(input_ids) 105 | 106 | decode_res = decode(self.asr_model.decoder, enc_hidden_states=audio_features, 107 | lang_list=src_langs, options=self.decode_options) 108 | 109 | asr_list = [res.text for res in decode_res] 110 | 111 | self.mt_tokenizer.src_lang = m_src_lang 112 | mt_imput_ids = self.mt_tokenizer(asr_list, return_tensors="pt", padding=True, truncation=True).input_ids.to( 113 | self.device) 114 | mt_gens = self.mt_model.generate(mt_imput_ids, forced_bos_token_id=self.mt_tokenizer.lang_code_to_id["en_XX"], 115 | max_new_tokens=100, num_beams=5) 116 | 117 | o_list = self.mt_tokenizer.batch_decode(mt_gens, skip_special_tokens=True) 118 | l_list = batch["translations"] 119 | preprocess_sentence(o_list, tgt_lang, self.segment_tokenizers) 120 | preprocess_sentence(l_list, tgt_lang, self.segment_tokenizers) 121 | self.test_metrics['bleu'][dataloader_idx](o_list, [[l] for l in l_list]) 122 | 123 | return { 124 | 'asr_list': asr_list, 125 | 'asr_label': batch["transcription"], 126 | 'o_list': o_list, 127 | 'l_list': l_list, 128 | } 129 | 130 | def test_epoch_end(self, outputs): 131 | bleu_scores = [b.compute() * 100 for b in self.test_metrics['bleu']] 132 | for i, bleu in enumerate(bleu_scores): 133 | self.log(f"test_bleu_{i}", round(bleu.item(), 2)) 134 | print(f"test_bleu_{i}", round(bleu.item(), 2)) 135 | self.log('test_bleu_epoch', torch.mean(torch.tensor(bleu_scores))) 136 | print("test_bleu_epoch", torch.mean(torch.tensor(bleu_scores))) 137 | for metrics in self.test_metrics.values(): 138 | for metric in metrics: 139 | metric.reset() 140 | 141 | def configure_optimizers(self): 142 | pass 143 | 144 | def train_dataloader(self): 145 | return None 146 | 147 | def val_dataloader(self): 148 | return None 149 | 150 | def test_dataloader(self): 151 | datasets = [CascadeDataset(test_dataset, self.asr_tokenizer, self.cfg.sample_rate) for test_dataset in 152 | self.__test_dataset] 153 | return [torch.utils.data.DataLoader(dataset, 154 | batch_size=self.cfg.test_batch_size, 155 | num_workers=self.cfg.num_worker, 156 | collate_fn=CascadeDataCollatorWhithPadding(self.cfg) 157 | ) for dataset in datasets] 158 | 159 | 160 | if __name__ == "__main__": 161 | from config.parse_yaml_args import parse_args_and_yaml 162 | from model.model_util import deep_to_device 163 | 164 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 165 | cfg = parse_args_and_yaml(config_path="../config/exp_spec/cascade.yaml") 166 | DATA_ROOT = cfg.data_root 167 | cfg.batch_size = cfg.test_batch_size = 10 168 | language_list = cfg.language_list 169 | 170 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 171 | for split in ["train", "dev", "test"]: 172 | joined_data_pair_lists[split], sep_data_pair_lists[split] = load_data_record(DATA_ROOT, split, 173 | language_list=language_list) 174 | module = CascadeModelModule(cfg, joined_data_pair_lists, sep_data_pair_lists).cuda().eval().half() 175 | 176 | loader = module.test_dataloader()[0] 177 | 178 | with torch.no_grad(): 179 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 180 | for b in loader: 181 | b = deep_to_device(b, "cuda") 182 | 183 | test_res = module.test_step(b, 0, 0) 184 | print(test_res) 185 | module.test_epoch_end([test_res]) 186 | 187 | break 188 | -------------------------------------------------------------------------------- /modules/whisper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning import LightningModule 8 | import torchmetrics 9 | from transformers import WhisperTokenizer 10 | 11 | if __name__ == "__main__": 12 | import sys 13 | 14 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 15 | sys.path.append(BASE_DIR) 16 | 17 | from decode.whisper_decode import decode, DecodingOptions 18 | from data.data_util import LANG_DICT, load_data_record, pad_trim_audio 19 | from data.dataset import CoVoSTDataset 20 | from model.optimizer import configure_optimizer_schedular 21 | from criterion.metric_util import get_segment_tokenizers, preprocess_sentence 22 | import Whisper 23 | 24 | 25 | class WhisperTranslateDataset(CoVoSTDataset): 26 | def __init__(self, audio_info_list, tokenizer, sample_rate=16000) -> None: 27 | super().__init__(audio_info_list, tokenizer, sample_rate) 28 | 29 | def tokenize(self, src_lang, tgt_lang, transcription, translation): 30 | self.tokenizer.set_prefix_tokens(language=src_lang) 31 | translation_ids = self.tokenizer(text=translation, return_tensors="np", max_length=self.max_text_length, 32 | truncation=True).input_ids[0].tolist() 33 | translation_labels = translation_ids[1:] + [self.tokenizer.eos_token_id] 34 | return None, translation_ids, None, translation_labels 35 | 36 | 37 | class WhisperDataCollatorWhithPadding: 38 | def __init__(self, cfg, pad_token_id=-100): 39 | self.cfg = cfg 40 | self.pad_token_id = pad_token_id 41 | 42 | def __call__(self, features): 43 | audio, labels, dec_input_ids, audio_paths, src_lang, tgt_lang = [], [], [], [], [], [] 44 | for f in features: 45 | audio.append(f["audio"]) 46 | labels.append(f["translation_labels"]) 47 | dec_input_ids.append(f["translation_ids"]) 48 | audio_paths.append(f["audio_path"]) 49 | src_lang.append(LANG_DICT[f["src_lang"]]['whisper']) 50 | tgt_lang.append(LANG_DICT[f["tgt_lang"]]['whisper']) 51 | 52 | # audio 53 | audio_input_feature = pad_trim_audio(audio, self.cfg) 54 | 55 | label_lengths = [len(lab) for lab in labels] 56 | dec_input_ids_length = [len(e) for e in dec_input_ids] 57 | max_label_len = max(label_lengths + dec_input_ids_length) 58 | 59 | labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in 60 | zip(labels, label_lengths)] 61 | dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in 62 | zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id 63 | 64 | batch = { 65 | "labels": labels, 66 | "dec_input_ids": dec_input_ids, 67 | } 68 | 69 | batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()} 70 | batch["input_ids"] = audio_input_feature 71 | batch["audio_paths"] = audio_paths 72 | batch["src_lang"] = src_lang 73 | batch["tgt_lang"] = tgt_lang 74 | 75 | return batch 76 | 77 | 78 | class WhisperModelModule(LightningModule): 79 | def __init__(self, cfg, joined_dataset: dict, sep_dataset: dict) -> None: 80 | super().__init__() 81 | 82 | self.model = Whisper.load_model( 83 | cfg.whisper_name, 84 | device='cpu', 85 | download_root=cfg.cache_dir, ) 86 | if cfg.use_acti_ckpt: 87 | self.model.enable_acti_ckpt() 88 | 89 | self.tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-large-v2', language="spanish", 90 | cache_dir=cfg.cache_dir, 91 | task='translate', predict_timestamps=False) 92 | 93 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100) 94 | 95 | self.cfg = cfg 96 | self.__train_dataset = joined_dataset.get("train", []) 97 | self.__eval_dataset = sep_dataset.get("dev", []) 98 | self.__test_dataset = sep_dataset.get("test", []) 99 | self.decode_options = DecodingOptions(task='translate', beam_size=5, without_timestamps=True) 100 | 101 | self.valid_metrics = nn.ModuleDict( 102 | {"bleu": nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__eval_dataset]), 103 | "loss": nn.ModuleList([torchmetrics.MeanMetric(compute_on_step=False) for _ in self.__eval_dataset])}) 104 | 105 | self.test_metrics = nn.ModuleDict( 106 | {"bleu": nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__test_dataset])}) 107 | self.segment_tokenizers = get_segment_tokenizers() 108 | 109 | def forward(self, x): 110 | return self.model(x) 111 | 112 | def decode(self, audio_features, labels, src_langs, tgt_lang_codes, dataloader_idx, metrics): 113 | labels[labels == -100] = self.tokenizer.eos_token_id 114 | 115 | decode_res = decode(self.model.decoder, enc_hidden_states=audio_features, 116 | lang_list=src_langs, options=self.decode_options) 117 | 118 | o_list = [res.text for res in decode_res] 119 | l_list = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 120 | o_list_ = o_list.copy() 121 | l_list_ = l_list.copy() 122 | preprocess_sentence(o_list, tgt_lang_codes, self.segment_tokenizers) 123 | preprocess_sentence(l_list, tgt_lang_codes, self.segment_tokenizers) 124 | metrics['bleu'][dataloader_idx](o_list, [[l] for l in l_list]) 125 | result = { 126 | 'o_list': o_list_, 127 | 'l_list': l_list_, 128 | } 129 | 130 | return result 131 | 132 | def training_step(self, batch, batch_id): 133 | input_ids = batch["input_ids"] 134 | labels = batch["labels"].long() 135 | dec_input_ids = batch["dec_input_ids"].long() 136 | 137 | audio_features = self.model.encoder(input_ids) 138 | 139 | out = self.model.decoder(dec_input_ids, audio_features) 140 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 141 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True, on_epoch=True) 142 | return loss 143 | 144 | def on_validation_epoch_start(self) -> None: 145 | for k, v in self.valid_metrics.items(): 146 | for metric in v: 147 | metric.set_dtype(torch.float32) 148 | 149 | def validation_step(self, batch, batch_id, dataloader_idx): 150 | input_ids = batch["input_ids"] 151 | labels = batch["labels"].long() 152 | dec_input_ids = batch["dec_input_ids"].long() 153 | 154 | audio_features = self.model.encoder(input_ids) 155 | logits = self.model.decoder(dec_input_ids, audio_features) 156 | 157 | loss = self.loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) 158 | 159 | result = self.decode(audio_features, labels, batch["src_lang"], batch["tgt_lang"], dataloader_idx, 160 | self.valid_metrics) 161 | self.valid_metrics['loss'][dataloader_idx](loss) 162 | 163 | return { 164 | "loss": loss, 165 | "result": result, 166 | } 167 | 168 | def validation_epoch_end(self, outputs): 169 | loss_scores = [l.compute() for l in self.valid_metrics['loss']] 170 | self.log('valid_loss_epoch', torch.mean(torch.tensor(loss_scores))) 171 | print("valid_loss_epoch", torch.mean(torch.tensor(loss_scores))) 172 | bleu_scores = [b.compute() * 100 for b in self.valid_metrics['bleu']] 173 | self.log('valid_bleu_epoch', torch.mean(torch.tensor(bleu_scores))) 174 | print("valid_bleu_epoch", torch.mean(torch.tensor(bleu_scores))) 175 | for metrics in self.valid_metrics.values(): 176 | for metric in metrics: 177 | metric.reset() 178 | 179 | def on_test_epoch_start(self) -> None: 180 | for k, v in self.test_metrics.items(): 181 | for metric in v: 182 | metric.set_dtype(torch.float32) 183 | 184 | def test_step(self, batch, batch_id, dataloader_idx): 185 | input_ids = batch["input_ids"] 186 | labels = batch["labels"].long() 187 | 188 | audio_features = self.model.encoder(input_ids) 189 | 190 | result = self.decode(audio_features, labels, batch["src_lang"], batch["tgt_lang"], dataloader_idx, 191 | self.test_metrics) 192 | 193 | return { 194 | "result": result, 195 | } 196 | 197 | def test_epoch_end(self, outputs): 198 | bleu_scores = [b.compute() * 100 for b in self.test_metrics['bleu']] 199 | for i, bleu in enumerate(bleu_scores): 200 | self.log(f"test_bleu_{i}", round(bleu.item(), 2)) 201 | print(f"test_bleu_{i}", round(bleu.item(), 2)) 202 | self.log('test_bleu_epoch', torch.mean(torch.tensor(bleu_scores))) 203 | print("test_bleu_epoch", torch.mean(torch.tensor(bleu_scores))) 204 | for metrics in self.test_metrics.values(): 205 | for metric in metrics: 206 | metric.reset() 207 | 208 | def configure_optimizers(self): 209 | optimizer, scheduler = configure_optimizer_schedular( 210 | cfg=self.cfg, 211 | params_generator=self.named_parameters, 212 | num_training_steps=self.trainer.estimated_stepping_batches 213 | ) 214 | self.optimizer = optimizer 215 | self.scheduler = scheduler 216 | 217 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 218 | 219 | def train_dataloader(self): 220 | dataset = WhisperTranslateDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate, ) 221 | return DataLoader(dataset, 222 | batch_size=self.cfg.batch_size, 223 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, 224 | collate_fn=WhisperDataCollatorWhithPadding(self.cfg) 225 | ) 226 | 227 | def val_dataloader(self): 228 | datasets = [WhisperTranslateDataset(dataset, self.tokenizer, self.cfg.sample_rate, ) 229 | for dataset in self.__eval_dataset] 230 | return [DataLoader(dataset, 231 | batch_size=self.cfg.test_batch_size, 232 | num_workers=self.cfg.num_worker, 233 | collate_fn=WhisperDataCollatorWhithPadding(self.cfg) 234 | ) for dataset in datasets] 235 | 236 | def test_dataloader(self): 237 | datasets = [WhisperTranslateDataset(dataset, self.tokenizer, self.cfg.sample_rate, ) 238 | for dataset in self.__test_dataset] 239 | return [DataLoader(dataset, 240 | batch_size=self.cfg.test_batch_size, 241 | num_workers=self.cfg.num_worker, 242 | collate_fn=WhisperDataCollatorWhithPadding(self.cfg) 243 | ) for dataset in datasets] 244 | 245 | 246 | if __name__ == "__main__": 247 | from config.parse_yaml_args import parse_args_and_yaml 248 | from model.model_util import deep_to_device 249 | 250 | pd.options.display.max_rows = 100 251 | pd.options.display.max_colwidth = 1000 252 | 253 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 254 | cfg = parse_args_and_yaml(config_path="../config/exp_spec/whisper.yaml") 255 | DATA_ROOT = cfg.data_root 256 | language_list = cfg.language_list 257 | 258 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 259 | for split in ["train", "dev", "test"]: 260 | joined_data_pair_lists[split], sep_data_pair_lists[split] = load_data_record(DATA_ROOT, split, 261 | language_list=language_list, ) 262 | 263 | cfg.batch_size = cfg.test_batch_size = 10 264 | cfg.num_worker = 0 265 | 266 | module = WhisperModelModule(cfg, joined_data_pair_lists, sep_data_pair_lists).to(cfg.device).eval() 267 | 268 | loader = module.test_dataloader()[0] 269 | 270 | with torch.no_grad(): 271 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 272 | for b in loader: 273 | b = deep_to_device(b, cfg.device) 274 | train_res = module.training_step(b, 0) 275 | print(train_res) 276 | 277 | valid_res = module.validation_step(b, 0, 0) 278 | print(valid_res) 279 | module.validation_epoch_end([valid_res]) 280 | 281 | test_res = module.test_step(b, 0, 0) 282 | print(test_res) 283 | module.test_epoch_end([test_res]) 284 | 285 | break 286 | -------------------------------------------------------------------------------- /model/ComSL_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | from model.model_util import Conv1dAdaptor, GradMultiply, load_mbart_model 6 | import Whisper 7 | 8 | 9 | class ComSTDecoder(nn.Module): 10 | def __init__(self, cfg, mbart_model) -> None: 11 | super().__init__() 12 | 13 | self.decoder = TextDecoder(cfg, mbart_model) 14 | self.d_model = 1024 15 | self.cfg = cfg 16 | self.num_updates = 0 17 | 18 | def forward(self, dec_input_ids, encoder_out, mlm_mask=None, past_key_values=None, use_cache=False): 19 | 20 | if isinstance(encoder_out, tuple): # only for training 21 | assert len(encoder_out) == 2 and past_key_values is None and use_cache is False 22 | 23 | transcript_dec_inputs, translate_dec_inputs = dec_input_ids 24 | 25 | dec_outs = {'translate': [], 'transcript': []} 26 | for i in range(2): 27 | if transcript_dec_inputs is not None: 28 | if i == 0 and self.cfg.asr_loss_weight > 0: 29 | logits, misc_dict = self.decoder(transcript_dec_inputs, encoder_out[i]['encoder_out']) 30 | else: 31 | logits = None 32 | dec_outs['transcript'].append(logits) 33 | 34 | if translate_dec_inputs is not None: 35 | if i == 0 and self.cfg.spch_loss_weight == 0: 36 | logits = None 37 | elif i == 1 and self.cfg.text_loss_weight == 0: 38 | logits = None 39 | else: 40 | logits, misc_dict = self.decoder(translate_dec_inputs, encoder_out[i]['encoder_out']) 41 | dec_outs['translate'].append(logits) 42 | 43 | mix_dec_outs = None 44 | erm_loss = None 45 | if self.cfg.use_cml and 'mix_enc_out_spch' in encoder_out[1]: 46 | mix_enc_out = [encoder_out[1]['mix_enc_out_spch'], encoder_out[1]['mix_enc_out_text']] 47 | mix_dec_outs = {'translate': [], 'transcript': []} 48 | for i in range(2): 49 | if transcript_dec_inputs is not None: 50 | logits, misc_dict = self.decoder(transcript_dec_inputs, mix_enc_out[i]) 51 | mix_dec_outs['transcript'].append(logits) 52 | 53 | if translate_dec_inputs is not None: 54 | logits, misc_dict = self.decoder(translate_dec_inputs, mix_enc_out[i]) 55 | mix_dec_outs['translate'].append(logits) 56 | 57 | if self.cfg.use_erm: 58 | erm_loss = (encoder_out[0]['reg_hidden'] - encoder_out[1]['reg_hidden']).norm(2, dim=-1) 59 | 60 | return dec_outs, {"erm_loss": erm_loss, "mix_dec_outs": mix_dec_outs, 'mlm_mask': mlm_mask} 61 | 62 | else: # for decoding 63 | if isinstance(encoder_out, Tensor): 64 | encoder_out = {"encoder_out": encoder_out, "encoder_padding_mask": None} 65 | return self.decoder(dec_input_ids, encoder_out['encoder_out'], 66 | encoder_padding_mask=encoder_out["encoder_padding_mask"], 67 | past_key_values=past_key_values, 68 | use_cache=use_cache) 69 | 70 | 71 | class TextDecoder(nn.Module): 72 | def __init__(self, cfg, mbart_model) -> None: 73 | super().__init__() 74 | self.decoder = mbart_model.base_model.decoder 75 | self.lm_head = mbart_model.lm_head 76 | self.embed_tokens = self.decoder.embed_tokens 77 | self.embed_scale = self.decoder.embed_scale 78 | self.cfg = cfg 79 | 80 | def forward(self, dec_input_ids, encoder_out, encoder_padding_mask=None, past_key_values=None, use_cache=False): 81 | output_attentions = True if past_key_values is None else False 82 | text_embeds = self.embed_tokens(dec_input_ids) * self.embed_scale 83 | 84 | dec_output = self.decoder(inputs_embeds=text_embeds, 85 | encoder_hidden_states=encoder_out, 86 | encoder_attention_mask=encoder_padding_mask, 87 | past_key_values=past_key_values, 88 | use_cache=use_cache, 89 | output_attentions=output_attentions) 90 | last_hidden_state = dec_output.last_hidden_state 91 | 92 | attn_key_values = dec_output.past_key_values if use_cache else None 93 | logits = self.lm_head(last_hidden_state) 94 | return logits, {"attn_key_values": attn_key_values} 95 | 96 | 97 | class TextEncoder(nn.Module): 98 | def __init__(self, cfg, text_encoder) -> None: 99 | super().__init__() 100 | self.encoder = text_encoder 101 | self.cfg = cfg 102 | self.embed_tokens = text_encoder.embed_tokens 103 | self.embed_scale = text_encoder.embed_scale 104 | self.erm_layer = cfg.erm_layer 105 | 106 | def forward(self, src_tokens, masked_src_tokens, spch_embeds, **kwargs): 107 | enc_out = {} 108 | inputs_embeds = self.embed_tokens(src_tokens) * self.embed_scale 109 | enc_out['embeds'] = inputs_embeds 110 | encoder_out = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, output_hidden_states=True) 111 | enc_out['encoder_out'] = encoder_out.last_hidden_state 112 | enc_out['encoder_padding_mask'] = None 113 | text_hidden_ori = encoder_out.hidden_states[self.erm_layer] 114 | enc_out['text_hidden_ori'] = text_hidden_ori / text_hidden_ori.norm(2, dim=-1, keepdim=True) 115 | 116 | if self.cfg.use_cml and masked_src_tokens is not None: 117 | masked_inputs_embeds = self.embed_tokens(masked_src_tokens) * self.embed_scale 118 | text_len = masked_inputs_embeds.shape[1] 119 | mix_input_embeds = torch.cat([masked_inputs_embeds, spch_embeds], dim=1) 120 | mix_enc_out = self.encoder(inputs_embeds=mix_input_embeds, attention_mask=None, output_hidden_states=True) 121 | enc_out['mix_enc_out_text'] = mix_enc_out.last_hidden_state[:, :text_len, :] 122 | enc_out['mix_enc_out_spch'] = mix_enc_out.last_hidden_state[:, text_len:, :] 123 | text_hidden_mix = mix_enc_out.hidden_states[self.erm_layer][:, :text_len, :] 124 | reg_hidden = mix_enc_out.hidden_states[self.erm_layer][:, text_len:, :] 125 | enc_out['reg_hidden'] = reg_hidden / reg_hidden.norm(2, dim=-1, keepdim=True) 126 | enc_out['text_hidden_mix'] = text_hidden_mix / text_hidden_mix.norm(2, dim=-1, keepdim=True) 127 | 128 | return enc_out 129 | 130 | 131 | class SpchEncoder(nn.Module): 132 | 133 | def __init__(self, cfg, text_encoder) -> None: 134 | super().__init__() 135 | whisper_model = Whisper.load_model( 136 | cfg.whisper_name, 137 | download_root=cfg.cache_dir, 138 | device='cpu' 139 | ) 140 | if cfg.spch_init_model_path is not None: 141 | whisper_model.load_state_dict(torch.load(os.path.join(cfg.cache_dir, cfg.spch_init_model_path), 142 | map_location="cpu")) 143 | print("loaded asr model from {}".format(os.path.join(cfg.cache_dir, cfg.spch_init_model_path))) 144 | self.spch_encoder = whisper_model.encoder 145 | self.spch_encoder.gradient_checkpointing = cfg.use_acti_ckpt 146 | if cfg.spch_n_layers > 0: 147 | self.spch_encoder.blocks = self.spch_encoder.blocks[:cfg.spch_n_layers] 148 | self.down_sampler1 = Conv1dAdaptor(whisper_model.dims.n_audio_state, 1024, n_layers=1, proj=True) 149 | self.down_sampler2 = Conv1dAdaptor(1024, 1024, n_layers=1, ) 150 | self.mbart_encoder = text_encoder 151 | self.erm_layer = cfg.erm_layer 152 | self.cfg = cfg 153 | 154 | self.embed_tokens = text_encoder.embed_tokens 155 | self.embed_scale = text_encoder.embed_scale 156 | 157 | def forward(self, mel, src_lang_ids): 158 | enc_out = {} 159 | audio_embeds = self.down_sampler1(self.spch_encoder(mel))[0] 160 | inputs_embeds = self.down_sampler2(audio_embeds)[0] 161 | lang_embeds = self.embed_tokens(src_lang_ids.reshape(-1, 1)) * self.embed_scale 162 | inputs_embeds = torch.cat([lang_embeds, inputs_embeds], dim=1) 163 | encoder_out = self.mbart_encoder(inputs_embeds=inputs_embeds, attention_mask=None, output_hidden_states=True) 164 | regulation_hidden = encoder_out.hidden_states[self.erm_layer] 165 | 166 | enc_out['encoder_out'] = encoder_out.last_hidden_state 167 | enc_out['encoder_padding_mask'] = None 168 | enc_out['reg_hidden'] = regulation_hidden / regulation_hidden.norm(dim=-1, keepdim=True) 169 | enc_out['embeds'] = inputs_embeds 170 | return enc_out 171 | 172 | 173 | class ComSTEncoder(nn.Module): 174 | def __init__( 175 | self, 176 | args, 177 | mbart_model, 178 | ): 179 | super().__init__() 180 | 181 | mbart_encoder = mbart_model.base_model.encoder 182 | 183 | self.spch_encoder = SpchEncoder(args, mbart_encoder) 184 | self.text_encoder = TextEncoder(args, mbart_encoder) 185 | self.enc_grad_mult = args.enc_grad_mult 186 | 187 | def mult_rst_grad(self, rst, ratio): 188 | assert isinstance(rst, dict) # instead of EncoderOut 189 | rst["encoder_out"] = GradMultiply.apply(rst["encoder_out"], ratio) 190 | return rst 191 | 192 | def forward( 193 | self, 194 | mel, 195 | src_tokens, 196 | src_lang_ids, 197 | masked_src_tokens, 198 | **kwargs 199 | ): 200 | 201 | if mel is None and src_tokens is None: 202 | raise ValueError( 203 | "src_tokens and src_txt_tokens cannot be None at the same time" 204 | ) 205 | ret1 = None 206 | ret2 = None 207 | if mel is not None: 208 | ret1 = self.spch_encoder(mel, src_lang_ids) 209 | 210 | if src_tokens is not None: 211 | ret2 = self.text_encoder( 212 | src_tokens, 213 | masked_src_tokens, 214 | ret1['embeds'], 215 | **kwargs 216 | ) 217 | 218 | def merge_output(rst1, rst2): 219 | if self.enc_grad_mult != 1.0 and self.training: 220 | rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult) 221 | rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult) 222 | rst = (rst1, rst2) 223 | return rst 224 | 225 | return merge_output(ret1, ret2) 226 | 227 | 228 | class ComSTModel(nn.Module): 229 | def __init__(self, args): 230 | super().__init__() 231 | mbart_model = load_mbart_model(args, load_from_local=True, path=args.language_init_model_path) 232 | self.encoder = ComSTEncoder(args, mbart_model) 233 | self.decoder = ComSTDecoder(args, mbart_model) 234 | 235 | self.args = args 236 | self.num_updates = 0 237 | if args.language_regularization_model_path is not None: 238 | self.reg_model = load_mbart_model(args, load_from_local=True, 239 | path=args.language_regularization_model_path).eval() 240 | self.reg_model.requires_grad_(False) 241 | else: 242 | self.reg_model = None 243 | 244 | def set_num_updates(self, num_updates, current_epoch=None): 245 | """Set the number of parameters updates.""" 246 | self.num_updates = num_updates 247 | self.decoder.num_updates = num_updates 248 | if self.args.disable_spch_grad_epoch > 0: 249 | if current_epoch < self.args.disable_spch_grad_epoch: 250 | self.encoder.spch_encoder.spch_encoder.requires_grad_(False) 251 | else: 252 | self.encoder.spch_encoder.spch_encoder.requires_grad_(True) 253 | 254 | def forward( 255 | self, 256 | mel, 257 | src_lang_ids, 258 | tokens, 259 | masked_src_tokens, 260 | mlm_mask, 261 | **kwargs 262 | ): 263 | encoder_out = self.encoder( 264 | mel, 265 | src_tokens=tokens[0], 266 | src_lang_ids=src_lang_ids, 267 | masked_src_tokens=masked_src_tokens, 268 | **kwargs 269 | ) 270 | decoder_out = self.decoder( 271 | tokens, 272 | encoder_out=encoder_out, 273 | mlm_mask=mlm_mask, 274 | ) 275 | if self.reg_model is not None: 276 | with torch.no_grad(): 277 | reg_logits = self.reg_model( 278 | input_ids=tokens[0], 279 | decoder_input_ids=tokens[1], 280 | ).logits 281 | decoder_out[1]['reg_logits'] = reg_logits 282 | 283 | return decoder_out, encoder_out 284 | -------------------------------------------------------------------------------- /modules/mbart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torchmetrics 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning import LightningModule 8 | from transformers import ( 9 | AdamW, 10 | get_polynomial_decay_schedule_with_warmup, 11 | ) 12 | 13 | import sys 14 | 15 | if __name__ == "__main__": 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | sys.path.append(BASE_DIR) 18 | 19 | from data.data_util import load_data_record 20 | from data.dataset import MbartDataset 21 | from decode.mbart_decode import decode, DecodingOptions 22 | 23 | from model.model_util import load_mbart_tokenizer 24 | from model.mBART_model import MbartModel 25 | from criterion.metric_util import get_segment_tokenizers, preprocess_sentence 26 | 27 | 28 | class MbartCollatorWhithPadding: 29 | def __call__(self, features): 30 | enc_input_ids, labels, dec_input_ids, tgt_lang = [], [], [], [] 31 | for f in features: 32 | enc_input_ids.append(f["enc_input_ids"]) 33 | labels.append(f["labels"]) 34 | dec_input_ids.append(f["dec_input_ids"]) 35 | tgt_lang.append(f["tgt_lang"]) 36 | 37 | label_lengths = [len(lab) for lab in labels] 38 | dec_input_ids_length = [len(e) for e in dec_input_ids] 39 | max_label_len = max(label_lengths + dec_input_ids_length) 40 | 41 | labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=pad_token_id) for lab, lab_len 42 | in zip(labels, label_lengths)] 43 | dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=pad_token_id) for e, e_len in 44 | zip(dec_input_ids, dec_input_ids_length)] 45 | 46 | enc_input_len = [len(e) for e in enc_input_ids] 47 | max_enc_input_len = max(enc_input_len) 48 | enc_input_ids = [np.pad(e, (0, max_enc_input_len - e_len), 'constant', constant_values=pad_token_id) for 49 | e, e_len in zip(enc_input_ids, enc_input_len)] 50 | 51 | batch = { 52 | "labels": labels, 53 | "dec_input_ids": dec_input_ids, 54 | } 55 | 56 | batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()} 57 | 58 | batch["enc_input_ids"] = torch.tensor(np.array(enc_input_ids)) 59 | batch["tgt_lang"] = tgt_lang 60 | 61 | return batch 62 | 63 | 64 | class MbartModelModule(LightningModule): 65 | def __init__(self, cfg, joined_dataset: dict, sep_dataset: dict) -> None: 66 | super().__init__() 67 | self.tokenizer = load_mbart_tokenizer(cfg) 68 | global pad_token_id 69 | pad_token_id = self.tokenizer.pad_token_id 70 | self.model = MbartModel(cfg) 71 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token_id, label_smoothing=cfg.label_smoothing) 72 | 73 | self.cfg = cfg 74 | self.__train_dataset = joined_dataset.get("train", []) 75 | self.__eval_dataset = sep_dataset.get("dev", []) 76 | self.__test_dataset = sep_dataset.get("test", []) 77 | 78 | self.decode_option = DecodingOptions(beam_size=5) 79 | self.valid_metrics = nn.ModuleDict( 80 | {"bleu": nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__eval_dataset]), 81 | "loss": nn.ModuleList([torchmetrics.MeanMetric(compute_on_step=False) for _ in self.__eval_dataset])}) 82 | 83 | self.test_metrics = nn.ModuleDict( 84 | {"bleu": nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__test_dataset])}) 85 | self.segment_tokenizers = get_segment_tokenizers() 86 | 87 | def forward(self, x): 88 | return self.model(x) 89 | 90 | def training_step(self, batch, batch_id): 91 | input_ids = batch["enc_input_ids"] 92 | bsz = input_ids.size(0) 93 | labels = batch["labels"].long() 94 | dec_input_ids = batch["dec_input_ids"].long() 95 | 96 | # with torch.no_grad(): 97 | outputs = self.model(input_ids, dec_input_ids) 98 | loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1)) 99 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True, on_epoch=True, sync_dist=True, 100 | batch_size=bsz) 101 | return loss 102 | 103 | def on_validation_epoch_start(self) -> None: 104 | for k, v in self.valid_metrics.items(): 105 | for metric in v: 106 | metric.set_dtype(torch.float32) 107 | 108 | def validation_step(self, batch, batch_id, dataloader_idx=None): 109 | input_ids = batch["enc_input_ids"] 110 | labels = batch["labels"].long() 111 | dec_input_ids = batch["dec_input_ids"].long() 112 | tgt_lang = batch['tgt_lang'] 113 | bsz = input_ids.size(0) 114 | 115 | encoder_hidden_states = self.model.encoder(input_ids) 116 | outputs = self.model.decoder(dec_input_ids, encoder_hidden_states)[0] 117 | loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1)) 118 | gens = decode(self.model.decoder, 119 | tokenizer=self.tokenizer, 120 | enc_hidden_states=encoder_hidden_states, 121 | forced_bos_token_id=dec_input_ids[:, 1], 122 | options=self.decode_option) 123 | 124 | o_list = self.tokenizer.batch_decode(gens, skip_special_tokens=True) 125 | l_list = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 126 | # inputs = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) 127 | 128 | preprocess_sentence(o_list, tgt_lang, self.segment_tokenizers) 129 | preprocess_sentence(l_list, tgt_lang, self.segment_tokenizers) 130 | 131 | self.valid_metrics['loss'][dataloader_idx](loss.cuda()) 132 | self.valid_metrics['bleu'][dataloader_idx](o_list, [[l] for l in l_list]) 133 | 134 | return { 135 | "loss": loss, 136 | "o_list": o_list, 137 | "l_list": l_list, 138 | } 139 | 140 | def validation_epoch_end(self, outputs): 141 | loss_scores = [l.compute() for l in self.valid_metrics['loss']] 142 | self.log('valid_loss_epoch', torch.mean(torch.tensor(loss_scores))) 143 | print("valid_loss_epoch", torch.mean(torch.tensor(loss_scores))) 144 | bleu_scores = [b.compute() * 100 for b in self.valid_metrics['bleu']] 145 | self.log('valid_bleu_epoch', torch.mean(torch.tensor(bleu_scores))) 146 | print("valid_bleu_epoch", torch.mean(torch.tensor(bleu_scores))) 147 | for k, v in self.valid_metrics.items(): 148 | for metric in v: 149 | metric.reset() 150 | 151 | def on_test_epoch_start(self) -> None: 152 | for k, v in self.test_metrics.items(): 153 | for metric in v: 154 | metric.set_dtype(torch.float32) 155 | 156 | def test_step(self, batch, batch_idx, dataloader_idx=None): 157 | input_ids = batch["enc_input_ids"] 158 | bsz = input_ids.size(0) 159 | labels = batch["labels"].long() 160 | dec_input_ids = batch["dec_input_ids"].long() 161 | tgt_lang = batch['tgt_lang'] 162 | encoder_hidden_states = self.model.encoder(input_ids) 163 | gens = decode(self.model.decoder, 164 | tokenizer=self.tokenizer, 165 | enc_hidden_states=encoder_hidden_states, 166 | forced_bos_token_id=dec_input_ids[:, 1], 167 | options=self.decode_option) 168 | 169 | o_list = self.tokenizer.batch_decode(gens, skip_special_tokens=True) 170 | l_list = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 171 | 172 | o_list_ = o_list.copy() 173 | l_list_ = l_list.copy() 174 | 175 | for i in range(len(o_list)): 176 | tokenizer = self.segment_tokenizers.get(tgt_lang[i], self.segment_tokenizers["default"]) 177 | o_list[i] = "".join(tokenizer(o_list[i].rstrip())) 178 | l_list[i] = "".join(tokenizer(l_list[i].rstrip())) 179 | self.test_metrics['bleu'][dataloader_idx](o_list, [[l] for l in l_list]) 180 | 181 | return { 182 | "output": o_list_, 183 | "label": l_list_, 184 | } 185 | 186 | def test_epoch_end(self, outputs): 187 | bleu_scores = [b.compute() * 100 for b in self.test_metrics['bleu']] 188 | for i, bleu in enumerate(bleu_scores): 189 | self.log(f"test_bleu_{i}", round(bleu.item(), 2)) 190 | print(f"test_bleu_{i}", round(bleu.item(), 2)) 191 | self.log('test_bleu_epoch', round(torch.mean(torch.tensor(bleu_scores)).detach().cpu().item(), 2)) 192 | for k, v in self.test_metrics.items(): 193 | for metric in v: 194 | metric.reset() 195 | print("test_bleu_epoch", torch.mean(torch.tensor(bleu_scores))) 196 | 197 | def configure_optimizers(self): 198 | model = self.model 199 | no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] 200 | optimizer_grouped_parameters = [ 201 | { 202 | "params": [p for n, p in model.named_parameters() 203 | if not any(nd in n for nd in no_decay)], 204 | "weight_decay": self.cfg.weight_decay, 205 | }, 206 | { 207 | "params": [p for n, p in model.named_parameters() 208 | if any(nd in n for nd in no_decay)], 209 | "weight_decay": 0.0, 210 | }, 211 | ] 212 | optimizer = AdamW(optimizer_grouped_parameters, 213 | lr=self.cfg.learning_rate, 214 | eps=self.cfg.adam_epsilon, 215 | betas=self.cfg.adam_betas) 216 | self.optimizer = optimizer 217 | 218 | scheduler = get_polynomial_decay_schedule_with_warmup( 219 | optimizer, num_warmup_steps=self.cfg.warmup_steps, 220 | num_training_steps=self.trainer.estimated_stepping_batches 221 | ) 222 | self.scheduler = scheduler 223 | 224 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 225 | 226 | def train_dataloader(self): 227 | dataset = MbartDataset(self.__train_dataset, self.tokenizer) 228 | return DataLoader(dataset, 229 | batch_size=self.cfg.batch_size, 230 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, 231 | collate_fn=MbartCollatorWhithPadding() 232 | ) 233 | 234 | def val_dataloader(self): 235 | datasets = [MbartDataset(dataset, self.tokenizer) for dataset in 236 | self.__eval_dataset] 237 | return [DataLoader(dataset, 238 | batch_size=self.cfg.test_batch_size, 239 | num_workers=self.cfg.num_worker, 240 | collate_fn=MbartCollatorWhithPadding() 241 | ) for dataset in datasets] 242 | 243 | def test_dataloader(self): 244 | datasets = [MbartDataset(dataset, self.tokenizer) for dataset in 245 | self.__test_dataset] 246 | return [DataLoader(dataset, 247 | batch_size=self.cfg.test_batch_size, 248 | num_workers=self.cfg.num_worker, 249 | collate_fn=MbartCollatorWhithPadding() 250 | ) for dataset in datasets] 251 | 252 | 253 | if __name__ == "__main__": 254 | from config.parse_yaml_args import parse_args_and_yaml 255 | from model.model_util import deep_to_device 256 | 257 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 258 | cfg = parse_args_and_yaml(config_path="../config/exp_spec/mbart_en_x.yaml") 259 | DATA_ROOT = cfg.data_root 260 | cfg.batch_size = cfg.test_batch_size = 10 261 | language_list = cfg.language_list 262 | 263 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 264 | for split in ["train", "dev", "test"]: 265 | joined_data_pair_lists[split], sep_data_pair_lists[split] = load_data_record(DATA_ROOT, split, 266 | language_list=language_list) 267 | module = MbartModelModule(cfg, joined_data_pair_lists, sep_data_pair_lists).cuda().eval() 268 | 269 | loader = module.test_dataloader()[0] 270 | 271 | with torch.no_grad(): 272 | for b in loader: 273 | b = deep_to_device(b, cfg.device) 274 | train_res = module.training_step(b, 0) 275 | print(train_res) 276 | 277 | valid_res = module.validation_step(b, 0, 0) 278 | print(valid_res) 279 | module.validation_epoch_end([valid_res]) 280 | 281 | test_res = module.test_step(b, 0, 0) 282 | print(test_res) 283 | module.test_epoch_end([test_res]) 284 | 285 | break 286 | -------------------------------------------------------------------------------- /modules/whisper_asr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torchmetrics 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | 9 | from pytorch_lightning import LightningModule 10 | from transformers import WhisperTokenizer 11 | 12 | if __name__ == "__main__": 13 | import sys 14 | 15 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | sys.path.append(BASE_DIR) 17 | 18 | import Whisper 19 | from Whisper.normalizers import BasicTextNormalizer 20 | from decode.whisper_decode import decode, DecodingOptions 21 | from data.data_util import LANG_DICT, load_data_record 22 | from data.data_util import pad_trim_audio 23 | from data.dataset import CoVoSTDataset 24 | from model.optimizer import configure_optimizer_schedular 25 | from criterion.metric_util import get_segment_tokenizers, preprocess_sentence 26 | 27 | 28 | class WhisperAsrDataset(CoVoSTDataset): 29 | def __init__(self, audio_info_list, tokenizer, sample_rate=16000) -> None: 30 | super().__init__(audio_info_list, tokenizer, sample_rate) 31 | 32 | def tokenize(self, src_lang, tgt_lang, transcription, translation): 33 | self.tokenizer.set_prefix_tokens(language=src_lang) 34 | transcription_ids = self.tokenizer(text=transcription, return_tensors="np", max_length=self.max_text_length, 35 | truncation=True).input_ids[0].tolist() 36 | transcription_labels = transcription_ids[1:] + [self.tokenizer.eos_token_id] 37 | return transcription_ids, None, transcription_labels, None 38 | 39 | 40 | class WhisperAsrDataCollatorWhithPadding: 41 | def __init__(self, cfg, pad_token_id=-100): 42 | self.cfg = cfg 43 | self.pad_token_id = pad_token_id 44 | 45 | 46 | def __call__(self, features): 47 | audio, labels, dec_input_ids, audio_paths, src_langs, tgt_langs = [], [], [], [], [], [] 48 | for f in features: 49 | audio.append(f["audio"]) 50 | labels.append(f["transcription_labels"]) 51 | dec_input_ids.append(f["transcription_ids"]) 52 | audio_paths.append(f["audio_path"]) 53 | src_langs.append(LANG_DICT[f["src_lang"]]['whisper']) 54 | tgt_langs.append(LANG_DICT[f["tgt_lang"]]['whisper']) 55 | 56 | # audio 57 | audio_input_feature = pad_trim_audio(audio, self.cfg) 58 | 59 | label_lengths = [len(lab) for lab in labels] 60 | dec_input_ids_length = [len(e) for e in dec_input_ids] 61 | max_label_len = max(label_lengths + dec_input_ids_length) 62 | 63 | labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in 64 | zip(labels, label_lengths)] 65 | dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in 66 | zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id 67 | 68 | batch = { 69 | "labels": labels, 70 | "dec_input_ids": dec_input_ids, 71 | } 72 | 73 | batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()} 74 | batch["input_ids"] = audio_input_feature 75 | batch["audio_paths"] = audio_paths 76 | batch["src_langs"] = src_langs 77 | batch["tgt_langs"] = tgt_langs 78 | 79 | return batch 80 | 81 | 82 | class WhisperAsrModelModule(LightningModule): 83 | def __init__(self, cfg, joined_dataset: dict, sep_dataset: dict) -> None: 84 | super().__init__() 85 | self.model = Whisper.load_model( 86 | cfg.whisper_name, 87 | device='cpu', 88 | download_root=cfg.cache_dir, 89 | ) 90 | self.tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-large-v2', language="spanish", 91 | cache_dir=cfg.cache_dir, 92 | task='transcribe', predict_timestamps=False) 93 | self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100) 94 | 95 | self.cfg = cfg 96 | self.__train_dataset = joined_dataset.get("train", []) 97 | self.__eval_dataset = sep_dataset.get("dev", []) 98 | self.__test_dataset = sep_dataset.get("test", []) 99 | self.decode_options = DecodingOptions(task='transcribe', beam_size=5, without_timestamps=True) 100 | 101 | self.normalizer = BasicTextNormalizer() 102 | if cfg.use_acti_ckpt: 103 | self.model.enable_acti_ckpt() 104 | 105 | self.valid_metrics = nn.ModuleDict({ 106 | "loss": nn.ModuleList([torchmetrics.MeanMetric(compute_on_step=False) for _ in self.__eval_dataset]), 107 | 'wer': nn.ModuleList([torchmetrics.WordErrorRate(compute_on_step=False) for _ in self.__eval_dataset]), 108 | }) 109 | 110 | self.test_metrics = nn.ModuleDict({ 111 | 'wer': nn.ModuleList([torchmetrics.WordErrorRate(compute_on_step=False) for _ in self.__test_dataset]), 112 | }) 113 | 114 | self.segment_tokenizers = get_segment_tokenizers() 115 | 116 | def forward(self, x): 117 | return self.model(x) 118 | 119 | def decode(self, audio_features, labels, src_langs, tgt_langs, metrics, dataloader_idx): 120 | labels[labels == -100] = self.tokenizer.eos_token_id 121 | 122 | decode_res = decode(self.model.decoder, enc_hidden_states=audio_features, 123 | lang_list=src_langs, options=self.decode_options) 124 | 125 | o_list = [res.text for res in decode_res] 126 | l_list = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 127 | 128 | preprocess_sentence(o_list, src_langs, self.segment_tokenizers) 129 | preprocess_sentence(l_list, src_langs, self.segment_tokenizers) 130 | 131 | o_list_ = [self.normalizer(o) for o in o_list] 132 | l_list_ = [self.normalizer(l) for l in l_list] 133 | 134 | metrics["wer"][dataloader_idx](o_list_, l_list_) 135 | result = { 136 | 'st_res': o_list_, 137 | 'st_ref': l_list_, 138 | } 139 | 140 | return result 141 | 142 | def training_step(self, batch, batch_id): 143 | input_ids = batch["input_ids"] 144 | labels = batch["labels"].long() 145 | dec_input_ids = batch["dec_input_ids"].long() 146 | 147 | audio_features = self.model.encoder(input_ids) 148 | 149 | out = self.model.decoder(dec_input_ids, audio_features) 150 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 151 | self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True, on_epoch=True) 152 | return loss 153 | 154 | def on_validation_epoch_start(self) -> None: 155 | for k, v in self.valid_metrics.items(): 156 | for metric in v: 157 | metric.set_dtype(torch.float32) 158 | 159 | def validation_step(self, batch, batch_id, dataloader_idx=None): 160 | if dataloader_idx is None: 161 | print("warning: dataloader_idx is None") 162 | dataloader_idx = 0 163 | input_ids = batch["input_ids"] 164 | labels = batch["labels"].long() 165 | dec_input_ids = batch["dec_input_ids"].long() 166 | 167 | audio_features = self.model.encoder(input_ids) 168 | out = self.model.decoder(dec_input_ids, audio_features) 169 | 170 | loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) 171 | 172 | result = self.decode(audio_features, labels, batch["src_langs"], batch["tgt_langs"], self.valid_metrics, 173 | dataloader_idx) 174 | 175 | self.valid_metrics['loss'][dataloader_idx](loss) 176 | 177 | return { 178 | "loss": loss, 179 | "result": result 180 | } 181 | 182 | def validation_epoch_end(self, outputs): 183 | loss_scores = [l.compute() for l in self.valid_metrics['loss']] 184 | self.log('valid_loss_epoch', torch.mean(torch.tensor(loss_scores))) 185 | print("valid_loss_epoch", torch.mean(torch.tensor(loss_scores))) 186 | wer_scores = [b.compute() for b in self.valid_metrics['wer']] 187 | self.log('valid_wer_epoch', torch.mean(torch.tensor(wer_scores))) 188 | print("valid_wer_epoch", torch.mean(torch.tensor(wer_scores))) 189 | for metrics in self.valid_metrics.values(): 190 | for metric in metrics: 191 | metric.reset() 192 | 193 | def on_test_epoch_start(self) -> None: 194 | for k, v in self.test_metrics.items(): 195 | for metric in v: 196 | metric.set_dtype(torch.float32) 197 | 198 | def test_step(self, batch, batch_id, dataloader_idx=None): 199 | if dataloader_idx is None: 200 | print("warning: dataloader_idx is None") 201 | dataloader_idx = 0 202 | 203 | input_ids = batch["input_ids"] 204 | labels = batch["labels"].long() 205 | audio_features = self.model.encoder(input_ids) 206 | result = self.decode(audio_features, labels, batch["src_langs"], batch["tgt_langs"], self.test_metrics, 207 | dataloader_idx) 208 | 209 | return { 210 | 'result': result, 211 | } 212 | 213 | def test_epoch_end(self, outputs): 214 | wer_scores = [b.compute() for b in self.test_metrics['wer']] 215 | for i, wer in enumerate(wer_scores): 216 | self.log(f"test_wer_{i}", round(wer.item(), 2)) 217 | print(f"test_wer_{i}", round(wer.item(), 2)) 218 | self.log('test_wer_epoch', torch.mean(torch.tensor(wer_scores))) 219 | print("test_wer_epoch", torch.mean(torch.tensor(wer_scores))) 220 | for metrics in self.test_metrics.values(): 221 | for metric in metrics: 222 | metric.reset() 223 | 224 | def configure_optimizers(self): 225 | optimizer, scheduler = configure_optimizer_schedular( 226 | cfg=self.cfg, 227 | params_generator=self.named_parameters, 228 | num_training_steps=self.trainer.estimated_stepping_batches 229 | ) 230 | self.optimizer = optimizer 231 | self.scheduler = scheduler 232 | 233 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 234 | 235 | def train_dataloader(self): 236 | dataset = WhisperAsrDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate) 237 | return DataLoader(dataset, 238 | batch_size=self.cfg.batch_size, 239 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, 240 | collate_fn=WhisperAsrDataCollatorWhithPadding(self.cfg) 241 | ) 242 | 243 | def val_dataloader(self): 244 | datasets = [WhisperAsrDataset(dataset, self.tokenizer, self.cfg.sample_rate) for dataset in self.__eval_dataset] 245 | return [DataLoader(dataset, 246 | batch_size=self.cfg.test_batch_size, 247 | num_workers=self.cfg.num_worker, 248 | collate_fn=WhisperAsrDataCollatorWhithPadding(self.cfg) 249 | ) for dataset in datasets] 250 | 251 | def test_dataloader(self): 252 | datasets = [WhisperAsrDataset(dataset, self.tokenizer, self.cfg.sample_rate) for dataset in self.__test_dataset] 253 | return [DataLoader(dataset, 254 | batch_size=self.cfg.test_batch_size, 255 | num_workers=self.cfg.num_worker, 256 | collate_fn=WhisperAsrDataCollatorWhithPadding(self.cfg) 257 | ) for dataset in datasets] 258 | 259 | 260 | if __name__ == "__main__": 261 | from config.parse_yaml_args import parse_args_and_yaml 262 | 263 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 264 | cfg = parse_args_and_yaml(config_path="../config/exp_spec/whisper_asr.yaml") 265 | DATA_ROOT = cfg.data_root 266 | language_list = cfg.language_list 267 | 268 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 269 | for split in ["train", "dev", "test"]: 270 | joined_data_pair_lists[split], sep_data_pair_lists[split] = load_data_record(DATA_ROOT, split, 271 | language_list=language_list) 272 | 273 | cfg.batch_size = cfg.test_batch_size = 1 274 | 275 | module = WhisperAsrModelModule(cfg, joined_data_pair_lists, sep_data_pair_lists).to(cfg.device).eval() 276 | 277 | loader = module.test_dataloader()[0] 278 | asr_state_dict = torch.load(f'{cfg.cache_dir}/whisper_asr.pt', map_location='cuda') 279 | 280 | module.model.load_state_dict(asr_state_dict) 281 | 282 | 283 | def deep_to_device(obj, device): 284 | if isinstance(obj, torch.Tensor): 285 | return obj.to(device) 286 | elif isinstance(obj, dict): 287 | return {k: deep_to_device(v, device) for k, v in obj.items()} 288 | elif isinstance(obj, list): 289 | return [deep_to_device(v, device) for v in obj] 290 | else: 291 | return obj 292 | 293 | 294 | with torch.no_grad(): 295 | for b in loader: 296 | b = deep_to_device(b, cfg.device) 297 | train_res = module.training_step(b, 0) 298 | print(train_res) 299 | 300 | valid_res = module.validation_step(b, 0, 0) 301 | print(valid_res) 302 | module.validation_epoch_end([valid_res]) 303 | 304 | test_res = module.test_step(b, 0, 0) 305 | print(test_res) 306 | module.test_epoch_end([test_res]) 307 | 308 | break 309 | -------------------------------------------------------------------------------- /Whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from functools import lru_cache 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from transformers import GPT2TokenizerFast 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "iw": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | } 111 | 112 | # language comsl lookup by name, with a few language aliases 113 | TO_LANGUAGE_CODE = { 114 | **{language: code for code, language in LANGUAGES.items()}, 115 | "burmese": "my", 116 | "valencian": "ca", 117 | "flemish": "nl", 118 | "haitian": "ht", 119 | "letzeburgesch": "lb", 120 | "pushto": "ps", 121 | "panjabi": "pa", 122 | "moldavian": "ro", 123 | "moldovan": "ro", 124 | "sinhalese": "si", 125 | "castilian": "es", 126 | } 127 | 128 | 129 | @dataclass(frozen=False) 130 | class Tokenizer: 131 | """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" 132 | 133 | tokenizer: "GPT2TokenizerFast" 134 | language: Optional[str] 135 | sot_sequence: Tuple[int] 136 | task: Optional[str] 137 | predict_timestamps: Optional[bool] 138 | 139 | def encode(self, text, **kwargs): 140 | return self.tokenizer.encode(text, **kwargs) 141 | 142 | def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): 143 | return self.tokenizer.decode(token_ids, **kwargs) 144 | 145 | def decode_with_timestamps(self, tokens) -> str: 146 | """ 147 | Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 148 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 149 | """ 150 | outputs = [[]] 151 | for token in tokens: 152 | if token >= self.timestamp_begin: 153 | timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" 154 | outputs.append(timestamp) 155 | outputs.append([]) 156 | else: 157 | outputs[-1].append(token) 158 | outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] 159 | return "".join(outputs) 160 | 161 | @property 162 | def eot(self) -> int: 163 | return self.tokenizer.eos_token_id 164 | 165 | @property 166 | def sot(self) -> int: 167 | return self._get_single_token_id("<|startoftranscript|>") 168 | 169 | @property 170 | def sot_lm(self) -> int: 171 | return self._get_single_token_id("<|startoflm|>") 172 | 173 | @property 174 | def sot_prev(self) -> int: 175 | return self._get_single_token_id("<|startofprev|>") 176 | 177 | @property 178 | def no_speech(self) -> int: 179 | return self._get_single_token_id("<|nospeech|>") 180 | 181 | @property 182 | def no_timestamps(self) -> int: 183 | return self._get_single_token_id("<|notimestamps|>") 184 | 185 | @property 186 | def timestamp_begin(self) -> int: 187 | return self.tokenizer.all_special_ids[-1] + 1 188 | 189 | @property 190 | def language_token(self) -> int: 191 | """Returns the token id corresponding to the value of the `language` field""" 192 | if self.language is None: 193 | raise ValueError(f"This tokenizer does not have language token configured") 194 | 195 | additional_tokens = dict( 196 | zip( 197 | self.tokenizer.additional_special_tokens, 198 | self.tokenizer.additional_special_tokens_ids, 199 | ) 200 | ) 201 | candidate = f"<|{self.language}|>" 202 | if candidate in additional_tokens: 203 | return additional_tokens[candidate] 204 | 205 | raise KeyError(f"Language {self.language} not found in tokenizer.") 206 | 207 | @property 208 | def all_language_tokens(self) -> Tuple[int]: 209 | result = [] 210 | for token, token_id in zip( 211 | self.tokenizer.additional_special_tokens, 212 | self.tokenizer.additional_special_tokens_ids, 213 | ): 214 | if token.strip("<|>") in LANGUAGES: 215 | result.append(token_id) 216 | return tuple(result) 217 | 218 | @property 219 | def all_language_dict(self) -> Tuple[int]: 220 | result = {} 221 | for token, token_id in zip( 222 | self.tokenizer.additional_special_tokens, 223 | self.tokenizer.additional_special_tokens_ids, 224 | ): 225 | if token.strip("<|>") in LANGUAGES: 226 | result[token.strip("<|>")] = token_id 227 | return result 228 | 229 | @property 230 | def all_language_codes(self) -> Tuple[str]: 231 | return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) 232 | 233 | @property 234 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 235 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 236 | 237 | @property 238 | def non_speech_tokens(self) -> Tuple[int]: 239 | """ 240 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 241 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 242 | 243 | - ♪♪♪ 244 | - ( SPEAKING FOREIGN LANGUAGE ) 245 | - [DAVID] Hey there, 246 | 247 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 248 | """ 249 | symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") 250 | symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 251 | 252 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 253 | # In case they're multiple tokens, suppress the first token, which is safe because: 254 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 255 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 256 | miscellaneous = set("♩♪♫♬♭♮♯") 257 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 258 | 259 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 260 | result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} 261 | for symbol in symbols + list(miscellaneous): 262 | for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: 263 | if len(tokens) == 1 or symbol in miscellaneous: 264 | result.add(tokens[0]) 265 | 266 | return tuple(sorted(result)) 267 | 268 | def _get_single_token_id(self, text) -> int: 269 | tokens = self.tokenizer.encode(text) 270 | assert len(tokens) == 1, f"{text} is not encoded as a single token" 271 | return tokens[0] 272 | 273 | def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None): 274 | """ 275 | Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to 276 | update the prefix tokens as required when fine-tuning. Example: 277 | ```python 278 | >>> # instantiate the tokenizer and set the prefix token to Spanish 279 | >>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish") 280 | >>> # now switch the prefix token from Spanish to French 281 | >>> tokenizer.set_prefix_tokens(language="french") 282 | ``` 283 | Args: 284 | language (`str`, *optional*, defaults to `None`): 285 | The language of the transcription text. 286 | task (`str`, *optional*, defaults to `None`): 287 | Task identifier to append at the start of sequence (if any). 288 | predict_timestamps (`bool`, *optional*, defaults to `None`): 289 | Whether to omit the `<|notimestamps|>` token at the start of the sequence. 290 | """ 291 | self.language = language if language is not None else self.language 292 | self.task = task if task is not None else self.task 293 | self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps 294 | 295 | def build_inputs_with_special_tokens(self, token_ids_0) -> List[int]: 296 | """Build model inputs from a sequence by appending eos_token_id.""" 297 | 298 | return self.prefix_tokens + token_ids_0 + [self.eos_token_id] 299 | 300 | @property 301 | def eos_token_id(self) -> int: 302 | return self.tokenizer.eos_token_id 303 | 304 | @property 305 | def prefix_tokens(self): 306 | TASK_IDS = ["translate", "transcribe"] 307 | all_special_ids = self.tokenizer.all_special_ids 308 | bos_token_id = all_special_ids[-106] 309 | translate_token_id = all_special_ids[-6] 310 | transcribe_token_id = all_special_ids[-5] 311 | notimestamps_token_id = all_special_ids[-1] 312 | langs = tuple(LANGUAGES.keys()) 313 | 314 | if self.language is not None: 315 | self.language = self.language.lower() 316 | if self.language in TO_LANGUAGE_CODE: 317 | language_id = TO_LANGUAGE_CODE[self.language] 318 | else: 319 | raise ValueError( 320 | f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}" 321 | ) 322 | 323 | if self.task is not None: 324 | if self.task not in TASK_IDS: 325 | raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}") 326 | 327 | bos_sequence = [bos_token_id] 328 | if self.language is not None: 329 | bos_sequence.append(bos_token_id + 1 + langs.index(language_id)) 330 | if self.task is not None: 331 | bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id) 332 | if not self.predict_timestamps: 333 | bos_sequence.append(notimestamps_token_id) 334 | return bos_sequence 335 | 336 | 337 | def build_tokenizer(name: str = "gpt2"): 338 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 339 | path = os.path.join(os.path.dirname(__file__), "assets", name) 340 | tokenizer = GPT2TokenizerFast.from_pretrained(path) 341 | 342 | specials = [ 343 | "<|startoftranscript|>", 344 | *[f"<|{lang}|>" for lang in LANGUAGES.keys()], 345 | "<|translate|>", 346 | "<|transcribe|>", 347 | "<|startoflm|>", 348 | "<|startofprev|>", 349 | "<|nospeech|>", 350 | "<|notimestamps|>", 351 | ] 352 | 353 | tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) 354 | return tokenizer 355 | 356 | 357 | def get_tokenizer( 358 | multilingual: bool, 359 | *, 360 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 361 | language: Optional[str] = None, 362 | ) -> Tokenizer: 363 | if language is not None: 364 | language = language.lower() 365 | if language not in LANGUAGES: 366 | if language in TO_LANGUAGE_CODE: 367 | language = TO_LANGUAGE_CODE[language] 368 | else: 369 | raise ValueError(f"Unsupported language: {language}") 370 | 371 | if multilingual: 372 | tokenizer_name = "multilingual" 373 | task = task or "transcribe" 374 | language = language or "en" 375 | else: 376 | tokenizer_name = "gpt2" 377 | task = None 378 | language = None 379 | 380 | tokenizer = build_tokenizer(name=tokenizer_name) 381 | all_special_ids: List[int] = tokenizer.all_special_ids 382 | sot: int = all_special_ids[1] 383 | translate: int = all_special_ids[-6] 384 | transcribe: int = all_special_ids[-5] 385 | 386 | langs = tuple(LANGUAGES.keys()) 387 | sot_sequence = [sot] 388 | if language is not None: 389 | sot_sequence.append(sot + 1 + langs.index(language)) 390 | if task is not None: 391 | sot_sequence.append(transcribe if task == "transcribe" else translate) 392 | 393 | return Tokenizer(tokenizer=tokenizer, language=language, task=task, predict_timestamps=False, sot_sequence=tuple(sot_sequence)) 394 | -------------------------------------------------------------------------------- /Whisper/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | from typing import Iterable, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch import nn 10 | 11 | from .transcribe import transcribe as transcribe_function 12 | from .decoding import detect_language as detect_language_function, decode as decode_function 13 | 14 | 15 | @dataclass 16 | class ModelDimensions: 17 | n_mels: int 18 | n_audio_ctx: int 19 | n_audio_state: int 20 | n_audio_head: int 21 | n_audio_layer: int 22 | n_vocab: int 23 | n_text_ctx: int 24 | n_text_state: int 25 | n_text_head: int 26 | n_text_layer: int 27 | 28 | 29 | class LayerNorm(nn.LayerNorm): 30 | def forward(self, x: Tensor) -> Tensor: 31 | # return super().forward(x.float()).type(x.dtype) 32 | return super().forward(x) 33 | 34 | 35 | class Linear(nn.Linear): 36 | def forward(self, x: Tensor) -> Tensor: 37 | return F.linear( 38 | x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) 39 | ) 40 | 41 | 42 | class Conv1d(nn.Conv1d): 43 | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: 44 | return super()._conv_forward( 45 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 46 | ) 47 | 48 | 49 | def sinusoids(length, channels, max_timescale=10000): 50 | """Returns sinusoids for positional embedding""" 51 | assert channels % 2 == 0 52 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 53 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 54 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 55 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 56 | 57 | 58 | class MultiHeadAttention(nn.Module): 59 | def __init__(self, 60 | n_state: int, 61 | n_head: int, 62 | output_attention_weight=False,): 63 | super().__init__() 64 | self.n_head = n_head 65 | self.query = Linear(n_state, n_state) 66 | self.key = Linear(n_state, n_state, bias=False) 67 | self.value = Linear(n_state, n_state) 68 | self.out = Linear(n_state, n_state) 69 | self.output_attention_weight = output_attention_weight 70 | 71 | def forward( 72 | self, 73 | x: Tensor, 74 | xa: Optional[Tensor] = None, 75 | mask: Optional[Tensor] = None, 76 | kv_cache: Optional[dict] = None, 77 | ): 78 | q = self.query(x) 79 | 80 | if kv_cache is None or xa is None or self.key not in kv_cache: 81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 82 | # otherwise, perform key/value projections for self- or cross-attention as usual. 83 | k = self.key(x if xa is None else xa) 84 | v = self.value(x if xa is None else xa) 85 | else: 86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 87 | k = kv_cache[self.key] 88 | v = kv_cache[self.value] 89 | 90 | wv = self.qkv_attention(q, k, v, mask) 91 | if isinstance(wv, tuple): 92 | wv, w = wv 93 | return self.out(wv), w 94 | else: 95 | return self.out(wv) 96 | 97 | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): 98 | n_batch, n_ctx, n_state = q.shape 99 | scale = (n_state // self.n_head) ** -0.25 100 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 101 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 102 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 103 | 104 | qk = q @ k 105 | if mask is not None: 106 | qk = qk + mask[:n_ctx, :n_ctx] 107 | qk = qk.float() 108 | 109 | w = F.softmax(qk, dim=-1).to(q.dtype) 110 | if self.output_attention_weight: 111 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), w 112 | else: 113 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) 114 | 115 | 116 | class ResidualAttentionBlock(nn.Module): 117 | def __init__(self, 118 | n_state: int, 119 | n_head: int, 120 | cross_attention: bool = False, 121 | output_cross_attention_weight=False, 122 | ): 123 | super().__init__() 124 | 125 | self.attn = MultiHeadAttention(n_state, n_head) 126 | self.attn_ln = LayerNorm(n_state) 127 | 128 | self.cross_attn = MultiHeadAttention(n_state, n_head, output_cross_attention_weight) if cross_attention else None 129 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 130 | 131 | n_mlp = n_state * 4 132 | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) 133 | self.mlp_ln = LayerNorm(n_state) 134 | self.output_cross_attention_weight = output_cross_attention_weight 135 | 136 | def forward( 137 | self, 138 | x: Tensor, 139 | xa: Optional[Tensor] = None, 140 | mask: Optional[Tensor] = None, 141 | kv_cache: Optional[dict] = None, 142 | ): 143 | cross_attention_weight = None 144 | 145 | residual = x 146 | x = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) 147 | x = residual + x 148 | 149 | if self.cross_attn: 150 | residual = x 151 | if self.output_cross_attention_weight: 152 | x, cross_attention_weight = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 153 | else: 154 | x = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) 155 | x = residual + x 156 | 157 | residual = x 158 | x = self.mlp(self.mlp_ln(x)) 159 | x = residual + x 160 | 161 | if cross_attention_weight is not None: 162 | return x, cross_attention_weight 163 | else: 164 | return x 165 | 166 | 167 | class AudioEncoder(nn.Module): 168 | def __init__(self, 169 | n_mels: int, 170 | n_ctx: int, 171 | n_state: int, 172 | n_head: int, 173 | n_layer: int, ): 174 | super().__init__() 175 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 176 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 177 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 178 | 179 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 180 | [ResidualAttentionBlock(n_state, n_head, False) for _ in range(n_layer)] 181 | ) 182 | self.ln_post = LayerNorm(n_state) 183 | self.gradient_checkpointing = False 184 | 185 | def forward(self, x: Tensor): 186 | """ 187 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 188 | the mel spectrogram of the audio 189 | """ 190 | x = F.gelu(self.conv1(x)) 191 | x = F.gelu(self.conv2(x)) 192 | x = x.permute(0, 2, 1) 193 | 194 | assert x.shape[2:] == self.positional_embedding.shape[1:], "incorrect audio shape" 195 | x = (x + self.positional_embedding[:x.shape[1]]).to(x.dtype) 196 | 197 | 198 | 199 | for block in self.blocks: 200 | if self.gradient_checkpointing and self.training: 201 | 202 | def create_custom_forward(module): 203 | def custom_forward(*inputs): 204 | return module(*inputs) 205 | 206 | return custom_forward 207 | 208 | x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x) 209 | else: 210 | x = block(x) 211 | 212 | x = self.ln_post(x) 213 | return x 214 | 215 | 216 | class TextDecoder(nn.Module): 217 | def __init__(self, 218 | dims: ModelDimensions, 219 | output_last_cross_attention=False,): 220 | super().__init__() 221 | 222 | n_vocab, n_ctx, n_state, n_head, n_layer = dims.n_vocab, dims.n_text_ctx, dims.n_text_state, dims.n_text_head, dims.n_text_layer 223 | self.dims = dims 224 | 225 | self.token_embedding = nn.Embedding(n_vocab, n_state) 226 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 227 | 228 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 229 | [ResidualAttentionBlock(n_state, n_head, True, output_last_cross_attention) for _ in range(n_layer)] 230 | ) 231 | self.ln = LayerNorm(n_state) 232 | 233 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 234 | self.register_buffer("mask", mask, persistent=False) 235 | self.gradient_checkpointing = False 236 | 237 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, output_last_cross_attention=False): 238 | """ 239 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 240 | the text tokens 241 | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) 242 | the encoded audio features to be attended on 243 | """ 244 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 245 | 246 | last_cross_attention_weight = None 247 | 248 | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] 249 | 250 | x = x.to(xa.dtype) 251 | 252 | for block in self.blocks: 253 | if self.gradient_checkpointing and self.training: 254 | def create_custom_forward(module): 255 | def custom_forward(*inputs): 256 | return module(*inputs) 257 | 258 | return custom_forward 259 | x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, xa, self.mask, kv_cache) 260 | else: 261 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 262 | if isinstance(x, tuple): 263 | x, last_cross_attention_weight = x 264 | 265 | x = self.ln(x) 266 | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() 267 | 268 | if last_cross_attention_weight is not None and output_last_cross_attention: 269 | return logits, last_cross_attention_weight 270 | 271 | return logits 272 | 273 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 274 | """ 275 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 276 | tensors calculated for the previous positions. This method returns a dictionary that stores 277 | all caches, and the necessary hooks for the key and value projection modules that save the 278 | intermediate tensors to be reused during later calculations. 279 | 280 | Returns 281 | ------- 282 | cache : Dict[nn.Module, torch.Tensor] 283 | A dictionary object mapping the key/value projection modules to its cache 284 | hooks : List[RemovableHandle] 285 | List of PyTorch RemovableHandle objects to stop the hooks to be called 286 | """ 287 | cache = {**cache} if cache is not None else {} 288 | hooks = [] 289 | 290 | def save_to_cache(module, _, output): 291 | if module not in cache or output.shape[1] > self.positional_embedding.shape[0]: 292 | cache[module] = output # save as-is, for the first token or cross attention 293 | else: 294 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 295 | return cache[module] 296 | 297 | def install_hooks(layer: nn.Module): 298 | if isinstance(layer, MultiHeadAttention): 299 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 300 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 301 | 302 | self.apply(install_hooks) 303 | return cache, hooks 304 | 305 | 306 | class Whisper(nn.Module): 307 | def __init__(self, 308 | dims: ModelDimensions, 309 | output_last_cross_attention=False,): 310 | super().__init__() 311 | self.dims = dims 312 | self.encoder = AudioEncoder( 313 | self.dims.n_mels, 314 | self.dims.n_audio_ctx, 315 | self.dims.n_audio_state, 316 | self.dims.n_audio_head, 317 | self.dims.n_audio_layer, 318 | ) 319 | self.decoder = TextDecoder( 320 | self.dims, 321 | output_last_cross_attention=output_last_cross_attention, 322 | ) 323 | 324 | def embed_audio(self, mel: torch.Tensor): 325 | return self.encoder(mel) 326 | 327 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 328 | return self.decoder(tokens, audio_features) 329 | 330 | def forward(self, mel: torch.Tensor, tokens: torch.Tensor, output_last_cross_attention=False) -> Dict[str, torch.Tensor]: 331 | return self.decoder(tokens, self.encoder(mel), output_last_cross_attention) 332 | 333 | @property 334 | def device(self): 335 | return next(self.parameters()).device 336 | 337 | @property 338 | def is_multilingual(self): 339 | return self.dims.n_vocab == 51865 340 | 341 | def enable_acti_ckpt(self): 342 | self.encoder.gradient_checkpointing = True 343 | self.decoder.gradient_checkpointing = True 344 | 345 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 346 | """ 347 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 348 | tensors calculated for the previous positions. This method returns a dictionary that stores 349 | all caches, and the necessary hooks for the key and value projection modules that save the 350 | intermediate tensors to be reused during later calculations. 351 | 352 | Returns 353 | ------- 354 | cache : Dict[nn.Module, torch.Tensor] 355 | A dictionary object mapping the key/value projection modules to its cache 356 | hooks : List[RemovableHandle] 357 | List of PyTorch RemovableHandle objects to stop the hooks to be called 358 | """ 359 | cache = {**cache} if cache is not None else {} 360 | hooks = [] 361 | 362 | def save_to_cache(module, _, output): 363 | if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: 364 | cache[module] = output # save as-is, for the first token or cross attention 365 | else: 366 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 367 | return cache[module] 368 | 369 | def install_hooks(layer: nn.Module): 370 | if isinstance(layer, MultiHeadAttention): 371 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 372 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 373 | 374 | self.decoder.apply(install_hooks) 375 | return cache, hooks 376 | 377 | detect_language = detect_language_function 378 | transcribe = transcribe_function 379 | decode = decode_function 380 | -------------------------------------------------------------------------------- /criterion/mix_criterions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import _Loss 5 | import torchmetrics 6 | 7 | 8 | def item(tensor): 9 | if torch.is_tensor(tensor) and tensor.device.type == "xla": 10 | return tensor.detach() 11 | if hasattr(tensor, "item"): 12 | return tensor.item() 13 | if hasattr(tensor, "__getitem__"): 14 | return tensor[0] 15 | return tensor 16 | 17 | 18 | class GuidedCrossEntMultiTaskCriterion(_Loss): 19 | def __init__( 20 | self, 21 | args, 22 | padding_idx, 23 | ): 24 | super().__init__() 25 | self.padding_idx = padding_idx 26 | self.spch_loss_weight = args.spch_loss_weight 27 | self.asr_loss_weight = args.asr_loss_weight 28 | self.text_loss_weight = args.text_loss_weight 29 | self.use_cml = args.use_cml 30 | self.use_erm = args.use_erm 31 | self.cml_loss_weight = args.cml_loss_weight 32 | self.erm_loss_weight = args.erm_loss_weight 33 | self.alpha = args.guide_alpha 34 | self.text_alpha = args.text_alpha 35 | assert 0 <= self.alpha <= 1.0 36 | 37 | self.metrics = nn.ModuleDict({}) 38 | if self.spch_loss_weight > 0.0: 39 | self.metrics.update({ 40 | 'speech_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 41 | 'speech_correct': torchmetrics.SumMetric(compute_on_step=False), 42 | 'speech_total': torchmetrics.SumMetric(compute_on_step=False), 43 | }) 44 | if self.asr_loss_weight > 0.0: 45 | self.metrics.update({ 46 | 'asr_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 47 | 'asr_correct': torchmetrics.SumMetric(compute_on_step=False), 48 | 'asr_total': torchmetrics.SumMetric(compute_on_step=False), 49 | }) 50 | if self.text_loss_weight > 0.0: 51 | self.metrics.update({ 52 | 'text_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 53 | 'text_correct': torchmetrics.SumMetric(compute_on_step=False), 54 | 'text_total': torchmetrics.SumMetric(compute_on_step=False), 55 | }) 56 | if self.alpha > 0.0: 57 | self.metrics.update({ 58 | 'guide_loss': torchmetrics.SumMetric(compute_on_step=False), 59 | }) 60 | 61 | if args.use_cml: 62 | self.metrics.update({ 63 | 'mix_speech_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 64 | 'mix_asr_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 65 | 'mix_mlm_nll_loss': torchmetrics.SumMetric(compute_on_step=False), 66 | 'mix_speech_correct': torchmetrics.SumMetric(compute_on_step=False), 67 | 'mix_asr_correct': torchmetrics.SumMetric(compute_on_step=False), 68 | 'mix_mlm_correct': torchmetrics.SumMetric(compute_on_step=False), 69 | 'mix_speech_total': torchmetrics.SumMetric(compute_on_step=False), 70 | 'mix_asr_total': torchmetrics.SumMetric(compute_on_step=False), 71 | 'mix_mlm_total': torchmetrics.SumMetric(compute_on_step=False), 72 | }) 73 | if self.use_erm: 74 | self.metrics.update({ 75 | 'reg_loss': torchmetrics.SumMetric(compute_on_step=False), 76 | 'reg_total': torchmetrics.SumMetric(compute_on_step=False), 77 | }) 78 | 79 | def forward(self, model, sample, reduce=True): 80 | reduction = 'sum' if reduce else 'none' 81 | net_input = sample["net_input"] 82 | output = model(**net_input) 83 | decoder_output = output[0] 84 | 85 | targets = {'transcript': sample["target"][0], 'translate': sample["target"][1]} 86 | 87 | logits = decoder_output[0] 88 | 89 | speech_loss, speech_nll_loss, speech_correct, speech_total = {}, {}, {}, {} 90 | text_loss, text_nll_loss, text_correct, text_total = {}, {}, {}, {} 91 | nll_loss, acc = {}, {} 92 | 93 | # asr 94 | spch_logits, _ = logits['transcript'] 95 | if self.asr_loss_weight > 0.0: 96 | speech_loss['transcript'], speech_nll_loss['transcript'], speech_correct['transcript'], speech_total[ 97 | 'transcript'], _ = self.compute_loss_and_acc(model, spch_logits, targets['transcript'], 98 | reduction=reduction) 99 | nll_loss['asr'] = speech_nll_loss['transcript'] / speech_total['transcript'] 100 | acc['asr'] = speech_correct['transcript'] / speech_total['transcript'] * 100.0 101 | 102 | spch_logits, text_logits = logits['translate'] 103 | # mt 104 | if self.text_loss_weight > 0.0: 105 | if 'reg_logits' in decoder_output[1] and self.text_alpha > 0.0: 106 | text_loss['translate'], text_nll_loss['translate'], text_correct['translate'], text_total[ 107 | 'translate'], _ = self.guide_loss_and_acc(model, text_logits, decoder_output[1]['reg_logits'], 108 | targets['translate'], reduction=reduction, 109 | alpha=self.text_alpha) 110 | else: 111 | text_loss['translate'], text_nll_loss['translate'], text_correct['translate'], text_total[ 112 | 'translate'], _ = self.compute_loss_and_acc(model, text_logits, targets['translate'], 113 | reduction=reduction) 114 | nll_loss['text_translate'] = text_nll_loss['translate'] / text_total['translate'] 115 | acc['text_translate'] = text_correct['translate'] / text_total['translate'] * 100.0 116 | # st 117 | if self.spch_loss_weight > 0.0: 118 | if text_logits is not None: 119 | speech_loss['translate'], speech_nll_loss['translate'], speech_correct['translate'], speech_total[ 120 | 'translate'], guide_loss = self.guide_loss_and_acc(model, spch_logits, text_logits, 121 | targets['translate'], reduction=reduction) 122 | else: 123 | speech_loss['translate'], speech_nll_loss['translate'], speech_correct['translate'], speech_total[ 124 | 'translate'], _ = self.compute_loss_and_acc(model, spch_logits, targets['translate'], 125 | reduction=reduction) 126 | nll_loss['speech_translate'] = speech_nll_loss['translate'] / speech_total['translate'] 127 | acc['speech_translate'] = speech_correct['translate'] / speech_total['translate'] * 100.0 128 | 129 | if not self.training: 130 | if self.asr_loss_weight > 0.0: 131 | self.metrics['asr_nll_loss'].update(speech_loss['transcript']) 132 | self.metrics['asr_correct'].update(speech_correct['transcript']) 133 | self.metrics['asr_total'].update(speech_total['transcript']) 134 | if self.text_loss_weight > 0.0: 135 | self.metrics['text_nll_loss'].update(text_loss['translate']) 136 | self.metrics['text_correct'].update(text_correct['translate']) 137 | self.metrics['text_total'].update(text_total['translate']) 138 | if self.spch_loss_weight > 0.0: 139 | self.metrics['speech_nll_loss'].update(speech_loss['translate']) 140 | self.metrics['speech_correct'].update(speech_correct['translate']) 141 | self.metrics['speech_total'].update(speech_total['translate']) 142 | if self.alpha > 0.0: 143 | self.metrics['guide_loss'].update(guide_loss) 144 | 145 | total_loss = 0.0 146 | if self.asr_loss_weight > 0.0: 147 | total_loss += speech_loss['transcript'] * self.asr_loss_weight 148 | if self.text_loss_weight > 0.0: 149 | total_loss += text_loss['translate'] * self.text_loss_weight 150 | if self.spch_loss_weight > 0.0: 151 | total_loss += speech_loss['translate'] * self.spch_loss_weight 152 | 153 | translate_logits_teacher = logits['translate'][1] 154 | 155 | mix_logits = decoder_output[1]['mix_dec_outs'] 156 | erm_loss = decoder_output[1]['erm_loss'] 157 | 158 | if self.use_cml and mix_logits is not None: 159 | mlm_mask = decoder_output[1]['mlm_mask'] 160 | for task in ['transcript', 'translate']: 161 | if len(mix_logits[task]) == 0: 162 | continue 163 | spch_logits, text_logits = mix_logits[task] 164 | if task == 'translate': 165 | speech_loss[task], speech_nll_loss[task], speech_correct[task], speech_total[ 166 | task], _ = self.guide_loss_and_acc(model, spch_logits, translate_logits_teacher, targets[task], 167 | reduction=reduction) 168 | else: 169 | speech_loss[task], speech_nll_loss[task], speech_correct[task], speech_total[ 170 | task], _ = self.compute_loss_and_acc(model, spch_logits, targets[task], reduction=reduction) 171 | mlm_target = targets[task].detach().clone() 172 | mlm_target[~mlm_mask] = self.padding_idx 173 | text_loss[task], text_nll_loss[task], text_correct[task], text_total[ 174 | task], _ = self.compute_loss_and_acc(model, text_logits, mlm_target, reduction=reduction) 175 | 176 | nll_loss['mix_speech_translate'] = speech_nll_loss['translate'] / speech_total['translate'] 177 | nll_loss['mix_asr'] = speech_nll_loss['transcript'] / speech_total['transcript'] 178 | nll_loss['mix_mlm'] = text_nll_loss['transcript'] / text_total['transcript'] 179 | acc['mix_speech_translate'] = speech_correct['translate'] / speech_total['translate'] * 100.0 180 | acc['mix_asr'] = speech_correct['transcript'] / speech_total['transcript'] * 100.0 181 | acc['mix_mlm'] = text_correct['transcript'] / text_total['transcript'] * 100.0 182 | 183 | if not self.training: 184 | self.metrics['mix_speech_nll_loss'].update(speech_nll_loss['translate']) 185 | self.metrics['mix_asr_nll_loss'].update(speech_nll_loss['transcript']) 186 | self.metrics['mix_mlm_nll_loss'].update(text_nll_loss['transcript']) 187 | self.metrics['mix_speech_correct'].update(speech_correct['translate']) 188 | self.metrics['mix_asr_correct'].update(speech_correct['transcript']) 189 | self.metrics['mix_mlm_correct'].update(text_correct['transcript']) 190 | self.metrics['mix_speech_total'].update(speech_total['translate']) 191 | self.metrics['mix_asr_total'].update(speech_total['transcript']) 192 | self.metrics['mix_mlm_total'].update(text_total['transcript']) 193 | 194 | cml_loss = (speech_loss['transcript'] + speech_loss['translate'] + text_loss['transcript']) / 3 195 | 196 | if self.use_erm and erm_loss is not None: 197 | src_token_num = item(erm_loss.ne(0).sum()) 198 | erm_loss = erm_loss.sum() 199 | cml_loss += erm_loss * self.erm_loss_weight 200 | if not self.training: 201 | self.metrics['reg_loss'].update(erm_loss) 202 | self.metrics['reg_total'].update(src_token_num) 203 | erm_loss /= src_token_num 204 | else: 205 | erm_loss = 0 206 | 207 | total_loss += cml_loss * self.cml_loss_weight 208 | 209 | mean_guide_loss = guide_loss / speech_total['translate'] if self.alpha > 0.0 else None 210 | logging_output = self.step_logging_output( 211 | acc, nll_loss, mean_guide_loss, erm_loss 212 | ) 213 | 214 | return total_loss, logging_output, output 215 | 216 | def reduce_metric(self): 217 | output = {} 218 | for k, v in self.metrics.items(): 219 | if 'total' in k: 220 | continue 221 | elif 'nll_loss' in k: 222 | output[k] = v.compute() / self.metrics[k.replace('nll_loss', 'total')].compute() 223 | elif 'correct' in k: 224 | output[k.replace('correct', 'acc')] = v.compute() / self.metrics[ 225 | k.replace('correct', 'total')].compute() * 100.0 226 | if self.alpha > 0.0: 227 | output['guide_loss'] = self.metrics['guide_loss'].compute() / self.metrics['speech_total'].compute() 228 | if self.use_cml and self.use_erm: 229 | output['reg_loss'] = self.metrics['reg_loss'].compute() / self.metrics['reg_total'].compute() 230 | for k, v in self.metrics.items(): 231 | v.reset() 232 | return output 233 | 234 | def compute_loss_and_acc(self, model, logits, target, reduction='sum'): 235 | logits = logits.view(-1, logits.size(-1)).float() # -> (B x T) x C 236 | target = target.view(-1) 237 | loss = F.cross_entropy( 238 | logits, target, ignore_index=self.padding_idx, reduction=reduction, 239 | ) 240 | 241 | nll_loss = F.cross_entropy( 242 | logits, target, label_smoothing=0, ignore_index=self.padding_idx, reduction=reduction, 243 | ).detach() 244 | 245 | mask = target.ne(self.padding_idx) 246 | correct = torch.sum(logits.argmax(1).masked_select(mask).eq(target.masked_select(mask))) 247 | total = torch.sum(mask) 248 | return loss, nll_loss, correct, total, torch.tensor(0.0) 249 | 250 | def guide_loss_and_acc(self, model, logits, logits_teacher, target, reduction='sum', alpha=None): 251 | """ lprobs_teacher is used as guide for lprobs """ 252 | alpha = self.alpha if alpha is None else alpha 253 | if alpha == 0.0: 254 | return self.compute_loss_and_acc(model, logits, target, reduction=reduction) 255 | 256 | logits = logits.view(-1, logits.size(-1)).float() # -> (B x T) x C 257 | logits_teacher = logits_teacher.view(-1, logits_teacher.size(-1)).float() # -> (B x T) x C 258 | target = target.view(-1) 259 | loss = F.cross_entropy(logits, target, ignore_index=self.padding_idx, reduction=reduction) 260 | nll_loss = loss 261 | probs_teacher = F.softmax(logits_teacher, dim=-1).masked_fill_(target.unsqueeze(-1).eq(self.padding_idx), 0) 262 | probs_teacher = probs_teacher.detach() 263 | lprobs = F.log_softmax(logits, dim=-1) 264 | guide_loss = -(probs_teacher * lprobs).sum() if reduction == 'sum' else -(probs_teacher * lprobs).sum(-1, 265 | keepdim=True) 266 | loss = alpha * guide_loss + (1.0 - alpha) * loss 267 | 268 | mask = target.ne(self.padding_idx) 269 | correct = torch.sum(logits.argmax(1).masked_select(mask).eq(target.masked_select(mask))) 270 | total = torch.sum(mask) 271 | return loss, nll_loss, correct, total, guide_loss 272 | 273 | def step_logging_output( 274 | self, 275 | acc, 276 | nll_loss, 277 | guide_loss=None, 278 | reg_cost=None, 279 | ): 280 | logging_output = {} 281 | for k in acc.keys(): 282 | logging_output[f'acc_{k}'] = item(acc[k].data) 283 | logging_output[f'nll_loss_{k}'] = item(nll_loss[k].data) 284 | if guide_loss is not None: 285 | logging_output[f'guide_loss'] = item(guide_loss.data) 286 | logging_output[f'reg_loss'] = item(reg_cost.data) if reg_cost is not None else 0.0 287 | 288 | return logging_output 289 | -------------------------------------------------------------------------------- /Whisper/transcribe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import List, Optional, Tuple, Union, TYPE_CHECKING 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram 11 | from .decoding import DecodingOptions, DecodingResult 12 | from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer 13 | from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def transcribe( 20 | model: "Whisper", 21 | audio: Union[str, np.ndarray, torch.Tensor], 22 | *, 23 | verbose: Optional[bool] = None, 24 | temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), 25 | compression_ratio_threshold: Optional[float] = 2.4, 26 | logprob_threshold: Optional[float] = -1.0, 27 | no_speech_threshold: Optional[float] = 0.6, 28 | condition_on_previous_text: bool = True, 29 | **decode_options, 30 | ): 31 | """ 32 | Transcribe an audio file using Whisper 33 | 34 | Parameters 35 | ---------- 36 | model: Whisper 37 | The Whisper model instance 38 | 39 | audio: Union[str, np.ndarray, torch.Tensor] 40 | The path to the audio file to open, or the audio waveform 41 | 42 | verbose: bool 43 | Whether to display the text being decoded to the console. If True, displays all the details, 44 | If False, displays minimal details. If None, does not display anything 45 | 46 | temperature: Union[float, Tuple[float, ...]] 47 | Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 48 | upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 49 | 50 | compression_ratio_threshold: float 51 | If the gzip compression ratio is above this value, treat as failed 52 | 53 | logprob_threshold: float 54 | If the average log probability over sampled tokens is below this value, treat as failed 55 | 56 | no_speech_threshold: float 57 | If the no_speech probability is higher than this value AND the average log probability 58 | over sampled tokens is below `logprob_threshold`, consider the segment as silent 59 | 60 | condition_on_previous_text: bool 61 | if True, the previous output of the model is provided as a prompt for the next window; 62 | disabling may make the text inconsistent across windows, but the model becomes less prone to 63 | getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 64 | 65 | decode_options: dict 66 | Keyword arguments to construct `DecodingOptions` instances 67 | 68 | Returns 69 | ------- 70 | A dictionary containing the resulting text ("text") and segment-level details ("segments"), and 71 | the spoken language ("language"), which is detected when `decode_options["language"]` is None. 72 | """ 73 | dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 74 | if model.device == torch.device("cpu"): 75 | if torch.cuda.is_available(): 76 | warnings.warn("Performing inference on CPU when CUDA is available") 77 | if dtype == torch.float16: 78 | warnings.warn("FP16 is not supported on CPU; using FP32 instead") 79 | dtype = torch.float32 80 | 81 | if dtype == torch.float32: 82 | decode_options["fp16"] = False 83 | 84 | mel = log_mel_spectrogram(audio) 85 | 86 | if decode_options.get("language", None) is None: 87 | if not model.is_multilingual: 88 | decode_options["language"] = "en" 89 | else: 90 | if verbose: 91 | print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") 92 | segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) 93 | _, probs = model.detect_language(segment) 94 | decode_options["language"] = max(probs, key=probs.get) 95 | if verbose is not None: 96 | print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") 97 | 98 | language = decode_options["language"] 99 | task = decode_options.get("task", "transcribe") 100 | tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) 101 | 102 | def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: 103 | temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature 104 | decode_result = None 105 | 106 | for t in temperatures: 107 | kwargs = {**decode_options} 108 | if t > 0: 109 | # disable beam_size and patience when t > 0 110 | kwargs.pop("beam_size", None) 111 | kwargs.pop("patience", None) 112 | else: 113 | # disable best_of when t == 0 114 | kwargs.pop("best_of", None) 115 | 116 | options = DecodingOptions(**kwargs, temperature=t) 117 | decode_result = model.decode(segment, options) 118 | 119 | needs_fallback = False 120 | if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: 121 | needs_fallback = True # too repetitive 122 | if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: 123 | needs_fallback = True # average log probability is too low 124 | 125 | if not needs_fallback: 126 | break 127 | 128 | return decode_result 129 | 130 | seek = 0 131 | input_stride = exact_div( 132 | N_FRAMES, model.dims.n_audio_ctx 133 | ) # mel frames per output token: 2 134 | time_precision = ( 135 | input_stride * HOP_LENGTH / SAMPLE_RATE 136 | ) # time per output token: 0.02 (seconds) 137 | all_tokens = [] 138 | all_segments = [] 139 | prompt_reset_since = 0 140 | 141 | initial_prompt = decode_options.pop("initial_prompt", None) or [] 142 | if initial_prompt: 143 | initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) 144 | all_tokens.extend(initial_prompt) 145 | 146 | def add_segment( 147 | *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult 148 | ): 149 | text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) 150 | if len(text.strip()) == 0: # skip empty text output 151 | return 152 | 153 | all_segments.append( 154 | { 155 | "id": len(all_segments), 156 | "seek": seek, 157 | "start": start, 158 | "end": end, 159 | "text": text, 160 | "tokens": text_tokens.tolist(), 161 | "temperature": result.temperature, 162 | "avg_logprob": result.avg_logprob, 163 | "compression_ratio": result.compression_ratio, 164 | "no_speech_prob": result.no_speech_prob, 165 | } 166 | ) 167 | if verbose: 168 | print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}") 169 | 170 | # show the progress bar when verbose is False (otherwise the transcribed text will be printed) 171 | num_frames = mel.shape[-1] 172 | previous_seek_value = seek 173 | 174 | with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: 175 | while seek < num_frames: 176 | timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) 177 | segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) 178 | segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE 179 | 180 | decode_options["prompt"] = all_tokens[prompt_reset_since:] 181 | result: DecodingResult = decode_with_fallback(segment) 182 | tokens = torch.tensor(result.tokens) 183 | 184 | if no_speech_threshold is not None: 185 | # no voice activity check 186 | should_skip = result.no_speech_prob > no_speech_threshold 187 | if logprob_threshold is not None and result.avg_logprob > logprob_threshold: 188 | # don't skip if the logprob is high enough, despite the no_speech_prob 189 | should_skip = False 190 | 191 | if should_skip: 192 | seek += segment.shape[-1] # fast-forward to the next segment boundary 193 | continue 194 | 195 | timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) 196 | consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) 197 | if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens 198 | last_slice = 0 199 | for current_slice in consecutive: 200 | sliced_tokens = tokens[last_slice:current_slice] 201 | start_timestamp_position = ( 202 | sliced_tokens[0].item() - tokenizer.timestamp_begin 203 | ) 204 | end_timestamp_position = ( 205 | sliced_tokens[-1].item() - tokenizer.timestamp_begin 206 | ) 207 | add_segment( 208 | start=timestamp_offset + start_timestamp_position * time_precision, 209 | end=timestamp_offset + end_timestamp_position * time_precision, 210 | text_tokens=sliced_tokens[1:-1], 211 | result=result, 212 | ) 213 | last_slice = current_slice 214 | last_timestamp_position = ( 215 | tokens[last_slice - 1].item() - tokenizer.timestamp_begin 216 | ) 217 | seek += last_timestamp_position * input_stride 218 | all_tokens.extend(tokens[: last_slice + 1].tolist()) 219 | else: 220 | duration = segment_duration 221 | timestamps = tokens[timestamp_tokens.nonzero().flatten()] 222 | if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: 223 | # no consecutive timestamps but it has a timestamp; use the last one. 224 | # single timestamp at the end means no speech after the last timestamp. 225 | last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin 226 | duration = last_timestamp_position * time_precision 227 | 228 | add_segment( 229 | start=timestamp_offset, 230 | end=timestamp_offset + duration, 231 | text_tokens=tokens, 232 | result=result, 233 | ) 234 | 235 | seek += segment.shape[-1] 236 | all_tokens.extend(tokens.tolist()) 237 | 238 | if not condition_on_previous_text or result.temperature > 0.5: 239 | # do not feed the prompt tokens if a high temperature was used 240 | prompt_reset_since = len(all_tokens) 241 | 242 | # update progress bar 243 | pbar.update(min(num_frames, seek) - previous_seek_value) 244 | previous_seek_value = seek 245 | 246 | return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) 247 | 248 | 249 | def cli(): 250 | from . import available_models 251 | 252 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 253 | parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") 254 | parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") 255 | parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") 256 | parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") 257 | parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") 258 | parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") 259 | 260 | parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") 261 | parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") 262 | 263 | parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") 264 | parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") 265 | parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") 266 | parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") 267 | parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") 268 | 269 | parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") 270 | parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") 271 | parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") 272 | parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") 273 | 274 | parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") 275 | parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") 276 | parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") 277 | parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") 278 | parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") 279 | 280 | args = parser.parse_args().__dict__ 281 | model_name: str = args.pop("model") 282 | model_dir: str = args.pop("model_dir") 283 | output_dir: str = args.pop("output_dir") 284 | device: str = args.pop("device") 285 | os.makedirs(output_dir, exist_ok=True) 286 | 287 | if model_name.endswith(".en") and args["language"] not in {"en", "English"}: 288 | if args["language"] is not None: 289 | warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") 290 | args["language"] = "en" 291 | 292 | temperature = args.pop("temperature") 293 | temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") 294 | if temperature_increment_on_fallback is not None: 295 | temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) 296 | else: 297 | temperature = [temperature] 298 | 299 | threads = args.pop("threads") 300 | if threads > 0: 301 | torch.set_num_threads(threads) 302 | 303 | from . import load_model 304 | model = load_model(model_name, device=device, download_root=model_dir) 305 | 306 | for audio_path in args.pop("audio"): 307 | result = transcribe(model, audio_path, temperature=temperature, **args) 308 | 309 | audio_basename = os.path.basename(audio_path) 310 | 311 | # save TXT 312 | with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt: 313 | write_txt(result["segments"], file=txt) 314 | 315 | # save VTT 316 | with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt: 317 | write_vtt(result["segments"], file=vtt) 318 | 319 | # save SRT 320 | with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt: 321 | write_srt(result["segments"], file=srt) 322 | 323 | 324 | if __name__ == '__main__': 325 | cli() 326 | -------------------------------------------------------------------------------- /modules/comsl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import os 5 | from pytorch_lightning import LightningModule 6 | import torchmetrics 7 | from torch.utils.data import DataLoader 8 | 9 | if __name__ == "__main__": 10 | import sys 11 | 12 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | sys.path.append(BASE_DIR) 14 | 15 | from data.data_util import LANG_DICT, load_data_record, pad_trim_audio 16 | from data.dataset import ComSTDataset 17 | from decode.mbart_decode import decode, DecodingOptions 18 | from criterion.mix_criterions import GuidedCrossEntMultiTaskCriterion 19 | from criterion.metric_util import get_segment_tokenizers, preprocess_sentence 20 | from model.optimizer import configure_optimizer_schedular 21 | from model.model_util import load_mbart_tokenizer 22 | from model.ComSL_model import ComSTModel 23 | from Whisper.normalizers import BasicTextNormalizer 24 | 25 | 26 | class ComSTCollatorWhithPadding: 27 | def __init__(self, cfg, pad_token_id, tokenizer) -> None: 28 | self.pad_token_id = pad_token_id 29 | self.cfg = cfg 30 | self.tokenizer = tokenizer 31 | self.p_mask = cfg.p_mask 32 | 33 | def __call__(self, features): 34 | transcription_ids, translation_ids, transcription_labels, translation_labels, audio, src_lang_codes, tgt_lang_codes = [], [], [], [], [], [], [] 35 | for f in features: 36 | transcription_ids.append(f["transcription_ids"]) 37 | translation_ids.append(f["translation_ids"]) 38 | transcription_labels.append(f["transcription_labels"]) 39 | translation_labels.append(f["translation_labels"]) 40 | audio.append(f["audio"]) 41 | src_lang_codes.append(LANG_DICT[f['src_lang']]['whisper']) 42 | tgt_lang_codes.append(LANG_DICT[f['tgt_lang']]['whisper']) 43 | 44 | # audio 45 | mel = pad_trim_audio(audio, self.cfg) 46 | # transcription 47 | transcription_len = [len(t) for t in transcription_ids] 48 | max_transcription_len = max(transcription_len) 49 | 50 | src_txt_tokens = [ 51 | np.pad(ids, (0, max_transcription_len - length), 'constant', constant_values=self.pad_token_id) for 52 | ids, length in zip(transcription_ids, transcription_len)] 53 | transcription_labels = [ 54 | np.pad(ids, (0, max_transcription_len - length), 'constant', constant_values=self.pad_token_id) for 55 | ids, length in zip(transcription_labels, transcription_len)] 56 | 57 | # text encoder 58 | translation_len = [len(t) for t in translation_ids] 59 | max_translation_len = max(translation_len) 60 | 61 | tgt_txt_tokens = [np.pad(ids, (0, max_translation_len - length), 'constant', constant_values=self.pad_token_id) 62 | for ids, length in zip(translation_ids, translation_len)] 63 | translation_labels = [ 64 | np.pad(ids, (0, max_translation_len - length), 'constant', constant_values=self.pad_token_id) for 65 | ids, length in zip(translation_labels, translation_len)] 66 | 67 | def to_tensor(x): 68 | return torch.tensor(np.array(x), requires_grad=False) 69 | 70 | src_txt_tokens = to_tensor(src_txt_tokens).long() 71 | tgt_txt_tokens = to_tensor(tgt_txt_tokens).long() 72 | 73 | mlm_mask = torch.rand(src_txt_tokens.shape) < self.p_mask 74 | mlm_mask = mlm_mask & (src_txt_tokens != self.pad_token_id) & (src_txt_tokens != 2) 75 | mlm_mask[:, 0] = False 76 | 77 | while mlm_mask.sum() == 0: 78 | mlm_mask = torch.rand(src_txt_tokens.shape) < self.p_mask 79 | mlm_mask = mlm_mask & (src_txt_tokens != self.pad_token_id) & (src_txt_tokens != 2) 80 | mlm_mask[:, 0] = False 81 | 82 | masked_src_tokens = src_txt_tokens.clone() 83 | masked_src_tokens[mlm_mask] = self.tokenizer.mask_token_id 84 | 85 | tgt_mlm_mask = torch.roll(mlm_mask, -1, dims=1) 86 | 87 | src_lang_ids = src_txt_tokens[:, 0] 88 | 89 | batch = { 90 | "net_input": { 91 | "mel": mel, 92 | "src_lang_ids": src_lang_ids, 93 | "tokens": [src_txt_tokens, tgt_txt_tokens], 94 | 'masked_src_tokens': masked_src_tokens, 95 | 'mlm_mask': tgt_mlm_mask, 96 | "txt_lengths": [to_tensor(transcription_len), to_tensor(translation_len)], 97 | }, 98 | "target": [to_tensor(transcription_labels).long(), to_tensor(translation_labels).long()], 99 | "dec_start_ids": [src_txt_tokens[:, 0], tgt_txt_tokens[:, 0]], 100 | "ntokens": [sum(transcription_len), sum(translation_len)], 101 | "tgt_lang_codes": tgt_lang_codes, 102 | "src_lang_codes": src_lang_codes, 103 | } 104 | 105 | return batch 106 | 107 | 108 | class ComSTModule(LightningModule): 109 | def __init__(self, cfg, joined_dataset: dict, sep_dataset: dict) -> None: 110 | super().__init__() 111 | 112 | self.tokenizer = load_mbart_tokenizer(cfg) 113 | self.model = ComSTModel(cfg) 114 | self.decode_options = DecodingOptions(beam_size=5) 115 | self.collect_fn = ComSTCollatorWhithPadding(cfg, self.tokenizer.pad_token_id, tokenizer=self.tokenizer) 116 | self.automatic_optimization = True 117 | 118 | self.cfg = cfg 119 | self.__train_dataset = joined_dataset.get("train", []) 120 | self.__eval_dataset = sep_dataset.get("dev", []) 121 | self.__test_dataset = sep_dataset.get("test", []) 122 | 123 | self.segment_tokenizers = get_segment_tokenizers() 124 | self.normalizer = BasicTextNormalizer() 125 | 126 | self.train_criterion = GuidedCrossEntMultiTaskCriterion(cfg, self.tokenizer.pad_token_id) 127 | self.val_criterion = nn.ModuleList( 128 | [GuidedCrossEntMultiTaskCriterion(cfg, self.tokenizer.pad_token_id) for _ in self.__eval_dataset]) 129 | self.valid_metrics = nn.ModuleDict({ 130 | 'bleu_spch': nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__eval_dataset]), 131 | 'bleu_text': nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__eval_dataset]), 132 | 'wer': nn.ModuleList([torchmetrics.WordErrorRate(compute_on_step=False) for _ in self.__eval_dataset]), 133 | }) 134 | self.test_metrics = nn.ModuleDict({ 135 | 'bleu_spch': nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__test_dataset]), 136 | 'bleu_text': nn.ModuleList([torchmetrics.BLEUScore(compute_on_step=False) for _ in self.__test_dataset]), 137 | 'wer': nn.ModuleList([torchmetrics.WordErrorRate(compute_on_step=False) for _ in self.__test_dataset]), 138 | }) 139 | 140 | def forward(self, **kwargs): 141 | return self.model(**kwargs) 142 | 143 | def training_step(self, batch, batch_id): 144 | 145 | self.model.set_num_updates(self.global_step, self.current_epoch) 146 | loss, logging_output, net_output = self.train_criterion(self.model, batch) 147 | for log_item, data in logging_output.items(): 148 | self.log(f"train_{log_item}", data, on_step=True, prog_bar=True, logger=True, on_epoch=False, 149 | sync_dist=True) 150 | return loss 151 | 152 | def on_validation_epoch_start(self) -> None: 153 | for k, v in self.valid_metrics.items(): 154 | for metric in v: 155 | metric.set_dtype(torch.float32) 156 | for criterion in self.val_criterion: 157 | for metric in criterion.metrics.values(): 158 | metric.set_dtype(torch.float32) 159 | 160 | def validation_step(self, batch, batch_id, dataloader_idx=None): 161 | if dataloader_idx is None: 162 | print("warning: dataloader_idx is None") 163 | dataloader_idx = 0 164 | labels = batch["target"] 165 | loss, logging_output, net_output = self.val_criterion[dataloader_idx](self.model, batch) 166 | 167 | enc_out = net_output[1] 168 | hidden_states = [enc_out[0]['encoder_out'], enc_out[1]['encoder_out']] 169 | decode_res = self.decode(hidden_states, 170 | labels, 171 | batch["dec_start_ids"], 172 | batch["src_lang_codes"], 173 | batch["tgt_lang_codes"], 174 | dataloader_idx, 175 | self.valid_metrics) 176 | 177 | return { 178 | "result": decode_res, 179 | "logging_output": logging_output 180 | } 181 | 182 | def validation_epoch_end(self, outputs): 183 | for metric_name, metric_list in self.valid_metrics.items(): 184 | results = [m.compute() for m in metric_list] 185 | if 'bleu' in metric_name: 186 | results = [r * 100 for r in results] 187 | mean_result = sum(results) / len(results) 188 | self.log(f"val_{metric_name}_epoch", mean_result) 189 | print(f"val_{metric_name}_epoch", mean_result) 190 | for m in metric_list: 191 | m.reset() 192 | valid_reports = {} 193 | for criterion in self.val_criterion: 194 | valid_report = criterion.reduce_metric() 195 | for k, v in valid_report.items(): 196 | if k not in valid_reports: 197 | valid_reports[k] = v 198 | else: 199 | valid_reports[k] += v 200 | average_valid_reports = {k: v / len(self.val_criterion) for k, v in valid_reports.items()} 201 | for k, v in average_valid_reports.items(): 202 | self.log(f"val_{k}_epoch", v) 203 | print(f"val_{k}_epoch", v) 204 | 205 | def on_test_epoch_start(self) -> None: 206 | for metric_name, metric_list in self.test_metrics.items(): 207 | for m in metric_list: 208 | m.set_dtype(torch.float32) 209 | 210 | def test_step(self, batch, batch_idx, dataloader_idx=None): 211 | if dataloader_idx is None: 212 | print("warning: dataloader_idx is None") 213 | dataloader_idx = 0 214 | labels = batch["target"] 215 | net_input = batch["net_input"] 216 | enc_out = self.model.encoder( 217 | mel=net_input["mel"], 218 | src_tokens=net_input["tokens"][0], 219 | src_lang_ids=net_input["src_lang_ids"], 220 | masked_src_tokens=None, 221 | ) 222 | 223 | hidden_states = [enc_out[0]['encoder_out'], enc_out[1]['encoder_out']] 224 | decode_res = self.decode(hidden_states, 225 | labels, 226 | batch["dec_start_ids"], 227 | batch["src_lang_codes"], 228 | batch["tgt_lang_codes"], 229 | dataloader_idx, 230 | self.test_metrics) 231 | 232 | return decode_res 233 | 234 | def test_epoch_end(self, outputs): 235 | for metric_name, metric_list in self.test_metrics.items(): 236 | results = [m.compute() for m in metric_list] 237 | if 'bleu' in metric_name: 238 | results = [r * 100 for r in results] 239 | mean_result = sum(results) / len(results) 240 | self.log(f"test_{metric_name}_epoch", mean_result) 241 | for i, r in enumerate(results): 242 | self.log(f"test_{metric_name}_{i}", r) 243 | print(f"test_{metric_name}_epoch", mean_result) 244 | for m in metric_list: 245 | m.reset() 246 | 247 | def decode(self, hidden_states, labels, decoder_start_ids, src_lang_codes, tgt_lang_codes, dataloader_idx, metrics): 248 | 249 | transcript_start_ids, translate_start_ids = decoder_start_ids 250 | transcript_labels, translate_labels = labels 251 | spch_hiddens, text_hiddens = hidden_states 252 | 253 | labels = {'transcript': transcript_labels, 'translate': translate_labels} 254 | start_ids = {'transcript': transcript_start_ids, 'translate': translate_start_ids} 255 | result = {} 256 | for task in ['transcript', 'translate']: 257 | if labels[task] is None: 258 | continue 259 | spch_pred_token = decode(self.model.decoder, 260 | self.tokenizer, 261 | enc_hidden_states=spch_hiddens, 262 | forced_bos_token_id=start_ids[task], 263 | options=self.decode_options, ) 264 | 265 | spch_detoken_out = self.tokenizer.batch_decode(spch_pred_token, skip_special_tokens=True) 266 | detoken_label = self.tokenizer.batch_decode(labels[task], skip_special_tokens=True) 267 | lang_codes = src_lang_codes if task == 'transcript' else tgt_lang_codes 268 | preprocess_sentence(spch_detoken_out, lang_codes, self.segment_tokenizers) 269 | preprocess_sentence(detoken_label, lang_codes, self.segment_tokenizers) 270 | if task == 'translate': 271 | text_pred_token = decode(self.model.decoder, 272 | self.tokenizer, 273 | enc_hidden_states=text_hiddens, 274 | forced_bos_token_id=start_ids[task], 275 | options=self.decode_options, ) 276 | text_detoken_out = self.tokenizer.batch_decode(text_pred_token, skip_special_tokens=True) 277 | preprocess_sentence(text_detoken_out, lang_codes, self.segment_tokenizers) 278 | result[f'{task}_text_pred'] = text_detoken_out 279 | if task == 'translate': 280 | metrics['bleu_spch'][dataloader_idx](spch_detoken_out, [[l] for l in detoken_label]) 281 | metrics['bleu_text'][dataloader_idx](text_detoken_out, [[l] for l in detoken_label]) 282 | elif task == 'transcript': 283 | o_list_ = [self.normalizer(o) for o in spch_detoken_out] 284 | l_list_ = [self.normalizer(l) for l in detoken_label] 285 | metrics['wer'][dataloader_idx](o_list_, l_list_) 286 | result[f'{task}_spch_pred'] = spch_detoken_out 287 | result[f'{task}_label'] = detoken_label 288 | 289 | return result 290 | 291 | def configure_optimizers(self): 292 | optimizer, scheduler = configure_optimizer_schedular( 293 | cfg=self.cfg, 294 | params_generator=self.named_parameters, 295 | num_training_steps=self.trainer.estimated_stepping_batches 296 | ) 297 | self.optimizer = optimizer 298 | self.scheduler = scheduler 299 | 300 | return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] 301 | 302 | def train_dataloader(self): 303 | dataset = ComSTDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate, cfg=self.cfg) 304 | return DataLoader(dataset, 305 | batch_size=self.cfg.batch_size, 306 | drop_last=True, shuffle=True, num_workers=self.cfg.num_worker, 307 | collate_fn=self.collect_fn 308 | ) 309 | 310 | def val_dataloader(self): 311 | datasets = [ComSTDataset(dataset, self.tokenizer, self.cfg.sample_rate, cfg=self.cfg) for 312 | dataset in self.__eval_dataset] 313 | return [DataLoader(dataset, 314 | batch_size=self.cfg.test_batch_size, 315 | num_workers=self.cfg.num_worker, 316 | collate_fn=self.collect_fn 317 | ) for dataset in datasets] 318 | 319 | def test_dataloader(self): 320 | datasets = [ComSTDataset(dataset, self.tokenizer, self.cfg.sample_rate, cfg=self.cfg) for 321 | dataset in self.__test_dataset] 322 | return [DataLoader(dataset, 323 | batch_size=self.cfg.test_batch_size, 324 | num_workers=self.cfg.num_worker, 325 | collate_fn=self.collect_fn 326 | ) for dataset in datasets] 327 | 328 | 329 | if __name__ == "__main__": 330 | from config.parse_yaml_args import parse_args_and_yaml 331 | from model.model_util import deep_to_device 332 | 333 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 334 | cfg = parse_args_and_yaml(config_path="../config/exp_spec/comsl.yaml") 335 | 336 | DATA_ROOT = cfg.data_root 337 | language_list = cfg.language_list 338 | extra_language_list = cfg.extra_language_list 339 | 340 | joined_data_pair_lists, sep_data_pair_lists = {}, {} 341 | for split in ["train", "dev", "test"]: 342 | joined_data_pair_lists[split], sep_data_pair_lists[split] = load_data_record( 343 | DATA_ROOT, 344 | split, 345 | language_list=language_list, 346 | expanded_data_root=cfg.cv_data_root, 347 | expanded_language_list=extra_language_list) 348 | if "OUTPUT_DIR" in os.environ: 349 | output_dir = os.environ["OUTPUT_DIR"] 350 | else: 351 | output_dir = cfg.output_dir 352 | 353 | cfg.cache_dir = f"{output_dir}/cache" 354 | 355 | cfg.batch_size = cfg.test_batch_size = 10 356 | cfg.num_worker = 4 357 | 358 | module = ComSTModule(cfg, joined_data_pair_lists, sep_data_pair_lists).cuda().eval() 359 | 360 | loader = module.test_dataloader()[0] 361 | 362 | optimizer, scheduler = configure_optimizer_schedular( 363 | cfg=module.cfg, 364 | params_generator=module.named_parameters, 365 | num_training_steps=10000 366 | ) 367 | 368 | with torch.no_grad(): 369 | for b in loader: 370 | 371 | b = deep_to_device(b, 'cuda') 372 | train_res = module.training_step(b, 0) 373 | print(train_res) 374 | 375 | valid_res = module.validation_step(b, 0, 0) 376 | print(valid_res) 377 | module.validation_epoch_end([valid_res]) 378 | 379 | test_res = module.test_step(b, 0, 0) 380 | print(test_res) 381 | module.test_epoch_end([test_res]) 382 | 383 | break 384 | -------------------------------------------------------------------------------- /Whisper/normalizers/english.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from fractions import Fraction 5 | from typing import Iterator, List, Match, Optional, Union 6 | 7 | from more_itertools import windowed 8 | 9 | from .basic import remove_symbols_and_diacritics 10 | 11 | 12 | class EnglishNumberNormalizer: 13 | """ 14 | Convert any spelled-out numbers into arabic numbers, while handling: 15 | 16 | - remove any commas 17 | - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. 18 | - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` 19 | - spell out `one` and `ones` 20 | - interpret successive single-digit numbers as nominal: `one oh one` -> `101` 21 | """ 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | self.zeros = {"o", "oh", "zero"} 27 | self.ones = { 28 | name: i 29 | for i, name in enumerate( 30 | [ 31 | "one", 32 | "two", 33 | "three", 34 | "four", 35 | "five", 36 | "six", 37 | "seven", 38 | "eight", 39 | "nine", 40 | "ten", 41 | "eleven", 42 | "twelve", 43 | "thirteen", 44 | "fourteen", 45 | "fifteen", 46 | "sixteen", 47 | "seventeen", 48 | "eighteen", 49 | "nineteen", 50 | ], 51 | start=1, 52 | ) 53 | } 54 | self.ones_plural = { 55 | "sixes" if name == "six" else name + "s": (value, "s") 56 | for name, value in self.ones.items() 57 | } 58 | self.ones_ordinal = { 59 | "zeroth": (0, "th"), 60 | "first": (1, "st"), 61 | "second": (2, "nd"), 62 | "third": (3, "rd"), 63 | "fifth": (5, "th"), 64 | "twelfth": (12, "th"), 65 | **{ 66 | name + ("h" if name.endswith("t") else "th"): (value, "th") 67 | for name, value in self.ones.items() 68 | if value > 3 and value != 5 and value != 12 69 | }, 70 | } 71 | self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} 72 | 73 | self.tens = { 74 | "twenty": 20, 75 | "thirty": 30, 76 | "forty": 40, 77 | "fifty": 50, 78 | "sixty": 60, 79 | "seventy": 70, 80 | "eighty": 80, 81 | "ninety": 90, 82 | } 83 | self.tens_plural = { 84 | name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() 85 | } 86 | self.tens_ordinal = { 87 | name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() 88 | } 89 | self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} 90 | 91 | self.multipliers = { 92 | "hundred": 100, 93 | "thousand": 1_000, 94 | "million": 1_000_000, 95 | "billion": 1_000_000_000, 96 | "trillion": 1_000_000_000_000, 97 | "quadrillion": 1_000_000_000_000_000, 98 | "quintillion": 1_000_000_000_000_000_000, 99 | "sextillion": 1_000_000_000_000_000_000_000, 100 | "septillion": 1_000_000_000_000_000_000_000_000, 101 | "octillion": 1_000_000_000_000_000_000_000_000_000, 102 | "nonillion": 1_000_000_000_000_000_000_000_000_000_000, 103 | "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, 104 | } 105 | self.multipliers_plural = { 106 | name + "s": (value, "s") for name, value in self.multipliers.items() 107 | } 108 | self.multipliers_ordinal = { 109 | name + "th": (value, "th") for name, value in self.multipliers.items() 110 | } 111 | self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} 112 | self.decimals = {*self.ones, *self.tens, *self.zeros} 113 | 114 | self.preceding_prefixers = { 115 | "minus": "-", 116 | "negative": "-", 117 | "plus": "+", 118 | "positive": "+", 119 | } 120 | self.following_prefixers = { 121 | "pound": "£", 122 | "pounds": "£", 123 | "euro": "€", 124 | "euros": "€", 125 | "dollar": "$", 126 | "dollars": "$", 127 | "cent": "¢", 128 | "cents": "¢", 129 | } 130 | self.prefixes = set( 131 | list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) 132 | ) 133 | self.suffixers = { 134 | "per": {"cent": "%"}, 135 | "percent": "%", 136 | } 137 | self.specials = {"and", "double", "triple", "point"} 138 | 139 | self.words = set( 140 | [ 141 | key 142 | for mapping in [ 143 | self.zeros, 144 | self.ones, 145 | self.ones_suffixed, 146 | self.tens, 147 | self.tens_suffixed, 148 | self.multipliers, 149 | self.multipliers_suffixed, 150 | self.preceding_prefixers, 151 | self.following_prefixers, 152 | self.suffixers, 153 | self.specials, 154 | ] 155 | for key in mapping 156 | ] 157 | ) 158 | self.literal_words = {"one", "ones"} 159 | 160 | def process_words(self, words: List[str]) -> Iterator[str]: 161 | prefix: Optional[str] = None 162 | value: Optional[Union[str, int]] = None 163 | skip = False 164 | 165 | def to_fraction(s: str): 166 | try: 167 | return Fraction(s) 168 | except ValueError: 169 | return None 170 | 171 | def output(result: Union[str, int]): 172 | nonlocal prefix, value 173 | result = str(result) 174 | if prefix is not None: 175 | result = prefix + result 176 | value = None 177 | prefix = None 178 | return result 179 | 180 | if len(words) == 0: 181 | return 182 | 183 | for prev, current, next in windowed([None] + words + [None], 3): 184 | if skip: 185 | skip = False 186 | continue 187 | 188 | next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) 189 | has_prefix = current[0] in self.prefixes 190 | current_without_prefix = current[1:] if has_prefix else current 191 | if re.match(r"^\d+(\.\d+)?$", current_without_prefix): 192 | # arabic numbers (potentially with signs and fractions) 193 | f = to_fraction(current_without_prefix) 194 | assert f is not None 195 | if value is not None: 196 | if isinstance(value, str) and value.endswith("."): 197 | # concatenate decimals / ip address components 198 | value = str(value) + str(current) 199 | continue 200 | else: 201 | yield output(value) 202 | 203 | prefix = current[0] if has_prefix else prefix 204 | if f.denominator == 1: 205 | value = f.numerator # store integers as int 206 | else: 207 | value = current_without_prefix 208 | elif current not in self.words: 209 | # non-numeric words 210 | if value is not None: 211 | yield output(value) 212 | yield output(current) 213 | elif current in self.zeros: 214 | value = str(value or "") + "0" 215 | elif current in self.ones: 216 | ones = self.ones[current] 217 | 218 | if value is None: 219 | value = ones 220 | elif isinstance(value, str) or prev in self.ones: 221 | if prev in self.tens and ones < 10: # replace the last zero with the digit 222 | assert value[-1] == "0" 223 | value = value[:-1] + str(ones) 224 | else: 225 | value = str(value) + str(ones) 226 | elif ones < 10: 227 | if value % 10 == 0: 228 | value += ones 229 | else: 230 | value = str(value) + str(ones) 231 | else: # eleven to nineteen 232 | if value % 100 == 0: 233 | value += ones 234 | else: 235 | value = str(value) + str(ones) 236 | elif current in self.ones_suffixed: 237 | # ordinal or cardinal; yield the number right away 238 | ones, suffix = self.ones_suffixed[current] 239 | if value is None: 240 | yield output(str(ones) + suffix) 241 | elif isinstance(value, str) or prev in self.ones: 242 | if prev in self.tens and ones < 10: 243 | assert value[-1] == "0" 244 | yield output(value[:-1] + str(ones) + suffix) 245 | else: 246 | yield output(str(value) + str(ones) + suffix) 247 | elif ones < 10: 248 | if value % 10 == 0: 249 | yield output(str(value + ones) + suffix) 250 | else: 251 | yield output(str(value) + str(ones) + suffix) 252 | else: # eleven to nineteen 253 | if value % 100 == 0: 254 | yield output(str(value + ones) + suffix) 255 | else: 256 | yield output(str(value) + str(ones) + suffix) 257 | value = None 258 | elif current in self.tens: 259 | tens = self.tens[current] 260 | if value is None: 261 | value = tens 262 | elif isinstance(value, str): 263 | value = str(value) + str(tens) 264 | else: 265 | if value % 100 == 0: 266 | value += tens 267 | else: 268 | value = str(value) + str(tens) 269 | elif current in self.tens_suffixed: 270 | # ordinal or cardinal; yield the number right away 271 | tens, suffix = self.tens_suffixed[current] 272 | if value is None: 273 | yield output(str(tens) + suffix) 274 | elif isinstance(value, str): 275 | yield output(str(value) + str(tens) + suffix) 276 | else: 277 | if value % 100 == 0: 278 | yield output(str(value + tens) + suffix) 279 | else: 280 | yield output(str(value) + str(tens) + suffix) 281 | elif current in self.multipliers: 282 | multiplier = self.multipliers[current] 283 | if value is None: 284 | value = multiplier 285 | elif isinstance(value, str) or value == 0: 286 | f = to_fraction(value) 287 | p = f * multiplier if f is not None else None 288 | if f is not None and p.denominator == 1: 289 | value = p.numerator 290 | else: 291 | yield output(value) 292 | value = multiplier 293 | else: 294 | before = value // 1000 * 1000 295 | residual = value % 1000 296 | value = before + residual * multiplier 297 | elif current in self.multipliers_suffixed: 298 | multiplier, suffix = self.multipliers_suffixed[current] 299 | if value is None: 300 | yield output(str(multiplier) + suffix) 301 | elif isinstance(value, str): 302 | f = to_fraction(value) 303 | p = f * multiplier if f is not None else None 304 | if f is not None and p.denominator == 1: 305 | yield output(str(p.numerator) + suffix) 306 | else: 307 | yield output(value) 308 | yield output(str(multiplier) + suffix) 309 | else: # int 310 | before = value // 1000 * 1000 311 | residual = value % 1000 312 | value = before + residual * multiplier 313 | yield output(str(value) + suffix) 314 | value = None 315 | elif current in self.preceding_prefixers: 316 | # apply prefix (positive, minus, etc.) if it precedes a number 317 | if value is not None: 318 | yield output(value) 319 | 320 | if next in self.words or next_is_numeric: 321 | prefix = self.preceding_prefixers[current] 322 | else: 323 | yield output(current) 324 | elif current in self.following_prefixers: 325 | # apply prefix (dollars, cents, etc.) only after a number 326 | if value is not None: 327 | prefix = self.following_prefixers[current] 328 | yield output(value) 329 | else: 330 | yield output(current) 331 | elif current in self.suffixers: 332 | # apply suffix symbols (percent -> '%') 333 | if value is not None: 334 | suffix = self.suffixers[current] 335 | if isinstance(suffix, dict): 336 | if next in suffix: 337 | yield output(str(value) + suffix[next]) 338 | skip = True 339 | else: 340 | yield output(value) 341 | yield output(current) 342 | else: 343 | yield output(str(value) + suffix) 344 | else: 345 | yield output(current) 346 | elif current in self.specials: 347 | if next not in self.words and not next_is_numeric: 348 | # apply special handling only if the next word can be numeric 349 | if value is not None: 350 | yield output(value) 351 | yield output(current) 352 | elif current == "and": 353 | # ignore "and" after hundreds, thousands, etc. 354 | if prev not in self.multipliers: 355 | if value is not None: 356 | yield output(value) 357 | yield output(current) 358 | elif current == "double" or current == "triple": 359 | if next in self.ones or next in self.zeros: 360 | repeats = 2 if current == "double" else 3 361 | ones = self.ones.get(next, 0) 362 | value = str(value or "") + str(ones) * repeats 363 | skip = True 364 | else: 365 | if value is not None: 366 | yield output(value) 367 | yield output(current) 368 | elif current == "point": 369 | if next in self.decimals or next_is_numeric: 370 | value = str(value or "") + "." 371 | else: 372 | # should all have been covered at this point 373 | raise ValueError(f"Unexpected token: {current}") 374 | else: 375 | # all should have been covered at this point 376 | raise ValueError(f"Unexpected token: {current}") 377 | 378 | if value is not None: 379 | yield output(value) 380 | 381 | def preprocess(self, s: str): 382 | # replace " and a half" with " point five" 383 | results = [] 384 | 385 | segments = re.split(r"\band\s+a\s+half\b", s) 386 | for i, segment in enumerate(segments): 387 | if len(segment.strip()) == 0: 388 | continue 389 | if i == len(segments) - 1: 390 | results.append(segment) 391 | else: 392 | results.append(segment) 393 | last_word = segment.rsplit(maxsplit=2)[-1] 394 | if last_word in self.decimals or last_word in self.multipliers: 395 | results.append("point five") 396 | else: 397 | results.append("and a half") 398 | 399 | s = " ".join(results) 400 | 401 | # put a space at number/letter boundary 402 | s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) 403 | s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) 404 | 405 | # but remove spaces which could be a suffix 406 | s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) 407 | 408 | return s 409 | 410 | def postprocess(self, s: str): 411 | def combine_cents(m: Match): 412 | try: 413 | currency = m.group(1) 414 | integer = m.group(2) 415 | cents = int(m.group(3)) 416 | return f"{currency}{integer}.{cents:02d}" 417 | except ValueError: 418 | return m.string 419 | 420 | def extract_cents(m: Match): 421 | try: 422 | return f"¢{int(m.group(1))}" 423 | except ValueError: 424 | return m.string 425 | 426 | # apply currency postprocessing; "$2 and ¢7" -> "$2.07" 427 | s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) 428 | s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) 429 | 430 | # write "one(s)" instead of "1(s)", just for the readability 431 | s = re.sub(r"\b1(s?)\b", r"one\1", s) 432 | 433 | return s 434 | 435 | def __call__(self, s: str): 436 | s = self.preprocess(s) 437 | s = " ".join(word for word in self.process_words(s.split()) if word is not None) 438 | s = self.postprocess(s) 439 | 440 | return s 441 | 442 | 443 | class EnglishSpellingNormalizer: 444 | """ 445 | Applies British-American spelling mappings as listed in [1]. 446 | 447 | [1] https://www.tysto.com/uk-us-spelling-list.html 448 | """ 449 | 450 | def __init__(self): 451 | mapping_path = os.path.join(os.path.dirname(__file__), "english.json") 452 | self.mapping = json.load(open(mapping_path)) 453 | 454 | def __call__(self, s: str): 455 | return " ".join(self.mapping.get(word, word) for word in s.split()) 456 | 457 | 458 | class EnglishTextNormalizer: 459 | def __init__(self): 460 | self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" 461 | self.replacers = { 462 | # common contractions 463 | r"\bwon't\b": "will not", 464 | r"\bcan't\b": "can not", 465 | r"\blet's\b": "let us", 466 | r"\bain't\b": "aint", 467 | r"\by'all\b": "you all", 468 | r"\bwanna\b": "want to", 469 | r"\bgotta\b": "got to", 470 | r"\bgonna\b": "going to", 471 | r"\bi'ma\b": "i am going to", 472 | r"\bimma\b": "i am going to", 473 | r"\bwoulda\b": "would have", 474 | r"\bcoulda\b": "could have", 475 | r"\bshoulda\b": "should have", 476 | r"\bma'am\b": "madam", 477 | # contractions in titles/prefixes 478 | r"\bmr\b": "mister ", 479 | r"\bmrs\b": "missus ", 480 | r"\bst\b": "saint ", 481 | r"\bdr\b": "doctor ", 482 | r"\bprof\b": "professor ", 483 | r"\bcapt\b": "captain ", 484 | r"\bgov\b": "governor ", 485 | r"\bald\b": "alderman ", 486 | r"\bgen\b": "general ", 487 | r"\bsen\b": "senator ", 488 | r"\brep\b": "representative ", 489 | r"\bpres\b": "president ", 490 | r"\brev\b": "reverend ", 491 | r"\bhon\b": "honorable ", 492 | r"\basst\b": "assistant ", 493 | r"\bassoc\b": "associate ", 494 | r"\blt\b": "lieutenant ", 495 | r"\bcol\b": "colonel ", 496 | r"\bjr\b": "junior ", 497 | r"\bsr\b": "senior ", 498 | r"\besq\b": "esquire ", 499 | # prefect tenses, ideally it should be any past participles, but it's harder.. 500 | r"'d been\b": " had been", 501 | r"'s been\b": " has been", 502 | r"'d gone\b": " had gone", 503 | r"'s gone\b": " has gone", 504 | r"'d done\b": " had done", # "'s done" is ambiguous 505 | r"'s got\b": " has got", 506 | # general contractions 507 | r"n't\b": " not", 508 | r"'re\b": " are", 509 | r"'s\b": " is", 510 | r"'d\b": " would", 511 | r"'ll\b": " will", 512 | r"'t\b": " not", 513 | r"'ve\b": " have", 514 | r"'m\b": " am", 515 | } 516 | self.standardize_numbers = EnglishNumberNormalizer() 517 | self.standardize_spellings = EnglishSpellingNormalizer() 518 | 519 | def __call__(self, s: str): 520 | s = s.lower() 521 | 522 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 523 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 524 | s = re.sub(self.ignore_patterns, "", s) 525 | s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe 526 | 527 | for pattern, replacement in self.replacers.items(): 528 | s = re.sub(pattern, replacement, s) 529 | 530 | s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits 531 | s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers 532 | s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics 533 | 534 | s = self.standardize_numbers(s) 535 | s = self.standardize_spellings(s) 536 | 537 | # now remove prefix/suffix symbols that are not preceded/followed by numbers 538 | s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) 539 | s = re.sub(r"([^0-9])%", r"\1 ", s) 540 | 541 | s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space 542 | 543 | return s 544 | --------------------------------------------------------------------------------