├── .gitignore
├── LICENSE
├── README.md
├── README.t
├── audios
├── audio_0.wav
├── audio_1.wav
├── audio_2.wav
├── audio_3.wav
├── audio_4.wav
├── audio_5.wav
├── audio_6.wav
├── audio_7.wav
├── audio_8.wav
└── audio_9.wav
├── char_list.pkl
├── collect_char_list.py
├── config.py
├── data_gen.py
├── demo.py
├── export.py
├── extract.py
├── ngram_lm.py
├── pre_process.py
├── replace.py
├── requirements.txt
├── results.json
├── specAugment
├── __init__.py
├── sparse_image_warp_np.py
├── sparse_image_warp_pytorch.py
├── spec_augment_pytorch.py
└── spec_augment_tensorflow.py
├── sponsor.jpg
├── test.py
├── test
├── test_decode.py
├── test_lm.py
├── test_lr.py
├── test_pe.py
├── test_specaug.py
└── test_trim.py
├── test_lm.py
├── train.py
├── transformer
├── __init__.py
├── attention.py
├── decoder.py
├── encoder.py
├── loss.py
├── module.py
├── optimizer.py
├── transformer.py
└── utils.py
├── utils.py
└── xer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__/
3 | BEST_checkpoint.tar
4 | checkpoint.tar
5 | data/
6 | runs/
7 | nohup.out
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 刘杨
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speech Transformer
2 |
3 | ## Introduction
4 |
5 | This is a PyTorch re-implementation of Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition.
6 |
7 | ## Dataset
8 |
9 | Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co.,Ltd.
10 |
11 | 400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition.
12 | ```
13 | @inproceedings{aishell_2017,
14 | title={AIShell-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline},
15 | author={Hui Bu, Jiayu Du, Xingyu Na, Bengu Wu, Hao Zheng},
16 | booktitle={Oriental COCOSDA 2017},
17 | pages={Submitted},
18 | year={2017}
19 | }
20 | ```
21 | In data folder, download speech data and transcripts:
22 |
23 | ```bash
24 | $ wget http://www.openslr.org/resources/33/data_aishell.tgz
25 | ```
26 |
27 | ## Performance
28 |
29 | Evaluate with 7176 audios in Aishell test set:
30 | ```bash
31 | $ python test.py
32 | ```
33 |
34 | ## Results
35 |
36 | |Model|CER|Download|
37 | |---|---|---|
38 | |Speech Transformer|11.5|[Link](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/BEST_checkpoint.tar)|
39 |
40 | ## Dependency
41 |
42 | - Python 3.6.8
43 | - PyTorch 1.3.0
44 |
45 | ## Usage
46 | ### Data Pre-processing
47 | Extract data_aishell.tgz:
48 | ```bash
49 | $ python extract.py
50 | ```
51 |
52 | Extract wav files into train/dev/test folders:
53 | ```bash
54 | $ cd data/data_aishell/wav
55 | $ find . -name '*.tar.gz' -execdir tar -xzvf '{}' \;
56 | ```
57 |
58 | Scan transcript data, generate features:
59 | ```bash
60 | $ python pre_process.py
61 | ```
62 |
63 | Now the folder structure under data folder is sth. like:
64 |
65 | data/
66 | data_aishell.tgz
67 | data_aishell/
68 | transcript/
69 | aishell_transcript_v0.8.txt
70 | wav/
71 | train/
72 | dev/
73 | test/
74 | aishell.pickle
75 |
76 |
77 | ### Train
78 | ```bash
79 | $ python train.py
80 | ```
81 |
82 | If you want to visualize during training, run in your terminal:
83 | ```bash
84 | $ tensorboard --logdir runs
85 | ```
86 |
87 | ### Demo
88 | Please download the [pretrained model](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt) then run:
89 | ```bash
90 | $ python demo.py
91 | ```
92 |
93 | It picks 10 random test examples and recognize them like these:
94 |
95 | |Audio|Out|GT|
96 | |---|---|---|
97 | |[audio_0.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_0.wav)|我国的经济处在爬破过凯的重要公考
我国的经济处在爬破过凯的重要公口
我国的经济处在盘破过凯的重要公考
我国的经济处在爬破过凯的重要公靠
我国的经济处在爬坡过凯的重要公考|我国的经济处在爬坡过坎的重要关口|
98 | |[audio_1.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_1.wav)|完善主地承包经一全流市市场
完善主地承包经一全六市市场
完善主地承包经营全流市市场
完善主地承包经一权流市市场
完善主地承包经营全六市市场|完善土地承包经营权流转市场|
99 | |[audio_2.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_2.wav)|临长各类设施使用年限
严长各类设施使用年限
延长各类设施使用年限
很长各类设施使用年限
难长各类设施使用年限|延长各类设施使用年限|
100 | |[audio_3.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_3.wav)|苹果此举是为了节约用电量
苹果此举是是了节约用电量
苹果此举是为了解约用电量
苹果此举是为了节约用电令
苹果此举只为了节约用电量|苹果此举是为了节约用电量|
101 | |[audio_4.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_4.wav)|反他们也可以有机会参与体育运动
让他们也可以有机会参与体育运动
反她们也可以有机会参与体育运动
范他们也可以有机会参与体育运动
但他们也可以有机会参与体育运动|让他们也可以有机会参与体育运动|
102 | |[audio_5.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_5.wav)|陈言希穿着粉色上衣
陈闫希穿着粉色上衣
陈延希穿着粉色上衣
陈言琪穿着粉色上衣
陈演希穿着粉色上衣|陈妍希穿着粉色上衣|
103 | |[audio_6.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_6.wav)|说起自己的伴女大下
说起自己的伴理大下
说起自己的半女大下
说起自己的办女大下
说起自己的半理大下|说起自己的伴侣大侠|
104 | |[audio_7.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_7.wav)|每日经济新闻记者注意到
每日经济新闻记者朱意到
每日经济新闻记者注一到
每日经济新闻记者注注到
每日经济新闻记者注以到|每日经济新闻记者注意到|
105 | |[audio_8.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_8.wav)|这是今年五月份以来库存环比增幅幅小了一次
这是今年五月份以来库存环比增幅最小了一次
这是今年五月份以来库存环比增幅幅小的一次
这是今年五月份以来库存环比增幅最小的一次
这是今年五月份以来库存环比增幅幅小小一次|这是今年五月份以来库存环比增幅最小的一次|
106 | |[audio_9.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_9.wav)|一个人的精使生命就将走向摔老
一个连的精使生命就将走向摔老
一个人的金使生命就将走向摔老
一个人的坚使生命就将走向摔老
一个连的金使生命就将走向摔老|一个人的精神生命就将走向衰老|
107 |
108 | ## 小小的赞助~
109 |
110 |
111 |
112 | 若对您有帮助可给予小小的赞助~
113 |
114 |
115 |
--------------------------------------------------------------------------------
/README.t:
--------------------------------------------------------------------------------
1 | # Speech Transformer
2 |
3 | ## Introduction
4 |
5 | This is a PyTorch re-implementation of Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition.
6 |
7 | ## Dataset
8 |
9 | Aishell is an open-source Chinese Mandarin speech corpus published by Beijing Shell Shell Technology Co.,Ltd.
10 |
11 | 400 people from different accent areas in China are invited to participate in the recording, which is conducted in a quiet indoor environment using high fidelity microphone and downsampled to 16kHz. The manual transcription accuracy is above 95%, through professional speech annotation and strict quality inspection. The data is free for academic use. We hope to provide moderate amount of data for new researchers in the field of speech recognition.
12 | ```
13 | @inproceedings{aishell_2017,
14 | title={AIShell-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline},
15 | author={Hui Bu, Jiayu Du, Xingyu Na, Bengu Wu, Hao Zheng},
16 | booktitle={Oriental COCOSDA 2017},
17 | pages={Submitted},
18 | year={2017}
19 | }
20 | ```
21 | In data folder, download speech data and transcripts:
22 |
23 | ```bash
24 | $ wget http://www.openslr.org/resources/33/data_aishell.tgz
25 | ```
26 |
27 | ## Performance
28 |
29 | Evaluate with 7176 audios in Aishell test set:
30 | ```bash
31 | $ python test.py
32 | ```
33 |
34 | ## Results
35 |
36 | |Model|CER|Download|
37 | |---|---|---|
38 | |Speech Transformer|11.5|[Link](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/BEST_checkpoint.tar)|
39 |
40 | ## Dependency
41 |
42 | - Python 3.6.8
43 | - PyTorch 1.3.0
44 |
45 | ## Usage
46 | ### Data Pre-processing
47 | Extract data_aishell.tgz:
48 | ```bash
49 | $ python extract.py
50 | ```
51 |
52 | Extract wav files into train/dev/test folders:
53 | ```bash
54 | $ cd data/data_aishell/wav
55 | $ find . -name '*.tar.gz' -execdir tar -xzvf '{}' \;
56 | ```
57 |
58 | Scan transcript data, generate features:
59 | ```bash
60 | $ python pre_process.py
61 | ```
62 |
63 | Now the folder structure under data folder is sth. like:
64 |
65 | data/
66 | data_aishell.tgz
67 | data_aishell/
68 | transcript/
69 | aishell_transcript_v0.8.txt
70 | wav/
71 | train/
72 | dev/
73 | test/
74 | aishell.pickle
75 |
76 |
77 | ### Train
78 | ```bash
79 | $ python train.py
80 | ```
81 |
82 | If you want to visualize during training, run in your terminal:
83 | ```bash
84 | $ tensorboard --logdir runs
85 | ```
86 |
87 | ### Demo
88 | Please download the [pretrained model](https://github.com/foamliu/Speech-Transformer/releases/download/v1.0/speech-transformer-cn.pt) then run:
89 | ```bash
90 | $ python demo.py
91 | ```
92 |
93 | It picks 10 random test examples and recognize them like these:
94 |
95 | |Audio|Out|GT|
96 | |---|---|---|
97 | |[audio_0.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_0.wav)|$(out_list_0)|$(gt_0)|
98 | |[audio_1.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_1.wav)|$(out_list_1)|$(gt_1)|
99 | |[audio_2.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_2.wav)|$(out_list_2)|$(gt_2)|
100 | |[audio_3.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_3.wav)|$(out_list_3)|$(gt_3)|
101 | |[audio_4.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_4.wav)|$(out_list_4)|$(gt_4)|
102 | |[audio_5.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_5.wav)|$(out_list_5)|$(gt_5)|
103 | |[audio_6.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_6.wav)|$(out_list_6)|$(gt_6)|
104 | |[audio_7.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_7.wav)|$(out_list_7)|$(gt_7)|
105 | |[audio_8.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_8.wav)|$(out_list_8)|$(gt_8)|
106 | |[audio_9.wav](https://github.com/foamliu/Speech-Transformer/raw/master/audios/audio_9.wav)|$(out_list_9)|$(gt_9)|
107 |
108 | ## 小小的赞助~
109 |
110 |
111 |
112 | 若对您有帮助可给予小小的赞助~
113 |
114 |
115 |
--------------------------------------------------------------------------------
/audios/audio_0.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_0.wav
--------------------------------------------------------------------------------
/audios/audio_1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_1.wav
--------------------------------------------------------------------------------
/audios/audio_2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_2.wav
--------------------------------------------------------------------------------
/audios/audio_3.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_3.wav
--------------------------------------------------------------------------------
/audios/audio_4.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_4.wav
--------------------------------------------------------------------------------
/audios/audio_5.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_5.wav
--------------------------------------------------------------------------------
/audios/audio_6.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_6.wav
--------------------------------------------------------------------------------
/audios/audio_7.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_7.wav
--------------------------------------------------------------------------------
/audios/audio_8.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_8.wav
--------------------------------------------------------------------------------
/audios/audio_9.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/audios/audio_9.wav
--------------------------------------------------------------------------------
/char_list.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/char_list.pkl
--------------------------------------------------------------------------------
/collect_char_list.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | from config import pickle_file
4 |
5 | if __name__ == '__main__':
6 | with open(pickle_file, 'rb') as file:
7 | data = pickle.load(file)
8 | char_list = data['IVOCAB']
9 |
10 | with open('char_list.pkl', 'wb') as file:
11 | pickle.dump(char_list, file)
12 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors
6 |
7 | # Model parameters
8 | input_dim = 80 # dimension of feature
9 | window_size = 25 # window size for FFT (ms)
10 | stride = 10 # window stride for FFT (ms)
11 | hidden_size = 512
12 | embedding_dim = 512
13 | cmvn = True # apply CMVN on feature
14 | num_layers = 4
15 | LFR_m = 4
16 | LFR_n = 3
17 | sample_rate = 16000 # aishell
18 |
19 | # Training parameters
20 | grad_clip = 5. # clip gradients at an absolute value of
21 | print_freq = 100 # print training/validation stats every __ batches
22 | checkpoint = None # path to checkpoint, None if none
23 |
24 | # Data parameters
25 | IGNORE_ID = -1
26 | sos_id = 0
27 | eos_id = 1
28 | num_train = 120098
29 | num_dev = 14326
30 | num_test = 7176
31 | vocab_size = 4335
32 |
33 | DATA_DIR = 'data'
34 | aishell_folder = 'data/data_aishell'
35 | wav_folder = os.path.join(aishell_folder, 'wav')
36 | tran_file = os.path.join(aishell_folder, 'transcript/aishell_transcript_v0.8.txt')
37 | pickle_file = 'data/aishell.pickle'
38 |
--------------------------------------------------------------------------------
/data_gen.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import random
3 |
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from torch.utils.data.dataloader import default_collate
7 |
8 | from config import pickle_file, IGNORE_ID
9 | from utils import extract_feature
10 |
11 |
12 | def pad_collate(batch):
13 | max_input_len = float('-inf')
14 | max_target_len = float('-inf')
15 |
16 | for elem in batch:
17 | feature, trn = elem
18 | max_input_len = max_input_len if max_input_len > feature.shape[0] else feature.shape[0]
19 | max_target_len = max_target_len if max_target_len > len(trn) else len(trn)
20 |
21 | for i, elem in enumerate(batch):
22 | feature, trn = elem
23 | input_length = feature.shape[0]
24 | input_dim = feature.shape[1]
25 | padded_input = np.zeros((max_input_len, input_dim), dtype=np.float32)
26 | padded_input[:input_length, :] = feature
27 | padded_target = np.pad(trn, (0, max_target_len - len(trn)), 'constant', constant_values=IGNORE_ID)
28 | batch[i] = (padded_input, padded_target, input_length)
29 |
30 | # sort it by input lengths (long to short)
31 | batch.sort(key=lambda x: x[2], reverse=True)
32 |
33 | return default_collate(batch)
34 |
35 |
36 | def build_LFR_features(inputs, m, n):
37 | """
38 | Actually, this implements stacking frames and skipping frames.
39 | if m = 1 and n = 1, just return the origin features.
40 | if m = 1 and n > 1, it works like skipping.
41 | if m > 1 and n = 1, it works like stacking but only support right frames.
42 | if m > 1 and n > 1, it works like LFR.
43 | Args:
44 | inputs_batch: inputs is T x D np.ndarray
45 | m: number of frames to stack
46 | n: number of frames to skip
47 | """
48 | # LFR_inputs_batch = []
49 | # for inputs in inputs_batch:
50 | LFR_inputs = []
51 | T = inputs.shape[0]
52 | T_lfr = int(np.ceil(T / n))
53 | for i in range(T_lfr):
54 | if m <= T - i * n:
55 | LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
56 | else: # process last LFR frame
57 | num_padding = m - (T - i * n)
58 | frame = np.hstack(inputs[i * n:])
59 | for _ in range(num_padding):
60 | frame = np.hstack((frame, inputs[-1]))
61 | LFR_inputs.append(frame)
62 | return np.vstack(LFR_inputs)
63 |
64 |
65 | # Source: https://www.kaggle.com/davids1992/specaugment-quick-implementation
66 | def spec_augment(spec: np.ndarray,
67 | num_mask=2,
68 | freq_masking=0.15,
69 | time_masking=0.20,
70 | value=0):
71 | spec = spec.copy()
72 | num_mask = random.randint(1, num_mask)
73 | for i in range(num_mask):
74 | all_freqs_num, all_frames_num = spec.shape
75 | freq_percentage = random.uniform(0.0, freq_masking)
76 |
77 | num_freqs_to_mask = int(freq_percentage * all_freqs_num)
78 | f0 = np.random.uniform(low=0.0, high=all_freqs_num - num_freqs_to_mask)
79 | f0 = int(f0)
80 | spec[f0:f0 + num_freqs_to_mask, :] = value
81 |
82 | time_percentage = random.uniform(0.0, time_masking)
83 |
84 | num_frames_to_mask = int(time_percentage * all_frames_num)
85 | t0 = np.random.uniform(low=0.0, high=all_frames_num - num_frames_to_mask)
86 | t0 = int(t0)
87 | spec[:, t0:t0 + num_frames_to_mask] = value
88 | return spec
89 |
90 |
91 | class AiShellDataset(Dataset):
92 | def __init__(self, args, split):
93 | self.args = args
94 | with open(pickle_file, 'rb') as file:
95 | data = pickle.load(file)
96 |
97 | self.samples = data[split]
98 | print('loading {} {} samples...'.format(len(self.samples), split))
99 |
100 | def __getitem__(self, i):
101 | sample = self.samples[i]
102 | wave = sample['wave']
103 | trn = sample['trn']
104 |
105 | feature = extract_feature(input_file=wave, feature='fbank', dim=self.args.d_input, cmvn=True)
106 | # zero mean and unit variance
107 | feature = (feature - feature.mean()) / feature.std()
108 | feature = spec_augment(feature)
109 | feature = build_LFR_features(feature, m=self.args.LFR_m, n=self.args.LFR_n)
110 |
111 | return feature, trn
112 |
113 | def __len__(self):
114 | return len(self.samples)
115 |
116 |
117 | if __name__ == "__main__":
118 | import torch
119 | from utils import parse_args
120 | from tqdm import tqdm
121 |
122 | args = parse_args()
123 | train_dataset = AiShellDataset(args, 'train')
124 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=args.num_workers,
125 | collate_fn=pad_collate)
126 | #
127 | # print(len(train_dataset))
128 | # print(len(train_loader))
129 | #
130 | # feature = train_dataset[10][0]
131 | # print(feature.shape)
132 | #
133 | # trn = train_dataset[10][1]
134 | # print(trn)
135 | #
136 | # with open(pickle_file, 'rb') as file:
137 | # data = pickle.load(file)
138 | # IVOCAB = data['IVOCAB']
139 | #
140 | # print([IVOCAB[idx] for idx in trn])
141 | #
142 | # for data in train_loader:
143 | # print(data)
144 | # break
145 |
146 | max_len = 0
147 |
148 | for data in tqdm(train_loader):
149 | feature = data[0]
150 | # print(feature.shape)
151 | if feature.shape[1] > max_len:
152 | max_len = feature.shape[1]
153 |
154 | print('max_len: ' + str(max_len))
155 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import random
4 | from shutil import copyfile
5 |
6 | import torch
7 |
8 | from config import pickle_file, device, input_dim, LFR_m, LFR_n
9 | from data_gen import build_LFR_features
10 | from transformer.transformer import Transformer
11 | from utils import extract_feature, ensure_folder
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(
16 | "End-to-End Automatic Speech Recognition Decoding.")
17 | # decode
18 | parser.add_argument('--beam_size', default=5, type=int,
19 | help='Beam size')
20 | parser.add_argument('--nbest', default=5, type=int,
21 | help='Nbest size')
22 | parser.add_argument('--decode_max_len', default=100, type=int,
23 | help='Max output length. If ==0 (default), it uses a '
24 | 'end-detect function to automatically find maximum '
25 | 'hypothesis lengths')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | if __name__ == '__main__':
31 | args = parse_args()
32 | with open('char_list.pkl', 'rb') as file:
33 | char_list = pickle.load(file)
34 | with open(pickle_file, 'rb') as file:
35 | data = pickle.load(file)
36 | samples = data['test']
37 |
38 | filename = 'speech-transformer-cn.pt'
39 | print('loading model: {}...'.format(filename))
40 | model = Transformer()
41 | model.load_state_dict(torch.load(filename))
42 | model = model.to(device)
43 | model.eval()
44 |
45 | samples = random.sample(samples, 10)
46 | ensure_folder('audios')
47 | results = []
48 |
49 | for i, sample in enumerate(samples):
50 | wave = sample['wave']
51 | trn = sample['trn']
52 |
53 | copyfile(wave, 'audios/audio_{}.wav'.format(i))
54 |
55 | feature = extract_feature(input_file=wave, feature='fbank', dim=input_dim, cmvn=True)
56 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n)
57 | # feature = np.expand_dims(feature, axis=0)
58 | input = torch.from_numpy(feature).to(device)
59 | input_length = [input[0].shape[0]]
60 | input_length = torch.LongTensor(input_length).to(device)
61 | nbest_hyps = model.recognize(input, input_length, char_list, args)
62 | out_list = []
63 | for hyp in nbest_hyps:
64 | out = hyp['yseq']
65 | out = [char_list[idx] for idx in out]
66 | out = ''.join(out)
67 | out_list.append(out)
68 | print('OUT_LIST: {}'.format(out_list))
69 |
70 | gt = [char_list[idx] for idx in trn]
71 | gt = ''.join(gt)
72 | print('GT: {}\n'.format(gt))
73 |
74 | results.append({'out_list_{}'.format(i): out_list, 'gt_{}'.format(i): gt})
75 |
76 | import json
77 |
78 | with open('results.json', 'w') as file:
79 | json.dump(results, file, indent=4, ensure_ascii=False)
80 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | if __name__ == '__main__':
5 | checkpoint = 'BEST_checkpoint.tar'
6 | checkpoint = torch.load(checkpoint)
7 | model = checkpoint['model']
8 | # model.eval()
9 |
10 | torch.save(model.state_dict(), 'speech-transformer-cn.pt')
11 |
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tarfile
3 |
4 |
5 | def extract(filename):
6 | print('Extracting {}...'.format(filename))
7 | tar = tarfile.open(filename, 'r')
8 | tar.extractall('data')
9 | tar.close()
10 |
11 |
12 | if __name__ == "__main__":
13 | if not os.path.isdir('data/data_aishell'):
14 | extract('data/data_aishell.tgz')
15 |
--------------------------------------------------------------------------------
/ngram_lm.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import pickle
3 |
4 | import nltk
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from config import pickle_file
9 |
10 | with open(pickle_file, 'rb') as file:
11 | data = pickle.load(file)
12 | char_list = data['IVOCAB']
13 | vocab_size = len(char_list)
14 | samples = data['train']
15 | bigram_counter = collections.Counter()
16 |
17 | for sample in tqdm(samples):
18 | text = sample['trn']
19 | # text = [char_list[idx] for idx in text]
20 | tokens = list(text)
21 | bigrm = nltk.bigrams(tokens)
22 | # print(*map(' '.join, bigrm), sep=', ')
23 |
24 | # get the frequency of each bigram in our corpus
25 | bigram_counter.update(bigrm)
26 |
27 | # what are the ten most popular ngrams in this Spanish corpus?
28 | print(bigram_counter.most_common(10))
29 |
30 | temp_dict = dict()
31 | for key, value in bigram_counter.items():
32 | temp_dict[key] = value
33 |
34 | print('smoothing and freq -> prob')
35 | bigram_freq = dict()
36 | for i in tqdm(range(vocab_size)):
37 | freq_list = []
38 | for j in range(vocab_size):
39 | if (i, j) in temp_dict:
40 | freq_list.append(temp_dict[(i, j)])
41 | else:
42 | freq_list.append(1)
43 |
44 | freq_list = np.array(freq_list)
45 | freq_list = freq_list / np.sum(freq_list)
46 |
47 | assert (len(freq_list) == vocab_size)
48 | bigram_freq[i] = freq_list
49 |
50 | print(len(bigram_freq[0]))
51 | with open('bigram_freq.pkl', 'wb') as file:
52 | pickle.dump(bigram_freq, file)
53 |
--------------------------------------------------------------------------------
/pre_process.py:
--------------------------------------------------------------------------------
1 | # Usage : python pre_process.py --n_samples train:800,dev:100,test:100
2 | # python pre_process.py --n_samples train:800,dev:100
3 | # ...
4 |
5 |
6 | import os
7 | import pickle
8 |
9 | from tqdm import tqdm
10 |
11 | from config import wav_folder, tran_file, pickle_file
12 | from utils import ensure_folder, parse_args
13 |
14 |
15 | def get_data(split, n_samples):
16 | print('getting {} data...'.format(split))
17 |
18 | global VOCAB
19 |
20 | with open(tran_file, 'r', encoding='utf-8') as file:
21 | lines = file.readlines()
22 |
23 | tran_dict = dict()
24 | for line in lines:
25 | tokens = line.split()
26 | key = tokens[0]
27 | trn = ''.join(tokens[1:])
28 | tran_dict[key] = trn
29 |
30 | samples = []
31 |
32 | #n_samples = 5000
33 | rest = n_samples
34 |
35 | folder = os.path.join(wav_folder, split)
36 | ensure_folder(folder)
37 | dirs = [os.path.join(folder, d) for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]
38 | for dir in tqdm(dirs):
39 | files = [f for f in os.listdir(dir) if f.endswith('.wav')]
40 |
41 | rest = len(files) if n_samples <= 0 else rest
42 |
43 | for f in files[:rest]:
44 |
45 | wave = os.path.join(dir, f)
46 |
47 | key = f.split('.')[0]
48 |
49 | if key in tran_dict:
50 | trn = tran_dict[key]
51 | trn = list(trn.strip()) + ['']
52 |
53 | for token in trn:
54 | build_vocab(token)
55 |
56 | trn = [VOCAB[token] for token in trn]
57 |
58 | samples.append({'trn': trn, 'wave': wave})
59 |
60 | rest = rest - len(files) if n_samples > 0 else rest
61 | if rest <= 0 :
62 | break
63 |
64 | print('split: {}, num_files: {}'.format(split, len(samples)))
65 | return samples
66 |
67 |
68 | def build_vocab(token):
69 | global VOCAB, IVOCAB
70 | if not token in VOCAB:
71 | next_index = len(VOCAB)
72 | VOCAB[token] = next_index
73 | IVOCAB[next_index] = token
74 |
75 |
76 | if __name__ == "__main__":
77 |
78 | # number of examples to use
79 | global args
80 | args = parse_args()
81 | tmp = args.n_samples.split(",")
82 | tmp = [a.split(":") for a in tmp]
83 | tmp = {a[0]:int(a[1]) for a in tmp}
84 | args.n_samples = {"train":-1, "dev":-1,"test":-1}
85 | args.n_samples.update(tmp)
86 |
87 | VOCAB = {'': 0, '': 1}
88 | IVOCAB = {0: '', 1: ''}
89 |
90 | data = dict()
91 | data['VOCAB'] = VOCAB
92 | data['IVOCAB'] = IVOCAB
93 | data['train'] = get_data('train', args.n_samples["train"])
94 | data['dev'] = get_data('dev', args.n_samples["dev"])
95 | data['test'] = get_data('test', args.n_samples["test"])
96 |
97 | with open(pickle_file, 'wb') as file:
98 | pickle.dump(data, file)
99 |
100 | print('num_train: ' + str(len(data['train'])))
101 | print('num_dev: ' + str(len(data['dev'])))
102 | print('num_test: ' + str(len(data['test'])))
103 | print('vocab_size: ' + str(len(data['VOCAB'])))
104 |
--------------------------------------------------------------------------------
/replace.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 |
4 | if __name__ == '__main__':
5 | with open('README.t', 'r', encoding="utf-8") as file:
6 | text = file.readlines()
7 | text = ''.join(text)
8 |
9 | with open('results.json', 'r', encoding="utf-8") as file:
10 | results = json.load(file)
11 |
12 | print(results[0])
13 |
14 | for i, result in enumerate(results):
15 | out_key = 'out_list_{}'.format(i)
16 | text = text.replace('$({})'.format(out_key), '
'.join(result[out_key]))
17 | gt_key = 'gt_{}'.format(i)
18 | text = text.replace('$({})'.format(gt_key), result[gt_key])
19 |
20 | text = text.replace('', '')
21 | text = text.replace('', '')
22 |
23 | with open('README.md', 'w', encoding="utf-8") as file:
24 | file.write(text)
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa
--------------------------------------------------------------------------------
/results.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "out_list_0": [
4 | "我国的经济处在爬破过凯的重要公考",
5 | "我国的经济处在爬破过凯的重要公口",
6 | "我国的经济处在盘破过凯的重要公考",
7 | "我国的经济处在爬破过凯的重要公靠",
8 | "我国的经济处在爬坡过凯的重要公考"
9 | ],
10 | "gt_0": "我国的经济处在爬坡过坎的重要关口"
11 | },
12 | {
13 | "out_list_1": [
14 | "完善主地承包经一全流市市场",
15 | "完善主地承包经一全六市市场",
16 | "完善主地承包经营全流市市场",
17 | "完善主地承包经一权流市市场",
18 | "完善主地承包经营全六市市场"
19 | ],
20 | "gt_1": "完善土地承包经营权流转市场"
21 | },
22 | {
23 | "out_list_2": [
24 | "临长各类设施使用年限",
25 | "严长各类设施使用年限",
26 | "延长各类设施使用年限",
27 | "很长各类设施使用年限",
28 | "难长各类设施使用年限"
29 | ],
30 | "gt_2": "延长各类设施使用年限"
31 | },
32 | {
33 | "out_list_3": [
34 | "苹果此举是为了节约用电量",
35 | "苹果此举是是了节约用电量",
36 | "苹果此举是为了解约用电量",
37 | "苹果此举是为了节约用电令",
38 | "苹果此举只为了节约用电量"
39 | ],
40 | "gt_3": "苹果此举是为了节约用电量"
41 | },
42 | {
43 | "out_list_4": [
44 | "反他们也可以有机会参与体育运动",
45 | "让他们也可以有机会参与体育运动",
46 | "反她们也可以有机会参与体育运动",
47 | "范他们也可以有机会参与体育运动",
48 | "但他们也可以有机会参与体育运动"
49 | ],
50 | "gt_4": "让他们也可以有机会参与体育运动"
51 | },
52 | {
53 | "out_list_5": [
54 | "陈言希穿着粉色上衣",
55 | "陈闫希穿着粉色上衣",
56 | "陈延希穿着粉色上衣",
57 | "陈言琪穿着粉色上衣",
58 | "陈演希穿着粉色上衣"
59 | ],
60 | "gt_5": "陈妍希穿着粉色上衣"
61 | },
62 | {
63 | "out_list_6": [
64 | "说起自己的伴女大下",
65 | "说起自己的伴理大下",
66 | "说起自己的半女大下",
67 | "说起自己的办女大下",
68 | "说起自己的半理大下"
69 | ],
70 | "gt_6": "说起自己的伴侣大侠"
71 | },
72 | {
73 | "out_list_7": [
74 | "每日经济新闻记者注意到",
75 | "每日经济新闻记者朱意到",
76 | "每日经济新闻记者注一到",
77 | "每日经济新闻记者注注到",
78 | "每日经济新闻记者注以到"
79 | ],
80 | "gt_7": "每日经济新闻记者注意到"
81 | },
82 | {
83 | "out_list_8": [
84 | "这是今年五月份以来库存环比增幅幅小了一次",
85 | "这是今年五月份以来库存环比增幅最小了一次",
86 | "这是今年五月份以来库存环比增幅幅小的一次",
87 | "这是今年五月份以来库存环比增幅最小的一次",
88 | "这是今年五月份以来库存环比增幅幅小小一次"
89 | ],
90 | "gt_8": "这是今年五月份以来库存环比增幅最小的一次"
91 | },
92 | {
93 | "out_list_9": [
94 | "一个人的精使生命就将走向摔老",
95 | "一个连的精使生命就将走向摔老",
96 | "一个人的金使生命就将走向摔老",
97 | "一个人的坚使生命就将走向摔老",
98 | "一个连的金使生命就将走向摔老"
99 | ],
100 | "gt_9": "一个人的精神生命就将走向衰老"
101 | }
102 | ]
--------------------------------------------------------------------------------
/specAugment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/specAugment/__init__.py
--------------------------------------------------------------------------------
/specAugment/sparse_image_warp_np.py:
--------------------------------------------------------------------------------
1 | """Image warping using sparse flow defined at control points."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import scipy as sp
8 | from scipy.interpolate import interp2d
9 | from skimage.transform import warp
10 |
11 |
12 | def _get_grid_locations(image_height, image_width):
13 | """Wrapper for np.meshgrid."""
14 |
15 | y_range = np.linspace(0, image_height - 1, image_height)
16 | x_range = np.linspace(0, image_width - 1, image_width)
17 | y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
18 | return np.stack((y_grid, x_grid), -1)
19 |
20 |
21 | def _expand_to_minibatch(np_array, batch_size):
22 | """Tile arbitrarily-sized np_array to include new batch dimension."""
23 | tiles = [batch_size] + [1] * np_array.ndim
24 | return np.tile(np.expand_dims(np_array, 0), tiles)
25 |
26 |
27 | def _get_boundary_locations(image_height, image_width, num_points_per_edge):
28 | """Compute evenly-spaced indices along edge of image."""
29 | y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2)
30 | x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2)
31 | ys, xs = np.meshgrid(y_range, x_range, indexing='ij')
32 | is_boundary = np.logical_or(
33 | np.logical_or(xs == 0, xs == image_width - 1),
34 | np.logical_or(ys == 0, ys == image_height - 1))
35 | return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1)
36 |
37 |
38 | def _add_zero_flow_controls_at_boundary(control_point_locations,
39 | control_point_flows, image_height,
40 | image_width, boundary_points_per_edge):
41 | # batch_size = tensor_shape.dimension_value(control_point_locations.shape[0])
42 | batch_size = control_point_locations.shape[0]
43 |
44 | boundary_point_locations = _get_boundary_locations(image_height, image_width,
45 | boundary_points_per_edge)
46 |
47 | boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2])
48 |
49 | type_to_use = control_point_locations.dtype
50 | # boundary_point_locations = constant_op.constant(
51 | # _expand_to_minibatch(boundary_point_locations, batch_size),
52 | # dtype=type_to_use)
53 | boundary_point_locations = _expand_to_minibatch(boundary_point_locations, batch_size)
54 |
55 | # boundary_point_flows = constant_op.constant(
56 | # _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use)
57 | boundary_point_flows = _expand_to_minibatch(boundary_point_flows, batch_size)
58 |
59 | # merged_control_point_locations = array_ops.concat(
60 | # [control_point_locations, boundary_point_locations], 1)
61 |
62 | merged_control_point_locations = np.concatenate(
63 | [control_point_locations, boundary_point_locations], 1)
64 |
65 | # merged_control_point_flows = array_ops.concat(
66 | # [control_point_flows, boundary_point_flows], 1)
67 |
68 | merged_control_point_flows = np.concatenate(
69 | [control_point_flows, boundary_point_flows], 1)
70 |
71 | return merged_control_point_locations, merged_control_point_flows
72 |
73 |
74 | def sparse_image_warp_np(image,
75 | source_control_point_locations,
76 | dest_control_point_locations,
77 | interpolation_order=2,
78 | regularization_weight=0.0,
79 | num_boundary_points=0):
80 | # image = ops.convert_to_tensor(image)
81 | # source_control_point_locations = ops.convert_to_tensor(
82 | # source_control_point_locations)
83 | # dest_control_point_locations = ops.convert_to_tensor(
84 | # dest_control_point_locations)
85 |
86 | control_point_flows = (
87 | dest_control_point_locations - source_control_point_locations)
88 |
89 | clamp_boundaries = num_boundary_points > 0
90 | boundary_points_per_edge = num_boundary_points - 1
91 |
92 | # batch_size, image_height, image_width, _ = image.get_shape().as_list()
93 | batch_size, image_height, image_width, _ = list(image.shape)
94 |
95 | # This generates the dense locations where the interpolant
96 | # will be evaluated.
97 |
98 | grid_locations = _get_grid_locations(image_height, image_width)
99 |
100 | flattened_grid_locations = np.reshape(grid_locations,
101 | [image_height * image_width, 2])
102 |
103 | # flattened_grid_locations = constant_op.constant(
104 | # _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
105 |
106 | flattened_grid_locations = _expand_to_minibatch(flattened_grid_locations, batch_size)
107 |
108 | if clamp_boundaries:
109 | (dest_control_point_locations,
110 | control_point_flows) = _add_zero_flow_controls_at_boundary(
111 | dest_control_point_locations, control_point_flows, image_height,
112 | image_width, boundary_points_per_edge)
113 |
114 | # flattened_flows = interpolate_spline.interpolate_spline(
115 | # dest_control_point_locations, control_point_flows,
116 | # flattened_grid_locations, interpolation_order, regularization_weight)
117 | flattened_flows = sp.interpolate.spline(
118 | dest_control_point_locations, control_point_flows,
119 | flattened_grid_locations, interpolation_order, regularization_weight)
120 |
121 | # dense_flows = array_ops.reshape(flattened_flows,
122 | # [batch_size, image_height, image_width, 2])
123 | dense_flows = np.reshape(flattened_flows,
124 | [batch_size, image_height, image_width, 2])
125 |
126 | # warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
127 | warped_image = warp(image, dense_flows)
128 |
129 | return warped_image, dense_flows
130 |
131 |
132 | def dense_image_warp(image, flow):
133 | # batch_size, height, width, channels = (array_ops.shape(image)[0],
134 | # array_ops.shape(image)[1],
135 | # array_ops.shape(image)[2],
136 | # array_ops.shape(image)[3])
137 | batch_size, height, width, channels = (np.shape(image)[0],
138 | np.shape(image)[1],
139 | np.shape(image)[2],
140 | np.shape(image)[3])
141 |
142 | # The flow is defined on the image grid. Turn the flow into a list of query
143 | # points in the grid space.
144 | # grid_x, grid_y = array_ops.meshgrid(
145 | # math_ops.range(width), math_ops.range(height))
146 | # stacked_grid = math_ops.cast(
147 | # array_ops.stack([grid_y, grid_x], axis=2), flow.dtype)
148 | # batched_grid = array_ops.expand_dims(stacked_grid, axis=0)
149 | # query_points_on_grid = batched_grid - flow
150 | # query_points_flattened = array_ops.reshape(query_points_on_grid,
151 | # [batch_size, height * width, 2])
152 | grid_x, grid_y = np.meshgrid(
153 | np.range(width), np.range(height))
154 | stacked_grid = np.cast(
155 | np.stack([grid_y, grid_x], axis=2), flow.dtype)
156 | batched_grid = np.expand_dims(stacked_grid, axis=0)
157 | query_points_on_grid = batched_grid - flow
158 | query_points_flattened = np.reshape(query_points_on_grid,
159 | [batch_size, height * width, 2])
160 | # Compute values at the query points, then reshape the result back to the
161 | # image grid.
162 | interpolated = interp2d(image, query_points_flattened)
163 | interpolated = np.reshape(interpolated,
164 | [batch_size, height, width, channels])
165 | return interpolated
166 |
--------------------------------------------------------------------------------
/specAugment/sparse_image_warp_pytorch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 RnD at Spoon Radio
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | # import torch
16 | # import numpy as np
17 | # from torch.autograd import Variable
18 | # import librosa
19 | import random
20 |
21 | import numpy as np
22 | # import scipy.signal
23 | import torch
24 |
25 |
26 | # import torchaudio
27 | # from torchaudio import transforms
28 | # import math
29 | # from torch.utils.data import DataLoader
30 | # from torch.utils.data import Dataset
31 |
32 |
33 | def time_warp(spec, W=5):
34 | spec = spec.view(1, spec.shape[0], spec.shape[1])
35 | num_rows = spec.shape[1]
36 | spec_len = spec.shape[2]
37 |
38 | y = num_rows // 2
39 | horizontal_line_at_ctr = spec[0][y]
40 | assert len(horizontal_line_at_ctr) == spec_len
41 |
42 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)]
43 | assert isinstance(point_to_warp, torch.Tensor)
44 |
45 | # Uniform distribution from (0,W) with chance to be up to W negative
46 | dist_to_warp = random.randrange(-W, W)
47 | src_pts, dest_pts = torch.tensor([[[y, point_to_warp]]]), torch.tensor([[[y, point_to_warp + dist_to_warp]]])
48 | warped_spectro, dense_flows = SparseImageWarp.sparse_image_warp(spec, src_pts, dest_pts)
49 | return warped_spectro.squeeze(3)
50 |
51 |
52 | def freq_mask(spec, F=15, num_masks=1, replace_with_zero=False):
53 | cloned = spec.clone()
54 | num_mel_channels = cloned.shape[1]
55 |
56 | for i in range(0, num_masks):
57 | f = random.randrange(0, F)
58 | f_zero = random.randrange(0, num_mel_channels - f)
59 |
60 | # avoids randrange error if values are equal and range is empty
61 | if (f_zero == f_zero + f): return cloned
62 |
63 | mask_end = random.randrange(f_zero, f_zero + f)
64 | if (replace_with_zero):
65 | cloned[0][f_zero:mask_end] = 0
66 | else:
67 | cloned[0][f_zero:mask_end] = cloned.mean()
68 |
69 | return cloned
70 |
71 |
72 | def time_mask(spec, T=15, num_masks=1, replace_with_zero=False):
73 | cloned = spec.clone()
74 | len_spectro = cloned.shape[2]
75 |
76 | for i in range(0, num_masks):
77 | t = random.randrange(0, T)
78 | t_zero = random.randrange(0, len_spectro - t)
79 |
80 | # avoids randrange error if values are equal and range is empty
81 | if (t_zero == t_zero + t): return cloned
82 |
83 | mask_end = random.randrange(t_zero, t_zero + t)
84 | if (replace_with_zero):
85 | cloned[0][:, t_zero:mask_end] = 0
86 | else:
87 | cloned[0][:, t_zero:mask_end] = cloned.mean()
88 | return cloned
89 |
90 |
91 | def sparse_image_warp(img_tensor,
92 | source_control_point_locations,
93 | dest_control_point_locations,
94 | interpolation_order=2,
95 | regularization_weight=0.0,
96 | num_boundaries_points=0):
97 | control_point_flows = (dest_control_point_locations - source_control_point_locations)
98 |
99 | batch_size, image_height, image_width = img_tensor.shape
100 | grid_locations = get_grid_locations(image_height, image_width)
101 | flattened_grid_locations = torch.tensor(flatten_grid_locations(grid_locations, image_height, image_width))
102 |
103 | flattened_flows = interpolate_spline(
104 | dest_control_point_locations,
105 | control_point_flows,
106 | flattened_grid_locations,
107 | interpolation_order,
108 | regularization_weight)
109 |
110 | dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width)
111 |
112 | warped_image = dense_image_warp(img_tensor, dense_flows)
113 |
114 | return warped_image, dense_flows
115 |
116 |
117 | def get_grid_locations(image_height, image_width):
118 | """Wrapper for np.meshgrid."""
119 |
120 | y_range = np.linspace(0, image_height - 1, image_height)
121 | x_range = np.linspace(0, image_width - 1, image_width)
122 | y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
123 | return np.stack((y_grid, x_grid), -1)
124 |
125 |
126 | def flatten_grid_locations(grid_locations, image_height, image_width):
127 | return np.reshape(grid_locations, [image_height * image_width, 2])
128 |
129 |
130 | def create_dense_flows(flattened_flows, batch_size, image_height, image_width):
131 | # possibly .view
132 | return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2])
133 |
134 |
135 | def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0, ):
136 | # First, fit the spline to the observed data.
137 | w, v = solve_interpolation(train_points, train_values, order, regularization_weight)
138 | # Then, evaluate the spline at the query locations.
139 | query_values = apply_interpolation(query_points, train_points, w, v, order)
140 |
141 | return query_values
142 |
143 |
144 | def solve_interpolation(train_points, train_values, order, regularization_weight):
145 | b, n, d = train_points.shape
146 | k = train_values.shape[-1]
147 |
148 | # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
149 | # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
150 | # To account for python style guidelines we use
151 | # matrix_a for A and matrix_b for B.
152 |
153 | c = train_points
154 | f = train_values.float()
155 |
156 | matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n]
157 | # if regularization_weight > 0:
158 | # batch_identity_matrix = array_ops.expand_dims(
159 | # linalg_ops.eye(n, dtype=c.dtype), 0)
160 | # matrix_a += regularization_weight * batch_identity_matrix
161 |
162 | # Append ones to the feature values for the bias term in the linear model.
163 | ones = torch.ones(1, dtype=train_points.dtype).view([-1, 1, 1])
164 | matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1]
165 |
166 | # [b, n + d + 1, n]
167 | left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
168 |
169 | num_b_cols = matrix_b.shape[2] # d + 1
170 |
171 | # In Tensorflow, zeros are used here. Pytorch gesv fails with zeros for some reason we don't understand.
172 | # So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
173 | lhs_zeros = torch.randn((b, num_b_cols, num_b_cols)) / 1e10
174 | right_block = torch.cat((matrix_b, lhs_zeros),
175 | 1) # [b, n + d + 1, d + 1]
176 | lhs = torch.cat((left_block, right_block),
177 | 2) # [b, n + d + 1, n + d + 1]
178 |
179 | rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype).float()
180 | rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k]
181 |
182 | # Then, solve the linear system and unpack the results.
183 | X, LU = torch.solve(rhs, lhs)
184 | w = X[:, :n, :]
185 | v = X[:, n:, :]
186 |
187 | return w, v
188 |
189 |
190 | def cross_squared_distance_matrix(x, y):
191 | """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
192 | Computes the pairwise distances between rows of x and rows of y
193 | Args:
194 | x: [batch_size, n, d] float `Tensor`
195 | y: [batch_size, m, d] float `Tensor`
196 | Returns:
197 | squared_dists: [batch_size, n, m] float `Tensor`, where
198 | squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
199 | """
200 | x_norm_squared = torch.sum(torch.mul(x, x))
201 | y_norm_squared = torch.sum(torch.mul(y, y))
202 |
203 | x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1))
204 |
205 | # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
206 | squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared
207 |
208 | return squared_dists.float()
209 |
210 |
211 | def phi(r, order):
212 | """Coordinate-wise nonlinearity used to define the order of the interpolation.
213 | See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
214 | Args:
215 | r: input op
216 | order: interpolation order
217 | Returns:
218 | phi_k evaluated coordinate-wise on r, for k = r
219 | """
220 | EPSILON = torch.tensor(1e-10)
221 | # using EPSILON prevents log(0), sqrt0), etc.
222 | # sqrt(0) is well-defined, but its gradient is not
223 | if order == 1:
224 | r = torch.max(r, EPSILON)
225 | r = torch.sqrt(r)
226 | return r
227 | elif order == 2:
228 | return 0.5 * r * torch.log(torch.max(r, EPSILON))
229 | elif order == 4:
230 | return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
231 | elif order % 2 == 0:
232 | r = torch.max(r, EPSILON)
233 | return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
234 | else:
235 | r = torch.max(r, EPSILON)
236 | return torch.pow(r, 0.5 * order)
237 |
238 |
239 | def apply_interpolation(query_points, train_points, w, v, order):
240 | """Apply polyharmonic interpolation model to data.
241 | Given coefficients w and v for the interpolation model, we evaluate
242 | interpolated function values at query_points.
243 | Args:
244 | query_points: `[b, m, d]` x values to evaluate the interpolation at
245 | train_points: `[b, n, d]` x values that act as the interpolation centers
246 | ( the c variables in the wikipedia article)
247 | w: `[b, n, k]` weights on each interpolation center
248 | v: `[b, d, k]` weights on each input dimension
249 | order: order of the interpolation
250 | Returns:
251 | Polyharmonic interpolation evaluated at points defined in query_points.
252 | """
253 | query_points = query_points.unsqueeze(0)
254 | # First, compute the contribution from the rbf term.
255 | pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float())
256 | phi_pairwise_dists = phi(pairwise_dists, order)
257 |
258 | rbf_term = torch.matmul(phi_pairwise_dists, w)
259 |
260 | # Then, compute the contribution from the linear term.
261 | # Pad query_points with ones, for the bias term in the linear model.
262 | ones = torch.ones_like(query_points[..., :1])
263 | query_points_pad = torch.cat((
264 | query_points,
265 | ones
266 | ), 2).float()
267 | linear_term = torch.matmul(query_points_pad, v)
268 |
269 | return rbf_term + linear_term
270 |
271 |
272 | def dense_image_warp(image, flow):
273 | """Image warping using per-pixel flow vectors.
274 | Apply a non-linear warp to the image, where the warp is specified by a dense
275 | flow field of offset vectors that define the correspondences of pixel values
276 | in the output image back to locations in the source image. Specifically, the
277 | pixel value at output[b, j, i, c] is
278 | images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
279 | The locations specified by this formula do not necessarily map to an int
280 | index. Therefore, the pixel value is obtained by bilinear
281 | interpolation of the 4 nearest pixels around
282 | (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
283 | of the image, we use the nearest pixel values at the image boundary.
284 | Args:
285 | image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
286 | flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
287 | name: A name for the operation (optional).
288 | Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
289 | and do not necessarily have to be the same type.
290 | Returns:
291 | A 4-D float `Tensor` with shape`[batch, height, width, channels]`
292 | and same type as input image.
293 | Raises:
294 | ValueError: if height < 2 or width < 2 or the inputs have the wrong number
295 | of dimensions.
296 | """
297 | image = image.unsqueeze(3) # add a single channel dimension to image tensor
298 | batch_size, height, width, channels = image.shape
299 |
300 | # The flow is defined on the image grid. Turn the flow into a list of query
301 | # points in the grid space.
302 | grid_x, grid_y = torch.meshgrid(
303 | torch.arange(width), torch.arange(height))
304 |
305 | stacked_grid = torch.stack((grid_y, grid_x), dim=2).float()
306 |
307 | batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2)
308 |
309 | query_points_on_grid = batched_grid - flow
310 | query_points_flattened = torch.reshape(query_points_on_grid,
311 | [batch_size, height * width, 2])
312 | # Compute values at the query points, then reshape the result back to the
313 | # image grid.
314 | interpolated = interpolate_bilinear(image, query_points_flattened)
315 | interpolated = torch.reshape(interpolated,
316 | [batch_size, height, width, channels])
317 | return interpolated
318 |
319 |
320 | def interpolate_bilinear(grid,
321 | query_points,
322 | name='interpolate_bilinear',
323 | indexing='ij'):
324 | """Similar to Matlab's interp2 function.
325 | Finds values for query points on a grid using bilinear interpolation.
326 | Args:
327 | grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
328 | query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
329 | name: a name for the operation (optional).
330 | indexing: whether the query points are specified as row and column (ij),
331 | or Cartesian coordinates (xy).
332 | Returns:
333 | values: a 3-D `Tensor` with shape `[batch, N, channels]`
334 | Raises:
335 | ValueError: if the indexing mode is invalid, or if the shape of the inputs
336 | invalid.
337 | """
338 | if indexing != 'ij' and indexing != 'xy':
339 | raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
340 |
341 | shape = grid.shape
342 | if len(shape) != 4:
343 | msg = 'Grid must be 4 dimensional. Received size: '
344 | raise ValueError(msg + str(grid.shape))
345 |
346 | batch_size, height, width, channels = grid.shape
347 |
348 | shape = [batch_size, height, width, channels]
349 | query_type = query_points.dtype
350 | grid_type = grid.dtype
351 |
352 | num_queries = query_points.shape[1]
353 |
354 | alphas = []
355 | floors = []
356 | ceils = []
357 | index_order = [0, 1] if indexing == 'ij' else [1, 0]
358 | unstacked_query_points = query_points.unbind(2)
359 |
360 | for dim in index_order:
361 | queries = unstacked_query_points[dim]
362 |
363 | size_in_indexing_dimension = shape[dim + 1]
364 |
365 | # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
366 | # is still a valid index into the grid.
367 | max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type)
368 | min_floor = torch.tensor(0.0, dtype=query_type)
369 | maxx = torch.max(min_floor, torch.floor(queries))
370 | floor = torch.min(maxx, max_floor)
371 | int_floor = floor.long()
372 | floors.append(int_floor)
373 | ceil = int_floor + 1
374 | ceils.append(ceil)
375 |
376 | # alpha has the same type as the grid, as we will directly use alpha
377 | # when taking linear combinations of pixel values from the image.
378 | alpha = torch.tensor(queries - floor, dtype=grid_type)
379 | min_alpha = torch.tensor(0.0, dtype=grid_type)
380 | max_alpha = torch.tensor(1.0, dtype=grid_type)
381 | alpha = torch.min(torch.max(min_alpha, alpha), max_alpha)
382 |
383 | # Expand alpha to [b, n, 1] so we can use broadcasting
384 | # (since the alpha values don't depend on the channel).
385 | alpha = torch.unsqueeze(alpha, 2)
386 | alphas.append(alpha)
387 |
388 | flattened_grid = torch.reshape(
389 | grid, [batch_size * height * width, channels])
390 | batch_offsets = torch.reshape(
391 | torch.arange(batch_size) * height * width, [batch_size, 1])
392 |
393 | # This wraps array_ops.gather. We reshape the image data such that the
394 | # batch, y, and x coordinates are pulled into the first dimension.
395 | # Then we gather. Finally, we reshape the output back. It's possible this
396 | # code would be made simpler by using array_ops.gather_nd.
397 | def gather(y_coords, x_coords, name):
398 | linear_coordinates = batch_offsets + y_coords * width + x_coords
399 | gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates)
400 | return torch.reshape(gathered_values,
401 | [batch_size, num_queries, channels])
402 |
403 | # grab the pixel values in the 4 corners around each query point
404 | top_left = gather(floors[0], floors[1], 'top_left')
405 | top_right = gather(floors[0], ceils[1], 'top_right')
406 | bottom_left = gather(ceils[0], floors[1], 'bottom_left')
407 | bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
408 |
409 | interp_top = alphas[1] * (top_right - top_left) + top_left
410 | interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
411 | interp = alphas[0] * (interp_bottom - interp_top) + interp_top
412 |
413 | return interp
414 |
--------------------------------------------------------------------------------
/specAugment/spec_augment_pytorch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 RnD at Spoon Radio
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """specAugment Implementation for Tensorflow.
16 | Related paper : https://arxiv.org/pdf/1904.08779.pdf
17 | In this paper, show summarized parameters by each open datasets in Tabel 1.
18 | -----------------------------------------
19 | Policy | W | F | m_F | T | p | m_T
20 | -----------------------------------------
21 | None | 0 | 0 | - | 0 | - | -
22 | -----------------------------------------
23 | LB | 80 | 27 | 1 | 100 | 1.0 | 1
24 | -----------------------------------------
25 | LD | 80 | 27 | 2 | 100 | 1.0 | 2
26 | -----------------------------------------
27 | SM | 40 | 15 | 2 | 70 | 0.2 | 2
28 | -----------------------------------------
29 | SS | 40 | 27 | 2 | 70 | 0.2 | 2
30 | -----------------------------------------
31 | LB : LibriSpeech basic
32 | LD : LibriSpeech double
33 | SM : Switchboard mild
34 | SS : Switchboard strong
35 | """
36 |
37 | import random
38 |
39 | import librosa
40 | import librosa.display
41 | import matplotlib
42 | import numpy as np
43 |
44 | matplotlib.use('TkAgg')
45 | import matplotlib.pyplot as plt
46 | from specAugment.sparse_image_warp_np import sparse_image_warp_np
47 | import torch
48 |
49 |
50 | def time_warp(spec, W=5):
51 | num_rows = spec.shape[1]
52 | spec_len = spec.shape[2]
53 |
54 | y = num_rows // 2
55 | horizontal_line_at_ctr = spec[0][y]
56 | # assert len(horizontal_line_at_ctr) == spec_len
57 |
58 | point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)]
59 | # assert isinstance(point_to_warp, torch.Tensor)
60 |
61 | # Uniform distribution from (0,W) with chance to be up to W negative
62 | dist_to_warp = random.randrange(-W, W)
63 | src_pts = torch.tensor([[[y, point_to_warp]]])
64 | dest_pts = torch.tensor([[[y, point_to_warp + dist_to_warp]]])
65 | warped_spectro, dense_flows = sparse_image_warp_np(spec, src_pts, dest_pts)
66 | return warped_spectro.squeeze(3)
67 |
68 |
69 | def spec_augment(mel_spectrogram, time_warping_para=80, frequency_masking_para=27,
70 | time_masking_para=100, frequency_mask_num=1, time_mask_num=1):
71 | """Spec augmentation Calculation Function.
72 | 'specAugment' have 3 steps for audio data augmentation.
73 | first step is time warping using Tensorflow's image_sparse_warp function.
74 | Second step is frequency masking, last step is time masking.
75 | # Arguments:
76 | mel_spectrogram(numpy array): audio file path of you want to warping and masking.
77 | time_warping_para(float): Augmentation parameter, "time warp parameter W".
78 | If none, default = 80 for LibriSpeech.
79 | frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F"
80 | If none, default = 100 for LibriSpeech.
81 | time_masking_para(float): Augmentation parameter, "time mask parameter T"
82 | If none, default = 27 for LibriSpeech.
83 | frequency_mask_num(float): number of frequency masking lines, "m_F".
84 | If none, default = 1 for LibriSpeech.
85 | time_mask_num(float): number of time masking lines, "m_T".
86 | If none, default = 1 for LibriSpeech.
87 | # Returns
88 | mel_spectrogram(numpy array): warped and masked mel spectrogram.
89 | """
90 | v = mel_spectrogram.shape[1]
91 | tau = mel_spectrogram.shape[2]
92 |
93 | # Step 1 : Time warping
94 | warped_mel_spectrogram = time_warp(mel_spectrogram)
95 |
96 | # Step 2 : Frequency masking
97 | for i in range(frequency_mask_num):
98 | f = np.random.uniform(low=0.0, high=frequency_masking_para)
99 | f = int(f)
100 | f0 = random.randint(0, v - f)
101 | warped_mel_spectrogram[:, f0:f0 + f, :] = 0
102 |
103 | # Step 3 : Time masking
104 | for i in range(time_mask_num):
105 | t = np.random.uniform(low=0.0, high=time_masking_para)
106 | t = int(t)
107 | t0 = random.randint(0, tau - t)
108 | warped_mel_spectrogram[:, :, t0:t0 + t] = 0
109 |
110 | return warped_mel_spectrogram
111 |
112 |
113 | def visualization_spectrogram(mel_spectrogram, title):
114 | """visualizing result of specAugment
115 | # Arguments:
116 | mel_spectrogram(ndarray): mel_spectrogram to visualize.
117 | title(String): plot figure's title
118 | """
119 | # Show mel-spectrogram using librosa's specshow.
120 | plt.figure(figsize=(10, 4))
121 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :], ref=np.max), y_axis='mel', fmax=8000,
122 | x_axis='time')
123 | # plt.colorbar(format='%+2.0f dB')
124 | plt.title(title)
125 | plt.tight_layout()
126 | plt.show()
127 |
--------------------------------------------------------------------------------
/specAugment/spec_augment_tensorflow.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 RnD at Spoon Radio
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """SpecAugment Implementation for Tensorflow.
16 | Related paper : https://arxiv.org/pdf/1904.08779.pdf
17 | In this paper, show summarized parameters by each open datasets in Tabel 1.
18 | -----------------------------------------
19 | Policy | W | F | m_F | T | p | m_T
20 | -----------------------------------------
21 | None | 0 | 0 | - | 0 | - | -
22 | -----------------------------------------
23 | LB | 80 | 27 | 1 | 100 | 1.0 | 1
24 | -----------------------------------------
25 | LD | 80 | 27 | 2 | 100 | 1.0 | 2
26 | -----------------------------------------
27 | SM | 40 | 15 | 2 | 70 | 0.2 | 2
28 | -----------------------------------------
29 | SS | 40 | 27 | 2 | 70 | 0.2 | 2
30 | -----------------------------------------
31 | LB : LibriSpeech basic
32 | LD : LibriSpeech double
33 | SM : Switchboard mild
34 | SS : Switchboard strong
35 | """
36 |
37 | import librosa
38 | import librosa.display
39 | import matplotlib.pyplot as plt
40 | import numpy as np
41 | import tensorflow as tf
42 | from tensorflow.contrib.image import sparse_image_warp
43 |
44 |
45 | def sparse_warp(mel_spectrogram, time_warping_para=80):
46 | """Spec augmentation Calculation Function.
47 | 'SpecAugment' have 3 steps for audio data augmentation.
48 | first step is time warping using Tensorflow's image_sparse_warp function.
49 | Second step is frequency masking, last step is time masking.
50 | # Arguments:
51 | mel_spectrogram(numpy array): audio file path of you want to warping and masking.
52 | time_warping_para(float): Augmentation parameter, "time warp parameter W".
53 | If none, default = 80 for LibriSpeech.
54 | # Returns
55 | mel_spectrogram(numpy array): warped and masked mel spectrogram.
56 | """
57 |
58 | fbank_size = tf.shape(mel_spectrogram)
59 | n, v = fbank_size[1], fbank_size[2]
60 |
61 | # Step 1 : Time warping
62 | # Image warping control point setting.
63 | # Source
64 | pt = tf.random_uniform([], time_warping_para, n - time_warping_para, tf.int32) # radnom point along the time axis
65 | src_ctr_pt_freq = tf.range(v // 2) # control points on freq-axis
66 | src_ctr_pt_time = tf.ones_like(src_ctr_pt_freq) * pt # control points on time-axis
67 | src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
68 | src_ctr_pts = tf.to_float(src_ctr_pts)
69 |
70 | # Destination
71 | w = tf.random_uniform([], -time_warping_para, time_warping_para, tf.int32) # distance
72 | dest_ctr_pt_freq = src_ctr_pt_freq
73 | dest_ctr_pt_time = src_ctr_pt_time + w
74 | dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
75 | dest_ctr_pts = tf.to_float(dest_ctr_pts)
76 |
77 | # warp
78 | source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2)
79 | dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2)
80 |
81 | warped_image, _ = sparse_image_warp(mel_spectrogram,
82 | source_control_point_locations,
83 | dest_control_point_locations)
84 | return warped_image
85 |
86 |
87 | def frequency_masking(mel_spectrogram, v, frequency_masking_para=27, frequency_mask_num=2):
88 | """Spec augmentation Calculation Function.
89 | 'SpecAugment' have 3 steps for audio data augmentation.
90 | first step is time warping using Tensorflow's image_sparse_warp function.
91 | Second step is frequency masking, last step is time masking.
92 | # Arguments:
93 | mel_spectrogram(numpy array): audio file path of you want to warping and masking.
94 | frequency_masking_para(float): Augmentation parameter, "frequency mask parameter F"
95 | If none, default = 100 for LibriSpeech.
96 | frequency_mask_num(float): number of frequency masking lines, "m_F".
97 | If none, default = 1 for LibriSpeech.
98 | # Returns
99 | mel_spectrogram(numpy array): warped and masked mel spectrogram.
100 | """
101 | # Step 2 : Frequency masking
102 | fbank_size = tf.shape(mel_spectrogram)
103 | n, v = fbank_size[1], fbank_size[2]
104 |
105 | for i in range(frequency_mask_num):
106 | f = tf.random_uniform([], minval=0, maxval=frequency_masking_para, dtype=tf.int32)
107 | v = tf.to_int32(v)
108 | f0 = tf.random_uniform([], minval=0, maxval=v - f, dtype=tf.int32)
109 |
110 | # warped_mel_spectrogram[f0:f0 + f, :] = 0
111 | mask = tf.concat((tf.ones(shape=(1, n, v - f0 - f, 1)),
112 | tf.zeros(shape=(1, n, f, 1)),
113 | tf.ones(shape=(1, n, f0, 1)),
114 | ), 2)
115 | mel_spectrogram = mel_spectrogram * mask
116 | return tf.to_float(mel_spectrogram)
117 |
118 |
119 | def time_masking(mel_spectrogram, tau, time_masking_para=100, time_mask_num=2):
120 | """Spec augmentation Calculation Function.
121 | 'SpecAugment' have 3 steps for audio data augmentation.
122 | first step is time warping using Tensorflow's image_sparse_warp function.
123 | Second step is frequency masking, last step is time masking.
124 | # Arguments:
125 | mel_spectrogram(numpy array): audio file path of you want to warping and masking.
126 | time_masking_para(float): Augmentation parameter, "time mask parameter T"
127 | If none, default = 27 for LibriSpeech.
128 | time_mask_num(float): number of time masking lines, "m_T".
129 | If none, default = 1 for LibriSpeech.
130 | # Returns
131 | mel_spectrogram(numpy array): warped and masked mel spectrogram.
132 | """
133 | fbank_size = tf.shape(mel_spectrogram)
134 | n, v = fbank_size[1], fbank_size[2]
135 |
136 | # Step 3 : Time masking
137 | for i in range(time_mask_num):
138 | t = tf.random_uniform([], minval=0, maxval=time_masking_para, dtype=tf.int32)
139 | t0 = tf.random_uniform([], minval=0, maxval=tau - t, dtype=tf.int32)
140 |
141 | # mel_spectrogram[:, t0:t0 + t] = 0
142 | mask = tf.concat((tf.ones(shape=(1, n - t0 - t, v, 1)),
143 | tf.zeros(shape=(1, t, v, 1)),
144 | tf.ones(shape=(1, t0, v, 1)),
145 | ), 1)
146 | mel_spectrogram = mel_spectrogram * mask
147 | return tf.to_float(mel_spectrogram)
148 |
149 |
150 | def spec_augment(mel_spectrogram):
151 | v = mel_spectrogram.shape[0]
152 | tau = mel_spectrogram.shape[1]
153 |
154 | warped_mel_spectrogram = sparse_warp(mel_spectrogram)
155 |
156 | warped_frequency_spectrogram = frequency_masking(warped_mel_spectrogram, v=v)
157 |
158 | warped_frequency_time_sepctrogram = time_masking(warped_frequency_spectrogram, tau=tau)
159 |
160 | return warped_frequency_time_sepctrogram
161 |
162 |
163 | def visualization_spectrogram(mel_spectrogram, title):
164 | """visualizing first one result of SpecAugment
165 | # Arguments:
166 | mel_spectrogram(ndarray): mel_spectrogram to visualize.
167 | title(String): plot figure's title
168 | """
169 | # Show mel-spectrogram using librosa's specshow.
170 | plt.figure(figsize=(10, 4))
171 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000,
172 | x_axis='time')
173 | plt.title(title)
174 | plt.tight_layout()
175 | plt.show()
176 |
177 |
178 | def visualization_tensor_spectrogram(mel_spectrogram, title):
179 | """visualizing first one result of SpecAugment
180 | # Arguments:
181 | mel_spectrogram(ndarray): mel_spectrogram to visualize.
182 | title(String): plot figure's title
183 | """
184 |
185 | # session for plotting
186 | sess = tf.InteractiveSession()
187 | mel_spectrogram = mel_spectrogram.eval()
188 |
189 | # Show mel-spectrogram using librosa's specshow.
190 | plt.figure(figsize=(10, 4))
191 | librosa.display.specshow(librosa.power_to_db(mel_spectrogram[0, :, :, 0], ref=np.max), y_axis='mel', fmax=8000,
192 | x_axis='time')
193 | # plt.colorbar(format='%+2.0f dB')
194 | plt.title(title)
195 | plt.tight_layout()
196 | plt.show()
197 |
--------------------------------------------------------------------------------
/sponsor.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/sponsor.jpg
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 |
4 | import torch
5 | from tqdm import tqdm
6 |
7 | from config import pickle_file, device, input_dim, LFR_m, LFR_n, sos_id, eos_id
8 | from data_gen import build_LFR_features
9 | from utils import extract_feature
10 | from xer import cer_function
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(
15 | "End-to-End Automatic Speech Recognition Decoding.")
16 | # decode
17 | parser.add_argument('--beam_size', default=5, type=int,
18 | help='Beam size')
19 | parser.add_argument('--nbest', default=1, type=int,
20 | help='Nbest size')
21 | parser.add_argument('--decode_max_len', default=100, type=int,
22 | help='Max output length. If ==0 (default), it uses a '
23 | 'end-detect function to automatically find maximum '
24 | 'hypothesis lengths')
25 | args = parser.parse_args()
26 | return args
27 |
28 |
29 | if __name__ == '__main__':
30 | args = parse_args()
31 | with open(pickle_file, 'rb') as file:
32 | data = pickle.load(file)
33 | char_list = data['IVOCAB']
34 | samples = data['test']
35 |
36 | checkpoint = 'BEST_checkpoint.tar'
37 | checkpoint = torch.load(checkpoint, map_location='cpu')
38 | model = checkpoint['model'].to(device)
39 | model.eval()
40 |
41 | num_samples = len(samples)
42 |
43 | total_cer = 0
44 |
45 | for i in tqdm(range(num_samples)):
46 | sample = samples[i]
47 | wave = sample['wave']
48 | trn = sample['trn']
49 |
50 | feature = extract_feature(input_file=wave, feature='fbank', dim=input_dim, cmvn=True)
51 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n)
52 | # feature = np.expand_dims(feature, axis=0)
53 | input = torch.from_numpy(feature).to(device)
54 | input_length = [input[0].shape[0]]
55 | input_length = torch.LongTensor(input_length).to(device)
56 | with torch.no_grad():
57 | nbest_hyps = model.recognize(input, input_length, char_list, args)
58 |
59 | hyp_list = []
60 | for hyp in nbest_hyps:
61 | out = hyp['yseq']
62 | out = [char_list[idx] for idx in out if idx not in (sos_id, eos_id)]
63 | out = ''.join(out)
64 | hyp_list.append(out)
65 |
66 | print(hyp_list)
67 |
68 | gt = [char_list[idx] for idx in trn if idx not in (sos_id, eos_id)]
69 | gt = ''.join(gt)
70 | gt_list = [gt]
71 |
72 | print(gt_list)
73 |
74 | cer = cer_function(gt_list, hyp_list)
75 | total_cer += cer
76 |
77 | avg_cer = total_cer / num_samples
78 |
79 | print('avg_cer: ' + str(avg_cer))
80 |
--------------------------------------------------------------------------------
/test/test_decode.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 | from transformer.decoder import Decoder
6 | from transformer.encoder import Encoder
7 | from transformer.transformer import Transformer
8 |
9 | if __name__ == "__main__":
10 | D = 3
11 | beam_size = 5
12 | nbest = 5
13 | defaults = dict(beam_size=beam_size,
14 | nbest=nbest,
15 | decode_max_len=0,
16 | d_input=D,
17 | LFR_m=1,
18 | n_layers_enc=2,
19 | n_head=2,
20 | d_k=6,
21 | d_v=6,
22 | d_model=12,
23 | d_inner=8,
24 | dropout=0.1,
25 | pe_maxlen=100,
26 | d_word_vec=12,
27 | n_layers_dec=2,
28 | tgt_emb_prj_weight_sharing=1)
29 | args = argparse.Namespace(**defaults)
30 | char_list = ["a", "b", "c", "d", "e", "f", "g", "h", "", ""]
31 | sos_id, eos_id = 8, 9
32 | vocab_size = len(char_list)
33 | # model
34 | encoder = Encoder(args.d_input * args.LFR_m, args.n_layers_enc, args.n_head,
35 | args.d_k, args.d_v, args.d_model, args.d_inner,
36 | dropout=args.dropout, pe_maxlen=args.pe_maxlen)
37 | decoder = Decoder(sos_id, eos_id, vocab_size,
38 | args.d_word_vec, args.n_layers_dec, args.n_head,
39 | args.d_k, args.d_v, args.d_model, args.d_inner,
40 | dropout=args.dropout,
41 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
42 | pe_maxlen=args.pe_maxlen)
43 | model = Transformer(encoder, decoder)
44 |
45 | for i in range(3):
46 | print("\n***** Utt", i + 1)
47 | Ti = i + 20
48 | input = torch.randn(Ti, D)
49 | length = torch.tensor([Ti], dtype=torch.int)
50 | nbest_hyps = model.recognize(input, length, char_list, args)
51 |
52 | file_path = "./temp.pth"
53 | optimizer = torch.optim.Adam(model.parameters())
54 | torch.save(model.serialize(model, optimizer, 1, LFR_m=1, LFR_n=1), file_path)
55 | model, LFR_m, LFR_n = Transformer.load_model(file_path)
56 | print(model)
57 |
58 | import os
59 |
60 | os.remove(file_path)
61 |
--------------------------------------------------------------------------------
/test/test_lm.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import pickle
3 |
4 | import nltk
5 | from tqdm import tqdm
6 |
7 | from config import pickle_file
8 |
9 | with open(pickle_file, 'rb') as file:
10 | data = pickle.load(file)
11 | char_list = data['IVOCAB']
12 | samples = data['train']
13 | bigram_freq = collections.Counter()
14 |
15 | for sample in tqdm(samples):
16 | text = sample['trn']
17 | text = [char_list[idx] for idx in text]
18 | tokens = list(text)
19 | bigrm = nltk.bigrams(tokens)
20 | # print(*map(' '.join, bigrm), sep=', ')
21 |
22 | # get the frequency of each bigram in our corpus
23 | bigram_freq.update(bigrm)
24 |
25 | # what are the ten most popular ngrams in this Spanish corpus?
26 | print(bigram_freq.most_common(10))
27 |
--------------------------------------------------------------------------------
/test/test_lr.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | if __name__ == '__main__':
4 | k = 0.2
5 | warmup_steps = 4000
6 | d_model = 512
7 | init_lr = d_model ** (-0.5)
8 |
9 | lr_list = []
10 | for step_num in range(1, 500000):
11 | lr = k * init_lr * min(step_num ** (-0.5),
12 | step_num * (warmup_steps ** (-1.5)))
13 | lr_list.append(lr)
14 |
15 | print(lr_list[:100])
16 | print(lr_list[-100:])
17 |
18 | plt.plot(lr_list)
19 | plt.show()
20 |
--------------------------------------------------------------------------------
/test/test_pe.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class PositionalEncoding(nn.Module):
8 | """Implement the positional encoding (PE) function.
9 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
10 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
11 | """
12 |
13 | def __init__(self, d_model, max_len=5000):
14 | super(PositionalEncoding, self).__init__()
15 | # Compute the positional encodings once in log space.
16 | pe = torch.zeros(max_len, d_model, requires_grad=False)
17 | position = torch.arange(0, max_len).unsqueeze(1).float()
18 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
19 | -(math.log(10000.0) / d_model))
20 | pe[:, 0::2] = torch.sin(position * div_term)
21 | pe[:, 1::2] = torch.cos(position * div_term)
22 | pe = pe.unsqueeze(0)
23 | self.register_buffer('pe', pe)
24 |
25 | def forward(self, input):
26 | """
27 | Args:
28 | input: N x T x D
29 | """
30 | length = input.size(1)
31 | return self.pe[:, :length]
32 |
33 |
34 | if __name__ == '__main__':
35 | import numpy as np
36 | import matplotlib.pyplot as plt
37 |
38 | d_model = 512
39 | max_len = 5000
40 | pe = PositionalEncoding(d_model, max_len)
41 | mat = pe.pe.numpy()[0] # (5000, 512)
42 | mat = np.transpose(mat, (1, 0))
43 | print(mat.shape)
44 | print(mat)
45 | plt.imshow(mat)
46 | plt.colorbar()
47 | plt.show()
48 |
--------------------------------------------------------------------------------
/test/test_specaug.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from data_gen import spec_augment, build_LFR_features
4 | from utils import extract_feature
5 |
6 | LFR_m = 4
7 | LFR_n = 3
8 |
9 |
10 | def plot_data(data, figsize=(16, 4)):
11 | fig, axes = plt.subplots(1, len(data), figsize=figsize)
12 | for i in range(len(data)):
13 | axes[i].imshow(data[i], aspect='auto', origin='bottom',
14 | interpolation='none')
15 |
16 |
17 | if __name__ == '__main__':
18 | path = '../audios/audio_0.wav'
19 | feature = extract_feature(input_file=path, feature='fbank', dim=80, cmvn=True)
20 | feature = build_LFR_features(feature, m=LFR_m, n=LFR_n)
21 |
22 | # zero mean and unit variance
23 | feature = (feature - feature.mean()) / feature.std()
24 |
25 | feature_1 = spec_augment(feature)
26 | #
27 |
28 | plot_data((feature.transpose(), feature_1.transpose()))
29 | plt.show()
30 |
--------------------------------------------------------------------------------
/test/test_trim.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | import soundfile
4 |
5 | from utils import normalize
6 |
7 | sampling_rate = 16000
8 | top_db = 20
9 | reduced_ratios = []
10 |
11 | for i in range(10):
12 | audiopath = '../audios/audio_{}.wav'.format(i)
13 | print(audiopath)
14 | y, sr = librosa.load(audiopath)
15 | # Trim the beginning and ending silence
16 | yt, index = librosa.effects.trim(y, top_db=top_db)
17 | yt = normalize(yt)
18 |
19 | reduced_ratios.append(len(yt) / len(y))
20 |
21 | # Print the durations
22 | print(librosa.get_duration(y), librosa.get_duration(yt))
23 | print(len(y), len(yt))
24 | target = '../audios/trimed_{}.wav'.format(i)
25 | soundfile.write(target, yt, sampling_rate)
26 |
27 | print('\nreduced_ratio: ' + str(100 - 100 * np.mean(reduced_ratios)))
28 |
--------------------------------------------------------------------------------
/test_lm.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import numpy as np
4 |
5 | from config import pickle_file, sos_id, eos_id
6 |
7 | print('loading {}...'.format(pickle_file))
8 | with open(pickle_file, 'rb') as file:
9 | data = pickle.load(file)
10 | VOCAB = data['VOCAB']
11 | IVOCAB = data['IVOCAB']
12 |
13 | print('loading bigram_freq.pkl...')
14 | with open('bigram_freq.pkl', 'rb') as file:
15 | bigram_freq = pickle.load(file)
16 |
17 | OUT_LIST = ['比赛很快便城像一边到的局面第二规合', '比赛很快便呈像一边到的局面第二规合', '比赛很快便城向一边到的局面第二规合',
18 | '比赛很快便呈向一边到的局面第二规合', '比赛很快便城像一边到的局面第二回合']
19 | GT = '比赛很快便呈向一边倒的局面第二回合'
20 |
21 | print('calculating prob...')
22 | prob_list = []
23 | for out in OUT_LIST:
24 | print(out)
25 | out = out.replace('', '').replace('', '')
26 | out = [sos_id] + [VOCAB[ch] for ch in out] + [eos_id]
27 | prob = 1.0
28 | for i in range(1, len(out)):
29 | prob *= bigram_freq[(out[i - 1], out[i])]
30 | prob_list.append(prob)
31 |
32 | prob_list = np.array(prob_list)
33 | prob_list = prob_list / np.sum(prob_list)
34 | print(prob_list)
35 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.tensorboard import SummaryWriter
4 | # from torch import nn
5 | from tqdm import tqdm
6 |
7 | from config import device, print_freq, vocab_size, sos_id, eos_id
8 | from data_gen import AiShellDataset, pad_collate
9 | from transformer.decoder import Decoder
10 | from transformer.encoder import Encoder
11 | from transformer.loss import cal_performance
12 | from transformer.optimizer import TransformerOptimizer
13 | from transformer.transformer import Transformer
14 | from utils import parse_args, save_checkpoint, AverageMeter, get_logger
15 |
16 |
17 | def train_net(args):
18 | torch.manual_seed(7)
19 | np.random.seed(7)
20 | checkpoint = args.checkpoint
21 | start_epoch = 0
22 | best_loss = float('inf')
23 | writer = SummaryWriter()
24 | epochs_since_improvement = 0
25 |
26 | # Initialize / load checkpoint
27 | if checkpoint is None:
28 | # model
29 | encoder = Encoder(args.d_input * args.LFR_m, args.n_layers_enc, args.n_head,
30 | args.d_k, args.d_v, args.d_model, args.d_inner,
31 | dropout=args.dropout, pe_maxlen=args.pe_maxlen)
32 | decoder = Decoder(sos_id, eos_id, vocab_size,
33 | args.d_word_vec, args.n_layers_dec, args.n_head,
34 | args.d_k, args.d_v, args.d_model, args.d_inner,
35 | dropout=args.dropout,
36 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
37 | pe_maxlen=args.pe_maxlen)
38 | model = Transformer(encoder, decoder)
39 | # print(model)
40 | # model = nn.DataParallel(model)
41 |
42 | # optimizer
43 | optimizer = TransformerOptimizer(
44 | torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09))
45 |
46 | else:
47 | checkpoint = torch.load(checkpoint)
48 | start_epoch = checkpoint['epoch'] + 1
49 | epochs_since_improvement = checkpoint['epochs_since_improvement']
50 | model = checkpoint['model']
51 | optimizer = checkpoint['optimizer']
52 |
53 | logger = get_logger()
54 |
55 | # Move to GPU, if available
56 | model = model.to(device)
57 |
58 | # Custom dataloaders
59 | train_dataset = AiShellDataset(args, 'train')
60 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
61 | pin_memory=True, shuffle=True, num_workers=args.num_workers)
62 | valid_dataset = AiShellDataset(args, 'dev')
63 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
64 | pin_memory=True, shuffle=False, num_workers=args.num_workers)
65 |
66 | # Epochs
67 | for epoch in range(start_epoch, args.epochs):
68 | # One epoch's training
69 | train_loss = train(train_loader=train_loader,
70 | model=model,
71 | optimizer=optimizer,
72 | epoch=epoch,
73 | logger=logger)
74 | writer.add_scalar('model/train_loss', train_loss, epoch)
75 |
76 | lr = optimizer.lr
77 | print('\nLearning rate: {}'.format(lr))
78 | writer.add_scalar('model/learning_rate', lr, epoch)
79 | step_num = optimizer.step_num
80 | print('Step num: {}\n'.format(step_num))
81 |
82 | # One epoch's validation
83 | valid_loss = valid(valid_loader=valid_loader,
84 | model=model,
85 | logger=logger)
86 | writer.add_scalar('model/valid_loss', valid_loss, epoch)
87 |
88 | # Check if there was an improvement
89 | is_best = valid_loss < best_loss
90 | best_loss = min(valid_loss, best_loss)
91 | if not is_best:
92 | epochs_since_improvement += 1
93 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
94 | else:
95 | epochs_since_improvement = 0
96 |
97 | # Save checkpoint
98 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
99 |
100 |
101 | def train(train_loader, model, optimizer, epoch, logger):
102 | model.train() # train mode (dropout and batchnorm is used)
103 |
104 | losses = AverageMeter()
105 |
106 | # Batches
107 | for i, (data) in enumerate(train_loader):
108 | # Move to GPU, if available
109 | padded_input, padded_target, input_lengths = data
110 | padded_input = padded_input.to(device)
111 | padded_target = padded_target.to(device)
112 | input_lengths = input_lengths.to(device)
113 |
114 | # Forward prop.
115 | pred, gold = model(padded_input, input_lengths, padded_target)
116 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
117 |
118 | # Back prop.
119 | optimizer.zero_grad()
120 | loss.backward()
121 |
122 | # Update weights
123 | optimizer.step()
124 |
125 | # Keep track of metrics
126 | losses.update(loss.item())
127 |
128 | # Print status
129 | if i % print_freq == 0:
130 | logger.info('Epoch: [{0}][{1}/{2}]\t'
131 | 'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, len(train_loader), loss=losses))
132 |
133 | return losses.avg
134 |
135 |
136 | def valid(valid_loader, model, logger):
137 | model.eval()
138 |
139 | losses = AverageMeter()
140 |
141 | # Batches
142 | for data in tqdm(valid_loader):
143 | # Move to GPU, if available
144 | padded_input, padded_target, input_lengths = data
145 | padded_input = padded_input.to(device)
146 | padded_target = padded_target.to(device)
147 | input_lengths = input_lengths.to(device)
148 |
149 | with torch.no_grad():
150 | # Forward prop.
151 | pred, gold = model(padded_input, input_lengths, padded_target)
152 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
153 |
154 | # Keep track of metrics
155 | losses.update(loss.item())
156 |
157 | # Print status
158 | logger.info('\nValidation Loss {loss.val:.5f} ({loss.avg:.5f})\n'.format(loss=losses))
159 |
160 | return losses.avg
161 |
162 |
163 | def main():
164 | global args
165 | args = parse_args()
166 | train_net(args)
167 |
168 |
169 | if __name__ == '__main__':
170 | main()
171 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Speech-Transformer/cf917db8c219e837e9392177a5d385c9f2b60b0d/transformer/__init__.py
--------------------------------------------------------------------------------
/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class MultiHeadAttention(nn.Module):
7 | ''' Multi-Head Attention module '''
8 |
9 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
10 | super().__init__()
11 |
12 | self.n_head = n_head
13 | self.d_k = d_k
14 | self.d_v = d_v
15 |
16 | self.w_qs = nn.Linear(d_model, n_head * d_k)
17 | self.w_ks = nn.Linear(d_model, n_head * d_k)
18 | self.w_vs = nn.Linear(d_model, n_head * d_v)
19 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
20 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
21 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
22 |
23 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5),
24 | attn_dropout=dropout)
25 | self.layer_norm = nn.LayerNorm(d_model)
26 |
27 | self.fc = nn.Linear(n_head * d_v, d_model)
28 | nn.init.xavier_normal_(self.fc.weight)
29 |
30 | self.dropout = nn.Dropout(dropout)
31 |
32 | def forward(self, q, k, v, mask=None):
33 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
34 |
35 | sz_b, len_q, _ = q.size()
36 | sz_b, len_k, _ = k.size()
37 | sz_b, len_v, _ = v.size()
38 |
39 | residual = q
40 |
41 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
42 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
43 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
44 |
45 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
46 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
47 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
48 |
49 | if mask is not None:
50 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
51 |
52 | output, attn = self.attention(q, k, v, mask=mask)
53 |
54 | output = output.view(n_head, sz_b, len_q, d_v)
55 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
56 |
57 | output = self.dropout(self.fc(output))
58 | output = self.layer_norm(output + residual)
59 |
60 | return output, attn
61 |
62 |
63 | class ScaledDotProductAttention(nn.Module):
64 | ''' Scaled Dot-Product Attention '''
65 |
66 | def __init__(self, temperature, attn_dropout=0.1):
67 | super().__init__()
68 | self.temperature = temperature
69 | self.dropout = nn.Dropout(attn_dropout)
70 | self.softmax = nn.Softmax(dim=2)
71 |
72 | def forward(self, q, k, v, mask=None):
73 | attn = torch.bmm(q, k.transpose(1, 2))
74 | attn = attn / self.temperature
75 |
76 | if mask is not None:
77 | attn = attn.masked_fill(mask.bool(), -np.inf)
78 |
79 | attn = self.softmax(attn)
80 | attn = self.dropout(attn)
81 | output = torch.bmm(attn, v)
82 |
83 | return output, attn
84 |
--------------------------------------------------------------------------------
/transformer/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from config import IGNORE_ID
6 | from .attention import MultiHeadAttention
7 | from .module import PositionalEncoding, PositionwiseFeedForward
8 | from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list
9 |
10 |
11 | # filename = 'bigram_freq.pkl'
12 | # print('loading {}...'.format(filename))
13 | # with open(filename, 'rb') as file:
14 | # bigram_freq = pickle.load(file)
15 |
16 |
17 | class Decoder(nn.Module):
18 | ''' A decoder model with self attention mechanism. '''
19 |
20 | def __init__(
21 | self, sos_id=0, eos_id=1,
22 | n_tgt_vocab=4335, d_word_vec=512,
23 | n_layers=6, n_head=8, d_k=64, d_v=64,
24 | d_model=512, d_inner=2048, dropout=0.1,
25 | tgt_emb_prj_weight_sharing=True,
26 | pe_maxlen=5000):
27 | super(Decoder, self).__init__()
28 | # parameters
29 | self.sos_id = sos_id # Start of Sentence
30 | self.eos_id = eos_id # End of Sentence
31 | self.n_tgt_vocab = n_tgt_vocab
32 | self.d_word_vec = d_word_vec
33 | self.n_layers = n_layers
34 | self.n_head = n_head
35 | self.d_k = d_k
36 | self.d_v = d_v
37 | self.d_model = d_model
38 | self.d_inner = d_inner
39 | self.dropout = dropout
40 | self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing
41 | self.pe_maxlen = pe_maxlen
42 |
43 | self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec)
44 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
45 | self.dropout = nn.Dropout(dropout)
46 |
47 | self.layer_stack = nn.ModuleList([
48 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
49 | for _ in range(n_layers)])
50 |
51 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
52 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
53 |
54 | if tgt_emb_prj_weight_sharing:
55 | # Share the weight matrix between target word embedding & the final logit dense layer
56 | self.tgt_word_prj.weight = self.tgt_word_emb.weight
57 | self.x_logit_scale = (d_model ** -0.5)
58 | else:
59 | self.x_logit_scale = 1.
60 |
61 | def preprocess(self, padded_input):
62 | """Generate decoder input and output label from padded_input
63 | Add to decoder input, and add to decoder output label
64 | """
65 | ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys
66 | # prepare input and output word sequences with sos/eos IDs
67 | eos = ys[0].new([self.eos_id])
68 | sos = ys[0].new([self.sos_id])
69 | ys_in = [torch.cat([sos, y], dim=0) for y in ys]
70 | ys_out = [torch.cat([y, eos], dim=0) for y in ys]
71 | # padding for ys with -1
72 | # pys: utt x olen
73 | ys_in_pad = pad_list(ys_in, self.eos_id)
74 | ys_out_pad = pad_list(ys_out, IGNORE_ID)
75 | assert ys_in_pad.size() == ys_out_pad.size()
76 | return ys_in_pad, ys_out_pad
77 |
78 | def forward(self, padded_input, encoder_padded_outputs,
79 | encoder_input_lengths, return_attns=False):
80 | """
81 | Args:
82 | padded_input: N x To
83 | encoder_padded_outputs: N x Ti x H
84 | Returns:
85 | """
86 | dec_slf_attn_list, dec_enc_attn_list = [], []
87 |
88 | # Get Deocder Input and Output
89 | ys_in_pad, ys_out_pad = self.preprocess(padded_input)
90 |
91 | # Prepare masks
92 | non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id)
93 |
94 | slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad)
95 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
96 | seq_q=ys_in_pad,
97 | pad_idx=self.eos_id)
98 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
99 |
100 | output_length = ys_in_pad.size(1)
101 | dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
102 | encoder_input_lengths,
103 | output_length)
104 |
105 | # Forward
106 | dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
107 | self.positional_encoding(ys_in_pad))
108 |
109 | for dec_layer in self.layer_stack:
110 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
111 | dec_output, encoder_padded_outputs,
112 | non_pad_mask=non_pad_mask,
113 | slf_attn_mask=slf_attn_mask,
114 | dec_enc_attn_mask=dec_enc_attn_mask)
115 |
116 | if return_attns:
117 | dec_slf_attn_list += [dec_slf_attn]
118 | dec_enc_attn_list += [dec_enc_attn]
119 |
120 | # before softmax
121 | seq_logit = self.tgt_word_prj(dec_output)
122 |
123 | # Return
124 | pred, gold = seq_logit, ys_out_pad
125 |
126 | if return_attns:
127 | return pred, gold, dec_slf_attn_list, dec_enc_attn_list
128 | return pred, gold
129 |
130 | def recognize_beam(self, encoder_outputs, char_list, args):
131 | """Beam search, decode one utterence now.
132 | Args:
133 | encoder_outputs: T x H
134 | char_list: list of character
135 | args: args.beam
136 | Returns:
137 | nbest_hyps:
138 | """
139 | # search params
140 | beam = args.beam_size
141 | nbest = args.nbest
142 | if args.decode_max_len == 0:
143 | maxlen = encoder_outputs.size(0)
144 | else:
145 | maxlen = args.decode_max_len
146 |
147 | encoder_outputs = encoder_outputs.unsqueeze(0)
148 |
149 | # prepare sos
150 | ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
151 |
152 | # yseq: 1xT
153 | hyp = {'score': 0.0, 'yseq': ys}
154 | hyps = [hyp]
155 | ended_hyps = []
156 |
157 | for i in range(maxlen):
158 | hyps_best_kept = []
159 | for hyp in hyps:
160 | ys = hyp['yseq'] # 1 x i
161 | # last_id = ys.cpu().numpy()[0][-1]
162 | # freq = bigram_freq[last_id]
163 | # freq = torch.log(torch.from_numpy(freq))
164 | # # print(freq.dtype)
165 | # freq = freq.type(torch.float).to(device)
166 | # print(freq.dtype)
167 | # print('freq.size(): ' + str(freq.size()))
168 | # print('freq: ' + str(freq))
169 | # -- Prepare masks
170 | non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
171 | slf_attn_mask = get_subsequent_mask(ys)
172 |
173 | # -- Forward
174 | dec_output = self.dropout(
175 | self.tgt_word_emb(ys) * self.x_logit_scale +
176 | self.positional_encoding(ys))
177 |
178 | for dec_layer in self.layer_stack:
179 | dec_output, _, _ = dec_layer(
180 | dec_output, encoder_outputs,
181 | non_pad_mask=non_pad_mask,
182 | slf_attn_mask=slf_attn_mask,
183 | dec_enc_attn_mask=None)
184 |
185 | seq_logit = self.tgt_word_prj(dec_output[:, -1])
186 | # local_scores = F.log_softmax(seq_logit, dim=1)
187 | local_scores = F.log_softmax(seq_logit, dim=1)
188 | # print('local_scores.size(): ' + str(local_scores.size()))
189 | # local_scores += freq
190 | # print('local_scores: ' + str(local_scores))
191 |
192 | # topk scores
193 | local_best_scores, local_best_ids = torch.topk(
194 | local_scores, beam, dim=1)
195 |
196 | for j in range(beam):
197 | new_hyp = {}
198 | new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
199 | new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long()
200 | new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
201 | new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
202 | # will be (2 x beam) hyps at most
203 | hyps_best_kept.append(new_hyp)
204 |
205 | hyps_best_kept = sorted(hyps_best_kept,
206 | key=lambda x: x['score'],
207 | reverse=True)[:beam]
208 | # end for hyp in hyps
209 | hyps = hyps_best_kept
210 |
211 | # add eos in the final loop to avoid that there are no ended hyps
212 | if i == maxlen - 1:
213 | for hyp in hyps:
214 | hyp['yseq'] = torch.cat([hyp['yseq'],
215 | torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()],
216 | dim=1)
217 |
218 | # add ended hypothes to a final list, and removed them from current hypothes
219 | # (this will be a probmlem, number of hyps < beam)
220 | remained_hyps = []
221 | for hyp in hyps:
222 | if hyp['yseq'][0, -1] == self.eos_id:
223 | ended_hyps.append(hyp)
224 | else:
225 | remained_hyps.append(hyp)
226 |
227 | hyps = remained_hyps
228 | # if len(hyps) > 0:
229 | # print('remeined hypothes: ' + str(len(hyps)))
230 | # else:
231 | # print('no hypothesis. Finish decoding.')
232 | # break
233 | #
234 | # for hyp in hyps:
235 | # print('hypo: ' + ''.join([char_list[int(x)]
236 | # for x in hyp['yseq'][0, 1:]]))
237 | # end for i in range(maxlen)
238 | nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
239 | :min(len(ended_hyps), nbest)]
240 | # compitable with LAS implementation
241 | for hyp in nbest_hyps:
242 | hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
243 | return nbest_hyps
244 |
245 |
246 | class DecoderLayer(nn.Module):
247 | ''' Compose with three layers '''
248 |
249 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
250 | super(DecoderLayer, self).__init__()
251 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
252 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
253 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
254 |
255 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
256 | dec_output, dec_slf_attn = self.slf_attn(
257 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
258 | dec_output *= non_pad_mask
259 |
260 | dec_output, dec_enc_attn = self.enc_attn(
261 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
262 | dec_output *= non_pad_mask
263 |
264 | dec_output = self.pos_ffn(dec_output)
265 | dec_output *= non_pad_mask
266 |
267 | return dec_output, dec_slf_attn, dec_enc_attn
268 |
--------------------------------------------------------------------------------
/transformer/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .attention import MultiHeadAttention
4 | from .module import PositionalEncoding, PositionwiseFeedForward
5 | from .utils import get_non_pad_mask, get_attn_pad_mask
6 |
7 |
8 | class Encoder(nn.Module):
9 | """Encoder of Transformer including self-attention and feed forward.
10 | """
11 |
12 | def __init__(self, d_input=320, n_layers=6, n_head=8, d_k=64, d_v=64,
13 | d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000):
14 | super(Encoder, self).__init__()
15 | # parameters
16 | self.d_input = d_input
17 | self.n_layers = n_layers
18 | self.n_head = n_head
19 | self.d_k = d_k
20 | self.d_v = d_v
21 | self.d_model = d_model
22 | self.d_inner = d_inner
23 | self.dropout_rate = dropout
24 | self.pe_maxlen = pe_maxlen
25 |
26 | # use linear transformation with layer norm to replace input embedding
27 | self.linear_in = nn.Linear(d_input, d_model)
28 | self.layer_norm_in = nn.LayerNorm(d_model)
29 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
30 | self.dropout = nn.Dropout(dropout)
31 |
32 | self.layer_stack = nn.ModuleList([
33 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
34 | for _ in range(n_layers)])
35 |
36 | def forward(self, padded_input, input_lengths, return_attns=False):
37 | """
38 | Args:
39 | padded_input: N x T x D
40 | input_lengths: N
41 | Returns:
42 | enc_output: N x T x H
43 | """
44 | enc_slf_attn_list = []
45 |
46 | # Prepare masks
47 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
48 | length = padded_input.size(1)
49 | slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length)
50 |
51 | # Forward
52 | enc_output = self.dropout(
53 | self.layer_norm_in(self.linear_in(padded_input)) +
54 | self.positional_encoding(padded_input))
55 |
56 | for enc_layer in self.layer_stack:
57 | enc_output, enc_slf_attn = enc_layer(
58 | enc_output,
59 | non_pad_mask=non_pad_mask,
60 | slf_attn_mask=slf_attn_mask)
61 | if return_attns:
62 | enc_slf_attn_list += [enc_slf_attn]
63 |
64 | if return_attns:
65 | return enc_output, enc_slf_attn_list
66 | return enc_output,
67 |
68 |
69 | class EncoderLayer(nn.Module):
70 | """Compose with two sub-layers.
71 | 1. A multi-head self-attention mechanism
72 | 2. A simple, position-wise fully connected feed-forward network.
73 | """
74 |
75 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
76 | super(EncoderLayer, self).__init__()
77 | self.slf_attn = MultiHeadAttention(
78 | n_head, d_model, d_k, d_v, dropout=dropout)
79 | self.pos_ffn = PositionwiseFeedForward(
80 | d_model, d_inner, dropout=dropout)
81 |
82 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
83 | enc_output, enc_slf_attn = self.slf_attn(
84 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
85 | enc_output *= non_pad_mask
86 |
87 | enc_output = self.pos_ffn(enc_output)
88 | enc_output *= non_pad_mask
89 |
90 | return enc_output, enc_slf_attn
91 |
--------------------------------------------------------------------------------
/transformer/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from config import IGNORE_ID
5 |
6 |
7 | def cal_performance(pred, gold, smoothing=0.0):
8 | """Calculate cross entropy loss, apply label smoothing if needed.
9 | Args:
10 | pred: N x T x C, score before softmax
11 | gold: N x T
12 | """
13 |
14 | pred = pred.view(-1, pred.size(2))
15 | gold = gold.contiguous().view(-1)
16 |
17 | loss = cal_loss(pred, gold, smoothing)
18 |
19 | pred = pred.max(1)[1]
20 | non_pad_mask = gold.ne(IGNORE_ID)
21 | n_correct = pred.eq(gold)
22 | n_correct = n_correct.masked_select(non_pad_mask).sum().item()
23 |
24 | return loss, n_correct
25 |
26 |
27 | def cal_loss(pred, gold, smoothing=0.0):
28 | """Calculate cross entropy loss, apply label smoothing if needed.
29 | """
30 |
31 | if smoothing > 0.0:
32 | eps = smoothing
33 | n_class = pred.size(1)
34 |
35 | # Generate one-hot matrix: N x C.
36 | # Only label position is 1 and all other positions are 0
37 | # gold include -1 value (IGNORE_ID) and this will lead to assert error
38 | gold_for_scatter = gold.ne(IGNORE_ID).long() * gold
39 | one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
40 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class
41 | log_prb = F.log_softmax(pred, dim=1)
42 |
43 | non_pad_mask = gold.ne(IGNORE_ID)
44 | n_word = non_pad_mask.sum().item()
45 | loss = -(one_hot * log_prb).sum(dim=1)
46 | loss = loss.masked_select(non_pad_mask).sum() / n_word
47 | else:
48 | loss = F.cross_entropy(pred, gold,
49 | ignore_index=IGNORE_ID,
50 | reduction='elementwise_mean')
51 |
52 | return loss
53 |
--------------------------------------------------------------------------------
/transformer/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class PositionalEncoding(nn.Module):
9 | """Implement the positional encoding (PE) function.
10 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
11 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
12 | """
13 |
14 | def __init__(self, d_model, max_len=5000):
15 | super(PositionalEncoding, self).__init__()
16 | # Compute the positional encodings once in log space.
17 | pe = torch.zeros(max_len, d_model, requires_grad=False)
18 | position = torch.arange(0, max_len).unsqueeze(1).float()
19 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
20 | -(math.log(10000.0) / d_model))
21 | pe[:, 0::2] = torch.sin(position * div_term)
22 | pe[:, 1::2] = torch.cos(position * div_term)
23 | pe = pe.unsqueeze(0)
24 | self.register_buffer('pe', pe)
25 |
26 | def forward(self, input):
27 | """
28 | Args:
29 | input: N x T x D
30 | """
31 | length = input.size(1)
32 | return self.pe[:, :length]
33 |
34 |
35 | class PositionwiseFeedForward(nn.Module):
36 | """Implements position-wise feedforward sublayer.
37 | FFN(x) = max(0, xW1 + b1)W2 + b2
38 | """
39 |
40 | def __init__(self, d_model, d_ff, dropout=0.1):
41 | super(PositionwiseFeedForward, self).__init__()
42 | self.w_1 = nn.Linear(d_model, d_ff)
43 | self.w_2 = nn.Linear(d_ff, d_model)
44 | self.dropout = nn.Dropout(dropout)
45 | self.layer_norm = nn.LayerNorm(d_model)
46 |
47 | def forward(self, x):
48 | residual = x
49 | output = self.w_2(F.relu(self.w_1(x)))
50 | output = self.dropout(output)
51 | output = self.layer_norm(output + residual)
52 | return output
53 |
54 |
55 | # Another implementation
56 | class PositionwiseFeedForwardUseConv(nn.Module):
57 | """A two-feed-forward-layer module"""
58 |
59 | def __init__(self, d_in, d_hid, dropout=0.1):
60 | super(PositionwiseFeedForwardUseConv, self).__init__()
61 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
62 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
63 | self.layer_norm = nn.LayerNorm(d_in)
64 | self.dropout = nn.Dropout(dropout)
65 |
66 | def forward(self, x):
67 | residual = x
68 | output = x.transpose(1, 2)
69 | output = self.w_2(F.relu(self.w_1(output)))
70 | output = output.transpose(1, 2)
71 | output = self.dropout(output)
72 | output = self.layer_norm(output + residual)
73 | return output
74 |
--------------------------------------------------------------------------------
/transformer/optimizer.py:
--------------------------------------------------------------------------------
1 | class TransformerOptimizer(object):
2 | """A simple wrapper class for learning rate scheduling"""
3 |
4 | def __init__(self, optimizer, warmup_steps=4000, k=0.2):
5 | self.optimizer = optimizer
6 | self.k = k
7 | self.warmup_steps = warmup_steps
8 | d_model = 512
9 | self.init_lr = d_model ** (-0.5)
10 | self.lr = self.init_lr
11 | self.warmup_steps = warmup_steps
12 | self.k = k
13 | self.step_num = 0
14 |
15 | def zero_grad(self):
16 | self.optimizer.zero_grad()
17 |
18 | def step(self):
19 | self._update_lr()
20 | self.optimizer.step()
21 |
22 | def _update_lr(self):
23 | self.step_num += 1
24 | self.lr = self.k * self.init_lr * min(self.step_num ** (-0.5),
25 | self.step_num * (self.warmup_steps ** (-1.5)))
26 | for param_group in self.optimizer.param_groups:
27 | param_group['lr'] = self.lr
28 |
--------------------------------------------------------------------------------
/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .decoder import Decoder
4 | from .encoder import Encoder
5 |
6 |
7 | class Transformer(nn.Module):
8 | """An encoder-decoder framework only includes attention.
9 | """
10 |
11 | def __init__(self, encoder=None, decoder=None):
12 | super(Transformer, self).__init__()
13 |
14 | if encoder is not None and decoder is not None:
15 | self.encoder = encoder
16 | self.decoder = decoder
17 |
18 | for p in self.parameters():
19 | if p.dim() > 1:
20 | nn.init.xavier_uniform_(p)
21 | else:
22 | self.encoder = Encoder()
23 | self.decoder = Decoder()
24 |
25 | def forward(self, padded_input, input_lengths, padded_target):
26 | """
27 | Args:
28 | padded_input: N x Ti x D
29 | input_lengths: N
30 | padded_targets: N x To
31 | """
32 | encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
33 | # pred is score before softmax
34 | pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs,
35 | input_lengths)
36 | return pred, gold
37 |
38 | def recognize(self, input, input_length, char_list, args):
39 | """Sequence-to-Sequence beam search, decode one utterence now.
40 | Args:
41 | input: T x D
42 | char_list: list of characters
43 | args: args.beam
44 | Returns:
45 | nbest_hyps:
46 | """
47 | encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length)
48 | nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0],
49 | char_list,
50 | args)
51 | return nbest_hyps
52 |
--------------------------------------------------------------------------------
/transformer/utils.py:
--------------------------------------------------------------------------------
1 | def pad_list(xs, pad_value):
2 | # From: espnet/src/nets/e2e_asr_th.py: pad_list()
3 | n_batch = len(xs)
4 | max_len = max(x.size(0) for x in xs)
5 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
6 | for i in range(n_batch):
7 | pad[i, :xs[i].size(0)] = xs[i]
8 | return pad
9 |
10 |
11 | def process_dict(dict_path):
12 | with open(dict_path, 'rb') as f:
13 | dictionary = f.readlines()
14 | char_list = [entry.decode('utf-8').split(' ')[0]
15 | for entry in dictionary]
16 | sos_id = char_list.index('')
17 | eos_id = char_list.index('')
18 | return char_list, sos_id, eos_id
19 |
20 |
21 | if __name__ == "__main__":
22 | import sys
23 |
24 | path = sys.argv[1]
25 | char_list, sos_id, eos_id = process_dict(path)
26 | print(char_list, sos_id, eos_id)
27 |
28 |
29 | # * ------------------ recognition related ------------------ *
30 |
31 |
32 | def parse_hypothesis(hyp, char_list):
33 | """Function to parse hypothesis
34 | :param list hyp: recognition hypothesis
35 | :param list char_list: list of characters
36 | :return: recognition text strinig
37 | :return: recognition token strinig
38 | :return: recognition tokenid string
39 | """
40 | # remove sos and get results
41 | tokenid_as_list = list(map(int, hyp['yseq'][1:]))
42 | token_as_list = [char_list[idx] for idx in tokenid_as_list]
43 | score = float(hyp['score'])
44 |
45 | # convert to string
46 | tokenid = " ".join([str(idx) for idx in tokenid_as_list])
47 | token = " ".join(token_as_list)
48 | text = "".join(token_as_list).replace('', ' ')
49 |
50 | return text, token, tokenid, score
51 |
52 |
53 | def add_results_to_json(js, nbest_hyps, char_list):
54 | """Function to add N-best results to json
55 | :param dict js: groundtruth utterance dict
56 | :param list nbest_hyps: list of hypothesis
57 | :param list char_list: list of characters
58 | :return: N-best results added utterance dict
59 | """
60 | # copy old json info
61 | new_js = dict()
62 | new_js['utt2spk'] = js['utt2spk']
63 | new_js['output'] = []
64 |
65 | for n, hyp in enumerate(nbest_hyps, 1):
66 | # parse hypothesis
67 | rec_text, rec_token, rec_tokenid, score = parse_hypothesis(
68 | hyp, char_list)
69 |
70 | # copy ground-truth
71 | out_dic = dict(js['output'][0].items())
72 |
73 | # update name
74 | out_dic['name'] += '[%d]' % n
75 |
76 | # add recognition results
77 | out_dic['rec_text'] = rec_text
78 | out_dic['rec_token'] = rec_token
79 | out_dic['rec_tokenid'] = rec_tokenid
80 | out_dic['score'] = score
81 |
82 | # add to list of N-best result dicts
83 | new_js['output'].append(out_dic)
84 |
85 | # show 1-best result
86 | if n == 1:
87 | print('groundtruth: %s' % out_dic['text'])
88 | print('prediction : %s' % out_dic['rec_text'])
89 |
90 | return new_js
91 |
92 |
93 | # -- Transformer Related --
94 | import torch
95 |
96 |
97 | def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None):
98 | """padding position is set to 0, either use input_lengths or pad_idx
99 | """
100 | assert input_lengths is not None or pad_idx is not None
101 | if input_lengths is not None:
102 | # padded_input: N x T x ..
103 | N = padded_input.size(0)
104 | non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # N x T
105 | for i in range(N):
106 | non_pad_mask[i, input_lengths[i]:] = 0
107 | if pad_idx is not None:
108 | # padded_input: N x T
109 | assert padded_input.dim() == 2
110 | non_pad_mask = padded_input.ne(pad_idx).float()
111 | # unsqueeze(-1) for broadcast
112 | return non_pad_mask.unsqueeze(-1)
113 |
114 |
115 | def get_subsequent_mask(seq):
116 | ''' For masking out the subsequent info. '''
117 |
118 | sz_b, len_s = seq.size()
119 | subsequent_mask = torch.triu(
120 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
121 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
122 |
123 | return subsequent_mask
124 |
125 |
126 | def get_attn_key_pad_mask(seq_k, seq_q, pad_idx):
127 | ''' For masking out the padding part of key sequence. '''
128 |
129 | # Expand to fit the shape of key query attention matrix.
130 | len_q = seq_q.size(1)
131 | padding_mask = seq_k.eq(pad_idx)
132 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
133 |
134 | return padding_mask
135 |
136 |
137 | def get_attn_pad_mask(padded_input, input_lengths, expand_length):
138 | """mask position is set to 1"""
139 | # N x Ti x 1
140 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
141 | # N x Ti, lt(1) like not operation
142 | pad_mask = non_pad_mask.squeeze(-1).lt(1)
143 | attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
144 | return attn_mask
145 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import librosa
5 | import numpy as np
6 | import torch
7 | from config import sample_rate
8 |
9 |
10 | def clip_gradient(optimizer, grad_clip):
11 | """
12 | Clips gradients computed during backpropagation to avoid explosion of gradients.
13 | :param optimizer: optimizer with the gradients to be clipped
14 | :param grad_clip: clip value
15 | """
16 | for group in optimizer.param_groups:
17 | for param in group['params']:
18 | if param.grad is not None:
19 | param.grad.data.clamp_(-grad_clip, grad_clip)
20 |
21 |
22 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best):
23 | state = {'epoch': epoch,
24 | 'epochs_since_improvement': epochs_since_improvement,
25 | 'loss': loss,
26 | 'model': model,
27 | 'optimizer': optimizer}
28 |
29 | filename = 'checkpoint.tar'
30 | torch.save(state, filename)
31 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
32 | if is_best:
33 | torch.save(state, 'BEST_checkpoint.tar')
34 |
35 |
36 | class AverageMeter(object):
37 | """
38 | Keeps track of most recent, average, sum, and count of a metric.
39 | """
40 |
41 | def __init__(self):
42 | self.reset()
43 |
44 | def reset(self):
45 | self.val = 0
46 | self.avg = 0
47 | self.sum = 0
48 | self.count = 0
49 |
50 | def update(self, val, n=1):
51 | self.val = val
52 | self.sum += val * n
53 | self.count += n
54 | self.avg = self.sum / self.count
55 |
56 |
57 | def adjust_learning_rate(optimizer, shrink_factor):
58 | """
59 | Shrinks learning rate by a specified factor.
60 | :param optimizer: optimizer whose learning rate must be shrunk.
61 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
62 | """
63 |
64 | print("\nDECAYING learning rate.")
65 | for param_group in optimizer.param_groups:
66 | param_group['lr'] = param_group['lr'] * shrink_factor
67 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
68 |
69 |
70 | def accuracy(scores, targets, k=1):
71 | batch_size = targets.size(0)
72 | _, ind = scores.topk(k, 1, True, True)
73 | correct = ind.eq(targets.view(-1, 1).expand_as(ind))
74 | correct_total = correct.view(-1).float().sum() # 0D tensor
75 | return correct_total.item() * (100.0 / batch_size)
76 |
77 |
78 | def parse_args():
79 | parser = argparse.ArgumentParser(description='Speech Transformer')
80 | # Low Frame Rate (stacking and skipping frames)
81 | parser.add_argument('--LFR_m', default=4, type=int,
82 | help='Low Frame Rate: number of frames to stack')
83 | parser.add_argument('--LFR_n', default=3, type=int,
84 | help='Low Frame Rate: number of frames to skip')
85 | # Network architecture
86 | # encoder
87 | # TODO: automatically infer input dim
88 | parser.add_argument('--d_input', default=80, type=int,
89 | help='Dim of encoder input (before LFR)')
90 | parser.add_argument('--n_layers_enc', default=6, type=int,
91 | help='Number of encoder stacks')
92 | parser.add_argument('--n_head', default=8, type=int,
93 | help='Number of Multi Head Attention (MHA)')
94 | parser.add_argument('--d_k', default=64, type=int,
95 | help='Dimension of key')
96 | parser.add_argument('--d_v', default=64, type=int,
97 | help='Dimension of value')
98 | parser.add_argument('--d_model', default=512, type=int,
99 | help='Dimension of model')
100 | parser.add_argument('--d_inner', default=2048, type=int,
101 | help='Dimension of inner')
102 | parser.add_argument('--dropout', default=0.1, type=float,
103 | help='Dropout rate')
104 | parser.add_argument('--pe_maxlen', default=5000, type=int,
105 | help='Positional Encoding max len')
106 | # decoder
107 | parser.add_argument('--d_word_vec', default=512, type=int,
108 | help='Dim of decoder embedding')
109 | parser.add_argument('--n_layers_dec', default=6, type=int,
110 | help='Number of decoder stacks')
111 | parser.add_argument('--tgt_emb_prj_weight_sharing', default=1, type=int,
112 | help='share decoder embedding with decoder projection')
113 | # Loss
114 | parser.add_argument('--label_smoothing', default=0.1, type=float,
115 | help='label smoothing')
116 |
117 | # Training config
118 | parser.add_argument('--epochs', default=150, type=int,
119 | help='Number of maximum epochs')
120 | # minibatch
121 | parser.add_argument('--shuffle', default=1, type=int,
122 | help='reshuffle the data at every epoch')
123 | parser.add_argument('--batch-size', default=32, type=int,
124 | help='Batch size')
125 | parser.add_argument('--batch_frames', default=0, type=int,
126 | help='Batch frames. If this is not 0, batch size will make no sense')
127 | parser.add_argument('--maxlen-in', default=800, type=int, metavar='ML',
128 | help='Batch size is reduced if the input sequence length > ML')
129 | parser.add_argument('--maxlen-out', default=150, type=int, metavar='ML',
130 | help='Batch size is reduced if the output sequence length > ML')
131 | parser.add_argument('--num-workers', default=4, type=int,
132 | help='Number of workers to generate minibatch')
133 | # optimizer
134 | parser.add_argument('--lr', default=0.001, type=float,
135 | help='learning rate')
136 | parser.add_argument('--k', default=0.2, type=float,
137 | help='tunable scalar multiply to learning rate')
138 | parser.add_argument('--warmup_steps', default=4000, type=int,
139 | help='warmup steps')
140 |
141 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint')
142 |
143 | parser.add_argument('--n_samples', default="train:-1,dev:-1,test:-1", type=str,
144 | help='choose the number of examples to use')
145 | args = parser.parse_args()
146 |
147 | return args
148 |
149 |
150 | def get_logger():
151 | logger = logging.getLogger()
152 | handler = logging.StreamHandler()
153 | formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s")
154 | handler.setFormatter(formatter)
155 | logger.addHandler(handler)
156 | logger.setLevel(logging.DEBUG)
157 | return logger
158 |
159 |
160 | def ensure_folder(folder):
161 | import os
162 | if not os.path.isdir(folder):
163 | os.mkdir(folder)
164 |
165 |
166 | def pad_list(xs, pad_value):
167 | # From: espnet/src/nets/e2e_asr_th.py: pad_list()
168 | n_batch = len(xs)
169 | max_len = max(x.size(0) for x in xs)
170 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
171 | for i in range(n_batch):
172 | pad[i, :xs[i].size(0)] = xs[i]
173 | return pad
174 |
175 |
176 | # [-0.5, 0.5]
177 | def normalize(yt):
178 | yt_max = np.max(yt)
179 | yt_min = np.min(yt)
180 | a = 1.0 / (yt_max - yt_min)
181 | b = -(yt_max + yt_min) / (2 * (yt_max - yt_min))
182 |
183 | yt = yt * a + b
184 | return yt
185 |
186 |
187 | # Acoustic Feature Extraction
188 | # Parameters
189 | # - input file : str, audio file path
190 | # - feature : str, fbank or mfcc
191 | # - dim : int, dimension of feature
192 | # - cmvn : bool, apply CMVN on feature
193 | # - window_size : int, window size for FFT (ms)
194 | # - stride : int, window stride for FFT
195 | # - save_feature: str, if given, store feature to the path and return len(feature)
196 | # Return
197 | # acoustic features with shape (time step, dim)
198 | def extract_feature(input_file, feature='fbank', dim=80, cmvn=True, delta=False, delta_delta=False,
199 | window_size=25, stride=10, save_feature=None):
200 | y, sr = librosa.load(input_file, sr=sample_rate)
201 | yt, _ = librosa.effects.trim(y, top_db=20)
202 | yt = normalize(yt)
203 | ws = int(sr * 0.001 * window_size)
204 | st = int(sr * 0.001 * stride)
205 | if feature == 'fbank': # log-scaled
206 | feat = librosa.feature.melspectrogram(y=yt, sr=sr, n_mels=dim,
207 | n_fft=ws, hop_length=st)
208 | feat = np.log(feat + 1e-6)
209 | elif feature == 'mfcc':
210 | feat = librosa.feature.mfcc(y=yt, sr=sr, n_mfcc=dim, n_mels=26,
211 | n_fft=ws, hop_length=st)
212 | feat[0] = librosa.feature.rmse(yt, hop_length=st, frame_length=ws)
213 |
214 | else:
215 | raise ValueError('Unsupported Acoustic Feature: ' + feature)
216 |
217 | feat = [feat]
218 | if delta:
219 | feat.append(librosa.feature.delta(feat[0]))
220 |
221 | if delta_delta:
222 | feat.append(librosa.feature.delta(feat[0], order=2))
223 | feat = np.concatenate(feat, axis=0)
224 | if cmvn:
225 | feat = (feat - feat.mean(axis=1)[:, np.newaxis]) / (feat.std(axis=1) + 1e-16)[:, np.newaxis]
226 | if save_feature is not None:
227 | tmp = np.swapaxes(feat, 0, 1).astype('float32')
228 | np.save(save_feature, tmp)
229 | return len(tmp)
230 | else:
231 | return np.swapaxes(feat, 0, 1).astype('float32')
232 |
--------------------------------------------------------------------------------
/xer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(
4 | format='%(levelname)s(%(filename)s:%(lineno)d): %(message)s')
5 |
6 |
7 | def levenshtein(u, v):
8 | prev = None
9 | curr = [0] + list(range(1, len(v) + 1))
10 | # Operations: (SUB, DEL, INS)
11 | prev_ops = None
12 | curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
13 | for x in range(1, len(u) + 1):
14 | prev, curr = curr, [x] + ([None] * len(v))
15 | prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
16 | for y in range(1, len(v) + 1):
17 | delcost = prev[y] + 1
18 | addcost = curr[y - 1] + 1
19 | subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
20 | curr[y] = min(subcost, delcost, addcost)
21 | if curr[y] == subcost:
22 | (n_s, n_d, n_i) = prev_ops[y - 1]
23 | curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
24 | elif curr[y] == delcost:
25 | (n_s, n_d, n_i) = prev_ops[y]
26 | curr_ops[y] = (n_s, n_d + 1, n_i)
27 | else:
28 | (n_s, n_d, n_i) = curr_ops[y - 1]
29 | curr_ops[y] = (n_s, n_d, n_i + 1)
30 | return curr[len(v)], curr_ops[len(v)]
31 |
32 |
33 | def load_file(fname, encoding):
34 | try:
35 | f = open(fname, 'r')
36 | data = []
37 | for line in f:
38 | data.append(line.rstrip('\n').rstrip('\r').decode(encoding))
39 | f.close()
40 | except:
41 | logging.error('Error reading file "%s"', fname)
42 | exit(1)
43 | return data
44 |
45 |
46 | def cer_function(ref, hyp):
47 | wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0
48 | cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
49 | sen_err = 0
50 | for n in range(len(ref)):
51 | # update CER statistics
52 | _, (s, i, d) = levenshtein(ref[n], hyp[n])
53 | cer_s += s
54 | cer_i += i
55 | cer_d += d
56 | cer_n += len(ref[n])
57 | # update WER statistics
58 | _, (s, i, d) = levenshtein(ref[n].split(), hyp[n].split())
59 | wer_s += s
60 | wer_i += i
61 | wer_d += d
62 | wer_n += len(ref[n].split())
63 | # update SER statistics
64 | if s + i + d > 0:
65 | sen_err += 1
66 |
67 | print(cer_s, cer_i, cer_d, cer_n)
68 | return (cer_s + cer_i + cer_d) / cer_n
69 |
70 |
71 | if __name__ == '__main__':
72 | ref = ['天然气用户为优先允许限制类和禁止类']
73 | hyp = ['天然气用户为优先允许限制类和禁止量内']
74 | cer_function = cer_function(ref, hyp)
75 | print(cer_function)
76 |
--------------------------------------------------------------------------------