├── .gitignore ├── README.md ├── asr ├── analysis │ ├── analyse_ctc_prob.py │ └── compare_wer.py ├── correct │ ├── README.md │ └── exps │ │ └── csj │ │ ├── asr.yaml │ │ └── del_pc_mlm.yaml ├── criteria.py ├── datasets.py ├── distill │ ├── eval_label.py │ └── make_label.py ├── fusion │ └── test_fusion_grid.py ├── metrics.py ├── modeling │ ├── asr.py │ ├── conformer.py │ ├── decoders │ │ ├── ctc.py │ │ ├── ctc_aligner.py │ │ ├── ctc_score.py │ │ ├── las.py │ │ ├── rnn_transducer.py │ │ ├── rnnt_aligner.py │ │ └── transformer.py │ ├── encoders │ │ ├── conv.py │ │ ├── rnn.py │ │ └── transformer.py │ ├── model_utils.py │ └── transformer.py ├── optimizers.py ├── rescore │ ├── README.md │ ├── align_hyps.py │ └── test_rescore_grid.py ├── spec_augment.py ├── test_asr.py ├── test_asr_correct.py └── train_asr.py ├── corpora ├── epasr │ ├── make_utts_json.py │ ├── make_utts_stm.py │ └── prep.sh ├── ted2 │ ├── join_suffix.py │ ├── make_utts.py │ ├── prep.sh │ └── prep_tsv.py └── utils │ ├── concat_text.py │ ├── get_cols.py │ ├── map2phone.py │ ├── map2phone_g2p.py │ ├── norm_feats.py │ ├── rm_utt.py │ ├── sort_bylen.py │ ├── split_tsv.py │ ├── spm_encode.py │ ├── spm_train.py │ └── wav_to_feats.py ├── lm ├── README.md ├── criteria.py ├── datasets.py ├── exps │ └── ted2_nsp10k │ │ ├── electra.yaml │ │ └── pelectra.yaml ├── modeling │ ├── bert.py │ ├── electra.py │ ├── lm.py │ ├── p2w.py │ ├── rnn.py │ ├── transformer.py │ └── transformers │ │ ├── activations.py │ │ ├── configuration_electra.py │ │ ├── configuration_transformers.py │ │ ├── configuration_utils.py │ │ ├── file_utils.py │ │ ├── modeling_bert.py │ │ ├── modeling_electra.py │ │ └── modeling_utils.py ├── test_ppl.py ├── text_augment.py └── train_lm.py └── utils ├── average_checkpoints.py ├── configure.py ├── converters.py ├── log.py ├── paths.py └── vocab.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .vscode/ 4 | OLD/ 5 | nohup.out 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # emoASR 2 | 3 | ## Features 4 | 5 | ### ASR 6 | 7 | * Encoder 8 | * RNN 9 | * Transformer (Trf.) [[Vaswani 2017]](https://arxiv.org/abs/1706.03762) 10 | * Conformer (Cf.) [[Gulati 2020]](https://arxiv.org/abs/2005.08100) 11 | * Decoder 12 | * CTC [[Graves 2006]](https://www.cs.toronto.edu/~graves/icml_2006.pdf) 13 | * RNN-Transducer (RNN-T) [[Graves 2012]](https://arxiv.org/abs/1211.3711) 14 | * LAS [[Chan 2015]](https://arxiv.org/abs/1508.01211) 15 | * Transformer (Trf.) 16 | 17 | ### LM 18 | 19 | * Modeling 20 | * RNNLM 21 | * Transformer LM 22 | * BERT [[Devlin 2018]](https://arxiv.org/abs/1810.04805) 23 | * ELECTRA [[Clark 2020]](https://arxiv.org/abs/2003.10555) 24 | * Phone-attentive ELECTRA (P-ELECTRA) [[Futami 2021]](https://arxiv.org/abs/2110.01857) 25 | 26 | * Method 27 | * Rescoring 28 | * Shallow Fusion 29 | * Knowledge Distillation [[Futami 2020]](https://arxiv.org/abs/2008.03822) 30 | 31 | ## Results 32 | 33 | ### Librispeech[WER] 34 | 35 | | | Decoder(Encoder) | params | clean | other | 36 | |:---:|:---|:---:|:---:|:---:| 37 | | `L1` | CTC(Trf.) | 20M | 5.2 | 11.8 | 38 | | `L2` | CTC(Cf.) | 23M | 4.2 | 10.1 | 39 | | `L3` | Trf.(Cf.) | 35M | 3.2 | 7.0 | 40 | | `L3-1` | +CTC | - | 2.9 | 6.9 | 41 | | `L3-2` | +SF | - | 2.9 | 6.3 | 42 | | `L3-3` | +CTC+SF | - | **2.5** | **6.0** | 43 | | `L4` | RNN-T(Cf.) 1kBPE | 26M | 2.8 | 7.0 | 44 | 45 | ### TED-LIUM2[WER] 46 | 47 | | | Decoder(Encoder) | params | test | dev | 48 | |:---:|:---|:---:|:---:|:---:| 49 | | `T1` | CTC(Trf.) | 20M | 10.9 | 12.4 | 50 | | `T2` | CTC(Cf.) | 23M | 9.4 | 10.1 | 51 | | `T3` | Trf.(Cf.) | 35M | 7.8 | 11.5 | 52 | | `T3-1` | +CTC | - | 7.4 | 9.6 | 53 | | `T3-2` | +SF | - | 7.4 | 10.7 | 54 | | `T3-3` | +CTC+SF | - | **6.8** | 9.2 | 55 | | `T4` | RNN-T(Trf.) 1kBPE | 22M | 9.5 | 10.5 | 56 | | `T5` | RNN-T(Cf.) 1kBPE | 26M | 7.4 | **8.1** | 57 | 58 | ### CSJ[WER/CER] 59 | 60 | | | | params | eval1 | eval2 | eval3 | 61 | |:---:|:---|:---:|:---:|:---:|:---:| 62 | | `C1` | CTC(Trf.) | 20M | 8.1/6.2 | 6.4/4.8 | 6.9/5.0 | 63 | | `C2` | CTC(Cf.) | 24M | 6.8/5.0 | 5.3/4.0 | 5.9/4.3 | 64 | | `C3` | Trf.(Trf.) | 32M | 6.7/5.0 | **4.9**/3.6 | 5.5/4.0 | 65 | | `C4` | Trf.(Cf.) | 36M | **6.3**/4.7 | 5.0/3.8 | **5.2**/4.0 | 66 | | `C5` | RNN-T(Cf.) 4kBPE | 33M | 6.4/4.7 | 5.0/4.1 | 5.3/4.1 | 67 | | `C6` | RNN-T(Cf.) 4kBPE Large | 91M | 6.2/4.5 | 4.9/3.7 | 5.2/4.1 | 68 | 69 | ## Reference 70 | 71 | * https://github.com/espnet/espnet 72 | * https://github.com/hirofumi0810/neural_sp 73 | -------------------------------------------------------------------------------- /asr/analysis/analyse_ctc_prob.py: -------------------------------------------------------------------------------- 1 | """ see frame-level predictions in CTC 2 | """ 3 | import argparse 4 | import json 5 | import os 6 | import sys 7 | 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | ROOT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 13 | sys.path.append(ROOT_DIR) 14 | 15 | from asr.datasets import ASRDataset 16 | from asr.modeling.asr import ASR 17 | from utils.configure import load_config 18 | from utils.paths import get_eval_path, get_model_path, rel_to_abs_path 19 | from utils.vocab import Vocab 20 | 21 | # Reproducibility 22 | torch.manual_seed(0) 23 | torch.cuda.manual_seed_all(0) 24 | 25 | 26 | def test(args): 27 | device = torch.device("cpu") 28 | torch.set_num_threads(1) 29 | # make sure all operations are done on cpu 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 31 | 32 | params = load_config(args.conf) 33 | model_path = get_model_path(args.conf, args.ep) 34 | 35 | model = ASR(params, phase="test") 36 | 37 | state_dict = torch.load(model_path, map_location=device) 38 | model.load_state_dict(state_dict) 39 | model.eval() 40 | model.to(device) 41 | 42 | data_path = get_eval_path(args.data) 43 | dataset = ASRDataset(params, rel_to_abs_path(data_path), phase="test") 44 | dataloader = DataLoader( 45 | dataset=dataset, 46 | batch_size=1, 47 | shuffle=False, 48 | collate_fn=dataset.collate_fn, 49 | num_workers=1, 50 | ) 51 | vocab = Vocab(rel_to_abs_path(params.vocab_path)) 52 | 53 | for i, data in enumerate(dataloader): 54 | utt_id = data["utt_ids"][0] 55 | if utt_id != args.utt_id: 56 | continue 57 | 58 | xs = data["xs"].to(device) 59 | xlens = data["xlens"].to(device) 60 | 61 | hyps, _, logits, _ = model.decode(xs, xlens) 62 | probs = torch.softmax(logits, dim=-1) 63 | print(vocab.ids2text(hyps[0])) 64 | print("###") 65 | 66 | for i, prob in enumerate(probs[0]): 67 | print(f"Frame #{i:d}: ", end="") 68 | p_topk, v_topk = torch.topk(prob, k=args.topk) 69 | for p, v in zip(p_topk, v_topk): 70 | print(f"{vocab.id2token(v.item())}({p.item():.3f}) ", end="") 71 | print() 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("-conf", type=str, required=True) 77 | parser.add_argument("-ep", type=str, required=True) 78 | parser.add_argument("-utt_id", type=str, required=True) 79 | parser.add_argument("-data", type=str, required=True) 80 | parser.add_argument("--topk", type=int, default=5) 81 | args = parser.parse_args() 82 | 83 | test(args) 84 | -------------------------------------------------------------------------------- /asr/analysis/compare_wer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 9 | sys.path.append(EMOASR_ROOT) 10 | 11 | from asr.metrics import compute_wer 12 | from utils.paths import get_eval_path 13 | 14 | 15 | def main(args): 16 | dfhyp1 = pd.read_table(args.hyp1, comment="#") 17 | dfhyp2 = pd.read_table(args.hyp2, comment="#") 18 | dfref = pd.read_table(get_eval_path(args.ref), comment="#") 19 | 20 | id2hyp1, id2hyp2 = {}, {} 21 | cnt_na1, cnt_na2 = 0, 0 22 | 23 | for rowhyp1 in dfhyp1.itertuples(): 24 | if pd.isna(rowhyp1.text): 25 | id2hyp1[rowhyp1.utt_id] = [] 26 | cnt_na1 += 1 27 | else: 28 | id2hyp1[rowhyp1.utt_id] = rowhyp1.text.split() 29 | for rowhyp2 in dfhyp2.itertuples(): 30 | if pd.isna(rowhyp2.text): 31 | id2hyp2[rowhyp2.utt_id] = [] 32 | cnt_na2 += 1 33 | else: 34 | id2hyp2[rowhyp2.utt_id] = rowhyp2.text.split() 35 | 36 | for rowref in dfref.itertuples(): 37 | hyp1 = id2hyp1[rowref.utt_id] if rowref.utt_id in id2hyp1 else [""] 38 | hyp2 = id2hyp2[rowref.utt_id] if rowref.utt_id in id2hyp2 else [""] 39 | ref = rowref.text.split() 40 | 41 | _, wer_dict1 = compute_wer(hyp1, ref) 42 | _, wer_dict2 = compute_wer(hyp2, ref) 43 | n_err1 = wer_dict1["n_del"] + wer_dict1["n_sub"] + wer_dict1["n_ins"] 44 | n_err2 = wer_dict2["n_del"] + wer_dict2["n_sub"] + wer_dict2["n_ins"] 45 | n_err_diff = abs(n_err1 - n_err2) 46 | 47 | if n_err1 == n_err2: 48 | neq = "=" 49 | elif n_err1 < n_err2: 50 | neq = "<" 51 | else: 52 | neq = ">" 53 | 54 | if ( 55 | (args.filter is None or args.filter == neq) 56 | and len(ref) <= args.max_len 57 | and n_err_diff >= args.min_diff 58 | ): 59 | print(f"utt_id: {rowref.utt_id}", flush=True) 60 | print( 61 | f"hyp1[D={wer_dict1['n_del']} S={wer_dict1['n_sub']} I={wer_dict1['n_ins']}] {neq} hyp2[D={wer_dict2['n_del']} S={wer_dict2['n_sub']} I={wer_dict2['n_ins']}]" 62 | ) 63 | print(f"hyp1: {' '.join(hyp1)}") 64 | print(f"hyp2: {' '.join(hyp2)}") 65 | print(f"ref : {' '.join(ref)}") 66 | print("==========") 67 | 68 | print(f"cannot decode: hyp1: {cnt_na1:d}, hyp2: {cnt_na2:d}") 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("-hyp1", type=str) 74 | parser.add_argument("-hyp2", type=str) 75 | parser.add_argument("-ref", type=str) 76 | parser.add_argument("--filter", type=str, choices=["<", ">", "=", ""], default=None) 77 | parser.add_argument("--max_len", type=int, default=1000) 78 | parser.add_argument("--min_diff", type=int, default=0) 79 | args = parser.parse_args() 80 | main(args) 81 | -------------------------------------------------------------------------------- /asr/correct/README.md: -------------------------------------------------------------------------------- 1 | # Non-autoregressive Error Correction for CTC-based ASR with Phone-conditioned Masked LM 2 | 3 | ## Requirements 4 | 5 | * Python 3.8 6 | 7 | ``` 8 | torch==1.7.1 9 | gitpython==3.1.24 10 | numpy==1.20.3 11 | pandas==1.3.3 12 | ``` 13 | 14 | ## Data 15 | Prepare train/test data in tsv format as follows: 16 | ``` 17 | utt_id feat_path xlen token_id text ylen phone_text phone_token_id 18 | S09F1170_0516169_0516727 path/to/S09F1170_0516169_0516727.npy 54 8 538 20 で なぜ か 3 d e n a z e k a 12 14 26 7 42 14 22 7 19 | ``` 20 | Note that `feat_path` and `xlen` is not required in PC-MLM training data. 21 | 22 | ## Train ASR 23 | Train a Transformer-based CTC model on **SPS** subset of CSJ. It is trained for phone-level targets in addition to word-level ones, known as hierarchical multi-task learning. 24 | 25 | ``` 26 | python asr/train_asr.py -conf asr/correct/exps/csj/asr.yaml 27 | ``` 28 | 29 | Model checkpoints and log will be saved at `exps/csj/asr/` 30 | 31 | ## Train PC-MLM (Error Correction model) 32 | Train a phone-conditioned masked LM (PC-MLM), which is a phone-to-word conversion model, on **APS** subset of CSJ. 33 | 34 | Deletable version of PC-MLM (Del PC-MLM) that addresses insertion errors is trained as follows: 35 | ``` 36 | python lm/train_lm.py -conf asr/correct/exps/csj/del_pc_mlm.yaml 37 | ``` 38 | 39 | ## Test ASR 40 | Test ASR on `eval1` set for APS domain (domain adaptation setting). 41 | 42 | Without correction (A1): 43 | ``` 44 | python asr/test_asr.py -conf asr/correct/exps/csj/asr.yaml -ep 91-100 45 | ``` 46 | 47 | With correction (A7): 48 | ``` 49 | python asr/test_asr_correct.py -conf asr/correct/exps/csj/asr.yaml -ep 91-100 -lm_conf asr/correct/exps/csj/del_pc_mlm.yaml -lm_ep 100 --lm_weight 0.5 --mask_th 0.8 50 | ``` 51 | 52 | Results are saved at `exps/csj/asr/results/`. 53 | RTF can be calculated with `--runtime` option. 54 | 55 | | | | WER | RTF | 56 | |:---:|:---|:---:|:---:| 57 | | (A1) | CTC (greedy) | 18.10 | 0.0033 | 58 | | (A7) | +Correction (w/ Del PC-MLM) | 16.48 | 0.0094 | 59 | -------------------------------------------------------------------------------- /asr/correct/exps/csj/asr.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: "transformer" 2 | decoder_type: "ctc" 3 | lr_schedule_type: "noam" 4 | 5 | # frontend 6 | input_layer: "conv2d" 7 | feat_dim: 80 8 | num_framestacks: 1 9 | spec_augment: true 10 | max_mask_freq: 30 11 | max_mask_time: 40 12 | num_masks_freq: 2 13 | num_masks_time: 2 14 | replace_with_zero: true 15 | 16 | # model 17 | enc_hidden_size: 256 18 | enc_num_attention_heads: 4 19 | enc_num_layers: 12 20 | enc_intermediate_size: 2048 21 | 22 | # data 23 | blank_id: 0 24 | eos_id: 2 25 | phone_eos_id: 2 26 | vocab_path: "corpora/csj/nsp10k/data/orig/vocab.txt" 27 | phone_vocab_path: "corpora/csj/nsp10k/data/orig/vocab_phone.txt" 28 | vocab_size: 10872 29 | phone_vocab_size: 43 30 | train_path: "corpora/csj/nsp10k/data/train_nodev_sps_sorted_p2w_ctc.tsv" 31 | dev_path: "corpora/csj/nsp10k/data/dev_500.tsv" 32 | test_path: "corpora/csj/nsp10k/data/eval1.tsv" 33 | train_data_shuffle: true 34 | 35 | model_path: "" 36 | optim_path: "" 37 | startep: 0 38 | log_step: 100 39 | save_step: 1 40 | 41 | # train 42 | batch_size: 50 43 | max_xlens_batch: 30000 44 | max_ylens_batch: 3000 45 | num_epochs: 100 46 | learning_rate: 5.0 47 | num_warmup_steps: 25000 48 | clip_grad_norm: 5.0 49 | dropout_enc_rate: 0.1 50 | dropout_attn_rate: 0.1 51 | weight_decay: 0.000001 52 | accum_grad: 5 53 | lsm_prob: 0 54 | kd_weight: 0 55 | 56 | # MTL 57 | mtl_phone_ctc_weight: 0.3 58 | hie_mtl_phone: true 59 | inter_ctc_layer_id: 6 60 | mtl_inter_ctc_weight: 0 61 | 62 | # decode 63 | beam_width: 0 64 | len_weight: 0 65 | decode_ctc_weight: 0 66 | lm_weight: 0 67 | -------------------------------------------------------------------------------- /asr/correct/exps/csj/del_pc_mlm.yaml: -------------------------------------------------------------------------------- 1 | lm_type: "pbert" 2 | lr_schedule_type: "noam" 3 | 4 | # model 5 | input_layer: "embed" 6 | enc_hidden_size: 256 7 | enc_num_attention_heads: 4 8 | enc_num_layers: 4 9 | enc_intermediate_size: 1024 10 | dec_hidden_size: 256 11 | dec_num_attention_heads: 4 12 | dec_num_layers: 4 13 | dec_intermediate_size: 1024 14 | dropout_enc_rate: 0.1 15 | dropout_dec_rate: 0.1 16 | dropout_attn_rate: 0.1 17 | mtl_ctc_weight: 0 18 | lsm_prob: 0 19 | kd_weight: 0 20 | max_decode_ylen: 256 21 | 22 | text_augment: true 23 | textaug_max_mask_prob: 0.4 24 | textaug_max_replace_prob: 0 25 | 26 | # data 27 | vocab_size: 10872 28 | src_vocab_size: 45 29 | max_seq_len: 256 30 | eos_id: 2 31 | mask_id: 10871 32 | phone_eos_id: 2 33 | phone_mask_id: 43 34 | blank_id: 0 35 | train_path: "corpora/csj/nsp10k/data/train_nodev_aps_p2w_sorted.tsv" 36 | train_size: 154446 37 | add_sos_eos: false 38 | dev_path: "corpora/csj/nsp10k/data/dev_500.tsv" 39 | test_path: "corpora/csj/nsp10k/data/eval1_p2w.tsv" 40 | vocab_path: "corpora/csj/nsp10k/data/orig/vocab.txt" 41 | bucket_shuffle: true 42 | 43 | model_path: "" 44 | optim_path: "" 45 | startep: 0 46 | log_step: 100 47 | save_step: 1 48 | 49 | # train 50 | batch_size: 100 51 | num_epochs: 100 52 | max_plens_batch: 20000 53 | max_ylens_batch: 10000 54 | learning_rate: 5.0 55 | warmup_proportion: 0.1 56 | weight_decay: 0.01 57 | clip_grad_norm: 5.0 58 | mask_proportion: 0.3 59 | random_num_to_mask: true 60 | mask_insert_poisson_lam: 0.2 61 | accum_grad: 1 62 | weight_tying: false 63 | -------------------------------------------------------------------------------- /asr/datasets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.data import Dataset, Sampler 12 | 13 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 14 | sys.path.append(EMOASR_ROOT) 15 | 16 | from utils.converters import get_utt_id_nosp, str2ints 17 | 18 | from asr.spec_augment import SpecAugment 19 | 20 | random.seed(0) 21 | eos_id = 2 22 | phone_eos_id = 2 23 | 24 | 25 | class ASRDataset(Dataset): 26 | def __init__(self, params, data_path, phase="train", size=-1, decode_phone=False): 27 | self.feat_dim = params.feat_dim 28 | self.num_framestacks = params.num_framestacks 29 | self.vocab_size = params.vocab_size 30 | self.lsm_prob = params.lsm_prob 31 | 32 | global eos_id 33 | eos_id = params.eos_id 34 | 35 | self.phase = phase # `train` or `test` or `valid` 36 | 37 | if self.phase == "train" and params.spec_augment: 38 | self.specaug = SpecAugment(params) 39 | else: 40 | self.specaug = None 41 | 42 | self.data = pd.read_table(data_path) 43 | 44 | self.mtl_phone_ctc_weight = ( 45 | params.mtl_phone_ctc_weight 46 | if hasattr(params, "mtl_phone_ctc_weight") 47 | else 0 48 | ) 49 | 50 | if (self.phase == "train" and self.mtl_phone_ctc_weight > 0) or decode_phone: 51 | self.data = self.data[ 52 | [ 53 | "feat_path", 54 | "utt_id", 55 | "token_id", 56 | "text", 57 | "xlen", 58 | "ylen", 59 | "phone_token_id", 60 | "phone_text", 61 | ] 62 | ] 63 | global phone_eos_id 64 | phone_eos_id = params.phone_eos_id 65 | else: 66 | self.data = self.data[ 67 | ["feat_path", "utt_id", "token_id", "text", "xlen", "ylen"] 68 | ] 69 | 70 | self.use_kd = params.kd_weight > 0 or (hasattr(params, "inter_kd_weight") and params.inter_kd_weight > 0) 71 | 72 | if self.phase == "train" and self.use_kd: 73 | with open(params.kd_label_path, "rb") as f: 74 | self.data_kd = pickle.load(f) 75 | logging.info(f"kd labels: {params.kd_label_path}") 76 | 77 | self.add_eos = params.decoder_type in ["transformer", "las"] 78 | else: 79 | self.data_kd = None 80 | 81 | if size > 0: 82 | self.data = self.data[:size] 83 | 84 | def __len__(self): 85 | return len(self.data) 86 | 87 | def __getitem__(self, idx): 88 | utt_id = self.data.loc[idx]["utt_id"] 89 | text = self.data.loc[idx]["text"] 90 | 91 | feat_path = self.data.loc[idx]["feat_path"] 92 | x = np.load(feat_path)[:, : self.feat_dim] 93 | 94 | if self.specaug is not None: 95 | x = self.specaug(x) 96 | 97 | x = torch.tensor(x, dtype=torch.float) # float32 98 | 99 | if self.num_framestacks > 1: 100 | x = self._stack_frames(x, self.num_framestacks) 101 | 102 | xlen = x.size(0) # `xlen` is based on length after frame stacking 103 | 104 | token_id = str2ints(self.data.loc[idx]["token_id"]) 105 | y = torch.tensor(token_id, dtype=torch.long) # int64 106 | ylen = y.size(0) 107 | 108 | if "phone_token_id" in self.data: 109 | phone_token_id = str2ints(self.data.loc[idx]["phone_token_id"]) 110 | phone_text = self.data.loc[idx]["phone_text"] 111 | p = torch.tensor(phone_token_id, dtype=torch.long) 112 | plen = p.size(0) 113 | ptext = phone_text 114 | else: 115 | p, plen, ptext = None, None, None 116 | 117 | # for knowledge distillation 118 | if self.data_kd is not None: 119 | utt_id_nosp = get_utt_id_nosp(utt_id) 120 | 121 | if utt_id_nosp in self.data_kd: 122 | data_kd_utt = self.data_kd[utt_id_nosp] 123 | else: 124 | data_kd_utt = [] 125 | logging.warning(f"soft label: {utt_id_nosp} not found") 126 | 127 | soft_label = create_soft_label( 128 | data_kd_utt, ylen, self.vocab_size, self.lsm_prob, add_eos=self.add_eos 129 | ) 130 | else: 131 | soft_label = None 132 | 133 | return utt_id, x, xlen, y, ylen, text, p, plen, ptext, soft_label 134 | 135 | @staticmethod 136 | def _stack_frames(x, num_framestacks): 137 | new_len = x.size(0) // num_framestacks 138 | feat_dim = x.size(1) 139 | x_stacked = x[0 : new_len * num_framestacks].reshape( 140 | new_len, feat_dim * num_framestacks 141 | ) 142 | 143 | return x_stacked 144 | 145 | @staticmethod 146 | def collate_fn(batch): 147 | utt_ids, xs, xlens, ys_list, ylens, texts, ps, plens, ptexts, soft_labels = zip(*batch) 148 | 149 | ret = {} 150 | 151 | ret["utt_ids"] = list(utt_ids) 152 | ret["texts"] = list(texts) 153 | 154 | # ys = [[y_1, ..., y_n], ...] 155 | ret["ys"] = pad_sequence( 156 | ys_list, batch_first=True, padding_value=eos_id 157 | ) # NOTE: without 158 | 159 | # add and here 160 | ys_eos_list = [[eos_id] + y.tolist() + [eos_id] for y in ys_list] 161 | 162 | # ys_in = [[, y_1, ..., y_n], ...], ys_out = [[y_1, ..., y_n, ], ...] 163 | ys_in = [torch.tensor(y[:-1], dtype=torch.long) for y in ys_eos_list] 164 | ys_out = [torch.tensor(y[1:], dtype=torch.long) for y in ys_eos_list] 165 | 166 | ret["xs"] = pad_sequence(xs, batch_first=True) 167 | ret["xlens"] = torch.tensor(xlens) 168 | ret["ys_in"] = pad_sequence( 169 | ys_in, batch_first=True, padding_value=eos_id 170 | ) # NOTE: is added 171 | ret["ys_out"] = pad_sequence( 172 | ys_out, batch_first=True, padding_value=eos_id 173 | ) # NOTE: is added 174 | 175 | # NOTE: ys_in and ys_out have length ylens+1 176 | ret["ylens"] = torch.tensor(ylens, dtype=torch.long) 177 | 178 | if ps[0] is not None: 179 | ret["ps"] = pad_sequence(ps, batch_first=True, padding_value=phone_eos_id) 180 | ret["plens"] = torch.tensor(plens) 181 | ret["ptexts"] = list(ptexts) 182 | 183 | if soft_labels[0] is not None: 184 | ret["soft_labels"] = pad_sequence(soft_labels, batch_first=True) 185 | 186 | return ret 187 | 188 | 189 | class ASRBatchSampler(Sampler): 190 | def __init__(self, dataset, params, min_batch_size=1): 191 | self.xlens = dataset.data["xlen"].values 192 | self.ylens = dataset.data["ylen"].values 193 | self.dataset_size = len(self.xlens) 194 | self.max_xlens_batch = params.max_xlens_batch 195 | self.max_ylens_batch = params.max_ylens_batch 196 | self.batch_size = params.batch_size 197 | self.min_batch_size = min_batch_size 198 | self.indices_batches = self._make_batches() 199 | 200 | def _make_batches(self): 201 | self.index = 0 202 | indices_batches = [] 203 | 204 | while self.index < self.dataset_size: 205 | indices = [] 206 | xlens_sum = 0 207 | ylens_sum = 0 208 | 209 | while self.index < self.dataset_size: 210 | xlen = self.xlens[self.index] 211 | ylen = self.ylens[self.index] 212 | 213 | assert xlen <= self.max_xlens_batch 214 | assert ylen <= self.max_ylens_batch 215 | if ( 216 | xlens_sum + xlen > self.max_xlens_batch 217 | or ylens_sum + ylen > self.max_ylens_batch 218 | or len(indices) + 1 > self.batch_size 219 | ): 220 | break 221 | 222 | indices.append(self.index) 223 | xlens_sum += xlen 224 | ylens_sum += ylen 225 | self.index += 1 226 | 227 | if len(indices) < self.min_batch_size: 228 | logging.warning( 229 | f"{len(indices)} utterances are skipped because of they are smaller than min_batch_size" 230 | ) 231 | else: 232 | indices_batches.append(indices) 233 | 234 | return indices_batches 235 | 236 | def __iter__(self): 237 | # NOTE: shuffled for each epoch 238 | random.shuffle(self.indices_batches) 239 | logging.debug("batches are shuffled in Sampler") 240 | 241 | for indices in self.indices_batches: 242 | yield indices 243 | 244 | def __len__(self): 245 | return len(self.indices_batches) 246 | 247 | 248 | def create_soft_label(data_kd_utt, ylen, vocab_size, lsm_prob, add_eos=False): 249 | if add_eos: 250 | soft_label = torch.zeros(ylen + 1, vocab_size) # same length as `ys_out` 251 | else: 252 | soft_label = torch.zeros(ylen, vocab_size) 253 | 254 | for i, topk_probs in enumerate(data_kd_utt): 255 | soft_label[i, :] = lsm_prob / (vocab_size - len(topk_probs)) 256 | for v, prob in topk_probs: 257 | soft_label[i, v] = prob.astype(np.float64) * (1 - lsm_prob) 258 | 259 | if add_eos: 260 | soft_label[-1, :] = lsm_prob / (vocab_size - 1) 261 | soft_label[-1, eos_id] = 1.0 * (1 - lsm_prob) 262 | 263 | return soft_label 264 | -------------------------------------------------------------------------------- /asr/distill/eval_label.py: -------------------------------------------------------------------------------- 1 | """ Measures soft label accuracy 2 | """ 3 | 4 | import argparse 5 | import os 6 | import pickle 7 | import sys 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | EMOASR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 13 | sys.path.append(EMOASR_DIR) 14 | 15 | from utils.converters import str2ints 16 | from utils.paths import get_eval_path 17 | from utils.vocab import Vocab 18 | 19 | 20 | def accuracy(labels, dfref, vocab=None): 21 | id2ref = {} 22 | cnt, cntacc1, cntacck = 0, 0, 0 23 | 24 | for row in dfref.itertuples(): 25 | id2ref[row.utt_id] = str2ints(row.token_id) 26 | # assert row.utt_id in labels.keys() 27 | 28 | for utt_id, label in tqdm(labels.items()): 29 | ref_token_id = id2ref[utt_id] 30 | cnt += len(label) 31 | 32 | if vocab is not None: 33 | print(f"# utt_id: {utt_id}") 34 | 35 | ref_text = vocab.ids2tokens(ref_token_id) 36 | for i, vps in enumerate(label): 37 | # mask i-th token 38 | ref_text_masked = ref_text.copy() 39 | ref_text_masked[i] = "" 40 | print(" ".join(ref_text_masked)) 41 | 42 | for v, p in vps: 43 | print(f"{vocab.id2token(v)}: {p:.2f}", end=" ") 44 | print() 45 | 46 | for i, vps in enumerate(label): 47 | v1, _ = vps[0] 48 | cntacc1 += int(v1 == ref_token_id[i]) 49 | 50 | for v, _ in vps: 51 | cntacck += int(v == ref_token_id[i]) 52 | 53 | acc1 = (cntacc1 / cnt) * 100 54 | acck = (cntacck / cnt) * 100 55 | 56 | return acc1, acck, cnt 57 | 58 | 59 | def main(args): 60 | with open(args.pkl_path, "rb") as f: 61 | labels = pickle.load(f) 62 | print("pickle loaded") 63 | 64 | tsv_path = get_eval_path(args.ref) 65 | dfref = pd.read_table(tsv_path) 66 | 67 | if args.vocab is not None: 68 | vocab = Vocab(args.vocab) 69 | else: 70 | vocab = None 71 | 72 | acc1, acck, cnt = accuracy(labels, dfref, vocab=vocab) 73 | 74 | print(f"{cnt:d} tokens") 75 | print(f"Accuracy top1: {acc1:.3f} topk: {acck:.3f}") 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("pkl_path", type=str) 81 | parser.add_argument("-ref", type=str, required=True) 82 | parser.add_argument("--vocab", type=str) # debug 83 | args = parser.parse_args() 84 | 85 | main(args) 86 | -------------------------------------------------------------------------------- /asr/fusion/test_fusion_grid.py: -------------------------------------------------------------------------------- 1 | """ Shallow Fusion with grid search parameter search 2 | """ 3 | import argparse 4 | import logging 5 | import multiprocessing 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | 11 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 12 | sys.path.append(EMOASR_ROOT) 13 | 14 | from asr.test_asr import test_main 15 | from utils.paths import get_eval_path, get_results_dir 16 | 17 | EPS = 1e-5 18 | 19 | 20 | def main(args): 21 | log_dir = get_results_dir(args.conf) 22 | data_path = get_eval_path(args.data) 23 | data_tag = ( 24 | args.data 25 | if args.data_tag == "test" and data_path != args.data 26 | else args.data_tag 27 | ) 28 | log_file = ( 29 | f"test_fusion_grid_{data_tag}_ctc{args.decode_ctc_weight}_ep{args.ep}.log" 30 | ) 31 | 32 | logging.basicConfig( 33 | filename=os.path.join(log_dir, log_file), 34 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 35 | level=logging.INFO, 36 | ) 37 | 38 | # grid search 39 | lm_weight_cands = np.arange(args.lm_min, args.lm_max + EPS, args.lm_step) 40 | len_weight_cands = np.arange(args.len_min, args.len_max + EPS, args.len_step) 41 | pool = multiprocessing.Pool(len(lm_weight_cands) * len(len_weight_cands)) 42 | 43 | func_args = [] 44 | 45 | for lm_weight in lm_weight_cands: 46 | for len_weight in len_weight_cands: 47 | func_args.append((args, lm_weight, len_weight)) 48 | 49 | results = pool.starmap(test_main, func_args) 50 | 51 | lm_weight_min = 0 52 | len_weight_min = 0 53 | wer_min = 100 54 | wer_info_min = "" 55 | 56 | for lm_weight, len_weight, wer, wer_info in results: 57 | logging.info( 58 | f"lm_weight: {lm_weight:.3f} len_weight: {len_weight:.3f} - {wer_info}" 59 | ) 60 | if wer < wer_min: 61 | lm_weight_min = lm_weight 62 | len_weight_min = len_weight 63 | wer_min = wer 64 | wer_info_min = wer_info 65 | 66 | logging.info("***** best WER:") 67 | logging.info( 68 | f"lm_weight: {lm_weight_min:.3f} len_weight: {len_weight_min:.3f} - {wer_info_min}" 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("-conf", type=str, required=True) 75 | parser.add_argument("-ep", type=str, required=True) 76 | parser.add_argument("--data", type=str, default=None) 77 | parser.add_argument("--data_tag", type=str, default="test") 78 | parser.add_argument("--save_dir", type=str, default=None) 79 | parser.add_argument("--beam_width", type=int, default=None) 80 | parser.add_argument("--decode_ctc_weight", type=float, default=0) 81 | # 82 | parser.add_argument("--lm_min", type=float, default=0) 83 | parser.add_argument("--lm_max", type=float, default=1) 84 | parser.add_argument("--lm_step", type=float, default=0.1) 85 | parser.add_argument("--len_min", type=float, default=0) 86 | parser.add_argument("--len_max", type=float, default=5) 87 | parser.add_argument("--len_step", type=float, default=1) 88 | parser.add_argument("--lm_conf", type=str, default=None) 89 | parser.add_argument("--lm_ep", type=str, default=None) 90 | parser.add_argument("--lm_tag", type=str, default=None) 91 | args = parser.parse_args() 92 | 93 | # set unused attributes 94 | args.cpu = True 95 | args.nbest = False 96 | args.debug = False 97 | args.utt_id = None 98 | args.runtime = False 99 | main(args) 100 | -------------------------------------------------------------------------------- /asr/metrics.py: -------------------------------------------------------------------------------- 1 | """ WER computation 2 | 3 | Reference 4 | https://github.com/hirofumi0810/neural_sp/blob/master/neural_sp/evaluators/edit_distance.py 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 15 | sys.path.append(EMOASR_ROOT) 16 | 17 | from utils.log import insert_comment 18 | 19 | 20 | def compute_wer(hyp, ref, cer=False): 21 | # NOTE: if cannot decode, use symbol (never match with ref) 22 | if len(hyp) == 0: 23 | hyp = [""] 24 | 25 | if cer: 26 | hyp = list("".join(hyp)) 27 | ref = list("".join(ref)) 28 | 29 | # edit distance 30 | d = np.zeros((len(ref) + 1) * (len(hyp) + 1), dtype=np.uint16) 31 | d = d.reshape((len(ref) + 1, len(hyp) + 1)) 32 | for i in range(len(ref) + 1): 33 | for j in range(len(hyp) + 1): 34 | if i == 0: 35 | d[0][j] = j 36 | elif j == 0: 37 | d[i][0] = i 38 | 39 | for i in range(1, len(ref) + 1): 40 | for j in range(1, len(hyp) + 1): 41 | if ref[i - 1] == hyp[j - 1]: 42 | d[i][j] = d[i - 1][j - 1] 43 | else: 44 | sub_tmp = d[i - 1][j - 1] + 1 45 | ins_tmp = d[i][j - 1] + 1 46 | del_tmp = d[i - 1][j] + 1 47 | d[i][j] = min(sub_tmp, ins_tmp, del_tmp) 48 | dist = d[len(ref)][len(hyp)] 49 | 50 | # backtrack 51 | x = len(ref) 52 | y = len(hyp) 53 | error_list = [] 54 | while True: 55 | if x == 0 and y == 0: 56 | break 57 | else: 58 | if x > 0 and y > 0: 59 | if d[x][y] == d[x - 1][y - 1] and ref[x - 1] == hyp[y - 1]: 60 | error_list.append("C") 61 | x = x - 1 62 | y = y - 1 63 | elif d[x][y] == d[x][y - 1] + 1: 64 | error_list.append("I") 65 | y = y - 1 66 | elif d[x][y] == d[x - 1][y - 1] + 1: 67 | error_list.append("S") 68 | x = x - 1 69 | y = y - 1 70 | else: 71 | error_list.append("D") 72 | x = x - 1 73 | elif x == 0 and y > 0: 74 | if d[x][y] == d[x][y - 1] + 1: 75 | error_list.append("I") 76 | y = y - 1 77 | else: 78 | error_list.append("D") 79 | x = x - 1 80 | elif y == 0 and x > 0: 81 | error_list.append("D") 82 | x = x - 1 83 | else: 84 | raise ValueError 85 | error_list.reverse() 86 | 87 | n_sub = error_list.count("S") 88 | n_ins = error_list.count("I") 89 | n_del = error_list.count("D") 90 | n_cor = error_list.count("C") 91 | 92 | assert dist == (n_sub + n_ins + n_del) 93 | assert n_cor == (len(ref) - n_sub - n_del) 94 | 95 | wer = (dist / len(ref)) * 100 96 | wer_dict = { 97 | "wer": wer, 98 | "n_sub": n_sub, 99 | "n_ins": n_ins, 100 | "n_del": n_del, 101 | "n_ref": len(ref), 102 | "error_list": error_list, 103 | } 104 | 105 | return wer, wer_dict 106 | 107 | 108 | def compute_wers(hyps: list, refs: list, vocab=None, cer=False): 109 | n_sub_total, n_ins_total, n_del_total, n_ref_total = 0, 0, 0, 0 110 | 111 | for hyp, ref in zip(hyps, refs): 112 | if vocab is not None: 113 | hyp = vocab.ids2words(hyp) 114 | ref = vocab.ids2words(ref) 115 | 116 | _, wer_dict = compute_wer(hyp, ref, cer=cer) 117 | 118 | n_sub_total += wer_dict["n_sub"] 119 | n_ins_total += wer_dict["n_ins"] 120 | n_del_total += wer_dict["n_del"] 121 | n_ref_total += wer_dict["n_ref"] 122 | 123 | wer = ((n_sub_total + n_ins_total + n_del_total) / n_ref_total) * 100 124 | wer_dict = { 125 | "wer": wer, 126 | "n_sub": n_sub_total, 127 | "n_ins": n_ins_total, 128 | "n_del": n_del_total, 129 | "n_ref": n_ref_total, 130 | } 131 | 132 | return wer, wer_dict 133 | 134 | 135 | def compute_wers_df(dfhyp, dfref=None, cer=False): 136 | n_sub_total, n_ins_total, n_del_total, n_ref_total = 0, 0, 0, 0 137 | 138 | if dfref is None: 139 | for row in dfhyp.itertuples(): 140 | hyp = row.text.split() if not pd.isna(row.text) else [] 141 | ref = row.reftext.split() 142 | 143 | _, wer_dict = compute_wer(hyp, ref, cer=cer) 144 | 145 | n_sub_total += wer_dict["n_sub"] 146 | n_ins_total += wer_dict["n_ins"] 147 | n_del_total += wer_dict["n_del"] 148 | n_ref_total += wer_dict["n_ref"] 149 | else: 150 | id2hyp = {} 151 | 152 | for row in dfhyp.itertuples(): 153 | id2hyp[row.utt_id] = row.text.split() 154 | 155 | for row in dfref.itertuples(): 156 | hyp = id2hyp[row.utt_id] if row.utt_id in id2hyp else [] 157 | ref = row.text.split() 158 | 159 | _, wer_dict = compute_wer(hyp, ref, cer=cer) 160 | 161 | n_sub_total += wer_dict["n_sub"] 162 | n_ins_total += wer_dict["n_ins"] 163 | n_del_total += wer_dict["n_del"] 164 | n_ref_total += wer_dict["n_ref"] 165 | 166 | wer = ((n_sub_total + n_ins_total + n_del_total) / n_ref_total) * 100 167 | wer_dict = { 168 | "wer": wer, 169 | "n_sub": n_sub_total, 170 | "n_ins": n_ins_total, 171 | "n_del": n_del_total, 172 | "n_ref": n_ref_total, 173 | } 174 | 175 | return wer, wer_dict 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("tsv_path", type=str) 181 | parser.add_argument("--cer", action="store_true") 182 | args = parser.parse_args() 183 | 184 | data = pd.read_table(args.tsv_path, comment="#") 185 | wer, wer_dict = compute_wers_df(data, cer=args.cer) 186 | if args.cer: 187 | wer_info = f"CER: {wer:.2f} [D={wer_dict['n_del']:d}, S={wer_dict['n_sub']:d}, I={wer_dict['n_ins']:d}, N={wer_dict['n_ref']:d}]" 188 | else: 189 | wer_info = f"WER: {wer:.2f} [D={wer_dict['n_del']:d}, S={wer_dict['n_sub']:d}, I={wer_dict['n_ins']:d}, N={wer_dict['n_ref']:d}]" 190 | print(wer_info) 191 | insert_comment(args.tsv_path, wer_info) 192 | -------------------------------------------------------------------------------- /asr/modeling/asr.py: -------------------------------------------------------------------------------- 1 | """ End-to-End ASR modeling 2 | """ 3 | 4 | import logging 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 12 | sys.path.append(EMOASR_ROOT) 13 | 14 | from asr.modeling.decoders.ctc import CTCDecoder 15 | from asr.modeling.decoders.rnn_transducer import RNNTDecoder 16 | from asr.modeling.decoders.transformer import TransformerDecoder 17 | from asr.modeling.encoders.rnn import RNNEncoder 18 | from asr.modeling.encoders.transformer import TransformerEncoder 19 | 20 | 21 | class ASR(nn.Module): 22 | def __init__(self, params, phase="train"): 23 | super(ASR, self).__init__() 24 | 25 | self.encoder_type = params.encoder_type 26 | self.decoder_type = params.decoder_type 27 | 28 | logging.info(f"encoder type: {self.encoder_type}") 29 | if self.encoder_type == "rnn": 30 | self.encoder = RNNEncoder(params) 31 | elif self.encoder_type in ["transformer", "conformer"]: 32 | self.encoder = TransformerEncoder( 33 | params, is_conformer=(self.encoder_type == "conformer") 34 | ) 35 | 36 | logging.info(f"decoder type: {self.decoder_type}") 37 | if self.decoder_type == "ctc": 38 | self.decoder = CTCDecoder(params) 39 | elif self.decoder_type == "rnn_transducer": 40 | self.decoder = RNNTDecoder(params, phase) 41 | elif self.decoder_type == "transformer": 42 | self.decoder = TransformerDecoder(params) 43 | # TODO: LAS 44 | 45 | num_params = sum(p.numel() for p in self.parameters()) 46 | num_params_trainable = sum( 47 | p.numel() for p in self.parameters() if p.requires_grad 48 | ) 49 | logging.info( 50 | f"ASR model #parameters: {num_params} ({num_params_trainable} trainable)" 51 | ) 52 | 53 | def forward( 54 | self, xs, xlens, ys, ylens, ys_in, ys_out, soft_labels=None, ps=None, plens=None 55 | ): 56 | # DataParallel 57 | xs = xs[:, : max(xlens), :] 58 | ys = ys[:, : max(ylens)] 59 | ys_in = ys_in[:, : max(ylens) + 1] 60 | ys_out = ys_out[:, : max(ylens) + 1] 61 | if ps is not None: 62 | ps = ps[:, : max(plens)] 63 | 64 | eouts, elens, eouts_inter = self.encoder(xs, xlens) 65 | loss, loss_dict, _ = self.decoder( 66 | eouts, elens, eouts_inter, ys, ylens, ys_in, ys_out, soft_labels, ps, plens 67 | ) 68 | return loss, loss_dict 69 | 70 | def decode( 71 | self, 72 | xs, 73 | xlens, 74 | beam_width=1, 75 | len_weight=0, 76 | lm=None, 77 | lm_weight=0, 78 | decode_ctc_weight=0, 79 | decode_phone=False, 80 | ): 81 | with torch.no_grad(): 82 | eouts, elens, eouts_inter = self.encoder(xs, xlens) 83 | hyps, scores, logits, aligns = self.decoder.decode( 84 | eouts, 85 | elens, 86 | eouts_inter, 87 | beam_width, 88 | len_weight, 89 | lm, 90 | lm_weight, 91 | decode_ctc_weight, 92 | decode_phone, 93 | ) 94 | 95 | return hyps, scores, logits, aligns 96 | 97 | def forced_align(self, xs, xlens, decode_ctc_weight=0): 98 | with torch.no_grad(): 99 | eouts, elens = self.encoder(xs, xlens) 100 | aligns = self.decoder.forced_align(eouts, elens, decode_ctc_weight) 101 | return aligns 102 | -------------------------------------------------------------------------------- /asr/modeling/conformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 10 | sys.path.append(EMOASR_ROOT) 11 | 12 | from asr.modeling.model_utils import Swish 13 | from asr.modeling.transformer import MultiHeadedAttention, PositionwiseFeedForward 14 | 15 | 16 | class RelPositionalEncoder(nn.Module): 17 | def __init__(self, hidden_size, dropout_rate=0.1, max_len=5000): 18 | super(RelPositionalEncoder, self).__init__() 19 | self.hidden_size = hidden_size 20 | self.xscale = math.sqrt(self.hidden_size) 21 | self.dropout = nn.Dropout(dropout_rate) 22 | self.pe = None 23 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 24 | 25 | def extend_pe(self, x): 26 | if self.pe is not None: 27 | if self.pe.size(1) >= x.size(1) * 2 - 1: 28 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 29 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 30 | return 31 | pe_positive = torch.zeros(x.size(1), self.hidden_size) 32 | pe_negative = torch.zeros(x.size(1), self.hidden_size) 33 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 34 | div_term = torch.exp( 35 | torch.arange(0, self.hidden_size, 2, dtype=torch.float32) 36 | * -(math.log(10000.0) / self.hidden_size) 37 | ) 38 | pe_positive[:, 0::2] = torch.sin(position * div_term) 39 | pe_positive[:, 1::2] = torch.cos(position * div_term) 40 | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) 41 | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) 42 | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) 43 | pe_negative = pe_negative[1:].unsqueeze(0) 44 | pe = torch.cat([pe_positive, pe_negative], dim=1) 45 | self.pe = pe.to(device=x.device, dtype=x.dtype) 46 | 47 | def forward(self, xs): 48 | self.extend_pe(xs) 49 | xs = xs * self.xscale 50 | pos_emb = self.pe[ 51 | :, 52 | self.pe.size(1) // 2 - xs.size(1) + 1 : self.pe.size(1) // 2 + xs.size(1), 53 | ] 54 | return self.dropout(xs), self.dropout(pos_emb) 55 | 56 | 57 | class RelMultiHeadedAttention(MultiHeadedAttention): 58 | def __init__(self, num_attention_heads, hidden_size, dropout_rate): 59 | super().__init__( 60 | num_attention_heads, hidden_size, dropout_rate, 61 | ) 62 | self.linear_pos = nn.Linear(hidden_size, hidden_size, bias=False) 63 | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) 64 | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) 65 | torch.nn.init.xavier_uniform_(self.pos_bias_u) 66 | torch.nn.init.xavier_uniform_(self.pos_bias_v) 67 | 68 | def rel_shift(self, x): 69 | zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) 70 | x_padded = torch.cat([zero_pad, x], dim=-1) 71 | 72 | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) 73 | x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] 74 | 75 | return x 76 | 77 | def forward(self, query, key, value, pos_emb, mask): 78 | q, k, v = self.forward_qkv(query, key, value) 79 | q = q.transpose(1, 2) 80 | 81 | n_batch_pos = pos_emb.size(0) 82 | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) 83 | p = p.transpose(1, 2) 84 | 85 | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) 86 | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) 87 | 88 | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) 89 | 90 | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) 91 | matrix_bd = self.rel_shift(matrix_bd) 92 | 93 | scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) 94 | 95 | return self.forward_attention(v, scores, mask) 96 | 97 | 98 | class ConvModule(nn.Module): 99 | def __init__(self, channels, kernel_size=31): 100 | super(ConvModule, self).__init__() 101 | assert (kernel_size - 1) % 2 == 0 102 | 103 | self.pointwise_conv1 = nn.Conv1d( 104 | channels, 2 * channels, kernel_size=1, stride=1, padding=0, 105 | ) 106 | self.depthwise_conv = nn.Conv1d( 107 | channels, 108 | channels, 109 | kernel_size=kernel_size, 110 | stride=1, 111 | padding=(kernel_size - 1) // 2, 112 | groups=channels, 113 | ) 114 | self.batch_norm = nn.BatchNorm1d(channels) 115 | self.pointwise_conv2 = nn.Conv1d( 116 | channels, channels, kernel_size=1, stride=1, padding=0 117 | ) 118 | self.glu_act = nn.GLU(dim=1) 119 | self.swish_act = Swish() 120 | 121 | def forward(self, x): 122 | # 1. Layernorm is applied before 123 | x = x.transpose(1, 2) 124 | 125 | # 2. Pointwise Conv 126 | x = self.pointwise_conv1(x) 127 | 128 | # 3. GLU Conv 129 | x = self.glu_act(x) 130 | 131 | # 4. 1D Depthwise Conv 132 | x = self.depthwise_conv(x) 133 | 134 | # 5. Batchnorm 135 | x = self.batch_norm(x) 136 | 137 | # 6. Swish activation 138 | x = self.swish_act(x) 139 | 140 | # 7. Pointwise Conv 141 | x = self.pointwise_conv2(x) 142 | 143 | return x.transpose(1, 2) 144 | 145 | 146 | class ConformerEncoderLayer(nn.Module): 147 | def __init__( 148 | self, 149 | enc_num_attention_heads, 150 | enc_hidden_size, 151 | enc_intermediate_size, 152 | dropout_enc_rate, 153 | dropout_attn_rate, 154 | pos_encode_type="rel", 155 | ): 156 | super(ConformerEncoderLayer, self).__init__() 157 | 158 | self.pos_encode_type = pos_encode_type 159 | 160 | if self.pos_encode_type == "abs": 161 | self.self_attn = MultiHeadedAttention( 162 | enc_num_attention_heads, enc_hidden_size, dropout_attn_rate 163 | ) 164 | elif self.pos_encode_type == "rel": 165 | self.self_attn = RelMultiHeadedAttention( 166 | enc_num_attention_heads, enc_hidden_size, dropout_attn_rate 167 | ) 168 | 169 | self.conv = ConvModule(enc_hidden_size) 170 | self.feed_forward = PositionwiseFeedForward( 171 | enc_hidden_size, 172 | enc_intermediate_size, 173 | dropout_enc_rate, 174 | activation_type="swish", 175 | ) 176 | self.feed_forward_macaron = PositionwiseFeedForward( 177 | enc_hidden_size, 178 | enc_intermediate_size, 179 | dropout_enc_rate, 180 | activation_type="swish", 181 | ) 182 | 183 | self.norm_self_attn = nn.LayerNorm(enc_hidden_size) 184 | self.norm_conv = nn.LayerNorm(enc_hidden_size) 185 | self.norm_ff = nn.LayerNorm(enc_hidden_size) 186 | self.norm_ff_macaron = nn.LayerNorm(enc_hidden_size) 187 | self.norm_final = nn.LayerNorm(enc_hidden_size) 188 | 189 | self.dropout = nn.Dropout(dropout_enc_rate) 190 | 191 | def forward(self, x, mask, pos_emb=None): 192 | # 1. Feed Forward module 193 | residual = x 194 | x = self.norm_ff_macaron(x) 195 | x = residual + 0.5 * self.dropout(self.feed_forward_macaron(x)) 196 | 197 | if self.pos_encode_type == "rel": 198 | # 2. Multi-Head Self Attention module 199 | residual = x 200 | x = self.norm_self_attn(x) 201 | x_q = x 202 | x = residual + self.dropout(self.self_attn(x_q, x, x, pos_emb, mask)) 203 | 204 | # 3. Convolution module 205 | residual = x 206 | x = self.norm_conv(x) 207 | x = residual + self.dropout(self.conv(x)) 208 | 209 | elif self.pos_encode_type == "abs": 210 | # 2. Convolution module 211 | residual = x 212 | x = self.norm_conv(x) 213 | x = residual + self.dropout(self.conv(x)) 214 | 215 | # 3. Multi-Head Self Attention module 216 | residual = x 217 | x = self.norm_self_attn(x) 218 | x_q = x 219 | x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) 220 | 221 | # 4. Feed Forward module 222 | residual = x 223 | x = self.norm_ff(x) 224 | x = residual + 0.5 * self.dropout(self.feed_forward(x)) 225 | 226 | # 5. Layernorm 227 | x = self.norm_final(x) 228 | 229 | return x, mask 230 | -------------------------------------------------------------------------------- /asr/modeling/decoders/ctc_aligner.py: -------------------------------------------------------------------------------- 1 | """ Forced alignment with CTC Forward-Backward algorithm 2 | 3 | Reference 4 | https://github.com/hirofumi0810/neural_sp/blob/master/neural_sp/models/seq2seq/decoders/ctc.py 5 | """ 6 | 7 | import os 8 | import sys 9 | 10 | import torch 11 | 12 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 13 | sys.path.append(EMOASR_ROOT) 14 | 15 | LOG_0 = -1e10 16 | LOG_1 = 0 17 | 18 | 19 | def _label_to_path(labels, blank): 20 | path = labels.new_zeros(labels.size(0), labels.size(1) * 2 + 1).fill_(blank).long() 21 | path[:, 1::2] = labels 22 | return path 23 | 24 | 25 | def _flip_path(path, path_lens): 26 | """Flips label sequence. 27 | This function rotates a label sequence and flips it. 28 | ``path[b, t]`` stores a label at time ``t`` in ``b``-th batch. 29 | The rotated matrix ``r`` is defined as 30 | ``r[b, t] = path[b, t + path_lens[b]]`` 31 | .. :: 32 | a b c d . . a b c d d c b a . 33 | e f . . . -> . . . e f -> f e . . . 34 | g h i j k g h i j k k j i h g 35 | """ 36 | bs = path.size(0) 37 | max_path_len = path.size(1) 38 | rotate = (torch.arange(max_path_len) + path_lens[:, None]) % max_path_len 39 | return torch.flip( 40 | path[torch.arange(bs, dtype=torch.int64)[:, None], rotate], dims=[1], 41 | ) 42 | 43 | 44 | def _flip_label_probability(log_probs, xlens): 45 | """Flips a label probability matrix. 46 | This function rotates a label probability matrix and flips it. 47 | ``log_probs[i, b, l]`` stores log probability of label ``l`` at ``i``-th 48 | input in ``b``-th batch. 49 | The rotated matrix ``r`` is defined as 50 | ``r[i, b, l] = log_probs[i + xlens[b], b, l]`` 51 | """ 52 | xmax, bs, vocab = log_probs.size() 53 | rotate = (torch.arange(xmax, dtype=torch.int64)[:, None] + xlens) % xmax 54 | return torch.flip( 55 | log_probs[ 56 | rotate[:, :, None], 57 | torch.arange(bs, dtype=torch.int64)[None, :, None], 58 | torch.arange(vocab, dtype=torch.int64)[None, None, :], 59 | ], 60 | dims=[0], 61 | ) 62 | 63 | 64 | def _flip_path_probability(cum_log_prob, xlens, path_lens): 65 | """Flips a path probability matrix. 66 | This function returns a path probability matrix and flips it. 67 | ``cum_log_prob[i, b, t]`` stores log probability at ``i``-th input and 68 | at time ``t`` in a output sequence in ``b``-th batch. 69 | The rotated matrix ``r`` is defined as 70 | ``r[i, j, k] = cum_log_prob[i + xlens[j], j, k + path_lens[j]]`` 71 | """ 72 | xmax, bs, max_path_len = cum_log_prob.size() 73 | rotate_input = (torch.arange(xmax, dtype=torch.int64)[:, None] + xlens) % xmax 74 | rotate_label = ( 75 | torch.arange(max_path_len, dtype=torch.int64) + path_lens[:, None] 76 | ) % max_path_len 77 | return torch.flip( 78 | cum_log_prob[ 79 | rotate_input[:, :, None], 80 | torch.arange(bs, dtype=torch.int64)[None, :, None], 81 | rotate_label, 82 | ], 83 | dims=[0, 2], 84 | ) 85 | 86 | 87 | def _make_pad_mask(seq_lens): 88 | bs = seq_lens.size(0) 89 | max_time = seq_lens.max() 90 | seq_range = torch.arange(0, max_time, dtype=torch.int32, device=seq_lens.device) 91 | seq_range = seq_range.unsqueeze(0).expand(bs, max_time) 92 | mask = seq_range < seq_lens.unsqueeze(-1) 93 | return mask 94 | 95 | 96 | class CTCForcedAligner(object): 97 | """ 98 | 99 | Reference: 100 | https://github.com/hirofumi0810/neural_sp/blob/master/neural_sp/models/seq2seq/decoders/ctc.py 101 | """ 102 | 103 | def __init__(self, blank_id=0): 104 | self.blank_id = blank_id 105 | 106 | def _computes_transition( 107 | self, prev_log_prob, path, path_lens, cum_log_prob, y, skip_accum=False 108 | ): 109 | """ 110 | prev_log_prob [B, T, vocab]: alpha or beta or gamma 111 | path [B, T, ] 112 | path_lens [B] 113 | cum_log_prob 114 | y 115 | """ 116 | bs, max_path_len = path.size() 117 | mat = prev_log_prob.new_zeros(3, bs, max_path_len).fill_(LOG_0) 118 | mat[0, :, :] = prev_log_prob 119 | mat[1, :, 1:] = prev_log_prob[:, :-1] 120 | mat[2, :, 2:] = prev_log_prob[:, :-2] 121 | # disable transition between the same symbols 122 | # (including blank-to-blank) 123 | same_transition = path[:, :-2] == path[:, 2:] 124 | mat[2, :, 2:][same_transition] = LOG_0 125 | log_prob = torch.logsumexp(mat, dim=0) 126 | outside = torch.arange(max_path_len, dtype=torch.int64,) >= path_lens.unsqueeze( 127 | 1 128 | ) 129 | log_prob[outside] = LOG_0 130 | if not skip_accum: 131 | cum_log_prob += log_prob 132 | batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1) 133 | 134 | # print(y[batch_index, path]) 135 | log_prob += y[batch_index, path] 136 | return log_prob 137 | 138 | def __call__(self, log_probs, elens, ys, ylens): 139 | """Calculte the best CTC alignment with the forward-backward algorithm. 140 | """ 141 | bs, xmax, vocab = log_probs.size() 142 | device = log_probs.device 143 | 144 | # zero padding 145 | mask = _make_pad_mask(elens.to(device)) 146 | mask = mask.unsqueeze(2).repeat([1, 1, vocab]) 147 | log_probs = log_probs.masked_fill_(mask == 0, 0) 148 | log_probs = log_probs.transpose(0, 1) # `[T, B, vocab]` 149 | 150 | path = _label_to_path(ys, self.blank_id) 151 | path_lens = 2 * ylens.long().cpu() + 1 152 | 153 | ymax = ys.size(1) 154 | max_path_len = path.size(1) 155 | assert ys.size() == (bs, ymax), ys.size() 156 | assert path.size() == (bs, ymax * 2 + 1) 157 | 158 | alpha = log_probs.new_zeros(bs, max_path_len).fill_(LOG_0) 159 | alpha[:, 0] = LOG_1 160 | beta = alpha.clone() 161 | gamma = alpha.clone() 162 | 163 | batch_index = torch.arange(bs, dtype=torch.int64).unsqueeze(1) 164 | seq_index = torch.arange(xmax, dtype=torch.int64).unsqueeze(1).unsqueeze(2) 165 | log_probs_fwd_bwd = log_probs[seq_index, batch_index, path] 166 | 167 | # forward algorithm 168 | for t in range(xmax): 169 | alpha = self._computes_transition( 170 | alpha, path, path_lens, log_probs_fwd_bwd[t], log_probs[t], 171 | ) 172 | 173 | # backward algorithm 174 | r_path = _flip_path(path, path_lens) 175 | log_probs_inv = _flip_label_probability( 176 | log_probs, elens.long().cpu() 177 | ) # (T, B, vocab) 178 | log_probs_fwd_bwd = _flip_path_probability( 179 | log_probs_fwd_bwd, elens.long().cpu(), path_lens 180 | ) # (T, B, 2*L+1) 181 | for t in range(xmax): 182 | beta = self._computes_transition( 183 | beta, r_path, path_lens, log_probs_fwd_bwd[t], log_probs_inv[t], 184 | ) 185 | 186 | # pick up the best CTC path 187 | best_aligns = log_probs.new_zeros((bs, xmax), dtype=torch.int64) 188 | 189 | # forward algorithm 190 | log_probs_fwd_bwd = _flip_path_probability( 191 | log_probs_fwd_bwd, elens.long().cpu(), path_lens 192 | ) 193 | 194 | for t in range(xmax): 195 | gamma = self._computes_transition( 196 | gamma, 197 | path, 198 | path_lens, 199 | log_probs_fwd_bwd[t], 200 | log_probs[t], 201 | skip_accum=True, 202 | ) 203 | 204 | # select paths where gamma is valid 205 | log_probs_fwd_bwd[t] = log_probs_fwd_bwd[t].masked_fill_( 206 | gamma == LOG_0, LOG_0 207 | ) 208 | 209 | # pick up the best alignment 210 | offsets = log_probs_fwd_bwd[t].argmax(1) 211 | for b in range(bs): 212 | if t <= elens[b] - 1: 213 | token_idx = path[b, offsets[b]] 214 | best_aligns[b, t] = token_idx 215 | 216 | # remove the rest of paths (select the best path) 217 | gamma = log_probs.new_zeros(bs, max_path_len).fill_(LOG_0) 218 | for b in range(bs): 219 | gamma[b, offsets[b]] = LOG_1 220 | 221 | return best_aligns 222 | 223 | 224 | if __name__ == "__main__": 225 | torch.manual_seed(1) 226 | bs, T, vocab = 2, 8, 3 227 | logits = torch.rand((bs, T, vocab)) * 10.0 228 | probs = torch.nn.functional.softmax(logits, dim=-1) 229 | print("probs:", probs) 230 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 231 | elens = torch.tensor([7, 8]) 232 | ys = torch.tensor([[1, 2, 0], [1, 2, 1]]) 233 | ylens = torch.tensor([2, 3]) 234 | aligner = CTCForcedAligner() 235 | aligns = aligner(log_probs, elens, ys, ylens) 236 | print(aligns) 237 | -------------------------------------------------------------------------------- /asr/modeling/decoders/ctc_score.py: -------------------------------------------------------------------------------- 1 | """ CTC score for joint CTC decoding 2 | 3 | Reference: 4 | https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py 5 | """ 6 | 7 | import numpy as np 8 | import six 9 | 10 | LOG_0 = -1e10 11 | 12 | 13 | class CTCPrefixScorer: 14 | def __init__(self, x, blank_id, eos_id): 15 | self.blank_id = blank_id 16 | self.eos_id = eos_id 17 | self.input_length = len(x) 18 | self.x = x 19 | 20 | def initial_state(self): 21 | """Obtain an initial CTC state 22 | 23 | :return: CTC state 24 | """ 25 | # initial CTC state is made of a frame x 2 tensor that corresponds to 26 | # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent 27 | # superscripts n and b (non-blank and blank), respectively. 28 | r = np.full((self.input_length, 2), LOG_0, dtype=np.float32) 29 | r[0, 1] = self.x[0, self.blank_id] 30 | for i in six.moves.range(1, self.input_length): 31 | r[i, 1] = r[i - 1, 1] + self.x[i, self.blank_id] 32 | return r 33 | 34 | def __call__(self, y, cs, r_prev): 35 | """Compute CTC prefix scores for next labels 36 | 37 | :param y : prefix label sequence 38 | :param cs : array of next labels 39 | :param r_prev: previous CTC state 40 | :return ctc_scores, ctc_states 41 | """ 42 | # initialize CTC states 43 | output_length = len(y) - 1 # ignore sos 44 | # new CTC states are prepared as a frame x (n or b) x n_labels tensor 45 | # that corresponds to r_t^n(h) and r_t^b(h). 46 | r = np.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) 47 | xs = self.x[:, cs] 48 | if output_length == 0: 49 | r[0, 0] = xs[0] 50 | r[0, 1] = LOG_0 51 | else: 52 | r[output_length - 1] = LOG_0 53 | 54 | # prepare forward probabilities for the last label 55 | r_sum = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) # log(r_t^n(g) + r_t^b(g)) 56 | last = y[-1] 57 | if output_length > 0 and last in cs: 58 | log_phi = np.ndarray((self.input_length, len(cs)), dtype=np.float32) 59 | for i in six.moves.range(len(cs)): 60 | log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1] 61 | else: 62 | log_phi = r_sum 63 | 64 | # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)), 65 | # and log prefix probabilites log(psi) 66 | start = max(output_length, 1) 67 | log_psi = r[start - 1, 0] 68 | for t in six.moves.range(start, self.input_length): 69 | r[t, 0] = np.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] 70 | r[t, 1] = np.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank_id] 71 | log_psi = np.logaddexp(log_psi, log_phi[t - 1] + xs[t]) 72 | 73 | # get P(...eos|X) that ends with the prefix itself 74 | eos_pos = np.where(cs == self.eos_id)[0] 75 | if len(eos_pos) > 0: 76 | log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g)) 77 | 78 | # exclude blank probs 79 | blank_pos = np.where(cs == self.blank_id)[0] 80 | if len(blank_pos) > 0: 81 | log_psi[blank_pos] = LOG_0 82 | 83 | # return the log prefix probability and CTC states, where the label axis 84 | # of the CTC states is moved to the first axis to slice it easily 85 | return log_psi, np.rollaxis(r, 2) 86 | -------------------------------------------------------------------------------- /asr/modeling/decoders/rnnt_aligner.py: -------------------------------------------------------------------------------- 1 | """ Forced alignment with RNN-T Forward-Backward algorithm 2 | 3 | Reference 4 | https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/transducer_loss.py 5 | """ 6 | 7 | import math 8 | import time 9 | 10 | import torch 11 | from numba import cuda 12 | 13 | 14 | @cuda.jit( 15 | "(float32[:,:,:,:], int32[:,:], float32[:,:,:], float32[:], int32[:], int32[:], int32, int32[:,:])" 16 | ) 17 | def cu_kernel_forward(log_probs, labels, alpha, log_p, T, U, blank, lock): 18 | """ 19 | Compute forward pass for the forward-backward algorithm using Numba cuda kernel. 20 | Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf 21 | Arguments 22 | --------- 23 | log_probs : tensor 24 | 4D Tensor of (batch x TimeLength x LabelLength x outputDim) from the Transducer network. 25 | labels : tensor 26 | 2D Tensor of (batch x MaxSeqLabelLength) containing targets of the batch with zero padding. 27 | alpha : tensor 28 | 3D Tensor of (batch x TimeLength x LabelLength) for forward computation. 29 | log_p : tensor 30 | 1D Tensor of (batch) for forward cost computation. 31 | T : tensor 32 | 1D Tensor of (batch) containing TimeLength of each target. 33 | U : tensor 34 | 1D Tensor of (batch) containing LabelLength of each target. 35 | blank : int 36 | Blank indice. 37 | lock : tensor 38 | 2D Tensor of (batch x LabelLength) containing bool(1-0) lock for parallel computation. 39 | """ 40 | # parallelize the forward algorithm over batch and target length dim 41 | b = cuda.blockIdx.x 42 | u = cuda.threadIdx.x 43 | t = 0 44 | if u <= U[b]: 45 | # for each (B,U) Thread 46 | # wait the unlock of the previous computation of Alpha[b,U-1,:] 47 | # Do the computation over the whole Time sequence on alpha[B,U,:] 48 | # and then unlock the target U+1 for computation 49 | while t < T[b]: 50 | if u == 0: # init 51 | if t > 0: 52 | alpha[b, t, 0] = alpha[b, t - 1, 0] + log_probs[b, t - 1, 0, blank] 53 | cuda.atomic.add(lock, (b, u + 1), -1) # 0 -> -1 54 | t += 1 55 | else: 56 | if cuda.atomic.add(lock, (b, u), 0) < 0: 57 | if t == 0: 58 | alpha[b, 0, u] = ( 59 | alpha[b, 0, u - 1] 60 | + log_probs[b, 0, u - 1, labels[b, u - 1]] 61 | ) 62 | else: 63 | # compute emission prob 64 | emit = ( 65 | alpha[b, t, u - 1] 66 | + log_probs[b, t, u - 1, labels[b, u - 1]] 67 | ) 68 | # compute no_emission prob 69 | no_emit = alpha[b, t - 1, u] + log_probs[b, t - 1, u, blank] 70 | # do logsumexp between log_emit and log_no_emit 71 | alpha[b, t, u] = max(no_emit, emit) + math.log1p( 72 | math.exp(-abs(no_emit - emit)) 73 | ) 74 | if u < U[b]: 75 | cuda.atomic.add(lock, (b, u + 1), -1) 76 | cuda.atomic.add(lock, (b, u), 1) # -1 -> 0 77 | t += 1 78 | if u == U[b]: 79 | # for each thread b (utterance) 80 | # normalize the loss over time 81 | log_p[b] = ( 82 | alpha[b, T[b] - 1, U[b]] + log_probs[b, T[b] - 1, U[b], blank] 83 | ) / T[b] 84 | 85 | 86 | @cuda.jit( 87 | "(float32[:,:,:,:], int32[:,:], float32[:,:,:], float32[:], int32[:], int32[:], int32, int32[:,:])" 88 | ) 89 | def cu_kernel_backward(log_probs, labels, beta, log_p, T, U, blank, lock): 90 | """ 91 | Compute backward pass for the forward-backward algorithm using Numba cuda kernel. 92 | Sequence Transduction with naive implementation : https://arxiv.org/pdf/1211.3711.pdf 93 | Arguments 94 | --------- 95 | log_probs : tensor 96 | 4D Tensor of (batch x TimeLength x LabelLength x outputDim) from the Transducer network. 97 | labels : tensor 98 | 2D Tensor of (batch x MaxSeqLabelLength) containing targets of the batch with zero padding. 99 | beta : tensor 100 | 3D Tensor of (batch x TimeLength x LabelLength) for backward computation. 101 | log_p : tensor 102 | 1D Tensor of (batch) for backward cost computation. 103 | T : tensor 104 | 1D Tensor of (batch) containing TimeLength of each target. 105 | U : tensor 106 | 1D Tensor of (batch) containing LabelLength of each target. 107 | blank : int 108 | Blank indice. 109 | lock : tensor 110 | 2D Tensor of (batch x LabelLength) containing bool(1-0) lock for parallel computation. 111 | """ 112 | # parallelize the forward algorithm over batch and target length dim 113 | b = cuda.blockIdx.x 114 | u = cuda.threadIdx.x 115 | t = T[b] - 1 116 | if u <= U[b]: 117 | # for each (B,U) Thread 118 | # wait the unlock of the next computation of beta[b,U+1,:] 119 | # Do the computation over the whole Time sequence on beta[B,U,:] 120 | # and then unlock the target U-1 for computation 121 | while t >= 0: 122 | if u == U[b]: # init 123 | if t == T[b] - 1: 124 | beta[b, t, u] = log_probs[b, t, u, blank] 125 | else: 126 | beta[b, t, u] = beta[b, t + 1, u] + log_probs[b, t, u, blank] 127 | cuda.atomic.add(lock, (b, u - 1), -1) 128 | t -= 1 129 | else: 130 | if cuda.atomic.add(lock, (b, u), 0) < 0: 131 | if t == T[b] - 1: 132 | # do logsumexp between log_emit and log_no_emit 133 | beta[b, t, u] = ( 134 | beta[b, t, u + 1] + log_probs[b, t, u, labels[b, u]] 135 | ) 136 | else: 137 | # compute emission prob 138 | emit = beta[b, t, u + 1] + log_probs[b, t, u, labels[b, u]] 139 | # compute no_emission prob 140 | no_emit = beta[b, t + 1, u] + log_probs[b, t, u, blank] 141 | # do logsumexp between log_emit and log_no_emit 142 | beta[b, t, u] = max(no_emit, emit) + math.log1p( 143 | math.exp(-abs(no_emit - emit)) 144 | ) 145 | if u > 0: 146 | cuda.atomic.add(lock, (b, u - 1), -1) 147 | cuda.atomic.add(lock, (b, u), 1) 148 | t -= 1 149 | if u == 0: 150 | # for each thread b (utterance) 151 | # normalize the loss over time 152 | log_p[b] = beta[b, 0, 0] / T[b] 153 | 154 | 155 | class RNNTForcedAligner(object): 156 | def __init__(self, blank_id=0): 157 | self.blank_id = blank_id 158 | 159 | def __call__(self, log_probs, elens, ys, ylens): 160 | acts = log_probs.detach() 161 | labels = ys.int().detach() 162 | T = elens.int().detach() 163 | U = ylens.int().detach() 164 | 165 | B, maxT, maxU, _ = acts.shape 166 | 167 | alpha = torch.zeros((B, maxT, maxU), device=acts.device) 168 | beta = torch.zeros((B, maxT, maxU), device=acts.device) 169 | lock = torch.zeros((B, maxU), dtype=torch.int32, device=acts.device) 170 | log_p_alpha = torch.zeros((B,), device=acts.device) 171 | log_p_beta = torch.zeros((B,), device=acts.device) 172 | 173 | # forward 174 | cu_kernel_forward[B, maxU]( 175 | acts, labels, alpha, log_p_alpha, T, U, self.blank_id, lock, 176 | ) 177 | lock = lock * 0 178 | # backward 179 | cu_kernel_backward[B, maxU]( 180 | acts, labels, beta, log_p_beta, T, U, self.blank_id, lock 181 | ) 182 | log_probs_fwd_bwd = alpha + beta 183 | 184 | best_aligns = torch.zeros( 185 | (B, maxU - 1), dtype=torch.int32, device=log_probs.device 186 | ) 187 | 188 | # alignment 189 | for b in range(B): 190 | t, u = 0, 0 191 | while t + 1 < T[b] and u < U[b]: 192 | if log_probs_fwd_bwd[b, t + 1, u] > log_probs_fwd_bwd[b, t, u + 1]: 193 | t += 1 194 | else: 195 | best_aligns[b, u] = t # emit y_u 196 | u += 1 197 | 198 | return best_aligns 199 | 200 | 201 | if __name__ == "__main__": 202 | torch.manual_seed(1) 203 | log_probs = torch.randn((2, 10, 6, 5)).cuda().log_softmax(dim=-1).requires_grad_() 204 | labels = torch.Tensor([[1, 2, 1, 2, 0], [1, 2, 1, 2, 3]]).cuda().int() 205 | T = torch.Tensor([8, 10]).cuda().int() 206 | U = label_length = torch.Tensor([4, 5]).cuda().int() 207 | blank = 0 208 | 209 | log_probs = log_probs.detach() 210 | B, maxT, maxU, A = log_probs.shape 211 | alpha = torch.zeros((B, maxT, maxU), device=log_probs.device) 212 | beta = torch.zeros((B, maxT, maxU), device=log_probs.device) 213 | lock = torch.zeros((B, maxU), dtype=torch.int32, device=log_probs.device) 214 | log_p_alpha = torch.zeros((B,), device=log_probs.device) 215 | log_p_beta = torch.zeros((B,), device=log_probs.device) 216 | 217 | cu_kernel_forward[B, maxU]( 218 | log_probs, labels, alpha, log_p_alpha, T, U, blank, lock, 219 | ) 220 | lock = lock * 0 221 | cu_kernel_backward[B, maxU](log_probs, labels, beta, log_p_beta, T, U, blank, lock) 222 | 223 | log_probs_fwd_bwd = alpha + beta 224 | -------------------------------------------------------------------------------- /asr/modeling/encoders/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Conv2dEncoder(nn.Module): 6 | def __init__(self, input_dim, output_dim): 7 | super(Conv2dEncoder, self).__init__() 8 | self.conv = nn.Sequential( 9 | nn.Conv2d(in_channels=1, out_channels=output_dim, kernel_size=3, stride=2), 10 | nn.ReLU(), 11 | nn.Conv2d( 12 | in_channels=output_dim, out_channels=output_dim, kernel_size=3, stride=2 13 | ), 14 | nn.ReLU(), 15 | ) 16 | self.output = nn.Linear( 17 | output_dim * (((input_dim - 1) // 2 - 1) // 2), output_dim 18 | ) 19 | 20 | def forward(self, xs, xlens): 21 | xs = xs.unsqueeze(1) # (B, 1, L, input_dim) 22 | xs = self.conv(xs) 23 | bs, output_dim, length, dim = xs.size() 24 | xs = self.output( 25 | xs.transpose(1, 2).contiguous().view(bs, length, output_dim * dim) 26 | ) 27 | xlens = ((xlens - 1) // 2 - 1) // 2 28 | return xs, xlens 29 | 30 | 31 | # DEBUG 32 | if __name__ == "__main__": 33 | xs = torch.rand((5, 110, 80)) 34 | xlens = torch.randint(1, 110, (5,)) 35 | 36 | print("xs:", xs.shape) 37 | print("xlens:", xlens) 38 | 39 | input_dim = 80 40 | output_dim = 512 41 | conv = Conv2dEncoder(input_dim, output_dim) 42 | 43 | eouts, elens = conv(xs, xlens) 44 | print("eouts:", eouts.shape) 45 | print(eouts[0, :5, :5]) 46 | print("elens:", elens) 47 | 48 | import os 49 | import sys 50 | 51 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 52 | sys.path.append(EMOASR_ROOT) 53 | 54 | from asr.models.model_utils import make_src_mask 55 | from asr.models.transformer import Conv2dSubsampler 56 | 57 | conv2 = Conv2dSubsampler(input_dim, output_dim, 0) 58 | xs_mask = make_src_mask(xlens) 59 | eouts, eouts_mask = conv2(xs, xs_mask) 60 | print("eouts:", eouts.shape) 61 | print(eouts[0, :5, :5]) 62 | print("eouts_mask:", eouts_mask) 63 | elens = torch.sum(eouts_mask, dim=1) 64 | print("elens:", elens) 65 | -------------------------------------------------------------------------------- /asr/modeling/encoders/rnn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import torch.nn as nn 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 9 | sys.path.append(EMOASR_ROOT) 10 | 11 | from asr.modeling.encoders.conv import Conv2dEncoder 12 | 13 | 14 | class RNNEncoder(nn.Module): 15 | def __init__(self, params): 16 | super(RNNEncoder, self).__init__() 17 | 18 | self.input_layer = params.input_layer 19 | self.enc_num_layers = params.enc_num_layers 20 | 21 | input_size = params.feat_dim * params.num_framestacks 22 | 23 | if self.input_layer == "conv2d": 24 | self.conv = Conv2dEncoder( 25 | input_dim=input_size, output_dim=params.enc_hidden_size 26 | ) 27 | input_size = params.enc_hidden_size 28 | 29 | self.enc_hidden_sum_fwd_bwd = params.enc_hidden_sum_fwd_bwd 30 | 31 | if self.enc_hidden_sum_fwd_bwd: 32 | enc_hidden_size = params.enc_hidden_size 33 | else: 34 | assert params.enc_hidden_size % 2 == 0 35 | enc_hidden_size = params.enc_hidden_size // 2 36 | logging.warning( 37 | f"enc_hidden_sum_fwd_bwd is False, so LSTM with hidden_size = {enc_hidden_size}" 38 | ) 39 | 40 | self.rnns = nn.ModuleList() 41 | for _ in range(self.enc_num_layers): 42 | self.rnns += [ 43 | nn.LSTM( 44 | input_size=input_size, 45 | hidden_size=params.enc_hidden_size, 46 | num_layers=1, 47 | batch_first=True, 48 | bidirectional=True, 49 | ) 50 | ] 51 | input_size = params.enc_hidden_size 52 | 53 | self.dropout = nn.Dropout(p=params.dropout_enc_rate) 54 | 55 | def forward(self, xs, xlens): 56 | if self.input_layer == "conv2d": 57 | xs, elens = self.conv(xs, xlens) # lengths are converted 58 | elif self.input_layer == "none": 59 | elens = xlens 60 | 61 | for layer_id in range(self.enc_num_layers): 62 | self.rnns[layer_id].flatten_parameters() 63 | 64 | xs = pack_padded_sequence( 65 | xs, elens.cpu(), batch_first=True, enforce_sorted=False 66 | ) 67 | eouts_pack, _ = self.rnns[layer_id](xs) 68 | xs, _ = pad_packed_sequence(eouts_pack, batch_first=True) 69 | 70 | if self.enc_hidden_sum_fwd_bwd: 71 | # NOTE: sum up forward and backward RNN outputs 72 | # (B, T, enc_hidden_size*2) -> (B, T, enc_hidden_size) 73 | half = xs.size(-1) // 2 74 | xs = xs[:, :, :half] + xs[:, :, half:] 75 | 76 | xs = self.dropout(xs) 77 | 78 | eouts = xs 79 | eouts_inter = None 80 | 81 | return eouts, elens, eouts_inter 82 | -------------------------------------------------------------------------------- /asr/modeling/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 7 | sys.path.append(EMOASR_ROOT) 8 | 9 | import torch.nn as nn 10 | from asr.modeling.conformer import ConformerEncoderLayer, RelPositionalEncoder 11 | from asr.modeling.encoders.conv import Conv2dEncoder 12 | from asr.modeling.model_utils import make_src_mask 13 | from asr.modeling.transformer import PositionalEncoder, TransformerEncoderLayer 14 | 15 | 16 | class TransformerEncoder(nn.Module): 17 | """ Transformer encoder including Conformer 18 | """ 19 | 20 | def __init__(self, params, is_conformer=False): 21 | super(TransformerEncoder, self).__init__() 22 | 23 | self.input_layer = params.input_layer 24 | self.enc_num_layers = params.enc_num_layers 25 | self.pos_encode_type = ( 26 | params.pos_encode_type if hasattr(params, "pos_encode_type") else "abs" 27 | ) 28 | self.is_conformer = is_conformer 29 | 30 | if self.input_layer == "conv2d": 31 | input_size = params.feat_dim * params.num_framestacks 32 | self.conv = Conv2dEncoder( 33 | input_dim=input_size, output_dim=params.enc_hidden_size 34 | ) 35 | elif self.input_layer == "embed": 36 | self.embed = nn.Embedding(params.src_vocab_size, params.enc_hidden_size) 37 | elif self.input_layer == "linear": 38 | input_size = params.feat_dim * params.num_framestacks 39 | self.linear = nn.Linear(input_size, params.enc_hidden_size) 40 | input_size = params.enc_hidden_size 41 | 42 | if self.pos_encode_type == "abs": 43 | self.pe = PositionalEncoder( 44 | input_size, dropout_rate=params.dropout_enc_rate 45 | ) 46 | elif self.pos_encode_type == "rel": 47 | assert self.is_conformer 48 | self.pe = RelPositionalEncoder( 49 | input_size, dropout_rate=params.dropout_enc_rate 50 | ) 51 | 52 | if self.is_conformer: 53 | EncoderLayer = ConformerEncoderLayer 54 | else: 55 | EncoderLayer = TransformerEncoderLayer 56 | 57 | # TODO: rename to `encoders` 58 | self.transformers = nn.ModuleList() 59 | for _ in range(self.enc_num_layers): 60 | self.transformers += [ 61 | EncoderLayer( 62 | enc_num_attention_heads=params.enc_num_attention_heads, 63 | enc_hidden_size=params.enc_hidden_size, 64 | enc_intermediate_size=params.enc_intermediate_size, 65 | dropout_enc_rate=params.dropout_enc_rate, 66 | dropout_attn_rate=params.dropout_attn_rate, 67 | pos_encode_type=self.pos_encode_type, 68 | ) 69 | ] 70 | 71 | # normalize before 72 | # TODO: set `eps` to 1e-5 (default) 73 | self.norm = nn.LayerNorm(params.enc_hidden_size, eps=1e-12) 74 | 75 | if ( 76 | hasattr(params, "mtl_inter_ctc_weight") and params.mtl_inter_ctc_weight > 0 77 | ) or ( 78 | hasattr(params, "mtl_phone_ctc_weight") and params.mtl_phone_ctc_weight > 0 79 | ): 80 | self.inter_ctc_layer_id = params.inter_ctc_layer_id 81 | else: 82 | self.inter_ctc_layer_id = 0 83 | 84 | def forward(self, xs, xlens): 85 | if self.input_layer == "conv2d": 86 | xs, elens = self.conv(xs, xlens) 87 | elif self.input_layer == "embed": 88 | xs = self.embed(xs) 89 | elens = xlens 90 | elif self.input_layer == "linear": 91 | xs = self.linear(xs) 92 | elens = xlens 93 | 94 | mask = make_src_mask(elens) 95 | eouts_inter = None 96 | 97 | if self.pos_encode_type == "abs": 98 | xs = self.pe(xs) 99 | pos_emb = None 100 | elif self.pos_encode_type == "rel": 101 | xs, pos_emb = self.pe(xs) 102 | 103 | for layer_id in range(self.enc_num_layers): 104 | xs, mask = self.transformers[layer_id](xs, mask, pos_emb) 105 | # NOTE: intermediate branches also require normalization. 106 | if (layer_id + 1) == self.inter_ctc_layer_id: 107 | eouts_inter = self.norm(xs) 108 | 109 | # normalize before 110 | xs = self.norm(xs) 111 | eouts = xs 112 | 113 | return eouts, elens, eouts_inter 114 | -------------------------------------------------------------------------------- /asr/modeling/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | 6 | def make_nopad_mask(lengths): 7 | """ 8 | NOTE: faster implementation of the following 9 | mask = [[bool(l < length) for l in range(max(lengths))] for length in lengths] 10 | """ 11 | if torch.is_tensor(lengths): 12 | lens = lengths.tolist() 13 | else: 14 | lens = lengths 15 | 16 | bs = int(len(lengths)) 17 | maxlen = int(max(lengths)) 18 | 19 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 20 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 21 | seq_length_expand = seq_range_expand.new(lens).unsqueeze(-1) 22 | mask = seq_range_expand < seq_length_expand 23 | 24 | if torch.is_tensor(lengths): 25 | mask = mask.to(lengths.device) 26 | 27 | return mask 28 | 29 | 30 | def make_causal_mask(length): 31 | ret = torch.ones(length, length, dtype=bool) 32 | return torch.tril(ret, out=ret) 33 | 34 | 35 | def make_src_mask(xlens: torch.Tensor): 36 | return make_nopad_mask(xlens.tolist()).unsqueeze(-2).to(xlens.device) 37 | 38 | 39 | def make_tgt_mask(ylens: torch.Tensor): 40 | nopad_mask = make_nopad_mask(ylens.tolist()).unsqueeze(-2) 41 | maxlen = nopad_mask.size(-1) 42 | causal_mask = make_causal_mask(maxlen).unsqueeze(0) 43 | return (nopad_mask & causal_mask).to(ylens.device) 44 | 45 | 46 | class Swish(nn.Module): 47 | def forward(self, x): 48 | return x * torch.sigmoid(x) 49 | 50 | 51 | if __name__ == "__main__": 52 | xlens = torch.tensor([1, 2, 3]) 53 | src_mask = make_src_mask(xlens) 54 | tgt_mask = make_tgt_mask(xlens) 55 | print(src_mask) 56 | print(tgt_mask) 57 | -------------------------------------------------------------------------------- /asr/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 10 | sys.path.append(EMOASR_ROOT) 11 | 12 | from asr.modeling.model_utils import Swish 13 | 14 | 15 | class PositionalEncoder(nn.Module): 16 | def __init__(self, hidden_size, dropout_rate=0.1, max_len=5000): 17 | super(PositionalEncoder, self).__init__() 18 | self.hidden_size = hidden_size 19 | self.xscale = math.sqrt(self.hidden_size) 20 | self.dropout = nn.Dropout(dropout_rate) 21 | self.pe = None 22 | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) 23 | 24 | def extend_pe(self, xs): 25 | if self.pe is not None: 26 | if self.pe.size(1) >= xs.size(1): 27 | if self.pe.dtype != xs.dtype or self.pe.device != xs.device: 28 | self.pe = self.pe.to(dtype=xs.dtype, device=xs.device) 29 | return 30 | pe = torch.zeros(xs.size(1), self.hidden_size) 31 | position = torch.arange(0, xs.size(1), dtype=torch.float32).unsqueeze(1) 32 | div_term = torch.exp( 33 | torch.arange(0, self.hidden_size, 2, dtype=torch.float32) 34 | * -(math.log(10000.0) / self.hidden_size) 35 | ) 36 | pe[:, 0::2] = torch.sin(position * div_term) 37 | pe[:, 1::2] = torch.cos(position * div_term) 38 | pe = pe.unsqueeze(0) 39 | self.pe = pe.to(device=xs.device, dtype=xs.dtype) 40 | 41 | def forward(self, xs): 42 | self.extend_pe(xs) 43 | # ASR 44 | xs = xs * self.xscale + self.pe[:, : xs.size(1)] 45 | return self.dropout(xs) 46 | 47 | 48 | class MultiHeadedAttention(nn.Module): 49 | def __init__(self, num_attention_heads, hidden_size, dropout_rate): 50 | super(MultiHeadedAttention, self).__init__() 51 | assert hidden_size % num_attention_heads == 0 52 | # We assume d_v always equals d_k 53 | self.d_k = hidden_size // num_attention_heads 54 | self.h = num_attention_heads 55 | self.linear_q = nn.Linear(hidden_size, hidden_size) 56 | self.linear_k = nn.Linear(hidden_size, hidden_size) 57 | self.linear_v = nn.Linear(hidden_size, hidden_size) 58 | self.linear_out = nn.Linear(hidden_size, hidden_size) 59 | self.attn = None 60 | self.dropout = nn.Dropout(p=dropout_rate) 61 | 62 | def forward_qkv(self, query, key, value): 63 | n_batch = query.size(0) 64 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 65 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 66 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 67 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 68 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 69 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 70 | 71 | return q, k, v 72 | 73 | def forward_attention(self, value, scores, mask): 74 | n_batch = value.size(0) 75 | if mask is not None: 76 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 77 | min_value = float( 78 | np.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min 79 | ) 80 | 81 | scores = scores.masked_fill(mask, min_value) 82 | self.attn = torch.softmax(scores, dim=-1).masked_fill( 83 | mask, 0.0 84 | ) # (batch, head, time1, time2) 85 | else: 86 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 87 | 88 | p_attn = self.dropout(self.attn) 89 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 90 | x = ( 91 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) 92 | ) # (batch, time1, d_model) 93 | 94 | return self.linear_out(x) # (batch, time1, d_model) 95 | 96 | def forward(self, query, key, value, mask): 97 | q, k, v = self.forward_qkv(query, key, value) 98 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 99 | return self.forward_attention(v, scores, mask) 100 | 101 | 102 | class PositionwiseFeedForward(nn.Module): 103 | def __init__( 104 | self, input_size, intermediate_size, dropout_rate, activation_type="relu" 105 | ): 106 | super(PositionwiseFeedForward, self).__init__() 107 | self.w1 = nn.Linear(input_size, intermediate_size) 108 | self.w2 = nn.Linear(intermediate_size, input_size) 109 | self.dropout = nn.Dropout(dropout_rate) 110 | 111 | if activation_type == "relu": 112 | # TODO: rename to `act` 113 | self.activation = nn.ReLU() 114 | elif activation_type == "swish": 115 | self.activation = Swish() # for Conformer 116 | 117 | def forward(self, x): 118 | return self.w2(self.dropout(self.activation(self.w1(x)))) 119 | 120 | 121 | class TransformerEncoderLayer(nn.Module): 122 | def __init__( 123 | self, 124 | enc_num_attention_heads, 125 | enc_hidden_size, 126 | enc_intermediate_size, 127 | dropout_enc_rate, 128 | dropout_attn_rate, 129 | pos_encode_type="abs", 130 | ): 131 | super(TransformerEncoderLayer, self).__init__() 132 | self.self_attn = MultiHeadedAttention( 133 | enc_num_attention_heads, enc_hidden_size, dropout_attn_rate 134 | ) 135 | self.feed_forward = PositionwiseFeedForward( 136 | enc_hidden_size, enc_intermediate_size, dropout_enc_rate 137 | ) 138 | # TODO: set `eps` to 1e-5 (default) 139 | # TODO: rename to `norm_self_attn` and `norm_ff` 140 | self.norm1 = nn.LayerNorm(enc_hidden_size, eps=1e-12) 141 | self.norm2 = nn.LayerNorm(enc_hidden_size, eps=1e-12) 142 | self.dropout = nn.Dropout(dropout_enc_rate) 143 | 144 | def forward(self, x, mask, pos_emb=None): 145 | residual = x 146 | x = self.norm1(x) # normalize before 147 | x_q = x 148 | x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) 149 | residual = x 150 | x = self.norm2(x) # normalize before 151 | x = residual + self.dropout(self.feed_forward(x)) 152 | 153 | return x, mask 154 | 155 | 156 | class TransformerDecoderLayer(nn.Module): 157 | def __init__( 158 | self, 159 | dec_num_attention_heads, 160 | dec_hidden_size, 161 | dec_intermediate_size, 162 | dropout_dec_rate, 163 | dropout_attn_rate, 164 | pos_encode_type="abs", 165 | ): 166 | super(TransformerDecoderLayer, self).__init__() 167 | self.dec_hidden_size = dec_hidden_size 168 | self.self_attn = MultiHeadedAttention( 169 | dec_num_attention_heads, dec_hidden_size, dropout_attn_rate 170 | ) 171 | self.src_attn = MultiHeadedAttention( 172 | dec_num_attention_heads, dec_hidden_size, dropout_attn_rate 173 | ) 174 | self.feed_forward = PositionwiseFeedForward( 175 | dec_hidden_size, dec_intermediate_size, dropout_dec_rate 176 | ) 177 | # TODO: set `eps` to 1e-5 (default) 178 | self.norm1 = nn.LayerNorm(dec_hidden_size, eps=1e-12) 179 | self.norm2 = nn.LayerNorm(dec_hidden_size, eps=1e-12) 180 | self.norm3 = nn.LayerNorm(dec_hidden_size, eps=1e-12) 181 | self.dropout = nn.Dropout(dropout_dec_rate) 182 | 183 | def forward(self, x, mask, memory, memory_mask): 184 | residual = x 185 | x = self.norm1(x) # normalize before 186 | x_q = x 187 | x_q_mask = mask 188 | x = residual + self.dropout(self.self_attn(x_q, x, x, x_q_mask)) 189 | 190 | residual = x 191 | x = self.norm2(x) # normalize before 192 | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) 193 | 194 | residual = x 195 | x = self.norm3(x) # normalize before 196 | x = residual + self.dropout(self.feed_forward(x)) 197 | 198 | return x, mask, memory, memory_mask 199 | -------------------------------------------------------------------------------- /asr/optimizers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | 6 | class ScheduledOptimizer: 7 | """ wrapper for optimizer 8 | """ 9 | 10 | def __init__(self, optimizer, params, num_total_steps=None): 11 | self.optimizer = optimizer 12 | self.schedule_type = params.lr_schedule_type 13 | self._step = 0 14 | self._epoch = 0 15 | self.base_lr = params.learning_rate 16 | self.num_total_steps = num_total_steps 17 | 18 | # either `num_warmup_steps` or `warmup_proportion` must be specified 19 | assert hasattr(params, "num_warmup_steps") ^ hasattr( 20 | params, "warmup_proportion" 21 | ) 22 | 23 | if hasattr(params, "warmup_proportion"): 24 | self.num_warmup_steps = int(num_total_steps * params.warmup_proportion) 25 | logging.info(f"warmup #steps: {self.num_warmup_steps:d}") 26 | else: 27 | self.num_warmup_steps = params.num_warmup_steps 28 | 29 | self._lr = 0 30 | 31 | logging.info(f"lr scheduling type: {self.schedule_type}") 32 | if self.schedule_type == "epdecay": 33 | self.lr_decay_start_epoch = params.lr_decay_start_epoch 34 | self.lr_decay_rate = params.lr_decay_rate 35 | elif self.schedule_type == "noam": 36 | if hasattr(params, "enc_hidden_size"): 37 | self.model_dim = params.enc_hidden_size 38 | else: 39 | self.model_dim = params.hidden_size 40 | 41 | @property 42 | def param_groups(self): 43 | return self.optimizer.param_groups 44 | 45 | def step(self): 46 | self._step += 1 47 | 48 | new_lr = None 49 | 50 | if self.schedule_type == "epdecay": 51 | if self._step <= self.num_warmup_steps: 52 | new_lr = (self.base_lr / max(1.0, self.num_warmup_steps)) * self._step 53 | else: 54 | new_lr = self.base_lr 55 | 56 | elif self.schedule_type == "noam": 57 | new_lr = ( 58 | self.base_lr 59 | * self.model_dim ** (-0.5) 60 | * min( 61 | self._step ** (-0.5), self._step * self.num_warmup_steps ** (-1.5) 62 | ) 63 | ) 64 | 65 | elif self.schedule_type == "lindecay": 66 | # `transformers.get_linear_schedule_with_warmup` 67 | if self._step <= self.num_warmup_steps: 68 | new_lr = (self.base_lr / max(1.0, self.num_warmup_steps)) * self._step 69 | else: 70 | new_lr = self.base_lr * max( 71 | 0.0, 72 | float(self.num_total_steps - self._step) 73 | / float(max(1.0, self.num_total_steps - self.num_warmup_steps)), 74 | ) 75 | 76 | if new_lr != self._lr: 77 | # set optimizer's learning rate 78 | for param in self.optimizer.param_groups: 79 | param["lr"] = new_lr 80 | 81 | self._lr = new_lr 82 | self.optimizer.step() 83 | 84 | def update_epoch(self): 85 | self._epoch += 1 86 | 87 | if self.schedule_type == "epdecay": 88 | if self._epoch >= self.lr_decay_start_epoch: 89 | new_lr = self._lr * self.lr_decay_rate 90 | # set optimizer's learning rate 91 | for param in self.optimizer.param_groups: 92 | param["lr"] = new_lr 93 | logging.info(f"learning rate decreased: {self._lr:.6f} -> {new_lr:.6f}") 94 | self._lr = new_lr 95 | 96 | def zero_grad(self): 97 | self.optimizer.zero_grad() 98 | 99 | def state_dict(self): 100 | return { 101 | "_step": self._step, 102 | "_epoch": self._epoch, 103 | "base_lr": self.base_lr, 104 | "_lr": self._lr, 105 | "num_warmup_steps": self.num_warmup_steps, 106 | "num_total_steps": self.num_total_steps, 107 | "optimizer": self.optimizer.state_dict(), 108 | } 109 | 110 | def load_state_dict(self, state_dict): 111 | for key, value in state_dict.items(): 112 | if key == "optimizer": 113 | self.optimizer.load_state_dict(state_dict["optimizer"]) 114 | elif key == "num_total_steps" and value is not None: 115 | assert self.num_total_steps == value 116 | else: 117 | setattr(self, key, value) 118 | 119 | 120 | def optimizer_to(optimizer, device): 121 | for state in optimizer.optimizer.state.values(): 122 | for k, v in state.items(): 123 | if isinstance(v, torch.Tensor): 124 | state[k] = v.to(device) 125 | return optimizer 126 | 127 | 128 | def get_optimizer_params_nodecay(model_named_params: list, weight_decay: float): 129 | nodecay_keys = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 130 | optimizer_params = [ 131 | { 132 | "params": [ 133 | p 134 | for n, p in model_named_params 135 | if not any(nd in n for nd in nodecay_keys) 136 | ], 137 | "weight_decay": weight_decay, 138 | }, 139 | { 140 | "params": [ 141 | p for n, p in model_named_params if any(nd in n for nd in nodecay_keys) 142 | ], 143 | "weight_decay": 0.0, 144 | }, 145 | ] 146 | return optimizer_params 147 | -------------------------------------------------------------------------------- /asr/rescore/README.md: -------------------------------------------------------------------------------- 1 | # Rescoring 2 | 3 | "ASR RESCORING AND CONFIDENCE ESTIMATION WITH ELECTRA" [Futami ASRU2021] 4 | https://arxiv.org/pdf/2110.01857.pdf 5 | 6 | ### TED-LIUM2 7 | 8 | | | | WER(test) | Runtime(test) | 9 | |:---:|:---|:---:|:---:| 10 | | `T1` | CTC ASR (w/o LM) | 12.11 | - | 11 | | `T2` | +Transformer | 9.78 | x1 | 12 | | `T3` | +BERT | 9.72 | | 13 | | `T4` | +ELECTRA | 10.17 | | 14 | | `T5` | +ELECTRA(FT) | ... | | 15 | | `T6` | +P-ELECTRA | 9.73 | | 16 | | `T7` | +P-ELECTRA(FT) | ... | | 17 | 18 | - `T1`: `exps/ted2_nsp10k/transformer_ctc` 19 | - `T2`: `exps/ted2_nsp10k/transformer` 20 | - `T3`: `exps/ted2_nsp10k/bert` 21 | - `T4`: `exps/ted2_nsp10k/electra_utt` 22 | - `T5`: `exps/ted2_nsp10k/electra_utt_asr5b` 23 | - `T6`: `exps/ted2_nsp10k/pelectra_utt` 24 | - `T7`: `exps/ted2_nsp10k/pelectra_utt_asr5b` 25 | -------------------------------------------------------------------------------- /asr/rescore/align_hyps.py: -------------------------------------------------------------------------------- 1 | """ Align n-best hypotheses to train `electra-disc`, `pelectra-disc` 2 | """ 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 12 | sys.path.append(EMOASR_ROOT) 13 | 14 | from asr.metrics import compute_wer 15 | from utils.converters import str2ints 16 | from utils.paths import get_eval_path 17 | 18 | 19 | def alignment(dfhyp, dfref, align_type="SID", len_min=1, len_max=256): 20 | id2ref = {} 21 | 22 | for row in dfref.itertuples(): 23 | id2ref[row.utt_id] = str2ints(row.token_id) 24 | 25 | outs = [] 26 | 27 | for row in tqdm(dfhyp.itertuples()): 28 | hyp_token_id = str2ints(row.token_id) 29 | ref_token_id = id2ref[row.utt_id] 30 | 31 | if len(hyp_token_id) < len_min or len(hyp_token_id) > len_max: 32 | continue 33 | 34 | _, wer_dict = compute_wer(hyp_token_id, ref_token_id) 35 | error_list = wer_dict["error_list"] 36 | 37 | align_list = [] 38 | del_flag = False 39 | 40 | if align_type == "SI": 41 | align_list = [e for e in error_list if e != "D"] 42 | elif align_type == "SID": 43 | for e in error_list: 44 | if e == "D": 45 | # pass `D` to left 46 | if len(align_list) > 0 and align_list[-1] == "C": 47 | align_list[-1] == "D" 48 | else: # to right 49 | del_flag = True 50 | else: 51 | if del_flag and e == "C": 52 | align_list.append("D") 53 | else: 54 | align_list.append(e) 55 | del_flag = False 56 | 57 | assert len(hyp_token_id) == len(align_list) 58 | 59 | outs.append( 60 | (row.utt_id, row.score_asr, row.token_id, row.text, row.reftext, " ".join(align_list)) 61 | ) 62 | 63 | df = pd.DataFrame( 64 | outs, columns=["utt_id", "score_asr", "token_id", "text", "reftext", "error_label"] 65 | ) 66 | 67 | return df 68 | 69 | def main(args): 70 | dfhyp = pd.read_table(args.tsv_path) 71 | dfhyp = dfhyp.dropna() 72 | dfref = pd.read_table(get_eval_path(args.ref)) 73 | 74 | df = alignment(dfhyp, dfref, args.align_type, len_min=args.len_min, len_max=args.len_max) 75 | 76 | df.to_csv(args.tsv_path.replace(".tsv", f"_{args.align_type}align.tsv"), sep="\t", index=False) 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("tsv_path", type=str) 81 | parser.add_argument("-ref", type=str, required=True) 82 | parser.add_argument("--align_type", choices=["SI", "SID"], default="SID") 83 | parser.add_argument("--len_min", type=int, default=1) 84 | parser.add_argument("--len_max", type=int, default=256) 85 | args = parser.parse_args() 86 | 87 | main(args) 88 | -------------------------------------------------------------------------------- /asr/rescore/test_rescore_grid.py: -------------------------------------------------------------------------------- 1 | """ Rescoring with grid search parameter search 2 | """ 3 | 4 | import argparse 5 | import logging 6 | import os 7 | import re 8 | import sys 9 | import time 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from torch.nn.utils.rnn import pad_sequence 15 | 16 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 17 | sys.path.append(EMOASR_ROOT) 18 | 19 | from asr.metrics import compute_wers_df 20 | from lm.modeling.lm import LM 21 | from utils.configure import load_config 22 | from utils.converters import str2ints, tensor2np 23 | from utils.paths import get_eval_path, get_model_path, rel_to_abs_path 24 | from utils.vocab import Vocab 25 | 26 | BATCH_SIZE = 100 27 | EPS = 1e-5 28 | 29 | 30 | def score_lm(df, model, device, mask_id=None, vocab=None, num_samples=-1): 31 | ys, ylens, score_lms_all = [], [], [] 32 | 33 | utt_id = None 34 | cnt_utts = 0 35 | utt_ids = [] 36 | 37 | for i, row in enumerate(df.itertuples()): 38 | if row.utt_id != utt_id: 39 | cnt_utts += 1 40 | utt_id = row.utt_id 41 | utt_ids.append(utt_id) 42 | if num_samples > 0 and (cnt_utts + 1) > num_samples: 43 | return utt_ids 44 | 45 | y = str2ints(row.token_id) 46 | ys.append(torch.tensor(y)) 47 | ylens.append(len(y)) 48 | 49 | if len(ys) < BATCH_SIZE and i != (len(df) - 1): 50 | continue 51 | 52 | ys_pad = pad_sequence(ys, batch_first=True).to(device) 53 | ylens = torch.tensor(ylens).to(device) 54 | 55 | score_lms = model.score(ys_pad, ylens, batch_size=BATCH_SIZE) 56 | 57 | if vocab is not None: # debug mode 58 | for y, score_lm in zip(ys, score_lms): 59 | logging.debug( 60 | f"{' '.join(vocab.ids2words(tensor2np(y)))}: {score_lm:.3f}" 61 | ) 62 | 63 | score_lms_all.extend(score_lms) 64 | ys, ylens = [], [] 65 | 66 | df["score_lm"] = score_lms_all 67 | return df 68 | 69 | 70 | def rescore(df, dfref, lm_weight, len_weight): 71 | df["ylen"] = df["token_id"].apply(lambda s: len(s.split())) 72 | df["score"] = df["score_asr"] + lm_weight * df["score_lm"] + len_weight * df["ylen"] 73 | 74 | df_best = df.loc[df.groupby("utt_id")["score"].idxmax(), :] 75 | df_best = df_best[["utt_id", "text", "token_id", "score_asr"]] 76 | 77 | wer, wer_dict = compute_wers_df(df_best, dfref) 78 | return wer, wer_dict, df_best 79 | 80 | 81 | def main(args): 82 | if args.cpu: 83 | device = torch.device("cpu") 84 | torch.set_num_threads(1) 85 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 86 | else: 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | 89 | lm_params = load_config(args.lm_conf) 90 | lm_tag = lm_params.lm_type if args.lm_tag is None else args.lm_tag 91 | if args.debug or args.runtime: 92 | logging.basicConfig( 93 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 94 | level=logging.DEBUG, 95 | ) 96 | else: 97 | log_path = args.tsv_path.replace(".tsv", f"_{lm_tag}.log") 98 | logging.basicConfig( 99 | filename=log_path, 100 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 101 | level=logging.INFO, 102 | ) 103 | 104 | df = pd.read_table(args.tsv_path) 105 | df = df.dropna() 106 | dfref = pd.read_table(get_eval_path(args.ref)) 107 | 108 | # LM 109 | lm_path = get_model_path(args.lm_conf, args.lm_ep) 110 | lm_params = load_config(args.lm_conf) 111 | lm = LM(lm_params, phase="test") 112 | lm.load_state_dict(torch.load(lm_path, map_location=device)) 113 | logging.info(f"LM: {lm_path}") 114 | lm.to(device) 115 | lm.eval() 116 | 117 | mask_id = lm_params.mask_id if hasattr(lm_params, "mask_id") else None 118 | vocab = Vocab(rel_to_abs_path(lm_params.vocab_path)) if args.debug else None 119 | 120 | if args.runtime: 121 | torch.set_num_threads(1) 122 | 123 | global BATCH_SIZE 124 | BATCH_SIZE = 1 125 | 126 | runtimes = [] 127 | rtfs = [] 128 | for j in range(args.runtime_num_repeats): 129 | start_time = time.time() 130 | 131 | utt_ids = score_lm(df, lm, device, mask_id=mask_id, num_samples=args.runtime_num_samples) 132 | runtime = time.time() - start_time 133 | runtime_utt = runtime / args.runtime_num_samples 134 | wavtime = 0 135 | for utt_id in utt_ids: 136 | start_time = int(re.split("_|-", utt_id)[-2]) / args.wavtime_factor 137 | end_time = int(re.split("_|-", utt_id)[-1]) / args.wavtime_factor 138 | wavtime += (end_time - start_time) 139 | rtf = runtime / wavtime 140 | logging.info(f"Run {(j+1):d} runtime: {runtime:.5f}sec / utt, wavtime: {wavtime:.5f}sec | RTF: {(rtf):.5f}") 141 | runtimes.append(runtime) 142 | rtfs.append(rtf) 143 | 144 | logging.info(f"Averaged runtime {np.mean(runtimes):.5f}sec, RTF {np.mean(rtfs):.5f} on {device.type}") 145 | return 146 | 147 | scored_tsv_path = args.tsv_path.replace(".tsv", f"_{lm_tag}.tsv") 148 | 149 | # calculate `score_lm` 150 | if not os.path.exists(scored_tsv_path): 151 | df = score_lm(df, lm, device, mask_id=mask_id, vocab=vocab) 152 | df.to_csv(scored_tsv_path, sep="\t", index=False) 153 | else: 154 | logging.info(f"load score_lm: {scored_tsv_path}") 155 | df = pd.read_table(scored_tsv_path) 156 | 157 | # grid search 158 | lm_weight_cands = np.arange(args.lm_min, args.lm_max + EPS, args.lm_step) 159 | len_weight_cands = np.arange(args.len_min, args.len_max + EPS, args.len_step) 160 | 161 | lm_weight_best = 0 162 | len_weight_best = 0 163 | df_best = None 164 | wer_min = 100 165 | 166 | for lm_weight in lm_weight_cands: 167 | for len_weight in len_weight_cands: 168 | wer, wer_dict, df_result = rescore(df, dfref, lm_weight, len_weight) 169 | 170 | wer_info = f"WER: {wer:.2f} [D={wer_dict['n_del']:d}, S={wer_dict['n_sub']:d}, I={wer_dict['n_ins']:d}, N={wer_dict['n_ref']:d}]" 171 | logging.info( 172 | f"lm_weight: {lm_weight:.3f} len_weight: {len_weight:.3f} - {wer_info}" 173 | ) 174 | 175 | if wer < wer_min: 176 | wer_min = wer 177 | lm_weight_best = lm_weight 178 | len_weight_best = len_weight 179 | df_best = df_result 180 | 181 | best_tsv_path = scored_tsv_path.replace( 182 | ".tsv", f"_lm{lm_weight_best:.2f}_len{len_weight_best:.2f}.tsv" 183 | ) 184 | logging.info( 185 | f"best lm_weight: {lm_weight_best:.3f} len_weight: {len_weight_best:.3f}" 186 | ) 187 | if df_best is not None: 188 | df_best.to_csv(best_tsv_path, sep="\t", index=False) 189 | logging.info(f"best WER: {wer_min:.3f}") 190 | 191 | 192 | if __name__ == "__main__": 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument("tsv_path", type=str) # nbest 195 | parser.add_argument("-ref", type=str, required=True) # tsv_path for reference 196 | parser.add_argument("--cpu", action="store_true") 197 | parser.add_argument("--debug", action="store_true") 198 | parser.add_argument("--runtime", action="store_true") 199 | parser.add_argument("--runtime_num_samples", type=int, default=20) 200 | parser.add_argument("--runtime_num_repeats", type=int, default=5) 201 | parser.add_argument("--wavtime_factor", type=float, default=1000) 202 | # 203 | parser.add_argument("-lm_conf", type=str, required=True) 204 | parser.add_argument("-lm_ep", type=str, required=True) 205 | parser.add_argument("--lm_tag", type=str, default=None) 206 | parser.add_argument("--lm_min", type=float, default=0) 207 | parser.add_argument("--lm_max", type=float, default=1) 208 | parser.add_argument("--lm_step", type=float, default=0.1) 209 | parser.add_argument("--len_min", type=float, default=0) 210 | parser.add_argument("--len_max", type=float, default=5) 211 | parser.add_argument("--len_step", type=float, default=1) 212 | args = parser.parse_args() 213 | main(args) 214 | -------------------------------------------------------------------------------- /asr/spec_augment.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import numpy as np 5 | 6 | random.seed(0) 7 | np.random.seed(0) 8 | 9 | 10 | class SpecAugment: 11 | """ SpecAugment 12 | 13 | Reference: 14 | https://arxiv.org/abs/1904.08779 15 | """ 16 | 17 | def __init__(self, params): 18 | self.max_mask_freq = params.max_mask_freq 19 | self.num_masks_freq = params.num_masks_freq 20 | 21 | if hasattr(params, "max_mask_time_ratio"): 22 | # Adaptive SpecAugment 23 | # https://arxiv.org/pdf/1912.05533.pdf 24 | self.adaptive_specaug = True 25 | self.max_mask_time_ratio = params.max_mask_time_ratio 26 | self.num_masks_time_ratio = params.num_masks_time_ratio 27 | else: 28 | self.adaptive_specaug = False 29 | self.max_mask_time = params.max_mask_time 30 | self.num_masks_time = params.num_masks_time 31 | 32 | self.replace_with_zero = params.replace_with_zero 33 | 34 | logging.info(f"apply SpecAugment - {vars(self)}") 35 | 36 | def __call__(self, x: np.ndarray): 37 | return self._time_mask(self._freq_mask(x)) 38 | 39 | def _freq_mask(self, x: np.ndarray): 40 | """ 41 | Reference: 42 | https://github.com/espnet/espnet/blob/master/espnet/transform/spec_augment.py 43 | """ 44 | cloned = x.copy() 45 | fdim = cloned.shape[1] 46 | 47 | fs = np.random.randint(0, self.max_mask_freq, size=(self.num_masks_freq, 2)) 48 | 49 | for f, mask_end in fs: 50 | f_zero = random.randrange(0, fdim - f) 51 | mask_end += f_zero 52 | 53 | # avoids randrange error if values are equal and range is empty 54 | if f_zero == f_zero + f: 55 | continue 56 | 57 | if self.replace_with_zero: 58 | cloned[:, f_zero:mask_end] = 0 59 | else: 60 | cloned[:, f_zero:mask_end] = cloned.mean() 61 | return cloned 62 | 63 | def _time_mask(self, x: np.ndarray): 64 | """ 65 | Reference: 66 | https://github.com/espnet/espnet/blob/master/espnet/transform/spec_augment.py 67 | """ 68 | cloned = x.copy() 69 | xlen = cloned.shape[0] 70 | 71 | if self.adaptive_specaug: 72 | max_mask_time = min(20, round(xlen * self.max_mask_time_ratio)) 73 | num_masks_time = min(20, round(xlen * self.num_masks_time_ratio)) 74 | else: 75 | max_mask_time = self.max_mask_time 76 | num_masks_time = self.num_masks_time 77 | 78 | ts = np.random.randint(0, max_mask_time, size=(num_masks_time, 2)) 79 | 80 | for t, mask_end in ts: 81 | # avoid randint range error 82 | if xlen - t <= 0: 83 | continue 84 | t_zero = random.randrange(0, xlen - t) 85 | 86 | # avoids randrange error if values are equal and range is empty 87 | if t_zero == t_zero + t: 88 | continue 89 | 90 | mask_end += t_zero 91 | if self.replace_with_zero: 92 | cloned[t_zero:mask_end] = 0 93 | else: 94 | cloned[t_zero:mask_end] = cloned.mean() 95 | return cloned 96 | -------------------------------------------------------------------------------- /corpora/epasr/make_utts_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import subprocess 5 | 6 | import pandas as pd 7 | 8 | 9 | def main(args): 10 | rows = [] # utt_id, wav_path, text 11 | 12 | for data_file in sorted(os.listdir(args.data_dir)): # e.g.`t6` 13 | data_dir1 = os.path.join(args.data_dir, data_file) 14 | for data_file in sorted(os.listdir(data_dir1)): # e.g.`2009-04-21` 15 | data_dir2 = os.path.join(data_dir1, data_file) 16 | for data_file in sorted(os.listdir(data_dir2)): # e.g.`2-196` 17 | data_dir3 = os.path.join(data_dir2, data_file) 18 | files = [file for file in os.listdir(data_dir3)] 19 | wav_path, json_path = "", "" 20 | for file in files: 21 | if file.endswith(".wav"): 22 | wav_path = os.path.join(data_dir3, file) 23 | if file.endswith(args.json_ext): 24 | json_path = os.path.join(data_dir3, file) 25 | assert wav_path and json_path 26 | 27 | wav_file = os.path.basename(wav_path) 28 | assert "ep-asr.en.orig." in wav_file 29 | wav_file = wav_file.replace("ep-asr.en.orig.", "") 30 | utt_prefix = wav_file.replace(".wav", "") 31 | 32 | out_wav_dir = os.path.join(args.out_wav_dir, utt_prefix) 33 | os.makedirs(out_wav_dir, exist_ok=True) 34 | 35 | with open(json_path) as f: 36 | sections = json.load(f) 37 | for section in sections: 38 | start_time = float(section["b"]) 39 | end_time = float(section["e"]) 40 | start_time_str = str(int(start_time * 100)).zfill(7) 41 | end_time_str = str(int(end_time * 100)).zfill(7) 42 | text = " ".join([sec["w"] for sec in section["wl"]]) 43 | utt_id = f"{utt_prefix}-{start_time_str}-{end_time_str}" 44 | 45 | out_wav_path = os.path.join(out_wav_dir, f"{utt_id}.wav") 46 | 47 | # trim wav 48 | cp = subprocess.run( 49 | [ 50 | "sox", 51 | wav_path, 52 | out_wav_path, 53 | "trim", 54 | f"{start_time:.2f}", 55 | f"={end_time:.2f}", 56 | ] 57 | ) 58 | assert cp.returncode == 0 59 | 60 | print(f"{wav_path} -> {out_wav_path}") 61 | 62 | rows.append((utt_id, out_wav_path, text)) 63 | 64 | data = pd.DataFrame(rows, columns=["utt_id", "wav_path", "text"]) 65 | data.to_csv(args.tsv_path, sep="\t", index=False) 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("data_dir", type=str) # wav + json 71 | # `train` 72 | parser.add_argument("out_wav_dir", type=str) 73 | parser.add_argument("tsv_path", type=str) 74 | parser.add_argument("json_ext", type=str) 75 | args = parser.parse_args() 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /corpora/epasr/make_utts_stm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import subprocess 5 | 6 | import pandas as pd 7 | 8 | 9 | def main(args): 10 | stm_labels = {} 11 | with open(args.stm_path) as f: 12 | lines = f.readlines() 13 | for line in lines: 14 | sections = line.strip().split() 15 | utt_prefix = sections[0].replace("ep-asr.en.orig.", "") 16 | start_time = float(sections[3]) 17 | end_time = float(sections[4]) 18 | text = " ".join(sections[6:]) 19 | if utt_prefix not in stm_labels: 20 | stm_labels[utt_prefix] = [(start_time, end_time, text)] 21 | else: 22 | stm_labels[utt_prefix].append((start_time, end_time, text)) 23 | 24 | rows = [] # utt_id, wav_path, text 25 | 26 | for data_file in sorted(os.listdir(args.data_dir)): # e.g.`t6` 27 | data_dir1 = os.path.join(args.data_dir, data_file) 28 | for data_file in sorted(os.listdir(data_dir1)): # e.g.`2009-04-21` 29 | data_dir2 = os.path.join(data_dir1, data_file) 30 | for data_file in sorted(os.listdir(data_dir2)): # e.g.`2-196` 31 | data_dir3 = os.path.join(data_dir2, data_file) 32 | files = [file for file in os.listdir(data_dir3)] 33 | wav_path, json_path = "", "" 34 | for file in files: 35 | if file.endswith(".wav"): 36 | wav_path = os.path.join(data_dir3, file) 37 | 38 | wav_file = os.path.basename(wav_path) 39 | assert "ep-asr.en.orig." in wav_file 40 | wav_file = wav_file.replace("ep-asr.en.orig.", "") 41 | utt_prefix = wav_file.replace(".wav", "") 42 | 43 | out_wav_dir = os.path.join(args.out_wav_dir, utt_prefix) 44 | os.makedirs(out_wav_dir, exist_ok=True) 45 | 46 | for start_time, end_time, text in stm_labels[utt_prefix]: 47 | start_time_str = str(int(start_time * 100)).zfill(7) 48 | end_time_str = str(int(end_time * 100)).zfill(7) 49 | utt_id = f"{utt_prefix}-{start_time_str}-{end_time_str}" 50 | 51 | out_wav_path = os.path.join(out_wav_dir, f"{utt_id}.wav") 52 | 53 | # trim wav 54 | cp = subprocess.run( 55 | [ 56 | "sox", 57 | wav_path, 58 | out_wav_path, 59 | "trim", 60 | f"{start_time:.2f}", 61 | f"={end_time:.2f}", 62 | ] 63 | ) 64 | assert cp.returncode == 0 65 | 66 | print(f"{wav_path} -> {out_wav_path}") 67 | 68 | rows.append((utt_id, out_wav_path, text)) 69 | 70 | data = pd.DataFrame(rows, columns=["utt_id", "wav_path", "text"]) 71 | data.to_csv(args.tsv_path, sep="\t", index=False) 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("data_dir", type=str) # wav 77 | # `dev-dep`, `dev-indep`, `test-dep`, `test-indep` 78 | parser.add_argument("out_wav_dir", type=str) 79 | parser.add_argument("tsv_path", type=str) 80 | parser.add_argument("stm_path", type=str) 81 | args = parser.parse_args() 82 | 83 | main(args) 84 | -------------------------------------------------------------------------------- /corpora/epasr/prep.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | 3 | data=epasr/data/orig/release/en 4 | 5 | train=${data}/train/original_audio/speeches 6 | dev_dep=${data}/dev/original_audio/spk-dep/speeches; dev_indep=${data}/dev/original_audio/spk-indep/speeches 7 | test_dep=${data}/test/original_audio/spk-dep/speeches; test_indep=${data}/test/original_audio/spk-indep/speeches 8 | 9 | ### ASR 10 | 11 | for set in $train $dev_dep $dev_indep $test_dep $test_indep; do 12 | m4as=$(find $set -name "*.m4a") 13 | for m4a in $m4as; do 14 | wav=${m4a/.m4a/.wav} 15 | ffmpeg -y -i $m4a -ar 16000 $wav -loglevel error 16 | echo "${m4a} -> ${wav}" 17 | done 18 | done 19 | 20 | # split wav for utterances 21 | python epasr/make_utts_json.py $train epasr/data/train epasr/data/train_wav.tsv ".tr.verb.json" 22 | 23 | # read `stm` 24 | python epasr/make_utts_stm.py $dev_dep epasr/data/dev_dep epasr/data/dev_dep_wav.tsv ${data}/dev/original_audio/spk-dep/refs/ep-asr.en.dev.spk-dep.rev.stm 25 | python epasr/make_utts_stm.py $dev_indep epasr/data/dev_indep epasr/data/dev_indep_wav.tsv ${data}/dev/original_audio/spk-indep/refs/ep-asr.en.dev.spk-indep.rev.stm 26 | python epasr/make_utts_stm.py $test_dep epasr/data/test_dep epasr/data/test_dep_wav.tsv ${data}/test/original_audio/spk-dep/refs/ep-asr.en.test.spk-dep.rev.stm 27 | python epasr/make_utts_stm.py $test_indep epasr/data/test_indep epasr/data/test_indep_wav.tsv ${data}/test/original_audio/spk-indep/refs/ep-asr.en.test.spk-indep.rev.stm 28 | 29 | # skip `ignore_time_segment_in_scoring` 30 | for set in "dev_dep" "dev_indep" "test_dep" "test_indep"; do 31 | python utils/rm_utt.py epasr/data/${set}_wav.tsv 32 | done 33 | 34 | # wav -> lmfb (npy) 35 | python utils/wav_to_feats.py epasr/data/train_wav.tsv 36 | for set in "dev_dep" "dev_indep" "test_dep" "test_indep"; do 37 | python utils/wav_to_feats.py epasr/data/${set}_wav.tsv 38 | done 39 | 40 | # normalize 41 | for set in "train" "dev_dep" "dev_indep" "test_dep" "test_indep"; do 42 | python utils/norm_feats.py epasr/data/${set}_wav.tsv epasr/data/train_wav_norm.pkl 43 | done 44 | 45 | # tokenize 46 | for set in "train" "dev_dep" "dev_indep" "test_dep" "test_indep"; do 47 | python utils/spm_encode.py epasr/data/${set}_wav.tsv -model ted2/data/sp10k/sp10k.model -vocab ted2/data/sp10k/vocab.txt --out epasr/data/tedsp10k/${set}.tsv 48 | python ted2/prep_tsv.py epasr/data/tedsp10k/${set}.tsv 49 | done 50 | 51 | python utils/sort_bylen.py epasr/data/sp10k/train.tsv 52 | 53 | ### LM 54 | 55 | -------------------------------------------------------------------------------- /corpora/ted2/join_suffix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | 5 | 6 | def process_text(text): 7 | # it 's -> it's 8 | tokens = text.split() 9 | new_tokens = [] 10 | i = 0 11 | while i < len(tokens): 12 | if i < len(tokens) - 1 and tokens[i + 1][0] == "'": 13 | new_tokens.append(tokens[i] + tokens[i + 1]) 14 | i += 1 15 | else: 16 | new_tokens.append(tokens[i]) 17 | i += 1 18 | new_text = " ".join(new_tokens) 19 | return new_text 20 | 21 | 22 | def main(args): 23 | df = pd.read_table(args.tsv_path) 24 | df["text"] = df["text"].map(process_text) 25 | df.to_csv(args.tsv_path, sep="\t", index=False) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("tsv_path", type=str) 31 | args = parser.parse_args() 32 | main(args) 33 | -------------------------------------------------------------------------------- /corpora/ted2/make_utts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | 9 | def main(args): 10 | rows = [] # utt_id, wav_path, text 11 | 12 | for stm_file in tqdm(sorted(os.listdir(args.stm_dir))): 13 | stm_path = os.path.join(args.stm_dir, stm_file) 14 | if not stm_path.endswith(".stm"): 15 | continue 16 | 17 | # read stm 18 | with open(stm_path) as f: 19 | lines = f.readlines() 20 | for line in lines: 21 | sections = line.strip().split() 22 | utt_prefix = sections[0] 23 | start_time = float(sections[3]) 24 | end_time = float(sections[4]) 25 | text = " ".join(sections[6:]) 26 | start_time_str = str(int(start_time * 100)).zfill(7) 27 | end_time_str = str(int(end_time * 100)).zfill(7) 28 | utt_id = f"{utt_prefix}-{start_time_str}-{end_time_str}" 29 | 30 | out_wav_dir = os.path.join(args.out_wav_dir, utt_prefix) 31 | os.makedirs(out_wav_dir, exist_ok=True) 32 | 33 | # The training set seems to not have enough silence padding in the segmentations, 34 | # especially at the beginning of segments. Extend the times. 35 | if args.extend_time: 36 | start_time_fix = max(0, start_time - 0.15) 37 | end_time_fix = end_time + 0.1 38 | else: 39 | start_time_fix = start_time 40 | end_time_fix = end_time 41 | 42 | if args.speed_perturb: 43 | for speed in ["0.9", "1.0", "1.1"]: 44 | wav_path = os.path.join(args.wav_dir, f"sp{speed}-{utt_prefix}.wav") 45 | sp_utt_id = f"sp{speed}-{utt_id}" 46 | out_wav_path = os.path.join(out_wav_dir, f"{sp_utt_id}.wav") 47 | start_time_fix_sp = start_time_fix / float(speed) 48 | end_time_fix_sp = end_time_fix / float(speed) 49 | 50 | # trim wav 51 | cp = subprocess.run( 52 | [ 53 | "sox", 54 | wav_path, 55 | out_wav_path, 56 | "trim", 57 | f"{start_time_fix_sp:.2f}", 58 | f"={end_time_fix_sp:.2f}", 59 | ] 60 | ) 61 | assert cp.returncode == 0 62 | rows.append((sp_utt_id, out_wav_path, text)) 63 | else: 64 | wav_path = os.path.join(args.wav_dir, f"{utt_prefix}.wav") 65 | out_wav_path = os.path.join(out_wav_dir, f"{utt_id}.wav") 66 | 67 | # trim wav 68 | cp = subprocess.run( 69 | [ 70 | "sox", 71 | wav_path, 72 | out_wav_path, 73 | "trim", 74 | f"{start_time_fix:.2f}", 75 | f"={end_time_fix:.2f}", 76 | ] 77 | ) 78 | assert cp.returncode == 0 79 | rows.append((utt_id, out_wav_path, text)) 80 | 81 | data = pd.DataFrame(rows, columns=["utt_id", "wav_path", "text"]) 82 | data.to_csv(args.tsv_path, sep="\t", index=False) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("stm_dir", type=str) 88 | parser.add_argument("wav_dir", type=str) 89 | parser.add_argument("out_wav_dir", type=str) 90 | parser.add_argument("tsv_path", type=str) 91 | parser.add_argument("--extend_time", action="store_true") 92 | parser.add_argument("--speed_perturb", action="store_true") 93 | args = parser.parse_args() 94 | 95 | main(args) 96 | -------------------------------------------------------------------------------- /corpora/ted2/prep.sh: -------------------------------------------------------------------------------- 1 | cd ../ 2 | 3 | # download dataset 4 | wget http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz -P data/orig 5 | tar xzf TEDLIUM_release2.tar.gz 6 | 7 | # install sph2pipe 8 | wget https://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz 9 | tar xzf sph2pipe_v2.5.tar.gz 10 | cd sph2pipe_v2.5/ 11 | 12 | # sph -> wav 13 | mkdir ted2/data/orig/TEDLIUM_release2/train/wav 14 | for set in "train" "dev" "test"; do 15 | wavdir=ted2/data/${set}/wav 16 | mkdir -p $wavdir 17 | sphpaths="ted2/data/orig/TEDLIUM_release2/${set}/sph/*.sph" 18 | for sphpath in $sphpaths; do 19 | wavpath=${sphpath//sph/wav} 20 | ted2/sph2pipe_v2.5/sph2pipe -f wav -p $sphpath $wavpath 21 | echo "${sphpath} -> ${wavpath}" 22 | done 23 | done 24 | 25 | # speed perturbation 26 | wavpaths="ted2/data/orig/TEDLIUM_release2/train/wav/*.wav" 27 | mkdir ted2/data/orig/TEDLIUM_release2/train/wav_sp 28 | for speed in "0.9" "1.0" "1.1"; do 29 | for wavpath in $wavpaths; do 30 | wav=$(basename ${wavpath}) 31 | spwavpath="ted2/data/orig/TEDLIUM_release2/train/wav_sp/sp${speed}-${wav}" 32 | sox ${wavpath} ${spwavpath} speed ${speed} 33 | echo "${wavpath} -> ${spwavpath}" 34 | done 35 | done 36 | for set in "dev" "test"; do 37 | mkdir ted2/data/orig/TEDLIUM_release2/${set}/wav_sp 38 | cp ted2/data/orig/TEDLIUM_release2/${set}/wav/*.wav ted2/data/orig/TEDLIUM_release2/${set}/wav_sp/. 39 | done 40 | 41 | # split wav for utterances 42 | for set in "train" "dev" "test"; do 43 | stmdir=ted2/data/orig/TEDLIUM_release2/${set}/stm 44 | wavdir=ted2/data/orig/TEDLIUM_release2/${set}/wav_sp 45 | outwavdir=ted2/data/${set}/feats 46 | mkdir $outwavdir 47 | tsvpath=ted2/data/${set}_feats.tsv 48 | if [${set} = "train"]; then 49 | python ted2/make_utts.py $stmdir $wavdir $outwavdir $tsvpath --extend_time --speed_perturb 50 | else 51 | python ted2/make_utts.py $stmdir $wavdir $outwavdir $tsvpath 52 | fi 53 | done 54 | 55 | for set in "train" "dev" "test"; do 56 | # skip `ignore_time_segment_in_scoring` 57 | python utils/rm_utt.py ted2/data/${set}_feats.tsv 58 | # e.g. it 's -> it's 59 | python ted2/join_suffix.py ted2/data/${set}_feats.tsv 60 | done 61 | 62 | # wav -> lmfb (npy) 63 | for set in "train" "dev" "test"; do 64 | python utils/wav_to_feats.py ted2/data/${set}_feats.tsv 65 | done 66 | 67 | # normalize by train 68 | for set in "train" "dev" "test"; do 69 | python utils/norm_feats.py ted2/data/${set}_feats.tsv ted2/data/train_feats_norm.pkl 70 | done 71 | 72 | # tokenize 73 | mkdir ted2/data/sp10k 74 | python utils/get_cols.py ted2/data/train_feats.tsv -cols text --no_header -out ted2/data/train_feats.txt 75 | python utils/spm_train.py ted2/data/train_feats.txt -model ted2/data/sp10k/sp10k.model -vocab ted2/data/sp10k/vocab.txt -vocab_size 10000 76 | for set in "train" "dev" "test"; do 77 | python utils/spm_encode.py ted2/data/${set}_feats.tsv -model ted2/data/sp10k/sp10k.model -vocab ted2/data/sp10k/vocab.txt --out ted2/data/sp10k/${set}.tsv 78 | python ted2/prep_tsv.py ted2/data/sp10k/${set}.tsv 79 | done 80 | 81 | python utils/sort_bylen.py ted2/data/sp10k/train.tsv 82 | -------------------------------------------------------------------------------- /corpora/ted2/prep_tsv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 6 | sys.path.append(EMOASR_ROOT) 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from utils.converters import str2ints 11 | 12 | 13 | def get_xlen(npy_path): 14 | x = np.load(npy_path) 15 | return len(x) 16 | 17 | 18 | def get_ylen(token_id): 19 | return len(str2ints(token_id)) 20 | 21 | 22 | def main(args): 23 | df = pd.read_table(args.tsv_path) 24 | if "wav_path" in df: 25 | df["wav_path"] = df["wav_path"].str.replace(".wav", f"_{args.norm_suffix}.npy", regex=False) 26 | df["wav_path"] = "/n/work1/futami/emoASR/corpora/" + df["wav_path"] 27 | df = df.rename(columns={"wav_path": "feat_path"}) 28 | if "xlen" not in df: 29 | df["xlen"] = df["feat_path"].map(get_xlen) 30 | if "ylen" not in df: 31 | df["ylen"] = df["token_id"].map(get_ylen) 32 | 33 | df.to_csv(args.tsv_path, sep="\t", index=False) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("tsv_path", type=str) 39 | parser.add_argument("--norm_suffix", type=str, default="norm") 40 | args = parser.parse_args() 41 | main(args) 42 | -------------------------------------------------------------------------------- /corpora/utils/concat_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import os 4 | import sys 5 | 6 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 7 | sys.path.append(EMOASR_ROOT) 8 | 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | from utils.converters import ints2str, str2ints 13 | 14 | 15 | def main(args): 16 | data = pd.read_table(args.tsv_path) 17 | # data = data.dropna() 18 | 19 | print(f"Read tsv ({len(data)} samples)") 20 | 21 | # shuffle 22 | if args.shuffle: 23 | data = data.sample(frac=1, random_state=0).reset_index(drop=True) 24 | print(f"Data shuffled") 25 | else: 26 | print(f"Data NOT shuffled") 27 | 28 | # concat sentences (its lengths is NOT always the same as args.max_len) 29 | if args.task == "P2W": 30 | utt_id_start, utt_id_end = "", "" 31 | phone_token_id_concat = [args.phone_eos_id] 32 | phone_text_concat = "" 33 | token_id_concat = [args.eos_id] 34 | text_concat = "" 35 | 36 | outs = [] # utt_id, phone_token_id, phone_text, token_id, text 37 | 38 | for row in tqdm(data.itertuples()): 39 | utt_id = row.utt_id 40 | phone_token_id = str2ints(row.phone_token_id) + [args.phone_eos_id] 41 | token_id = str2ints(row.token_id) + [args.eos_id] 42 | phone_text = f" {row.phone_text} " 43 | text = f" {row.text} " 44 | 45 | if len(phone_token_id) + 1 > args.max_src_len: 46 | continue 47 | if len(token_id) + 1 > args.max_len: 48 | continue 49 | 50 | if utt_id_start == "": 51 | utt_id_start = row.utt_id 52 | utt_id_end = row.utt_id 53 | 54 | # NOTE: filter by its length 55 | if ( 56 | len(phone_token_id_concat) + len(phone_token_id) > args.max_src_len 57 | or len(token_id_concat) + len(token_id) > args.max_len 58 | ): 59 | if ( 60 | len(phone_token_id_concat) >= args.min_src_len 61 | and len(token_id_concat) >= args.min_len 62 | ): 63 | outs.append( 64 | ( 65 | f"{utt_id_start}-{utt_id_end}", 66 | ints2str(phone_token_id_concat), 67 | phone_text_concat, 68 | ints2str(token_id_concat), 69 | text_concat, 70 | ) 71 | ) 72 | 73 | utt_id_start, utt_id_end = "", "" 74 | phone_token_id_concat = [args.phone_eos_id] 75 | phone_text_concat = "" 76 | token_id_concat = [args.eos_id] 77 | text_concat = "" 78 | 79 | else: 80 | phone_token_id_concat.extend(phone_token_id) 81 | token_id_concat.extend(token_id) 82 | phone_text_concat += phone_text 83 | text_concat += text 84 | 85 | if utt_id_start != "": 86 | if ( 87 | len(phone_token_id_concat) >= args.min_src_len 88 | and len(token_id_concat) >= args.min_len 89 | ): 90 | outs.append( 91 | ( 92 | f"{utt_id_start}-{utt_id_end}", 93 | ints2str(phone_token_id_concat), 94 | phone_text_concat, 95 | ints2str(token_id_concat), 96 | text_concat, 97 | ) 98 | ) 99 | data = pd.DataFrame( 100 | outs, 101 | columns=["utt_id", "phone_token_id", "phone_text", "token_id", "text",], 102 | ) 103 | 104 | # concat tokens (its lengths is always the same as args.max_len) 105 | # NOTE: sentence longer than max_len is skipped 106 | elif args.task == "LM": 107 | utt_id_start, utt_id_end = "", "" 108 | token_id_concat = [args.eos_id] 109 | 110 | outs = [] # utt_id, token_id, text 111 | 112 | for row in tqdm(data.itertuples()): 113 | utt_id = row.utt_id 114 | token_id = str2ints(row.token_id) + [args.eos_id] 115 | 116 | if utt_id_start == "": 117 | utt_id_start = row.utt_id 118 | utt_id_end = row.utt_id 119 | 120 | if len(token_id) > args.max_len: 121 | continue 122 | 123 | if len(token_id_concat) + len(token_id) < args.max_len: 124 | token_id_concat += token_id 125 | else: 126 | remainder = args.max_len - len(token_id_concat) 127 | token_id_concat += token_id[:remainder] 128 | assert len(token_id_concat) == args.max_len 129 | outs.append((f"{utt_id_start}-{utt_id_end}", ints2str(token_id_concat))) 130 | utt_id_start, utt_id_end = "", "" 131 | token_id_concat = token_id[remainder:] 132 | 133 | # NOTE: text cannot provide 134 | data = pd.DataFrame(outs, columns=["utt_id", "token_id"],) 135 | 136 | elif args.task == "LMall": 137 | if args.eos_id >= 0: 138 | token_id_all = [args.eos_id] 139 | else: 140 | token_id_all = [] 141 | 142 | # NOTE: First, concat all tokens 143 | for row in data.itertuples(): 144 | token_id_all.extend(str2ints(row.token_id)) 145 | if args.eos_id >= 0: 146 | token_id_all.append(args.eos_id) 147 | 148 | # save memory 149 | del data 150 | gc.collect() 151 | 152 | start = 0 153 | utt_id_prefix = os.path.splitext(os.path.basename(args.tsv_path))[0] 154 | outs = [] # utt_id, token_id 155 | 156 | for i in range(args.rep): 157 | start = 0 + i * (args.max_len // args.rep) 158 | while start + args.max_len < len(token_id_all): 159 | end = start + args.max_len 160 | outs.append( 161 | (f"{utt_id_prefix}-{i}-{start}", ints2str(token_id_all[start:end])) 162 | ) 163 | start = end 164 | 165 | # NOTE: text cannot provide 166 | data = pd.DataFrame(outs, columns=["utt_id", "token_id"],) 167 | 168 | if args.out is None: 169 | data.to_csv( 170 | f"{os.path.splitext(args.tsv_path)[0]}_concat.tsv", sep="\t", index=False 171 | ) 172 | else: 173 | data.to_csv(args.out, sep="\t", index=False) 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser() 178 | parser.add_argument("tsv_path", type=str) 179 | parser.add_argument("-task", choices=["P2W", "LM", "LMall"], required=True) 180 | parser.add_argument("--max_len", type=int, default=256) 181 | parser.add_argument("--min_len", type=int, default=64) 182 | # NOTE: max source length is set to 1024 183 | parser.add_argument("--max_src_len", type=int, default=1024) 184 | parser.add_argument("--min_src_len", type=int, default=64) 185 | parser.add_argument("--eos_id", type=int, default=2) 186 | parser.add_argument("--phone_eos_id", type=int, default=2) 187 | parser.add_argument("--rep", type=int, default=1) 188 | parser.add_argument("--out", type=str, default=None) 189 | parser.add_argument("--shuffle", action="store_true") 190 | args = parser.parse_args() 191 | main(args) 192 | -------------------------------------------------------------------------------- /corpora/utils/get_cols.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | 5 | 6 | def main(args): 7 | data = pd.read_table(args.tsv_path) 8 | print(f"Read tsv ({len(data)} samples)") 9 | 10 | columns = [column for column in args.cols.split(",")] 11 | data = data[columns] 12 | data.to_csv(args.out, index=False, header=(not args.no_header), sep="\t") 13 | print(f"Results saved to {args.out}") 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("tsv_path", type=str) 19 | parser.add_argument("-cols", type=str, required=True) 20 | parser.add_argument("-out", type=str, required=True) 21 | parser.add_argument("--no_header", action="store_true") 22 | args = parser.parse_args() 23 | main(args) 24 | -------------------------------------------------------------------------------- /corpora/utils/map2phone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add phone mapping to tsv (as `phone_token_id`, `phone_text`) with lexicon 3 | """ 4 | 5 | import argparse 6 | import os 7 | import re 8 | import sys 9 | 10 | import pandas as pd 11 | from tqdm import tqdm 12 | 13 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 14 | sys.path.append(EMOASR_ROOT) 15 | 16 | from utils.converters import ints2str 17 | from utils.vocab import Vocab 18 | 19 | 20 | def main(args): 21 | word2phone = {} 22 | with open(args.lexicon, "r", encoding="utf-8") as f: 23 | for line in f: 24 | line = re.sub(r"[\s]+", " ", line.strip()) # Remove successive spaces 25 | word = line.split(" ")[0] 26 | word = word.split("+")[0] # for CSJ 27 | word = word.lower() # for Librispeech 28 | phone_seq = " ".join(line.split(" ")[1:]) 29 | word2phone[word] = phone_seq 30 | vocab = Vocab(args.vocab) 31 | 32 | if args.input.endswith(".tsv"): 33 | tsv_path = args.input 34 | df = pd.read_table(tsv_path) 35 | df = df.dropna(subset=["utt_id", "token_id", "text"]) 36 | 37 | phone_texts = [] 38 | phone_token_ids = [] 39 | phone_lens = [] 40 | 41 | for row in tqdm(df.itertuples()): 42 | # print("text:", row.text) 43 | # print("token_id:", row.token_id) 44 | words = row.text.split(" ") 45 | phones = [] 46 | for w in words: 47 | if w in word2phone: 48 | phones += word2phone[w].split() 49 | else: 50 | phones += [args.unk] 51 | phone_text = " ".join(phones) 52 | phone_token_id = ints2str(vocab.tokens2ids(phones)) 53 | 54 | # print("phone_text:", phone_text) 55 | # print("phone_token_id:", phone_token_id) 56 | 57 | phone_texts.append(phone_text) 58 | phone_token_ids.append(phone_token_id) 59 | phone_lens.append(len(phones)) 60 | 61 | df["phone_text"] = phone_texts 62 | df["phone_token_id"] = phone_token_ids 63 | df["plen"] = phone_lens 64 | 65 | if args.cols is not None: 66 | columns = [column for column in args.cols.split(",")] 67 | assert ( 68 | ("utt_id" in columns) 69 | and ("phone_text" in columns) 70 | and ("phone_token_id" in columns) 71 | and ("plen" in columns) 72 | ) 73 | df = df[columns] 74 | 75 | if args.out is None: 76 | df.to_csv(tsv_path.replace(".tsv", "_p2w.tsv"), sep="\t", index=False) 77 | else: 78 | df.to_csv(args.out, sep="\t", index=False) 79 | else: 80 | words = args.input.split(" ") 81 | phones = [] 82 | for w in words: 83 | if w in word2phone: 84 | phones += word2phone[w].split() 85 | else: 86 | phones += [args.unk] 87 | phone_text = " ".join(phones) 88 | phone_token_id = ints2str(vocab.tokens2ids(phones)) 89 | 90 | print(f"text: {phone_text}") 91 | print(f"token_id: {phone_token_id}") 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("input", type=str) # tsv_path or text 97 | parser.add_argument("-lexicon", type=str, required=True) 98 | parser.add_argument("-vocab", type=str, required=True) 99 | parser.add_argument("--unk", type=str, default="NSN") 100 | parser.add_argument("--out", type=str, default=None) 101 | parser.add_argument( 102 | "--cols", type=str, default=None 103 | ) # utt_id,token_id,text,phone_token_id,phone_text 104 | args = parser.parse_args() 105 | main(args) 106 | -------------------------------------------------------------------------------- /corpora/utils/map2phone_g2p.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add phone mapping to tsv (as `phone_token_id`, `phone_text`) with g2p tool (openjtalk) 3 | """ 4 | 5 | import argparse 6 | import os 7 | import sys 8 | 9 | import pandas as pd 10 | import pyopenjtalk 11 | from tqdm import tqdm 12 | 13 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 14 | sys.path.append(EMOASR_ROOT) 15 | 16 | from utils.converters import ints2str 17 | from utils.vocab import Vocab 18 | 19 | 20 | def build_vocab(df, vocab_path): 21 | print(f"building vocab ...") 22 | 23 | vocab_dict = {"": 1, "": 2, "": 3} 24 | vocab_set = [] 25 | 26 | for row in tqdm(df.itertuples()): 27 | text = row.text.replace(" ", "") # remove spaces 28 | 29 | phones = pyopenjtalk.g2p(text, join=False) 30 | # remove pause 31 | phones = [phone for phone in phones if phone != "pau"] 32 | 33 | for phone in phones: 34 | if phone not in vocab_set: 35 | vocab_set.append(phone) 36 | 37 | # alphabetical order 38 | vocab_set.sort() 39 | 40 | wlines = [] 41 | for v in vocab_set: 42 | index = len(vocab_dict) + 1 43 | vocab_dict[v] = index 44 | 45 | for v, index in vocab_dict.items(): 46 | wlines.append(f"{v} {index:d}\n") 47 | 48 | with open(vocab_path, "w", encoding="utf-8") as f: 49 | f.writelines(wlines) 50 | 51 | print(f"vocabulary saved to {vocab_path}") 52 | 53 | return Vocab(vocab_path) 54 | 55 | 56 | def main(args): 57 | df = pd.read_table(args.tsv_path) 58 | df = df.dropna(subset=["utt_id", "token_id", "text"]) 59 | 60 | if not os.path.exists(args.vocab): 61 | vocab = build_vocab(df, args.vocab) 62 | else: 63 | vocab = Vocab(args.vocab) 64 | print(f"load vocab: {args.vocab}") 65 | 66 | phone_texts = [] 67 | phone_token_ids = [] 68 | phone_lens = [] 69 | 70 | for row in tqdm(df.itertuples()): 71 | text = row.text.replace(" ", "") # remove spaces 72 | phones = pyopenjtalk.g2p(text, join=False) 73 | phone_text = " ".join(phones) 74 | phone_token_id = ints2str(vocab.tokens2ids(phones)) 75 | 76 | phone_texts.append(phone_text) 77 | phone_token_ids.append(phone_token_id) 78 | phone_lens.append(len(phones)) 79 | 80 | df["phone_text"] = phone_texts 81 | df["phone_token_id"] = phone_token_ids 82 | df["plen"] = phone_lens 83 | 84 | if args.cols is not None: 85 | columns = [column for column in args.cols.split(",")] 86 | assert ( 87 | ("utt_id" in columns) 88 | and ("phone_text" in columns) 89 | and ("phone_token_id" in columns) 90 | ) 91 | df = df[columns] 92 | 93 | if args.out is None: 94 | df.to_csv(args.tsv_path.replace(".tsv", "_p2w.tsv"), sep="\t", index=False) 95 | else: 96 | df.to_csv(args.out, sep="\t", index=False) 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument("tsv_path", type=str) 102 | parser.add_argument("-vocab", type=str, required=True) 103 | parser.add_argument("--out", type=str, default=None) 104 | parser.add_argument( 105 | "--cols", type=str, default=None 106 | ) # utt_id,token_id,text,phone_token_id,phone_text 107 | args = parser.parse_args() 108 | main(args) 109 | -------------------------------------------------------------------------------- /corpora/utils/norm_feats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | 9 | def save_feats(npy_path, mean, std, norm_suffix="norm"): 10 | out_npy_path = npy_path.replace(".npy", f"_{args.norm_suffix}.npy") 11 | x = np.load(npy_path) 12 | x_norm = (x - mean) / std 13 | np.save(out_npy_path, x_norm) 14 | 15 | 16 | def main(args): 17 | norm_paths = args.norm_path.split(",") 18 | 19 | lmfb_sum, lmfb_sqsum = None, None 20 | num_frames = 0 21 | for norm_path in norm_paths: 22 | with open(norm_path, "rb") as f: 23 | norm_info = pickle.load(f) 24 | if lmfb_sum is None: 25 | lmfb_sum = norm_info["lmfb_sum"] 26 | lmfb_sqsum = norm_info["lmfb_sqsum"] 27 | else: 28 | lmfb_sum += norm_info["lmfb_sum"] 29 | lmfb_sqsum += norm_info["lmfb_sqsum"] 30 | num_frames += norm_info["num_frames"] 31 | 32 | mean = lmfb_sum / num_frames 33 | var = lmfb_sqsum / num_frames - (mean * mean) 34 | std = np.sqrt(var) 35 | 36 | if args.data_path.endswith(".tsv"): 37 | df = pd.read_table(args.data_path) 38 | for row in tqdm(df.itertuples()): 39 | npy_path = row.wav_path.replace(".wav", ".npy") 40 | save_feats(npy_path, mean, std, norm_suffix=args.norm_suffix) 41 | elif args.data_path.endswith(".npy"): 42 | save_feats(args.data_path, mean, std) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("data_path", type=str) 48 | parser.add_argument("norm_path", type=str) 49 | parser.add_argument("--norm_suffix", type=str, default="norm") 50 | args = parser.parse_args() 51 | main(args) 52 | -------------------------------------------------------------------------------- /corpora/utils/rm_utt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | 5 | IGNORE_TEXT = "ignore_time_segment_in_scoring" 6 | 7 | 8 | def main(args): 9 | df = pd.read_table(args.tsv_path) 10 | df_rm = df[df["text"] != IGNORE_TEXT] 11 | print(f"remove {IGNORE_TEXT} in {args.tsv_path}: {len(df):d} -> {len(df_rm):d}") 12 | df_rm.to_csv(args.tsv_path, index=False, sep="\t") 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("tsv_path", type=str) 18 | args = parser.parse_args() 19 | main(args) 20 | -------------------------------------------------------------------------------- /corpora/utils/sort_bylen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def get_xlen(npy_path): 9 | x = np.load(npy_path) 10 | return len(x) 11 | 12 | 13 | def sort_data(data, task): 14 | if task == "ASR": 15 | if "xlen" not in data: 16 | data["xlen"] = data["feat_path"].map(get_xlen) 17 | data_sorted = data.sort_values(["xlen"]) 18 | 19 | elif task == "P2W": 20 | if "plen" not in data: 21 | data["plen"] = data["phone_token_id"].str.split().str.len() 22 | data_sorted = data.sort_values(["plen"]) 23 | 24 | return data_sorted 25 | 26 | 27 | def main(args): 28 | if os.path.isdir(args.tsv_path): 29 | for tsv_file in os.listdir(args.tsv_path): 30 | tsv_file_path = os.path.join(args.tsv_path, tsv_file) 31 | data = pd.read_table(tsv_file_path) 32 | data_sorted = sort_data(data, args.task) 33 | save_path = f"{os.path.splitext(tsv_file_path)[0]}_sorted.tsv" 34 | # NOTE: inplace 35 | data_sorted.to_csv(tsv_file_path, sep="\t", index=False) 36 | print(f"sorted data saved to: {tsv_file_path}") 37 | else: 38 | data = pd.read_table(args.tsv_path) 39 | data_sorted = sort_data(data, args.task) 40 | save_path = f"{os.path.splitext(args.tsv_path)[0]}_sorted.tsv" 41 | data_sorted.to_csv(save_path, sep="\t", index=False) 42 | print(f"sorted data saved to: {save_path}") 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("tsv_path", type=str) 48 | parser.add_argument("--task", choices=["ASR", "P2W"], default="ASR") 49 | args = parser.parse_args() 50 | main(args) 51 | -------------------------------------------------------------------------------- /corpora/utils/split_tsv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pandas as pd 5 | 6 | 7 | def split(args): 8 | df = pd.read_table(args.tsv_path) 9 | print(f"Data size: {len(df):d}") 10 | 11 | if args.shuffle: 12 | df = df.sample(frac=1, random_state=0).reset_index(drop=True) 13 | print(f"Data shuffled") 14 | else: 15 | print(f"Data NOT shuffled") 16 | 17 | out_dir = os.path.splitext(args.tsv_path)[0] 18 | os.makedirs(out_dir, exist_ok=True) 19 | 20 | for i in range(args.n_splits - 1): 21 | s_id = int((i / args.n_splits) * len(df)) 22 | t_id = int(((i + 1) / args.n_splits) * len(df)) - 1 23 | df_part = df.loc[s_id:t_id] 24 | out_path = os.path.join(out_dir, f"part{(i+1):d}of{args.n_splits:d}.tsv") 25 | df_part.to_csv(out_path, index=False, sep="\t") 26 | print(f"Data[{s_id:d}:{t_id:d}] (size={t_id - s_id + 1}) saved to {out_path}") 27 | 28 | s_id = int(((args.n_splits - 1) / args.n_splits) * len(df)) 29 | t_id = len(df) - 1 30 | df_part = df.loc[s_id:t_id] 31 | out_path = os.path.join(out_dir, f"part{args.n_splits:d}of{args.n_splits:d}.tsv") 32 | df_part.to_csv(out_path, index=False, sep="\t") 33 | print(f"Data[{s_id:d}:{t_id:d}] (size={t_id-s_id+1}) saved to {out_path}") 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("tsv_path", type=str) 39 | parser.add_argument("-n_splits", type=int, required=True) 40 | parser.add_argument("--shuffle", action="store_true") 41 | args = parser.parse_args() 42 | 43 | split(args) 44 | -------------------------------------------------------------------------------- /corpora/utils/spm_encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import pandas as pd 6 | # https://github.com/google/sentencepiece 7 | import sentencepiece as spm 8 | from tqdm import tqdm 9 | 10 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 11 | sys.path.append(EMOASR_ROOT) 12 | 13 | from utils.converters import ints2str 14 | from utils.vocab import Vocab 15 | 16 | 17 | def main(args): 18 | df = pd.read_table(args.data) 19 | df = df.dropna() 20 | 21 | sp = spm.SentencePieceProcessor() 22 | sp.Load(args.model) 23 | vocab = Vocab(args.vocab) 24 | 25 | token_ids = [] 26 | 27 | for row in tqdm(df.itertuples()): 28 | tokens = sp.EncodeAsPieces(row.text) 29 | token_id = vocab.tokens2ids(tokens) 30 | token_ids.append(ints2str(token_id)) 31 | 32 | df["token_id"] = token_ids 33 | 34 | if args.out is None: 35 | # overwrite 36 | df.to_csv(args.data, sep="\t", index=False) 37 | else: 38 | df.to_csv(args.out, sep="\t", index=False) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("data", type=str) # .tsv 44 | parser.add_argument("-model", type=str, required=True) 45 | parser.add_argument("-vocab", type=str, required=True) 46 | parser.add_argument("--out", type=str, default=None) 47 | args = parser.parse_args() 48 | main(args) 49 | -------------------------------------------------------------------------------- /corpora/utils/spm_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import sentencepiece as spm 5 | 6 | 7 | def build_vocab( 8 | spm_vocab_path, vocab_path, special_tokens=["", "", "", ""] 9 | ): 10 | outs = [] 11 | for i, token in enumerate(special_tokens): 12 | outs.append(f"{token} {i:d}\n") 13 | with open(spm_vocab_path) as f: 14 | for i, line in enumerate(f): 15 | token = line.split()[0] 16 | outs.append(f"{token} {(i+len(special_tokens)):d}\n") 17 | with open(vocab_path, "w") as f: 18 | f.writelines(outs) 19 | 20 | 21 | def main(args): 22 | spm.SentencePieceTrainer.train( 23 | input=args.data, 24 | model_prefix=args.model.replace(".model", ""), 25 | vocab_size=args.vocab_size, 26 | ) 27 | build_vocab(args.model.replace(".model", ".vocab"), args.vocab) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("data", type=str) # .tsv 33 | parser.add_argument("-model", type=str, required=True) 34 | parser.add_argument("-vocab", type=str, required=True) 35 | parser.add_argument("-vocab_size", type=int, required=True) 36 | args = parser.parse_args() 37 | main(args) 38 | -------------------------------------------------------------------------------- /corpora/utils/wav_to_feats.py: -------------------------------------------------------------------------------- 1 | """ Convert wav to lmfb (as numpy array) 2 | """ 3 | 4 | import argparse 5 | import os 6 | import pickle 7 | import sys 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import torchaudio 13 | from tqdm import tqdm 14 | 15 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 16 | sys.path.append(EMOASR_ROOT) 17 | 18 | from utils.converters import tensor2np 19 | 20 | 21 | def save_feats(wav_path): 22 | with torch.no_grad(): 23 | wav, sr = torchaudio.load(wav_path) 24 | assert sr == 16000 25 | wav *= 2 ** 15 # kaldi 26 | lmfb = torchaudio.compliance.kaldi.fbank( 27 | wav, 28 | window_type="hamming", 29 | htk_compat=True, 30 | sample_frequency=16000, 31 | num_mel_bins=80, 32 | use_energy=False, 33 | ) 34 | lmfb = tensor2np(lmfb) 35 | 36 | npy_path = wav_path.replace(".wav", ".npy") 37 | np.save(npy_path, lmfb) 38 | 39 | lmfb_sum = np.sum(lmfb, axis=0) 40 | lmfb_sqsum = np.sum(lmfb * lmfb, axis=0) 41 | num_frames = lmfb.shape[0] 42 | 43 | return lmfb_sum, lmfb_sqsum, num_frames 44 | 45 | 46 | def main(args): 47 | if args.data_path.endswith(".tsv"): 48 | lmfb_sum_all, lmfb_sqsum_all = [], [] 49 | num_frames_all = 0 50 | data = pd.read_table(args.data_path) 51 | for row in tqdm(data.itertuples()): 52 | lmfb_sum, lmfb_sqsum, num_frames = save_feats(row.wav_path) 53 | lmfb_sum_all.extend(lmfb_sum) 54 | lmfb_sqsum_all.extend(lmfb_sqsum) 55 | num_frames_all += num_frames 56 | norm_info = {} 57 | norm_info["lmfb_sum"] = lmfb_sum 58 | norm_info["lmfb_sqsum"] = lmfb_sqsum 59 | norm_info["num_frames"] = num_frames 60 | 61 | pickle_path = args.data_path.replace(".tsv", "_norm.pkl") 62 | with open(pickle_path, "wb") as f: 63 | pickle.dump(norm_info, f) 64 | 65 | elif args.data_path.endswith(".wav"): 66 | lmfb_sum, lmfb_sqsum, num_frames = save_feats(args.data_path) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("data_path", type=str) 72 | args = parser.parse_args() 73 | main(args) 74 | -------------------------------------------------------------------------------- /lm/README.md: -------------------------------------------------------------------------------- 1 | ## Results 2 | 3 | ### Librispeech 4 | 5 | [nsp10k] 6 | 7 | | | | params | PPL(clean) | 8 | |:---:|:---|:---:|:---:| 9 | | `L1` | Transformer | 12M | 63.9 | 10 | | `L2` | BERT | 12M | 12.1 | 11 | | `L3` | RNN | 13M | 75.2 | 12 | 13 | - `L1`: `exps/libri_nsp10k/transformer` 14 | - `L2`: `exps/libri_nsp10k/bert` 15 | - `L3`: `exps/libri_nsp10k/rnn` 16 | 17 | ### TED-LIUM2 18 | 19 | [nsp10k] 20 | 21 | | | | params | PPL(test) | 22 | |:---:|:---|:---:|:---:| 23 | | `T1` | Transformer | 12M | 56.3 | 24 | | `T2` | BERT | 12M | 11.6 | 25 | | `T3` | RNN | 13M | 69.6 | 26 | 27 | - `T1`: `exps/ted2_nsp10k/transformer` 28 | - `T2`: `exps/ted2_nsp10k/bert` 29 | - `T3`: `exps/ted2_nsp10k/rnn` 30 | 31 | ### CSJ 32 | 33 | [nsp10k] 34 | 35 | | | | params | PPL(eval1) | 36 | |:---:|:---|:---:|:---:| 37 | | `C1` | Transformer (BCCWJ init) | 25M | 53.7 | 38 | | `C2` | BERT (BCCWJ init) | 25M | 8.6 | 39 | | `S1` | Transformer (BCCWJ init) | 12M | 38.4 | 40 | | `S2` | Transformer (w/o init) | 12M | 40.8 | 41 | | `S3` | BERT (BCCWJ init) | 12M | 10.0 | 42 | | `S4` | BERT (w/o init) | 12M | x | 43 | | `S5` | RNN (BCCWJ init) | 14M | 39.5 | 44 | | `S6` | RNN (w/o init) | 14M | 49.0 | 45 | 46 | - `C1`: `exps/csj_nsp10k_bccwj/bccwj_csj_transformer` 47 | - `C2`: `exps/csj_nsp10k_bccwj/bccwj_csj_bert` 48 | - `S1`: `exps/csj_nsp10k_bccwj/bccwj_csj_transformer_small` 49 | - `S2`: `exps/csj_nsp10k_bccwj/csj_transformer_small` 50 | - `S3`: `exps/csj_nsp10k_bccwj/bccwj_csj_bert_small` 51 | - `S4`: `exps/csj_nsp10k_bccwj/csj_bert_small` 52 | - `S5`: `exps/csj_nsp10k_bccwj/bccwj_csj_rnn_small` 53 | - `S6`: `exps/csj_nsp10k_bccwj/csj_rnn_small` 54 | -------------------------------------------------------------------------------- /lm/criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MaskedLMLoss(nn.Module): 6 | def __init__(self, vocab_size): 7 | super(MaskedLMLoss, self).__init__() 8 | 9 | # NOTE: pad with -100 (not masked) 10 | self.ce_loss = nn.CrossEntropyLoss() 11 | 12 | self.vocab_size = vocab_size 13 | 14 | def forward(self, logits, labels, ylens): 15 | loss = self.ce_loss( 16 | logits.contiguous().view(-1, self.vocab_size), labels.contiguous().view(-1), 17 | ) 18 | return loss 19 | -------------------------------------------------------------------------------- /lm/exps/ted2_nsp10k/electra.yaml: -------------------------------------------------------------------------------- 1 | lm_type: "electra" 2 | lr_schedule_type: "lindecay" 3 | 4 | # Generator 5 | gen_embedding_size: 256 6 | gen_hidden_size: 256 7 | gen_intermediate_size: 1024 8 | gen_num_attention_heads: 4 9 | gen_num_layers: 12 10 | 11 | # Discriminator (same size as Generator) 12 | disc_embedding_size: 256 13 | disc_hidden_size: 256 14 | disc_intermediate_size: 1024 15 | disc_num_attention_heads: 4 16 | disc_num_layers: 12 17 | 18 | vocab_size: 9798 19 | max_seq_len: 256 20 | eos_id: 2 21 | mask_id: 9797 22 | 23 | train_path: "corpora/ted2/nsp10k/data/train_ext_concat_noshuf/" 24 | train_size: 1232639 25 | add_sos_eos: false 26 | dev_path: "" 27 | test_path: "" 28 | vocab_path: "corpora/ted2/nsp10k/data/orig/vocab.txt" 29 | 30 | model_path: "" 31 | optim_path: "" 32 | startep: 0 33 | 34 | log_step: 25 35 | save_step: 1 36 | 37 | # train 38 | batch_size: 90 39 | num_epochs: 40 40 | learning_rate: 0.0001 41 | warmup_proportion: 0.1 42 | weight_decay: 0.01 43 | electra_disc_weight: 50 44 | num_to_mask: 35 45 | random_num_to_mask: false 46 | clip_grad_norm: 5.0 47 | 48 | accum_grad: 1 49 | -------------------------------------------------------------------------------- /lm/exps/ted2_nsp10k/pelectra.yaml: -------------------------------------------------------------------------------- 1 | lm_type: "pelectra" 2 | lr_schedule_type: "lindecay" 3 | 4 | # Generator 5 | input_layer: "embed" 6 | enc_hidden_size: 256 7 | enc_num_attention_heads: 4 8 | enc_num_layers: 4 9 | enc_intermediate_size: 1024 10 | dec_hidden_size: 256 11 | dec_num_attention_heads: 4 12 | dec_num_layers: 4 13 | dec_intermediate_size: 1024 14 | dropout_enc_rate: 0.1 15 | dropout_dec_rate: 0.1 16 | dropout_attn_rate: 0.1 17 | mtl_ctc_weight: 0 18 | lsm_prob: 0 19 | kd_weight: 0 20 | max_decode_ylen: 256 21 | 22 | # Discriminator 23 | disc_embedding_size: 256 24 | disc_hidden_size: 256 25 | disc_intermediate_size: 1024 26 | disc_num_attention_heads: 4 27 | disc_num_layers: 12 28 | 29 | vocab_size: 9798 30 | src_vocab_size: 45 31 | max_seq_len: 256 32 | blank_id: 0 33 | eos_id: 2 34 | mask_id: 9797 35 | phone_eos_id: 2 36 | phone_mask_id: 44 37 | 38 | # textaug 39 | text_augment: true 40 | textaug_max_mask_prob: 0.6 41 | textaug_max_replace_prob: 0 42 | 43 | train_path: "corpora/ted2/nsp10k/data/train_ext_p2w_concat_noshuf" 44 | train_size: 1167745 45 | add_sos_eos: false 46 | dev_path: "corpora/ted2/nsp10k/data/dev_p2w.tsv" 47 | test_path: "corpora/ted2/nsp10k/data/test_p2w.tsv" 48 | phone_vocab_path: "corpora/ted2/nsp10k/data/orig/vocab_phone.txt" 49 | vocab_path: "corpora/ted2/nsp10k/data/orig/vocab.txt" 50 | bucket_shuffle: false 51 | 52 | model_path: "" 53 | optim_path: "" 54 | startep: 0 55 | 56 | log_step: 25 57 | save_step: 1 58 | 59 | # train 60 | batch_size: 100 61 | num_epochs: 40 62 | learning_rate: 0.0002 63 | warmup_proportion: 0.1 64 | weight_decay: 0.01 65 | electra_disc_weight: 50 66 | 67 | # different from electra 68 | # we also compared this strategy on electra pre-training, but it didn't work well. 69 | mask_proportion: 0.3 70 | random_num_to_mask: true 71 | 72 | clip_grad_norm: 5.0 73 | 74 | accum_grad: 1 75 | -------------------------------------------------------------------------------- /lm/modeling/bert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | EMOASR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 8 | sys.path.append(EMOASR_DIR) 9 | 10 | from asr.modeling.model_utils import make_nopad_mask 11 | 12 | from lm.modeling.transformers.configuration_transformers import \ 13 | TransformersConfig 14 | from lm.modeling.transformers.modeling_bert import BertForMaskedLM 15 | 16 | 17 | class BERTMaskedLM(nn.Module): 18 | def __init__(self, params): 19 | super(BERTMaskedLM, self).__init__() 20 | config = TransformersConfig( 21 | vocab_size=params.vocab_size, 22 | hidden_size=params.hidden_size, 23 | num_hidden_layers=params.num_layers, 24 | num_attention_heads=params.num_attention_heads, 25 | intermediate_size=params.intermediate_size, 26 | max_position_embeddings=params.max_seq_len, 27 | ) 28 | self.bert = BertForMaskedLM(config) 29 | 30 | # if params.tie_weights: 31 | # pass 32 | 33 | self.mask_id = params.mask_id 34 | 35 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 36 | if ylens is None: 37 | attention_mask = None 38 | else: 39 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 40 | # DataParallel 41 | ys = ys[:, : max(ylens)] 42 | 43 | if labels is None: 44 | (logits,) = self.bert(ys, attention_mask=attention_mask) 45 | return logits 46 | 47 | if ylens is not None: 48 | labels = labels[:, : max(ylens)] 49 | loss, logits = self.bert(ys, attention_mask=attention_mask, labels=labels) 50 | loss_dict = {"loss_total": loss} 51 | 52 | return loss, loss_dict 53 | 54 | def score(self, ys, ylens, batch_size=100): 55 | """ score token sequence for Rescoring 56 | """ 57 | score_lms = [] 58 | 59 | for y, ylen in zip(ys, ylens): 60 | ys_masked, mask_pos, mask_label = [], [], [] 61 | 62 | score_lm = 0 63 | 64 | for pos in range(ylen): 65 | y_masked = y[:ylen].clone() 66 | y_masked[pos] = self.mask_id 67 | ys_masked.append(y_masked) 68 | mask_pos.append(pos) 69 | mask_label.append(y[pos]) 70 | 71 | if len(ys_masked) < batch_size and pos != (ylen - 1): 72 | continue 73 | 74 | ys_masked = torch.stack(ys_masked, dim=0) 75 | (logits,) = self.bert(ys_masked) 76 | logprobs = torch.log_softmax(logits, dim=-1) 77 | 78 | bs = ys_masked.size(0) 79 | for b in range(bs): 80 | score_lm += logprobs[b, mask_pos[b], mask_label[b]].item() 81 | 82 | ys_masked, mask_pos, mask_label = [], [], [] 83 | 84 | score_lms.append(score_lm) 85 | 86 | return score_lms 87 | 88 | def load_state_dict(self, state_dict): 89 | try: 90 | super().load_state_dict(state_dict) 91 | except: 92 | self.bert.load_state_dict(state_dict) 93 | -------------------------------------------------------------------------------- /lm/modeling/electra.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | EMOASR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 9 | sys.path.append(EMOASR_DIR) 10 | 11 | from asr.modeling.model_utils import make_nopad_mask 12 | from utils.log import get_num_parameters 13 | 14 | from lm.modeling.p2w import P2W 15 | from lm.modeling.transformers.configuration_electra import ElectraConfig 16 | from lm.modeling.transformers.modeling_electra import (ElectraForMaskedLM, 17 | ElectraForPreTraining) 18 | 19 | 20 | def sample_temp(logits, temp=1.0): 21 | if temp == 0.0: 22 | return torch.argmax(logits) 23 | probs = torch.softmax(logits / temp, dim=2) # smoothed softmax 24 | # sampling from probs 25 | sample_ids = ( 26 | probs.view(-1, probs.shape[2]) 27 | .multinomial(1, replacement=False) 28 | .view(probs.shape[0], -1) 29 | ) 30 | return sample_ids 31 | 32 | 33 | class ELECTRAModel(nn.Module): 34 | def __init__(self, params): 35 | super(ELECTRAModel, self).__init__() 36 | 37 | # Generator 38 | gconfig = ElectraConfig( 39 | vocab_size=params.vocab_size, 40 | hidden_size=params.gen_hidden_size, 41 | embedding_size=params.gen_embedding_size, 42 | num_hidden_layers=params.gen_num_layers, 43 | num_attention_heads=params.gen_num_attention_heads, 44 | intermediate_size=params.gen_intermediate_size, 45 | max_position_embeddings=params.max_seq_len, 46 | ) 47 | self.gmodel = ElectraForMaskedLM(config=gconfig) 48 | num_params, num_params_trainable = get_num_parameters(self.gmodel) 49 | logging.info( 50 | f"ELECTRA: Generator #parameters: {num_params} ({num_params_trainable} trainable)" 51 | ) 52 | 53 | # Discrminator 54 | dconfig = ElectraConfig( 55 | vocab_size=params.vocab_size, 56 | hidden_size=params.disc_hidden_size, 57 | embedding_size=params.disc_embedding_size, 58 | num_hidden_layers=params.disc_num_layers, 59 | num_attention_heads=params.disc_num_attention_heads, 60 | intermediate_size=params.disc_intermediate_size, 61 | max_position_embeddings=params.max_seq_len, 62 | ) 63 | self.dmodel = ElectraForPreTraining(config=dconfig) 64 | num_params, num_params_trainable = get_num_parameters(self.dmodel) 65 | logging.info( 66 | f"ELECTRA: Discriminator #parameters: {num_params} ({num_params_trainable} trainable)" 67 | ) 68 | 69 | self.electra_disc_weight = params.electra_disc_weight 70 | 71 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 72 | if ylens is None: 73 | attention_mask = None 74 | else: 75 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 76 | ys = ys[:, : max(ylens)] # DataParallel 77 | 78 | gloss, glogits = self.gmodel(ys, attention_mask=attention_mask, labels=labels) 79 | 80 | generated_ids = ys.clone() 81 | masked_indices = labels.long() != -100 82 | original_ids = ys.clone() 83 | original_ids[masked_indices] = labels[masked_indices] 84 | sample_ids = sample_temp(glogits) # sampling 85 | generated_ids[masked_indices] = sample_ids[masked_indices] 86 | labels_replaced = (generated_ids.long() != original_ids.long()).long() 87 | 88 | dloss, dlogits = self.dmodel( 89 | generated_ids, attention_mask=attention_mask, labels=labels_replaced 90 | ) 91 | 92 | loss = gloss + self.electra_disc_weight * dloss 93 | loss_dict = {} 94 | 95 | loss_dict["loss_gen"] = gloss 96 | loss_dict["loss_disc"] = dloss 97 | loss_dict["num_replaced"] = labels_replaced.sum().long() / ys.size(0) 98 | loss_dict["num_masked"] = masked_indices.sum().long() / ys.size(0) 99 | 100 | return loss, loss_dict 101 | 102 | def forward_disc(self, ys, ylens=None, error_labels=None): 103 | if ylens is None: 104 | attention_mask = None 105 | else: 106 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 107 | ys = ys[:, : max(ylens)] # DataParallel 108 | 109 | loss, _ = self.dmodel( 110 | ys, attention_mask=attention_mask, labels=error_labels 111 | ) 112 | loss_dict = {"loss_total": loss} 113 | 114 | return loss, loss_dict 115 | 116 | def score(self, ys, ylens, batch_size=None): 117 | """ score token sequence for Rescoring 118 | """ 119 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 120 | logits, = self.dmodel(ys, attention_mask=attention_mask) 121 | probs = torch.sigmoid(logits) 122 | 123 | if ys.size(0) == 1: 124 | return [torch.sum(probs, dim=-1).item()] 125 | 126 | score_lms = [] 127 | bs = len(ys) 128 | for b in range(bs): 129 | score_lm = (-1) * torch.sum(probs[b, : ylens[b]], dim=-1).item() 130 | score_lms.append(score_lm) 131 | 132 | return score_lms 133 | 134 | class PELECTRAModel(nn.Module): 135 | """ 136 | Phone-attentive ELECTRA 137 | """ 138 | 139 | def __init__(self, params): 140 | super(PELECTRAModel, self).__init__() 141 | 142 | # Generator: condictional MLM 143 | self.gmodel = P2W( 144 | params, encoder_type="transformer", decoder_type="bert", return_logits=True, 145 | ) 146 | num_params, num_params_trainable = get_num_parameters(self.gmodel) 147 | logging.info( 148 | f"PELECTRA: Generator #parameters: {num_params} ({num_params_trainable} trainable)" 149 | ) 150 | 151 | # Discrminator 152 | dconfig = ElectraConfig( 153 | vocab_size=params.vocab_size, 154 | hidden_size=params.disc_hidden_size, 155 | embedding_size=params.disc_embedding_size, 156 | num_hidden_layers=params.disc_num_layers, 157 | num_attention_heads=params.disc_num_attention_heads, 158 | intermediate_size=params.disc_intermediate_size, 159 | max_position_embeddings=params.max_seq_len, 160 | ) 161 | self.dmodel = ElectraForPreTraining(config=dconfig) 162 | num_params, num_params_trainable = get_num_parameters(self.dmodel) 163 | logging.info( 164 | f"PELECTRA: Discriminator #parameters: {num_params} ({num_params_trainable} trainable)" 165 | ) 166 | 167 | self.electra_disc_weight = params.electra_disc_weight 168 | 169 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 170 | if ylens is None: 171 | attention_ymask = None 172 | else: 173 | attention_ymask = make_nopad_mask(ylens).float().to(ys.device) 174 | # DataParallel 175 | ys = ys[:, : max(ylens)] 176 | ps = ps[:, : max(plens)] 177 | labels = labels[:, : max(ylens)] 178 | 179 | gloss, _, glogits = self.gmodel(ys, ylens, labels=labels, ps=ps, plens=plens) 180 | 181 | generated_ids = ys.clone() 182 | masked_indices = labels.long() != -100 183 | original_ids = ys.clone() 184 | original_ids[masked_indices] = labels[masked_indices] 185 | sample_ids = sample_temp(glogits) # sampling 186 | generated_ids[masked_indices] = sample_ids[masked_indices] 187 | labels_replaced = (generated_ids.long() != original_ids.long()).long() 188 | 189 | dloss, dlogits = self.dmodel( 190 | generated_ids, attention_mask=attention_ymask, labels=labels_replaced 191 | ) 192 | 193 | loss = gloss + self.electra_disc_weight * dloss 194 | loss_dict = {} 195 | 196 | loss_dict["loss_gen"] = gloss 197 | loss_dict["loss_disc"] = dloss 198 | loss_dict["num_replaced"] = labels_replaced.sum().long() / ys.size(0) 199 | loss_dict["num_masked"] = masked_indices.sum().long() / ys.size(0) 200 | 201 | return loss, loss_dict 202 | 203 | def forward_disc(self, ys, ylens=None, error_labels=None): 204 | if ylens is None: 205 | attention_mask = None 206 | else: 207 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 208 | ys = ys[:, : max(ylens)] # DataParallel 209 | 210 | loss, _ = self.dmodel( 211 | ys, attention_mask=attention_mask, labels=error_labels 212 | ) 213 | loss_dict = {"loss_total": loss} 214 | 215 | return loss, loss_dict 216 | 217 | def score(self, ys, ylens, batch_size=None): 218 | """ score token sequence for Rescoring 219 | """ 220 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 221 | logits, = self.dmodel(ys, attention_mask=attention_mask) 222 | probs = torch.sigmoid(logits) 223 | 224 | if ys.size(0) == 1: 225 | return [torch.sum(probs, dim=-1).item()] 226 | 227 | score_lms = [] 228 | bs = len(ys) 229 | for b in range(bs): 230 | score_lm = (-1) * torch.sum(probs[b, : ylens[b]], dim=-1).item() 231 | score_lms.append(score_lm) 232 | 233 | return score_lms 234 | -------------------------------------------------------------------------------- /lm/modeling/lm.py: -------------------------------------------------------------------------------- 1 | """ Language modeling 2 | """ 3 | 4 | import logging 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 12 | sys.path.append(EMOASR_ROOT) 13 | 14 | from utils.log import get_num_parameters 15 | 16 | from lm.modeling.bert import BERTMaskedLM 17 | from lm.modeling.electra import ELECTRAModel, PELECTRAModel 18 | from lm.modeling.rnn import RNNLM 19 | from lm.modeling.transformer import TransformerLM 20 | 21 | 22 | class LM(nn.Module): 23 | def __init__(self, params, phase="train"): 24 | super().__init__() 25 | 26 | self.lm_type = params.lm_type 27 | logging.info(f"LM type: {self.lm_type}") 28 | 29 | if self.lm_type == "bert": 30 | self.lm = BERTMaskedLM(params) 31 | elif self.lm_type == "transformer": 32 | self.lm = TransformerLM(params) 33 | elif self.lm_type in ["electra", "electra-disc"]: 34 | self.lm = ELECTRAModel(params) 35 | elif self.lm_type in ["pelectra", "pelectra-disc"]: 36 | self.lm = PELECTRAModel(params) 37 | elif self.lm_type == "rnn": 38 | self.lm = RNNLM(params) 39 | 40 | num_params, num_params_trainable = get_num_parameters(self) 41 | logging.info( 42 | f"LM model #parameters: {num_params} ({num_params_trainable} trainable)" 43 | ) 44 | 45 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 46 | return self.lm(ys, ylens, labels, ps, plens) 47 | 48 | def forward_disc(self, ys, ylens, error_labels): 49 | return self.lm.forward_disc(ys, ylens, error_labels) 50 | 51 | def zero_states(self, bs, device): 52 | return self.lm.zero_states(bs, device) 53 | 54 | def predict(self, ys, ylens, states=None): 55 | with torch.no_grad(): 56 | return self.lm.predict(ys, ylens, states) 57 | 58 | def score(self, ys, ylens, batch_size=100): 59 | with torch.no_grad(): 60 | return self.lm.score(ys, ylens, batch_size) 61 | 62 | def load_state_dict(self, state_dict): 63 | try: 64 | super().load_state_dict(state_dict) 65 | except: 66 | self.lm.load_state_dict(state_dict) 67 | -------------------------------------------------------------------------------- /lm/modeling/p2w.py: -------------------------------------------------------------------------------- 1 | """ Phone-to-word modeling 2 | """ 3 | 4 | import logging 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 12 | sys.path.append(EMOASR_ROOT) 13 | 14 | # same modeling as ASR 15 | from asr.modeling.decoders.transformer import CTCDecoder, TransformerDecoder 16 | from asr.modeling.encoders.transformer import TransformerEncoder 17 | from utils.log import get_num_parameters 18 | 19 | 20 | class P2W(nn.Module): 21 | def __init__( 22 | self, 23 | params, 24 | phase="train", 25 | encoder_type=None, 26 | decoder_type=None, 27 | return_logits=False, 28 | ): 29 | super().__init__() 30 | 31 | self.lm_type = params.lm_type 32 | logging.info(f"LM type: {self.lm_type}") 33 | 34 | self.encoder = TransformerEncoder(params) 35 | 36 | if decoder_type is None: 37 | if self.lm_type == "ptransformer": 38 | self.decoder_type = "transformer" 39 | elif self.lm_type == "pbert": 40 | self.decoder_type = "bert" 41 | elif self.lm_type == "pctc": 42 | self.decoder_type = "ctc" 43 | else: 44 | self.decoder_type = decoder_type 45 | 46 | if self.decoder_type == "transformer": 47 | self.decoder = TransformerDecoder(params) 48 | elif self.decoder_type == "bert": 49 | self.decoder = TransformerDecoder(params, cmlm=True) 50 | elif self.decoder_type == "ctc": 51 | self.decoder = CTCDecoder(params) 52 | 53 | self.vocab_size = params.vocab_size 54 | self.eos_id = params.eos_id 55 | self.add_sos_eos = params.add_sos_eos 56 | 57 | num_params, num_params_trainable = get_num_parameters(self) 58 | logging.info( 59 | f"P2W model #parameters: {num_params} ({num_params_trainable} trainable)" 60 | ) 61 | 62 | self.return_logits = return_logits 63 | 64 | def forward(self, ys=None, ylens=None, labels=None, ps=None, plens=None): 65 | # DataParallel 66 | if plens is None: 67 | plens = torch.tensor([ps.size(1)]).to(ps.device) 68 | else: 69 | ps = ps[:, : max(plens)] 70 | 71 | if ylens is None: 72 | ylens = torch.tensor([ys.size(1)]).to(ys.device) 73 | else: 74 | ys = ys[:, : max(ylens)] 75 | 76 | eouts, elens, _ = self.encoder(ps, plens) 77 | 78 | if self.decoder_type == "ctc": 79 | loss, loss_dict, logits = self.decoder(eouts, elens, ys=ys, ylens=ylens) 80 | return loss, loss_dict 81 | 82 | # FIXME: take care of `ymask = make_tgt_mask(ylens + 1)` 83 | if self.decoder_type == "transformer": 84 | ylens -= 1 85 | 86 | if labels is None: 87 | logits = self.decoder(eouts, elens, ys=ys, ylens=ylens, ys_in=ys) 88 | return logits 89 | 90 | labels = labels[:, : max(ylens)] 91 | 92 | loss, loss_dict, logits = self.decoder( 93 | eouts, elens, ys=ys, ylens=ylens, ys_in=ys, ys_out=labels 94 | ) 95 | 96 | if self.return_logits: 97 | return loss, loss_dict, logits 98 | else: 99 | return loss, loss_dict 100 | 101 | def decode(self, ps, plens=None): 102 | if plens is None: 103 | plens = torch.tensor([ps.size(1)]).to(ps.device) 104 | 105 | eouts, elens, _ = self.encoder(ps, plens) 106 | hyps, _, _, _ = self.decoder.decode(eouts, elens) 107 | return hyps 108 | -------------------------------------------------------------------------------- /lm/modeling/rnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | EMOASR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 8 | sys.path.append(EMOASR_DIR) 9 | 10 | from utils.converters import tensor2np 11 | 12 | 13 | class RNNLM(nn.Module): 14 | def __init__(self, params): 15 | super().__init__() 16 | 17 | self.embed = nn.Embedding(params.vocab_size, params.embedding_size) 18 | self.rnns = nn.LSTM( 19 | input_size=params.embedding_size, 20 | hidden_size=params.hidden_size, 21 | num_layers=params.num_layers, 22 | dropout=params.dropout_rate, 23 | batch_first=True, 24 | ) 25 | self.output = nn.Linear(params.hidden_size, params.vocab_size) 26 | self.dropout = nn.Dropout(params.dropout_rate) 27 | self.loss_fn = nn.CrossEntropyLoss() # ignore_index = -100 28 | 29 | self.num_layers = params.num_layers 30 | self.hidden_size = params.hidden_size 31 | self.vocab_size = params.vocab_size 32 | 33 | if params.tie_weights: 34 | pass 35 | 36 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 37 | if ylens is not None: 38 | # DataParallel 39 | ys = ys[:, : max(ylens)] 40 | 41 | ys_emb = self.dropout(self.embed(ys)) 42 | out, _ = self.rnns(ys_emb) 43 | logits = self.output(self.dropout(out)) 44 | 45 | if labels is None: 46 | return logits 47 | 48 | if ylens is not None: 49 | labels = labels[:, : max(ylens)] 50 | loss = self.loss_fn(logits.view(-1, self.vocab_size), labels.view(-1)) 51 | loss_dict = {"loss_total": loss} 52 | 53 | return loss, loss_dict 54 | 55 | def zero_states(self, bs, device): 56 | zeros = torch.zeros( 57 | self.num_layers, bs, self.hidden_size, device=device 58 | ) 59 | # hidden state, cell state 60 | return (zeros, zeros) 61 | 62 | def predict(self, ys, ylens, states=None): 63 | """ predict next token for Shallow Fusion 64 | """ 65 | ys_last = [] 66 | bs = len(ys) 67 | for b in range(bs): 68 | ys_last.append(tensor2np(ys[b, ylens[b] - 1 : ylens[b]])) 69 | ys_last = torch.tensor(ys_last).to(ys.device) 70 | 71 | #print("ys:", ys) 72 | #print("ylens:", ylens) 73 | #print("ys_last:", ys_last) 74 | #print("ys_last:", ys_last.shape) 75 | ys_last_emb = self.dropout(self.embed(ys_last)) 76 | out, states = self.rnns(ys_last_emb, states) 77 | logits = self.output(self.dropout(out)) 78 | #print("logits:", logits.shape) 79 | log_probs = torch.log_softmax(logits, dim=-1) 80 | 81 | return log_probs[:, -1], states 82 | 83 | def score(self, ys, ylens, batch_size=None): 84 | """ score token sequence for Rescoring 85 | """ 86 | pass 87 | -------------------------------------------------------------------------------- /lm/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | EMOASR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") 9 | sys.path.append(EMOASR_DIR) 10 | 11 | from asr.modeling.model_utils import make_nopad_mask 12 | from utils.converters import tensor2np 13 | 14 | from lm.modeling.transformers.configuration_transformers import \ 15 | TransformersConfig 16 | from lm.modeling.transformers.modeling_bert import BertForMaskedLM 17 | 18 | 19 | class TransformerLM(nn.Module): 20 | def __init__(self, params): 21 | super().__init__() 22 | config = TransformersConfig( 23 | vocab_size=params.vocab_size, 24 | hidden_size=params.hidden_size, 25 | num_hidden_layers=params.num_layers, 26 | num_attention_heads=params.num_attention_heads, 27 | intermediate_size=params.intermediate_size, 28 | max_position_embeddings=params.max_seq_len, 29 | ) 30 | self.transformer = BertForMaskedLM(config) 31 | 32 | # if params.tie_weights: 33 | # pass 34 | 35 | def forward(self, ys, ylens=None, labels=None, ps=None, plens=None): 36 | if ylens is None: 37 | attention_mask = None 38 | else: 39 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 40 | # DataParallel 41 | ys = ys[:, : max(ylens)] 42 | 43 | if labels is None: 44 | # NOTE: causal attention mask 45 | (logits,) = self.transformer(ys, attention_mask=attention_mask, causal=True) 46 | return logits 47 | 48 | if ylens is not None: 49 | labels = labels[:, : max(ylens)] 50 | # NOTE: causal attention mask 51 | loss, logits = self.transformer( 52 | ys, attention_mask=attention_mask, causal=True, labels=labels 53 | ) 54 | loss_dict = {"loss_total": loss} 55 | 56 | return loss, loss_dict 57 | 58 | def zero_states(self, bs, device): 59 | # Transformer LM is stateless 60 | return None 61 | 62 | def predict(self, ys, ylens, states=None): 63 | """ predict next token for Shallow Fusion 64 | """ 65 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 66 | 67 | with torch.no_grad(): 68 | (logits,) = self.transformer(ys, attention_mask, causal=True) 69 | 70 | log_probs = torch.log_softmax(logits, dim=-1) 71 | 72 | log_probs_next = [] 73 | bs = len(ys) 74 | for b in range(bs): 75 | log_probs_next.append(tensor2np(log_probs[b, ylens[b] - 1])) 76 | 77 | return torch.tensor(log_probs_next).to(ys.device), states 78 | 79 | def score(self, ys, ylens, batch_size=None): 80 | """ score token sequence for Rescoring 81 | """ 82 | attention_mask = make_nopad_mask(ylens).float().to(ys.device) 83 | 84 | with torch.no_grad(): 85 | (logits,) = self.transformer(ys, attention_mask, causal=True) 86 | 87 | log_probs = torch.log_softmax(logits, dim=-1) 88 | 89 | score_lms = [] 90 | bs = len(ys) 91 | for b in range(bs): 92 | score_lm = 0 93 | 94 | for i in range(0, ylens[b] - 1): 95 | v = ys[b, i + 1].item() # predict next 96 | score_lm += log_probs[b, i, v].item() 97 | score_lms.append(score_lm) 98 | 99 | return score_lms 100 | 101 | def load_state_dict(self, state_dict): 102 | try: 103 | super().load_state_dict(state_dict) 104 | except: 105 | self.transformer.load_state_dict(state_dict) 106 | -------------------------------------------------------------------------------- /lm/modeling/transformers/activations.py: -------------------------------------------------------------------------------- 1 | """ transformers v3.0.0 2 | https://github.com/huggingface/transformers/tree/b62ca59527de4e883fb8e91f02e97586115616b1 3 | """ 4 | 5 | import logging 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def swish(x): 15 | return x * torch.sigmoid(x) 16 | 17 | 18 | def _gelu_python(x): 19 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 20 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 21 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 22 | This is now written in C in torch.nn.functional 23 | Also see https://arxiv.org/abs/1606.08415 24 | """ 25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 26 | 27 | 28 | def gelu_new(x): 29 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 30 | Also see https://arxiv.org/abs/1606.08415 31 | """ 32 | return ( 33 | 0.5 34 | * x 35 | * ( 36 | 1.0 37 | + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))) 38 | ) 39 | ) 40 | 41 | 42 | if torch.__version__ < "1.4.0": 43 | gelu = _gelu_python 44 | else: 45 | gelu = F.gelu 46 | 47 | 48 | def gelu_fast(x): 49 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 50 | 51 | 52 | ACT2FN = { 53 | "relu": F.relu, 54 | "swish": swish, 55 | "gelu": gelu, 56 | "tanh": torch.tanh, 57 | "gelu_new": gelu_new, 58 | "gelu_fast": gelu_fast, 59 | } 60 | 61 | 62 | def get_activation(activation_string): 63 | if activation_string in ACT2FN: 64 | return ACT2FN[activation_string] 65 | else: 66 | raise KeyError( 67 | "function {} not found in ACT2FN mapping {}".format( 68 | activation_string, list(ACT2FN.keys()) 69 | ) 70 | ) 71 | -------------------------------------------------------------------------------- /lm/modeling/transformers/configuration_electra.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ ELECTRA model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json", 27 | "google/electra-base-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-generator/config.json", 28 | "google/electra-large-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-generator/config.json", 29 | "google/electra-small-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-discriminator/config.json", 30 | "google/electra-base-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/config.json", 31 | "google/electra-large-discriminator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-large-discriminator/config.json", 32 | } 33 | 34 | 35 | class ElectraConfig(PretrainedConfig): 36 | r""" 37 | This is the configuration class to store the configuration of a :class:`~transformers.ElectraModel`. 38 | It is used to instantiate an ELECTRA model according to the specified arguments, defining the model 39 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 40 | the ELECTRA `google/electra-small-discriminator `__ 41 | architecture. 42 | 43 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 44 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 45 | for more information. 46 | 47 | 48 | Args: 49 | vocab_size (:obj:`int`, optional, defaults to 30522): 50 | Vocabulary size of the ELECTRA model. Defines the different tokens that 51 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.ElectraModel`. 52 | embedding_size (:obj:`int`, optional, defaults to 128): 53 | Dimensionality of the encoder layers and the pooler layer. 54 | hidden_size (:obj:`int`, optional, defaults to 256): 55 | Dimensionality of the encoder layers and the pooler layer. 56 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 57 | Number of hidden layers in the Transformer encoder. 58 | num_attention_heads (:obj:`int`, optional, defaults to 4): 59 | Number of attention heads for each attention layer in the Transformer encoder. 60 | intermediate_size (:obj:`int`, optional, defaults to 1024): 61 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 62 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 63 | The non-linear activation function (function or string) in the encoder and pooler. 64 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 65 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 66 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 67 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 68 | The dropout ratio for the attention probabilities. 69 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 70 | The maximum sequence length that this model might ever be used with. 71 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 72 | type_vocab_size (:obj:`int`, optional, defaults to 2): 73 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.ElectraModel`. 74 | initializer_range (:obj:`float`, optional, defaults to 0.02): 75 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 76 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 77 | The epsilon used by the layer normalization layers. 78 | summary_type (:obj:`string`, optional, defaults to "first"): 79 | Argument used when doing sequence summary. Used in for the multiple choice head in 80 | :class:`~transformers.ElectraForMultipleChoice`. 81 | Is one of the following options: 82 | 83 | - 'last' => take the last token hidden state (like XLNet) 84 | - 'first' => take the first token hidden state (like Bert) 85 | - 'mean' => take the mean of all tokens hidden states 86 | - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) 87 | - 'attn' => Not implemented now, use multi-head attention 88 | summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`): 89 | Argument used when doing sequence summary. Used in for the multiple choice head in 90 | :class:`~transformers.ElectraForMultipleChoice`. 91 | Add a projection after the vector extraction 92 | summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`): 93 | Argument used when doing sequence summary. Used in for the multiple choice head in 94 | :class:`~transformers.ElectraForMultipleChoice`. 95 | 'gelu' => add a gelu activation to the output, Other => no activation. 96 | summary_last_dropout (:obj:`float`, optional, defaults to 0.0): 97 | Argument used when doing sequence summary. Used in for the multiple choice head in 98 | :class:`~transformers.ElectraForMultipleChoice`. 99 | Add a dropout after the projection and activation 100 | 101 | Example:: 102 | 103 | >>> from transformers import ElectraModel, ElectraConfig 104 | 105 | >>> # Initializing a ELECTRA electra-base-uncased style configuration 106 | >>> configuration = ElectraConfig() 107 | 108 | >>> # Initializing a model from the electra-base-uncased style configuration 109 | >>> model = ElectraModel(configuration) 110 | 111 | >>> # Accessing the model configuration 112 | >>> configuration = model.config 113 | """ 114 | model_type = "electra" 115 | 116 | def __init__( 117 | self, 118 | vocab_size=30522, 119 | embedding_size=128, 120 | hidden_size=256, 121 | num_hidden_layers=12, 122 | num_attention_heads=4, 123 | intermediate_size=1024, 124 | hidden_act="gelu", 125 | hidden_dropout_prob=0.1, 126 | attention_probs_dropout_prob=0.1, 127 | max_position_embeddings=512, 128 | type_vocab_size=2, 129 | initializer_range=0.02, 130 | layer_norm_eps=1e-12, 131 | summary_type="first", 132 | summary_use_proj=True, 133 | summary_activation="gelu", 134 | summary_last_dropout=0.1, 135 | pad_token_id=0, 136 | **kwargs 137 | ): 138 | super().__init__(pad_token_id=pad_token_id, **kwargs) 139 | 140 | self.vocab_size = vocab_size 141 | self.embedding_size = embedding_size 142 | self.hidden_size = hidden_size 143 | self.num_hidden_layers = num_hidden_layers 144 | self.num_attention_heads = num_attention_heads 145 | self.intermediate_size = intermediate_size 146 | self.hidden_act = hidden_act 147 | self.hidden_dropout_prob = hidden_dropout_prob 148 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 149 | self.max_position_embeddings = max_position_embeddings 150 | self.type_vocab_size = type_vocab_size 151 | self.initializer_range = initializer_range 152 | self.layer_norm_eps = layer_norm_eps 153 | 154 | self.summary_type = summary_type 155 | self.summary_use_proj = summary_use_proj 156 | self.summary_activation = summary_activation 157 | self.summary_last_dropout = summary_last_dropout 158 | -------------------------------------------------------------------------------- /lm/modeling/transformers/configuration_transformers.py: -------------------------------------------------------------------------------- 1 | """ transformers v3.0.0 2 | https://github.com/huggingface/transformers/tree/b62ca59527de4e883fb8e91f02e97586115616b1 3 | """ 4 | 5 | """ BERT model configuration """ 6 | 7 | 8 | import logging 9 | 10 | from .configuration_utils import PretrainedConfig 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 15 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 16 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 17 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 18 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 19 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 20 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 21 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 22 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 23 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 24 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 25 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 26 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 27 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 28 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 29 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 30 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 31 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 32 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 33 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 34 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 35 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 36 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 37 | # See all BERT models at https://huggingface.co/models?filter=bert 38 | } 39 | 40 | 41 | class TransformersConfig(PretrainedConfig): 42 | r""" 43 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 44 | It is used to instantiate an BERT model according to the specified arguments, defining the model 45 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 46 | the BERT `bert-base-uncased `__ architecture. 47 | 48 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 49 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 50 | for more information. 51 | 52 | 53 | Args: 54 | vocab_size (:obj:`int`, optional, defaults to 30522): 55 | Vocabulary size of the BERT model. Defines the different tokens that 56 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 57 | hidden_size (:obj:`int`, optional, defaults to 768): 58 | Dimensionality of the encoder layers and the pooler layer. 59 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 60 | Number of hidden layers in the Transformer encoder. 61 | num_attention_heads (:obj:`int`, optional, defaults to 12): 62 | Number of attention heads for each attention layer in the Transformer encoder. 63 | intermediate_size (:obj:`int`, optional, defaults to 3072): 64 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 65 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 66 | The non-linear activation function (function or string) in the encoder and pooler. 67 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 68 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 69 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 70 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 71 | The dropout ratio for the attention probabilities. 72 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 73 | The maximum sequence length that this model might ever be used with. 74 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 75 | type_vocab_size (:obj:`int`, optional, defaults to 2): 76 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 77 | initializer_range (:obj:`float`, optional, defaults to 0.02): 78 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 79 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 80 | The epsilon used by the layer normalization layers. 81 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 82 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 83 | 84 | Example:: 85 | 86 | >>> from transformers import BertModel, BertConfig 87 | 88 | >>> # Initializing a BERT bert-base-uncased style configuration 89 | >>> configuration = BertConfig() 90 | 91 | >>> # Initializing a model from the bert-base-uncased style configuration 92 | >>> model = BertModel(configuration) 93 | 94 | >>> # Accessing the model configuration 95 | >>> configuration = model.config 96 | """ 97 | # model_type = "bert" 98 | 99 | def __init__( 100 | self, 101 | vocab_size=30522, 102 | hidden_size=768, 103 | num_hidden_layers=12, 104 | num_attention_heads=12, 105 | intermediate_size=3072, 106 | hidden_act="gelu", 107 | hidden_dropout_prob=0.1, 108 | attention_probs_dropout_prob=0.1, 109 | max_position_embeddings=512, 110 | type_vocab_size=2, 111 | initializer_range=0.02, 112 | layer_norm_eps=1e-12, 113 | pad_token_id=0, 114 | gradient_checkpointing=False, 115 | **kwargs 116 | ): 117 | super().__init__(pad_token_id=pad_token_id, **kwargs) 118 | 119 | self.vocab_size = vocab_size 120 | self.hidden_size = hidden_size 121 | self.num_hidden_layers = num_hidden_layers 122 | self.num_attention_heads = num_attention_heads 123 | self.hidden_act = hidden_act 124 | self.intermediate_size = intermediate_size 125 | self.hidden_dropout_prob = hidden_dropout_prob 126 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 127 | self.max_position_embeddings = max_position_embeddings 128 | self.type_vocab_size = type_vocab_size 129 | self.initializer_range = initializer_range 130 | self.layer_norm_eps = layer_norm_eps 131 | self.gradient_checkpointing = gradient_checkpointing 132 | -------------------------------------------------------------------------------- /lm/test_ppl.py: -------------------------------------------------------------------------------- 1 | """ test LM on perplexity (PPL) 2 | """ 3 | import argparse 4 | import logging 5 | import math 6 | import os 7 | import sys 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 13 | sys.path.append(EMOASR_ROOT) 14 | 15 | from utils.converters import tensor2np 16 | from utils.io_utils import load_config 17 | from utils.log import insert_comment 18 | from utils.paths import get_eval_path, get_model_path, rel_to_abs_path 19 | from utils.vocab import Vocab 20 | 21 | from lm.datasets import LMDataset, P2WDataset 22 | from lm.modeling.lm import LM 23 | from lm.modeling.p2w import P2W 24 | 25 | # Reproducibility 26 | torch.manual_seed(0) 27 | torch.cuda.manual_seed_all(0) 28 | 29 | LOG_STEP = 100 30 | 31 | 32 | def ppl_lm(dataloader, model, device, vocab, add_sos_eos=False): 33 | cnt = 0 34 | sum_logprob = 0 35 | 36 | for i, data in enumerate(dataloader): 37 | if (i + 1) % LOG_STEP == 0: 38 | logging.info( 39 | f"{(i+1):>4} / {len(dataloader):>4} PPL: {math.exp(sum_logprob/cnt):.3f}" 40 | ) 41 | utt_id = data["utt_ids"][0] 42 | ys = data["ys_in"] 43 | ys_in = ys[:, :-1].to(device) 44 | ys_out = ys[:, 1:].to(device) 45 | ylens = data["ylens"].to(device) - 1 46 | assert ys.size(0) == 1 47 | 48 | # for P2WDataset 49 | ps = data["ps"].to(device) if "ps" in data else None 50 | plens = data["plens"].to(device) if "plens" in data else None 51 | 52 | if ys.size(1) <= 1: 53 | logging.warning(f"skip {utt_id}") 54 | continue 55 | if add_sos_eos and ys.size(1) <= 3: 56 | logging.warning(f"skip {utt_id}") 57 | continue 58 | 59 | with torch.no_grad(): 60 | logits = model(ys_in, ylens, labels=None, ps=ps, plens=plens) 61 | logprobs = torch.log_softmax(logits, dim=-1) 62 | 63 | # NOTE: skip the first token and prediction 64 | if add_sos_eos: 65 | logprobs = logprobs[:, 1:-1] 66 | ys_out = ys_out[:, 1:-1] 67 | 68 | for logprob, label in zip(logprobs[0], ys_out[0]): 69 | sum_logprob -= logprob[label].item() 70 | cnt += 1 71 | 72 | ppl = math.exp(sum_logprob / cnt) 73 | 74 | return cnt, ppl 75 | 76 | 77 | def ppl_masked_lm(dataloader, model, device, mask_id, max_seq_len, vocab): 78 | cnt = 0 79 | sum_logprob = 0 80 | 81 | for i, data in enumerate(dataloader): 82 | if (i + 1) % LOG_STEP == 0: 83 | logging.info( 84 | f"{(i+1):>4} / {len(dataloader):>4} PPL: {math.exp(sum_logprob/cnt):.3f}" 85 | ) 86 | utt_id = data["utt_ids"][0] 87 | ys = data["ys_in"].to(device) # not masked 88 | ylens = data["ylens"].to(device) 89 | assert ys.size(0) == 1 90 | 91 | # for P2WDataset 92 | ps = data["ps"].to(device) if "ps" in data else None 93 | plens = data["plens"].to(device) if "plens" in data else None 94 | 95 | if ys.size(1) > max_seq_len: 96 | logging.warning(f"input length longer than {max_seq_len:d} skip") 97 | continue 98 | 99 | if args.print_probs: 100 | print("********************") 101 | print(f"{utt_id}: {vocab.ids2text(tensor2np(ys[0]))}") 102 | 103 | for mask_pos in range(ys.size(1)): 104 | ys_masked = ys.clone() 105 | label = ys[0, mask_pos] 106 | ys_masked[0, mask_pos] = mask_id 107 | 108 | with torch.no_grad(): 109 | logits = model(ys_masked, ylens, labels=None, ps=ps, plens=plens) 110 | 111 | if args.print_probs: 112 | print(vocab.ids2text(tensor2np(ys_masked[0]))) 113 | # TODO: print phones 114 | 115 | probs = torch.softmax(logits, dim=-1) 116 | p_topk, v_topk = torch.topk(probs[0, mask_pos], k=5) 117 | print( 118 | f"{vocab.i2t[label.item()]} || " 119 | + " | ".join( 120 | [ 121 | f"{vocab.i2t[v.item()]}: {p.item():.2f}" 122 | for p, v in zip(p_topk, v_topk) 123 | ] 124 | ) 125 | ) 126 | 127 | logprobs = torch.log_softmax(logits, dim=-1) 128 | sum_logprob -= logprobs[0, mask_pos, label].item() 129 | cnt += 1 130 | 131 | ppl = math.exp(sum_logprob / cnt) 132 | 133 | return cnt, ppl 134 | 135 | 136 | def test(model, dataloader, params, device, vocab, add_sos_eos=False): 137 | if params.lm_type in ["bert", "pbert"]: 138 | cnt, ppl = ppl_masked_lm( 139 | dataloader, 140 | model, 141 | device, 142 | mask_id=params.mask_id, 143 | max_seq_len=params.max_seq_len, 144 | vocab=vocab, 145 | ) 146 | elif params.lm_type in ["transformer", "rnn", "ptransformer"]: 147 | cnt, ppl = ppl_lm(dataloader, model, device, vocab=vocab, add_sos_eos=add_sos_eos) 148 | 149 | logging.info(f"{cnt} tokens") 150 | return ppl 151 | 152 | 153 | def main(args): 154 | if args.cpu: 155 | device = torch.device("cpu") 156 | torch.set_num_threads(1) 157 | # make sure all operations are done on cpu 158 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 159 | else: 160 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 161 | 162 | logging.basicConfig( 163 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 164 | level=logging.INFO, 165 | ) 166 | 167 | params = load_config(args.conf) 168 | 169 | data_path = get_eval_path(args.data) 170 | if data_path is None: 171 | data_path = params.test_path 172 | logging.info(f"test data: {data_path}") 173 | 174 | with open(rel_to_abs_path(data_path)) as f: 175 | lines = f.readlines() 176 | logging.info(lines[0]) 177 | 178 | if params.lm_type in ["pelectra", "ptransformer", "pbert"]: 179 | dataset = P2WDataset(params, rel_to_abs_path(data_path), phase="test") 180 | else: 181 | dataset = LMDataset(params, rel_to_abs_path(data_path), phase="test") 182 | 183 | dataloader = DataLoader( 184 | dataset=dataset, 185 | batch_size=1, 186 | shuffle=False, 187 | collate_fn=dataset.collate_fn, 188 | num_workers=1, 189 | ) 190 | 191 | vocab = Vocab(rel_to_abs_path(params.vocab_path)) 192 | 193 | model_path = get_model_path(args.conf, args.ep) 194 | logging.info(f"model: {model_path}") 195 | 196 | if params.lm_type in ["ptransformer", "pbert"]: 197 | model = P2W(params) 198 | else: 199 | model = LM(params) 200 | 201 | model.load_state_dict(torch.load(model_path, map_location=device)) 202 | model.to(device) 203 | model.eval() 204 | 205 | ppl = test(model, dataloader, params, device, vocab, add_sos_eos=params.add_sos_eos) 206 | 207 | ppl_info = f"PPL: {ppl:.2f} (conf: {args.conf})" 208 | logging.info(ppl_info) 209 | 210 | if args.comment: 211 | insert_comment(args.data, ppl_info) 212 | 213 | 214 | if __name__ == "__main__": 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument("-conf", type=str, required=True) 217 | parser.add_argument("-ep", type=int, default=0) 218 | parser.add_argument("--cpu", action="store_true") 219 | parser.add_argument("--data", type=str, default=None) 220 | parser.add_argument("--print_probs", action="store_true") 221 | parser.add_argument("--comment", action="store_true") 222 | args = parser.parse_args() 223 | main(args) 224 | -------------------------------------------------------------------------------- /lm/text_augment.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | import torch 7 | 8 | random.seed(0) 9 | np.random.seed(0) 10 | 11 | 12 | class TextAugment: 13 | """ TextAugment 14 | 15 | Reference: 16 | - https://arxiv.org/abs/2011.08469 17 | """ 18 | 19 | def __init__(self, params): 20 | self.max_mask_prob = params.textaug_max_mask_prob 21 | self.max_replace_prob = params.textaug_max_replace_prob 22 | self.phone_vocab_size = params.src_vocab_size 23 | self.eos_id = params.phone_eos_id 24 | self.mask_id = params.phone_mask_id 25 | 26 | logging.info(f"apply TextAugment - {vars(self)}") 27 | 28 | def __call__(self, x): 29 | return self._text_replace(self._text_mask(x)) 30 | 31 | def _text_mask(self, x): 32 | x_masked = x.clone() 33 | if self.max_mask_prob <= 0: 34 | return x_masked 35 | 36 | num_to_mask = random.randint(0, int(len(x) * self.max_mask_prob)) 37 | cand_indices = [j for j in range(len(x)) if x[j] != self.eos_id] 38 | mask_indices = random.sample(cand_indices, min(len(cand_indices), num_to_mask)) 39 | x_masked[mask_indices] = self.mask_id 40 | return x_masked 41 | 42 | def _text_replace(self, x): 43 | x_replaced = x.clone() 44 | if self.max_replace_prob <= 0: 45 | return x_replaced 46 | 47 | num_to_replace = random.randint(0, int(len(x) * self.max_replace_prob)) 48 | cand_indices = [j for j in range(len(x)) if x[j] != self.eos_id] 49 | replace_indices = random.sample( 50 | cand_indices, min(len(cand_indices), num_to_replace) 51 | ) 52 | cand_vocab = [j for j in range(self.phone_vocab_size) if j != self.eos_id] 53 | replaced_ids = random.choices(cand_vocab, k=num_to_replace) 54 | x_replaced[replace_indices] = torch.tensor(replaced_ids, dtype=torch.long) 55 | return x_replaced 56 | 57 | 58 | if __name__ == "__main__": 59 | p = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 2]) 60 | 61 | params = namedtuple( 62 | "Params", 63 | [ 64 | "textaug_max_mask_prob", 65 | "textaug_max_replace_prob", 66 | "src_vocab_size", 67 | "phone_eos_id", 68 | "mask_id", 69 | ], 70 | ) 71 | params.textaug_max_mask_prob = 0.2 72 | params.textaug_max_replace_prob = 0.2 73 | params.src_vocab_size = 11 74 | params.phone_eos_id = 2 75 | params.mask_id = 10 76 | # 77 | textaug = TextAugment(params) 78 | p = textaug(p) 79 | print(p) 80 | -------------------------------------------------------------------------------- /utils/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import sys 7 | 8 | import torch 9 | 10 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 11 | sys.path.append(EMOASR_ROOT) 12 | 13 | from utils.paths import get_model_path 14 | 15 | 16 | def model_average(conf: str, ep: str): 17 | paths = [] 18 | 19 | if "-" in ep: 20 | startep = int(ep.split("-")[0]) 21 | endep = int(ep.split("-")[1]) 22 | epochs = [epoch for epoch in range(startep, endep + 1)] 23 | elif "+" in ep: 24 | epochs = list(map(int, ep.split("+"))) 25 | else: 26 | return 27 | logging.info(f"average checkpoints... (epoch: {epochs})") 28 | 29 | for epoch in epochs: 30 | paths.append(get_model_path(conf, str(epoch))) 31 | 32 | save_path = re.sub("model.ep[0-9]+", f"model.ep{ep}", paths[0]) 33 | if os.path.exists(save_path): 34 | logging.info(f"checkpoint: {save_path} already exists!") 35 | return 36 | 37 | avg = None 38 | # sum 39 | for path in paths: 40 | states = torch.load(path, map_location=torch.device("cpu")) 41 | if avg is None: 42 | avg = states 43 | else: 44 | for k in avg.keys(): 45 | avg[k] += states[k] 46 | # average 47 | for k in avg.keys(): 48 | if avg[k] is not None: 49 | avg[k] = torch.div(avg[k], len(paths)) 50 | 51 | torch.save(avg, save_path) 52 | logging.info(f"checkpoints saved to: {save_path}") 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("-conf", type=str, required=True) 58 | parser.add_argument("-ep", type=str, required=True) 59 | args = parser.parse_args() 60 | 61 | logging.basicConfig( 62 | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", 63 | level=logging.DEBUG, 64 | ) 65 | 66 | model_average(args.conf, args.ep) 67 | -------------------------------------------------------------------------------- /utils/configure.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from collections import namedtuple 3 | 4 | import yaml 5 | 6 | 7 | def load_config(config_path: str) -> dict: 8 | # TODO: detect duplicate keys 9 | with codecs.open(config_path, "r", encoding="utf-8") as f: 10 | params = yaml.load(f, Loader=yaml.FullLoader) 11 | 12 | # convert dict to namedtuple 13 | params = namedtuple("Params", params.keys())(**params) 14 | return params 15 | -------------------------------------------------------------------------------- /utils/converters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | def str2ints(s): 6 | return list(map(int, s.split())) 7 | 8 | 9 | def str2floats(s): 10 | return list(map(float, s.split())) 11 | 12 | 13 | def ints2str(ints): 14 | return " ".join(list(map(str, ints))) 15 | 16 | 17 | def get_utt_id_nosp(utt_id): 18 | if ( 19 | utt_id.startswith("sp0.9") 20 | or utt_id.startswith("sp1.0") 21 | or utt_id.startswith("sp1.1") 22 | ): 23 | utt_id_nosp = "-".join(utt_id.split("-")[1:]) 24 | else: 25 | utt_id_nosp = utt_id 26 | return utt_id_nosp 27 | 28 | 29 | def strip_eos(tokens, eos_id): 30 | return [token for token in tokens if token != eos_id] 31 | 32 | 33 | def add_sos_eos(ys, ylens, eos_id): 34 | ys_eos_list = [ 35 | torch.tensor([eos_id] + y[:ylen].tolist() + [eos_id], device=ys.device) 36 | for y, ylen in zip(ys, ylens) 37 | ] 38 | ys_eos = pad_sequence(ys_eos_list, batch_first=True, padding_value=eos_id) 39 | ylens_eos = ylens + 2 40 | return ys_eos, ylens_eos 41 | 42 | 43 | def tensor2np(x): 44 | return x.cpu().detach().numpy() 45 | 46 | 47 | def np2tensor(array, device=None): 48 | return torch.from_numpy(array).to(device) 49 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | 7 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") 8 | sys.path.append(EMOASR_ROOT) 9 | 10 | from utils.converters import np2tensor 11 | 12 | 13 | def insert_comment(file_path, comment): 14 | with open(file_path) as f: 15 | lines = f.readlines() 16 | 17 | if lines[0] == f"# {comment}\n": 18 | return 19 | 20 | lines.insert(0, f"# {comment}\n") 21 | lines.insert(1, "#\n") 22 | with open(file_path, mode="w") as f: 23 | f.writelines(lines) 24 | 25 | 26 | def get_num_parameters(model): 27 | num_params = sum(p.numel() for p in model.parameters()) 28 | num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 29 | return num_params, num_params_trainable 30 | 31 | 32 | def print_topk_probs(probs: np.ndarray, vocab, k=5): 33 | topk_infos = [] 34 | for prob in probs: 35 | p_topk, v_topk = torch.topk(np2tensor(prob), k) 36 | print( 37 | ( 38 | " | ".join( 39 | [ 40 | f"{vocab.i2t[v.item()]}: {p.item():.3f}" 41 | for p, v in zip(p_topk, v_topk) 42 | ] 43 | ) 44 | ) 45 | ) 46 | -------------------------------------------------------------------------------- /utils/paths.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import logging 3 | import os 4 | import re 5 | from collections import namedtuple 6 | 7 | import yaml 8 | 9 | EMOASR_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") 10 | 11 | 12 | def get_eval_path(ref_tag): 13 | # ted2 14 | if ref_tag == "test": # TODO: -> ted2-test 15 | return os.path.join(EMOASR_ROOT, "corpora/ted2/nsp10k/data/test.tsv") 16 | if ref_tag == "dev": 17 | return os.path.join(EMOASR_ROOT, "corpora/ted2/nsp10k/data/dev.tsv") 18 | 19 | # libri 20 | if ref_tag == "test-clean": # TODO: -> libri-test-clean 21 | return os.path.join(EMOASR_ROOT, "corpora/libri/nsp10k/data/test_clean.tsv") 22 | if ref_tag == "test-other": 23 | return os.path.join(EMOASR_ROOT, "corpora/libri/nsp10k/data/test_other.tsv") 24 | if ref_tag == "dev-clean": 25 | return os.path.join(EMOASR_ROOT, "corpora/libri/nsp10k/data/dev_clean.tsv") 26 | if ref_tag == "dev-other": 27 | return os.path.join(EMOASR_ROOT, "corpora/libri/nsp10k/data/dev_other.tsv") 28 | 29 | # csj 30 | if ref_tag == "eval1": # TODO: -> csj-eval1 31 | return os.path.join(EMOASR_ROOT, "corpora/csj/nsp10k/data/eval1.tsv") 32 | elif ref_tag == "eval2": 33 | return os.path.join(EMOASR_ROOT, "corpora/csj/nsp10k/data/eval2.tsv") 34 | elif ref_tag == "eval3": 35 | return os.path.join(EMOASR_ROOT, "corpora/csj/nsp10k/data/eval3.tsv") 36 | elif ref_tag == "csj-dev": 37 | return os.path.join(EMOASR_ROOT, "corpora/csj/nsp10k/data/dev.tsv") 38 | elif ref_tag == "csj-dev500": 39 | return os.path.join(EMOASR_ROOT, "corpora/csj/nsp10k/data/dev_500.tsv") 40 | 41 | return ref_tag 42 | 43 | 44 | def get_run_dir(conf_path): 45 | run_dir = os.path.splitext(conf_path)[0] 46 | return run_dir 47 | 48 | 49 | def get_exp_dir(conf_path): 50 | exp_dir = os.path.splitext(conf_path)[0] 51 | return exp_dir 52 | 53 | 54 | def get_model_path(conf_path, epoch): 55 | run_dir = get_run_dir(conf_path) 56 | model_dir = os.path.join(run_dir, "checkpoints") 57 | model_path = os.path.join(model_dir, f"model.ep{epoch}") 58 | return model_path 59 | 60 | 61 | def get_results_dir(conf_path): 62 | run_dir = get_run_dir(conf_path) 63 | results_dir = os.path.join(run_dir, "results") 64 | os.makedirs(results_dir, exist_ok=True) 65 | return results_dir 66 | 67 | 68 | def get_log_save_paths(conf_path): 69 | run_dir = get_run_dir(conf_path) 70 | os.makedirs(run_dir, exist_ok=True) 71 | log_dir = os.path.join(run_dir, "log") 72 | save_dir = os.path.join(run_dir, "checkpoints") 73 | os.makedirs(log_dir, exist_ok=True) 74 | os.makedirs(save_dir, exist_ok=True) 75 | save_format = os.path.join(save_dir, "model.ep{}") 76 | optim_save_format = os.path.join(save_dir, "optim.ep{}") 77 | 78 | return log_dir, save_format, optim_save_format 79 | 80 | 81 | def get_resume_paths(conf_path, epoch=0): 82 | run_dir = get_run_dir(conf_path) 83 | save_dir = os.path.join(run_dir, "checkpoints") 84 | 85 | model_ep_max = 0 86 | optim_ep_max = 0 87 | 88 | if epoch > 0: 89 | model_path = os.path.join(save_dir, f"model.ep{epoch:d}") 90 | optim_path = os.path.join(save_dir, f"optim.ep{epoch:d}") 91 | # if epoch is not given, find latest model and optim 92 | else: 93 | for ckpt_file in os.listdir(save_dir): 94 | match = re.fullmatch(r"model.ep([0-9]+)", ckpt_file) 95 | if match is not None: 96 | model_ep = int(match.group(1)) 97 | model_ep_max = max(model_ep, model_ep_max) 98 | 99 | match = re.fullmatch(r"optim.ep([0-9]+)", ckpt_file) 100 | if match is not None: 101 | optim_ep = int(match.group(1)) 102 | optim_ep_max = max(optim_ep, optim_ep_max) 103 | 104 | assert model_ep_max == optim_ep_max 105 | epoch = model_ep_max 106 | 107 | if epoch > 0: 108 | model_path = os.path.join(save_dir, f"model.ep{epoch:d}") 109 | optim_path = os.path.join(save_dir, f"optim.ep{epoch:d}") 110 | else: 111 | model_path, optim_path = "", "" 112 | 113 | return model_path, optim_path, epoch 114 | 115 | 116 | def get_model_optim_paths( 117 | conf_path, resume=False, model_path=None, optim_path=None, start_epoch=0 118 | ): 119 | resume_model_path, resume_optim_path, resume_epoch = "", "", 0 120 | if resume: 121 | resume_model_path, resume_optim_path, resume_epoch = get_resume_paths(conf_path) 122 | if resume_epoch > 0: 123 | logging.info(f"resume from epoch = {resume_epoch:d}") 124 | 125 | model_path = resume_model_path or model_path 126 | optim_path = resume_optim_path or optim_path 127 | start_epoch = resume_epoch or start_epoch 128 | 129 | return model_path, optim_path, start_epoch 130 | 131 | 132 | def rel_to_abs_path(path): 133 | if os.path.exists(path): 134 | return path 135 | else: 136 | return os.path.join(EMOASR_ROOT, path) 137 | -------------------------------------------------------------------------------- /utils/vocab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | class Vocab: 6 | def __init__(self, vocab_path: str): 7 | with open(vocab_path) as f: 8 | lines = [line.strip() for line in f] 9 | 10 | i2t = {} 11 | t2i = {} 12 | for line in lines: 13 | token, idx = tuple(line.split()) 14 | i2t[int(idx)] = token 15 | t2i[token] = int(idx) 16 | self.i2t = i2t 17 | self.t2i = t2i 18 | 19 | self.unk_id = t2i[""] 20 | 21 | def id2token(self, idx): 22 | return self.i2t[idx] 23 | 24 | def ids2tokens(self, ids): 25 | return [self.id2token(i) for i in ids] 26 | 27 | def ids2words(self, ids): 28 | return self.subwords_to_words(self.ids2tokens(ids)) 29 | 30 | def ids2text(self, ids): 31 | return " ".join(self.subwords_to_words(self.ids2tokens(ids))) 32 | 33 | def token2id(self, word): 34 | if word in self.t2i: 35 | return self.t2i[word] 36 | return self.unk_id 37 | 38 | def tokens2ids(self, words): 39 | return [self.token2id(w) for w in words] 40 | 41 | def is_subword(self, idx): 42 | subword = self.id2word(idx) 43 | return subword[0] != "_" and subword[0] != "<" 44 | 45 | def subwords_to_words(self, subwords): 46 | """ assume BPE style as in https://github.com/google/sentencepiece 47 | """ 48 | tmp = "" 49 | words = [] 50 | for subword in subwords: 51 | if ( 52 | subword[0] == "▁" or subword[0] == "<" or (tmp and tmp[-1] == ">") 53 | ): # appear new word 54 | if tmp != "": 55 | words.append(tmp) 56 | tmp = "" 57 | 58 | tmp += subword[1:] if subword[0] == "▁" else subword 59 | else: 60 | tmp += subword 61 | 62 | if tmp != "": 63 | words.append(tmp) 64 | return words 65 | --------------------------------------------------------------------------------