├── .gitignore ├── LICENSE ├── README.md ├── README.t ├── audios ├── audio_0.wav ├── audio_1.wav ├── audio_2.wav ├── audio_3.wav ├── audio_4.wav ├── audio_5.wav ├── audio_6.wav ├── audio_7.wav ├── audio_8.wav └── audio_9.wav ├── char_list.pkl ├── collect_char_list.py ├── config.py ├── data_gen.py ├── demo.py ├── export.py ├── extract.py ├── ngram_lm.py ├── pre_process.py ├── replace.py ├── requirements.txt ├── results.json ├── specAugment ├── __init__.py ├── sparse_image_warp_np.py ├── sparse_image_warp_pytorch.py ├── spec_augment_pytorch.py └── spec_augment_tensorflow.py ├── sponsor.jpg ├── test.py ├── test ├── test_decode.py ├── test_lm.py ├── test_lr.py ├── test_pe.py ├── test_specaug.py └── test_trim.py ├── test_lm.py ├── train.py ├── transformer ├── __init__.py ├── attention.py ├── decoder.py ├── encoder.py ├── loss.py ├── module.py ├── optimizer.py ├── transformer.py └── utils.py ├── utils.py └── xer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | BEST_checkpoint.tar 4 | checkpoint.tar 5 | data/ 6 | runs/ 7 | nohup.out 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 刘杨 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech Transformer 2 | 3 | ## Introduction 4 | 5 | This is a PyTorch re-implementation of Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition. 6 | 7 | ## Dataset 8 | 9 | Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co.,Ltd. 10 | 11 | 400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition. 12 | ``` 13 | @inproceedings{aishell_2017, 14 | title={AIShell-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline}, 15 | author={Hui Bu, Jiayu Du, Xingyu Na, Bengu Wu, Hao Zheng}, 16 | booktitle={Oriental COCOSDA 2017}, 17 | pages={Submitted}, 18 | year={2017} 19 | } 20 | ``` 21 | In data folder, download speech data and transcripts: 22 | 23 | ```bash 24 | $ wget http://www.openslr.org/resources/33/data_aishell.tgz 25 | ``` 26 | 27 | ## Performance 28 | 29 | Evaluate with 7176 audios in Aishell test set: 30 | ```bash 31 | $ python test.py 32 | ``` 33 | 34 | ## Results 35 | 36 | |Model|CER|Download| 37 | |---|---|---| 38 | |Speech Transformer|11.5|[Link](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/BEST_checkpoint.tar)| 39 | 40 | ## Dependency 41 | 42 | - Python 3.6.8 43 | - PyTorch 1.3.0 44 | 45 | ## Usage 46 | ### Data Pre-processing 47 | Extract data_aishell.tgz: 48 | ```bash 49 | $ python extract.py 50 | ``` 51 | 52 | Extract wav files into train/dev/test folders: 53 | ```bash 54 | $ cd data/data_aishell/wav 55 | $ find . -name '*.tar.gz' -execdir tar -xzvf '{}' \; 56 | ``` 57 | 58 | Scan transcript data, generate features: 59 | ```bash 60 | $ python pre_process.py 61 | ``` 62 | 63 | Now the folder structure under data folder is sth. like: 64 |
 65 | data/
 66 |     data_aishell.tgz
 67 |     data_aishell/
 68 |         transcript/
 69 |             aishell_transcript_v0.8.txt
 70 |         wav/
 71 |             train/
 72 |             dev/
 73 |             test/
 74 |     aishell.pickle
 75 | 
76 | 77 | ### Train 78 | ```bash 79 | $ python train.py 80 | ``` 81 | 82 | If you want to visualize during training, run in your terminal: 83 | ```bash 84 | $ tensorboard --logdir runs 85 | ``` 86 | 87 | ### Demo 88 | Please download the [pretrained model](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt) then run: 89 | ```bash 90 | $ python demo.py 91 | ``` 92 | 93 | It picks 10 random test examples and recognize them like these: 94 | 95 | |Audio|Out|GT| 96 | |---|---|---| 97 | |[audio_0.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_0.wav)|我国的经济处在爬破过凯的重要公考
我国的经济处在爬破过凯的重要公口
我国的经济处在盘破过凯的重要公考
我国的经济处在爬破过凯的重要公靠
我国的经济处在爬坡过凯的重要公考|我国的经济处在爬坡过坎的重要关口| 98 | |[audio_1.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_1.wav)|完善主地承包经一全流市市场
完善主地承包经一全六市市场
完善主地承包经营全流市市场
完善主地承包经一权流市市场
完善主地承包经营全六市市场|完善土地承包经营权流转市场| 99 | |[audio_2.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_2.wav)|临长各类设施使用年限
严长各类设施使用年限
延长各类设施使用年限
很长各类设施使用年限
难长各类设施使用年限|延长各类设施使用年限| 100 | |[audio_3.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_3.wav)|苹果此举是为了节约用电量
苹果此举是是了节约用电量
苹果此举是为了解约用电量
苹果此举是为了节约用电令
苹果此举只为了节约用电量|苹果此举是为了节约用电量| 101 | |[audio_4.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_4.wav)|反他们也可以有机会参与体育运动
让他们也可以有机会参与体育运动
反她们也可以有机会参与体育运动
范他们也可以有机会参与体育运动
但他们也可以有机会参与体育运动|让他们也可以有机会参与体育运动| 102 | |[audio_5.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_5.wav)|陈言希穿着粉色上衣
陈闫希穿着粉色上衣
陈延希穿着粉色上衣
陈言琪穿着粉色上衣
陈演希穿着粉色上衣|陈妍希穿着粉色上衣| 103 | |[audio_6.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_6.wav)|说起自己的伴女大下
说起自己的伴理大下
说起自己的半女大下
说起自己的办女大下
说起自己的半理大下|说起自己的伴侣大侠| 104 | |[audio_7.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_7.wav)|每日经济新闻记者注意到
每日经济新闻记者朱意到
每日经济新闻记者注一到
每日经济新闻记者注注到
每日经济新闻记者注以到|每日经济新闻记者注意到| 105 | |[audio_8.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_8.wav)|这是今年五月份以来库存环比增幅幅小了一次
这是今年五月份以来库存环比增幅最小了一次
这是今年五月份以来库存环比增幅幅小的一次
这是今年五月份以来库存环比增幅最小的一次
这是今年五月份以来库存环比增幅幅小小一次|这是今年五月份以来库存环比增幅最小的一次| 106 | |[audio_9.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_9.wav)|一个人的精使生命就将走向摔老
一个连的精使生命就将走向摔老
一个人的金使生命就将走向摔老
一个人的坚使生命就将走向摔老
一个连的金使生命就将走向摔老|一个人的精神生命就将走向衰老| 107 | 108 | ## 小小的赞助~ 109 |

110 | Sample 111 |

112 | 若对您有帮助可给予小小的赞助~ 113 |

114 |

115 |


-------------------------------------------------------------------------------- /README.t: -------------------------------------------------------------------------------- 1 | # Speech Transformer 2 | 3 | ## Introduction 4 | 5 | This is a PyTorch re-implementation of Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition. 6 | 7 | ## Dataset 8 | 9 | Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co.,Ltd. 10 | 11 | 400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition. 12 | ``` 13 | @inproceedings{aishell_2017, 14 | title={AIShell-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline}, 15 | author={Hui Bu, Jiayu Du, Xingyu Na, Bengu Wu, Hao Zheng}, 16 | booktitle={Oriental COCOSDA 2017}, 17 | pages={Submitted}, 18 | year={2017} 19 | } 20 | ``` 21 | In data folder, download speech data and transcripts: 22 | 23 | ```bash 24 | $ wget http://www.openslr.org/resources/33/data_aishell.tgz 25 | ``` 26 | 27 | ## Performance 28 | 29 | Evaluate with 7176 audios in Aishell test set: 30 | ```bash 31 | $ python test.py 32 | ``` 33 | 34 | ## Results 35 | 36 | |Model|CER|Download| 37 | |---|---|---| 38 | |Speech Transformer|11.5|[Link](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/BEST_checkpoint.tar)| 39 | 40 | ## Dependency 41 | 42 | - Python 3.6.8 43 | - PyTorch 1.3.0 44 | 45 | ## Usage 46 | ### Data Pre-processing 47 | Extract data_aishell.tgz: 48 | ```bash 49 | $ python extract.py 50 | ``` 51 | 52 | Extract wav files into train/dev/test folders: 53 | ```bash 54 | $ cd data/data_aishell/wav 55 | $ find . -name '*.tar.gz' -execdir tar -xzvf '{}' \; 56 | ``` 57 | 58 | Scan transcript data, generate features: 59 | ```bash 60 | $ python pre_process.py 61 | ``` 62 | 63 | Now the folder structure under data folder is sth. like: 64 |
 65 | data/
 66 |     data_aishell.tgz
 67 |     data_aishell/
 68 |         transcript/
 69 |             aishell_transcript_v0.8.txt
 70 |         wav/
 71 |             train/
 72 |             dev/
 73 |             test/
 74 |     aishell.pickle
 75 | 
76 | 77 | ### Train 78 | ```bash 79 | $ python train.py 80 | ``` 81 | 82 | If you want to visualize during training, run in your terminal: 83 | ```bash 84 | $ tensorboard --logdir runs 85 | ``` 86 | 87 | ### Demo 88 | Please download the [pretrained model](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt) then run: 89 | ```bash 90 | $ python demo.py 91 | ``` 92 | 93 | It picks 10 random test examples and recognize them like these: 94 | 95 | |Audio|Out|GT| 96 | |---|---|---| 97 | |[audio_0.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_0.wav)|$(out_list_0)|$(gt_0)| 98 | |[audio_1.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_1.wav)|$(out_list_1)|$(gt_1)| 99 | |[audio_2.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_2.wav)|$(out_list_2)|$(gt_2)| 100 | |[audio_3.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_3.wav)|$(out_list_3)|$(gt_3)| 101 | |[audio_4.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_4.wav)|$(out_list_4)|$(gt_4)| 102 | |[audio_5.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_5.wav)|$(out_list_5)|$(gt_5)| 103 | |[audio_6.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_6.wav)|$(out_list_6)|$(gt_6)| 104 | |[audio_7.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_7.wav)|$(out_list_7)|$(gt_7)| 105 | |[audio_8.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_8.wav)|$(out_list_8)|$(gt_8)| 106 | |[audio_9.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_9.wav)|$(out_list_9)|$(gt_9)| 107 | 108 | ## 小小的赞助~ 109 |

110 | Sample 111 |

112 | 若对您有帮助可给予小小的赞助~ 113 |

114 |

115 |


-------------------------------------------------------------------------------- /audios/audio_0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_0.wav -------------------------------------------------------------------------------- /audios/audio_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_1.wav -------------------------------------------------------------------------------- /audios/audio_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_2.wav -------------------------------------------------------------------------------- /audios/audio_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_3.wav -------------------------------------------------------------------------------- /audios/audio_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_4.wav -------------------------------------------------------------------------------- /audios/audio_5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_5.wav -------------------------------------------------------------------------------- /audios/audio_6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_6.wav -------------------------------------------------------------------------------- /audios/audio_7.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_7.wav -------------------------------------------------------------------------------- /audios/audio_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_8.wav -------------------------------------------------------------------------------- /audios/audio_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_9.wav -------------------------------------------------------------------------------- /char_list.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/char_list.pkl -------------------------------------------------------------------------------- /collect_char_list.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from config import pickle_file 4 | 5 | if __name__ == '__main__': 6 | with open(pickle_file, 'rb') as file: 7 | data = pickle.load(file) 8 | char_list = data['IVOCAB'] 9 | 10 | with open('char_list.pkl', 'wb') as file: 11 | pickle.dump(char_list, file) 12 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 6 | 7 | # Model parameters 8 | input_dim = 80 # dimension of feature 9 | window_size = 25 # window size for FFT (ms) 10 | stride = 10 # window stride for FFT (ms) 11 | hidden_size = 512 12 | embedding_dim = 512 13 | cmvn = True # apply CMVN on feature 14 | num_layers = 4 15 | LFR_m = 4 16 | LFR_n = 3 17 | sample_rate = 16000 # aishell 18 | 19 | # Training parameters 20 | grad_clip = 5. # clip gradients at an absolute value of 21 | print_freq = 100 # print training/validation stats every __ batches 22 | checkpoint = None # path to checkpoint, None if none 23 | 24 | # Data parameters 25 | IGNORE_ID = -1 26 | sos_id = 0 27 | eos_id = 1 28 | num_train = 120098 29 | num_dev = 14326 30 | num_test = 7176 31 | vocab_size = 4335 32 | 33 | DATA_DIR = 'data' 34 | aishell_folder = 'data/data_aishell' 35 | wav_folder = os.path.join(aishell_folder, 'wav') 36 | tran_file = os.path.join(aishell_folder, 'transcript/aishell_transcript_v0.8.txt') 37 | pickle_file = 'data/aishell.pickle' 38 | -------------------------------------------------------------------------------- /data_gen.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from config import pickle_file, IGNORE_ID 9 | from utils import extract_feature 10 | 11 | 12 | def pad_collate(batch): 13 | max_input_len = float('-inf') 14 | max_target_len = float('-inf') 15 | 16 | for elem in batch: 17 | feature, trn = elem 18 | max_input_len = max_input_len if max_input_len > feature.shape[0] else feature.shape[0] 19 | max_target_len = max_target_len if max_target_len > len(trn) else len(trn) 20 | 21 | for i, elem in enumerate(batch): 22 | feature, trn = elem 23 | input_length = feature.shape[0] 24 | input_dim = feature.shape[1] 25 | padded_input = np.zeros((max_input_len, input_dim), dtype=np.float32) 26 | padded_input[:input_length, :] = feature 27 | padded_target = np.pad(trn, (0, max_target_len - len(trn)), 'constant', constant_values=IGNORE_ID) 28 | batch[i] = (padded_input, padded_target, input_length) 29 | 30 | # sort it by input lengths (long to short) 31 | batch.sort(key=lambda x: x[2], reverse=True) 32 | 33 | return default_collate(batch) 34 | 35 | 36 | def build_LFR_features(inputs, m, n): 37 | """ 38 | Actually, this implements stacking frames and skipping frames. 39 | if m = 1 and n = 1, just return the origin features. 40 | if m = 1 and n > 1, it works like skipping. 41 | if m > 1 and n = 1, it works like stacking but only support right frames. 42 | if m > 1 and n > 1, it works like LFR. 43 | Args: 44 | inputs_batch: inputs is T x D np.ndarray 45 | m: number of frames to stack 46 | n: number of frames to skip 47 | """ 48 | # LFR_inputs_batch = [] 49 | # for inputs in inputs_batch: 50 | LFR_inputs = [] 51 | T = inputs.shape[0] 52 | T_lfr = int(np.ceil(T / n)) 53 | for i in range(T_lfr): 54 | if m <= T - i * n: 55 | LFR_inputs.append(np.hstack(inputs[i * n:i * n + m])) 56 | else: # process last LFR frame 57 | num_padding = m - (T - i * n) 58 | frame = np.hstack(inputs[i * n:]) 59 | for _ in range(num_padding): 60 | frame = np.hstack((frame, inputs[-1])) 61 | LFR_inputs.append(frame) 62 | return np.vstack(LFR_inputs) 63 | 64 | 65 | # Source: https://www.kaggle.com/davids1992/specaugment-quick-implementation 66 | def spec_augment(spec: np.ndarray, 67 | num_mask=2, 68 | freq_masking=0.15, 69 | time_masking=0.20, 70 | value=0): 71 | spec = spec.copy() 72 | num_mask = random.randint(1, num_mask) 73 | for i in range(num_mask): 74 | all_freqs_num, all_frames_num = spec.shape 75 | freq_percentage = random.uniform(0.0, freq_masking) 76 | 77 | num_freqs_to_mask = int(freq_percentage * all_freqs_num) 78 | f0 = np.random.uniform(low=0.0, high=all_freqs_num - num_freqs_to_mask) 79 | f0 = int(f0) 80 | spec[f0:f0 + num_freqs_to_mask, :] = value 81 | 82 | time_percentage = random.uniform(0.0, time_masking) 83 | 84 | num_frames_to_mask = int(time_percentage * all_frames_num) 85 | t0 = np.random.uniform(low=0.0, high=all_frames_num - num_frames_to_mask) 86 | t0 = int(t0) 87 | spec[:, t0:t0 + num_frames_to_mask] = value 88 | return spec 89 | 90 | 91 | class AiShellDataset(Dataset): 92 | def __init__(self, args, split): 93 | self.args = args 94 | with open(pickle_file, 'rb') as file: 95 | data = pickle.load(file) 96 | 97 | self.samples = data[split] 98 | print('loading {} {} samples...'.format(len(self.samples), split)) 99 | 100 | def __getitem__(self, i): 101 | sample = self.samples[i] 102 | wave = sample['wave'] 103 | trn = sample['trn'] 104 | 105 | feature = extract_feature(input_file=wave, feature='fbank', dim=self.args.d_input, cmvn=True) 106 | # zero mean and unit variance 107 | feature = (feature - feature.mean()) / feature.std() 108 | feature = spec_augment(feature) 109 | feature = build_LFR_features(feature, m=self.args.LFR_m, n=self.args.LFR_n) 110 | 111 | return feature, trn 112 | 113 | def __len__(self): 114 | return len(self.samples) 115 | 116 | 117 | if __name__ == "__main__": 118 | import torch 119 | from utils import parse_args 120 | from tqdm import tqdm 121 | 122 | args = parse_args() 123 | train_dataset = AiShellDataset(args, 'train') 124 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=args.num_workers, 125 | collate_fn=pad_collate) 126 | # 127 | # print(len(train_dataset)) 128 | # print(len(train_loader)) 129 | # 130 | # feature = train_dataset[10][0] 131 | # print(feature.shape) 132 | # 133 | # trn = train_dataset[10][1] 134 | # print(trn) 135 | # 136 | # with open(pickle_file, 'rb') as file: 137 | # data = pickle.load(file) 138 | # IVOCAB = data['IVOCAB'] 139 | # 140 | # print([IVOCAB[idx] for idx in trn]) 141 | # 142 | # for data in train_loader: 143 | # print(data) 144 | # break 145 | 146 | max_len = 0 147 | 148 | for data in tqdm(train_loader): 149 | feature = data[0] 150 | # print(feature.shape) 151 | if feature.shape[1] > max_len: 152 | max_len = feature.shape[1] 153 | 154 | print('max_len: ' + str(max_len)) 155 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import random 4 | from shutil import copyfile 5 | 6 | import torch 7 | 8 | from config import pickle_file, device, input_dim, LFR_m, LFR_n 9 | from data_gen import build_LFR_features 10 | from transformer.transformer import Transformer 11 | from utils import extract_feature, ensure_folder 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | "End-to-End Automatic Speech Recognition Decoding.") 17 | # decode 18 | parser.add_argument('--beam_size', default=5, type=int, 19 | help='Beam size') 20 | parser.add_argument('--nbest', default=5, type=int, 21 | help='Nbest size') 22 | parser.add_argument('--decode_max_len', default=100, type=int, 23 | help='Max output length. If ==0 (default), it uses a ' 24 | 'end-detect function to automatically find maximum ' 25 | 'hypothesis lengths') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | if __name__ == '__main__': 31 | args = parse_args() 32 | with open('char_list.pkl', 'rb') as file: 33 | char_list = pickle.load(file) 34 | with open(pickle_file, 'rb') as file: 35 | data = pickle.load(file) 36 | samples = data['test'] 37 | 38 | filename = 'speech-transformer-cn.pt' 39 | print('loading model: {}...'.format(filename)) 40 | model = Transformer() 41 | model.load_state_dict(torch.load(filename)) 42 | model = model.to(device) 43 | model.eval() 44 | 45 | samples = random.sample(samples, 10) 46 | ensure_folder('audios') 47 | results = [] 48 | 49 | for i, sample in enumerate(samples): 50 | wave = sample['wave'] 51 | trn = sample['trn'] 52 | 53 | copyfile(wave, 'audios/audio_{}.wav'.format(i)) 54 | 55 | feature = extract_feature(input_file=wave, feature='fbank', dim=input_dim, cmvn=True) 56 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n) 57 | # feature = np.expand_dims(feature, axis=0) 58 | input = torch.from_numpy(feature).to(device) 59 | input_length = [input[0].shape[0]] 60 | input_length = torch.LongTensor(input_length).to(device) 61 | nbest_hyps = model.recognize(input, input_length, char_list, args) 62 | out_list = [] 63 | for hyp in nbest_hyps: 64 | out = hyp['yseq'] 65 | out = [char_list[idx] for idx in out] 66 | out = ''.join(out) 67 | out_list.append(out) 68 | print('OUT_LIST: {}'.format(out_list)) 69 | 70 | gt = [char_list[idx] for idx in trn] 71 | gt = ''.join(gt) 72 | print('GT: {}\n'.format(gt)) 73 | 74 | results.append({'out_list_{}'.format(i): out_list, 'gt_{}'.format(i): gt}) 75 | 76 | import json 77 | 78 | with open('results.json', 'w') as file: 79 | json.dump(results, file, indent=4, ensure_ascii=False) 80 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | if __name__ == '__main__': 5 | checkpoint = 'BEST_checkpoint.tar' 6 | checkpoint = torch.load(checkpoint) 7 | model = checkpoint['model'] 8 | # model.eval() 9 | 10 | torch.save(model.state_dict(), 'speech-transformer-cn.pt') 11 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | 4 | 5 | def extract(filename): 6 | print('Extracting {}...'.format(filename)) 7 | tar = tarfile.open(filename, 'r') 8 | tar.extractall('data') 9 | tar.close() 10 | 11 | 12 | if __name__ == "__main__": 13 | if not os.path.isdir('data/data_aishell'): 14 | extract('data/data_aishell.tgz') 15 | -------------------------------------------------------------------------------- /ngram_lm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | 4 | import nltk 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from config import pickle_file 9 | 10 | with open(pickle_file, 'rb') as file: 11 | data = pickle.load(file) 12 | char_list = data['IVOCAB'] 13 | vocab_size = len(char_list) 14 | samples = data['train'] 15 | bigram_counter = collections.Counter() 16 | 17 | for sample in tqdm(samples): 18 | text = sample['trn'] 19 | # text = [char_list[idx] for idx in text] 20 | tokens = list(text) 21 | bigrm = nltk.bigrams(tokens) 22 | # print(*map(' '.join, bigrm), sep=', ') 23 | 24 | # get the frequency of each bigram in our corpus 25 | bigram_counter.update(bigrm) 26 | 27 | # what are the ten most popular ngrams in this Spanish corpus? 28 | print(bigram_counter.most_common(10)) 29 | 30 | temp_dict = dict() 31 | for key, value in bigram_counter.items(): 32 | temp_dict[key] = value 33 | 34 | print('smoothing and freq -> prob') 35 | bigram_freq = dict() 36 | for i in tqdm(range(vocab_size)): 37 | freq_list = [] 38 | for j in range(vocab_size): 39 | if (i, j) in temp_dict: 40 | freq_list.append(temp_dict[(i, j)]) 41 | else: 42 | freq_list.append(1) 43 | 44 | freq_list = np.array(freq_list) 45 | freq_list = freq_list / np.sum(freq_list) 46 | 47 | assert (len(freq_list) == vocab_size) 48 | bigram_freq[i] = freq_list 49 | 50 | print(len(bigram_freq[0])) 51 | with open('bigram_freq.pkl', 'wb') as file: 52 | pickle.dump(bigram_freq, file) 53 | -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | # Usage : python pre_process.py --n_samples train:800,dev:100,test:100 2 | # python pre_process.py --n_samples train:800,dev:100 3 | # ... 4 | 5 | 6 | import os 7 | import pickle 8 | 9 | from tqdm import tqdm 10 | 11 | from config import wav_folder, tran_file, pickle_file 12 | from utils import ensure_folder, parse_args 13 | 14 | 15 | def get_data(split, n_samples): 16 | print('getting {} data...'.format(split)) 17 | 18 | global VOCAB 19 | 20 | with open(tran_file, 'r', encoding='utf-8') as file: 21 | lines = file.readlines() 22 | 23 | tran_dict = dict() 24 | for line in lines: 25 | tokens = line.split() 26 | key = tokens[0] 27 | trn = ''.join(tokens[1:]) 28 | tran_dict[key] = trn 29 | 30 | samples = [] 31 | 32 | #n_samples = 5000 33 | rest = n_samples 34 | 35 | folder = os.path.join(wav_folder, split) 36 | ensure_folder(folder) 37 | dirs = [os.path.join(folder, d) for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))] 38 | for dir in tqdm(dirs): 39 | files = [f for f in os.listdir(dir) if f.endswith('.wav')] 40 | 41 | rest = len(files) if n_samples <= 0 else rest 42 | 43 | for f in files[:rest]: 44 | 45 | wave = os.path.join(dir, f) 46 | 47 | key = f.split('.')[0] 48 | 49 | if key in tran_dict: 50 | trn = tran_dict[key] 51 | trn = list(trn.strip()) + [''] 52 | 53 | for token in trn: 54 | build_vocab(token) 55 | 56 | trn = [VOCAB[token] for token in trn] 57 | 58 | samples.append({'trn': trn, 'wave': wave}) 59 | 60 | rest = rest - len(files) if n_samples > 0 else rest 61 | if rest <= 0 : 62 | break 63 | 64 | print('split: {}, num_files: {}'.format(split, len(samples))) 65 | return samples 66 | 67 | 68 | def build_vocab(token): 69 | global VOCAB, IVOCAB 70 | if not token in VOCAB: 71 | next_index = len(VOCAB) 72 | VOCAB[token] = next_index 73 | IVOCAB[next_index] = token 74 | 75 | 76 | if __name__ == "__main__": 77 | 78 | # number of examples to use 79 | global args 80 | args = parse_args() 81 | tmp = args.n_samples.split(",") 82 | tmp = [a.split(":") for a in tmp] 83 | tmp = {a[0]:int(a[1]) for a in tmp} 84 | args.n_samples = {"train":-1, "dev":-1,"test":-1} 85 | args.n_samples.update(tmp) 86 | 87 | VOCAB = {'': 0, '': 1} 88 | IVOCAB = {0: '', 1: ''} 89 | 90 | data = dict() 91 | data['VOCAB'] = VOCAB 92 | data['IVOCAB'] = IVOCAB 93 | data['train'] = get_data('train', args.n_samples["train"]) 94 | data['dev'] = get_data('dev', args.n_samples["dev"]) 95 | data['test'] = get_data('test', args.n_samples["test"]) 96 | 97 | with open(pickle_file, 'wb') as file: 98 | pickle.dump(data, file) 99 | 100 | print('num_train: ' + str(len(data['train']))) 101 | print('num_dev: ' + str(len(data['dev']))) 102 | print('num_test: ' + str(len(data['test']))) 103 | print('vocab_size: ' + str(len(data['VOCAB']))) 104 | -------------------------------------------------------------------------------- /replace.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | 4 | if __name__ == '__main__': 5 | with open('README.t', 'r', encoding="utf-8") as file: 6 | text = file.readlines() 7 | text = ''.join(text) 8 | 9 | with open('results.json', 'r', encoding="utf-8") as file: 10 | results = json.load(file) 11 | 12 | print(results[0]) 13 | 14 | for i, result in enumerate(results): 15 | out_key = 'out_list_{}'.format(i) 16 | text = text.replace('$({})'.format(out_key), '
'.join(result[out_key])) 17 | gt_key = 'gt_{}'.format(i) 18 | text = text.replace('$({})'.format(gt_key), result[gt_key]) 19 | 20 | text = text.replace('', '') 21 | text = text.replace('', '') 22 | 23 | with open('README.md', 'w', encoding="utf-8") as file: 24 | file.write(text) 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa -------------------------------------------------------------------------------- /results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "out_list_0": [ 4 | "我国的经济处在爬破过凯的重要公考", 5 | "我国的经济处在爬破过凯的重要公口", 6 | "我国的经济处在盘破过凯的重要公考", 7 | "我国的经济处在爬破过凯的重要公靠", 8 | "我国的经济处在爬坡过凯的重要公考" 9 | ], 10 | "gt_0": "我国的经济处在爬坡过坎的重要关口" 11 | }, 12 | { 13 | "out_list_1": [ 14 | "完善主地承包经一全流市市场", 15 | "完善主地承包经一全六市市场", 16 | "完善主地承包经营全流市市场", 17 | "完善主地承包经一权流市市场", 18 | "完善主地承包经营全六市市场" 19 | ], 20 | "gt_1": "完善土地承包经营权流转市场" 21 | }, 22 | { 23 | "out_list_2": [ 24 | "临长各类设施使用年限", 25 | "严长各类设施使用年限", 26 | "延长各类设施使用年限", 27 | "很长各类设施使用年限", 28 | "难长各类设施使用年限" 29 | ], 30 | "gt_2": "延长各类设施使用年限" 31 | }, 32 | { 33 | "out_list_3": [ 34 | "苹果此举是为了节约用电量", 35 | "苹果此举是是了节约用电量", 36 | "苹果此举是为了解约用电量", 37 | "苹果此举是为了节约用电令", 38 | "苹果此举只为了节约用电量" 39 | ], 40 | "gt_3": "苹果此举是为了节约用电量" 41 | }, 42 | { 43 | "out_list_4": [ 44 | "反他们也可以有机会参与体育运动", 45 | "让他们也可以有机会参与体育运动", 46 | "反她们也可以有机会参与体育运动", 47 | "范他们也可以有机会参与体育运动", 48 | "但他们也可以有机会参与体育运动" 49 | ], 50 | "gt_4": "让他们也可以有机会参与体育运动" 51 | }, 52 | { 53 | "out_list_5": [ 54 | "陈言希穿着粉色上衣", 55 | "陈闫希穿着粉色上衣", 56 | "陈延希穿着粉色上衣", 57 | "陈言琪穿着粉色上衣", 58 | "陈演希穿着粉色上衣" 59 | ], 60 | "gt_5": "陈妍希穿着粉色上衣" 61 | }, 62 | { 63 | "out_list_6": [ 64 | "说起自己的伴女大下", 65 | "说起自己的伴理大下", 66 | "说起自己的半女大下", 67 | "说起自己的办女大下", 68 | "说起自己的半理大下" 69 | ], 70 | "gt_6": "说起自己的伴侣大侠" 71 | }, 72 | { 73 | "out_list_7": [ 74 | "每日经济新闻记者注意到", 75 | "每日经济新闻记者朱意到", 76 | "每日经济新闻记者注一到", 77 | "每日经济新闻记者注注到", 78 | "每日经济新闻记者注以到" 79 | ], 80 | "gt_7": "每日经济新闻记者注意到" 81 | }, 82 | { 83 | "out_list_8": [ 84 | "这是今年五月份以来库存环比增幅幅小了一次", 85 | "这是今年五月份以来库存环比增幅最小了一次", 86 | "这是今年五月份以来库存环比增幅幅小的一次", 87 | "这是今年五月份以来库存环比增幅最小的一次", 88 | "这是今年五月份以来库存环比增幅幅小小一次" 89 | ], 90 | "gt_8": "这是今年五月份以来库存环比增幅最小的一次" 91 | }, 92 | { 93 | "out_list_9": [ 94 | "一个人的精使生命就将走向摔老", 95 | "一个连的精使生命就将走向摔老", 96 | "一个人的金使生命就将走向摔老", 97 | "一个人的坚使生命就将走向摔老", 98 | "一个连的金使生命就将走向摔老" 99 | ], 100 | "gt_9": "一个人的精神生命就将走向衰老" 101 | } 102 | ] -------------------------------------------------------------------------------- /specAugment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/specAugment/__init__.py -------------------------------------------------------------------------------- /specAugment/sparse_image_warp_np.py: -------------------------------------------------------------------------------- 1 | """Image warping using sparse flow defined at control points.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import scipy as sp 8 | from scipy.interpolate import interp2d 9 | from skimage.transform import warp 10 | 11 | 12 | def _get_grid_locations(image_height, image_width): 13 | """Wrapper for np.meshgrid.""" 14 | 15 | y_range = np.linspace(0, image_height - 1, image_height) 16 | x_range = np.linspace(0, image_width - 1, image_width) 17 | y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') 18 | return np.stack((y_grid, x_grid), -1) 19 | 20 | 21 | def _expand_to_minibatch(np_array, batch_size): 22 | """Tile arbitrarily-sized np_array to include new batch dimension.""" 23 | tiles = [batch_size] + [1] * np_array.ndim 24 | return np.tile(np.expand_dims(np_array, 0), tiles) 25 | 26 | 27 | def _get_boundary_locations(image_height, image_width, num_points_per_edge): 28 | """Compute evenly-spaced indices along edge of image.""" 29 | y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) 30 | x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) 31 | ys, xs = np.meshgrid(y_range, x_range, indexing='ij') 32 | is_boundary = np.logical_or( 33 | np.logical_or(xs == 0, xs == image_width - 1), 34 | np.logical_or(ys == 0, ys == image_height - 1)) 35 | return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) 36 | 37 | 38 | def _add_zero_flow_controls_at_boundary(control_point_locations, 39 | control_point_flows, image_height, 40 | image_width, boundary_points_per_edge): 41 | # batch_size = tensor_shape.dimension_value(control_point_locations.shape[0]) 42 | batch_size = control_point_locations.shape[0] 43 | 44 | boundary_point_locations = _get_boundary_locations(image_height, image_width, 45 | boundary_points_per_edge) 46 | 47 | boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) 48 | 49 | type_to_use = control_point_locations.dtype 50 | # boundary_point_locations = constant_op.constant( 51 | # _expand_to_minibatch(boundary_point_locations, batch_size), 52 | # dtype=type_to_use) 53 | boundary_point_locations = _expand_to_minibatch(boundary_point_locations, batch_size) 54 | 55 | # boundary_point_flows = constant_op.constant( 56 | # _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) 57 | boundary_point_flows = _expand_to_minibatch(boundary_point_flows, batch_size) 58 | 59 | # merged_control_point_locations = array_ops.concat( 60 | # [control_point_locations, boundary_point_locations], 1) 61 | 62 | merged_control_point_locations = np.concatenate( 63 | [control_point_locations, boundary_point_locations], 1) 64 | 65 | # merged_control_point_flows = array_ops.concat( 66 | # [control_point_flows, boundary_point_flows], 1) 67 | 68 | merged_control_point_flows = np.concatenate( 69 | [control_point_flows, boundary_point_flows], 1) 70 | 71 | return merged_control_point_locations, merged_control_point_flows 72 | 73 | 74 | def sparse_image_warp_np(image, 75 | source_control_point_locations, 76 | dest_control_point_locations, 77 | interpolation_order=2, 78 | regularization_weight=0.0, 79 | num_boundary_points=0): 80 | # image = ops.convert_to_tensor(image) 81 | # source_control_point_locations = ops.convert_to_tensor( 82 | # source_control_point_locations) 83 | # dest_control_point_locations = ops.convert_to_tensor( 84 | # dest_control_point_locations) 85 | 86 | control_point_flows = ( 87 | dest_control_point_locations - source_control_point_locations) 88 | 89 | clamp_boundaries = num_boundary_points > 0 90 | boundary_points_per_edge = num_boundary_points - 1 91 | 92 | # batch_size, image_height, image_width, _ = image.get_shape().as_list() 93 | batch_size, image_height, image_width, _ = list(image.shape) 94 | 95 | # This generates the dense locations where the interpolant 96 | # will be evaluated. 97 | 98 | grid_locations = _get_grid_locations(image_height, image_width) 99 | 100 | flattened_grid_locations = np.reshape(grid_locations, 101 | [image_height * image_width, 2]) 102 | 103 | # flattened_grid_locations = constant_op.constant( 104 | # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) 105 | 106 | flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size) 107 | 108 | if clamp_boundaries: 109 | (dest_control_point_locations, 110 | control_point_flows) = _add_zero_flow_controls_at_boundary( 111 | dest_control_point_locations, control_point_flows, image_height, 112 | image_width, boundary_points_per_edge) 113 | 114 | # flattened_flows = interpolate_spline.interpolate_spline( 115 | # dest_control_point_locations, control_point_flows, 116 | # flattened_grid_locations, interpolation_order, regularization_weight) 117 | flattened_flows = sp.interpolate.spline( 118 | dest_control_point_locations, control_point_flows, 119 | flattened_grid_locations, interpolation_order, regularization_weight) 120 | 121 | # dense_flows = array_ops.reshape(flattened_flows, 122 | # [batch_size, image_height, image_width, 2]) 123 | dense_flows = np.reshape(flattened_flows, 124 | [batch_size, image_height, image_width, 2]) 125 | 126 | # warped_image = dense_image_warp.dense_image_warp(image, dense_flows) 127 | warped_image = warp(image, dense_flows) 128 | 129 | return warped_image, dense_flows 130 | 131 | 132 | def dense_image_warp(image, flow): 133 | # batch_size, height, width, channels = (array_ops.shape(image)[0], 134 | # array_ops.shape(image)[1], 135 | # array_ops.shape(image)[2], 136 | # array_ops.shape(image)[3]) 137 | batch_size, height, width, channels = (np.shape(image)[0], 138 | np.shape(image)[1], 139 | np.shape(image)[2], 140 | np.shape(image)[3]) 141 | 142 | # The flow is defined on the image grid. Turn the flow into a list of query 143 | # points in the grid space. 144 | # grid_x, grid_y = array_ops.meshgrid( 145 | # math_ops.range(width), math_ops.range(height)) 146 | # stacked_grid = math_ops.cast( 147 | # array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) 148 | # batched_grid = array_ops.expand_dims(stacked_grid, axis=0) 149 | # query_points_on_grid = batched_grid - flow 150 | # query_points_flattened = array_ops.reshape(query_points_on_grid, 151 | # [batch_size, height * width, 2]) 152 | grid_x, grid_y = np.meshgrid( 153 | np.range(width), np.range(height)) 154 | stacked_grid = np.cast( 155 | np.stack([grid_y, grid_x], axis=2), flow.dtype) 156 | batched_grid = np.expand_dims(stacked_grid, axis=0) 157 | query_points_on_grid = batched_grid - flow 158 | query_points_flattened = np.reshape(query_points_on_grid, 159 | [batch_size, height * width, 2]) 160 | # Compute values at the query points, then reshape the result back to the 161 | # image grid. 162 | interpolated = interp2d(image, query_points_flattened) 163 | interpolated = np.reshape(interpolated, 164 | [batch_size, height, width, channels]) 165 | return interpolated 166 | -------------------------------------------------------------------------------- /specAugment/sparse_image_warp_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 RnD at Spoon Radio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # import torch 16 | # import numpy as np 17 | # from torch.autograd import Variable 18 | # import librosa 19 | import random 20 | 21 | import numpy as np 22 | # import scipy.signal 23 | import torch 24 | 25 | 26 | # import torchaudio 27 | # from torchaudio import transforms 28 | # import math 29 | # from torch.utils.data import DataLoader 30 | # from torch.utils.data import Dataset 31 | 32 | 33 | def time_warp(spec, W=5): 34 | spec = spec.view(1, spec.shape[0], spec.shape[1]) 35 | num_rows = spec.shape[1] 36 | spec_len = spec.shape[2] 37 | 38 | y = num_rows // 2 39 | horizontal_line_at_ctr = spec[0][y] 40 | assert len(horizontal_line_at_ctr) == spec_len 41 | 42 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)] 43 | assert isinstance(point_to_warp, torch.Tensor) 44 | 45 | # Uniform distribution from (0,W) with chance to be up to W negative 46 | dist_to_warp = random.randrange(-W, W) 47 | src_pts, dest_pts = torch.tensor([[[y, point_to_warp]]]), torch.tensor([[[y, point_to_warp + dist_to_warp]]]) 48 | warped_spectro, dense_flows = SparseImageWarp.sparse_image_warp(spec, src_pts, dest_pts) 49 | return warped_spectro.squeeze(3) 50 | 51 | 52 | def freq_mask(spec, F=15, num_masks=1, replace_with_zero=False): 53 | cloned = spec.clone() 54 | num_mel_channels = cloned.shape[1] 55 | 56 | for i in range(0, num_masks): 57 | f = random.randrange(0, F) 58 | f_zero = random.randrange(0, num_mel_channels - f) 59 | 60 | # avoids randrange error if values are equal and range is empty 61 | if (f_zero == f_zero + f): return cloned 62 | 63 | mask_end = random.randrange(f_zero, f_zero + f) 64 | if (replace_with_zero): 65 | cloned[0][f_zero:mask_end] = 0 66 | else: 67 | cloned[0][f_zero:mask_end] = cloned.mean() 68 | 69 | return cloned 70 | 71 | 72 | def time_mask(spec, T=15, num_masks=1, replace_with_zero=False): 73 | cloned = spec.clone() 74 | len_spectro = cloned.shape[2] 75 | 76 | for i in range(0, num_masks): 77 | t = random.randrange(0, T) 78 | t_zero = random.randrange(0, len_spectro - t) 79 | 80 | # avoids randrange error if values are equal and range is empty 81 | if (t_zero == t_zero + t): return cloned 82 | 83 | mask_end = random.randrange(t_zero, t_zero + t) 84 | if (replace_with_zero): 85 | cloned[0][:, t_zero:mask_end] = 0 86 | else: 87 | cloned[0][:, t_zero:mask_end] = cloned.mean() 88 | return cloned 89 | 90 | 91 | def sparse_image_warp(img_tensor, 92 | source_control_point_locations, 93 | dest_control_point_locations, 94 | interpolation_order=2, 95 | regularization_weight=0.0, 96 | num_boundaries_points=0): 97 | control_point_flows = (dest_control_point_locations - source_control_point_locations) 98 | 99 | batch_size, image_height, image_width = img_tensor.shape 100 | grid_locations = get_grid_locations(image_height, image_width) 101 | flattened_grid_locations = torch.tensor(flatten_grid_locations(grid_locations, image_height, image_width)) 102 | 103 | flattened_flows = interpolate_spline( 104 | dest_control_point_locations, 105 | control_point_flows, 106 | flattened_grid_locations, 107 | interpolation_order, 108 | regularization_weight) 109 | 110 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width) 111 | 112 | warped_image = dense_image_warp(img_tensor, dense_flows) 113 | 114 | return warped_image, dense_flows 115 | 116 | 117 | def get_grid_locations(image_height, image_width): 118 | """Wrapper for np.meshgrid.""" 119 | 120 | y_range = np.linspace(0, image_height - 1, image_height) 121 | x_range = np.linspace(0, image_width - 1, image_width) 122 | y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') 123 | return np.stack((y_grid, x_grid), -1) 124 | 125 | 126 | def flatten_grid_locations(grid_locations, image_height, image_width): 127 | return np.reshape(grid_locations, [image_height * image_width, 2]) 128 | 129 | 130 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width): 131 | # possibly .view 132 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) 133 | 134 | 135 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0, ): 136 | # First, fit the spline to the observed data. 137 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight) 138 | # Then, evaluate the spline at the query locations. 139 | query_values = apply_interpolation(query_points, train_points, w, v, order) 140 | 141 | return query_values 142 | 143 | 144 | def solve_interpolation(train_points, train_values, order, regularization_weight): 145 | b, n, d = train_points.shape 146 | k = train_values.shape[-1] 147 | 148 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.) 149 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. 150 | # To account for python style guidelines we use 151 | # matrix_a for A and matrix_b for B. 152 | 153 | c = train_points 154 | f = train_values.float() 155 | 156 | matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n] 157 | # if regularization_weight > 0: 158 | # batch_identity_matrix = array_ops.expand_dims( 159 | # linalg_ops.eye(n, dtype=c.dtype), 0) 160 | # matrix_a += regularization_weight * batch_identity_matrix 161 | 162 | # Append ones to the feature values for the bias term in the linear model. 163 | ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1]) 164 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1] 165 | 166 | # [b, n + d + 1, n] 167 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) 168 | 169 | num_b_cols = matrix_b.shape[2] # d + 1 170 | 171 | # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand. 172 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication. 173 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10 174 | right_block = torch.cat((matrix_b, lhs_zeros), 175 | 1) # [b, n + d + 1, d + 1] 176 | lhs = torch.cat((left_block, right_block), 177 | 2) # [b, n + d + 1, n + d + 1] 178 | 179 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float() 180 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k] 181 | 182 | # Then, solve the linear system and unpack the results. 183 | X, LU = torch.solve(rhs, lhs) 184 | w = X[:, :n, :] 185 | v = X[:, n:, :] 186 | 187 | return w, v 188 | 189 | 190 | def cross_squared_distance_matrix(x, y): 191 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim). 192 | Computes the pairwise distances between rows of x and rows of y 193 | Args: 194 | x: [batch_size, n, d] float `Tensor` 195 | y: [batch_size, m, d] float `Tensor` 196 | Returns: 197 | squared_dists: [batch_size, n, m] float `Tensor`, where 198 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 199 | """ 200 | x_norm_squared = torch.sum(torch.mul(x, x)) 201 | y_norm_squared = torch.sum(torch.mul(y, y)) 202 | 203 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1)) 204 | 205 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj 206 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared 207 | 208 | return squared_dists.float() 209 | 210 | 211 | def phi(r, order): 212 | """Coordinate-wise nonlinearity used to define the order of the interpolation. 213 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. 214 | Args: 215 | r: input op 216 | order: interpolation order 217 | Returns: 218 | phi_k evaluated coordinate-wise on r, for k = r 219 | """ 220 | EPSILON = torch.tensor(1e-10) 221 | # using EPSILON prevents log(0), sqrt0), etc. 222 | # sqrt(0) is well-defined, but its gradient is not 223 | if order == 1: 224 | r = torch.max(r, EPSILON) 225 | r = torch.sqrt(r) 226 | return r 227 | elif order == 2: 228 | return 0.5 * r * torch.log(torch.max(r, EPSILON)) 229 | elif order == 4: 230 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) 231 | elif order % 2 == 0: 232 | r = torch.max(r, EPSILON) 233 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) 234 | else: 235 | r = torch.max(r, EPSILON) 236 | return torch.pow(r, 0.5 * order) 237 | 238 | 239 | def apply_interpolation(query_points, train_points, w, v, order): 240 | """Apply polyharmonic interpolation model to data. 241 | Given coefficients w and v for the interpolation model, we evaluate 242 | interpolated function values at query_points. 243 | Args: 244 | query_points: `[b, m, d]` x values to evaluate the interpolation at 245 | train_points: `[b, n, d]` x values that act as the interpolation centers 246 | ( the c variables in the wikipedia article) 247 | w: `[b, n, k]` weights on each interpolation center 248 | v: `[b, d, k]` weights on each input dimension 249 | order: order of the interpolation 250 | Returns: 251 | Polyharmonic interpolation evaluated at points defined in query_points. 252 | """ 253 | query_points = query_points.unsqueeze(0) 254 | # First, compute the contribution from the rbf term. 255 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float()) 256 | phi_pairwise_dists = phi(pairwise_dists, order) 257 | 258 | rbf_term = torch.matmul(phi_pairwise_dists, w) 259 | 260 | # Then, compute the contribution from the linear term. 261 | # Pad query_points with ones, for the bias term in the linear model. 262 | ones = torch.ones_like(query_points[..., :1]) 263 | query_points_pad = torch.cat(( 264 | query_points, 265 | ones 266 | ), 2).float() 267 | linear_term = torch.matmul(query_points_pad, v) 268 | 269 | return rbf_term + linear_term 270 | 271 | 272 | def dense_image_warp(image, flow): 273 | """Image warping using per-pixel flow vectors. 274 | Apply a non-linear warp to the image, where the warp is specified by a dense 275 | flow field of offset vectors that define the correspondences of pixel values 276 | in the output image back to locations in the source image. Specifically, the 277 | pixel value at output[b, j, i, c] is 278 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. 279 | The locations specified by this formula do not necessarily map to an int 280 | index. Therefore, the pixel value is obtained by bilinear 281 | interpolation of the 4 nearest pixels around 282 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside 283 | of the image, we use the nearest pixel values at the image boundary. 284 | Args: 285 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 286 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 287 | name: A name for the operation (optional). 288 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64, 289 | and do not necessarily have to be the same type. 290 | Returns: 291 | A 4-D float `Tensor` with shape`[batch, height, width, channels]` 292 | and same type as input image. 293 | Raises: 294 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number 295 | of dimensions. 296 | """ 297 | image = image.unsqueeze(3) # add a single channel dimension to image tensor 298 | batch_size, height, width, channels = image.shape 299 | 300 | # The flow is defined on the image grid. Turn the flow into a list of query 301 | # points in the grid space. 302 | grid_x, grid_y = torch.meshgrid( 303 | torch.arange(width), torch.arange(height)) 304 | 305 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() 306 | 307 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) 308 | 309 | query_points_on_grid = batched_grid - flow 310 | query_points_flattened = torch.reshape(query_points_on_grid, 311 | [batch_size, height * width, 2]) 312 | # Compute values at the query points, then reshape the result back to the 313 | # image grid. 314 | interpolated = interpolate_bilinear(image, query_points_flattened) 315 | interpolated = torch.reshape(interpolated, 316 | [batch_size, height, width, channels]) 317 | return interpolated 318 | 319 | 320 | def interpolate_bilinear(grid, 321 | query_points, 322 | name='interpolate_bilinear', 323 | indexing='ij'): 324 | """Similar to Matlab's interp2 function. 325 | Finds values for query points on a grid using bilinear interpolation. 326 | Args: 327 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 328 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. 329 | name: a name for the operation (optional). 330 | indexing: whether the query points are specified as row and column (ij), 331 | or Cartesian coordinates (xy). 332 | Returns: 333 | values: a 3-D `Tensor` with shape `[batch, N, channels]` 334 | Raises: 335 | ValueError: if the indexing mode is invalid, or if the shape of the inputs 336 | invalid. 337 | """ 338 | if indexing != 'ij' and indexing != 'xy': 339 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'') 340 | 341 | shape = grid.shape 342 | if len(shape) != 4: 343 | msg = 'Grid must be 4 dimensional. Received size: ' 344 | raise ValueError(msg + str(grid.shape)) 345 | 346 | batch_size, height, width, channels = grid.shape 347 | 348 | shape = [batch_size, height, width, channels] 349 | query_type = query_points.dtype 350 | grid_type = grid.dtype 351 | 352 | num_queries = query_points.shape[1] 353 | 354 | alphas = [] 355 | floors = [] 356 | ceils = [] 357 | index_order = [0, 1] if indexing == 'ij' else [1, 0] 358 | unstacked_query_points = query_points.unbind(2) 359 | 360 | for dim in index_order: 361 | queries = unstacked_query_points[dim] 362 | 363 | size_in_indexing_dimension = shape[dim + 1] 364 | 365 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 366 | # is still a valid index into the grid. 367 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type) 368 | min_floor = torch.tensor(0.0, dtype=query_type) 369 | maxx = torch.max(min_floor, torch.floor(queries)) 370 | floor = torch.min(maxx, max_floor) 371 | int_floor = floor.long() 372 | floors.append(int_floor) 373 | ceil = int_floor + 1 374 | ceils.append(ceil) 375 | 376 | # alpha has the same type as the grid, as we will directly use alpha 377 | # when taking linear combinations of pixel values from the image. 378 | alpha = torch.tensor(queries - floor, dtype=grid_type) 379 | min_alpha = torch.tensor(0.0, dtype=grid_type) 380 | max_alpha = torch.tensor(1.0, dtype=grid_type) 381 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) 382 | 383 | # Expand alpha to [b, n, 1] so we can use broadcasting 384 | # (since the alpha values don't depend on the channel). 385 | alpha = torch.unsqueeze(alpha, 2) 386 | alphas.append(alpha) 387 | 388 | flattened_grid = torch.reshape( 389 | grid, [batch_size * height * width, channels]) 390 | batch_offsets = torch.reshape( 391 | torch.arange(batch_size) * height * width, [batch_size, 1]) 392 | 393 | # This wraps array_ops.gather. We reshape the image data such that the 394 | # batch, y, and x coordinates are pulled into the first dimension. 395 | # Then we gather. Finally, we reshape the output back. It's possible this 396 | # code would be made simpler by using array_ops.gather_nd. 397 | def gather(y_coords, x_coords, name): 398 | linear_coordinates = batch_offsets + y_coords * width + x_coords 399 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) 400 | return torch.reshape(gathered_values, 401 | [batch_size, num_queries, channels]) 402 | 403 | # grab the pixel values in the 4 corners around each query point 404 | top_left = gather(floors[0], floors[1], 'top_left') 405 | top_right = gather(floors[0], ceils[1], 'top_right') 406 | bottom_left = gather(ceils[0], floors[1], 'bottom_left') 407 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right') 408 | 409 | interp_top = alphas[1] * (top_right - top_left) + top_left 410 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 411 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top 412 | 413 | return interp 414 | -------------------------------------------------------------------------------- /specAugment/spec_augment_pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 RnD at Spoon Radio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """specAugment Implementation for Tensorflow. 16 | Related paper : https://arxiv.org/pdf/1904.08779.pdf 17 | In this paper, show summarized parameters by each open datasets in Tabel 1. 18 | ----------------------------------------- 19 | Policy | W | F | m_F | T | p | m_T 20 | ----------------------------------------- 21 | None | 0 | 0 | - | 0 | - | - 22 | ----------------------------------------- 23 | LB | 80 | 27 | 1 | 100 | 1.0 | 1 24 | ----------------------------------------- 25 | LD | 80 | 27 | 2 | 100 | 1.0 | 2 26 | ----------------------------------------- 27 | SM | 40 | 15 | 2 | 70 | 0.2 | 2 28 | ----------------------------------------- 29 | SS | 40 | 27 | 2 | 70 | 0.2 | 2 30 | ----------------------------------------- 31 | LB : LibriSpeech basic 32 | LD : LibriSpeech double 33 | SM : Switchboard mild 34 | SS : Switchboard strong 35 | """ 36 | 37 | import random 38 | 39 | import librosa 40 | import librosa.display 41 | import matplotlib 42 | import numpy as np 43 | 44 | matplotlib.use('TkAgg') 45 | import matplotlib.pyplot as plt 46 | from specAugment.sparse_image_warp_np import sparse_image_warp_np 47 | import torch 48 | 49 | 50 | def time_warp(spec, W=5): 51 | num_rows = spec.shape[1] 52 | spec_len = spec.shape[2] 53 | 54 | y = num_rows // 2 55 | horizontal_line_at_ctr = spec[0][y] 56 | # assert len(horizontal_line_at_ctr) == spec_len 57 | 58 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)] 59 | # assert isinstance(point_to_warp, torch.Tensor) 60 | 61 | # Uniform distribution from (0,W) with chance to be up to W negative 62 | dist_to_warp = random.randrange(-W, W) 63 | src_pts = torch.tensor([[[y, point_to_warp]]]) 64 | dest_pts = torch.tensor([[[y, point_to_warp + dist_to_warp]]]) 65 | warped_spectro, dense_flows = sparse_image_warp_np(spec, src_pts, dest_pts) 66 | return warped_spectro.squeeze(3) 67 | 68 | 69 | def spec_augment(mel_spectrogram, time_warping_para=80, frequency_masking_para=27, 70 | time_masking_para=100, frequency_mask_num=1, time_mask_num=1): 71 | """Spec augmentation Calculation Function. 72 | 'specAugment' have 3 steps for audio data augmentation. 73 | first step is time warping using Tensorflow's image_sparse_warp function. 74 | Second step is frequency masking, last step is time masking. 75 | # Arguments: 76 | mel_spectrogram(numpy array): audio file path of you want to warping and masking. 77 | time_warping_para(float): Augmentation parameter, "time warp parameter W". 78 | If none, default = 80 for LibriSpeech. 79 | frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F" 80 | If none, default = 100 for LibriSpeech. 81 | time_masking_para(float): Augmentation parameter, "time mask parameter T" 82 | If none, default = 27 for LibriSpeech. 83 | frequency_mask_num(float): number of frequency masking lines, "m_F". 84 | If none, default = 1 for LibriSpeech. 85 | time_mask_num(float): number of time masking lines, "m_T". 86 | If none, default = 1 for LibriSpeech. 87 | # Returns 88 | mel_spectrogram(numpy array): warped and masked mel spectrogram. 89 | """ 90 | v = mel_spectrogram.shape[1] 91 | tau = mel_spectrogram.shape[2] 92 | 93 | # Step 1 : Time warping 94 | warped_mel_spectrogram = time_warp(mel_spectrogram) 95 | 96 | # Step 2 : Frequency masking 97 | for i in range(frequency_mask_num): 98 | f = np.random.uniform(low=0.0, high=frequency_masking_para) 99 | f = int(f) 100 | f0 = random.randint(0, v - f) 101 | warped_mel_spectrogram[:, f0:f0 + f, :] = 0 102 | 103 | # Step 3 : Time masking 104 | for i in range(time_mask_num): 105 | t = np.random.uniform(low=0.0, high=time_masking_para) 106 | t = int(t) 107 | t0 = random.randint(0, tau - t) 108 | warped_mel_spectrogram[:, :, t0:t0 + t] = 0 109 | 110 | return warped_mel_spectrogram 111 | 112 | 113 | def visualization_spectrogram(mel_spectrogram, title): 114 | """visualizing result of specAugment 115 | # Arguments: 116 | mel_spectrogram(ndarray): mel_spectrogram to visualize. 117 | title(String): plot figure's title 118 | """ 119 | # Show mel-spectrogram using librosa's specshow. 120 | plt.figure(figsize=(10, 4)) 121 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :], ref=np.max), y_axis='mel', fmax=8000, 122 | x_axis='time') 123 | # plt.colorbar(format='%+2.0f dB') 124 | plt.title(title) 125 | plt.tight_layout() 126 | plt.show() 127 | -------------------------------------------------------------------------------- /specAugment/spec_augment_tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 RnD at Spoon Radio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SpecAugment Implementation for Tensorflow. 16 | Related paper : https://arxiv.org/pdf/1904.08779.pdf 17 | In this paper, show summarized parameters by each open datasets in Tabel 1. 18 | ----------------------------------------- 19 | Policy | W | F | m_F | T | p | m_T 20 | ----------------------------------------- 21 | None | 0 | 0 | - | 0 | - | - 22 | ----------------------------------------- 23 | LB | 80 | 27 | 1 | 100 | 1.0 | 1 24 | ----------------------------------------- 25 | LD | 80 | 27 | 2 | 100 | 1.0 | 2 26 | ----------------------------------------- 27 | SM | 40 | 15 | 2 | 70 | 0.2 | 2 28 | ----------------------------------------- 29 | SS | 40 | 27 | 2 | 70 | 0.2 | 2 30 | ----------------------------------------- 31 | LB : LibriSpeech basic 32 | LD : LibriSpeech double 33 | SM : Switchboard mild 34 | SS : Switchboard strong 35 | """ 36 | 37 | import librosa 38 | import librosa.display 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | import tensorflow as tf 42 | from tensorflow.contrib.image import sparse_image_warp 43 | 44 | 45 | def sparse_warp(mel_spectrogram, time_warping_para=80): 46 | """Spec augmentation Calculation Function. 47 | 'SpecAugment' have 3 steps for audio data augmentation. 48 | first step is time warping using Tensorflow's image_sparse_warp function. 49 | Second step is frequency masking, last step is time masking. 50 | # Arguments: 51 | mel_spectrogram(numpy array): audio file path of you want to warping and masking. 52 | time_warping_para(float): Augmentation parameter, "time warp parameter W". 53 | If none, default = 80 for LibriSpeech. 54 | # Returns 55 | mel_spectrogram(numpy array): warped and masked mel spectrogram. 56 | """ 57 | 58 | fbank_size = tf.shape(mel_spectrogram) 59 | n, v = fbank_size[1], fbank_size[2] 60 | 61 | # Step 1 : Time warping 62 | # Image warping control point setting. 63 | # Source 64 | pt = tf.random_uniform([], time_warping_para, n - time_warping_para, tf.int32) # radnom point along the time axis 65 | src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis 66 | src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt # control points on time-axis 67 | src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) 68 | src_ctr_pts = tf.to_float(src_ctr_pts) 69 | 70 | # Destination 71 | w = tf.random_uniform([], -time_warping_para, time_warping_para, tf.int32) # distance 72 | dest_ctr_pt_freq = src_ctr_pt_freq 73 | dest_ctr_pt_time = src_ctr_pt_time + w 74 | dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) 75 | dest_ctr_pts = tf.to_float(dest_ctr_pts) 76 | 77 | # warp 78 | source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) 79 | dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) 80 | 81 | warped_image, _ = sparse_image_warp(mel_spectrogram, 82 | source_control_point_locations, 83 | dest_control_point_locations) 84 | return warped_image 85 | 86 | 87 | def frequency_masking(mel_spectrogram, v, frequency_masking_para=27, frequency_mask_num=2): 88 | """Spec augmentation Calculation Function. 89 | 'SpecAugment' have 3 steps for audio data augmentation. 90 | first step is time warping using Tensorflow's image_sparse_warp function. 91 | Second step is frequency masking, last step is time masking. 92 | # Arguments: 93 | mel_spectrogram(numpy array): audio file path of you want to warping and masking. 94 | frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F" 95 | If none, default = 100 for LibriSpeech. 96 | frequency_mask_num(float): number of frequency masking lines, "m_F". 97 | If none, default = 1 for LibriSpeech. 98 | # Returns 99 | mel_spectrogram(numpy array): warped and masked mel spectrogram. 100 | """ 101 | # Step 2 : Frequency masking 102 | fbank_size = tf.shape(mel_spectrogram) 103 | n, v = fbank_size[1], fbank_size[2] 104 | 105 | for i in range(frequency_mask_num): 106 | f = tf.random_uniform([], minval=0, maxval=frequency_masking_para, dtype=tf.int32) 107 | v = tf.to_int32(v) 108 | f0 = tf.random_uniform([], minval=0, maxval=v - f, dtype=tf.int32) 109 | 110 | # warped_mel_spectrogram[f0:f0 + f, :] = 0 111 | mask = tf.concat((tf.ones(shape=(1, n, v - f0 - f, 1)), 112 | tf.zeros(shape=(1, n, f, 1)), 113 | tf.ones(shape=(1, n, f0, 1)), 114 | ), 2) 115 | mel_spectrogram = mel_spectrogram * mask 116 | return tf.to_float(mel_spectrogram) 117 | 118 | 119 | def time_masking(mel_spectrogram, tau, time_masking_para=100, time_mask_num=2): 120 | """Spec augmentation Calculation Function. 121 | 'SpecAugment' have 3 steps for audio data augmentation. 122 | first step is time warping using Tensorflow's image_sparse_warp function. 123 | Second step is frequency masking, last step is time masking. 124 | # Arguments: 125 | mel_spectrogram(numpy array): audio file path of you want to warping and masking. 126 | time_masking_para(float): Augmentation parameter, "time mask parameter T" 127 | If none, default = 27 for LibriSpeech. 128 | time_mask_num(float): number of time masking lines, "m_T". 129 | If none, default = 1 for LibriSpeech. 130 | # Returns 131 | mel_spectrogram(numpy array): warped and masked mel spectrogram. 132 | """ 133 | fbank_size = tf.shape(mel_spectrogram) 134 | n, v = fbank_size[1], fbank_size[2] 135 | 136 | # Step 3 : Time masking 137 | for i in range(time_mask_num): 138 | t = tf.random_uniform([], minval=0, maxval=time_masking_para, dtype=tf.int32) 139 | t0 = tf.random_uniform([], minval=0, maxval=tau - t, dtype=tf.int32) 140 | 141 | # mel_spectrogram[:, t0:t0 + t] = 0 142 | mask = tf.concat((tf.ones(shape=(1, n - t0 - t, v, 1)), 143 | tf.zeros(shape=(1, t, v, 1)), 144 | tf.ones(shape=(1, t0, v, 1)), 145 | ), 1) 146 | mel_spectrogram = mel_spectrogram * mask 147 | return tf.to_float(mel_spectrogram) 148 | 149 | 150 | def spec_augment(mel_spectrogram): 151 | v = mel_spectrogram.shape[0] 152 | tau = mel_spectrogram.shape[1] 153 | 154 | warped_mel_spectrogram = sparse_warp(mel_spectrogram) 155 | 156 | warped_frequency_spectrogram = frequency_masking(warped_mel_spectrogram, v=v) 157 | 158 | warped_frequency_time_sepctrogram = time_masking(warped_frequency_spectrogram, tau=tau) 159 | 160 | return warped_frequency_time_sepctrogram 161 | 162 | 163 | def visualization_spectrogram(mel_spectrogram, title): 164 | """visualizing first one result of SpecAugment 165 | # Arguments: 166 | mel_spectrogram(ndarray): mel_spectrogram to visualize. 167 | title(String): plot figure's title 168 | """ 169 | # Show mel-spectrogram using librosa's specshow. 170 | plt.figure(figsize=(10, 4)) 171 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000, 172 | x_axis='time') 173 | plt.title(title) 174 | plt.tight_layout() 175 | plt.show() 176 | 177 | 178 | def visualization_tensor_spectrogram(mel_spectrogram, title): 179 | """visualizing first one result of SpecAugment 180 | # Arguments: 181 | mel_spectrogram(ndarray): mel_spectrogram to visualize. 182 | title(String): plot figure's title 183 | """ 184 | 185 | # session for plotting 186 | sess = tf.InteractiveSession() 187 | mel_spectrogram = mel_spectrogram.eval() 188 | 189 | # Show mel-spectrogram using librosa's specshow. 190 | plt.figure(figsize=(10, 4)) 191 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000, 192 | x_axis='time') 193 | # plt.colorbar(format='%+2.0f dB') 194 | plt.title(title) 195 | plt.tight_layout() 196 | plt.show() 197 | -------------------------------------------------------------------------------- /sponsor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/sponsor.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from config import pickle_file, device, input_dim, LFR_m, LFR_n, sos_id, eos_id 8 | from data_gen import build_LFR_features 9 | from utils import extract_feature 10 | from xer import cer_function 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser( 15 | "End-to-End Automatic Speech Recognition Decoding.") 16 | # decode 17 | parser.add_argument('--beam_size', default=5, type=int, 18 | help='Beam size') 19 | parser.add_argument('--nbest', default=1, type=int, 20 | help='Nbest size') 21 | parser.add_argument('--decode_max_len', default=100, type=int, 22 | help='Max output length. If ==0 (default), it uses a ' 23 | 'end-detect function to automatically find maximum ' 24 | 'hypothesis lengths') 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | if __name__ == '__main__': 30 | args = parse_args() 31 | with open(pickle_file, 'rb') as file: 32 | data = pickle.load(file) 33 | char_list = data['IVOCAB'] 34 | samples = data['test'] 35 | 36 | checkpoint = 'BEST_checkpoint.tar' 37 | checkpoint = torch.load(checkpoint, map_location='cpu') 38 | model = checkpoint['model'].to(device) 39 | model.eval() 40 | 41 | num_samples = len(samples) 42 | 43 | total_cer = 0 44 | 45 | for i in tqdm(range(num_samples)): 46 | sample = samples[i] 47 | wave = sample['wave'] 48 | trn = sample['trn'] 49 | 50 | feature = extract_feature(input_file=wave, feature='fbank', dim=input_dim, cmvn=True) 51 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n) 52 | # feature = np.expand_dims(feature, axis=0) 53 | input = torch.from_numpy(feature).to(device) 54 | input_length = [input[0].shape[0]] 55 | input_length = torch.LongTensor(input_length).to(device) 56 | with torch.no_grad(): 57 | nbest_hyps = model.recognize(input, input_length, char_list, args) 58 | 59 | hyp_list = [] 60 | for hyp in nbest_hyps: 61 | out = hyp['yseq'] 62 | out = [char_list[idx] for idx in out if idx not in (sos_id, eos_id)] 63 | out = ''.join(out) 64 | hyp_list.append(out) 65 | 66 | print(hyp_list) 67 | 68 | gt = [char_list[idx] for idx in trn if idx not in (sos_id, eos_id)] 69 | gt = ''.join(gt) 70 | gt_list = [gt] 71 | 72 | print(gt_list) 73 | 74 | cer = cer_function(gt_list, hyp_list) 75 | total_cer += cer 76 | 77 | avg_cer = total_cer / num_samples 78 | 79 | print('avg_cer: ' + str(avg_cer)) 80 | -------------------------------------------------------------------------------- /test/test_decode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from transformer.decoder import Decoder 6 | from transformer.encoder import Encoder 7 | from transformer.transformer import Transformer 8 | 9 | if __name__ == "__main__": 10 | D = 3 11 | beam_size = 5 12 | nbest = 5 13 | defaults = dict(beam_size=beam_size, 14 | nbest=nbest, 15 | decode_max_len=0, 16 | d_input=D, 17 | LFR_m=1, 18 | n_layers_enc=2, 19 | n_head=2, 20 | d_k=6, 21 | d_v=6, 22 | d_model=12, 23 | d_inner=8, 24 | dropout=0.1, 25 | pe_maxlen=100, 26 | d_word_vec=12, 27 | n_layers_dec=2, 28 | tgt_emb_prj_weight_sharing=1) 29 | args = argparse.Namespace(**defaults) 30 | char_list = ["a", "b", "c", "d", "e", "f", "g", "h", "", ""] 31 | sos_id, eos_id = 8, 9 32 | vocab_size = len(char_list) 33 | # model 34 | encoder = Encoder(args.d_input * args.LFR_m, args.n_layers_enc, args.n_head, 35 | args.d_k, args.d_v, args.d_model, args.d_inner, 36 | dropout=args.dropout, pe_maxlen=args.pe_maxlen) 37 | decoder = Decoder(sos_id, eos_id, vocab_size, 38 | args.d_word_vec, args.n_layers_dec, args.n_head, 39 | args.d_k, args.d_v, args.d_model, args.d_inner, 40 | dropout=args.dropout, 41 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing, 42 | pe_maxlen=args.pe_maxlen) 43 | model = Transformer(encoder, decoder) 44 | 45 | for i in range(3): 46 | print("\n***** Utt", i + 1) 47 | Ti = i + 20 48 | input = torch.randn(Ti, D) 49 | length = torch.tensor([Ti], dtype=torch.int) 50 | nbest_hyps = model.recognize(input, length, char_list, args) 51 | 52 | file_path = "./temp.pth" 53 | optimizer = torch.optim.Adam(model.parameters()) 54 | torch.save(model.serialize(model, optimizer, 1, LFR_m=1, LFR_n=1), file_path) 55 | model, LFR_m, LFR_n = Transformer.load_model(file_path) 56 | print(model) 57 | 58 | import os 59 | 60 | os.remove(file_path) 61 | -------------------------------------------------------------------------------- /test/test_lm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | 4 | import nltk 5 | from tqdm import tqdm 6 | 7 | from config import pickle_file 8 | 9 | with open(pickle_file, 'rb') as file: 10 | data = pickle.load(file) 11 | char_list = data['IVOCAB'] 12 | samples = data['train'] 13 | bigram_freq = collections.Counter() 14 | 15 | for sample in tqdm(samples): 16 | text = sample['trn'] 17 | text = [char_list[idx] for idx in text] 18 | tokens = list(text) 19 | bigrm = nltk.bigrams(tokens) 20 | # print(*map(' '.join, bigrm), sep=', ') 21 | 22 | # get the frequency of each bigram in our corpus 23 | bigram_freq.update(bigrm) 24 | 25 | # what are the ten most popular ngrams in this Spanish corpus? 26 | print(bigram_freq.most_common(10)) 27 | -------------------------------------------------------------------------------- /test/test_lr.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | if __name__ == '__main__': 4 | k = 0.2 5 | warmup_steps = 4000 6 | d_model = 512 7 | init_lr = d_model ** (-0.5) 8 | 9 | lr_list = [] 10 | for step_num in range(1, 500000): 11 | lr = k * init_lr * min(step_num ** (-0.5), 12 | step_num * (warmup_steps ** (-1.5))) 13 | lr_list.append(lr) 14 | 15 | print(lr_list[:100]) 16 | print(lr_list[-100:]) 17 | 18 | plt.plot(lr_list) 19 | plt.show() 20 | -------------------------------------------------------------------------------- /test/test_pe.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | """Implement the positional encoding (PE) function. 9 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 10 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 11 | """ 12 | 13 | def __init__(self, d_model, max_len=5000): 14 | super(PositionalEncoding, self).__init__() 15 | # Compute the positional encodings once in log space. 16 | pe = torch.zeros(max_len, d_model, requires_grad=False) 17 | position = torch.arange(0, max_len).unsqueeze(1).float() 18 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 19 | -(math.log(10000.0) / d_model)) 20 | pe[:, 0::2] = torch.sin(position * div_term) 21 | pe[:, 1::2] = torch.cos(position * div_term) 22 | pe = pe.unsqueeze(0) 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, input): 26 | """ 27 | Args: 28 | input: N x T x D 29 | """ 30 | length = input.size(1) 31 | return self.pe[:, :length] 32 | 33 | 34 | if __name__ == '__main__': 35 | import numpy as np 36 | import matplotlib.pyplot as plt 37 | 38 | d_model = 512 39 | max_len = 5000 40 | pe = PositionalEncoding(d_model, max_len) 41 | mat = pe.pe.numpy()[0] # (5000, 512) 42 | mat = np.transpose(mat, (1, 0)) 43 | print(mat.shape) 44 | print(mat) 45 | plt.imshow(mat) 46 | plt.colorbar() 47 | plt.show() 48 | -------------------------------------------------------------------------------- /test/test_specaug.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from data_gen import spec_augment, build_LFR_features 4 | from utils import extract_feature 5 | 6 | LFR_m = 4 7 | LFR_n = 3 8 | 9 | 10 | def plot_data(data, figsize=(16, 4)): 11 | fig, axes = plt.subplots(1, len(data), figsize=figsize) 12 | for i in range(len(data)): 13 | axes[i].imshow(data[i], aspect='auto', origin='bottom', 14 | interpolation='none') 15 | 16 | 17 | if __name__ == '__main__': 18 | path = '../audios/audio_0.wav' 19 | feature = extract_feature(input_file=path, feature='fbank', dim=80, cmvn=True) 20 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n) 21 | 22 | # zero mean and unit variance 23 | feature = (feature - feature.mean()) / feature.std() 24 | 25 | feature_1 = spec_augment(feature) 26 | # 27 | 28 | plot_data((feature.transpose(), feature_1.transpose())) 29 | plt.show() 30 | -------------------------------------------------------------------------------- /test/test_trim.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import soundfile 4 | 5 | from utils import normalize 6 | 7 | sampling_rate = 16000 8 | top_db = 20 9 | reduced_ratios = [] 10 | 11 | for i in range(10): 12 | audiopath = '../audios/audio_{}.wav'.format(i) 13 | print(audiopath) 14 | y, sr = librosa.load(audiopath) 15 | # Trim the beginning and ending silence 16 | yt, index = librosa.effects.trim(y, top_db=top_db) 17 | yt = normalize(yt) 18 | 19 | reduced_ratios.append(len(yt) / len(y)) 20 | 21 | # Print the durations 22 | print(librosa.get_duration(y), librosa.get_duration(yt)) 23 | print(len(y), len(yt)) 24 | target = '../audios/trimed_{}.wav'.format(i) 25 | soundfile.write(target, yt, sampling_rate) 26 | 27 | print('\nreduced_ratio: ' + str(100 - 100 * np.mean(reduced_ratios))) 28 | -------------------------------------------------------------------------------- /test_lm.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | 5 | from config import pickle_file, sos_id, eos_id 6 | 7 | print('loading {}...'.format(pickle_file)) 8 | with open(pickle_file, 'rb') as file: 9 | data = pickle.load(file) 10 | VOCAB = data['VOCAB'] 11 | IVOCAB = data['IVOCAB'] 12 | 13 | print('loading bigram_freq.pkl...') 14 | with open('bigram_freq.pkl', 'rb') as file: 15 | bigram_freq = pickle.load(file) 16 | 17 | OUT_LIST = ['比赛很快便城像一边到的局面第二规合', '比赛很快便呈像一边到的局面第二规合', '比赛很快便城向一边到的局面第二规合', 18 | '比赛很快便呈向一边到的局面第二规合', '比赛很快便城像一边到的局面第二回合'] 19 | GT = '比赛很快便呈向一边倒的局面第二回合' 20 | 21 | print('calculating prob...') 22 | prob_list = [] 23 | for out in OUT_LIST: 24 | print(out) 25 | out = out.replace('', '').replace('', '') 26 | out = [sos_id] + [VOCAB[ch] for ch in out] + [eos_id] 27 | prob = 1.0 28 | for i in range(1, len(out)): 29 | prob *= bigram_freq[(out[i - 1], out[i])] 30 | prob_list.append(prob) 31 | 32 | prob_list = np.array(prob_list) 33 | prob_list = prob_list / np.sum(prob_list) 34 | print(prob_list) 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | # from torch import nn 5 | from tqdm import tqdm 6 | 7 | from config import device, print_freq, vocab_size, sos_id, eos_id 8 | from data_gen import AiShellDataset, pad_collate 9 | from transformer.decoder import Decoder 10 | from transformer.encoder import Encoder 11 | from transformer.loss import cal_performance 12 | from transformer.optimizer import TransformerOptimizer 13 | from transformer.transformer import Transformer 14 | from utils import parse_args, save_checkpoint, AverageMeter, get_logger 15 | 16 | 17 | def train_net(args): 18 | torch.manual_seed(7) 19 | np.random.seed(7) 20 | checkpoint = args.checkpoint 21 | start_epoch = 0 22 | best_loss = float('inf') 23 | writer = SummaryWriter() 24 | epochs_since_improvement = 0 25 | 26 | # Initialize / load checkpoint 27 | if checkpoint is None: 28 | # model 29 | encoder = Encoder(args.d_input * args.LFR_m, args.n_layers_enc, args.n_head, 30 | args.d_k, args.d_v, args.d_model, args.d_inner, 31 | dropout=args.dropout, pe_maxlen=args.pe_maxlen) 32 | decoder = Decoder(sos_id, eos_id, vocab_size, 33 | args.d_word_vec, args.n_layers_dec, args.n_head, 34 | args.d_k, args.d_v, args.d_model, args.d_inner, 35 | dropout=args.dropout, 36 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing, 37 | pe_maxlen=args.pe_maxlen) 38 | model = Transformer(encoder, decoder) 39 | # print(model) 40 | # model = nn.DataParallel(model) 41 | 42 | # optimizer 43 | optimizer = TransformerOptimizer( 44 | torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09)) 45 | 46 | else: 47 | checkpoint = torch.load(checkpoint) 48 | start_epoch = checkpoint['epoch'] + 1 49 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 50 | model = checkpoint['model'] 51 | optimizer = checkpoint['optimizer'] 52 | 53 | logger = get_logger() 54 | 55 | # Move to GPU, if available 56 | model = model.to(device) 57 | 58 | # Custom dataloaders 59 | train_dataset = AiShellDataset(args, 'train') 60 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate, 61 | pin_memory=True, shuffle=True, num_workers=args.num_workers) 62 | valid_dataset = AiShellDataset(args, 'dev') 63 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=pad_collate, 64 | pin_memory=True, shuffle=False, num_workers=args.num_workers) 65 | 66 | # Epochs 67 | for epoch in range(start_epoch, args.epochs): 68 | # One epoch's training 69 | train_loss = train(train_loader=train_loader, 70 | model=model, 71 | optimizer=optimizer, 72 | epoch=epoch, 73 | logger=logger) 74 | writer.add_scalar('model/train_loss', train_loss, epoch) 75 | 76 | lr = optimizer.lr 77 | print('\nLearning rate: {}'.format(lr)) 78 | writer.add_scalar('model/learning_rate', lr, epoch) 79 | step_num = optimizer.step_num 80 | print('Step num: {}\n'.format(step_num)) 81 | 82 | # One epoch's validation 83 | valid_loss = valid(valid_loader=valid_loader, 84 | model=model, 85 | logger=logger) 86 | writer.add_scalar('model/valid_loss', valid_loss, epoch) 87 | 88 | # Check if there was an improvement 89 | is_best = valid_loss < best_loss 90 | best_loss = min(valid_loss, best_loss) 91 | if not is_best: 92 | epochs_since_improvement += 1 93 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 94 | else: 95 | epochs_since_improvement = 0 96 | 97 | # Save checkpoint 98 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best) 99 | 100 | 101 | def train(train_loader, model, optimizer, epoch, logger): 102 | model.train() # train mode (dropout and batchnorm is used) 103 | 104 | losses = AverageMeter() 105 | 106 | # Batches 107 | for i, (data) in enumerate(train_loader): 108 | # Move to GPU, if available 109 | padded_input, padded_target, input_lengths = data 110 | padded_input = padded_input.to(device) 111 | padded_target = padded_target.to(device) 112 | input_lengths = input_lengths.to(device) 113 | 114 | # Forward prop. 115 | pred, gold = model(padded_input, input_lengths, padded_target) 116 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing) 117 | 118 | # Back prop. 119 | optimizer.zero_grad() 120 | loss.backward() 121 | 122 | # Update weights 123 | optimizer.step() 124 | 125 | # Keep track of metrics 126 | losses.update(loss.item()) 127 | 128 | # Print status 129 | if i % print_freq == 0: 130 | logger.info('Epoch: [{0}][{1}/{2}]\t' 131 | 'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, len(train_loader), loss=losses)) 132 | 133 | return losses.avg 134 | 135 | 136 | def valid(valid_loader, model, logger): 137 | model.eval() 138 | 139 | losses = AverageMeter() 140 | 141 | # Batches 142 | for data in tqdm(valid_loader): 143 | # Move to GPU, if available 144 | padded_input, padded_target, input_lengths = data 145 | padded_input = padded_input.to(device) 146 | padded_target = padded_target.to(device) 147 | input_lengths = input_lengths.to(device) 148 | 149 | with torch.no_grad(): 150 | # Forward prop. 151 | pred, gold = model(padded_input, input_lengths, padded_target) 152 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing) 153 | 154 | # Keep track of metrics 155 | losses.update(loss.item()) 156 | 157 | # Print status 158 | logger.info('\nValidation Loss {loss.val:.5f} ({loss.avg:.5f})\n'.format(loss=losses)) 159 | 160 | return losses.avg 161 | 162 | 163 | def main(): 164 | global args 165 | args = parse_args() 166 | train_net(args) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/transformer/__init__.py -------------------------------------------------------------------------------- /transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MultiHeadAttention(nn.Module): 7 | ''' Multi-Head Attention module ''' 8 | 9 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 10 | super().__init__() 11 | 12 | self.n_head = n_head 13 | self.d_k = d_k 14 | self.d_v = d_v 15 | 16 | self.w_qs = nn.Linear(d_model, n_head * d_k) 17 | self.w_ks = nn.Linear(d_model, n_head * d_k) 18 | self.w_vs = nn.Linear(d_model, n_head * d_v) 19 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 20 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 21 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 22 | 23 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), 24 | attn_dropout=dropout) 25 | self.layer_norm = nn.LayerNorm(d_model) 26 | 27 | self.fc = nn.Linear(n_head * d_v, d_model) 28 | nn.init.xavier_normal_(self.fc.weight) 29 | 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | def forward(self, q, k, v, mask=None): 33 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 34 | 35 | sz_b, len_q, _ = q.size() 36 | sz_b, len_k, _ = k.size() 37 | sz_b, len_v, _ = v.size() 38 | 39 | residual = q 40 | 41 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 42 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 43 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 44 | 45 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 46 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 47 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 48 | 49 | if mask is not None: 50 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 51 | 52 | output, attn = self.attention(q, k, v, mask=mask) 53 | 54 | output = output.view(n_head, sz_b, len_q, d_v) 55 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 56 | 57 | output = self.dropout(self.fc(output)) 58 | output = self.layer_norm(output + residual) 59 | 60 | return output, attn 61 | 62 | 63 | class ScaledDotProductAttention(nn.Module): 64 | ''' Scaled Dot-Product Attention ''' 65 | 66 | def __init__(self, temperature, attn_dropout=0.1): 67 | super().__init__() 68 | self.temperature = temperature 69 | self.dropout = nn.Dropout(attn_dropout) 70 | self.softmax = nn.Softmax(dim=2) 71 | 72 | def forward(self, q, k, v, mask=None): 73 | attn = torch.bmm(q, k.transpose(1, 2)) 74 | attn = attn / self.temperature 75 | 76 | if mask is not None: 77 | attn = attn.masked_fill(mask.bool(), -np.inf) 78 | 79 | attn = self.softmax(attn) 80 | attn = self.dropout(attn) 81 | output = torch.bmm(attn, v) 82 | 83 | return output, attn 84 | -------------------------------------------------------------------------------- /transformer/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from config import IGNORE_ID 6 | from .attention import MultiHeadAttention 7 | from .module import PositionalEncoding, PositionwiseFeedForward 8 | from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list 9 | 10 | 11 | # filename = 'bigram_freq.pkl' 12 | # print('loading {}...'.format(filename)) 13 | # with open(filename, 'rb') as file: 14 | # bigram_freq = pickle.load(file) 15 | 16 | 17 | class Decoder(nn.Module): 18 | ''' A decoder model with self attention mechanism. ''' 19 | 20 | def __init__( 21 | self, sos_id=0, eos_id=1, 22 | n_tgt_vocab=4335, d_word_vec=512, 23 | n_layers=6, n_head=8, d_k=64, d_v=64, 24 | d_model=512, d_inner=2048, dropout=0.1, 25 | tgt_emb_prj_weight_sharing=True, 26 | pe_maxlen=5000): 27 | super(Decoder, self).__init__() 28 | # parameters 29 | self.sos_id = sos_id # Start of Sentence 30 | self.eos_id = eos_id # End of Sentence 31 | self.n_tgt_vocab = n_tgt_vocab 32 | self.d_word_vec = d_word_vec 33 | self.n_layers = n_layers 34 | self.n_head = n_head 35 | self.d_k = d_k 36 | self.d_v = d_v 37 | self.d_model = d_model 38 | self.d_inner = d_inner 39 | self.dropout = dropout 40 | self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing 41 | self.pe_maxlen = pe_maxlen 42 | 43 | self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec) 44 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | self.layer_stack = nn.ModuleList([ 48 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 49 | for _ in range(n_layers)]) 50 | 51 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 52 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 53 | 54 | if tgt_emb_prj_weight_sharing: 55 | # Share the weight matrix between target word embedding & the final logit dense layer 56 | self.tgt_word_prj.weight = self.tgt_word_emb.weight 57 | self.x_logit_scale = (d_model ** -0.5) 58 | else: 59 | self.x_logit_scale = 1. 60 | 61 | def preprocess(self, padded_input): 62 | """Generate decoder input and output label from padded_input 63 | Add to decoder input, and add to decoder output label 64 | """ 65 | ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys 66 | # prepare input and output word sequences with sos/eos IDs 67 | eos = ys[0].new([self.eos_id]) 68 | sos = ys[0].new([self.sos_id]) 69 | ys_in = [torch.cat([sos, y], dim=0) for y in ys] 70 | ys_out = [torch.cat([y, eos], dim=0) for y in ys] 71 | # padding for ys with -1 72 | # pys: utt x olen 73 | ys_in_pad = pad_list(ys_in, self.eos_id) 74 | ys_out_pad = pad_list(ys_out, IGNORE_ID) 75 | assert ys_in_pad.size() == ys_out_pad.size() 76 | return ys_in_pad, ys_out_pad 77 | 78 | def forward(self, padded_input, encoder_padded_outputs, 79 | encoder_input_lengths, return_attns=False): 80 | """ 81 | Args: 82 | padded_input: N x To 83 | encoder_padded_outputs: N x Ti x H 84 | Returns: 85 | """ 86 | dec_slf_attn_list, dec_enc_attn_list = [], [] 87 | 88 | # Get Deocder Input and Output 89 | ys_in_pad, ys_out_pad = self.preprocess(padded_input) 90 | 91 | # Prepare masks 92 | non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) 93 | 94 | slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) 95 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, 96 | seq_q=ys_in_pad, 97 | pad_idx=self.eos_id) 98 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 99 | 100 | output_length = ys_in_pad.size(1) 101 | dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, 102 | encoder_input_lengths, 103 | output_length) 104 | 105 | # Forward 106 | dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + 107 | self.positional_encoding(ys_in_pad)) 108 | 109 | for dec_layer in self.layer_stack: 110 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 111 | dec_output, encoder_padded_outputs, 112 | non_pad_mask=non_pad_mask, 113 | slf_attn_mask=slf_attn_mask, 114 | dec_enc_attn_mask=dec_enc_attn_mask) 115 | 116 | if return_attns: 117 | dec_slf_attn_list += [dec_slf_attn] 118 | dec_enc_attn_list += [dec_enc_attn] 119 | 120 | # before softmax 121 | seq_logit = self.tgt_word_prj(dec_output) 122 | 123 | # Return 124 | pred, gold = seq_logit, ys_out_pad 125 | 126 | if return_attns: 127 | return pred, gold, dec_slf_attn_list, dec_enc_attn_list 128 | return pred, gold 129 | 130 | def recognize_beam(self, encoder_outputs, char_list, args): 131 | """Beam search, decode one utterence now. 132 | Args: 133 | encoder_outputs: T x H 134 | char_list: list of character 135 | args: args.beam 136 | Returns: 137 | nbest_hyps: 138 | """ 139 | # search params 140 | beam = args.beam_size 141 | nbest = args.nbest 142 | if args.decode_max_len == 0: 143 | maxlen = encoder_outputs.size(0) 144 | else: 145 | maxlen = args.decode_max_len 146 | 147 | encoder_outputs = encoder_outputs.unsqueeze(0) 148 | 149 | # prepare sos 150 | ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() 151 | 152 | # yseq: 1xT 153 | hyp = {'score': 0.0, 'yseq': ys} 154 | hyps = [hyp] 155 | ended_hyps = [] 156 | 157 | for i in range(maxlen): 158 | hyps_best_kept = [] 159 | for hyp in hyps: 160 | ys = hyp['yseq'] # 1 x i 161 | # last_id = ys.cpu().numpy()[0][-1] 162 | # freq = bigram_freq[last_id] 163 | # freq = torch.log(torch.from_numpy(freq)) 164 | # # print(freq.dtype) 165 | # freq = freq.type(torch.float).to(device) 166 | # print(freq.dtype) 167 | # print('freq.size(): ' + str(freq.size())) 168 | # print('freq: ' + str(freq)) 169 | # -- Prepare masks 170 | non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1 171 | slf_attn_mask = get_subsequent_mask(ys) 172 | 173 | # -- Forward 174 | dec_output = self.dropout( 175 | self.tgt_word_emb(ys) * self.x_logit_scale + 176 | self.positional_encoding(ys)) 177 | 178 | for dec_layer in self.layer_stack: 179 | dec_output, _, _ = dec_layer( 180 | dec_output, encoder_outputs, 181 | non_pad_mask=non_pad_mask, 182 | slf_attn_mask=slf_attn_mask, 183 | dec_enc_attn_mask=None) 184 | 185 | seq_logit = self.tgt_word_prj(dec_output[:, -1]) 186 | # local_scores = F.log_softmax(seq_logit, dim=1) 187 | local_scores = F.log_softmax(seq_logit, dim=1) 188 | # print('local_scores.size(): ' + str(local_scores.size())) 189 | # local_scores += freq 190 | # print('local_scores: ' + str(local_scores)) 191 | 192 | # topk scores 193 | local_best_scores, local_best_ids = torch.topk( 194 | local_scores, beam, dim=1) 195 | 196 | for j in range(beam): 197 | new_hyp = {} 198 | new_hyp['score'] = hyp['score'] + local_best_scores[0, j] 199 | new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long() 200 | new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] 201 | new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) 202 | # will be (2 x beam) hyps at most 203 | hyps_best_kept.append(new_hyp) 204 | 205 | hyps_best_kept = sorted(hyps_best_kept, 206 | key=lambda x: x['score'], 207 | reverse=True)[:beam] 208 | # end for hyp in hyps 209 | hyps = hyps_best_kept 210 | 211 | # add eos in the final loop to avoid that there are no ended hyps 212 | if i == maxlen - 1: 213 | for hyp in hyps: 214 | hyp['yseq'] = torch.cat([hyp['yseq'], 215 | torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()], 216 | dim=1) 217 | 218 | # add ended hypothes to a final list, and removed them from current hypothes 219 | # (this will be a probmlem, number of hyps < beam) 220 | remained_hyps = [] 221 | for hyp in hyps: 222 | if hyp['yseq'][0, -1] == self.eos_id: 223 | ended_hyps.append(hyp) 224 | else: 225 | remained_hyps.append(hyp) 226 | 227 | hyps = remained_hyps 228 | # if len(hyps) > 0: 229 | # print('remeined hypothes: ' + str(len(hyps))) 230 | # else: 231 | # print('no hypothesis. Finish decoding.') 232 | # break 233 | # 234 | # for hyp in hyps: 235 | # print('hypo: ' + ''.join([char_list[int(x)] 236 | # for x in hyp['yseq'][0, 1:]])) 237 | # end for i in range(maxlen) 238 | nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[ 239 | :min(len(ended_hyps), nbest)] 240 | # compitable with LAS implementation 241 | for hyp in nbest_hyps: 242 | hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() 243 | return nbest_hyps 244 | 245 | 246 | class DecoderLayer(nn.Module): 247 | ''' Compose with three layers ''' 248 | 249 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 250 | super(DecoderLayer, self).__init__() 251 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 252 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 253 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 254 | 255 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 256 | dec_output, dec_slf_attn = self.slf_attn( 257 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 258 | dec_output *= non_pad_mask 259 | 260 | dec_output, dec_enc_attn = self.enc_attn( 261 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 262 | dec_output *= non_pad_mask 263 | 264 | dec_output = self.pos_ffn(dec_output) 265 | dec_output *= non_pad_mask 266 | 267 | return dec_output, dec_slf_attn, dec_enc_attn 268 | -------------------------------------------------------------------------------- /transformer/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .attention import MultiHeadAttention 4 | from .module import PositionalEncoding, PositionwiseFeedForward 5 | from .utils import get_non_pad_mask, get_attn_pad_mask 6 | 7 | 8 | class Encoder(nn.Module): 9 | """Encoder of Transformer including self-attention and feed forward. 10 | """ 11 | 12 | def __init__(self, d_input=320, n_layers=6, n_head=8, d_k=64, d_v=64, 13 | d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000): 14 | super(Encoder, self).__init__() 15 | # parameters 16 | self.d_input = d_input 17 | self.n_layers = n_layers 18 | self.n_head = n_head 19 | self.d_k = d_k 20 | self.d_v = d_v 21 | self.d_model = d_model 22 | self.d_inner = d_inner 23 | self.dropout_rate = dropout 24 | self.pe_maxlen = pe_maxlen 25 | 26 | # use linear transformation with layer norm to replace input embedding 27 | self.linear_in = nn.Linear(d_input, d_model) 28 | self.layer_norm_in = nn.LayerNorm(d_model) 29 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | self.layer_stack = nn.ModuleList([ 33 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 34 | for _ in range(n_layers)]) 35 | 36 | def forward(self, padded_input, input_lengths, return_attns=False): 37 | """ 38 | Args: 39 | padded_input: N x T x D 40 | input_lengths: N 41 | Returns: 42 | enc_output: N x T x H 43 | """ 44 | enc_slf_attn_list = [] 45 | 46 | # Prepare masks 47 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) 48 | length = padded_input.size(1) 49 | slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length) 50 | 51 | # Forward 52 | enc_output = self.dropout( 53 | self.layer_norm_in(self.linear_in(padded_input)) + 54 | self.positional_encoding(padded_input)) 55 | 56 | for enc_layer in self.layer_stack: 57 | enc_output, enc_slf_attn = enc_layer( 58 | enc_output, 59 | non_pad_mask=non_pad_mask, 60 | slf_attn_mask=slf_attn_mask) 61 | if return_attns: 62 | enc_slf_attn_list += [enc_slf_attn] 63 | 64 | if return_attns: 65 | return enc_output, enc_slf_attn_list 66 | return enc_output, 67 | 68 | 69 | class EncoderLayer(nn.Module): 70 | """Compose with two sub-layers. 71 | 1. A multi-head self-attention mechanism 72 | 2. A simple, position-wise fully connected feed-forward network. 73 | """ 74 | 75 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 76 | super(EncoderLayer, self).__init__() 77 | self.slf_attn = MultiHeadAttention( 78 | n_head, d_model, d_k, d_v, dropout=dropout) 79 | self.pos_ffn = PositionwiseFeedForward( 80 | d_model, d_inner, dropout=dropout) 81 | 82 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 83 | enc_output, enc_slf_attn = self.slf_attn( 84 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 85 | enc_output *= non_pad_mask 86 | 87 | enc_output = self.pos_ffn(enc_output) 88 | enc_output *= non_pad_mask 89 | 90 | return enc_output, enc_slf_attn 91 | -------------------------------------------------------------------------------- /transformer/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from config import IGNORE_ID 5 | 6 | 7 | def cal_performance(pred, gold, smoothing=0.0): 8 | """Calculate cross entropy loss, apply label smoothing if needed. 9 | Args: 10 | pred: N x T x C, score before softmax 11 | gold: N x T 12 | """ 13 | 14 | pred = pred.view(-1, pred.size(2)) 15 | gold = gold.contiguous().view(-1) 16 | 17 | loss = cal_loss(pred, gold, smoothing) 18 | 19 | pred = pred.max(1)[1] 20 | non_pad_mask = gold.ne(IGNORE_ID) 21 | n_correct = pred.eq(gold) 22 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 23 | 24 | return loss, n_correct 25 | 26 | 27 | def cal_loss(pred, gold, smoothing=0.0): 28 | """Calculate cross entropy loss, apply label smoothing if needed. 29 | """ 30 | 31 | if smoothing > 0.0: 32 | eps = smoothing 33 | n_class = pred.size(1) 34 | 35 | # Generate one-hot matrix: N x C. 36 | # Only label position is 1 and all other positions are 0 37 | # gold include -1 value (IGNORE_ID) and this will lead to assert error 38 | gold_for_scatter = gold.ne(IGNORE_ID).long() * gold 39 | one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1) 40 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class 41 | log_prb = F.log_softmax(pred, dim=1) 42 | 43 | non_pad_mask = gold.ne(IGNORE_ID) 44 | n_word = non_pad_mask.sum().item() 45 | loss = -(one_hot * log_prb).sum(dim=1) 46 | loss = loss.masked_select(non_pad_mask).sum() / n_word 47 | else: 48 | loss = F.cross_entropy(pred, gold, 49 | ignore_index=IGNORE_ID, 50 | reduction='elementwise_mean') 51 | 52 | return loss 53 | -------------------------------------------------------------------------------- /transformer/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """Implement the positional encoding (PE) function. 10 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 11 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 12 | """ 13 | 14 | def __init__(self, d_model, max_len=5000): 15 | super(PositionalEncoding, self).__init__() 16 | # Compute the positional encodings once in log space. 17 | pe = torch.zeros(max_len, d_model, requires_grad=False) 18 | position = torch.arange(0, max_len).unsqueeze(1).float() 19 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 20 | -(math.log(10000.0) / d_model)) 21 | pe[:, 0::2] = torch.sin(position * div_term) 22 | pe[:, 1::2] = torch.cos(position * div_term) 23 | pe = pe.unsqueeze(0) 24 | self.register_buffer('pe', pe) 25 | 26 | def forward(self, input): 27 | """ 28 | Args: 29 | input: N x T x D 30 | """ 31 | length = input.size(1) 32 | return self.pe[:, :length] 33 | 34 | 35 | class PositionwiseFeedForward(nn.Module): 36 | """Implements position-wise feedforward sublayer. 37 | FFN(x) = max(0, xW1 + b1)W2 + b2 38 | """ 39 | 40 | def __init__(self, d_model, d_ff, dropout=0.1): 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = nn.Linear(d_model, d_ff) 43 | self.w_2 = nn.Linear(d_ff, d_model) 44 | self.dropout = nn.Dropout(dropout) 45 | self.layer_norm = nn.LayerNorm(d_model) 46 | 47 | def forward(self, x): 48 | residual = x 49 | output = self.w_2(F.relu(self.w_1(x))) 50 | output = self.dropout(output) 51 | output = self.layer_norm(output + residual) 52 | return output 53 | 54 | 55 | # Another implementation 56 | class PositionwiseFeedForwardUseConv(nn.Module): 57 | """A two-feed-forward-layer module""" 58 | 59 | def __init__(self, d_in, d_hid, dropout=0.1): 60 | super(PositionwiseFeedForwardUseConv, self).__init__() 61 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 62 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 63 | self.layer_norm = nn.LayerNorm(d_in) 64 | self.dropout = nn.Dropout(dropout) 65 | 66 | def forward(self, x): 67 | residual = x 68 | output = x.transpose(1, 2) 69 | output = self.w_2(F.relu(self.w_1(output))) 70 | output = output.transpose(1, 2) 71 | output = self.dropout(output) 72 | output = self.layer_norm(output + residual) 73 | return output 74 | -------------------------------------------------------------------------------- /transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | class TransformerOptimizer(object): 2 | """A simple wrapper class for learning rate scheduling""" 3 | 4 | def __init__(self, optimizer, warmup_steps=4000, k=0.2): 5 | self.optimizer = optimizer 6 | self.k = k 7 | self.warmup_steps = warmup_steps 8 | d_model = 512 9 | self.init_lr = d_model ** (-0.5) 10 | self.lr = self.init_lr 11 | self.warmup_steps = warmup_steps 12 | self.k = k 13 | self.step_num = 0 14 | 15 | def zero_grad(self): 16 | self.optimizer.zero_grad() 17 | 18 | def step(self): 19 | self._update_lr() 20 | self.optimizer.step() 21 | 22 | def _update_lr(self): 23 | self.step_num += 1 24 | self.lr = self.k * self.init_lr * min(self.step_num ** (-0.5), 25 | self.step_num * (self.warmup_steps ** (-1.5))) 26 | for param_group in self.optimizer.param_groups: 27 | param_group['lr'] = self.lr 28 | -------------------------------------------------------------------------------- /transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .decoder import Decoder 4 | from .encoder import Encoder 5 | 6 | 7 | class Transformer(nn.Module): 8 | """An encoder-decoder framework only includes attention. 9 | """ 10 | 11 | def __init__(self, encoder=None, decoder=None): 12 | super(Transformer, self).__init__() 13 | 14 | if encoder is not None and decoder is not None: 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | 18 | for p in self.parameters(): 19 | if p.dim() > 1: 20 | nn.init.xavier_uniform_(p) 21 | else: 22 | self.encoder = Encoder() 23 | self.decoder = Decoder() 24 | 25 | def forward(self, padded_input, input_lengths, padded_target): 26 | """ 27 | Args: 28 | padded_input: N x Ti x D 29 | input_lengths: N 30 | padded_targets: N x To 31 | """ 32 | encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths) 33 | # pred is score before softmax 34 | pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, 35 | input_lengths) 36 | return pred, gold 37 | 38 | def recognize(self, input, input_length, char_list, args): 39 | """Sequence-to-Sequence beam search, decode one utterence now. 40 | Args: 41 | input: T x D 42 | char_list: list of characters 43 | args: args.beam 44 | Returns: 45 | nbest_hyps: 46 | """ 47 | encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length) 48 | nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0], 49 | char_list, 50 | args) 51 | return nbest_hyps 52 | -------------------------------------------------------------------------------- /transformer/utils.py: -------------------------------------------------------------------------------- 1 | def pad_list(xs, pad_value): 2 | # From: espnet/src/nets/e2e_asr_th.py: pad_list() 3 | n_batch = len(xs) 4 | max_len = max(x.size(0) for x in xs) 5 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 6 | for i in range(n_batch): 7 | pad[i, :xs[i].size(0)] = xs[i] 8 | return pad 9 | 10 | 11 | def process_dict(dict_path): 12 | with open(dict_path, 'rb') as f: 13 | dictionary = f.readlines() 14 | char_list = [entry.decode('utf-8').split(' ')[0] 15 | for entry in dictionary] 16 | sos_id = char_list.index('') 17 | eos_id = char_list.index('') 18 | return char_list, sos_id, eos_id 19 | 20 | 21 | if __name__ == "__main__": 22 | import sys 23 | 24 | path = sys.argv[1] 25 | char_list, sos_id, eos_id = process_dict(path) 26 | print(char_list, sos_id, eos_id) 27 | 28 | 29 | # * ------------------ recognition related ------------------ * 30 | 31 | 32 | def parse_hypothesis(hyp, char_list): 33 | """Function to parse hypothesis 34 | :param list hyp: recognition hypothesis 35 | :param list char_list: list of characters 36 | :return: recognition text strinig 37 | :return: recognition token strinig 38 | :return: recognition tokenid string 39 | """ 40 | # remove sos and get results 41 | tokenid_as_list = list(map(int, hyp['yseq'][1:])) 42 | token_as_list = [char_list[idx] for idx in tokenid_as_list] 43 | score = float(hyp['score']) 44 | 45 | # convert to string 46 | tokenid = " ".join([str(idx) for idx in tokenid_as_list]) 47 | token = " ".join(token_as_list) 48 | text = "".join(token_as_list).replace('', ' ') 49 | 50 | return text, token, tokenid, score 51 | 52 | 53 | def add_results_to_json(js, nbest_hyps, char_list): 54 | """Function to add N-best results to json 55 | :param dict js: groundtruth utterance dict 56 | :param list nbest_hyps: list of hypothesis 57 | :param list char_list: list of characters 58 | :return: N-best results added utterance dict 59 | """ 60 | # copy old json info 61 | new_js = dict() 62 | new_js['utt2spk'] = js['utt2spk'] 63 | new_js['output'] = [] 64 | 65 | for n, hyp in enumerate(nbest_hyps, 1): 66 | # parse hypothesis 67 | rec_text, rec_token, rec_tokenid, score = parse_hypothesis( 68 | hyp, char_list) 69 | 70 | # copy ground-truth 71 | out_dic = dict(js['output'][0].items()) 72 | 73 | # update name 74 | out_dic['name'] += '[%d]' % n 75 | 76 | # add recognition results 77 | out_dic['rec_text'] = rec_text 78 | out_dic['rec_token'] = rec_token 79 | out_dic['rec_tokenid'] = rec_tokenid 80 | out_dic['score'] = score 81 | 82 | # add to list of N-best result dicts 83 | new_js['output'].append(out_dic) 84 | 85 | # show 1-best result 86 | if n == 1: 87 | print('groundtruth: %s' % out_dic['text']) 88 | print('prediction : %s' % out_dic['rec_text']) 89 | 90 | return new_js 91 | 92 | 93 | # -- Transformer Related -- 94 | import torch 95 | 96 | 97 | def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None): 98 | """padding position is set to 0, either use input_lengths or pad_idx 99 | """ 100 | assert input_lengths is not None or pad_idx is not None 101 | if input_lengths is not None: 102 | # padded_input: N x T x .. 103 | N = padded_input.size(0) 104 | non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # N x T 105 | for i in range(N): 106 | non_pad_mask[i, input_lengths[i]:] = 0 107 | if pad_idx is not None: 108 | # padded_input: N x T 109 | assert padded_input.dim() == 2 110 | non_pad_mask = padded_input.ne(pad_idx).float() 111 | # unsqueeze(-1) for broadcast 112 | return non_pad_mask.unsqueeze(-1) 113 | 114 | 115 | def get_subsequent_mask(seq): 116 | ''' For masking out the subsequent info. ''' 117 | 118 | sz_b, len_s = seq.size() 119 | subsequent_mask = torch.triu( 120 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 121 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 122 | 123 | return subsequent_mask 124 | 125 | 126 | def get_attn_key_pad_mask(seq_k, seq_q, pad_idx): 127 | ''' For masking out the padding part of key sequence. ''' 128 | 129 | # Expand to fit the shape of key query attention matrix. 130 | len_q = seq_q.size(1) 131 | padding_mask = seq_k.eq(pad_idx) 132 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 133 | 134 | return padding_mask 135 | 136 | 137 | def get_attn_pad_mask(padded_input, input_lengths, expand_length): 138 | """mask position is set to 1""" 139 | # N x Ti x 1 140 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) 141 | # N x Ti, lt(1) like not operation 142 | pad_mask = non_pad_mask.squeeze(-1).lt(1) 143 | attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) 144 | return attn_mask 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import librosa 5 | import numpy as np 6 | import torch 7 | from config import sample_rate 8 | 9 | 10 | def clip_gradient(optimizer, grad_clip): 11 | """ 12 | Clips gradients computed during backpropagation to avoid explosion of gradients. 13 | :param optimizer: optimizer with the gradients to be clipped 14 | :param grad_clip: clip value 15 | """ 16 | for group in optimizer.param_groups: 17 | for param in group['params']: 18 | if param.grad is not None: 19 | param.grad.data.clamp_(-grad_clip, grad_clip) 20 | 21 | 22 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best): 23 | state = {'epoch': epoch, 24 | 'epochs_since_improvement': epochs_since_improvement, 25 | 'loss': loss, 26 | 'model': model, 27 | 'optimizer': optimizer} 28 | 29 | filename = 'checkpoint.tar' 30 | torch.save(state, filename) 31 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 32 | if is_best: 33 | torch.save(state, 'BEST_checkpoint.tar') 34 | 35 | 36 | class AverageMeter(object): 37 | """ 38 | Keeps track of most recent, average, sum, and count of a metric. 39 | """ 40 | 41 | def __init__(self): 42 | self.reset() 43 | 44 | def reset(self): 45 | self.val = 0 46 | self.avg = 0 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val, n=1): 51 | self.val = val 52 | self.sum += val * n 53 | self.count += n 54 | self.avg = self.sum / self.count 55 | 56 | 57 | def adjust_learning_rate(optimizer, shrink_factor): 58 | """ 59 | Shrinks learning rate by a specified factor. 60 | :param optimizer: optimizer whose learning rate must be shrunk. 61 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 62 | """ 63 | 64 | print("\nDECAYING learning rate.") 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = param_group['lr'] * shrink_factor 67 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 68 | 69 | 70 | def accuracy(scores, targets, k=1): 71 | batch_size = targets.size(0) 72 | _, ind = scores.topk(k, 1, True, True) 73 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 74 | correct_total = correct.view(-1).float().sum() # 0D tensor 75 | return correct_total.item() * (100.0 / batch_size) 76 | 77 | 78 | def parse_args(): 79 | parser = argparse.ArgumentParser(description='Speech Transformer') 80 | # Low Frame Rate (stacking and skipping frames) 81 | parser.add_argument('--LFR_m', default=4, type=int, 82 | help='Low Frame Rate: number of frames to stack') 83 | parser.add_argument('--LFR_n', default=3, type=int, 84 | help='Low Frame Rate: number of frames to skip') 85 | # Network architecture 86 | # encoder 87 | # TODO: automatically infer input dim 88 | parser.add_argument('--d_input', default=80, type=int, 89 | help='Dim of encoder input (before LFR)') 90 | parser.add_argument('--n_layers_enc', default=6, type=int, 91 | help='Number of encoder stacks') 92 | parser.add_argument('--n_head', default=8, type=int, 93 | help='Number of Multi Head Attention (MHA)') 94 | parser.add_argument('--d_k', default=64, type=int, 95 | help='Dimension of key') 96 | parser.add_argument('--d_v', default=64, type=int, 97 | help='Dimension of value') 98 | parser.add_argument('--d_model', default=512, type=int, 99 | help='Dimension of model') 100 | parser.add_argument('--d_inner', default=2048, type=int, 101 | help='Dimension of inner') 102 | parser.add_argument('--dropout', default=0.1, type=float, 103 | help='Dropout rate') 104 | parser.add_argument('--pe_maxlen', default=5000, type=int, 105 | help='Positional Encoding max len') 106 | # decoder 107 | parser.add_argument('--d_word_vec', default=512, type=int, 108 | help='Dim of decoder embedding') 109 | parser.add_argument('--n_layers_dec', default=6, type=int, 110 | help='Number of decoder stacks') 111 | parser.add_argument('--tgt_emb_prj_weight_sharing', default=1, type=int, 112 | help='share decoder embedding with decoder projection') 113 | # Loss 114 | parser.add_argument('--label_smoothing', default=0.1, type=float, 115 | help='label smoothing') 116 | 117 | # Training config 118 | parser.add_argument('--epochs', default=150, type=int, 119 | help='Number of maximum epochs') 120 | # minibatch 121 | parser.add_argument('--shuffle', default=1, type=int, 122 | help='reshuffle the data at every epoch') 123 | parser.add_argument('--batch-size', default=32, type=int, 124 | help='Batch size') 125 | parser.add_argument('--batch_frames', default=0, type=int, 126 | help='Batch frames. If this is not 0, batch size will make no sense') 127 | parser.add_argument('--maxlen-in', default=800, type=int, metavar='ML', 128 | help='Batch size is reduced if the input sequence length > ML') 129 | parser.add_argument('--maxlen-out', default=150, type=int, metavar='ML', 130 | help='Batch size is reduced if the output sequence length > ML') 131 | parser.add_argument('--num-workers', default=4, type=int, 132 | help='Number of workers to generate minibatch') 133 | # optimizer 134 | parser.add_argument('--lr', default=0.001, type=float, 135 | help='learning rate') 136 | parser.add_argument('--k', default=0.2, type=float, 137 | help='tunable scalar multiply to learning rate') 138 | parser.add_argument('--warmup_steps', default=4000, type=int, 139 | help='warmup steps') 140 | 141 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 142 | 143 | parser.add_argument('--n_samples', default="train:-1,dev:-1,test:-1", type=str, 144 | help='choose the number of examples to use') 145 | args = parser.parse_args() 146 | 147 | return args 148 | 149 | 150 | def get_logger(): 151 | logger = logging.getLogger() 152 | handler = logging.StreamHandler() 153 | formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s") 154 | handler.setFormatter(formatter) 155 | logger.addHandler(handler) 156 | logger.setLevel(logging.DEBUG) 157 | return logger 158 | 159 | 160 | def ensure_folder(folder): 161 | import os 162 | if not os.path.isdir(folder): 163 | os.mkdir(folder) 164 | 165 | 166 | def pad_list(xs, pad_value): 167 | # From: espnet/src/nets/e2e_asr_th.py: pad_list() 168 | n_batch = len(xs) 169 | max_len = max(x.size(0) for x in xs) 170 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 171 | for i in range(n_batch): 172 | pad[i, :xs[i].size(0)] = xs[i] 173 | return pad 174 | 175 | 176 | # [-0.5, 0.5] 177 | def normalize(yt): 178 | yt_max = np.max(yt) 179 | yt_min = np.min(yt) 180 | a = 1.0 / (yt_max - yt_min) 181 | b = -(yt_max + yt_min) / (2 * (yt_max - yt_min)) 182 | 183 | yt = yt * a + b 184 | return yt 185 | 186 | 187 | # Acoustic Feature Extraction 188 | # Parameters 189 | # - input file : str, audio file path 190 | # - feature : str, fbank or mfcc 191 | # - dim : int, dimension of feature 192 | # - cmvn : bool, apply CMVN on feature 193 | # - window_size : int, window size for FFT (ms) 194 | # - stride : int, window stride for FFT 195 | # - save_feature: str, if given, store feature to the path and return len(feature) 196 | # Return 197 | # acoustic features with shape (time step, dim) 198 | def extract_feature(input_file, feature='fbank', dim=80, cmvn=True, delta=False, delta_delta=False, 199 | window_size=25, stride=10, save_feature=None): 200 | y, sr = librosa.load(input_file, sr=sample_rate) 201 | yt, _ = librosa.effects.trim(y, top_db=20) 202 | yt = normalize(yt) 203 | ws = int(sr * 0.001 * window_size) 204 | st = int(sr * 0.001 * stride) 205 | if feature == 'fbank': # log-scaled 206 | feat = librosa.feature.melspectrogram(y=yt, sr=sr, n_mels=dim, 207 | n_fft=ws, hop_length=st) 208 | feat = np.log(feat + 1e-6) 209 | elif feature == 'mfcc': 210 | feat = librosa.feature.mfcc(y=yt, sr=sr, n_mfcc=dim, n_mels=26, 211 | n_fft=ws, hop_length=st) 212 | feat[0] = librosa.feature.rmse(yt, hop_length=st, frame_length=ws) 213 | 214 | else: 215 | raise ValueError('Unsupported Acoustic Feature: ' + feature) 216 | 217 | feat = [feat] 218 | if delta: 219 | feat.append(librosa.feature.delta(feat[0])) 220 | 221 | if delta_delta: 222 | feat.append(librosa.feature.delta(feat[0], order=2)) 223 | feat = np.concatenate(feat, axis=0) 224 | if cmvn: 225 | feat = (feat - feat.mean(axis=1)[:, np.newaxis]) / (feat.std(axis=1) + 1e-16)[:, np.newaxis] 226 | if save_feature is not None: 227 | tmp = np.swapaxes(feat, 0, 1).astype('float32') 228 | np.save(save_feature, tmp) 229 | return len(tmp) 230 | else: 231 | return np.swapaxes(feat, 0, 1).astype('float32') 232 | -------------------------------------------------------------------------------- /xer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig( 4 | format='%(levelname)s(%(filename)s:%(lineno)d): %(message)s') 5 | 6 | 7 | def levenshtein(u, v): 8 | prev = None 9 | curr = [0] + list(range(1, len(v) + 1)) 10 | # Operations: (SUB, DEL, INS) 11 | prev_ops = None 12 | curr_ops = [(0, 0, i) for i in range(len(v) + 1)] 13 | for x in range(1, len(u) + 1): 14 | prev, curr = curr, [x] + ([None] * len(v)) 15 | prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v)) 16 | for y in range(1, len(v) + 1): 17 | delcost = prev[y] + 1 18 | addcost = curr[y - 1] + 1 19 | subcost = prev[y - 1] + int(u[x - 1] != v[y - 1]) 20 | curr[y] = min(subcost, delcost, addcost) 21 | if curr[y] == subcost: 22 | (n_s, n_d, n_i) = prev_ops[y - 1] 23 | curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i) 24 | elif curr[y] == delcost: 25 | (n_s, n_d, n_i) = prev_ops[y] 26 | curr_ops[y] = (n_s, n_d + 1, n_i) 27 | else: 28 | (n_s, n_d, n_i) = curr_ops[y - 1] 29 | curr_ops[y] = (n_s, n_d, n_i + 1) 30 | return curr[len(v)], curr_ops[len(v)] 31 | 32 | 33 | def load_file(fname, encoding): 34 | try: 35 | f = open(fname, 'r') 36 | data = [] 37 | for line in f: 38 | data.append(line.rstrip('\n').rstrip('\r').decode(encoding)) 39 | f.close() 40 | except: 41 | logging.error('Error reading file "%s"', fname) 42 | exit(1) 43 | return data 44 | 45 | 46 | def cer_function(ref, hyp): 47 | wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0 48 | cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0 49 | sen_err = 0 50 | for n in range(len(ref)): 51 | # update CER statistics 52 | _, (s, i, d) = levenshtein(ref[n], hyp[n]) 53 | cer_s += s 54 | cer_i += i 55 | cer_d += d 56 | cer_n += len(ref[n]) 57 | # update WER statistics 58 | _, (s, i, d) = levenshtein(ref[n].split(), hyp[n].split()) 59 | wer_s += s 60 | wer_i += i 61 | wer_d += d 62 | wer_n += len(ref[n].split()) 63 | # update SER statistics 64 | if s + i + d > 0: 65 | sen_err += 1 66 | 67 | print(cer_s, cer_i, cer_d, cer_n) 68 | return (cer_s + cer_i + cer_d) / cer_n 69 | 70 | 71 | if __name__ == '__main__': 72 | ref = ['天然气用户为优先允许限制类和禁止类'] 73 | hyp = ['天然气用户为优先允许限制类和禁止量内'] 74 | cer_function = cer_function(ref, hyp) 75 | print(cer_function) 76 | --------------------------------------------------------------------------------