├── corpus ├── __init__.py ├── librispeech_char.txt ├── bopomo_vocab.txt ├── librispeech.py ├── preprocess_librispeech.py └── preprocess_dlhlp.py ├── src ├── __init__.py ├── option.py ├── lm.py ├── optim.py ├── collect_batch.py ├── bert_embedding.py ├── ctc.py ├── util.py ├── plugin.py ├── text.py ├── data.py ├── solver.py ├── decode.py ├── asr.py └── module.py ├── bin ├── __init__.py ├── train_lm.py ├── test_asr.py └── train_asr.py ├── tests ├── sample_data │ ├── word.vocab │ ├── demo.png │ ├── subword.model │ ├── subword-16k.model │ ├── subword-460.model │ ├── 3830-12529-0005.wav │ └── character.vocab ├── test_audio.py └── test_text.py ├── requirements.txt ├── script ├── train_lm.sh ├── train.sh └── test.sh ├── config ├── libri │ ├── decode_example.yaml │ ├── lm_example.yaml │ ├── asr_example.yaml │ └── asr_hybrid.yaml ├── dlhlp_test.yaml ├── librispeech_test.yaml ├── dlhlp_lm.yaml ├── librispeech_lm.yaml ├── dlhlp_asr.yaml ├── librispeech_asr.yaml └── README.md ├── util ├── get_gd.sh └── generate_vocab_file.py ├── LICENSE ├── .gitignore ├── eval.py ├── eval_beam.py ├── README.md └── main.py /corpus/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tests/sample_data/word.vocab: -------------------------------------------------------------------------------- 1 | SPEECH 2 | LAB 3 | IS 4 | GREAT 5 | -------------------------------------------------------------------------------- /tests/sample_data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/End-to-end-ASR-Pytorch-DLHLP/HEAD/tests/sample_data/demo.png -------------------------------------------------------------------------------- /tests/sample_data/subword.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/End-to-end-ASR-Pytorch-DLHLP/HEAD/tests/sample_data/subword.model -------------------------------------------------------------------------------- /tests/sample_data/subword-16k.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/End-to-end-ASR-Pytorch-DLHLP/HEAD/tests/sample_data/subword-16k.model -------------------------------------------------------------------------------- /tests/sample_data/subword-460.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/End-to-end-ASR-Pytorch-DLHLP/HEAD/tests/sample_data/subword-460.model -------------------------------------------------------------------------------- /tests/sample_data/3830-12529-0005.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vectominist/End-to-end-ASR-Pytorch-DLHLP/HEAD/tests/sample_data/3830-12529-0005.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | pandas 3 | future 4 | joblib 5 | pyyaml 6 | sentencepiece 7 | editdistance 8 | tb-nightly 9 | torch>=1.2.0 10 | torchaudio 11 | matplotlib 12 | librosa -------------------------------------------------------------------------------- /corpus/librispeech_char.txt: -------------------------------------------------------------------------------- 1 | 2 | ' 3 | A 4 | B 5 | C 6 | D 7 | E 8 | F 9 | G 10 | H 11 | I 12 | J 13 | K 14 | L 15 | M 16 | N 17 | O 18 | P 19 | Q 20 | R 21 | S 22 | T 23 | U 24 | V 25 | W 26 | X 27 | Y 28 | Z -------------------------------------------------------------------------------- /tests/sample_data/character.vocab: -------------------------------------------------------------------------------- 1 | 2 | A 3 | B 4 | C 5 | D 6 | E 7 | F 8 | G 9 | H 10 | I 11 | J 12 | K 13 | L 14 | M 15 | N 16 | O 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | U 23 | V 24 | W 25 | X 26 | Y 27 | Z 28 | ' -------------------------------------------------------------------------------- /corpus/bopomo_vocab.txt: -------------------------------------------------------------------------------- 1 | 2 | ˇ 3 | ˊ 4 | ˋ 5 | ˙ 6 | ㄅ 7 | ㄆ 8 | ㄇ 9 | ㄈ 10 | ㄉ 11 | ㄊ 12 | ㄋ 13 | ㄌ 14 | ㄍ 15 | ㄎ 16 | ㄏ 17 | ㄐ 18 | ㄑ 19 | ㄒ 20 | ㄓ 21 | ㄔ 22 | ㄕ 23 | ㄖ 24 | ㄗ 25 | ㄘ 26 | ㄙ 27 | ㄚ 28 | ㄛ 29 | ㄜ 30 | ㄝ 31 | ㄞ 32 | ㄟ 33 | ㄠ 34 | ㄡ 35 | ㄢ 36 | ㄣ 37 | ㄤ 38 | ㄥ 39 | ㄦ 40 | ㄧ 41 | ㄨ 42 | ㄩ -------------------------------------------------------------------------------- /script/train_lm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CONFIG="librispeech_lm" 4 | DIR="/data/storage/harry/E2E_ASR" 5 | 6 | echo "Start running training process of RNNLM" 7 | CUDA_VISIBLE_DEVICES=$2 python3 main.py --config config/${CONFIG}.yaml \ 8 | --name $1 \ 9 | --njobs 8 \ 10 | --seed 0 \ 11 | --lm \ 12 | --logdir ${DIR}/log/ \ 13 | --ckpdir ${DIR}/ckpt/ \ 14 | --outdir ${DIR}/result/ \ 15 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | # Default parameters which will be imported by solver 2 | default_hparas = { 3 | 'GRAD_CLIP': 5.0, # Grad. clip threshold 4 | 'PROGRESS_STEP': 100, # Std. output refresh freq. 5 | # Decode steps for objective validation (step = ratio*input_txt_len) 6 | 'DEV_STEP_RATIO': 1.2, 7 | # Number of examples (alignment/text) to show in tensorboard 8 | 'DEV_N_EXAMPLE': 4, 9 | 'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs) 10 | } 11 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # $1 : experiment name 4 | # $2 : cuda id 5 | 6 | CONFIG="dlhlp_asr" 7 | 8 | DIR="/data/storage/harry/E2E_ASR" 9 | 10 | echo "Start running training process of E2E ASR" 11 | CUDA_VISIBLE_DEVICES=$2 python3 main.py --config config/${CONFIG}.yaml \ 12 | --name $1 \ 13 | --njobs 8 \ 14 | --seed 0 \ 15 | --logdir ${DIR}/log/ \ 16 | --ckpdir ${DIR}/ckpt/ \ 17 | --outdir ${DIR}/result/ \ 18 | # --load ${DIR}/ckpt/$1/best_ctc_LibriSpeech.pth \ 19 | -------------------------------------------------------------------------------- /config/libri/decode_example.yaml: -------------------------------------------------------------------------------- 1 | # Most of the parameters will be imported from the training config 2 | src: 3 | ckpt: 'ckpt/asr_example_sd0/best_att.pth' 4 | config: 'config/libri/asr_example.yaml' 5 | data: 6 | corpus: 7 | name: 'Librispeech' 8 | dev_split: ['dev-clean'] 9 | test_split: ['test-clean'] 10 | decode: 11 | beam_size: 20 12 | min_len_ratio: 0.01 13 | max_len_ratio: 0.07 14 | lm_path: 'ckpt/lm_example_sd0/best_ppx.pth' 15 | lm_config: 'config/libri/lm_example.yaml' 16 | lm_weight: 0.5 17 | ctc_weight: 0.0 -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run testing process 3 | 4 | # $1 : Experiment name 5 | # $2 : Cuda id 6 | 7 | CONFIG="dlhlp_test" 8 | DIR="/data/storage/harry/E2E_ASR" 9 | 10 | echo "Start running testing process of E2E ASR" 11 | CUDA_VISIBLE_DEVICES=$2 python3 main.py --config config/${CONFIG}.yaml \ 12 | --name $1 \ 13 | --test \ 14 | --njobs 8 \ 15 | --seed 0 \ 16 | --ckpdir ${DIR}/ckpt/$1 \ 17 | --outdir ${DIR}/test_result/$1 18 | 19 | # Eval 20 | python3 eval.py --file ${DIR}/test_result/$1/$1_dev_output.csv 21 | python3 eval.py --file ${DIR}/test_result/$1/$1_test_output.csv 22 | -------------------------------------------------------------------------------- /util/get_gd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | download_from_gdrive() { 4 | file_id=$1 5 | file_name=$2 6 | 7 | # first stage to get the warning html 8 | curl -c /tmp/cookies \ 9 | "https://drive.google.com/uc?export=download&id=$file_id" > \ 10 | /tmp/intermezzo.html 11 | 12 | # second stage to extract the download link from html above 13 | download_link=$(cat /tmp/intermezzo.html | \ 14 | grep -Po 'uc-download-link" [^>]* href="\K[^"]*' | \ 15 | sed 's/\&/\&/g') 16 | curl -L -b /tmp/cookies \ 17 | "https://drive.google.com$download_link" > $file_name 18 | } 19 | 20 | download_from_gdrive $1 $2 21 | -------------------------------------------------------------------------------- /config/dlhlp_test.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | name: 'DLHLP' # Specify corpus 5 | path: '/data/storage/harry/test_env/DLHLP' 6 | dev_split: ['dev'] 7 | test_split: ['test'] 8 | bucketing: False 9 | batch_size: 1 10 | src: 11 | config: '/data/storage/harry/End-to-end-ASR-Pytorch/config/dlhlp_asr.yaml' 12 | ckpt: '/data/storage/harry/test_env/dlhlp_ckpt/hw1_best_att_dev.pth' 13 | 14 | decode: 15 | beam_size: 1 16 | min_len_ratio: 0.01 17 | max_len_ratio: 0.30 18 | lm_path: '/data/storage/harry/test_env/dlhlp_ckpt/hw1_best_ppx.pth' 19 | lm_config: 'config/dlhlp_lm.yaml' 20 | lm_weight: 0.7 21 | ctc_weight: 0 22 | -------------------------------------------------------------------------------- /config/librispeech_test.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | path: '/data/storage/harry/LibriSpeech' 5 | name: 'LibriSpeech' 6 | dev_split: ['LibriSpeech', 'dev-clean', 'dev'] 7 | test_split: ['LibriSpeech', 'test-clean', 'test'] 8 | bucketing: False 9 | batch_size: 1 10 | src: 11 | config: '/data/storage/harry/End-to-end-ASR-Pytorch/config/librispeech_asr.yaml' 12 | ckpt: '/data/storage/harry/E2E_ASR/ckpt/LS100_JOINT_0/best_att_LibriSpeech.pth' 13 | 14 | decode: 15 | ctc_weight: 0 16 | beam_size: 8 17 | # vocab_candidate: 12 18 | min_len_ratio: 0.01 19 | max_len_ratio: 0.25 20 | lm_config: '/data/storage/harry/End-to-end-ASR-Pytorch/config/librispeech_lm.yaml' 21 | lm_path: '/data/storage/harry/E2E_ASR/ckpt/RNNLM_LS100_TEST_0/best_ppx.pth' 22 | lm_weight: 0.3 -------------------------------------------------------------------------------- /config/dlhlp_lm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | name: 'DLHLP' # Specify corpus 5 | path: '/data/storage/harry/test_env/DLHLP' 6 | train_split: ['train'] 7 | dev_split: ['dev'] 8 | bucketing: True 9 | batch_size: 64 10 | text: 11 | mode: 'character' # 'character'/'word'/'subword' 12 | vocab_file: 'corpus/bopomo_vocab.txt' 13 | 14 | hparas: # Experiment hyper-parameters 15 | valid_step: 2000 16 | max_step: 1000000 17 | optimizer: 'Adam' 18 | lr: 0.0001 19 | eps: 0.00000001 20 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 21 | 22 | model: # Model architecture 23 | emb_tying: False # https://arxiv.org/pdf/1608.05859.pdf 24 | emb_dim: 512 25 | module: 'GRU' # 'LSTM'/'GRU' 26 | dim: 512 27 | n_layers: 2 28 | dropout: 0.5 29 | -------------------------------------------------------------------------------- /config/libri/lm_example.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | name: 'Librispeech' # Specify corpus 5 | path: 'data/LibriSpeech' 6 | train_split: ['librispeech-lm-norm.txt'] # Official LM src from LibriSpeech 7 | dev_split: ['dev-clean'] 8 | bucketing: True 9 | batch_size: 32 10 | text: 11 | mode: 'subword' # 'character'/'word'/'subword' 12 | vocab_file: 'tests/sample_data/subword-16k.model' 13 | 14 | hparas: # Experiment hyper-parameters 15 | valid_step: 10000 16 | max_step: 100000000 17 | optimizer: 'Adam' 18 | lr: 0.0001 19 | eps: 0.00000001 20 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 21 | 22 | model: # Model architecture 23 | emb_tying: False # https://arxiv.org/pdf/1608.05859.pdf 24 | emb_dim: 1024 25 | module: 'LSTM' # 'LSTM'/'GRU' 26 | dim: 1024 27 | n_layers: 2 28 | dropout: 0.5 29 | 30 | 31 | -------------------------------------------------------------------------------- /config/librispeech_lm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | name: 'Librispeech' # Specify corpus 5 | path: '/data/storage/harry/LibriSpeech' 6 | # train_split: ['librispeech-lm-norm.txt'] # Official LM src from LibriSpeech 7 | train_split: ['train-clean-100'] 8 | dev_split: ['dev-clean'] 9 | bucketing: True 10 | batch_size: 128 11 | text: 12 | mode: 'character' # 'character'/'word'/'subword' 13 | vocab_file: 'corpus/librispeech_char.txt' 14 | 15 | hparas: # Experiment hyper-parameters 16 | valid_step: 5000 17 | max_step: 1000000 18 | optimizer: 'Adam' 19 | lr: 0.0001 20 | eps: 0.00000001 21 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 22 | 23 | model: # Model architecture 24 | emb_tying: False # https://arxiv.org/pdf/1608.05859.pdf 25 | emb_dim: 256 26 | module: 'LSTM' # 'LSTM'/'GRU' 27 | dim: 256 28 | n_layers: 2 29 | dropout: 0.5 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 XenderLiu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RNNLM(nn.Module): 6 | ''' RNN Language Model ''' 7 | def __init__(self, vocab_size, emb_tying, emb_dim, module, dim, n_layers, dropout): 8 | super().__init__() 9 | self.dim = dim 10 | self.n_layers = n_layers 11 | self.emb_tying = emb_tying 12 | if emb_tying: 13 | assert emb_dim==dim, "Output dim of RNN should be identical to embedding if using weight tying." 14 | self.vocab_size = vocab_size 15 | self.emb = nn.Embedding(vocab_size, emb_dim) 16 | self.dp1 = nn.Dropout(dropout) 17 | self.dp2 = nn.Dropout(dropout) 18 | self.rnn = getattr(nn, module.upper())(emb_dim, dim, num_layers=n_layers, dropout=dropout, batch_first=True) 19 | if not self.emb_tying: 20 | self.trans = nn.Linear(emb_dim,vocab_size) 21 | 22 | def create_msg(self): 23 | # Messages for user 24 | msg = ['Model spec.| RNNLM weight tying = {}, # of layers = {}, dim = {}'.format(self.emb_tying,self.n_layers,self.dim)] 25 | return msg 26 | 27 | def forward(self, x, lens, hidden=None): 28 | emb_x = self.dp1(self.emb(x)) 29 | if not self.training: 30 | self.rnn.flatten_parameters() 31 | packed = nn.utils.rnn.pack_padded_sequence(emb_x, lens, batch_first=True, enforce_sorted=False) 32 | outputs, hidden = self.rnn(packed, hidden) # output: (seq_len, batch, hidden) 33 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 34 | if self.emb_tying: 35 | outputs = F.linear(self.dp2(outputs),self.emb.weight) 36 | else: 37 | outputs = self.trans(self.dp2(outputs)) 38 | return outputs, hidden 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # data and log directory 118 | log 119 | runs 120 | save 121 | TODO 122 | checkpoint 123 | run.sh 124 | 125 | # swap file 126 | *.swp 127 | 128 | *.DS_store -------------------------------------------------------------------------------- /src/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class Optimizer(): 5 | def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, 6 | tf_start=1, tf_end=1, tf_step=1, tf_step_start=0, 7 | weight_decay=0, amsgrad=False, **kwargs): 8 | 9 | # Setup teacher forcing scheduler 10 | self.tf_type = tf_end!=1 11 | self.tf_rate = lambda step: max(tf_end, 12 | tf_start-(tf_start-tf_end)*(step-tf_step_start)/tf_step if step >= tf_step_start else 1) 13 | 14 | # Setup torch optimizer 15 | self.opt_type = optimizer 16 | self.init_lr = lr 17 | self.sch_type = lr_scheduler 18 | opt = getattr(torch.optim,optimizer) 19 | if lr_scheduler == 'warmup': 20 | warmup_step = 4000.0 21 | init_lr = lr 22 | self.lr_scheduler = lambda step: init_lr * warmup_step **0.5 * np.minimum((step+1)*warmup_step**-1.5,(step+1)**-0.5 ) 23 | self.opt = opt(parameters,lr=1.0) 24 | else: 25 | self.lr_scheduler = None 26 | if optimizer.lower()[:4] == 'adam': 27 | self.opt = opt(parameters,lr=lr,eps=eps,weight_decay=weight_decay,amsgrad=amsgrad) # ToDo: 1e-8 better? 28 | else: 29 | self.opt = opt(parameters,lr=lr,eps=eps,weight_decay=weight_decay) # ToDo: 1e-8 better? 30 | 31 | def get_opt_state_dict(self): 32 | return self.opt.state_dict() 33 | 34 | def load_opt_state_dict(self,state_dict): 35 | self.opt.load_state_dict(state_dict) 36 | 37 | def pre_step(self, step): 38 | if self.lr_scheduler is not None: 39 | cur_lr = self.lr_scheduler(step) 40 | for param_group in self.opt.param_groups: 41 | param_group['lr'] = cur_lr 42 | self.opt.zero_grad() 43 | return self.tf_rate(step) 44 | 45 | def get_lr(self, step): 46 | if self.lr_scheduler is not None: 47 | return self.lr_scheduler(step) 48 | else: 49 | return self.init_lr 50 | 51 | def step(self): 52 | self.opt.step() 53 | 54 | def create_msg(self): 55 | return ['Optim.spec.| Algo. = {}\t| Lr = {}\t (schedule = {})| Scheduled sampling = {}'\ 56 | .format(self.opt_type, self.init_lr, self.sch_type, self.tf_type)] 57 | -------------------------------------------------------------------------------- /config/dlhlp_asr.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | name: 'DLHLP' # Specify corpus 5 | path: '/data/storage/harry/test_env/DLHLP' 6 | train_split: ['train'] 7 | dev_split: ['dev'] 8 | bucketing: True 9 | batch_size: 32 10 | 11 | audio: # Pass to audio transform 12 | feat_type: 'fbank' 13 | feat_dim: 80 14 | apply_cmvn: False 15 | delta_order: 1 # 0: do nothing, 1: add delta, 2: add delta and accelerate 16 | delta_window_size: 2 17 | frame_length: 25 # ms 18 | frame_shift: 10 # ms 19 | ref_level_db: 20 20 | min_level_db: -100 21 | preemphasis_coeff: 0.97 22 | 23 | text: 24 | mode: 'character' # 'character'/'word == phone'/'subword' 25 | vocab_file: 'corpus/bopomo_vocab.txt' 26 | 27 | hparas: # Experiment hyper-parameters 28 | valid_step: 500 29 | max_step: 80000 30 | tf_start: 1.0 31 | tf_end: 1.0 32 | tf_step: 150000 33 | optimizer: 'Adadelta' 34 | lr: 1.0 35 | eps: 0.00000001 # 1e-8 36 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 37 | curriculum: 0 38 | val_mode: 'cer' 39 | 40 | model: # Model architecture 41 | ctc_weight: 0.5 # Weight for CTC loss 42 | encoder: 43 | vgg: 1 # 4x reduction on time feature extraction 44 | vgg_freq: 1 45 | vgg_low_filt: -1 46 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 47 | bidirection: True 48 | dim: [512,512] 49 | dropout: [0.2,0.2] 50 | layer_norm: [False,False] 51 | proj: [True,True] # Linear projection + Tanh after each rnn layer 52 | sample_rate: [1,1] 53 | sample_style: 'drop' # 'drop'/'concat' 54 | attention: 55 | mode: 'loc' # 'dot'/'loc' 56 | dim: 300 57 | num_head: 1 58 | v_proj: False # if False and num_head>1, encoder state will be duplicated for each head 59 | temperature: 0.5 # scaling factor for attention 60 | loc_kernel_size: 100 # just for mode=='loc' 61 | loc_kernel_num: 10 # just for mode=='loc' 62 | decoder: 63 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 64 | dim: 512 65 | layer: 1 66 | dropout: 0 67 | -------------------------------------------------------------------------------- /config/librispeech_asr.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: # Pass to dataloader 3 | # The following depends on corpus 4 | path: '/data/storage/harry' 5 | name: 'LibriSpeech' 6 | train_split: ['LibriSpeech', 'train-clean-100', 'train'] 7 | dev_split: [['LibriSpeech', 'dev-clean', 'dev']] 8 | bucketing: True 9 | batch_size: 16 10 | 11 | audio: # Pass to audio transform 12 | feat_type: 'fbank' 13 | feat_dim: 80 14 | apply_cmvn: False 15 | delta_order: 1 # 0: do nothing, 1: add delta, 2: add delta and accelerate 16 | delta_window_size: 2 17 | frame_length: 25 # ms 18 | frame_shift: 10 # ms 19 | ref_level_db: 20 20 | min_level_db: -100 21 | preemphasis_coeff: 0.97 22 | 23 | text: 24 | mode: 'character' # 'character'/'word == phone'/'subword' 25 | vocab_file: 'corpus/librispeech_char.txt' 26 | 27 | hparas: # Experiment hyper-parameters 28 | valid_step: 5000 29 | max_step: 200000 30 | tf_start: 1.0 31 | tf_end: 1.0 32 | tf_step: 150000 33 | optimizer: 'Adadelta' 34 | lr: 1.0 35 | eps: 0.00000001 # 1e-8 36 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 37 | curriculum: 0 38 | val_mode: 'wer' 39 | 40 | model: # Model architecture 41 | ctc_weight: 0.5 # Weight for CTC loss 42 | encoder: 43 | vgg: 0 # 4x reduction on time feature extraction 44 | vgg_freq: -1 45 | vgg_low_filt: -1 46 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 47 | bidirection: True 48 | dim: [320,320,320,320] 49 | dropout: [0.2,0.2,0.2,0.2] 50 | layer_norm: [False,False,False,False] 51 | proj: [True,True,True,True] # Linear projection + Tanh after each rnn layer 52 | sample_rate: [1,2,1,1] 53 | sample_style: 'drop' # 'drop'/'concat' 54 | attention: 55 | mode: 'loc' # 'dot'/'loc' 56 | dim: 300 57 | num_head: 1 58 | v_proj: False # if False and num_head>1, encoder state will be duplicated for each head 59 | temperature: 0.5 # scaling factor for attention 60 | loc_kernel_size: 100 # just for mode=='loc' 61 | loc_kernel_num: 10 # just for mode=='loc' 62 | decoder: 63 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 64 | dim: 300 65 | layer: 1 66 | dropout: 0 -------------------------------------------------------------------------------- /util/generate_vocab_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from collections import Counter 4 | 5 | 6 | def main(args): 7 | 8 | if args.mode == "subword": 9 | logging.warn("Subword model is based on `sentencepiece`.") 10 | 11 | import sentencepiece as splib 12 | 13 | cmd = ("--input={} --model_prefix={} --model_type=bpe " 14 | "--vocab_size={} --character_coverage={} " 15 | "--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 " 16 | "--eos_piece= --remove_extra_whitespaces=true".format( 17 | args.input_file, args.output_file, 18 | args.vocab_size, args.character_coverage)) 19 | 20 | splib.SentencePieceTrainer.Train(cmd) 21 | else: 22 | with open(args.input_file, "r") as f: 23 | lines = [line.strip("\r\n ") for line in f] 24 | counter = Counter() 25 | if args.mode == "word": 26 | for line in lines: 27 | counter.update(line.split()) 28 | # In word mode, vocab_list is sorted by frequency 29 | # Only selected top `vocab_size` vocabularies 30 | vocab_list = sorted( 31 | counter.keys(), key=lambda k: counter[k], reverse=True)[:args.vocab_size] 32 | elif args.mode == "character": 33 | for line in lines: 34 | counter.update(line) 35 | # In character mode, vocab_list is sorted in alphabetical order 36 | vocab_list = sorted(counter) 37 | 38 | logging.info("Collected totally {} vocabularies.".format(len(counter))) 39 | logging.info("Selected {} vocabularies.".format(len(vocab_list))) 40 | 41 | with open(args.output_file, "w") as f: 42 | f.write("\n".join(vocab_list)) 43 | 44 | 45 | if __name__ == "__main__": 46 | logging.getLogger().setLevel(logging.INFO) 47 | 48 | parser = argparse.ArgumentParser( 49 | "Utility script to generate `vocab_file` needed by text encoder.") 50 | parser.add_argument("--input_file", required=True) 51 | parser.add_argument( 52 | "--mode", choices=["character", "word", "subword"], required=True) 53 | parser.add_argument("--output_file", required=True) 54 | parser.add_argument("--vocab_size", type=int, default=5000) 55 | parser.add_argument("--character_coverage", type=float, default=1) 56 | 57 | args = parser.parse_args() 58 | 59 | if args.mode != "subword": 60 | logging.warn( 61 | "`character_coverage` is not used in `word` and `character` mode.") 62 | if args.mode == "character": 63 | logging.warn("`vocab_size` is not used in `character` mode.") 64 | 65 | main(args) 66 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import editdistance as ed 5 | 6 | SEP = ' ' 7 | 8 | # Arguments 9 | parser = argparse.ArgumentParser(description='Script for evaluating recognition results.') 10 | parser.add_argument('--file', type=str, help='Path to result csv.') 11 | paras = parser.parse_args() 12 | 13 | # Error rate functions 14 | def cal_cer(row): 15 | return 100*float(ed.eval(row.hyp,row.truth))/len(row.truth) 16 | def cal_wer(row): 17 | return 100*float(ed.eval(row.hyp.split(SEP),row.truth.split(SEP)))/len(row.truth.split(SEP)) 18 | 19 | # Evaluation 20 | result = pd.read_csv(paras.file,sep='\t',keep_default_na=False) 21 | result['hyp_char_cnt'] = result.apply(lambda x: len(x.hyp),axis=1) 22 | result['hyp_word_cnt'] = result.apply(lambda x: len(x.hyp.split(SEP)),axis=1) 23 | result['truth_char_cnt'] = result.apply(lambda x: len(x.truth),axis=1) 24 | result['truth_word_cnt'] = result.apply(lambda x: len(x.truth.split(SEP)),axis=1) 25 | result['cer'] = result.apply(cal_cer,axis=1) 26 | result['wer'] = result.apply(cal_wer,axis=1) 27 | 28 | # Show results 29 | print() 30 | print('============ Result of',paras.file,'============') 31 | print(' -----------------------------------------------------------------------') 32 | print('| Statics\t\t| Truth\t| Prediction\t| Abs. Diff.\t|') 33 | print(' -----------------------------------------------------------------------') 34 | print('| Avg. # of chars\t| {:.2f}\t| {:.2f}\t| {:.2f}\t\t|'.\ 35 | format(result.truth_char_cnt.mean(), result.hyp_char_cnt.mean(), 36 | np.mean(np.abs(result.truth_char_cnt-result.hyp_char_cnt)))) 37 | print('| Avg. # of words\t| {:.2f}\t| {:.2f}\t| {:.2f}\t\t|'.\ 38 | format(result.truth_word_cnt.mean(), result.hyp_word_cnt.mean(), 39 | np.mean(np.abs(result.truth_word_cnt-result.hyp_word_cnt)))) 40 | print(' -----------------------------------------------------------------------') 41 | print(' ---------------------------------------------------------------') 42 | print('| Error Rate (%)| Mean\t\t| Std.\t\t| Min./Max.\t|') 43 | print(' ---------------------------------------------------------------') 44 | print('| Character\t| {:2.4f}\t| {:.2f}\t\t| {:.2f}/{:.2f}\t|'.format(result.cer.mean(),result.cer.std(), 45 | result.cer.min(),result.cer.max())) 46 | print('| Word\t\t| {:2.4f}\t| {:.2f}\t\t| {:.2f}/{:.2f}\t|'.format(result.wer.mean(),result.wer.std(), 47 | result.wer.min(),result.wer.max())) 48 | print(' ---------------------------------------------------------------') 49 | print('Note : If the text unit is phoneme, WER = PER and CER is meaningless.') 50 | print() -------------------------------------------------------------------------------- /config/libri/asr_example.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: 3 | name: 'Librispeech' # Specify corpus 4 | path: 'data/LibriSpeech' # Path to raw LibriSpeech dataset 5 | train_split: ['train-clean-100','train-clean-360'] # Name of data splits to be used as training set 6 | dev_split: ['dev-clean'] # Name of data splits to be used as validation set 7 | bucketing: True # Enable/Disable bucketing 8 | batch_size: 16 9 | audio: # Attributes of audio feature 10 | feat_type: 'fbank' 11 | feat_dim: 40 12 | frame_length: 25 # ms 13 | frame_shift: 10 # ms 14 | dither: 0 # random dither audio, 0: no dither 15 | apply_cmvn: True 16 | delta_order: 2 # 0: do nothing, 1: add delta, 2: add delta and accelerate 17 | delta_window_size: 2 18 | text: 19 | mode: 'subword' # 'character'/'word'/'subword' 20 | vocab_file: 'tests/sample_data/subword-16k.model' 21 | 22 | hparas: # Experiment hyper-parameters 23 | valid_step: 5000 24 | max_step: 1000001 25 | tf_start: 1.0 26 | tf_end: 1.0 27 | tf_step: 500000 28 | optimizer: 'Adadelta' 29 | lr: 1.0 30 | eps: 0.00000001 # 1e-8 31 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 32 | curriculum: 0 33 | 34 | model: # Model architecture 35 | ctc_weight: 0.0 # Weight for CTC loss 36 | encoder: 37 | prenet: 'vgg' # 'vgg'/'cnn'/'' 38 | # vgg: True # 4x reduction on time feature extraction 39 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 40 | bidirection: True 41 | dim: [512,512,512,512,512] 42 | dropout: [0,0,0,0,0] 43 | layer_norm: [False,False,False,False,False] 44 | proj: [True,True,True,True,True] # Linear projection + Tanh after each rnn layer 45 | sample_rate: [1,1,1,1,1] 46 | sample_style: 'drop' # 'drop'/'concat' 47 | attention: 48 | mode: 'loc' # 'dot'/'loc' 49 | dim: 300 50 | num_head: 1 51 | v_proj: False # if False and num_head>1, encoder state will be duplicated for each head 52 | temperature: 0.5 # scaling factor for attention 53 | loc_kernel_size: 100 # just for mode=='loc' 54 | loc_kernel_num: 10 # just for mode=='loc' 55 | decoder: 56 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 57 | dim: 512 58 | layer: 1 59 | dropout: 0 60 | -------------------------------------------------------------------------------- /config/libri/asr_hybrid.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | corpus: 3 | name: 'Librispeech' # Specify corpus 4 | path: 'data/LibriSpeech' # Path to raw LibriSpeech dataset 5 | train_split: ['train-clean-100','train-clean-360'] # Name of data splits to be used as training set 6 | dev_split: ['dev-clean'] # Name of data splits to be used as validation set 7 | bucketing: True # Enable/Disable bucketing 8 | batch_size: 16 9 | audio: # Attributes of audio feature 10 | feat_type: 'fbank' 11 | feat_dim: 40 12 | frame_length: 25 # ms 13 | frame_shift: 10 # ms 14 | dither: 0 # random dither audio, 0: no dither 15 | apply_cmvn: True 16 | delta_order: 2 # 0: do nothing, 1: add delta, 2: add delta and accelerate 17 | delta_window_size: 2 18 | text: 19 | mode: 'subword' # 'character'/'word'/'subword' 20 | vocab_file: 'tests/sample_data/subword-16k.model' 21 | 22 | hparas: # Experiment hyper-parameters 23 | valid_step: 5000 24 | max_step: 1000001 25 | tf_start: 1.0 26 | tf_end: 1.0 27 | tf_step: 500000 28 | optimizer: 'Adadelta' 29 | lr: 1.0 30 | eps: 0.00000001 # 1e-8 31 | lr_scheduler: 'fixed' # 'fixed'/'warmup' 32 | curriculum: 0 33 | 34 | model: # Model architecture 35 | ctc_weight: 0.5 # Weight for CTC loss 36 | encoder: 37 | prenet: 'vgg' # 'vgg'/'cnn'/'' 38 | # vgg: True # 4x reduction on time feature extraction 39 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 40 | bidirection: True 41 | dim: [512,512,512,512,512] 42 | dropout: [0,0,0,0,0] 43 | layer_norm: [False,False,False,False,False] 44 | proj: [True,True,True,True,True] # Linear projection + Tanh after each rnn layer 45 | sample_rate: [1,1,1,1,1] 46 | sample_style: 'drop' # 'drop'/'concat' 47 | attention: 48 | mode: 'loc' # 'dot'/'loc' 49 | dim: 300 50 | num_head: 1 51 | v_proj: False # if False and num_head>1, encoder state will be duplicated for each head 52 | temperature: 0.5 # scaling factor for attention 53 | loc_kernel_size: 100 # just for mode=='loc' 54 | loc_kernel_num: 10 # just for mode=='loc' 55 | decoder: 56 | module: 'LSTM' # 'LSTM'/'GRU'/'Transformer' 57 | dim: 512 58 | layer: 1 59 | dropout: 0 60 | -------------------------------------------------------------------------------- /src/collect_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.utils.rnn import pad_sequence 4 | import torch.nn.functional as F 5 | 6 | HALF_BATCHSIZE_AUDIO_LEN = 800 # Batch size will be halfed if the longest wavefile surpasses threshold 7 | # Note: Bucketing may cause random sampling to be biased (less sampled for those length > HALF_BATCHSIZE_AUDIO_LEN ) 8 | HALF_BATCHSIZE_TEXT_LEN = 150 9 | 10 | def collect_audio_batch(batch, audio_transform, mode): 11 | '''Collects a batch, should be list of tuples (audio_path , list of int token ) 12 | e.g. [(file1,txt1),(file2,txt2),...] ''' 13 | 14 | # Bucketed batch should be [[(file1,txt1),(file2,txt2),...]] 15 | if type(batch[0]) is not tuple: 16 | batch = batch[0] 17 | # Make sure that batch size is reasonable 18 | # For each bucket, the first audio must be the longest one 19 | # But for multi-dataset, this is not the case !!!! 20 | 21 | if HALF_BATCHSIZE_AUDIO_LEN < 3500 and mode == 'train': 22 | first_len = audio_transform(str(batch[0][0])).shape[0] 23 | if first_len > HALF_BATCHSIZE_AUDIO_LEN: 24 | batch = batch[::2] 25 | 26 | # Read batch 27 | file, audio_feat, audio_len, text = [],[],[],[] 28 | with torch.no_grad(): 29 | for index, b in enumerate(batch): 30 | if type(b[0]) is str: 31 | file.append(str(b[0]).split('/')[-1].split('.')[0]) 32 | feat = audio_transform(str(b[0])) 33 | else: 34 | file.append('dummy') 35 | feat = audio_transform(str(b[0])) 36 | audio_feat.append(feat) 37 | audio_len.append(len(feat)) 38 | text.append(torch.LongTensor(b[1])) 39 | # Descending audio length within each batch 40 | audio_len, file, audio_feat, text = zip(*[(feat_len,f_name,feat,txt) \ 41 | for feat_len,f_name,feat,txt in zip(audio_len,file,audio_feat,text)]) 42 | 43 | # Zero-padding 44 | audio_feat = pad_sequence(audio_feat, batch_first=True) 45 | text = pad_sequence(text, batch_first=True) 46 | audio_len = torch.LongTensor(audio_len) 47 | 48 | return file, audio_feat, audio_len, text 49 | 50 | def collect_text_batch(batch, mode): 51 | '''Collects a batch of text, should be list of list of int token 52 | e.g. [txt1 ,txt2 ,...] ''' 53 | 54 | # Bucketed batch should be [[txt1, txt2,...]] 55 | if type(batch[0][0]) is list: 56 | batch = batch[0] 57 | # Half batch size if input to long 58 | if len(batch[0])>HALF_BATCHSIZE_TEXT_LEN and mode=='train': 59 | batch = batch[:len(batch)//2] 60 | # Read batch 61 | text = [torch.LongTensor(b) for b in batch] 62 | # Zero-padding 63 | text = pad_sequence(text, batch_first=True) 64 | 65 | return text -------------------------------------------------------------------------------- /src/bert_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from src import text 5 | from pytorch_pretrained_bert import BertForMaskedLM 6 | from pytorch_pretrained_bert.modeling import BertOnlyMLMHead 7 | 8 | 9 | class BertLikeSentencePieceTextEncoder(object): 10 | def __init__(self, text_encoder): 11 | if not isinstance(text_encoder, text.SubwordTextEncoder): 12 | raise TypeError( 13 | "`text_encoder` must be an instance of `src.text.SubwordTextEncoder`.") 14 | self.text_encoder = text_encoder 15 | 16 | @property 17 | def vocab_size(self): 18 | # +3 accounts for [CLS], [SEP] and [MASK] 19 | return self.text_encoder.vocab_size + 3 20 | 21 | @property 22 | def cls_idx(self): 23 | return self.vocab_size - 3 24 | 25 | @property 26 | def sep_idx(self): 27 | return self.vocab_size - 2 28 | 29 | @property 30 | def mask_idx(self): 31 | return self.vocab_size - 1 32 | 33 | @property 34 | def eos_idx(self): 35 | return self.text_encoder.eos_idx 36 | 37 | 38 | def generate_embedding(bert_model, labels): 39 | """Generate bert's embedding from fine-tuned model.""" 40 | batch_size, time = labels.shape 41 | 42 | cls_ids = torch.full( 43 | (batch_size, 1), bert_model.bert_text_encoder.cls_idx, dtype=labels.dtype, device=labels.device) 44 | bert_labels = torch.cat([cls_ids, labels], 1) 45 | # replace eos with sep 46 | eos_idx = bert_model.bert_text_encoder.eos_idx 47 | sep_idx = bert_model.bert_text_encoder.sep_idx 48 | bert_labels[bert_labels == eos_idx] = sep_idx 49 | 50 | embedding, _ = bert_model.bert(bert_labels, output_all_encoded_layers=True) 51 | # sum over all layers embedding 52 | embedding = torch.stack(embedding).sum(0) 53 | # get rid of cls 54 | embedding = embedding[:, 1:] 55 | 56 | assert labels.shape == embedding.shape[:-1] 57 | 58 | return embedding 59 | 60 | 61 | def load_fine_tuned_model(bert_model, text_encoder, path): 62 | """Load fine-tuned bert model given text encoder and checkpoint path.""" 63 | bert_text_encoder = BertLikeSentencePieceTextEncoder(text_encoder) 64 | 65 | model = BertForMaskedLM.from_pretrained(bert_model) 66 | model.bert_text_encoder = bert_text_encoder 67 | model.bert.embeddings.word_embeddings = nn.Embedding( 68 | bert_text_encoder.vocab_size, model.bert.embeddings.word_embeddings.weight.shape[1]) 69 | model.config.vocab_size = bert_text_encoder.vocab_size 70 | model.cls = BertOnlyMLMHead( 71 | model.config, model.bert.embeddings.word_embeddings.weight) 72 | 73 | model.load_state_dict(torch.load(path)) 74 | 75 | return model 76 | 77 | 78 | class BertEmbeddingPredictor(nn.Module): 79 | def __init__(self, bert_model, text_encoder, path): 80 | super(BertEmbeddingPredictor, self).__init__() 81 | self.model = load_fine_tuned_model(bert_model, text_encoder, path) 82 | 83 | def forward(self, labels): 84 | # do not modify this 85 | self.eval() 86 | return generate_embedding(self.model, labels) 87 | -------------------------------------------------------------------------------- /eval_beam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | import editdistance as ed 5 | 6 | SEP = ' ' 7 | 8 | # Arguments 9 | parser = argparse.ArgumentParser(description='Script for evaluating recognition results.') 10 | parser.add_argument('--file', type=str, help='Path to result csv.') 11 | paras = parser.parse_args() 12 | 13 | # Error rate functions 14 | def cal_cer(row): 15 | return 100*float(ed.eval(row.hyp,row.truth))/len(row.truth) 16 | def cal_wer(row): 17 | return 100*float(ed.eval(row.hyp.split(SEP),row.truth.split(SEP)))/len(row.truth.split(SEP)) 18 | 19 | # Evaluation 20 | result = pd.read_csv(paras.file,sep='\t',keep_default_na=False) 21 | result['hyp_char_cnt'] = result.apply(lambda x: len(x.hyp),axis=1) 22 | result['hyp_word_cnt'] = result.apply(lambda x: len(x.hyp.split(SEP)),axis=1) 23 | result['truth_char_cnt'] = result.apply(lambda x: len(x.truth),axis=1) 24 | result['truth_word_cnt'] = result.apply(lambda x: len(x.truth.split(SEP)),axis=1) 25 | result['cer'] = result.apply(cal_cer,axis=1) 26 | result['wer'] = result.apply(cal_wer,axis=1) 27 | 28 | result_dict = result.to_dict('index') 29 | cers = [] 30 | wers = [] 31 | prev_idx = '' 32 | for key in result_dict: 33 | if result_dict[key]['idx'] == prev_idx: 34 | cers[-1] = min(cers[-1], result_dict[key]['cer']) 35 | wers[-1] = min(wers[-1], result_dict[key]['wer']) 36 | else: 37 | prev_idx = result_dict[key]['idx'] 38 | cers.append(result_dict[key]['cer']) 39 | wers.append(result_dict[key]['wer']) 40 | cers = np.array(cers) 41 | wers = np.array(wers) 42 | 43 | # Show results 44 | print() 45 | print('============ Result of',paras.file,'============') 46 | print(' -----------------------------------------------------------------------') 47 | print('| Statics\t\t| Truth\t| Prediction\t| Abs. Diff.\t|') 48 | print(' -----------------------------------------------------------------------') 49 | print('| Avg. # of chars\t| {:.2f}\t| {:.2f}\t| {:.2f}\t\t|'.\ 50 | format(result.truth_char_cnt.mean(), result.hyp_char_cnt.mean(), 51 | np.mean(np.abs(result.truth_char_cnt-result.hyp_char_cnt)))) 52 | print('| Avg. # of words\t| {:.2f}\t| {:.2f}\t| {:.2f}\t\t|'.\ 53 | format(result.truth_word_cnt.mean(), result.hyp_word_cnt.mean(), 54 | np.mean(np.abs(result.truth_word_cnt-result.hyp_word_cnt)))) 55 | print(' -----------------------------------------------------------------------') 56 | print(' ---------------------------------------------------------------') 57 | print('| Error Rate (%)| Mean\t\t| Std.\t\t| Min./Max.\t|') 58 | print(' ---------------------------------------------------------------') 59 | print('| Character\t| {:2.4f}\t| {:.2f}\t\t| {:.2f}/{:.2f}\t|'.format(cers.mean(),cers.std(), 60 | cers.min(),cers.max())) 61 | print('| Word\t\t| {:2.4f}\t| {:.2f}\t\t| {:.2f}/{:.2f}\t|'.format(wers.mean(),wers.std(), 62 | wers.min(),wers.max())) 63 | print(' ---------------------------------------------------------------') 64 | print('Note : If the text unit is phoneme, WER = PER and CER is meaningless.') 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end Automatic Speech Recognition Systems - PyTorch Implementation 2 | For complete introdution and usage, please see the original repository [Alexander-H-Liu/End-to-end-ASR-Pytorch](https://github.com/Alexander-H-Liu/End-to-end-ASR-Pytorch). 3 | ## New features 4 | 1. Added layer-wise transfer learning 5 | 2. Supports multiple development sets 6 | 3. Supports FreqCNN (frequency-divided CNN extractor) for whispered speech recognition. 7 | 4. Supports DLHLP corpus for the course [Deep Learning for Human Language Processing](http://speech.ee.ntu.edu.tw/~tlkagk/courses_DLHLP20.html) 8 | 9 | ## Instructions 10 | ### Training 11 | Modify `script/train.sh`, `script/train_lm.sh`, `config/librispeech_asr.yaml`, and `config/librispeech_lm.yaml` first. GPU is required. 12 | ``` 13 | bash script/train.sh 14 | bash script/train_lm.sh 15 | ``` 16 | ### Testing 17 | Modify `script/test.sh` and `config/librispeech_test.sh` first. Increase the number of `--njobs` can speed up decoding process, but might cause OOM. 18 | ``` 19 | bash script/test.sh 20 | ``` 21 | 22 | ## LibriSpeech 100hr Baseline 23 | This baseline is composed of a character-based joint CTC-attention ASR model and an RNNLM which were trained on the LibriSpeech `train-clean-100`. The perplexity of the LM on the `dev-clean` set is 3.66. 24 | 25 | | Decoding | DEV WER(%) | TEST WER(%) | 26 | | -------- | ---------- | ----------- | 27 | | Greedy | 25.4 | 25.9 | 28 | 29 | ## DLHLP Baseline 30 | This baseline is composed of a character-based joint CTC-attention ASR model and an RNN-LM which were trained on the DLHLP training set. 31 | 32 | | Decoding | DEV CER/WER(%) | TEST CER/WER(%) | 33 | | ---------------------- | -------------- | --------------- | 34 | | SpecAugment + Greedy | 1.0 / 3.4 | 0.8 / 3.1 | 35 | | SpecAugment + Beam=5 | 0.8 / 2.9 | 0.7 / 2.6 | 36 | 37 | ## TODO 38 | 1. CTC beam decoding (testing) 39 | 2. SpecAugment (will be released) 40 | 3. Multiple corpora training (will be released) 41 | 4. Support of WSJ and Switchboard dataset (under construction) 42 | 5. Combination of CTC and RNN-LM: RNN transducer (under construction) 43 | 44 | ## Citation 45 | 46 | ``` 47 | @inproceedings{liu2019adversarial, 48 | title={Adversarial Training of End-to-end Speech Recognition Using a Criticizing Language Model}, 49 | author={Liu, Alexander and Lee, Hung-yi and Lee, Lin-shan}, 50 | booktitle={International Conference on Speech RecognitionAcoustics, Speech and Signal Processing (ICASSP)}, 51 | year={2019}, 52 | organization={IEEE} 53 | } 54 | 55 | @inproceedings{alex2019sequencetosequence, 56 | title={Sequence-to-sequence Automatic Speech Recognition with Word Embedding Regularization and Fused Decoding}, 57 | author={Alexander H. Liu and Tzu-Wei Sung and Shun-Po Chuang and Hung-yi Lee and Lin-shan Lee}, 58 | booktitle={International Conference on Speech RecognitionAcoustics, Speech and Signal Processing (ICASSP)}, 59 | year={2020}, 60 | organization={IEEE} 61 | } 62 | 63 | @inproceedings{chang2020endtoend, 64 | title={End-to-end Whispered Speech Recognition with Frequency-weighted Approaches and Pseudo Whisper Pre-training}, 65 | author={Heng-Jui Chang and Alexander H. Liu and Hung-yi Lee and Lin-shan Lee}, 66 | booktitle={Spoken Language Technology Workshop (SLT)}, 67 | year={2021}, 68 | organization={IEEE} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | 5 | from src import audio 6 | 7 | 8 | class TestAudio(unittest.TestCase): 9 | def setUp(self): 10 | super(TestAudio, self).__init__() 11 | self.filepath = "tests/sample_data/3830-12529-0005.wav" 12 | 13 | def test_filter_bank(self): 14 | audio_config = { 15 | "feat_type": "fbank", 16 | "feat_dim": 40, 17 | "apply_cmvn": False, 18 | "frame_length": 25, 19 | "frame_shift": 10, 20 | } 21 | 22 | transform, d = audio.create_transform(audio_config) 23 | y = transform(self.filepath) 24 | self.assertEqual(list(y.shape), [392, d]) 25 | 26 | def test_mfcc(self): 27 | self.skipTest( 28 | "torchaudio.compliance.kaldi.mfcc is not in torchaudio==0.3.0") 29 | audio_config = { 30 | "feat_type": "mfcc", 31 | "feat_dim": 13, 32 | "apply_cmvn": False, 33 | "frame_length": 25, 34 | "frame_shift": 10, 35 | } 36 | 37 | transform, d = audio.create_transform(audio_config) 38 | y = transform(self.filepath) 39 | self.assertEqual(list(y.shape), [392, d]) 40 | 41 | def test_cmvn(self): 42 | audio_config = { 43 | "feat_type": "fbank", 44 | "feat_dim": 40, 45 | "apply_cmvn": True, 46 | "frame_length": 25, 47 | "frame_shift": 10, 48 | } 49 | 50 | transform, d = audio.create_transform(audio_config) 51 | y = transform(self.filepath) 52 | 53 | self.assertEqual(list(y.shape), [392, d]) 54 | np.testing.assert_allclose(y.mean(0), 0.0, rtol=1e-6, atol=5e-5) 55 | np.testing.assert_allclose(y.std(0), 1.0, rtol=1e-6, atol=1e-6) 56 | 57 | def test_delta(self): 58 | audio_config = { 59 | "feat_type": "fbank", 60 | "feat_dim": 40, 61 | "dither": 0.0, 62 | "apply_cmvn": True, 63 | "frame_length": 25, 64 | "frame_shift": 10, 65 | "delta_order": 1, 66 | "delta_window_size": 2, 67 | } 68 | 69 | transform, d = audio.create_transform(audio_config) 70 | y = transform(self.filepath) 71 | 72 | self.assertEqual(list(y.shape), [392, d]) 73 | 74 | audio_config = { 75 | "feat_type": "fbank", 76 | "feat_dim": 40, 77 | "dither": 0.0, 78 | "apply_cmvn": True, 79 | "frame_length": 25, 80 | "frame_shift": 10, 81 | "delta_order": 0, 82 | } 83 | 84 | transform, d = audio.create_transform(audio_config) 85 | y_no_delta = transform(self.filepath) 86 | 87 | np.testing.assert_allclose(y[:, :40], y_no_delta, rtol=1e-5, atol=1e-5) 88 | 89 | def test_delta_delta(self): 90 | audio_config = { 91 | "feat_type": "fbank", 92 | "feat_dim": 40, 93 | "apply_cmvn": True, 94 | "frame_length": 25, 95 | "frame_shift": 10, 96 | "delta_order": 2, 97 | "delta_window_size": 2, 98 | } 99 | 100 | transform, d = audio.create_transform(audio_config) 101 | y = transform(self.filepath) 102 | 103 | self.assertEqual(list(y.shape), [392, d]) 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import yaml 4 | import torch 5 | import argparse 6 | import numpy as np 7 | 8 | # For reproducibility, comment these may speed up training 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | 12 | # Arguments 13 | parser = argparse.ArgumentParser(description='Training E2E asr.') 14 | parser.add_argument('--config', type=str, help='Path to experiment config.') 15 | parser.add_argument('--name', default=None, type=str, help='Name for logging.') 16 | parser.add_argument('--logdir', default='log/', type=str, help='Logging path.', required=False) 17 | parser.add_argument('--ckpdir', default='ckpt/', type=str, help='Checkpoint path.', required=False) 18 | parser.add_argument('--outdir', default='result/', type=str, help='Decode output path.', required=False) 19 | parser.add_argument('--load', default=None, type=str, help='Load pre-trained model (for training only)', required=False) 20 | parser.add_argument('--seed', default=0, type=int, help='Random seed for reproducable results.', required=False) 21 | parser.add_argument('--cudnn-ctc', action='store_true', help='Switches CTC backend from torch to cudnn') 22 | parser.add_argument('--njobs', default=4, type=int, help='Number of threads for dataloader/decoding.', required=False) 23 | parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') 24 | parser.add_argument('--no-pin', action='store_true', help='Disable pin-memory for dataloader') 25 | parser.add_argument('--test', action='store_true', help='Test the model.') 26 | parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') 27 | parser.add_argument('--lm', action='store_true', help='Option for training RNNLM.') 28 | parser.add_argument('--amp', action='store_true', help='Option to enable AMP.') 29 | parser.add_argument('--reserve_gpu', default=0, type=float, help='Option to reserve GPU ram for training.') 30 | parser.add_argument('--jit', action='store_true', help='Option for enabling jit in pytorch. (feature in development)') 31 | parser.add_argument('--cuda', default=0, type=int, help='Choose which gpu to use.') 32 | 33 | paras = parser.parse_args() 34 | setattr(paras,'gpu',not paras.cpu) 35 | setattr(paras,'pin_memory',not paras.no_pin) 36 | setattr(paras,'verbose',not paras.no_msg) 37 | config = yaml.load(open(paras.config,'r'), Loader=yaml.FullLoader) 38 | 39 | print('[INFO] Using config {}'.format(paras.config)) 40 | 41 | np.random.seed(paras.seed) 42 | torch.manual_seed(paras.seed) 43 | if torch.cuda.is_available(): 44 | torch.cuda.manual_seed_all(paras.seed) 45 | # print('There are ', torch.cuda.device_count(), ' device(s) available') 46 | # print('Using device cuda:', str(paras.cuda)) 47 | 48 | # Hack to preserve GPU ram just incase OOM later on server 49 | if paras.gpu and paras.reserve_gpu>0: 50 | buff = torch.randn(int(paras.reserve_gpu*1e9//4)).to(torch.device('cuda:' + str(paras.cuda))) 51 | del buff 52 | 53 | if paras.lm: 54 | # Train RNNLM 55 | from bin.train_lm import Solver 56 | mode = 'train' 57 | else: 58 | if paras.test: 59 | # Test ASR 60 | assert paras.load is None, 'Load option is mutually exclusive to --test' 61 | from bin.test_asr import Solver 62 | mode = 'test' 63 | else: 64 | # Train ASR 65 | from bin.train_asr import Solver 66 | mode = 'train' 67 | 68 | solver = Solver(config,paras,mode) 69 | solver.load_data() 70 | solver.set_model() 71 | solver.exec() 72 | -------------------------------------------------------------------------------- /tests/test_text.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from src import text 4 | 5 | 6 | class TestChacterTextEncoder(unittest.TestCase): 7 | def setUp(self): 8 | super(TestChacterTextEncoder, self).__init__() 9 | self.vocab_file = "tests/sample_data/character.vocab" 10 | self.vocab_list = list(" ABCDEFGHIJKLMNOPQRSTUVWXYZ'") 11 | self.text = "SPEECH LAB!" 12 | 13 | def test_load_from_file(self): 14 | text_encoder = text.CharacterTextEncoder.load_from_file( 15 | self.vocab_file) 16 | self._test_encode_decode(text_encoder) 17 | 18 | def test_from_vocab_list(self): 19 | text_encoder = text.CharacterTextEncoder(self.vocab_list) 20 | self._test_encode_decode(text_encoder) 21 | 22 | def _test_encode_decode(self, text_encoder): 23 | ids = text_encoder.encode(self.text) 24 | 25 | self.assertEqual(31, text_encoder.vocab_size) 26 | self.assertEqual( 27 | ids, [22, 19, 8, 8, 6, 11, 3, 15, 4, 5, 2, 1]) 28 | 29 | decoded = text_encoder.decode(ids) 30 | self.assertEqual(decoded, self.text.replace("!", "")) 31 | 32 | 33 | class TestSubwordTextEncoder(unittest.TestCase): 34 | def setUp(self): 35 | self.filepath = "tests/sample_data/subword.model" 36 | self.text = "SPEECH LAB IS GREAT" 37 | 38 | def test_load_from_file(self): 39 | text_encoder = text.SubwordTextEncoder.load_from_file(self.filepath) 40 | self._test_encode_decode(text_encoder) 41 | 42 | def _test_encode_decode(self, text_encoder): 43 | ids = text_encoder.encode(self.text) 44 | 45 | self.assertEqual(5000, text_encoder.vocab_size) 46 | self.assertEqual(ids, [2845, 1699, 99, 333, 1]) 47 | 48 | decoded = text_encoder.decode(ids) 49 | self.assertEqual(decoded, self.text) 50 | 51 | 52 | class TestWordTextEncoder(unittest.TestCase): 53 | def setUp(self): 54 | super(TestWordTextEncoder, self).__init__() 55 | self.vocab_file = "tests/sample_data/word.vocab" 56 | self.vocab_list = ["SPEECH", "LAB", "IS", "GREAT"] 57 | self.text = "SPEECH LAB IS GREAT !!!" 58 | 59 | def test_load_from_file(self): 60 | text_encoder = text.WordTextEncoder.load_from_file(self.vocab_file) 61 | self._test_encode_decode(text_encoder) 62 | 63 | def test_from_vocab_list(self): 64 | text_encoder = text.WordTextEncoder(self.vocab_list) 65 | self._test_encode_decode(text_encoder) 66 | 67 | def _test_encode_decode(self, text_encoder): 68 | ids = text_encoder.encode(self.text) 69 | 70 | self.assertEqual(7, text_encoder.vocab_size) 71 | self.assertEqual(ids, [3, 4, 5, 6, 2, 1]) 72 | 73 | decoded = text_encoder.decode(ids) 74 | self.assertEqual(decoded, self.text.replace("!!!", "")) 75 | 76 | 77 | class TestBertTextEncoder(unittest.TestCase): 78 | def setUp(self): 79 | super(TestBertTextEncoder, self).__init__() 80 | self.vocab_file = "bert-base-uncased" 81 | self.text = "SPEECH LAB IS GREAT!!!" 82 | 83 | def test_load_from_file(self): 84 | text_encoder = text.BertTextEncoder.load_from_file(self.vocab_file) 85 | self._test_encode_decode(text_encoder) 86 | 87 | def _test_encode_decode(self, text_encoder): 88 | ids = text_encoder.encode(self.text) 89 | 90 | self.assertEqual(28639, text_encoder.vocab_size) 91 | self.assertEqual(ids, [3616, 5848, 1006, 1310, 2, 2, 2, 1]) 92 | 93 | decoded = text_encoder.decode(ids) 94 | self.assertEqual(decoded.upper(), self.text) 95 | -------------------------------------------------------------------------------- /corpus/librispeech.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | from os.path import join, getsize 4 | from joblib import Parallel, delayed 5 | from torch.utils.data import Dataset 6 | 7 | # Additional (official) text src provided 8 | OFFICIAL_TXT_SRC = ['librispeech-lm-norm.txt'] 9 | # Remove longest N sentence in librispeech-lm-norm.txt 10 | REMOVE_TOP_N_TXT = 5000000 11 | # Default num. of threads used for loading LibriSpeech 12 | READ_FILE_THREADS = 4 13 | 14 | 15 | def read_text(file): 16 | '''Get transcription of target wave file, 17 | it's somewhat redundant for accessing each txt multiplt times, 18 | but it works fine with multi-thread''' 19 | src_file = '-'.join(file.split('-')[:-1])+'.trans.txt' 20 | idx = file.split('/')[-1].split('.')[0] 21 | 22 | with open(src_file, 'r') as fp: 23 | for line in fp: 24 | if idx == line.split(' ')[0]: 25 | return line[:-1].split(' ', 1)[1] 26 | 27 | 28 | class LibriDataset(Dataset): 29 | def __init__(self, path, split, tokenizer, bucket_size, ascending=False): 30 | # Setup 31 | self.path = path 32 | self.bucket_size = bucket_size 33 | 34 | # List all wave files 35 | file_list = [] 36 | for s in split: 37 | split_list = list(Path(join(path, s)).rglob("*.flac")) 38 | assert len(split_list) > 0, "No data found @ {}".format(join(path,s)) 39 | file_list += split_list 40 | # Read text 41 | text = Parallel(n_jobs=READ_FILE_THREADS)( 42 | delayed(read_text)(str(f)) for f in file_list) 43 | #text = Parallel(n_jobs=-1)(delayed(tokenizer.encode)(txt) for txt in text) 44 | text = [tokenizer.encode(txt) for txt in text] 45 | 46 | # Sort dataset by text length 47 | #file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list) 48 | self.file_list, self.text = zip(*[(f_name, txt) 49 | for f_name, txt in sorted(zip(file_list, text), reverse=not ascending, key=lambda x:len(x[1]))]) 50 | 51 | def __getitem__(self, index): 52 | if self.bucket_size > 1: 53 | # Return a bucket 54 | index = min(len(self.file_list)-self.bucket_size, index) 55 | return [(f_path, txt) for f_path, txt in 56 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 57 | else: 58 | return self.file_list[index], self.text[index] 59 | 60 | def __len__(self): 61 | return len(self.file_list) 62 | 63 | 64 | class LibriTextDataset(Dataset): 65 | def __init__(self, path, split, tokenizer, bucket_size): 66 | # Setup 67 | self.path = path 68 | self.bucket_size = bucket_size 69 | self.encode_on_fly = False 70 | read_txt_src = [] 71 | 72 | # List all wave files 73 | file_list, all_sent = [], [] 74 | 75 | for s in split: 76 | if s in OFFICIAL_TXT_SRC: 77 | self.encode_on_fly = True 78 | with open(join(path, s), 'r') as f: 79 | all_sent += f.readlines() 80 | file_list += list(Path(join(path, s)).rglob("*.flac")) 81 | assert (len(file_list) > 0) or (len(all_sent) 82 | > 0), "No data found @ {}".format(path) 83 | 84 | # Read text 85 | text = Parallel(n_jobs=READ_FILE_THREADS)( 86 | delayed(read_text)(str(f)) for f in file_list) 87 | all_sent.extend(text) 88 | del text 89 | 90 | # Encode text 91 | if self.encode_on_fly: 92 | self.tokenizer = tokenizer 93 | self.text = all_sent 94 | else: 95 | self.text = [tokenizer.encode(txt) for txt in tqdm(all_sent)] 96 | del all_sent 97 | 98 | # Read file size and sort dataset by file size (Note: feature len. may be different) 99 | self.text = sorted(self.text, reverse=True, key=lambda x: len(x)) 100 | if self.encode_on_fly: 101 | del self.text[:REMOVE_TOP_N_TXT] 102 | 103 | def __getitem__(self, index): 104 | if self.bucket_size > 1: 105 | index = min(len(self.text)-self.bucket_size, index) 106 | if self.encode_on_fly: 107 | for i in range(index, index+self.bucket_size): 108 | if type(self.text[i]) is str: 109 | self.text[i] = self.tokenizer.encode(self.text[i]) 110 | # Return a bucket 111 | return self.text[index:index+self.bucket_size] 112 | else: 113 | if self.encode_on_fly and type(self.text[index]) is str: 114 | self.text[index] = self.tokenizer.encode(self.text[index]) 115 | return self.text[index] 116 | 117 | def __len__(self): 118 | return len(self.text) 119 | -------------------------------------------------------------------------------- /src/ctc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class CTCPrefixScore(): 5 | ''' 6 | CTC Prefix score calculator 7 | An implementation of Algo. 2 in https://www.merl.com/publications/docs/TR2017-190.pdf (Watanabe et. al.) 8 | Reference (official implementation): https://github.com/espnet/espnet/tree/master/espnet/nets 9 | ''' 10 | 11 | def __init__(self, x): 12 | self.logzero = -100000000.0 13 | self.blank = 0 14 | self.eos = 1 15 | self.x = x.cpu().numpy()[0] 16 | self.odim = x.shape[-1] 17 | self.input_length = len(self.x) 18 | 19 | def init_state(self): 20 | # 0 = non-blank, 1 = blank 21 | r = np.full((self.input_length, 2), self.logzero, dtype=np.float32) 22 | 23 | # Accumalate blank at each step 24 | r[0, 1] = self.x[0, self.blank] 25 | for i in range(1, self.input_length): 26 | r[i, 1] = r[i-1, 1] + self.x[i, self.blank] 27 | return r 28 | 29 | def full_compute(self, g, r_prev): 30 | '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) 31 | This function computes all possible tokens for c (memory inefficient)''' 32 | prefix_length = len(g) 33 | last_char = g[-1] if prefix_length > 0 else 0 34 | 35 | # init. r 36 | r = np.full((self.input_length, 2, self.odim), 37 | self.logzero, dtype=np.float32) 38 | 39 | # start from len(g) because is impossible for CTC to generate |y|>|X| 40 | start = max(1, prefix_length) 41 | 42 | if prefix_length == 0: 43 | r[0, 0, :] = self.x[0, :] # if g = 44 | 45 | psi = r[start-1, 0, :] 46 | 47 | phi = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) 48 | 49 | for t in range(start, self.input_length): 50 | # prev_blank 51 | prev_blank = np.full((self.odim), r_prev[t-1, 1], dtype=np.float32) 52 | # prev_nonblank 53 | prev_nonblank = np.full( 54 | (self.odim), r_prev[t-1, 0], dtype=np.float32) 55 | prev_nonblank[last_char] = self.logzero 56 | 57 | phi = np.logaddexp(prev_nonblank, prev_blank) 58 | # P(h|current step is non-blank) = [ P(prev. step = y) + P()]*P(c) 59 | r[t, 0, :] = np.logaddexp(r[t-1, 0, :], phi) + self.x[t, :] 60 | # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank) 61 | r[t, 1, :] = np.logaddexp( 62 | r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] 63 | psi = np.logaddexp(psi, phi+self.x[t, :]) 64 | 65 | #psi[self.eos] = np.logaddexp(r_prev[-1,0], r_prev[-1,1]) 66 | return psi, np.rollaxis(r, 2) 67 | 68 | def cheap_compute(self, g, r_prev, candidates): 69 | '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) 70 | This function considers only those tokens in candidates for c (memory efficient)''' 71 | prefix_length = len(g) 72 | odim = len(candidates) 73 | last_char = g[-1] if prefix_length > 0 else 0 74 | 75 | # init. r 76 | r = np.full((self.input_length, 2, len(candidates)), 77 | self.logzero, dtype=np.float32) 78 | 79 | # start from len(g) because is impossible for CTC to generate |y|>|X| 80 | start = max(1, prefix_length) 81 | 82 | if prefix_length == 0: 83 | r[0, 0, :] = self.x[0, candidates] # if g = 84 | 85 | psi = r[start-1, 0, :] 86 | # Phi = (prev_nonblank,prev_blank) 87 | sum_prev = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) 88 | phi = np.repeat(sum_prev[..., None],odim,axis=-1) 89 | # Handle edge case : last tok of prefix in candidates 90 | if prefix_length>0 and last_char in candidates: 91 | phi[:,candidates.index(last_char)] = r_prev[:,1] 92 | 93 | for t in range(start, self.input_length): 94 | # prev_blank 95 | # prev_blank = np.full((odim), r_prev[t-1, 1], dtype=np.float32) 96 | # prev_nonblank 97 | # prev_nonblank = np.full((odim), r_prev[t-1, 0], dtype=np.float32) 98 | # phi = np.logaddexp(prev_nonblank, prev_blank) 99 | # P(h|current step is non-blank) = P(prev. step = y)*P(c) 100 | r[t, 0, :] = np.logaddexp( r[t-1, 0, :], phi[t-1]) + self.x[t, candidates] 101 | # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank) 102 | r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] 103 | psi = np.logaddexp(psi, phi[t-1,]+self.x[t, candidates]) 104 | 105 | # P(end of sentence) = P(g) 106 | if self.eos in candidates: 107 | psi[candidates.index(self.eos)] = sum_prev[-1] 108 | return psi, np.rollaxis(r, 2) 109 | -------------------------------------------------------------------------------- /bin/train_lm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.solver import BaseSolver 3 | 4 | from src.lm import RNNLM 5 | from src.optim import Optimizer 6 | from src.data import load_textset 7 | from src.util import human_format 8 | 9 | 10 | class Solver(BaseSolver): 11 | ''' Solver for training language models''' 12 | def __init__(self,config,paras,mode): 13 | super().__init__(config,paras,mode) 14 | # Logger settings 15 | self.best_loss = 10 16 | 17 | def fetch_data(self, data): 18 | ''' Move data to device, insert and compute text seq. length''' 19 | txt = torch.cat((torch.zeros((data.shape[0],1),dtype=torch.long),data), dim=1).to(self.device) 20 | txt_len = torch.sum(data!=0,dim=-1) 21 | return txt, txt_len 22 | 23 | def load_data(self): 24 | ''' Load data for training/validation, store tokenizer and input/output shape''' 25 | self.tr_set, self.dv_set, self.vocab_size, self.tokenizer, msg = \ 26 | load_textset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, **self.config['data']) 27 | self.verbose(msg) 28 | 29 | def set_model(self): 30 | ''' Setup ASR model and optimizer ''' 31 | 32 | # Model 33 | self.model = RNNLM( self.vocab_size, **self.config['model']).to(self.device) 34 | self.verbose(self.model.create_msg()) 35 | # Losses 36 | self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) 37 | # Optimizer 38 | self.optimizer = Optimizer(self.model.parameters(),**self.config['hparas']) 39 | # Enable AMP if needed 40 | self.enable_apex() 41 | # load pre-trained model 42 | if self.paras.load: 43 | self.load_ckpt() 44 | ckpt = torch.load(self.paras.load, map_location=self.device) 45 | self.model.load_state_dict(ckpt['model']) 46 | self.optimizer.load_opt_state_dict(ckpt['optimizer']) 47 | self.step = ckpt['global_step'] 48 | self.verbose('Load ckpt from {}, restarting at step {}'.format(self.paras.load,self.step)) 49 | 50 | def exec(self): 51 | ''' Training End-to-end ASR system ''' 52 | self.verbose('Total training steps {}.'.format(human_format(self.max_step))) 53 | self.timer.set() 54 | 55 | while self.step< self.max_step: 56 | for data in self.tr_set: 57 | # Pre-step : update tf_rate/lr_rate and do zero_grad 58 | self.optimizer.pre_step(self.step) 59 | 60 | # Fetch data 61 | txt, txt_len = self.fetch_data(data) 62 | self.timer.cnt('rd') 63 | 64 | # Forward model 65 | pred, _ = self.model(txt[:,:-1], txt_len) 66 | 67 | # Compute all objectives 68 | lm_loss = self.seq_loss(pred.view(-1,self.vocab_size),txt[:,1:].reshape(-1)) 69 | self.timer.cnt('fw') 70 | 71 | # Backprop 72 | grad_norm = self.backward(lm_loss) 73 | self.step +=1 74 | 75 | # Logger 76 | if self.step%self.PROGRESS_STEP==0: 77 | self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ 78 | .format(lm_loss.cpu().item(),grad_norm,self.timer.show())) 79 | self.write_log('entropy',{'tr':lm_loss}) 80 | self.write_log('perplexity',{'tr':torch.exp(lm_loss).cpu().item()}) 81 | 82 | # Validation 83 | if (self.step==1) or (self.step%self.valid_step == 0): 84 | self.validate() 85 | 86 | # End of step 87 | self.timer.set() 88 | if self.step > self.max_step:break 89 | self.log.close() 90 | 91 | def validate(self): 92 | # Eval mode 93 | self.model.eval() 94 | dev_loss = [] 95 | 96 | for i,data in enumerate(self.dv_set): 97 | self.progress('Valid step - {}/{}'.format(i+1,len(self.dv_set))) 98 | # Fetch data 99 | txt, txt_len = self.fetch_data(data) 100 | 101 | # Forward model 102 | with torch.no_grad(): 103 | pred, _ = self.model(txt[:,:-1], txt_len) 104 | lm_loss = self.seq_loss(pred.view(-1,self.vocab_size),txt[:,1:].reshape(-1)) 105 | dev_loss.append(lm_loss) 106 | 107 | # Ckpt if performance improves 108 | dev_loss = sum(dev_loss)/len(dev_loss) 109 | dev_ppx = torch.exp(dev_loss).cpu().item() 110 | if dev_loss < self.best_loss : 111 | self.best_loss = dev_loss 112 | self.save_checkpoint('best_ppx.pth','perplexity',dev_ppx) 113 | self.write_log('entropy',{'dv':dev_loss}) 114 | self.write_log('perplexity',{'dv':dev_ppx}) 115 | 116 | # Show some example of last batch on tensorboard 117 | for i in range(min(len(txt),self.DEV_N_EXAMPLE)): 118 | if self.step ==1: 119 | self.write_log('true_text{}'.format(i),self.tokenizer.decode(txt[i].tolist())) 120 | self.write_log('pred_text{}'.format(i),self.tokenizer.decode(pred[i].argmax(dim=-1).tolist())) 121 | 122 | # Resume training 123 | self.model.train() 124 | -------------------------------------------------------------------------------- /corpus/preprocess_librispeech.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | from os.path import join,getsize 4 | from joblib import Parallel, delayed 5 | from torch.utils.data import Dataset 6 | 7 | OFFICIAL_TXT_SRC = ['librispeech-lm-norm.txt'] # Additional (official) text src provided 8 | REMOVE_TOP_N_TXT = 5000000 # Remove longest N sentence in librispeech-lm-norm.txt 9 | READ_FILE_THREADS = 16 # Default num. of threads used for loading LibriSpeech 10 | 11 | def read_text(file): 12 | '''Get transcription of target wave file, 13 | it's somewhat redundant for accessing each txt multiplt times, 14 | but it works fine with multi-thread''' 15 | src_file = '-'.join(file.split('-')[:-1])+'.trans.txt' 16 | idx = file.split('/')[-1].split('.')[0] 17 | 18 | with open(src_file,'r') as fp: 19 | for line in fp: 20 | if idx == line.split(' ')[0]: 21 | return line[:-1].split(' ',1)[1] 22 | 23 | class LibriDataset(Dataset): 24 | def __init__(self, path, split, tokenizer, bucket_size=1, 25 | ascending=False, read_audio=False, sort_by_text=False): 26 | # Setup 27 | self.path = path 28 | self.bucket_size = bucket_size 29 | 30 | # List all wave files 31 | file_list = [] 32 | for s in split: 33 | if s[0] == 't' or s[0] == 'd': 34 | file_list += list(Path(join(path,s)).rglob("*.flac")) 35 | assert len(file_list)>0, "No data found @ {}".format(path) 36 | 37 | # Read text 38 | text = Parallel(n_jobs=READ_FILE_THREADS)(delayed(read_text)(str(f)) for f in file_list) 39 | # text = Parallel(n_jobs=-1)(delayed(tokenizer.encode)(txt) for txt in text) 40 | text = [tokenizer.encode(txt) for txt in text] 41 | 42 | # Sort dataset by text length 43 | # file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list) 44 | if sort_by_text == False: 45 | file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list) 46 | else: 47 | file_len = [len(txt) for txt in text] 48 | 49 | self.file_list, self.text = zip(*[(str(f_name),txt) \ 50 | for time_len,f_name,txt in sorted(zip(file_len,file_list,text), reverse=not ascending, key=lambda x:x[0])]) 51 | if read_audio: 52 | from src.audio import ReadAudio 53 | audio_reader = ReadAudio(8000) 54 | self.file_list = [audio_reader(str(f)) for i, f in tqdm(enumerate(self.file_list))] 55 | 56 | print('[INFO] LibriSpeech', split[-1], 'set :',len(self.file_list),'audio files found') 57 | 58 | def __getitem__(self,index): 59 | if self.bucket_size>1: 60 | # Return a bucket 61 | index = min(len(self.file_list)-self.bucket_size,index) 62 | return [(f_path, txt) for f_path,txt in \ 63 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 64 | else: 65 | return self.file_list[index], self.text[index] 66 | 67 | def __len__(self): 68 | return len(self.file_list) 69 | 70 | class LibriTextDataset(Dataset): 71 | def __init__(self, path, split, tokenizer, bucket_size): 72 | # Setup 73 | self.path = path 74 | self.bucket_size = bucket_size 75 | self.encode_on_fly = False 76 | read_txt_src = [] 77 | 78 | # List all wave files 79 | file_list, all_sent = [],[] 80 | 81 | for s in split: 82 | if s in OFFICIAL_TXT_SRC: 83 | self.encode_on_fly = True 84 | with open(join(path,s),'r') as f: 85 | all_sent += f.readlines() 86 | file_list += list(Path(join(path,s)).rglob("*.flac")) 87 | assert (len(file_list)>0) or (len(all_sent)>0), "No data found @ {}".format(path) 88 | 89 | # Read text 90 | text = Parallel(n_jobs=READ_FILE_THREADS)(delayed(read_text)(str(f)) for f in file_list) 91 | all_sent.extend(text) 92 | del text 93 | 94 | # Encode text 95 | if self.encode_on_fly: 96 | self.tokenizer = tokenizer 97 | self.text = [sent for sent in all_sent] 98 | else: 99 | self.text = [tokenizer.encode(txt) for txt in tqdm(all_sent)] 100 | del all_sent 101 | 102 | # Read file size and sort dataset by file size (Note: feature len. may be different) 103 | self.text = sorted(self.text, reverse=True, key=lambda x:len(x)) 104 | if self.encode_on_fly: 105 | del self.text[:REMOVE_TOP_N_TXT] 106 | 107 | def __getitem__(self,index): 108 | if self.bucket_size>1: 109 | index = min(len(self.text)-self.bucket_size,index) 110 | if self.encode_on_fly: 111 | for i in range(index,index+self.bucket_size): 112 | if type(self.text[i]) is str: 113 | self.text[i] = self.tokenizer.encode(self.text[i]) 114 | # Return a bucket 115 | return self.text[index:index+self.bucket_size] 116 | else: 117 | if self.encode_on_fly and type(self.text[index]) is str: 118 | self.text[index] = self.tokenizer.encode(self.text[index]) 119 | return self.text[index] 120 | 121 | def __len__(self): 122 | return len(self.text) 123 | -------------------------------------------------------------------------------- /corpus/preprocess_dlhlp.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from pathlib import Path 3 | from os.path import join,getsize 4 | from joblib import Parallel, delayed 5 | from torch.utils.data import Dataset 6 | 7 | # from sphfile import SPHFile 8 | # import soundfile as sf 9 | import wave 10 | 11 | ADDITIONAL_TXT_SRC = ['bopomo_corpus.txt'] # Additional text src provided 12 | REMOVE_TOP_N_TXT = 5000000 # Remove longest N sentence in librispeech-lm-norm.txt 13 | READ_FILE_THREADS = 16 # Default num. of threads used for loading LibriSpeech 14 | 15 | def read_text(file): 16 | '''Get transcription of target wave file, 17 | it's somewhat redundant for accessing each txt multiplt times, 18 | but it works fine with multi-thread''' 19 | src_file = file.rsplit('/', 1)[0]+'/bopomo.trans.txt' 20 | idx = file.split('/')[-1].split('.')[0] 21 | 22 | with open(src_file,'r',encoding='UTF-8') as fp: 23 | for line in fp: 24 | if idx == line.split(' ')[0]: 25 | return line[:-1].split(' ',1)[1] 26 | 27 | class DLHLPDataset(Dataset): 28 | def __init__(self, path, split, tokenizer, bucket_size=1, ascending=False, read_audio=False): 29 | # Setup 30 | self.path = path 31 | self.bucket_size = bucket_size 32 | 33 | # List all wave files 34 | file_list = [] 35 | for s in split: 36 | if s[0] == 't' or s[0] == 'd': 37 | file_list += list(Path(join(path,s)).rglob("*.wav")) 38 | assert len(file_list)>0, "No data found @ {}".format(path) 39 | 40 | # Read text 41 | text = Parallel(n_jobs=READ_FILE_THREADS)(delayed(read_text)(str(f)) for f in file_list) 42 | # text = Parallel(n_jobs=-1)(delayed(tokenizer.encode)(txt) for txt in text) 43 | text = [tokenizer.encode(txt.lower()) for txt in text] 44 | 45 | # Sort dataset by text length 46 | # file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list) 47 | file_len = Parallel(n_jobs=READ_FILE_THREADS)(delayed(getsize)(f) for f in file_list) 48 | # if split[0] == 'test': 49 | # file_names = [int(str(f).split('/')[-1][:-4]) for f in file_list] 50 | # self.file_list, self.text = zip(*[(f_name,txt) \ 51 | # for f_num,f_name,txt in sorted(zip(file_names,file_list,text), reverse=ascending, key=lambda x:x[0])]) 52 | # else: 53 | self.file_list, self.text = zip(*[(f_name,txt) \ 54 | for time_len,f_name,txt in sorted(zip(file_len,file_list,text), reverse=not ascending, key=lambda x:x[0])]) 55 | 56 | print('[INFO] DLHLP dataset', split[-1], 'set :',len(self.file_list),'audio files found') 57 | 58 | def __getitem__(self,index): 59 | if self.bucket_size>1: 60 | # Return a bucket 61 | index = min(len(self.file_list)-self.bucket_size,index) 62 | return [(f_path, txt) for f_path,txt in \ 63 | zip(self.file_list[index:index+self.bucket_size], self.text[index:index+self.bucket_size])] 64 | else: 65 | return self.file_list[index], self.text[index] 66 | 67 | def __len__(self): 68 | return len(self.file_list) 69 | 70 | 71 | class DLHLPTextDataset(Dataset): 72 | def __init__(self, path, split, tokenizer, bucket_size): 73 | # Setup 74 | self.path = path 75 | self.bucket_size = bucket_size 76 | self.encode_on_fly = False 77 | read_txt_src = [] 78 | 79 | # List all wave files 80 | file_list, all_sent = [],[] 81 | 82 | for s in split: 83 | if s in ADDITIONAL_TXT_SRC: 84 | # self.encode_on_fly = True 85 | with open(join(path,s),'r') as f: 86 | all_sent += f.readlines() 87 | file_list += list(Path(join(path,s)).rglob("*.wav")) 88 | assert (len(file_list)>0) or (len(all_sent)>0), "No data found @ {}".format(path) 89 | 90 | # Read text 91 | text = Parallel(n_jobs=READ_FILE_THREADS)(delayed(read_text)(str(f)) for f in file_list) 92 | all_sent.extend(text) 93 | del text 94 | 95 | # Encode text 96 | if self.encode_on_fly: 97 | self.tokenizer = tokenizer 98 | self.text = all_sent 99 | else: 100 | self.text = [tokenizer.encode(txt) for txt in tqdm(all_sent)] 101 | del all_sent 102 | 103 | # Read file size and sort dataset by file size (Note: feature len. may be different) 104 | self.text = sorted(self.text, reverse=True, key=lambda x:len(x)) 105 | if self.encode_on_fly: 106 | del self.text[:REMOVE_TOP_N_TXT] 107 | 108 | def __getitem__(self,index): 109 | if self.bucket_size>1: 110 | index = min(len(self.text)-self.bucket_size,index) 111 | if self.encode_on_fly: 112 | for i in range(index,index+self.bucket_size): 113 | if type(self.text[i]) is str: 114 | self.text[i] = self.tokenizer.encode(self.text[i]) 115 | # Return a bucket 116 | return self.text[index:index+self.bucket_size] 117 | else: 118 | if self.encode_on_fly and type(self.text[index]) is str: 119 | self.text[index] = self.tokenizer.encode(self.text[index]) 120 | return self.text[index] 121 | 122 | def __len__(self): 123 | return len(self.text) 124 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import numpy as np 5 | from torch import nn 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | class Timer(): 12 | ''' Timer for recording training time distribution. ''' 13 | def __init__(self): 14 | self.prev_t = time.time() 15 | self.clear() 16 | 17 | def set(self): 18 | self.prev_t = time.time() 19 | 20 | def cnt(self,mode): 21 | self.time_table[mode] += time.time()-self.prev_t 22 | self.set() 23 | if mode =='bw': 24 | self.click += 1 25 | 26 | def show(self): 27 | total_time = sum(self.time_table.values()) 28 | self.time_table['avg'] = total_time/self.click 29 | self.time_table['rd'] = 100*self.time_table['rd']/total_time 30 | self.time_table['fw'] = 100*self.time_table['fw']/total_time 31 | self.time_table['bw'] = 100*self.time_table['bw']/total_time 32 | msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format(**self.time_table) 33 | self.clear() 34 | return msg 35 | 36 | def clear(self): 37 | self.time_table = {'rd':0,'fw':0,'bw':0} 38 | self.click = 0 39 | 40 | # Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168 41 | def init_weights(module): 42 | # Exceptions 43 | if type(module) == nn.Embedding: 44 | module.weight.data.normal_(0, 1) 45 | else: 46 | for p in module.parameters(): 47 | data = p.data 48 | if data.dim() == 1: 49 | # bias 50 | data.zero_() 51 | elif data.dim() == 2: 52 | # linear weight 53 | n = data.size(1) 54 | stdv = 1. / math.sqrt(n) 55 | data.normal_(0, stdv) 56 | elif data.dim() in [3,4]: 57 | # conv weight 58 | n = data.size(1) 59 | for k in data.size()[2:]: 60 | n *= k 61 | stdv = 1. / math.sqrt(n) 62 | data.normal_(0, stdv) 63 | else: 64 | raise NotImplementedError 65 | def init_gate(bias): 66 | n = bias.size(0) 67 | start, end = n // 4, n // 2 68 | bias.data[start:end].fill_(1.) 69 | return bias 70 | 71 | # Convert Tensor to Figure on tensorboard 72 | def feat_to_fig(feat, spec=False): 73 | # feat TxD tensor 74 | data = _save_canvas(feat.numpy(), spec=spec) 75 | return torch.FloatTensor(data),"HWC" 76 | 77 | def _save_canvas(data, meta=None, spec=False): 78 | sx = 16 79 | sy = 8 80 | if spec: 81 | sx = 24 82 | sy = 8 83 | fig, ax = plt.subplots(figsize=(sx, sy)) 84 | if meta is None: 85 | ax.imshow(data, aspect="auto", origin="lower") 86 | else: 87 | ax.bar(meta[0],data[0],tick_label=meta[1],fc=(0, 0, 1, 0.5)) 88 | ax.bar(meta[0],data[1],tick_label=meta[1],fc=(1, 0, 0, 0.5)) 89 | fig.canvas.draw() 90 | # Note : torch tb add_image takes color as [0,1] 91 | data = np.array(fig.canvas.renderer._renderer)[:,:,:-1]/255.0 92 | plt.close(fig) 93 | return data 94 | 95 | # Reference : https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python 96 | def human_format(num): 97 | magnitude = 0 98 | while num >= 1000: 99 | magnitude += 1 100 | num /= 1000.0 101 | # add more suffixes if you need them 102 | return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) 103 | 104 | def cal_er(tokenizer, pred, truth, mode='wer', ctc=False): 105 | import editdistance as ed 106 | # Calculate error rate of a batch 107 | if pred is None: 108 | return np.nan 109 | elif len(pred.shape)>=3: 110 | pred = pred.argmax(dim=-1) 111 | er = [] 112 | for p,t in zip(pred,truth): 113 | p = tokenizer.decode(p.tolist(), ignore_repeat=ctc) 114 | t = tokenizer.decode(t.tolist()) 115 | if mode == 'wer' or mode == 'per': 116 | p = p.split(' ') 117 | t = t.split(' ') 118 | error = 1. if len(t) == 0 else float(ed.eval(p,t))/len(t) 119 | er.append(error) 120 | return sum(er)/len(er) 121 | 122 | 123 | def load_embedding(text_encoder, embedding_filepath): 124 | with open(embedding_filepath, "r") as f: 125 | vocab_size, embedding_size = [int(x) for x in f.readline().strip().split()] 126 | embeddings = np.zeros((text_encoder.vocab_size, embedding_size)) 127 | 128 | unk_count = 0 129 | 130 | for line in f: 131 | vocab, emb = line.strip().split(" ", 1) 132 | # fasttext's is 133 | if vocab == "": 134 | vocab = "" 135 | 136 | if text_encoder.token_type == "subword": 137 | idx = text_encoder.spm.piece_to_id(vocab) 138 | else: 139 | # get rid of 140 | idx = text_encoder.encode(vocab)[0] 141 | 142 | if idx == text_encoder.unk_idx: 143 | unk_count += 1 144 | embeddings[idx] += np.asarray([float(x) for x in emb.split(" ")]) 145 | else: 146 | # Suppose there is only one (w, v) pair in embedding file 147 | embeddings[idx] = np.asarray([float(x) for x in emb.split(" ")]) 148 | 149 | # Average vector 150 | if unk_count != 0: 151 | embeddings[text_encoder.unk_idx] /= unk_count 152 | 153 | return embeddings 154 | 155 | def count_parameters(model): 156 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 157 | -------------------------------------------------------------------------------- /src/plugin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from src.util import load_embedding 4 | from src.bert_embedding import BertEmbeddingPredictor 5 | 6 | 7 | class EmbeddingRegularizer(nn.Module): 8 | ''' Perform word embedding regularization training for ASR''' 9 | 10 | def __init__(self, tokenizer, dec_dim, enable, src, distance, weight, fuse, temperature, 11 | freeze=True, fuse_normalize=False, dropout=0.0, bert=None): 12 | super(EmbeddingRegularizer, self).__init__() 13 | self.enable = enable 14 | if enable: 15 | if bert is not None: 16 | self.use_bert = True 17 | if not isinstance(bert, str): 18 | raise ValueError( 19 | "`bert` should be a str specifying bert config such as \"bert-base-uncased\".") 20 | self.emb_table = BertEmbeddingPredictor(bert, tokenizer, src) 21 | vocab_size, emb_dim = self.emb_table.model.bert.embeddings.word_embeddings.weight.shape 22 | vocab_size = vocab_size-3 # cls,sep,mask not used 23 | self.dim = emb_dim 24 | else: 25 | self.use_bert = False 26 | pretrained_emb = torch.FloatTensor( 27 | load_embedding(tokenizer, src)) 28 | # pretrained_emb = nn.functional.normalize(pretrained_emb,dim=-1) # ToDo : Check impact on old version 29 | vocab_size, emb_dim = pretrained_emb.shape 30 | self.dim = emb_dim 31 | 32 | self.emb_table = nn.Embedding.from_pretrained( 33 | pretrained_emb, freeze=freeze, padding_idx=0) 34 | 35 | self.emb_net = nn.Sequential(nn.Linear(dec_dim, (emb_dim+dec_dim)//2), 36 | nn.ReLU(), 37 | nn.Linear((emb_dim+dec_dim)//2, emb_dim)) 38 | self.weight = weight 39 | self.distance = distance 40 | self.fuse_normalize = fuse_normalize 41 | if distance == 'CosEmb': 42 | # This maybe somewhat reduandant since cos emb loss includes ||x|| 43 | self.measurement = nn.CosineEmbeddingLoss(reduction='none') 44 | elif distance == 'MSE': 45 | self.measurement = nn.MSELoss(reduction='none') 46 | else: 47 | raise NotImplementedError 48 | 49 | self.apply_dropout = dropout > 0 50 | if self.apply_dropout: 51 | self.dropout = nn.Dropout(dropout) 52 | 53 | self.apply_fuse = fuse != 0 54 | if self.apply_fuse: 55 | # Weight for mixing emb/dec prob 56 | if fuse == -1: 57 | # Learnable fusion 58 | self.fuse_type = "learnable" 59 | self.fuse_learnable = True 60 | self.fuse_lambda = nn.Parameter( 61 | data=torch.FloatTensor([0.5])) 62 | elif fuse == -2: 63 | # Learnable vocab-wise fusion 64 | self.fuse_type = "vocab-wise learnable" 65 | self.fuse_learnable = True 66 | self.fuse_lambda = nn.Parameter( 67 | torch.ones((vocab_size))*0.5) 68 | else: 69 | self.fuse_type = str(fuse) 70 | self.fuse_learnable = False 71 | self.register_buffer( 72 | 'fuse_lambda', torch.FloatTensor([fuse])) 73 | # Temperature of emb prob. 74 | if temperature == -1: 75 | self.temperature = 'learnable' 76 | self.temp = nn.Parameter(data=torch.FloatTensor([1])) 77 | elif temperature == -2: 78 | self.temperature = 'elementwise' 79 | self.temp = nn.Parameter(torch.ones((vocab_size))) 80 | else: 81 | self.temperature = str(temperature) 82 | self.register_buffer( 83 | 'temp', torch.FloatTensor([temperature])) 84 | self.eps = 1e-8 85 | 86 | def create_msg(self): 87 | msg = ['Plugin. | Word embedding regularization enabled (type:{}, weight:{})'.format( 88 | self.distance, self.weight)] 89 | if self.apply_fuse: 90 | msg.append(' | Embedding-fusion decoder enabled ( temp. = {}, lambda = {} )'. 91 | format(self.temperature, self.fuse_type)) 92 | return msg 93 | 94 | def get_weight(self): 95 | if self.fuse_learnable: 96 | return torch.sigmoid(self.fuse_lambda).mean().cpu().data 97 | else: 98 | return self.fuse_lambda 99 | 100 | def get_temp(self): 101 | return nn.functional.relu(self.temp).mean() 102 | 103 | def fuse_prob(self, x_emb, dec_logit): 104 | ''' Takes context and decoder logit to perform word embedding fusion ''' 105 | # Compute distribution for dec/emb 106 | if self.fuse_normalize: 107 | emb_logit = nn.functional.linear(nn.functional.normalize(x_emb, dim=-1), 108 | nn.functional.normalize(self.emb_table.weight, dim=-1)) 109 | else: 110 | emb_logit = nn.functional.linear(x_emb, self.emb_table.weight) 111 | emb_prob = (nn.functional.relu(self.temp)*emb_logit).softmax(dim=-1) 112 | dec_prob = dec_logit.softmax(dim=-1) 113 | # Mix distribution 114 | if self.fuse_learnable: 115 | fused_prob = (1-torch.sigmoid(self.fuse_lambda))*dec_prob +\ 116 | torch.sigmoid(self.fuse_lambda)*emb_prob 117 | else: 118 | fused_prob = (1-self.fuse_lambda)*dec_prob + \ 119 | self.fuse_lambda*emb_prob 120 | # Log-prob 121 | log_fused_prob = (fused_prob+self.eps).log() 122 | 123 | return log_fused_prob 124 | 125 | def forward(self, dec_state, dec_logit, label=None, return_loss=True): 126 | # Match embedding dim. 127 | log_fused_prob = None 128 | loss = None 129 | 130 | #x_emb = nn.functional.normalize(self.emb_net(dec_state),dim=-1) 131 | if self.apply_dropout: 132 | dec_state = self.dropout(dec_state) 133 | x_emb = self.emb_net(dec_state) 134 | 135 | if return_loss: 136 | # Compute embedding loss 137 | b, t = label.shape 138 | # Retrieve embedding 139 | if self.use_bert: 140 | with torch.no_grad(): 141 | y_emb = self.emb_table(label).contiguous() 142 | else: 143 | y_emb = self.emb_table(label) 144 | # Regression loss on embedding 145 | if self.distance == 'CosEmb': 146 | loss = self.measurement( 147 | x_emb.view(-1, self.dim), y_emb.view(-1, self.dim), torch.ones(1).to(dec_state.device)) 148 | else: 149 | loss = self.measurement( 150 | x_emb.view(-1, self.dim), y_emb.view(-1, self.dim)) 151 | loss = loss.view(b, t) 152 | # Mask out padding 153 | loss = torch.where(label != 0, loss, torch.zeros_like(loss)) 154 | loss = torch.mean(loss.sum(dim=-1) / 155 | (label != 0).sum(dim=-1).float()) 156 | 157 | if self.apply_fuse: 158 | log_fused_prob = self.fuse_prob(x_emb, dec_logit) 159 | 160 | return loss, log_fused_prob 161 | -------------------------------------------------------------------------------- /src/text.py: -------------------------------------------------------------------------------- 1 | """Modified from tensorflow_datasets.features.text.* 2 | 3 | Reference: https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text_lib 4 | """ 5 | import abc 6 | 7 | BERT_FIRST_IDX = 997 # Replacing the 2 tokens right before english starts as & 8 | BERT_LAST_IDX = 29635 # Drop rest of tokens 9 | 10 | class _BaseTextEncoder(abc.ABC): 11 | @abc.abstractmethod 12 | def encode(self, s): 13 | raise NotImplementedError 14 | 15 | @abc.abstractmethod 16 | def decode(self, ids, ignore_repeat=False): 17 | raise NotImplementedError 18 | 19 | @abc.abstractproperty 20 | def vocab_size(self): 21 | raise NotImplementedError 22 | 23 | @abc.abstractproperty 24 | def token_type(self): 25 | raise NotImplementedError 26 | 27 | @abc.abstractclassmethod 28 | def load_from_file(cls, vocab_file): 29 | raise NotImplementedError 30 | 31 | @property 32 | def pad_idx(self): 33 | return 0 34 | 35 | @property 36 | def eos_idx(self): 37 | return 1 38 | 39 | @property 40 | def unk_idx(self): 41 | return 2 42 | 43 | def __repr__(self): 44 | return "<{} vocab_size={}>".format(type(self).__name__, self.vocab_size) 45 | 46 | 47 | class CharacterTextEncoder(_BaseTextEncoder): 48 | def __init__(self, vocab_list): 49 | # Note that vocab_list must not contain , and 50 | # =0, =1, =2 51 | self._vocab_list = ["", "", ""] + vocab_list 52 | self._vocab2idx = {v: idx for idx, v in enumerate(self._vocab_list)} 53 | 54 | def encode(self, s): 55 | # Always strip trailing space, \r and \n 56 | s = s.strip("\r\n ") 57 | # Manually append eos to the end 58 | return [self.vocab_to_idx(v) for v in s] + [self.eos_idx] 59 | 60 | def decode(self, idxs, ignore_repeat=False): 61 | vocabs = [] 62 | for t, idx in enumerate(idxs): 63 | if idx == self.eos_idx: 64 | break 65 | elif idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t - 1 if t > 0 else 0]): 66 | continue 67 | v = self.idx_to_vocab(idx) 68 | vocabs.append(v) 69 | return "".join(vocabs) 70 | 71 | @classmethod 72 | def load_from_file(cls, vocab_file): 73 | with open(vocab_file, "r", encoding='UTF-8') as f: 74 | # Do not strip space because character based text encoder should 75 | # have a space token 76 | vocab_list = [line.strip("\r\n") for line in f] 77 | return cls(vocab_list) 78 | 79 | @property 80 | def vocab_size(self): 81 | return len(self._vocab_list) 82 | 83 | @property 84 | def token_type(self): 85 | return 'character' 86 | 87 | def vocab_to_idx(self, vocab): 88 | return self._vocab2idx.get(vocab, self.unk_idx) 89 | 90 | def idx_to_vocab(self, idx): 91 | return self._vocab_list[idx] 92 | 93 | 94 | class SubwordTextEncoder(_BaseTextEncoder): 95 | def __init__(self, spm): 96 | if spm.pad_id() != 0 or spm.eos_id() != 1 or spm.unk_id() != 2: 97 | raise ValueError( 98 | "Please train sentencepiece model with following argument:\n" 99 | "--pad_id=0 --eos_id=1 --unk_id=2 --bos_id=-1 --model_type=bpe --eos_piece=") 100 | self.spm = spm 101 | 102 | def encode(self, s): 103 | return self.spm.encode_as_ids(s) 104 | 105 | def decode(self, idxs, ignore_repeat=False): 106 | crop_idx = [] 107 | for t, idx in enumerate(idxs): 108 | if idx == self.eos_idx: 109 | break 110 | elif idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t-1]): 111 | continue 112 | else: 113 | crop_idx.append(idx) 114 | return self.spm.decode_ids(crop_idx) 115 | 116 | @classmethod 117 | def load_from_file(cls, filepath): 118 | import sentencepiece as splib 119 | spm = splib.SentencePieceProcessor() 120 | spm.load(filepath) 121 | spm.set_encode_extra_options(":eos") 122 | return cls(spm) 123 | 124 | @property 125 | def vocab_size(self): 126 | return len(self.spm) 127 | 128 | @property 129 | def token_type(self): 130 | return 'subword' 131 | 132 | 133 | class WordTextEncoder(CharacterTextEncoder): 134 | def encode(self, s): 135 | # Always strip trailing space, \r and \n 136 | s = s.strip("\r\n ") 137 | # Space as the delimiter between words 138 | words = s.split(" ") 139 | # Manually append eos to the end 140 | return [self.vocab_to_idx(v) for v in words] + [self.eos_idx] 141 | 142 | def decode(self, idxs, ignore_repeat=False): 143 | vocabs = [] 144 | for t, idx in enumerate(idxs): 145 | v = self.idx_to_vocab(idx) 146 | if idx == self.eos_idx: 147 | break 148 | elif idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t-1]): 149 | continue 150 | else: 151 | vocabs.append(v) 152 | return " ".join(vocabs) 153 | 154 | @property 155 | def token_type(self): 156 | return 'word' 157 | 158 | 159 | class BertTextEncoder(_BaseTextEncoder): 160 | """Bert Tokenizer. 161 | 162 | https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/tokenization_bert.py 163 | """ 164 | 165 | def __init__(self, tokenizer): 166 | self._tokenizer = tokenizer 167 | self._tokenizer.pad_token = "" 168 | self._tokenizer.eos_token = "" 169 | self._tokenizer.unk_token = "" 170 | 171 | def encode(self, s): 172 | # Reduce vocab size manually 173 | reduced_idx = [] 174 | for idx in self._tokenizer.encode(s): 175 | try: 176 | r_idx = idx-BERT_FIRST_IDX 177 | assert r_idx>0 178 | reduced_idx.append(r_idx) 179 | except: 180 | reduced_idx.append(self.unk_idx) 181 | reduced_idx.append(self.eos_idx) 182 | return reduced_idx 183 | 184 | def decode(self, idxs, ignore_repeat=False): 185 | crop_idx = [] 186 | for t, idx in enumerate(idxs): 187 | if idx == self.eos_idx: 188 | break 189 | elif idx == self.pad_idx or (ignore_repeat and t > 0 and idx == idxs[t-1]): 190 | continue 191 | else: 192 | crop_idx.append(idx+BERT_FIRST_IDX) # Shift to correct idx for bert tokenizer 193 | return self._tokenizer.decode(crop_idx) 194 | 195 | @property 196 | def vocab_size(self): 197 | return BERT_LAST_IDX-BERT_FIRST_IDX+1 198 | 199 | @property 200 | def token_type(self): 201 | return "bert" 202 | 203 | @classmethod 204 | def load_from_file(cls, vocab_file): 205 | from pytorch_transformers import BertTokenizer 206 | return cls(BertTokenizer.from_pretrained(vocab_file)) 207 | 208 | @property 209 | def pad_idx(self): 210 | return 0 211 | 212 | @property 213 | def eos_idx(self): 214 | return 1 215 | 216 | @property 217 | def unk_idx(self): 218 | return 2 219 | 220 | 221 | def load_text_encoder(mode, vocab_file): 222 | if mode == "character": 223 | return CharacterTextEncoder.load_from_file(vocab_file) 224 | elif mode == "subword": 225 | return SubwordTextEncoder.load_from_file(vocab_file) 226 | elif mode == "word" or mode == "phone": 227 | return WordTextEncoder.load_from_file(vocab_file) 228 | elif mode.startswith("bert-"): 229 | return BertTextEncoder.load_from_file(mode) 230 | else: 231 | raise NotImplementedError("`{}` is not yet supported.".format(mode)) 232 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from functools import partial 4 | from src.text import load_text_encoder 5 | from src.audio import create_transform 6 | from torch.utils.data import DataLoader 7 | from torch.nn.utils.rnn import pad_sequence 8 | import torch.nn.functional as F 9 | from os.path import join 10 | 11 | from src.collect_batch import collect_audio_batch, collect_text_batch 12 | 13 | def create_dataset(tokenizer, ascending, name, path, bucketing, batch_size, 14 | train_split=None, dev_split=None, test_split=None, read_audio=False): 15 | ''' Interface for creating all kinds of dataset''' 16 | 17 | # Recognize corpus 18 | if name.lower() == 'librispeech': 19 | from corpus.preprocess_librispeech import LibriDataset as Dataset 20 | elif name.lower() == 'dlhlp': 21 | from corpus.preprocess_dlhlp import DLHLPDataset as Dataset 22 | else: 23 | raise NotImplementedError 24 | 25 | # Create dataset 26 | if train_split is not None: 27 | # Training mode 28 | mode = 'train' 29 | tr_loader_bs = 1 if bucketing and (not ascending) else batch_size 30 | bucket_size = batch_size if bucketing and (not ascending) else 1 # Ascending without bucketing 31 | 32 | if type(dev_split[0]) is not list: 33 | dv_set = Dataset(path,dev_split,tokenizer, 1, read_audio=read_audio) # Do not use bucketing for dev set 34 | dv_len = len(dv_set) 35 | else: 36 | dv_set = [] 37 | for ds in dev_split: 38 | dev_dir = '' 39 | if ds[0].lower() == 'librispeech': 40 | dev_dir = join(path, 'LibriSpeech') 41 | from corpus.preprocess_librispeech import LibriDataset as DevDataset 42 | else: 43 | raise NotImplementedError(ds[0]) 44 | dv_set.append(DevDataset(dev_dir,ds,tokenizer, 1)) 45 | dv_len = sum([len(s) for s in dv_set]) 46 | 47 | if path[-4:].lower() != name[-4:].lower(): 48 | tr_dir = join(path, name) 49 | else: 50 | tr_dir = path 51 | 52 | tr_set = Dataset(tr_dir,train_split,tokenizer, bucket_size, 53 | ascending=ascending, 54 | read_audio=read_audio) 55 | # Messages to show 56 | msg_list = _data_msg(name,path,train_split.__str__(),len(tr_set), 57 | dev_split.__str__(),dv_len,batch_size,bucketing) 58 | 59 | return tr_set, dv_set, tr_loader_bs, batch_size, mode, msg_list 60 | else: 61 | # Testing model 62 | mode = 'test' 63 | if path[-4:].lower() != name[-4:].lower(): 64 | tt_dir = join(path, name) 65 | else: 66 | tt_dir = path 67 | 68 | bucket_size = 1 69 | if type(dev_split[0]) is list: dev_split = dev_split[0] 70 | 71 | dv_set = Dataset(tt_dir,dev_split,tokenizer, bucket_size, read_audio=read_audio) # Do not use bucketing for dev set 72 | tt_set = Dataset(tt_dir,test_split,tokenizer, bucket_size, read_audio=read_audio) # Do not use bucketing for test set 73 | # Messages to show 74 | msg_list = _data_msg(name,tt_dir,dev_split.__str__(),len(dv_set), 75 | test_split.__str__(),len(tt_set),batch_size,False) 76 | msg_list = [m.replace('Dev','Test').replace('Train','Dev') for m in msg_list] 77 | return dv_set, tt_set, batch_size, batch_size, mode, msg_list 78 | 79 | def create_textset(tokenizer, train_split, dev_split, name, path, bucketing, batch_size): 80 | ''' Interface for creating all kinds of text dataset''' 81 | msg_list = [] 82 | 83 | # Recognize corpus 84 | if name.lower() == "librispeech": 85 | from corpus.preprocess_librispeech import LibriTextDataset as Dataset 86 | elif name.lower() == 'dlhlp': 87 | from corpus.preprocess_dlhlp import DLHLPTextDataset as Dataset 88 | else: 89 | raise NotImplementedError 90 | 91 | # Create dataset 92 | bucket_size = batch_size if bucketing else 1 93 | tr_loader_bs = 1 if bucketing else batch_size 94 | dv_set = Dataset(path,dev_split,tokenizer, 1) # Do not use bucketing for dev set 95 | tr_set = Dataset(path,train_split,tokenizer, bucket_size) 96 | 97 | # Messages to show 98 | msg_list = _data_msg(name,path,train_split.__str__(),len(tr_set), 99 | dev_split.__str__(),len(dv_set),batch_size,bucketing) 100 | 101 | return tr_set, dv_set, tr_loader_bs, batch_size, msg_list 102 | 103 | def load_dataset(n_jobs, use_gpu, pin_memory, ascending, corpus, audio, text): 104 | ''' Prepare dataloader for training/validation''' 105 | # Audio feature extractor 106 | audio_transform_tr, feat_dim = create_transform(audio.copy()) 107 | audio_transform_dv, feat_dim = create_transform(audio.copy()) 108 | 109 | # Text tokenizer 110 | tokenizer = load_text_encoder(**text) 111 | # Dataset (in testing mode, tr_set=dv_set, dv_set=tt_set) 112 | tr_set, dv_set, tr_loader_bs, dv_loader_bs, mode, data_msg = create_dataset(tokenizer,ascending,**corpus) 113 | 114 | # Collect function 115 | collect_tr = partial(collect_audio_batch, audio_transform=audio_transform_tr, mode=mode) 116 | collect_dv = partial(collect_audio_batch, audio_transform=audio_transform_dv, mode='test') 117 | 118 | # Shuffle/drop applied to training set only 119 | shuffle = (mode=='train' and not ascending) 120 | drop_last = shuffle 121 | # Create data loader 122 | tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=shuffle, drop_last=drop_last, collate_fn=collect_tr, 123 | num_workers=n_jobs, pin_memory=use_gpu) 124 | 125 | if type(dv_set) is list: 126 | _tmp_set = [] 127 | for ds in dv_set: 128 | _tmp_set.append(DataLoader(ds, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv, 129 | num_workers=n_jobs, pin_memory=pin_memory)) 130 | dv_set = _tmp_set 131 | else: 132 | dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv, 133 | num_workers=n_jobs, pin_memory=pin_memory) 134 | 135 | # Messages to show 136 | data_msg.append('I/O spec. | Audio Feature = {}\t| Feature Dim = {}\t| Token Type = {}\t| Vocab Size = {}'\ 137 | .format(audio['feat_type'],feat_dim,tokenizer.token_type,tokenizer.vocab_size)) 138 | return tr_set, dv_set, feat_dim, tokenizer.vocab_size, tokenizer, data_msg 139 | 140 | def load_textset(n_jobs, use_gpu, pin_memory, corpus, text): 141 | # Text tokenizer 142 | tokenizer = load_text_encoder(**text) 143 | # Dataset 144 | tr_set, dv_set, tr_loader_bs, dv_loader_bs, data_msg = create_textset(tokenizer,**corpus) 145 | collect_tr = partial(collect_text_batch,mode='train') 146 | collect_dv = partial(collect_text_batch,mode='dev') 147 | # Dataloader (Text data stored in RAM, no need num_workers) 148 | tr_set = DataLoader(tr_set, batch_size=tr_loader_bs, shuffle=True, drop_last=True, collate_fn=collect_tr, 149 | num_workers=0, pin_memory=use_gpu) 150 | dv_set = DataLoader(dv_set, batch_size=dv_loader_bs, shuffle=False, drop_last=False, collate_fn=collect_dv, 151 | num_workers=0, pin_memory=pin_memory) 152 | 153 | # Messages to show 154 | data_msg.append('I/O spec. | Token type = {}\t| Vocab size = {}'\ 155 | .format(tokenizer.token_type,tokenizer.vocab_size)) 156 | 157 | return tr_set, dv_set, tokenizer.vocab_size, tokenizer, data_msg 158 | 159 | 160 | def _data_msg(name,path,train_split,tr_set,dev_split,dv_set,batch_size,bucketing): 161 | ''' List msg for verbose function ''' 162 | msg_list = [] 163 | msg_list.append('Data spec. | Corpus = {} (from {})'.format(name,path)) 164 | msg_list.append(' | Train sets = {}\t| Number of utts = {}'.format(train_split,tr_set)) 165 | msg_list.append(' | Dev sets = {}\t| Number of utts = {}'.format(dev_split,dv_set)) 166 | msg_list.append(' | Batch size = {}\t\t| Bucketing = {}'.format(batch_size,bucketing)) 167 | return msg_list 168 | -------------------------------------------------------------------------------- /bin/test_asr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | from functools import partial 6 | from joblib import Parallel, delayed 7 | import yaml 8 | 9 | from src.solver import BaseSolver 10 | from src.asr import ASR 11 | from src.decode import BeamDecoder #, CTCBeamDecoder 12 | from src.data import load_dataset 13 | from src.audio import Delta, Postprocess 14 | 15 | class Solver(BaseSolver): 16 | ''' Solver for training''' 17 | def __init__(self,config,paras,mode): 18 | super().__init__(config,paras,mode) 19 | 20 | # ToDo : support tr/eval on different corpus 21 | # assert self.config['data']['corpus']['name'] == self.src_config['data']['corpus']['name'] 22 | # self.config['data']['corpus']['path'] = self.src_config['data']['corpus']['path'] 23 | self.config['data']['corpus']['bucketing'] = False 24 | # self.config['data']['corpus']['threshold'] = 100 25 | 26 | # The follow attribute should be identical to training config 27 | self.config['data']['audio'] = self.src_config['data']['audio'] 28 | self.config['data']['text'] = self.src_config['data']['text'] 29 | self.config['model'] = self.src_config['model'] 30 | 31 | # Output file 32 | self.output_file = str(self.ckpdir)+'_{}_{}.csv' 33 | 34 | # Override batch size for beam decoding 35 | self.greedy = self.config['decode']['beam_size'] == 1 36 | if not self.greedy: 37 | self.config['data']['corpus']['batch_size'] = 1 38 | else: 39 | pass 40 | 41 | self.step = 0 42 | 43 | def fetch_data(self, data): 44 | ''' Move data to device and compute text seq. length''' 45 | _, feat, feat_len, txt = data 46 | feat = feat.to(self.device) 47 | feat_len = feat_len.to(self.device) 48 | txt = txt.to(self.device) 49 | txt_len = torch.sum(txt!=0,dim=-1) 50 | 51 | return feat, feat_len, txt, txt_len 52 | 53 | def load_data(self): 54 | ''' Load data for training/validation, store tokenizer and input/output shape''' 55 | self.dv_set, self.tt_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ 56 | load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, False, **self.config['data']) 57 | self.verbose(msg) 58 | 59 | def set_model(self): 60 | ''' Setup ASR model ''' 61 | # Model 62 | self.model = ASR(self.feat_dim, self.vocab_size, **self.config['model']) 63 | 64 | # Plug-ins 65 | if ('emb' in self.config) and (self.config['emb']['enable']) \ 66 | and (self.config['emb']['fuse']>0): 67 | from src.plugin import EmbeddingRegularizer 68 | self.emb_decoder = EmbeddingRegularizer(self.tokenizer, self.model.dec_dim, **self.config['emb']) 69 | 70 | # Load target model in eval mode 71 | self.load_ckpt() 72 | 73 | # self.ctc_only = False 74 | if self.greedy: 75 | self.decoder = copy.deepcopy(self.model).to(self.device) 76 | else: 77 | # Beam decoder 78 | # TODO: CTC decoding function Hidden by author 79 | # if not self.model.enable_att or self.config['decode'].get('ctc_weight', 0.0) == 1.0: 80 | # For CTC only decoding (character level) 81 | 82 | # self.decoder = CTCBeamDecoder(self.model.to(self.device), 83 | # range(self.model.vocab_size), 84 | # self.config['decode']['beam_size'], 85 | # self.config['decode']['vocab_candidate']) 86 | # self.ctc_only = True 87 | # else: 88 | # self.decoder = BeamDecoder(self.model, self.emb_decoder, **self.config['decode']) 89 | self.decoder = BeamDecoder(self.model, self.emb_decoder, **self.config['decode']) 90 | 91 | self.verbose(self.decoder.create_msg()) 92 | del self.model 93 | del self.emb_decoder 94 | self.emb_decoder = None 95 | 96 | def greedy_decode(self, dv_set): 97 | results = [] 98 | for i,data in enumerate(dv_set): 99 | self.progress('Valid step - {}/{}'.format(i+1,len(dv_set))) 100 | # Fetch data 101 | feat, feat_len, txt, txt_len = self.fetch_data(data) 102 | 103 | # Forward model 104 | with torch.no_grad(): 105 | ctc_output, encode_len, att_output, att_align, dec_state = \ 106 | self.decoder( feat, feat_len, int(float(feat_len.max()) * self.config['decode']['max_len_ratio']), 107 | emb_decoder=self.emb_decoder) 108 | for j in range(len(txt)): 109 | idx = j + self.config['data']['corpus']['batch_size'] * i 110 | if att_output is not None: 111 | hyp_seqs = att_output[j].argmax(dim=-1).tolist() 112 | else: 113 | hyp_seqs = ctc_output[j].argmax(dim=-1).tolist() 114 | true_txt = txt[j] 115 | results.append((str(idx), [hyp_seqs], true_txt)) 116 | return results 117 | 118 | def exec(self): 119 | ''' Testing End-to-end ASR system ''' 120 | for s, ds in zip(['dev','test'],[self.dv_set,self.tt_set]): 121 | # Setup output 122 | self.cur_output_path = self.output_file.format(s,'output') 123 | with open(self.cur_output_path,'w',encoding='UTF-8') as f: 124 | f.write('idx\thyp\ttruth\n') 125 | 126 | if self.greedy: 127 | # Greedy decode 128 | self.verbose('Performing batch-wise greedy decoding on {} set, num of batch = {}.'.format(s,len(ds))) 129 | results = self.greedy_decode(ds) 130 | self.verbose('Results will be stored at {}'.format(self.cur_output_path)) 131 | self.write_hyp(results, self.cur_output_path, 'jizz') 132 | # elif self.ctc_only: 133 | # # TODO: CTC decode 134 | # self.verbose('Performing instance-wise CTC beam decoding on {} set, num of batch = {}.'.format(s,len(ds))) 135 | # # Minimal function to pickle 136 | # ctc_beam_decode_func = partial(ctc_beam_decode, model=copy.deepcopy(self.decoder), device=self.device) 137 | # # Parallel beam decode 138 | # results = Parallel(n_jobs=self.paras.njobs)(delayed(ctc_beam_decode_func)(data) for data in tqdm(ds)) 139 | # self.verbose('Results will be stored at {}'.format(self.cur_output_path)) 140 | # self.write_hyp(results, self.cur_output_path, 'jizz') 141 | # torch.cuda.empty_cache() 142 | else: 143 | # Additional output to store all beams 144 | self.cur_beam_path = self.output_file.format(s,'beam') 145 | with open(self.cur_beam_path,'w',encoding='UTF-8') as f: 146 | f.write('idx\tbeam\thyp\ttruth\n') 147 | self.verbose('Performing instance-wise beam decoding on {} set. (NOTE: use --njobs to speedup)'.format(s)) 148 | # Minimal function to pickle 149 | beam_decode_func = partial(beam_decode, model=copy.deepcopy(self.decoder).to(self.device), device=self.device) 150 | # Parallel beam decode 151 | results = Parallel(n_jobs=self.paras.njobs)(delayed(beam_decode_func)(data) for data in tqdm(ds)) 152 | self.verbose('Results/Beams will be stored at {}/{}.'.format(self.cur_output_path,self.cur_beam_path)) 153 | self.write_hyp(results,self.cur_output_path,self.cur_beam_path) 154 | torch.cuda.empty_cache() 155 | self.verbose('All done !') 156 | 157 | def write_hyp(self, results, best_path, beam_path): 158 | '''Record decoding results''' 159 | if self.greedy: 160 | ignore_repeat = not self.decoder.enable_att 161 | else: 162 | ignore_repeat = not self.decoder.asr.enable_att 163 | for name, hyp_seqs, truth in tqdm(results): 164 | hyp_seqs = [self.tokenizer.decode(hyp, ignore_repeat=ignore_repeat) for hyp in hyp_seqs] 165 | truth = self.tokenizer.decode(truth) 166 | with open(best_path,'a',encoding='UTF-8') as f: 167 | if type(hyp_seqs[0]) is not str: 168 | hyp_seqs[0] = ' ' 169 | if len(hyp_seqs[0]) == 0: 170 | hyp_seqs[0] = ' ' 171 | if len(truth) == 0: 172 | truth = ' ' 173 | f.write('\t'.join([name,hyp_seqs[0],truth])+'\n') 174 | if not self.greedy: 175 | with open(beam_path,'a',encoding='UTF-8') as f: 176 | for b,hyp in enumerate(hyp_seqs): 177 | f.write('\t'.join([name,str(b),hyp,truth])+'\n') 178 | 179 | def beam_decode(data, model, device): 180 | # Fetch data : move data/model to device 181 | name, feat, feat_len, txt = data 182 | feat = feat.to(device) 183 | feat_len = feat_len.to(device) 184 | txt = txt.to(device) 185 | txt_len = torch.sum(txt!=0,dim=-1) 186 | # Decode 187 | with torch.no_grad(): 188 | hyps = model(feat, feat_len) 189 | 190 | hyp_seqs = [hyp.outIndex for hyp in hyps] 191 | del hyps 192 | return (name[0], hyp_seqs, txt[0].cpu().tolist()) # Note: bs == 1 193 | 194 | def ctc_beam_decode(data, model, device): 195 | # Fetch data : move data/model to device 196 | name, feat, feat_len, txt = data 197 | feat = feat.to(device) 198 | feat_len = feat_len.to(device) 199 | # Decode 200 | with torch.no_grad(): 201 | hyp = model(feat, feat_len) 202 | 203 | return (name[0], [hyp], txt[0]) # Note: bs == 1 204 | -------------------------------------------------------------------------------- /config/README.md: -------------------------------------------------------------------------------- 1 | # Config File Documentation 2 | 3 | Description of parameters available in config files for [training](##Training-Configs) and [inference](##Inference-Configs). 4 | 5 | ## Training Configs 6 | 7 | Each config should include `data`/`hparas`/`model`, see [example on LibrSpeech](libri/asr_example.yaml). 8 | 9 | ### Data 10 | 11 | Options under this category are all data-related. 12 | 13 | - Corpus 14 | 15 | For each corpus, a corresponding source python file `.py` file should be placed at `corpus/`, checkout [`librispeech.py`](../corpus/librispeech.py) for example. 16 | 17 | |Parameter | Description | Note | 18 | |----------|-------------|------| 19 | | name | `str` name of corpus (used in [`data.py`](../src/data.py) to import the dataset defined in `.py`) | Available: `Librispeech`| 20 | | path | `str` path to the specified corpus, parsing file structure should be handled in `.py` | | 21 | | train_split| `list` which includes subsets of corpus used for training, accepted partition names should be defined in `.py`|| 22 | | dev_split | `list` which includes subsets of corpus used for validation, accepted partition names should be defined in `.py`|| 23 | | bucketing | `bool` to enable bucketing, i.e. similar length in each batch, should be implemented in `.py`| More effecient training but biased sampling| 24 | | batch_size | `int` Batch size for training/validation, will be send to Torch Dataloader || 25 | 26 | - Audio 27 | 28 | Hyperparameters of feature extraction performed on-the-fly mostly done by [torchaudio](https://pytorch.org/audio/), checkout [audio.py](../src/audio.py) for implementation. 29 | 30 | |Parameter | Description | Note | 31 | |----------|-------------|------| 32 | | feat_type| `str` name of audio feature to be used. Please note that MFCC required latest torch audio | Available: `fbank`/`mfcc`| 33 | | feat_dim| `int` dimensionality of audio feature, if you are not fimiliar with audio features, `40` for `fbank` and `13` for `mfcc` generally works|| 34 | | frame_length | `int` size of the window (millisecond) for feature extraction || 35 | | frame_shift | `int` hop size of the window (millisecond) for feature extraction || 36 | | dither | `float` dither when extracting feature | See [doc](https://pytorch.org/audio/compliance.kaldi.html#functions)| 37 | | apply_cmvn | `bool` to activate feature normalization | Using our own implementation | 38 | | delta_order | `int` to apply delta on feature.

`0`: do nothing, `1`: add delta, `2`: also add accelerate | Using our own implementation| 39 | | delta_window_size | `int` to specify the window size for delta calculation || 40 | 41 | - Text 42 | 43 | Options to specify how text should be encoded, subword models use [sentencepiece](https://github.com/google/sentencepiece) 44 | 45 | |Parameter | Description | Note | 46 | |----------|-------------|------| 47 | | mode | `str` text unit for encoding sentences | Available: `character`/`subword`/`word`| 48 | | vocab_file | `src` path to file containing vocabulary set| Please use [generate_vocab_file.py](../util/generate_vocab_file.py) to generate it| 49 | 50 | ### Hparas 51 | 52 | Options under this category are all training-related. 53 | 54 | | Parameter | Description | Note | 55 | |---------------|-------------|------| 56 | | valid_step | `int` interval, numbers of training step for each validation | 57 | | max_step | `int` total training step | 58 | | tf_start | `float` init. teacher forcing probability in scheduled sampling | | 59 | | tf_end | `float` final teacher forcing probability in scheduled sampling | | 60 | | tf_step | `int` number of steps to linearly decrease teacher forcing probability| 61 | | optimizer | `str` the name of pytorch optimizer for training| Tested: `Adam`/`Adadelta`| 62 | | lr | `float` learning rate for optimizer | | 63 | | eps | `float` epsilon for optimizer | | 64 | | lr_scheduler | `str` learning rate scheduler | Available: `fixed`/`warmup`| 65 | | curriculum | `int` numbers of epochs to perform curriculum learning (short uttr. first) | | 66 | 67 | ### Model 68 | 69 | 70 | - `ctc_weight`: weight of CTC in hybird CTC-Attention model (between `0~1`, `0`=disabled, `1` is under development) 71 | - Encoder 72 | 73 | | Parameter | Description | Note | 74 | |--------------|--------------|------| 75 | | prenet | `str` to employ VGG/CNN based encoder before RNN | [`vgg`](https://arxiv.org/pdf/1706.02737.pdf)/`cnn` | 76 | | module | `str` the name of recurrent unit for encoder RNN layer | Only `LSTM` was tested | 77 | | bidirection | `bool` to enable bidirectional RNN over input sequence | | 78 | | dim | `list` of number of cells for each RNN layer (per direction)| | 79 | | dropout | `list` of dropout probability for each RNN layer| Length must match `dim` | 80 | | layer_norm | `list` of `bool` to enable LayerNorm for each RNN layer | Not recommended | 81 | | proj | `list` of `bool` to enable linear projection after each RNN layer | Length must match `dim` | 82 | | sample_rate | `list` sample rate for each RNN layer. For each layer, the length of output on the time dimension will be input/`sample_rate`.| Length must match `dim` | 83 | | sample_style | `str` the down sampling mechanism. `concat` will concatenate multiple time steps according to sample rate into one vector, `drop` will drop the unsampled timesteps. | Available:`concat`/`drop` | 84 | 85 | - Attention 86 | 87 | | Parameter | Description | Note | 88 | |-----------|-------------|------| 89 | | mode | `str` attention mechanism, `dot` is the vanilla attention and `loc` indicates the [location-based attention](https://arxiv.org/abs/1506.07503). | Available: `dot`/`loc` | 90 | | dim | `int` dimension of all networks in attention | | 91 | | num_head | `int` number of head in [multi-head attention](https://arxiv.org/pdf/1706.03762.pdf), `1`: normal attention | Performance untested | 92 | | v_proj | `bool` to apply additional linear transform to encoder feature before weighted sum | | 93 | | temperature | `float` the temperature to controll sharpness of sofmax function in attention | | 94 | | loc_kernel_size | `int` kernel size for convolution in [location awared attention](https://arxiv.org/pdf/1506.07503.pdf) | For `loc` only | 95 | | loc_kernel_num | `int` number of kernel for convolution in [location awared attention](https://arxiv.org/pdf/1506.07503.pdf) | For `loc` only | 96 | 97 | - Decoder 98 | 99 | | Parameter | Description | Note | 100 | |--------------|--------------|------| 101 | | module | `str` the name of recurrent unit for encoder RNN layer | Only `LSTM` was tested | 102 | | dim | `int` number of cells in decoder| | 103 | | layer | `int` number of layers in decoder | | 104 | | dropout | `float` of dropout probability | | 105 | 106 | 107 | ### Additional Plug-ins 108 | 109 | The following mechanisms are our proposed methods, can be activate by inserting these parameters to config file 110 | 111 | - Emb 112 | 113 | | Parameter | Description | Note | 114 | |--------------|--------------|------| 115 | | enable | `bool` to enable word embedding regularization or fused decoding on ASR | | 116 | | src | `str` path to pre-trained embedding table or BERT model| The `bert-base-uncased` model fine-tuned on librispeech text data is available [here](https://drive.google.com/file/d/1Y1q5cH3yfuzMxQArR7WJ4gQD1GN7xrPh/view?usp=sharing) | 117 | | distance | `str` measurement of distance between word embedding and model output | Available: `CosEmb`/`MSE`(untested) | 118 | | weight | `float` $\lambda$ in paper | | 119 | | fuse | `float` $\lambda_f$ in paper| | 120 | | fuse_normalize| `bool` to normalize output before Cosine-Softmax in paper, should be on when `distance==CosEmb` | | 121 | | bert | `str` name of BERT model if using BERT as target embedding, e.g. `bert-base-uncased`| mutually exclusive to `fuse>0`| 122 | 123 | 124 | 125 | ## Inference Configs 126 | 127 | Each config should include `src`/`decode`/`data`, see [example on LibrSpeech](libri/decode_example.yaml). 128 | Note that most of the options (audio feature, model structure, etc.) will be imported from the training config specified in `src`. 129 | 130 | ### Src 131 | 132 | Specify the ASR to use in decoding process. 133 | 134 | | Parameter | Description | Note | 135 | |-----------|--------------|------| 136 | | ckpt | `str` path to ASR checkpoint to be load | | 137 | | config | `str` path to ASR training config which belongs to the checkpoint| | 138 | 139 | ### Data 140 | 141 | - Corpus 142 | 143 | |Parameter | Description | Note | 144 | |----------|-------------|------| 145 | | name | See `corpus` section in training config|| 146 | | dev_split| See `corpus` section in training config|| 147 | | test_split| Like dev set, ASR will perform exactly same decoding process on this set, should also be defined by user like train/dev set|| 148 | 149 | 150 | ### Decode 151 | 152 | Options for decoding that *will* dramatically change the decoding result. 153 | 154 | | Parameter | Description | Note | 155 | |-----------|--------------|------| 156 | | beam_size | `int` beam size for beam search algorithm, be careful that larger beam increases memory usage|| 157 | | min_len_ratio | `float` the minimum length of any hypothesis will be `min_len_ratio` x `input length` | 158 | | max_len_ratio | `float` the maximum decoding time step will be `max_len_ratio` x `input length`, hypothesis will end if `` is predicted or maximum decoding step reached | 159 | | lm_path | `str` the path to pre-trained LM for joint decoding, **this is not language model rescoring**| [paper](https://arxiv.org/pdf/1706.02737.pdf)| 160 | | lm_config | `str` the path to the config of pre-trained LM for joint decoding| [paper](https://arxiv.org/pdf/1706.02737.pdf) | 161 | | lm_weight | `float` the weight for RNNLM in joint decoding| [paper](https://arxiv.org/pdf/1706.02737.pdf), slower inference | 162 | | ctc_weight| `float` the weight for CTC network in joint decoding, this will only be available if `ctc_weight` was not zero in training config | [paper](https://arxiv.org/pdf/1706.02737.pdf), slower inference | 163 | 164 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import abc 4 | import math 5 | import yaml 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from src.option import default_hparas 10 | from src.util import human_format, Timer 11 | 12 | class BaseSolver(): 13 | ''' 14 | Prototype Solver for all kinds of tasks 15 | Arguments 16 | config - yaml-styled config 17 | paras - argparse outcome 18 | ''' 19 | def __init__(self, config, paras, mode): 20 | # General Settings 21 | self.config = config 22 | self.paras = paras 23 | self.mode = mode 24 | for k,v in default_hparas.items(): 25 | setattr(self,k,v) 26 | self.device = torch.device('cuda:' + str(paras.cuda)) if self.paras.gpu and torch.cuda.is_available() else torch.device('cpu') 27 | self.amp = paras.amp 28 | 29 | # Name experiment 30 | self.exp_name = paras.name 31 | if self.exp_name is None: 32 | self.exp_name = paras.config.split('/')[-1].replace('.yaml','') # By default, exp is named after config file 33 | if mode == 'train': 34 | self.exp_name += '_sd{}'.format(paras.seed) 35 | 36 | # Plugin list 37 | self.emb_decoder = None 38 | 39 | self.transfer_learning = False 40 | # Transfer Learning 41 | if (self.config.get('transfer', None) is not None) and mode == 'train': 42 | self.transfer_learning = True 43 | self.train_enc = self.config['transfer']['train_enc'] 44 | self.train_dec = self.config['transfer']['train_dec'] 45 | self.fix_enc = [i for i in range(4) if i not in self.config['transfer']['train_enc'] ] 46 | self.fix_dec = not self.config['transfer']['train_dec'] 47 | log_name = '_T_{}_{}'.format(''.join([str(l) for l in self.train_enc]), '1' if self.train_dec else '0') 48 | self.save_name = '_tune-{}-{}'.format(''.join([str(l) for l in self.train_enc]), '1' if self.train_dec else '0') 49 | 50 | if self.paras.seed > 0: 51 | self.save_name += '-sd' + str(self.paras.seed) 52 | 53 | if mode == 'train': 54 | # Filepath setup 55 | os.makedirs(paras.ckpdir, exist_ok=True) 56 | self.ckpdir = os.path.join(paras.ckpdir,self.exp_name) 57 | os.makedirs(self.ckpdir, exist_ok=True) 58 | 59 | # Logger settings 60 | self.logdir = os.path.join(paras.logdir,self.exp_name + (log_name if self.transfer_learning else '')) 61 | self.log = SummaryWriter(self.logdir, flush_secs = self.TB_FLUSH_FREQ) 62 | self.timer = Timer() 63 | 64 | # Hyperparameters 65 | self.step = 0 66 | self.valid_step = config['hparas']['valid_step'] 67 | self.max_step = config['hparas']['max_step'] 68 | 69 | self.verbose('Exp. name : {}'.format(self.exp_name)) 70 | self.verbose('Loading data... large corpus may took a while.') 71 | 72 | elif mode == 'test': 73 | # Output path 74 | os.makedirs(paras.outdir, exist_ok=True) 75 | os.makedirs(os.path.join(paras.outdir, 'dev_out'), exist_ok=True) 76 | os.makedirs(os.path.join(paras.outdir, 'test_out'), exist_ok=True) 77 | self.ckpdir = os.path.join(paras.outdir,self.exp_name) 78 | 79 | # Load training config to get acoustic feat, text encoder and build model 80 | self.src_config = yaml.load(open(config['src']['config'],'r'), Loader=yaml.FullLoader) 81 | self.paras.load = config['src']['ckpt'] 82 | 83 | self.verbose('Evaluating result of tr. config @ {}'.format(config['src']['config'])) 84 | 85 | def backward(self, loss, time_cnt=True, optimize=True): 86 | ''' 87 | Standard backward step with self.timer and debugger 88 | Arguments 89 | loss - the loss to perform loss.backward() 90 | ''' 91 | if time_cnt: 92 | self.timer.set() 93 | loss.backward() 94 | grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.GRAD_CLIP) 95 | 96 | if math.isnan(grad_norm): 97 | self.verbose('Error : grad norm is NaN @ step '+str(self.step)) 98 | else: 99 | if optimize: 100 | self.optimizer.step() 101 | if time_cnt: 102 | self.timer.cnt('bw') 103 | return grad_norm 104 | 105 | def load_ckpt(self): 106 | ''' Load ckpt if --load option is specified ''' 107 | if self.paras.load: 108 | # Load weights 109 | ckpt = torch.load(self.paras.load, map_location=self.device if self.mode=='train' else 'cpu') 110 | self.model.load_state_dict(ckpt['model']) 111 | if self.emb_decoder is not None: 112 | self.emb_decoder.load_state_dict(ckpt['emb_decoder']) 113 | #if self.amp: 114 | # amp.load_state_dict(ckpt['amp']) 115 | # Load task-dependent items 116 | if self.mode == 'train': 117 | self.step = ckpt['global_step'] 118 | if self.transfer_learning == False: 119 | self.optimizer.load_opt_state_dict(ckpt['optimizer']) 120 | self.verbose('Load ckpt from {}, restarting at step {}'.format(self.paras.load,self.step)) 121 | else: 122 | for k,v in ckpt.items(): 123 | if type(v) is float: 124 | metric, score = k,v 125 | self.model.eval() 126 | if self.emb_decoder is not None: 127 | self.emb_decoder.eval() 128 | self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load,metric,score * 100)) 129 | 130 | def verbose(self,msg): 131 | ''' Verbose function for print information to stdout''' 132 | if self.paras.verbose: 133 | if type(msg)==list: 134 | for m in msg: 135 | print('[INFO]',m.ljust(100)) 136 | else: 137 | print('[INFO]',msg.ljust(100)) 138 | 139 | def progress(self,msg): 140 | ''' Verbose function for updating progress on stdout (do not include newline) ''' 141 | if self.paras.verbose: 142 | sys.stdout.write("\033[K") # Clear line 143 | print('[{}] {}'.format(human_format(self.step),msg),end='\r') 144 | 145 | def write_log(self,log_name,log_dict): 146 | ''' 147 | Write log to TensorBoard 148 | log_name - Name of tensorboard variable 149 | log_value - / Value of variable (e.g. dict of losses), passed if value = None 150 | ''' 151 | if type(log_dict) is dict: 152 | log_dict = {key:val for key, val in log_dict.items() if (val is not None and not math.isnan(val))} 153 | if log_dict is None: 154 | pass 155 | elif len(log_dict)>0: 156 | if 'align' in log_name or 'spec' in log_name: 157 | img, form = log_dict 158 | self.log.add_image(log_name,img, global_step=self.step, dataformats=form) 159 | elif 'text' in log_name or 'hyp' in log_name: 160 | self.log.add_text(log_name, log_dict, self.step) 161 | elif 'wav' in log_name: 162 | waveform, sr = log_dict 163 | waveform = torch.FloatTensor(waveform) 164 | if waveform.dim() == 1: 165 | waveform = waveform.unsqueeze(0) 166 | self.log.add_audio(log_name, waveform, global_step=self.step, sample_rate=sr) 167 | else: 168 | self.log.add_scalars(log_name,log_dict,self.step) 169 | 170 | def save_checkpoint(self, f_name, metric, score, name=''): 171 | '''' 172 | Ckpt saver 173 | f_name - the name phnof ckpt file (w/o prefix) to store, overwrite if existed 174 | score - The value of metric used to evaluate model 175 | ''' 176 | ckpt_path = os.path.join(self.ckpdir, f_name) 177 | full_dict = { 178 | "model": self.model.state_dict(), 179 | "optimizer": self.optimizer.get_opt_state_dict(), 180 | "global_step": self.step, 181 | metric: score 182 | } 183 | # Additional modules to save 184 | #if self.amp: 185 | # full_dict['amp'] = self.amp_lib.state_dict() 186 | if self.emb_decoder is not None: 187 | full_dict['emb_decoder'] = self.emb_decoder.state_dict() 188 | 189 | torch.save(full_dict, ckpt_path) 190 | if len(name) > 0: 191 | name = ' on ' + name 192 | ckpt_path = '/'.join(ckpt_path.split('/')[-2:]) # Set how long the path name to be shown. 193 | self.verbose("Saved ckpt (step = {}, {} = {:.2f}) @ {}{}".\ 194 | format(human_format(self.step),metric,score,ckpt_path,name)) 195 | 196 | def enable_apex(self): 197 | if self.amp: 198 | # Enable mixed precision computation (ToDo: Save/Load amp) 199 | from apex import amp 200 | self.amp_lib = amp 201 | self.verbose("AMP enabled (check https://github.com/NVIDIA/apex for more details).") 202 | self.model, self.optimizer.opt = self.amp_lib.initialize(self.model, self.optimizer.opt, opt_level='O1') 203 | 204 | 205 | # ----------------------------------- Abtract Methods ------------------------------------------ # 206 | @abc.abstractmethod 207 | def load_data(self): 208 | ''' 209 | Called by main to load all data 210 | After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) 211 | No return value 212 | ''' 213 | raise NotImplementedError 214 | 215 | @abc.abstractmethod 216 | def set_model(self): 217 | ''' 218 | Called by main to set models 219 | After this call, model related attributes should be setup (e.g. self.l2_loss) 220 | The followings MUST be setup 221 | - self.model (torch.nn.Module) 222 | - self.optimizer (src.Optimizer), 223 | init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas']) 224 | Loading pre-trained model should also be performed here 225 | No return value 226 | ''' 227 | raise NotImplementedError 228 | 229 | @abc.abstractmethod 230 | def exec(self): 231 | ''' 232 | Called by main to execute training/inference 233 | ''' 234 | raise NotImplementedError 235 | 236 | 237 | -------------------------------------------------------------------------------- /src/decode.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from src.lm import RNNLM 8 | from src.ctc import CTCPrefixScore 9 | 10 | CTC_BEAM_RATIO = 1.5 # DO NOT CHANGE THIS, MAY CAUSE OOM 11 | LOG_ZERO = -10000000.0 # Log-zero for CTC 12 | 13 | 14 | class BeamDecoder(nn.Module): 15 | ''' Beam decoder for ASR ''' 16 | 17 | def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, 18 | lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): 19 | super().__init__() 20 | # Setup 21 | self.beam_size = beam_size 22 | self.min_len_ratio = min_len_ratio 23 | self.max_len_ratio = max_len_ratio 24 | self.asr = asr 25 | 26 | # ToDo : implement pure ctc decode 27 | assert self.asr.enable_att 28 | 29 | # Additional decoding modules 30 | self.apply_ctc = ctc_weight > 0 31 | if self.apply_ctc: 32 | assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' 33 | self.ctc_w = ctc_weight 34 | self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) 35 | 36 | self.apply_lm = lm_weight > 0 37 | if self.apply_lm: 38 | self.lm_w = lm_weight 39 | self.lm_path = lm_path 40 | lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) 41 | self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) 42 | self.lm.load_state_dict(torch.load( 43 | self.lm_path, map_location='cpu')['model']) 44 | self.lm.eval() 45 | 46 | self.apply_emb = emb_decoder is not None 47 | if self.apply_emb: 48 | self.emb_decoder = emb_decoder 49 | 50 | def create_msg(self): 51 | msg = ['Decode spec| Beam size = {}\t| Min/Max len ratio = {}/{}'.format( 52 | self.beam_size, self.min_len_ratio, self.max_len_ratio)] 53 | if self.apply_ctc: 54 | msg.append( 55 | ' |Joint CTC decoding enabled \t| weight = {:.2f}\t'.format(self.ctc_w)) 56 | if self.apply_lm: 57 | msg.append(' |Joint LM decoding enabled \t| weight = {:.2f}\t| src = {}'.format( 58 | self.lm_w, self.lm_path)) 59 | if self.apply_emb: 60 | msg.append(' |Joint Emb. decoding enabled \t| weight = {:.2f}'.format( 61 | self.lm_w, self.emb_decoder.fuse_lambda.mean().cpu().item())) 62 | 63 | return msg 64 | 65 | def forward(self, audio_feature, feature_len): 66 | # Init. 67 | assert audio_feature.shape[0] == 1, "Batchsize == 1 is required for beam search" 68 | batch_size = audio_feature.shape[0] 69 | device = audio_feature.device 70 | dec_state = self.asr.decoder.init_state( 71 | batch_size) # Init zero states 72 | self.asr.attention.reset_mem() # Flush attention mem 73 | # Max output len set w/ hyper param. 74 | max_output_len = int( 75 | np.ceil(feature_len.cpu().item()*self.max_len_ratio)) 76 | # Min output len set w/ hyper param. 77 | min_output_len = int( 78 | np.ceil(feature_len.cpu().item()*self.min_len_ratio)) 79 | # Store attention map if location-aware 80 | store_att = self.asr.attention.mode == 'loc' 81 | prev_token = torch.zeros( 82 | (batch_size, 1), dtype=torch.long, device=device) # Start w/ 83 | # Cache of beam search 84 | final_hypothesis, next_top_hypothesis = [], [] 85 | # Incase ctc is disabled 86 | ctc_state, ctc_prob, candidates, lm_state = None, None, None, None 87 | 88 | # Encode 89 | encode_feature, encode_len = self.asr.encoder( 90 | audio_feature, feature_len) 91 | 92 | # CTC decoding 93 | if self.apply_ctc: 94 | ctc_output = F.log_softmax( 95 | self.asr.ctc_layer(encode_feature), dim=-1) 96 | ctc_prefix = CTCPrefixScore(ctc_output) 97 | ctc_state = ctc_prefix.init_state() 98 | 99 | # Start w/ empty hypothesis 100 | prev_top_hypothesis = [Hypothesis(decoder_state=dec_state, output_seq=[], 101 | output_scores=[], lm_state=None, ctc_prob=0, 102 | ctc_state=ctc_state, att_map=None)] 103 | # Attention decoding 104 | for t in range(max_output_len): 105 | for hypothesis in prev_top_hypothesis: 106 | # Resume previous step 107 | prev_token, prev_dec_state, prev_attn, prev_lm_state, prev_ctc_state = hypothesis.get_state( 108 | device) 109 | self.asr.set_state(prev_dec_state, prev_attn) 110 | 111 | # Normal asr forward 112 | attn, context = self.asr.attention( 113 | self.asr.decoder.get_query(), encode_feature, encode_len) 114 | asr_prev_token = self.asr.pre_embed(prev_token) 115 | decoder_input = torch.cat([asr_prev_token, context], dim=-1) 116 | cur_prob, d_state = self.asr.decoder(decoder_input) 117 | 118 | # Embedding fusion (output shape 1xV) 119 | if self.apply_emb: 120 | _, cur_prob = self.emb_decoder( d_state, cur_prob, return_loss=False) 121 | else: 122 | cur_prob = F.log_softmax(cur_prob, dim=-1) 123 | 124 | # Perform CTC prefix scoring on limited candidates (else OOM easily) 125 | if self.apply_ctc: 126 | # TODO : Check the performance drop for computing part of candidates only 127 | _, ctc_candidates = cur_prob.squeeze(0).topk(self.ctc_beam_size, dim=-1) 128 | candidates = ctc_candidates.cpu().tolist() 129 | ctc_prob, ctc_state = ctc_prefix.cheap_compute( 130 | hypothesis.outIndex, prev_ctc_state, candidates) 131 | # TODO : study why ctc_char (slightly) > 0 sometimes 132 | ctc_char = torch.FloatTensor(ctc_prob - hypothesis.ctc_prob).to(device) 133 | 134 | # Combine CTC score and Attention score (HACK: focus on candidates, block others) 135 | hack_ctc_char = torch.zeros_like(cur_prob).data.fill_(LOG_ZERO) 136 | for idx, char in enumerate(candidates): 137 | hack_ctc_char[0, char] = ctc_char[idx] 138 | cur_prob = (1-self.ctc_w)*cur_prob + self.ctc_w*hack_ctc_char # ctc_char 139 | cur_prob[0, 0] = LOG_ZERO # Hack to ignore 140 | 141 | # Joint RNN-LM decoding 142 | if self.apply_lm: 143 | # assuming batch size always 1, resulting 1x1 144 | lm_input = prev_token.unsqueeze(1) 145 | lm_output, lm_state = self.lm( 146 | lm_input, torch.ones([batch_size]), hidden=prev_lm_state) 147 | # assuming batch size always 1, resulting 1xV 148 | lm_output = lm_output.squeeze(0) 149 | cur_prob += self.lm_w*lm_output.log_softmax(dim=-1) 150 | 151 | # Beam search 152 | # Note: Ignored batch dim. 153 | topv, topi = cur_prob.squeeze(0).topk(self.beam_size) 154 | prev_attn = self.asr.attention.att_layer.prev_att.cpu() if store_att else None 155 | final, top = hypothesis.addTopk(topi, topv, self.asr.decoder.get_state(), att_map=prev_attn, 156 | lm_state=lm_state, ctc_state=ctc_state, ctc_prob=ctc_prob, 157 | ctc_candidates=candidates) 158 | # Move complete hyps. out 159 | if final is not None and (t >= min_output_len): 160 | final_hypothesis.append(final) 161 | if self.beam_size == 1: 162 | return final_hypothesis 163 | next_top_hypothesis.extend(top) 164 | 165 | # Sort for top N beams 166 | next_top_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) 167 | prev_top_hypothesis = next_top_hypothesis[:self.beam_size] 168 | next_top_hypothesis = [] 169 | 170 | # Rescore all hyp (finished/unfinished) 171 | final_hypothesis += prev_top_hypothesis 172 | final_hypothesis.sort(key=lambda o: o.avgScore(), reverse=True) 173 | 174 | return final_hypothesis[:self.beam_size] 175 | 176 | 177 | class Hypothesis: 178 | '''Hypothesis for beam search decoding. 179 | Stores the history of label sequence & score 180 | Stores the previous decoder state, ctc state, ctc score, lm state and attention map (if necessary)''' 181 | 182 | def __init__(self, decoder_state, output_seq, output_scores, lm_state, ctc_state, ctc_prob, att_map): 183 | assert len(output_seq) == len(output_scores) 184 | # attention decoder 185 | self.decoder_state = decoder_state 186 | self.att_map = att_map 187 | 188 | # RNN language model 189 | if type(lm_state) is tuple: 190 | self.lm_state = (lm_state[0].cpu(), 191 | lm_state[1].cpu()) # LSTM state 192 | elif lm_state is None: 193 | self.lm_state = None # Init state 194 | else: 195 | self.lm_state = lm_state.cpu() # GRU state 196 | 197 | # Previous outputs 198 | self.output_seq = output_seq # Prefix, List of list 199 | self.output_scores = output_scores # Prefix score, list of float 200 | 201 | # CTC decoding 202 | self.ctc_state = ctc_state # List of np 203 | self.ctc_prob = ctc_prob # List of float 204 | 205 | def avgScore(self): 206 | '''Return the averaged log probability of hypothesis''' 207 | assert len(self.output_scores) != 0 208 | return sum(self.output_scores) / len(self.output_scores) 209 | 210 | def addTopk(self, topi, topv, decoder_state, att_map=None, 211 | lm_state=None, ctc_state=None, ctc_prob=0.0, ctc_candidates=[]): 212 | '''Expand current hypothesis with a given beam size''' 213 | new_hypothesis = [] 214 | term_score = None 215 | ctc_s, ctc_p = None, None 216 | beam_size = topi.shape[-1] 217 | 218 | for i in range(beam_size): 219 | # Detect 220 | if topi[i].item() == 1: 221 | term_score = topv[i].cpu() 222 | continue 223 | 224 | idxes = self.output_seq[:] # pass by value 225 | scores = self.output_scores[:] # pass by value 226 | idxes.append(topi[i].cpu()) 227 | scores.append(topv[i].cpu()) 228 | if ctc_state is not None: 229 | # ToDo: Handle out-of-candidate case. 230 | idx = ctc_candidates.index(topi[i].item()) 231 | ctc_s = ctc_state[idx, :, :] 232 | ctc_p = ctc_prob[idx] 233 | new_hypothesis.append(Hypothesis(decoder_state, 234 | output_seq=idxes, output_scores=scores, lm_state=lm_state, 235 | ctc_state=ctc_s, ctc_prob=ctc_p, att_map=att_map)) 236 | if term_score is not None: 237 | self.output_seq.append(torch.tensor(1)) 238 | self.output_scores.append(term_score) 239 | return self, new_hypothesis 240 | return None, new_hypothesis 241 | 242 | def get_state(self, device): 243 | prev_token = self.output_seq[-1] if len(self.output_seq) != 0 else 0 244 | prev_token = torch.LongTensor([prev_token]).to(device) 245 | att_map = self.att_map.to(device) if self.att_map is not None else None 246 | if type(self.lm_state) is tuple: 247 | lm_state = (self.lm_state[0].to(device), 248 | self.lm_state[1].to(device)) # LSTM state 249 | elif self.lm_state is None: 250 | lm_state = None # Init state 251 | else: 252 | lm_state = self.lm_state.to( 253 | device) # GRU state 254 | return prev_token, self.decoder_state, att_map, lm_state, self.ctc_state 255 | 256 | @property 257 | def outIndex(self): 258 | return [i.item() for i in self.output_seq] 259 | -------------------------------------------------------------------------------- /bin/train_asr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import yaml 4 | 5 | from src.solver import BaseSolver 6 | 7 | from src.asr import ASR 8 | from src.optim import Optimizer 9 | from src.data import load_dataset 10 | from src.util import human_format, cal_er, feat_to_fig 11 | from src.audio import Delta, Postprocess 12 | 13 | EMPTY_CACHE_STEP = 100 14 | 15 | class Solver(BaseSolver): 16 | ''' Solver for training''' 17 | def __init__(self,config,paras,mode): 18 | super().__init__(config,paras,mode) 19 | 20 | # Curriculum learning affects data loader 21 | self.curriculum = self.config['hparas']['curriculum'] 22 | self.val_mode = self.config['hparas']['val_mode'].lower() 23 | self.WER = 'per' if self.val_mode == 'per' else 'wer' 24 | 25 | def fetch_data(self, data, train=False): 26 | ''' Move data to device and compute text seq. length''' 27 | # feat: B x T x D 28 | _, feat, feat_len, txt = data 29 | feat = feat.to(self.device) 30 | feat_len = feat_len.to(self.device) 31 | txt = txt.to(self.device) 32 | txt_len = torch.sum(txt!=0,dim=-1) 33 | 34 | return feat, feat_len, txt, txt_len 35 | 36 | def load_data(self): 37 | ''' Load data for training/validation, store tokenizer and input/output shape''' 38 | self.tr_set, self.dv_set, self.feat_dim, self.vocab_size, self.tokenizer, msg = \ 39 | load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, 40 | self.curriculum>0, 41 | **self.config['data']) 42 | self.verbose(msg) 43 | 44 | # Dev set sames 45 | self.dv_names = [] 46 | if type(self.dv_set) is list: 47 | for ds in self.config['data']['corpus']['dev_split']: 48 | self.dv_names.append(ds[0]) 49 | else: 50 | self.dv_names = self.config['data']['corpus']['dev_split'][0] 51 | 52 | # Logger settings 53 | if type(self.dv_names) is str: 54 | self.best_wer = {'att':{self.dv_names:3.0}, 55 | 'ctc':{self.dv_names:3.0}} 56 | else: 57 | self.best_wer = {'att': {},'ctc': {}} 58 | for name in self.dv_names: 59 | self.best_wer['att'][name] = 3.0 60 | self.best_wer['ctc'][name] = 3.0 61 | 62 | def set_model(self): 63 | ''' Setup ASR model and optimizer ''' 64 | # Model 65 | self.model = ASR(self.feat_dim, self.vocab_size, **self.config['model']).to(self.device) 66 | self.verbose(self.model.create_msg()) 67 | model_paras = [{'params':self.model.parameters()}] 68 | 69 | # Losses 70 | self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) 71 | self.ctc_loss = torch.nn.CTCLoss(blank=0, zero_infinity=False) # Note: zero_infinity=False is unstable? 72 | 73 | # Plug-ins 74 | self.emb_fuse = False 75 | self.emb_reg = ('emb' in self.config) and (self.config['emb']['enable']) 76 | if self.emb_reg: 77 | from src.plugin import EmbeddingRegularizer 78 | self.emb_decoder = EmbeddingRegularizer(self.tokenizer, self.model.dec_dim, **self.config['emb']).to(self.device) 79 | model_paras.append({'params':self.emb_decoder.parameters()}) 80 | self.emb_fuse = self.emb_decoder.apply_fuse 81 | if self.emb_fuse: 82 | self.seq_loss = torch.nn.NLLLoss(ignore_index=0) 83 | self.verbose(self.emb_decoder.create_msg()) 84 | 85 | # Optimizer 86 | self.optimizer = Optimizer(model_paras, **self.config['hparas']) 87 | self.verbose(self.optimizer.create_msg()) 88 | 89 | # Enable AMP if needed 90 | self.enable_apex() 91 | 92 | # Transfer Learning 93 | if self.transfer_learning: 94 | self.verbose('Apply transfer learning: ') 95 | self.verbose(' Train encoder layers: {}'.format(self.train_enc)) 96 | self.verbose(' Train decoder: {}'.format(self.train_dec)) 97 | self.verbose(' Save name: {}'.format(self.save_name)) 98 | 99 | # Automatically load pre-trained model if self.paras.load is given 100 | self.load_ckpt() 101 | 102 | def exec(self): 103 | ''' Training End-to-end ASR system ''' 104 | self.verbose('Total training steps {}.'.format(human_format(self.max_step))) 105 | if self.transfer_learning: 106 | self.model.encoder.fix_layers(self.fix_enc) 107 | if self.fix_dec and self.model.enable_att: 108 | self.model.decoder.fix_layers() 109 | if self.fix_dec and self.model.enable_ctc: 110 | self.model.fix_ctc_layer() 111 | 112 | n_epochs = 0 113 | self.timer.set() 114 | 115 | while self.step< self.max_step: 116 | ctc_loss, att_loss, emb_loss = None, None, None 117 | # Renew dataloader to enable random sampling 118 | if self.curriculum>0 and n_epochs==self.curriculum: 119 | self.verbose('Curriculum learning ends after {} epochs, starting random sampling.'.format(n_epochs)) 120 | self.tr_set, _, _, _, _, _ = \ 121 | load_dataset(self.paras.njobs, self.paras.gpu, self.paras.pin_memory, 122 | False, **self.config['data']) 123 | for data in self.tr_set: 124 | # Pre-step : update tf_rate/lr_rate and do zero_grad 125 | tf_rate = self.optimizer.pre_step(self.step) 126 | total_loss = 0 127 | 128 | # Fetch data 129 | feat, feat_len, txt, txt_len = self.fetch_data(data, train=True) 130 | self.timer.cnt('rd') 131 | 132 | # Forward model 133 | # Note: txt should NOT start w/ 134 | ctc_output, encode_len, att_output, att_align, dec_state = \ 135 | self.model( feat, feat_len, max(txt_len), tf_rate=tf_rate, 136 | teacher=txt, get_dec_state=self.emb_reg) 137 | # Clear not used objects 138 | del att_align 139 | 140 | # Plugins 141 | if self.emb_reg: 142 | emb_loss, fuse_output = self.emb_decoder( dec_state, att_output, label=txt) 143 | total_loss += self.emb_decoder.weight*emb_loss 144 | else: 145 | del dec_state 146 | 147 | # Compute all objectives 148 | if ctc_output is not None: 149 | if self.paras.cudnn_ctc: 150 | ctc_loss = self.ctc_loss(ctc_output.transpose(0,1), 151 | txt.to_sparse().values().to(device='cpu',dtype=torch.int32), 152 | [ctc_output.shape[1]]*len(ctc_output), 153 | #[int(encode_len.max()) for _ in encode_len], 154 | txt_len.cpu().tolist()) 155 | else: 156 | ctc_loss = self.ctc_loss(ctc_output.transpose(0,1), txt, encode_len, txt_len) 157 | total_loss += ctc_loss*self.model.ctc_weight 158 | del encode_len 159 | 160 | if att_output is not None: 161 | b,t,_ = att_output.shape 162 | att_output = fuse_output if self.emb_fuse else att_output 163 | att_loss = self.seq_loss(att_output.view(b*t,-1),txt.view(-1)) 164 | # Sum each uttr and devide by length then mean over batch 165 | # att_loss = torch.mean(torch.sum(att_loss.view(b,t),dim=-1)/torch.sum(txt!=0,dim=-1).float()) 166 | total_loss += att_loss*(1-self.model.ctc_weight) 167 | 168 | self.timer.cnt('fw') 169 | 170 | # Backprop 171 | grad_norm = self.backward(total_loss) 172 | self.step+=1 173 | 174 | # Logger 175 | if (self.step==1) or (self.step%self.PROGRESS_STEP==0): 176 | self.progress('Tr stat | Loss - {:.2f} | Grad. Norm - {:.2f} | {}'\ 177 | .format(total_loss.cpu().item(),grad_norm,self.timer.show())) 178 | self.write_log('emb_loss',{'tr':emb_loss}) 179 | if att_output is not None: 180 | self.write_log('loss',{'tr_att':att_loss}) 181 | self.write_log(self.WER,{'tr_att':cal_er(self.tokenizer,att_output,txt)}) 182 | self.write_log( 'cer',{'tr_att':cal_er(self.tokenizer,att_output,txt,mode='cer')}) 183 | if ctc_output is not None: 184 | self.write_log('loss',{'tr_ctc':ctc_loss}) 185 | self.write_log(self.WER,{'tr_ctc':cal_er(self.tokenizer,ctc_output,txt,ctc=True)}) 186 | self.write_log( 'cer',{'tr_ctc':cal_er(self.tokenizer,ctc_output,txt,mode='cer',ctc=True)}) 187 | self.write_log('ctc_text_train',self.tokenizer.decode(ctc_output[0].argmax(dim=-1).tolist(), 188 | ignore_repeat=True)) 189 | # if self.step==1 or self.step % (self.PROGRESS_STEP * 5) == 0: 190 | # self.write_log('spec_train',feat_to_fig(feat[0].transpose(0,1).cpu().detach(), spec=True)) 191 | del total_loss 192 | 193 | if self.emb_fuse: 194 | if self.emb_decoder.fuse_learnable: 195 | self.write_log('fuse_lambda',{'emb':self.emb_decoder.get_weight()}) 196 | self.write_log('fuse_temp',{'temp':self.emb_decoder.get_temp()}) 197 | 198 | # Validation 199 | if (self.step==1) or (self.step%self.valid_step == 0): 200 | if type(self.dv_set) is list: 201 | for dv_id in range(len(self.dv_set)): 202 | self.validate(self.dv_set[dv_id], self.dv_names[dv_id]) 203 | else: 204 | self.validate(self.dv_set, self.dv_names) 205 | 206 | # End of step 207 | # if self.step % EMPTY_CACHE_STEP == 0: 208 | # Empty cuda cache after every fixed amount of steps 209 | torch.cuda.empty_cache() # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 210 | self.timer.set() 211 | if self.step > self.max_step:break 212 | n_epochs +=1 213 | self.log.close() 214 | print('[INFO] Finished training after', human_format(self.max_step), 'steps.') 215 | 216 | def validate(self, _dv_set, _name): 217 | # Eval mode 218 | self.model.eval() 219 | if self.emb_decoder is not None: self.emb_decoder.eval() 220 | dev_wer = {'att':[],'ctc':[]} 221 | dev_cer = {'att':[],'ctc':[]} 222 | dev_er = {'att':[],'ctc':[]} 223 | 224 | for i,data in enumerate(_dv_set): 225 | self.progress('Valid step - {}/{}'.format(i+1,len(_dv_set))) 226 | # Fetch data 227 | feat, feat_len, txt, txt_len = self.fetch_data(data) 228 | 229 | # Forward model 230 | with torch.no_grad(): 231 | ctc_output, encode_len, att_output, att_align, dec_state = \ 232 | self.model( feat, feat_len, int(max(txt_len)*self.DEV_STEP_RATIO), 233 | emb_decoder=self.emb_decoder) 234 | 235 | if att_output is not None: 236 | dev_wer['att'].append(cal_er(self.tokenizer,att_output,txt,mode='wer')) 237 | dev_cer['att'].append(cal_er(self.tokenizer,att_output,txt,mode='cer')) 238 | dev_er['att'].append(cal_er(self.tokenizer,att_output,txt,mode=self.val_mode)) 239 | if ctc_output is not None: 240 | dev_wer['ctc'].append(cal_er(self.tokenizer,ctc_output,txt,mode='wer',ctc=True)) 241 | dev_cer['ctc'].append(cal_er(self.tokenizer,ctc_output,txt,mode='cer',ctc=True)) 242 | dev_er['ctc'].append(cal_er(self.tokenizer,ctc_output,txt,mode=self.val_mode,ctc=True)) 243 | 244 | # Show some example on tensorboard 245 | if i == len(_dv_set)//2: 246 | for i in range(min(len(txt),self.DEV_N_EXAMPLE)): 247 | if self.step==1: 248 | self.write_log('true_text_{}_{}'.format(_name, i),self.tokenizer.decode(txt[i].tolist())) 249 | if att_output is not None: 250 | self.write_log('att_align_{}_{}'.format(_name, i),feat_to_fig(att_align[i,0,:,:].cpu().detach())) 251 | self.write_log('att_text_{}_{}'.format(_name, i),self.tokenizer.decode(att_output[i].argmax(dim=-1).tolist())) 252 | if ctc_output is not None: 253 | self.write_log('ctc_text_{}_{}'.format(_name, i),self.tokenizer.decode(ctc_output[i].argmax(dim=-1).tolist(), 254 | ignore_repeat=True)) 255 | 256 | # Ckpt if performance improves 257 | tasks = [] 258 | if len(dev_er['att']) > 0: 259 | tasks.append('att') 260 | if len(dev_er['ctc']) > 0: 261 | tasks.append('ctc') 262 | 263 | for task in tasks: 264 | dev_er[task] = sum(dev_er[task])/len(dev_er[task]) 265 | dev_wer[task] = sum(dev_wer[task])/len(dev_wer[task]) 266 | dev_cer[task] = sum(dev_cer[task])/len(dev_cer[task]) 267 | if dev_er[task] < self.best_wer[task][_name]: 268 | self.best_wer[task][_name] = dev_er[task] 269 | self.save_checkpoint('best_{}_{}.pth'.format(task, _name + (self.save_name if self.transfer_learning else '')), 270 | self.val_mode,dev_er[task],_name) 271 | if self.step >= self.max_step: 272 | self.save_checkpoint('last_{}_{}.pth'.format(task, _name + (self.save_name if self.transfer_learning else '')), 273 | self.val_mode,dev_er[task],_name) 274 | self.write_log(self.WER,{'dv_'+task+'_'+_name.lower():dev_wer[task]}) 275 | self.write_log( 'cer',{'dv_'+task+'_'+_name.lower():dev_cer[task]}) 276 | # if self.transfer_learning: 277 | # print('[{}] WER {:.4f} / CER {:.4f} on {}'.format(human_format(self.step), dev_wer[task], dev_cer[task], _name)) 278 | 279 | # Resume training 280 | self.model.train() 281 | if self.transfer_learning: 282 | self.model.encoder.fix_layers(self.fix_enc) 283 | if self.fix_dec and self.model.enable_att: 284 | self.model.decoder.fix_layers() 285 | if self.fix_dec and self.model.enable_ctc: 286 | self.model.fix_ctc_layer() 287 | 288 | if self.emb_decoder is not None: self.emb_decoder.train() 289 | -------------------------------------------------------------------------------- /src/asr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions.categorical import Categorical 7 | 8 | from src.util import init_weights, init_gate 9 | from src.module import VGGExtractor, VGGExtractor2, FreqVGGExtractor, FreqVGGExtractor2, \ 10 | RNNLayer, ScaleDotAttention, LocationAwareAttention 11 | 12 | class ASR(nn.Module): 13 | ''' ASR model, including Encoder/Decoder(s)''' 14 | def __init__(self, input_size, vocab_size, ctc_weight, encoder, attention, decoder, emb_drop=0.0, init_adadelta=True): 15 | super(ASR, self).__init__() 16 | 17 | # Setup 18 | assert 0<=ctc_weight<=1 19 | self.vocab_size = vocab_size 20 | self.ctc_weight = ctc_weight 21 | self.enable_ctc = ctc_weight > 0 22 | self.enable_att = ctc_weight != 1 23 | self.lm = None 24 | 25 | # Modules 26 | self.encoder = Encoder(input_size, **encoder) 27 | if self.enable_ctc: 28 | self.ctc_layer = nn.Linear(self.encoder.out_dim, vocab_size) 29 | if self.enable_att: 30 | self.dec_dim = decoder['dim'] 31 | self.pre_embed = nn.Embedding(vocab_size, self.dec_dim) 32 | self.embed_drop = nn.Dropout(emb_drop) 33 | self.decoder = Decoder(self.encoder.out_dim+self.dec_dim, vocab_size, **decoder) 34 | query_dim = self.dec_dim*self.decoder.layer 35 | self.attention = Attention(self.encoder.out_dim, query_dim, **attention) 36 | 37 | # Init 38 | if init_adadelta: 39 | self.apply(init_weights) 40 | if self.enable_att: 41 | for l in range(self.decoder.layer): 42 | bias = getattr(self.decoder.layers,'bias_ih_l{}'.format(l)) 43 | bias = init_gate(bias) 44 | 45 | def set_state(self, prev_state, prev_attn): 46 | ''' Setting up all memory states for beam decoding''' 47 | self.decoder.set_state(prev_state) 48 | self.attention.set_mem(prev_attn) 49 | 50 | def create_msg(self): 51 | # Messages for user 52 | msg = [] 53 | msg.append('Model spec.| Encoder\'s downsampling rate of time axis is {}.'.format(self.encoder.sample_rate)) 54 | if self.encoder.vgg == 1: 55 | msg.append(' | VGG Extractor w/ time downsampling rate = 4 in encoder enabled.') 56 | if self.encoder.vgg == 2: 57 | msg.append(' | Freq VGG Extractor w/ time downsampling rate = 4 and freq split = {} in encoder enabled.'.format(self.encoder.vgg_freq)) 58 | if self.encoder.vgg == 3: 59 | msg.append(' | VGG Extractor w/ time downsampling rate = 2 in encoder enabled.'.format(self.encoder.vgg_freq)) 60 | if self.encoder.vgg == 4: 61 | msg.append(' | Freq VGG Extractor w/ time DS rate = 2, freq split = {}, and low-freq filters = {} in encoder enabled.'.format(self.encoder.vgg_freq, self.encoder.vgg_low_filt)) 62 | 63 | if self.enable_ctc: 64 | msg.append(' | CTC training on encoder enabled ( lambda = {}).'.format(self.ctc_weight)) 65 | if self.enable_att: 66 | msg.append(' | {} attention decoder enabled ( lambda = {}).'.format(self.attention.mode,1-self.ctc_weight)) 67 | return msg 68 | 69 | def forward(self, audio_feature, feature_len, decode_step, tf_rate=0.0, teacher=None, 70 | emb_decoder=None, get_dec_state=False, get_logit=False): 71 | ''' 72 | Arguments 73 | audio_feature - [BxTxD] Acoustic feature with shape 74 | feature_len - [B] Length of each sample in a batch 75 | decode_step - [int] The maximum number of attention decoder steps 76 | tf_rate - [0,1] The probability to perform teacher forcing for each step 77 | teacher - [BxLxD] Ground truth for teacher forcing with sentence length L 78 | emb_decoder - [obj] Introduces the word embedding decoder, different behavior for training/inference 79 | At training stage, this ONLY affects self-sampling (output remains the same) 80 | At inference stage, this affects output to become log prob. with distribution fusion 81 | get_dec_state - [bool] If true, return decoder state [BxLxD] for other purpose 82 | ''' 83 | # Init 84 | bs = audio_feature.shape[0] 85 | ctc_output, att_output, att_seq = None, None, None 86 | dec_state = [] if get_dec_state else None 87 | 88 | # Encode 89 | encode_feature,encode_len = self.encoder(audio_feature,feature_len) 90 | 91 | # CTC based decoding 92 | if self.enable_ctc: 93 | if get_logit: 94 | ctc_output = self.ctc_layer(encode_feature) 95 | else: 96 | ctc_output = F.log_softmax(self.ctc_layer(encode_feature),dim=-1) 97 | 98 | # Attention based decoding 99 | if self.enable_att: 100 | # Init (init char = , reset all rnn state and cell) 101 | self.decoder.init_state(bs) 102 | self.attention.reset_mem() 103 | last_char = self.pre_embed(torch.zeros((bs),dtype=torch.long, device=encode_feature.device)) 104 | att_seq, output_seq = [], [] 105 | 106 | # Preprocess data for teacher forcing 107 | if teacher is not None: 108 | teacher = self.embed_drop(self.pre_embed(teacher)) 109 | 110 | # Decode 111 | for t in range(decode_step): 112 | # Attend (inputs current state of first layer, encoded features) 113 | attn,context = self.attention(self.decoder.get_query(),encode_feature,encode_len) 114 | # Decode (inputs context + embedded last character) 115 | decoder_input = torch.cat([last_char,context],dim=-1) 116 | cur_char, d_state = self.decoder(decoder_input) 117 | # Prepare output as input of next step 118 | if (teacher is not None): 119 | # Training stage 120 | if (tf_rate==1) or (torch.rand(1).item()<=tf_rate): 121 | # teacher forcing 122 | last_char = teacher[:,t,:] 123 | else: 124 | # self-sampling (replace by argmax may be another choice) 125 | with torch.no_grad(): 126 | if (emb_decoder is not None) and emb_decoder.apply_fuse: 127 | _, cur_prob = emb_decoder(d_state,cur_char,return_loss=False) 128 | else: 129 | cur_prob = cur_char.softmax(dim=-1) 130 | sampled_char = Categorical(cur_prob).sample() 131 | last_char = self.embed_drop(self.pre_embed(sampled_char)) 132 | else: 133 | # Inference stage 134 | if (emb_decoder is not None) and emb_decoder.apply_fuse: 135 | _,cur_char = emb_decoder(d_state,cur_char,return_loss=False) 136 | # argmax for inference 137 | last_char = self.pre_embed(torch.argmax(cur_char,dim=-1)) 138 | 139 | # save output of each step 140 | output_seq.append(cur_char) 141 | att_seq.append(attn) 142 | if get_dec_state: 143 | dec_state.append(d_state) 144 | 145 | att_output = torch.stack(output_seq,dim=1) # BxTxV 146 | att_seq = torch.stack(att_seq,dim=2) # BxNxDtxT 147 | if get_dec_state: 148 | dec_state = torch.stack(dec_state,dim=1) 149 | 150 | return ctc_output, encode_len, att_output, att_seq, dec_state 151 | 152 | def fix_ctc_layer(self): 153 | for param in self.ctc_layer.parameters(): 154 | param.requires_grad = False 155 | 156 | class Decoder(nn.Module): 157 | ''' Decoder (a.k.a. Speller in LAS) ''' 158 | # ToDo: More elegant way to implement decoder 159 | def __init__(self, input_dim, vocab_size, module, dim, layer, dropout): 160 | super(Decoder, self).__init__() 161 | self.in_dim = input_dim 162 | self.layer = layer 163 | self.dim = dim 164 | self.dropout = dropout 165 | 166 | # Init 167 | assert module in ['LSTM','GRU'], NotImplementedError 168 | self.hidden_state = None 169 | self.enable_cell = module=='LSTM' 170 | 171 | # Modules 172 | self.layers = getattr(nn,module)(input_dim,dim, num_layers=layer, dropout=dropout, batch_first=True) 173 | self.char_trans = nn.Linear(dim,vocab_size) 174 | self.final_dropout = nn.Dropout(dropout) 175 | 176 | def init_state(self, bs): 177 | ''' Set all hidden states to zeros ''' 178 | device = next(self.parameters()).device 179 | if self.enable_cell: 180 | self.hidden_state = (torch.zeros((self.layer,bs,self.dim),device=device), 181 | torch.zeros((self.layer,bs,self.dim),device=device)) 182 | else: 183 | self.hidden_state = torch.zeros((self.layer,bs,self.dim),device=device) 184 | return self.get_state() 185 | 186 | def set_state(self, hidden_state): 187 | ''' Set all hidden states/cells, for decoding purpose''' 188 | device = next(self.parameters()).device 189 | if self.enable_cell: 190 | self.hidden_state = (hidden_state[0].to(device),hidden_state[1].to(device)) 191 | else: 192 | self.hidden_state = hidden_state.to(device) 193 | 194 | def get_state(self): 195 | ''' Return all hidden states/cells, for decoding purpose''' 196 | if self.enable_cell: 197 | return (self.hidden_state[0].cpu(),self.hidden_state[1].cpu()) 198 | else: 199 | return self.hidden_state.cpu() 200 | 201 | def get_query(self): 202 | ''' Return state of all layers as query for attention ''' 203 | if self.enable_cell: 204 | return self.hidden_state[0].transpose(0,1).reshape(-1,self.dim*self.layer) 205 | else: 206 | return self.hidden_state.transpose(0,1).reshape(-1,self.dim*self.layer) 207 | 208 | def forward(self, x): 209 | ''' Decode and transform into vocab ''' 210 | if not self.training: 211 | self.layers.flatten_parameters() 212 | x, self.hidden_state = self.layers(x.unsqueeze(1),self.hidden_state) 213 | x = x.squeeze(1) 214 | char = self.char_trans(self.final_dropout(x)) 215 | return char, x 216 | 217 | def fix_layers(self): 218 | for param in self.parameters(): 219 | param.requires_grad = False 220 | 221 | 222 | class Attention(nn.Module): 223 | ''' Attention mechanism 224 | please refer to http://www.aclweb.org/anthology/D15-1166 section 3.1 for more details about Attention implementation 225 | Input : Decoder state with shape [batch size, decoder hidden dimension] 226 | Compressed feature from Encoder with shape [batch size, T, encoder feature dimension] 227 | Output: Attention score with shape [batch size, num head, T (attention score of each time step)] 228 | Context vector with shape [batch size, encoder feature dimension] 229 | (i.e. weighted (by attention score) sum of all timesteps T's feature) ''' 230 | def __init__(self, v_dim, q_dim, mode, dim, num_head, temperature, v_proj, 231 | loc_kernel_size, loc_kernel_num): 232 | super(Attention,self).__init__() 233 | 234 | # Setup 235 | self.v_dim = v_dim 236 | self.dim = dim 237 | self.mode = mode.lower() 238 | self.num_head = num_head 239 | 240 | # Linear proj. before attention 241 | self.proj_q = nn.Linear( q_dim, dim*num_head) 242 | self.proj_k = nn.Linear( v_dim, dim*num_head) 243 | self.v_proj = v_proj 244 | if v_proj: 245 | self.proj_v = nn.Linear( v_dim, v_dim*num_head) 246 | 247 | # Attention 248 | if self.mode == 'dot': 249 | self.att_layer = ScaleDotAttention(temperature, self.num_head) 250 | elif self.mode == 'loc': 251 | self.att_layer = LocationAwareAttention(loc_kernel_size, loc_kernel_num, dim, num_head, temperature) 252 | else: 253 | raise NotImplementedError 254 | 255 | # Layer for merging MHA 256 | if self.num_head > 1: 257 | self.merge_head = nn.Linear(v_dim*num_head, v_dim) 258 | 259 | # Stored feature 260 | self.key = None 261 | self.value = None 262 | self.mask = None 263 | 264 | def reset_mem(self): 265 | self.key = None 266 | self.value = None 267 | self.mask = None 268 | self.att_layer.reset_mem() 269 | 270 | def set_mem(self,prev_attn): 271 | self.att_layer.set_mem(prev_attn) 272 | 273 | def forward(self, dec_state, enc_feat, enc_len): 274 | 275 | # Preprecessing 276 | bs,ts,_ = enc_feat.shape 277 | query = torch.tanh(self.proj_q(dec_state)) 278 | query = query.view(bs, self.num_head, self.dim).view(bs*self.num_head, self.dim) # BNxD 279 | 280 | if self.key is None: 281 | # Maskout attention score for padded states 282 | self.att_layer.compute_mask(enc_feat,enc_len.to(enc_feat.device)) 283 | 284 | # Store enc state to lower computational cost 285 | self.key = torch.tanh(self.proj_k(enc_feat)) 286 | self.value = torch.tanh(self.proj_v(enc_feat)) if self.v_proj else enc_feat # BxTxN 287 | 288 | if self.num_head>1: 289 | self.key = self.key.view(bs,ts,self.num_head,self.dim).permute(0,2,1,3) # BxNxTxD 290 | self.key = self.key.contiguous().view(bs*self.num_head,ts,self.dim) # BNxTxD 291 | if self.v_proj: 292 | self.value = self.value.view(bs,ts,self.num_head,self.v_dim).permute(0,2,1,3) # BxNxTxD 293 | self.value = self.value.contiguous().view(bs*self.num_head,ts,self.v_dim) # BNxTxD 294 | else: 295 | self.value = self.value.repeat(self.num_head,1,1) # 296 | 297 | # Calculate attention 298 | context, attn = self.att_layer(query, self.key, self.value) 299 | if self.num_head>1: 300 | context = context.view(bs,self.num_head*self.v_dim) # BNxD -> BxND 301 | context = self.merge_head(context) # BxD 302 | 303 | return attn,context 304 | 305 | 306 | class Encoder(nn.Module): 307 | ''' Encoder (a.k.a. Listener in LAS) 308 | Encodes acoustic feature to latent representation, see config file for more details.''' 309 | def __init__(self, input_size, vgg, vgg_freq, vgg_low_filt, module, bidirection, dim, dropout, layer_norm, proj, sample_rate, sample_style): 310 | super(Encoder, self).__init__() 311 | 312 | # Hyper-parameters checking 313 | self.vgg = vgg 314 | self.vgg_freq = vgg_freq 315 | self.vgg_low_filt = vgg_low_filt 316 | self.sample_rate = 1 317 | assert len(sample_rate)==len(dropout), 'Number of layer mismatch' 318 | assert len(dropout)==len(dim), 'Number of layer mismatch' 319 | num_layers = len(dim) 320 | assert num_layers>=1,'Encoder should have at least 1 layer' 321 | 322 | # Construct model 323 | module_list = [] 324 | input_dim = input_size 325 | 326 | if vgg > 0: 327 | if vgg == 1: 328 | vgg_extractor = VGGExtractor(input_size) 329 | elif vgg == 2: 330 | vgg_extractor = FreqVGGExtractor(input_size, vgg_freq, vgg_low_filt) 331 | elif vgg == 3: 332 | vgg_extractor = VGGExtractor2(input_size) 333 | elif vgg == 4: 334 | vgg_extractor = FreqVGGExtractor2(input_size, vgg_freq, vgg_low_filt) 335 | else: 336 | raise NotImplementedError('vgg = {} is not available'.format(vgg)) 337 | module_list.append(vgg_extractor) 338 | input_dim = vgg_extractor.out_dim 339 | self.sample_rate = self.sample_rate * (4 if vgg < 3 else 2) 340 | 341 | if module in ['LSTM','GRU']: 342 | for l in range(num_layers): 343 | module_list.append(RNNLayer(input_dim, module, dim[l], bidirection, dropout[l], layer_norm[l], 344 | sample_rate[l], sample_style, proj[l])) 345 | input_dim = module_list[-1].out_dim 346 | self.sample_rate = self.sample_rate*sample_rate[l] 347 | else: 348 | raise NotImplementedError 349 | 350 | self.in_dim = input_size 351 | self.out_dim = input_dim 352 | self.layers = nn.ModuleList(module_list) 353 | 354 | def forward(self, input_x, enc_len): 355 | for _, layer in enumerate(self.layers): 356 | input_x, enc_len = layer(input_x, enc_len) 357 | return input_x, enc_len 358 | 359 | def get_layer_output(self, input_x, enc_len, layer_num=1): 360 | for i, layer in enumerate(self.layers): 361 | if i >= layer_num: 362 | break 363 | input_x, enc_len = layer(input_x, enc_len) 364 | return input_x, enc_len 365 | 366 | def fix_layers(self, layers): 367 | for l in layers: 368 | for param in self.layers[l].parameters(): 369 | param.requires_grad = False 370 | -------------------------------------------------------------------------------- /src/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence 6 | from torch.autograd import Function 7 | 8 | FBANK_SIZE = 80 9 | 10 | class VGGExtractor(nn.Module): 11 | ''' VGG extractor for ASR described in https://arxiv.org/pdf/1706.02737.pdf''' 12 | def __init__(self,input_dim): 13 | super(VGGExtractor, self).__init__() 14 | self.init_dim = 64 15 | self.hide_dim = 128 16 | in_channel,freq_dim,out_dim = self.check_dim(input_dim) 17 | self.in_channel = in_channel 18 | self.freq_dim = freq_dim 19 | self.out_dim = out_dim 20 | 21 | self.extractor = nn.Sequential( 22 | nn.Conv2d( in_channel, self.init_dim, 3, stride=1, padding=1), 23 | nn.ReLU(), 24 | nn.Conv2d( self.init_dim, self.init_dim, 3, stride=1, padding=1), 25 | nn.ReLU(), 26 | nn.MaxPool2d(2, stride=2), # Half-time dimension 27 | nn.Conv2d( self.init_dim, self.hide_dim, 3, stride=1, padding=1), 28 | nn.ReLU(), 29 | nn.Conv2d( self.hide_dim, self.hide_dim, 3, stride=1, padding=1), 30 | nn.ReLU(), 31 | nn.MaxPool2d(2, stride=2) # Half-time dimension 32 | ) 33 | 34 | def check_dim(self,input_dim): 35 | # Check input dimension, delta feature should be stack over channel. 36 | if input_dim % 13 == 0: 37 | # MFCC feature 38 | return int(input_dim // 13),13,(13 // 4)*self.hide_dim 39 | elif input_dim % FBANK_SIZE == 0: 40 | # Fbank feature 41 | return int(input_dim // FBANK_SIZE),FBANK_SIZE,(FBANK_SIZE//4)*self.hide_dim 42 | else: 43 | raise ValueError('Acoustic feature dimension for VGG should be 13/26/39(MFCC) or 40/80/120(Fbank) but got '+d) 44 | 45 | def view_input(self,feature,feat_len): 46 | # downsample time 47 | feat_len = feat_len//4 48 | # crop sequence s.t. t%4==0 49 | if feature.shape[1]%4 != 0: 50 | feature = feature[:,:-(feature.shape[1]%4),:].contiguous() 51 | bs,ts,ds = feature.shape 52 | # stack feature according to result of check_dim 53 | feature = feature.view(bs,ts,self.in_channel,self.freq_dim) 54 | feature = feature.transpose(1,2) 55 | 56 | return feature,feat_len 57 | 58 | def forward(self,feature,feat_len): 59 | # Feature shape BSxTxD -> BS x CH(num of delta) x T x D(acoustic feature dim) 60 | feature, feat_len = self.view_input(feature,feat_len) 61 | # Foward 62 | feature = self.extractor(feature) 63 | # BSx128xT/4xD/4 -> BSxT/4x128xD/4 64 | feature = feature.transpose(1,2) 65 | # BS x T/4 x 128 x D/4 -> BS x T/4 x 32D 66 | feature = feature.contiguous().view(feature.shape[0],feature.shape[1],self.out_dim) 67 | return feature,feat_len 68 | 69 | 70 | class FreqVGGExtractor(nn.Module): 71 | ''' Frequency Modification VGG extractor for ASR ''' 72 | def __init__(self,input_dim, split_freq, low_dim=4): 73 | super(FreqVGGExtractor, self).__init__() 74 | self.split_freq = split_freq 75 | self.low_init_dim = low_dim 76 | self.low_hide_dim = low_dim * 2 77 | self.high_init_dim = 64 - low_dim 78 | self.high_hide_dim = 128 - low_dim * 2 79 | 80 | in_channel,freq_dim = self.check_dim(input_dim) 81 | self.in_channel = in_channel 82 | self.freq_dim = freq_dim 83 | self.low_out_dim = split_freq // 4 * self.low_hide_dim 84 | self.high_out_dim = (freq_dim - split_freq) // 4 * self.high_hide_dim 85 | self.out_dim = self.low_out_dim + self.high_out_dim 86 | 87 | self.low_extractor = nn.Sequential( 88 | nn.Conv2d( in_channel, self.low_init_dim, 3, stride=1, padding=1), 89 | nn.ReLU(), 90 | nn.Conv2d( self.low_init_dim, self.low_init_dim, 3, stride=1, padding=1), 91 | nn.ReLU(), 92 | nn.MaxPool2d(2, stride=2), # Half-time dimension 93 | nn.Conv2d( self.low_init_dim, self.low_hide_dim, 3, stride=1, padding=1), 94 | nn.ReLU(), 95 | nn.Conv2d( self.low_hide_dim, self.low_hide_dim, 3, stride=1, padding=1), 96 | nn.ReLU(), 97 | nn.MaxPool2d(2, stride=2) # Half-time dimension 98 | ) 99 | self.high_extractor = nn.Sequential( 100 | nn.Conv2d( in_channel, self.high_init_dim, 3, stride=1, padding=1), 101 | nn.ReLU(), 102 | nn.Conv2d( self.high_init_dim, self.high_init_dim, 3, stride=1, padding=1), 103 | nn.ReLU(), 104 | nn.MaxPool2d(2, stride=2), # Half-time dimension 105 | nn.Conv2d( self.high_init_dim, self.high_hide_dim, 3, stride=1, padding=1), 106 | nn.ReLU(), 107 | nn.Conv2d( self.high_hide_dim, self.high_hide_dim, 3, stride=1, padding=1), 108 | nn.ReLU(), 109 | nn.MaxPool2d(2, stride=2) # Half-time dimension 110 | ) 111 | 112 | assert(self.split_freq % 4 == 0) 113 | assert(self.split_freq > 0 and self.split_freq < self.freq_dim) 114 | 115 | def check_dim(self,input_dim): 116 | # Check input dimension, delta feature should be stack over channel. 117 | if input_dim % 13 == 0: 118 | # MFCC feature 119 | return int(input_dim // 13),13 120 | elif input_dim % FBANK_SIZE == 0: 121 | # Fbank feature 122 | return int(input_dim // FBANK_SIZE),FBANK_SIZE 123 | else: 124 | raise ValueError('Acoustic feature dimension for VGG should be 13/26/39(MFCC) or 40/80/120(Fbank) but got '+d) 125 | 126 | def view_input(self,feature,feat_len): 127 | # downsample time 128 | feat_len = feat_len//4 129 | # crop sequence s.t. t%4==0 130 | if feature.shape[1]%4 != 0: 131 | feature = feature[:,:-(feature.shape[1]%4),:].contiguous() 132 | bs,ts,ds = feature.shape 133 | # stack feature according to result of check_dim 134 | feature = feature.view(bs,ts,self.in_channel,self.freq_dim) 135 | feature = feature.transpose(1,2) 136 | 137 | return feature,feat_len 138 | 139 | def forward(self,feature,feat_len): 140 | # Feature shape BSxTxD -> BS x CH(num of delta) x T x D(acoustic feature dim) 141 | feature, feat_len = self.view_input(feature,feat_len) 142 | # Foward 143 | low_feature = self.low_extractor(feature[:,:,:,:self.split_freq]) 144 | high_feature = self.high_extractor(feature[:,:,:,self.split_freq:]) 145 | # features : BS x 4 x T/4 x D/4 , BS x 124 x T/4 x D/4 146 | # BS x H x T/4 x D/4 -> BS x T/4 x H x D/4 147 | low_feature = low_feature.transpose(1,2) 148 | high_feature = high_feature.transpose(1,2) 149 | # BS x T/4 x H x D/4 -> BS x T/4 x HD/4 150 | low_feature = low_feature.contiguous().view(low_feature.shape[0],low_feature.shape[1],self.low_out_dim) 151 | high_feature = high_feature.contiguous().view(high_feature.shape[0],high_feature.shape[1],self.high_out_dim) 152 | feature = torch.cat((low_feature, high_feature), dim=-1) 153 | return feature, feat_len 154 | 155 | class VGGExtractor2(nn.Module): 156 | ''' VGG extractor for ASR described in https://arxiv.org/pdf/1706.02737.pdf''' 157 | ''' Only downsample once ''' 158 | def __init__(self,input_dim): 159 | super(VGGExtractor2, self).__init__() 160 | self.init_dim = 64 161 | self.hide_dim = 128 162 | in_channel,freq_dim,out_dim = self.check_dim(input_dim) 163 | self.in_channel = in_channel 164 | self.freq_dim = freq_dim 165 | self.out_dim = out_dim 166 | 167 | self.extractor = nn.Sequential( 168 | nn.Conv2d( in_channel, self.init_dim, 3, stride=1, padding=1), 169 | nn.ReLU(), 170 | nn.Conv2d( self.init_dim, self.init_dim, 3, stride=1, padding=1), 171 | nn.ReLU(), 172 | nn.MaxPool2d(2, stride=2), # Half-time dimension 173 | nn.Conv2d( self.init_dim, self.hide_dim, 3, stride=1, padding=1), 174 | nn.ReLU(), 175 | nn.Conv2d( self.hide_dim, self.hide_dim, 3, stride=1, padding=1), 176 | nn.ReLU(), 177 | nn.MaxPool2d((1, 2), stride=(1, 2)) # 178 | ) 179 | 180 | def check_dim(self,input_dim): 181 | # Check input dimension, delta feature should be stack over channel. 182 | if input_dim % 13 == 0: 183 | # MFCC feature 184 | return int(input_dim // 13),13,(13 // 4)*self.hide_dim 185 | elif input_dim % FBANK_SIZE == 0: 186 | # Fbank feature 187 | return int(input_dim // FBANK_SIZE),FBANK_SIZE,(FBANK_SIZE//4)*self.hide_dim 188 | else: 189 | raise ValueError('Acoustic feature dimension for VGG should be 13/26/39(MFCC) or 40/80/120(Fbank) but got '+d) 190 | 191 | def view_input(self,feature,feat_len): 192 | # downsample time 193 | feat_len = feat_len//2 194 | # crop sequence s.t. t%4==0 195 | if feature.shape[1]%2 != 0: 196 | feature = feature[:,:-(feature.shape[1]%2),:].contiguous() 197 | bs,ts,ds = feature.shape 198 | # stack feature according to result of check_dim 199 | feature = feature.view(bs,ts,self.in_channel,self.freq_dim) 200 | feature = feature.transpose(1,2) 201 | 202 | return feature,feat_len 203 | 204 | def forward(self,feature,feat_len): 205 | # Feature shape BSxTxD -> BS x CH(num of delta) x T x D(acoustic feature dim) 206 | feature, feat_len = self.view_input(feature,feat_len) 207 | # Foward 208 | feature = self.extractor(feature) 209 | # BSx128xT/2xD/4 -> BSxT/2x128xD/4 210 | feature = feature.transpose(1,2) 211 | # BS x T/2 x 128 x D/4 -> BS x T/2 x 32D 212 | feature = feature.contiguous().view(feature.shape[0],feature.shape[1],self.out_dim) 213 | return feature,feat_len 214 | 215 | class FreqVGGExtractor2(nn.Module): 216 | ''' Frequency Modification VGG extractor for ASR ''' 217 | def __init__(self,input_dim, split_freq, low_dim=4): 218 | super(FreqVGGExtractor2, self).__init__() 219 | self.split_freq = split_freq 220 | self.low_init_dim = low_dim 221 | self.low_hide_dim = low_dim * 2 222 | self.high_init_dim = 64 - low_dim 223 | self.high_hide_dim = 128 - low_dim * 2 224 | # self.init_dim = 64 225 | # self.low_hide_dim = 8 226 | # self.high_hide_dim = 120 227 | 228 | in_channel,freq_dim = self.check_dim(input_dim) 229 | self.in_channel = in_channel 230 | self.freq_dim = freq_dim 231 | self.low_out_dim = split_freq // 4 * self.low_hide_dim 232 | self.high_out_dim = (freq_dim - split_freq) // 4 * self.high_hide_dim 233 | self.out_dim = self.low_out_dim + self.high_out_dim 234 | 235 | # self.first_extractor = nn.Sequential( 236 | # nn.Conv2d( in_channel, self.init_dim, 3, stride=1, padding=1), 237 | # nn.ReLU(), 238 | # nn.Conv2d( self.init_dim, self.init_dim, 3, stride=1, padding=1), 239 | # nn.ReLU(), 240 | # nn.MaxPool2d(2, stride=2), # Half-time dimension 241 | # ) 242 | # self.low_extractor = nn.Sequential( 243 | # nn.Conv2d( self.init_dim, self.low_hide_dim, 3, stride=1, padding=1), 244 | # nn.ReLU(), 245 | # nn.Conv2d( self.low_hide_dim, self.low_hide_dim, 3, stride=1, padding=1), 246 | # nn.ReLU(), 247 | # nn.MaxPool2d((1, 2), stride=(1, 2)) # 248 | # ) 249 | # self.high_extractor = nn.Sequential( 250 | # nn.Conv2d( self.init_dim, self.high_hide_dim, 3, stride=1, padding=1), 251 | # nn.ReLU(), 252 | # nn.Conv2d( self.high_hide_dim, self.high_hide_dim, 3, stride=1, padding=1), 253 | # nn.ReLU(), 254 | # nn.MaxPool2d((1, 2), stride=(1, 2)) # 255 | # ) 256 | self.low_extractor = nn.Sequential( 257 | nn.Conv2d( in_channel, self.low_init_dim, 3, stride=1, padding=1), 258 | nn.ReLU(), 259 | nn.Conv2d( self.low_init_dim, self.low_init_dim, 3, stride=1, padding=1), 260 | nn.ReLU(), 261 | nn.MaxPool2d(2, stride=2), # Half-time dimension 262 | nn.Conv2d( self.low_init_dim, self.low_hide_dim, 3, stride=1, padding=1), 263 | nn.ReLU(), 264 | nn.Conv2d( self.low_hide_dim, self.low_hide_dim, 3, stride=1, padding=1), 265 | nn.ReLU(), 266 | nn.MaxPool2d((1, 2), stride=(1, 2)) # 267 | ) 268 | self.high_extractor = nn.Sequential( 269 | nn.Conv2d( in_channel, self.high_init_dim, 3, stride=1, padding=1), 270 | nn.ReLU(), 271 | nn.Conv2d( self.high_init_dim, self.high_init_dim, 3, stride=1, padding=1), 272 | nn.ReLU(), 273 | nn.MaxPool2d(2, stride=2), # Half-time dimension 274 | nn.Conv2d( self.high_init_dim, self.high_hide_dim, 3, stride=1, padding=1), 275 | nn.ReLU(), 276 | nn.Conv2d( self.high_hide_dim, self.high_hide_dim, 3, stride=1, padding=1), 277 | nn.ReLU(), 278 | nn.MaxPool2d((1, 2), stride=(1, 2)) # 279 | ) 280 | 281 | assert(self.split_freq % 4 == 0) 282 | assert(self.split_freq > 0 and self.split_freq < self.freq_dim) 283 | 284 | def check_dim(self,input_dim): 285 | # Check input dimension, delta feature should be stack over channel. 286 | if input_dim % 13 == 0: 287 | # MFCC feature 288 | return int(input_dim // 13),13 289 | elif input_dim % FBANK_SIZE == 0: 290 | # Fbank feature 291 | return int(input_dim // FBANK_SIZE),FBANK_SIZE 292 | else: 293 | raise ValueError('Acoustic feature dimension for VGG should be 13/26/39(MFCC) or 40/80/120(Fbank) but got '+d) 294 | 295 | def view_input(self,feature,feat_len): 296 | # downsample time 297 | feat_len = feat_len//2 298 | # crop sequence s.t. t%4==0 299 | if feature.shape[1]%2 != 0: 300 | feature = feature[:,:-(feature.shape[1]%2),:].contiguous() 301 | bs,ts,ds = feature.shape 302 | # stack feature according to result of check_dim 303 | feature = feature.view(bs,ts,self.in_channel,self.freq_dim) 304 | feature = feature.transpose(1,2) 305 | 306 | return feature,feat_len 307 | 308 | def forward(self,feature,feat_len): 309 | # Feature shape BSxTxD -> BS x CH(num of delta) x T x D(acoustic feature dim) 310 | feature, feat_len = self.view_input(feature,feat_len) 311 | # feature = self.first_extractor(feature) # new 312 | # Foward 313 | low_feature = self.low_extractor(feature[:,:,:,:self.split_freq]) 314 | high_feature = self.high_extractor(feature[:,:,:,self.split_freq:]) 315 | # low_feature = self.low_extractor(feature[:,:,:,:self.split_freq//2]) 316 | # high_feature = self.high_extractor(feature[:,:,:,self.split_freq//2:]) 317 | # features : BS x 4 x T/4 x D/4 , BS x 124 x T/4 x D/4 318 | # BS x H x T/4 x D/4 -> BS x T/4 x H x D/4 319 | low_feature = low_feature.transpose(1,2) 320 | high_feature = high_feature.transpose(1,2) 321 | # BS x T/4 x H x D/4 -> BS x T/4 x HD/4 322 | low_feature = low_feature.contiguous().view(low_feature.shape[0],low_feature.shape[1],self.low_out_dim) 323 | high_feature = high_feature.contiguous().view(high_feature.shape[0],high_feature.shape[1],self.high_out_dim) 324 | feature = torch.cat((low_feature, high_feature), dim=-1) 325 | return feature, feat_len 326 | 327 | class RNNLayer(nn.Module): 328 | ''' RNN wrapper, includes time-downsampling''' 329 | def __init__(self, input_dim, module, dim, bidirection, dropout, layer_norm, sample_rate, sample_style, proj): 330 | super(RNNLayer, self).__init__() 331 | # Setup 332 | rnn_out_dim = 2*dim if bidirection else dim 333 | self.out_dim = sample_rate*rnn_out_dim if sample_rate>1 and sample_style=='concat' else rnn_out_dim 334 | self.dropout = dropout 335 | self.layer_norm = layer_norm 336 | self.sample_rate = sample_rate 337 | self.sample_style = sample_style 338 | self.proj = proj 339 | 340 | if self.sample_style not in ['drop','concat']: 341 | raise ValueError('Unsupported Sample Style: '+self.sample_style) 342 | 343 | # Recurrent layer 344 | self.layer = getattr(nn,module.upper())(input_dim, dim, bidirectional=bidirection, num_layers=1, batch_first=True) 345 | 346 | # Regularizations 347 | if self.layer_norm: 348 | self.ln = nn.LayerNorm(rnn_out_dim) 349 | if self.dropout>0: 350 | self.dp = nn.Dropout(p=dropout) 351 | 352 | # Additional projection layer 353 | if self.proj: 354 | self.pj = nn.Linear(rnn_out_dim,rnn_out_dim) 355 | 356 | def forward(self, input_x , x_len): 357 | # Forward RNN 358 | if not self.training: 359 | self.layer.flatten_parameters() 360 | # ToDo: check time efficiency of pack/pad 361 | #input_x = pack_padded_sequence(input_x, x_len, batch_first=True, enforce_sorted=False) 362 | output,_ = self.layer(input_x) 363 | #output,x_len = pad_packed_sequence(output,batch_first=True) 364 | 365 | # Normalizations 366 | if self.layer_norm: 367 | output = self.ln(output) 368 | if self.dropout>0: 369 | output = self.dp(output) 370 | 371 | # Perform Downsampling 372 | if self.sample_rate > 1: 373 | batch_size,timestep,feature_dim = output.shape 374 | x_len = x_len//self.sample_rate 375 | 376 | if self.sample_style =='drop': 377 | # Drop the unselected timesteps 378 | output = output[:,::self.sample_rate,:].contiguous() 379 | else: 380 | # Drop the redundant frames and concat the rest according to sample rate 381 | if timestep%self.sample_rate != 0: 382 | output = output[:,:-(timestep%self.sample_rate),:] 383 | output = output.contiguous().view(batch_size,int(timestep/self.sample_rate),feature_dim*self.sample_rate) 384 | 385 | if self.proj: 386 | output = torch.tanh(self.pj(output)) 387 | 388 | return output,x_len 389 | 390 | 391 | class BaseAttention(nn.Module): 392 | ''' Base module for attentions ''' 393 | def __init__(self, temperature, num_head): 394 | super().__init__() 395 | self.temperature = temperature 396 | self.num_head = num_head 397 | self.softmax = nn.Softmax(dim=-1) 398 | self.reset_mem() 399 | 400 | def reset_mem(self): 401 | # Reset mask 402 | self.mask = None 403 | self.k_len = None 404 | 405 | def set_mem(self): 406 | pass 407 | 408 | def compute_mask(self,k,k_len): 409 | # Make the mask for padded states 410 | self.k_len = k_len 411 | bs,ts,_ = k.shape 412 | self.mask = np.zeros((bs,self.num_head,ts)) 413 | for idx,sl in enumerate(k_len): 414 | self.mask[idx,:,sl:] = 1 # ToDo: more elegant way? 415 | self.mask = torch.from_numpy(self.mask).to(k_len.device, dtype=torch.bool).view(-1,ts)# BNxT 416 | 417 | def _attend(self, energy, value): 418 | attn = energy / self.temperature 419 | attn = attn.masked_fill(self.mask, -np.inf) 420 | attn = self.softmax(attn) # BNxT 421 | output = torch.bmm(attn.unsqueeze(1), value).squeeze(1) # BNxT x BNxTxD-> BNxD 422 | return output, attn 423 | 424 | 425 | class ScaleDotAttention(BaseAttention): 426 | ''' Scaled Dot-Product Attention ''' 427 | def __init__(self, temperature, num_head): 428 | super().__init__(temperature, num_head) 429 | 430 | def forward(self, q, k, v): 431 | ts = k.shape[1] 432 | energy = torch.bmm(q.unsqueeze(1), k.transpose(1, 2)).squeeze(1) # BNxD * BNxDxT = BNxT 433 | output, attn = self._attend(energy,v) 434 | 435 | attn = attn.view(-1,self.num_head,ts) # BNxT -> BxNxT 436 | 437 | return output, attn 438 | 439 | 440 | class LocationAwareAttention(BaseAttention): 441 | ''' Location-Awared Attention ''' 442 | def __init__(self, kernel_size, kernel_num, dim, num_head, temperature): 443 | super().__init__(temperature, num_head) 444 | self.prev_att = None 445 | self.loc_conv = nn.Conv1d(num_head, kernel_num, kernel_size=2*kernel_size+1, padding=kernel_size, bias=False) 446 | self.loc_proj = nn.Linear(kernel_num, dim,bias=False) 447 | self.gen_energy = nn.Linear(dim, 1) 448 | self.dim = dim 449 | 450 | def reset_mem(self): 451 | super().reset_mem() 452 | self.prev_att = None 453 | 454 | def set_mem(self, prev_att): 455 | self.prev_att = prev_att 456 | 457 | def forward(self, q, k, v): 458 | bs_nh,ts,_ = k.shape 459 | bs = bs_nh//self.num_head 460 | 461 | # Uniformly init prev_att 462 | if self.prev_att is None: 463 | self.prev_att = torch.zeros((bs,self.num_head,ts)).to(k.device) 464 | for idx,sl in enumerate(self.k_len): 465 | self.prev_att[idx,:,:sl] = 1.0/sl 466 | 467 | # Calculate location context 468 | loc_context = torch.tanh(self.loc_proj(self.loc_conv(self.prev_att).transpose(1,2))) # BxNxT->BxTxD 469 | loc_context = loc_context.unsqueeze(1).repeat(1,self.num_head,1,1).view(-1,ts,self.dim) # BxNxTxD -> BNxTxD 470 | q = q.unsqueeze(1) # BNx1xD 471 | 472 | # Compute energy and context 473 | energy = self.gen_energy(torch.tanh( k+q+loc_context )).squeeze(2) # BNxTxD -> BNxT 474 | output, attn = self._attend(energy,v) 475 | attn = attn.view(bs,self.num_head,ts) # BNxT -> BxNxT 476 | self.prev_att = attn 477 | 478 | return output, attn 479 | --------------------------------------------------------------------------------