├── .gitignore
├── LICENSE
├── README.md
├── bleu_score.py
├── config.py
├── convert_valid.py
├── data_gen.py
├── demo.py
├── export.py
├── extract.py
├── images
└── dataset.png
├── pre_process.py
├── test
├── test_bleu.py
└── test_lr.py
├── train.py
├── transformer
├── __init__.py
├── attention.py
├── decoder.py
├── encoder.py
├── loss.py
├── module.py
├── optimizer.py
├── transformer.py
└── utils.py
├── utils.py
└── vocab.pkl
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__/
3 | data/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 英中机器文本翻译
2 |
3 | 评测英中文本机器翻译的能力。机器翻译语言方向为英文到中文。
4 |
5 |
6 | ## 依赖
7 |
8 | - Python 3.6.8
9 | - PyTorch 1.3.0
10 |
11 | ## 数据集
12 |
13 | 我们使用AI Challenger 2017中的英中机器文本翻译数据集,超过1000万的英中对照的句子对作为数据集合。其中,训练集合占据绝大部分,为12904955对,验证集合8000对,测试集A 8000条,测试集B 8000条。
14 |
15 | 可以从这里下载:[英中翻译数据集](https://challenger.ai/datasets/translation)
16 |
17 | 
18 |
19 | ## 用法
20 |
21 | ### 数据预处理
22 | 提取训练和验证样本:
23 | ```bash
24 | $ python pre_process.py
25 | ```
26 |
27 | ### 训练
28 | ```bash
29 | $ python train.py
30 | ```
31 |
32 | 要想可视化训练过程,在终端中运行:
33 | ```bash
34 | $ tensorboard --logdir path_to_current_dir/logs
35 | ```
36 |
37 | ### Demo
38 | 下载 [预训练模型](https://github.com/foamliu/Scene-Classification/releases/download/v1.0/model.85-0.7657.hdf5) 然后执行:
39 |
40 | ```bash
41 | $ python demo.py
42 | ```
43 |
44 | 下面第一行是英文例句(数据集),第二行是人翻中文例句(数据集),之后一行是机翻(本模型)中文句子(实时生成)。
45 |
46 |
47 |
48 | < there are places in the world where girls do n t get educated simply because they are girls .
49 | = 在这个世界上某些地方的女孩,不能得到教育仅仅因为她是女孩。
50 | > 世界上有一些女孩没有受过教育,仅仅是因为她们是女孩。
51 | < i ve noticed this black van parked outside the house every day .
52 | = 我注意到有辆黑色的车每天都停在房子外面。
53 | > 我每天都注意到一辆黑色货车停在房子外面。
54 | < i mean i feel like i m okay cause i passed on the crazy .
55 | = 我是说,我觉得不错因为已经从疯狂里解脱了。
56 | > 我是说,我感觉很好因为我错过了疯狂。
57 | < i do n t know . i was having this dream . it s all right . i m here now .
58 | = 我不知道。我刚才做了个梦。没事了。我在这里。
59 | > 我不知道。我做了个梦。没关系。我在这儿。
60 | < and if the girls see you up here they re gon na take it
61 | = 如果她们看到你坐这她们会觉得
62 | > 如果女孩们看到你在这里,他们会接受的
63 | < getting filled out like an application constitutes being with a guy .
64 | = 愿意跟小男孩在一起。
65 | > 就像和男人在一起一样。
66 | < every single brother in my fraternity has worn this suit .
67 | = 兄弟会里每个人都穿成这样。
68 | > 我兄弟会的每个兄弟都穿这件衣服。
69 | < but i m here in town checking out some real estate and
70 | = 但我进城来看房子,以及-
71 | > 但我在城里,查一些房地产,
72 | < hey i was gon na come and see you today and say hi .
73 | = 嘿,我今天打算过去看看你打个招呼。
74 | > 嘿,我今天想去看你然后打个招呼。
75 | < it was a good show . stop saying it was a good show . shh !
76 | = 今晚的节目很好。不要再说它好了!
77 | > 是个好节目。别说是好节目了。嘘!
78 |
79 |
80 |
--------------------------------------------------------------------------------
/bleu_score.py:
--------------------------------------------------------------------------------
1 | # import the necessary packages
2 | import pickle
3 | import time
4 |
5 | import numpy as np
6 | import torch
7 | from nltk.translate.bleu_score import sentence_bleu
8 | from tqdm import tqdm
9 |
10 | from config import device, logger, data_file, vocab_file
11 | from transformer.transformer import Transformer
12 |
13 | if __name__ == '__main__':
14 | filename = 'transformer.pt'
15 | print('loading {}...'.format(filename))
16 | start = time.time()
17 | model = Transformer()
18 | model.load_state_dict(torch.load(filename))
19 | print('elapsed {} sec'.format(time.time() - start))
20 | model = model.to(device)
21 | model.eval()
22 |
23 | logger.info('loading samples...')
24 | start = time.time()
25 | with open(data_file, 'rb') as file:
26 | data = pickle.load(file)
27 | samples = data['valid']
28 | elapsed = time.time() - start
29 | logger.info('elapsed: {:.4f} seconds'.format(elapsed))
30 |
31 | logger.info('loading vocab...')
32 | start = time.time()
33 | with open(vocab_file, 'rb') as file:
34 | data = pickle.load(file)
35 | src_idx2char = data['dict']['src_idx2char']
36 | tgt_idx2char = data['dict']['tgt_idx2char']
37 | elapsed = time.time() - start
38 | logger.info('elapsed: {:.4f} seconds'.format(elapsed))
39 |
40 | # samples = random.sample(samples, 10)
41 | bleu_scores = []
42 |
43 | for sample in tqdm(samples):
44 | sentence_in = sample['in']
45 | sentence_out = sample['out']
46 |
47 | input = torch.from_numpy(np.array(sentence_in, dtype=np.long)).to(device)
48 | input_length = torch.LongTensor([len(sentence_in)]).to(device)
49 |
50 | sentence_in = ' '.join([src_idx2char[idx] for idx in sentence_in])
51 | sentence_out = ''.join([tgt_idx2char[idx] for idx in sentence_out])
52 | sentence_out = sentence_out.replace('', '').replace('', '')
53 | # print('< ' + sentence_in)
54 | # print('= ' + sentence_out)
55 |
56 | try:
57 | with torch.no_grad():
58 | nbest_hyps = model.recognize(input=input, input_length=input_length, char_list=tgt_idx2char)
59 | # print(nbest_hyps)
60 | except RuntimeError:
61 | print('sentence_in: ' + sentence_in)
62 | continue
63 |
64 | score_list = []
65 | for hyp in nbest_hyps:
66 | out = hyp['yseq']
67 | out = [tgt_idx2char[idx] for idx in out]
68 | out = ''.join(out)
69 | out = out.replace('', '').replace('', '')
70 | reference = list(sentence_out)
71 | hypothesis = list(out)
72 | score = sentence_bleu([reference], hypothesis)
73 | score_list.append(score)
74 |
75 | bleu_scores.append(max(score_list))
76 |
77 | print('len(bleu_scores): ' + str(len(bleu_scores)))
78 | print('np.max(bleu_scores): ' + str(np.max(bleu_scores)))
79 | print('np.min(bleu_scores): ' + str(np.min(bleu_scores)))
80 | print('np.mean(bleu_scores): ' + str(np.mean(bleu_scores)))
81 |
82 | import numpy as np
83 | import matplotlib.pyplot as plt
84 |
85 | # Fixing random state for reproducibility
86 | np.random.seed(19680801)
87 |
88 | # the histogram of the data
89 | n, bins, patches = plt.hist(bleu_scores, 50, density=True, facecolor='g', alpha=0.75)
90 |
91 | plt.xlabel('Scores')
92 | plt.ylabel('Probability')
93 | plt.title('BLEU Scores')
94 | plt.grid(True)
95 | plt.show()
96 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors
6 |
7 | # Model parameters
8 | d_model = 512
9 | epochs = 10000
10 | embedding_size = 300
11 | hidden_size = 1024
12 | data_file = 'data.pkl'
13 | vocab_file = 'vocab.pkl'
14 | n_src_vocab = 15000
15 | n_tgt_vocab = 15000 # target
16 | maxlen_in = 100
17 | maxlen_out = 50
18 | # Training parameters
19 | grad_clip = 1.0 # clip gradients at an absolute value of
20 | print_freq = 50 # print training/validation stats every __ batches
21 | checkpoint = None # path to checkpoint, None if none
22 |
23 | # Data parameters
24 | IGNORE_ID = -1
25 | pad_id = 0
26 | sos_id = 1
27 | eos_id = 2
28 | unk_id = 3
29 | num_train = 4669414
30 | num_valid = 3870
31 |
32 | train_translation_en_filename = 'data/ai_challenger_translation_train_20170904/translation_train_data_20170904/train.en'
33 | train_translation_zh_filename = 'data/ai_challenger_translation_train_20170904/translation_train_data_20170904/train.zh'
34 | valid_translation_en_filename = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en'
35 | valid_translation_zh_filename = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.zh'
36 |
37 |
38 | def get_logger():
39 | logger = logging.getLogger()
40 | handler = logging.StreamHandler()
41 | formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(threadName)s] %(name)s: %(message)s")
42 | handler.setFormatter(formatter)
43 | logger.addHandler(handler)
44 | logger.setLevel(logging.INFO)
45 | return logger
46 |
47 |
48 | logger = get_logger()
49 |
--------------------------------------------------------------------------------
/convert_valid.py:
--------------------------------------------------------------------------------
1 | import xml.etree.ElementTree
2 |
3 | valid_en_old = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en-zh.en.sgm'
4 | valid_en_new = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en'
5 | valid_zh_old = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en-zh.zh.sgm'
6 | valid_zh_new = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.zh'
7 |
8 |
9 | def convert(old, new):
10 | print('old: ' + old)
11 | print('new: ' + new)
12 | with open(old, 'r', encoding='utf-8') as f:
13 | data = f.readlines()
14 | data = [line.replace(' & ', ' & ') for line in data]
15 | with open(new, 'w', encoding='utf-8') as f:
16 | f.writelines(data)
17 |
18 | root = xml.etree.ElementTree.parse(new).getroot()
19 | data = [elem.text.strip() + '\n' for elem in root.iter() if elem.tag == 'seg']
20 | with open(new, 'w', encoding='utf-8') as file:
21 | file.writelines(data)
22 |
23 |
24 | if __name__ == '__main__':
25 | convert(valid_en_old, valid_en_new)
26 | convert(valid_zh_old, valid_zh_new)
27 |
--------------------------------------------------------------------------------
/data_gen.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import time
3 |
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from torch.utils.data.dataloader import default_collate
7 |
8 | from config import data_file, vocab_file, IGNORE_ID, pad_id, logger
9 |
10 | logger.info('loading samples...')
11 | start = time.time()
12 | with open(data_file, 'rb') as file:
13 | data = pickle.load(file)
14 | elapsed = time.time() - start
15 | logger.info('elapsed: {:.4f}'.format(elapsed))
16 |
17 |
18 | def get_data(filename):
19 | with open(filename, 'r') as file:
20 | data = file.readlines()
21 | data = [line.strip() for line in data]
22 | return data
23 |
24 |
25 | def pad_collate(batch):
26 | max_input_len = float('-inf')
27 | max_target_len = float('-inf')
28 |
29 | for elem in batch:
30 | src, tgt = elem
31 | max_input_len = max_input_len if max_input_len > len(src) else len(src)
32 | max_target_len = max_target_len if max_target_len > len(tgt) else len(tgt)
33 |
34 | for i, elem in enumerate(batch):
35 | src, tgt = elem
36 | input_length = len(src)
37 | padded_input = np.pad(src, (0, max_input_len - len(src)), 'constant', constant_values=pad_id)
38 | padded_target = np.pad(tgt, (0, max_target_len - len(tgt)), 'constant', constant_values=IGNORE_ID)
39 | batch[i] = (padded_input, padded_target, input_length)
40 |
41 | # sort it by input lengths (long to short)
42 | batch.sort(key=lambda x: x[2], reverse=True)
43 |
44 | return default_collate(batch)
45 |
46 |
47 | class AiChallenger2017Dataset(Dataset):
48 | def __init__(self, split):
49 | self.samples = data[split]
50 |
51 | def __getitem__(self, i):
52 | sample = self.samples[i]
53 | src_text = sample['in']
54 | tgt_text = sample['out']
55 |
56 | return np.array(src_text, dtype=np.long), np.array(tgt_text, np.long)
57 |
58 | def __len__(self):
59 | return len(self.samples)
60 |
61 |
62 | def main():
63 | from utils import sequence_to_text
64 |
65 | valid_dataset = AiChallenger2017Dataset('valid')
66 | print(valid_dataset[0])
67 |
68 | with open(vocab_file, 'rb') as file:
69 | data = pickle.load(file)
70 |
71 | src_idx2char = data['dict']['src_idx2char']
72 | tgt_idx2char = data['dict']['tgt_idx2char']
73 |
74 | src_text, tgt_text = valid_dataset[0]
75 | src_text = sequence_to_text(src_text, src_idx2char)
76 | src_text = ' '.join(src_text)
77 | print('src_text: ' + src_text)
78 |
79 | tgt_text = sequence_to_text(tgt_text, tgt_idx2char)
80 | tgt_text = ''.join(tgt_text)
81 | print('tgt_text: ' + tgt_text)
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # import the necessary packages
2 | import pickle
3 | import random
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from config import device, logger, data_file, vocab_file
10 | from transformer.transformer import Transformer
11 |
12 | if __name__ == '__main__':
13 | filename = 'transformer.pt'
14 | print('loading {}...'.format(filename))
15 | start = time.time()
16 | model = Transformer()
17 | model.load_state_dict(torch.load(filename))
18 | print('elapsed {} sec'.format(time.time() - start))
19 | model = model.to(device)
20 | model.eval()
21 |
22 | logger.info('loading samples...')
23 | start = time.time()
24 | with open(data_file, 'rb') as file:
25 | data = pickle.load(file)
26 | samples = data['valid']
27 | elapsed = time.time() - start
28 | logger.info('elapsed: {:.4f} seconds'.format(elapsed))
29 |
30 | logger.info('loading vocab...')
31 | start = time.time()
32 | with open(vocab_file, 'rb') as file:
33 | data = pickle.load(file)
34 | src_idx2char = data['dict']['src_idx2char']
35 | tgt_idx2char = data['dict']['tgt_idx2char']
36 | elapsed = time.time() - start
37 | logger.info('elapsed: {:.4f} seconds'.format(elapsed))
38 |
39 | samples = random.sample(samples, 10)
40 |
41 | for sample in samples:
42 | sentence_in = sample['in']
43 | sentence_out = sample['out']
44 |
45 | input = torch.from_numpy(np.array(sentence_in, dtype=np.long)).to(device)
46 | input_length = torch.LongTensor([len(sentence_in)]).to(device)
47 |
48 | sentence_in = ' '.join([src_idx2char[idx] for idx in sentence_in])
49 | sentence_out = ''.join([tgt_idx2char[idx] for idx in sentence_out])
50 | sentence_out = sentence_out.replace('', '').replace('', '')
51 | print('< ' + sentence_in)
52 | print('= ' + sentence_out)
53 |
54 | with torch.no_grad():
55 | nbest_hyps = model.recognize(input=input, input_length=input_length, char_list=tgt_idx2char)
56 | # print(nbest_hyps)
57 |
58 | for hyp in nbest_hyps:
59 | out = hyp['yseq']
60 | out = [tgt_idx2char[idx] for idx in out]
61 | out = ''.join(out)
62 | out = out.replace('', '').replace('', '')
63 |
64 | print('> {}'.format(out))
65 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 |
5 | from transformer.transformer import Transformer
6 |
7 | if __name__ == '__main__':
8 | checkpoint = 'BEST_checkpoint.tar'
9 | print('loading {}...'.format(checkpoint))
10 | start = time.time()
11 | checkpoint = torch.load(checkpoint)
12 | print('elapsed {} sec'.format(time.time() - start))
13 | model = checkpoint['model']
14 | print(type(model))
15 |
16 | filename = 'transformer.pt'
17 | print('saving {}...'.format(filename))
18 | start = time.time()
19 | torch.save(model.state_dict(), filename)
20 | print('elapsed {} sec'.format(time.time() - start))
21 |
22 | print('loading {}...'.format(filename))
23 | start = time.time()
24 | model = Transformer()
25 | model.load_state_dict(torch.load(filename))
26 | print('elapsed {} sec'.format(time.time() - start))
27 |
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 |
3 |
4 | def extract(filename):
5 | print('Extracting {}...'.format(filename))
6 | with zipfile.ZipFile(filename, 'r') as zip_ref:
7 | zip_ref.extractall('data')
8 |
9 |
10 | if __name__ == '__main__':
11 | extract('data/ai_challenger_translation_train_20170904.zip')
12 | extract('data/ai_challenger_translation_validation_20170912.zip')
13 |
--------------------------------------------------------------------------------
/images/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/images/dataset.png
--------------------------------------------------------------------------------
/pre_process.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from collections import Counter
3 |
4 | import jieba
5 | import nltk
6 | from tqdm import tqdm
7 |
8 | from config import train_translation_en_filename, train_translation_zh_filename, valid_translation_en_filename, \
9 | valid_translation_zh_filename, vocab_file, maxlen_in, maxlen_out, data_file, sos_id, eos_id, n_src_vocab, \
10 | n_tgt_vocab, unk_id
11 | from utils import normalizeString, encode_text
12 |
13 |
14 | def build_vocab(token, word2idx, idx2char):
15 | if token not in word2idx:
16 | next_index = len(word2idx)
17 | word2idx[token] = next_index
18 | idx2char[next_index] = token
19 |
20 |
21 | def process(file, lang='zh'):
22 | print('processing {}...'.format(file))
23 | with open(file, 'r', encoding='utf-8') as f:
24 | data = f.readlines()
25 |
26 | word_freq = Counter()
27 | lengths = []
28 |
29 | for line in tqdm(data):
30 | sentence = line.strip()
31 | if lang == 'en':
32 | sentence_en = sentence.lower()
33 | tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en)]
34 | word_freq.update(list(tokens))
35 | vocab_size = n_src_vocab
36 | else:
37 | seg_list = jieba.cut(sentence.strip())
38 | tokens = list(seg_list)
39 | word_freq.update(list(tokens))
40 | vocab_size = n_tgt_vocab
41 |
42 | lengths.append(len(tokens))
43 |
44 | words = word_freq.most_common(vocab_size - 4)
45 | word_map = {k[0]: v + 4 for v, k in enumerate(words)}
46 | word_map[''] = 0
47 | word_map[''] = 1
48 | word_map[''] = 2
49 | word_map[''] = 3
50 | print(len(word_map))
51 | print(words[:100])
52 | #
53 | # n, bins, patches = plt.hist(lengths, 50, density=True, facecolor='g', alpha=0.75)
54 | #
55 | # plt.xlabel('Lengths')
56 | # plt.ylabel('Probability')
57 | # plt.title('Histogram of Lengths')
58 | # plt.grid(True)
59 | # plt.show()
60 |
61 | word2idx = word_map
62 | idx2char = {v: k for k, v in word2idx.items()}
63 |
64 | return word2idx, idx2char
65 |
66 |
67 | def get_data(in_file, out_file):
68 | print('getting data {}->{}...'.format(in_file, out_file))
69 | with open(in_file, 'r', encoding='utf-8') as file:
70 | in_lines = file.readlines()
71 | with open(out_file, 'r', encoding='utf-8') as file:
72 | out_lines = file.readlines()
73 |
74 | samples = []
75 |
76 | for i in tqdm(range(len(in_lines))):
77 | sentence_en = in_lines[i].strip().lower()
78 | tokens = [normalizeString(s.strip()) for s in nltk.word_tokenize(sentence_en)]
79 | in_data = encode_text(src_char2idx, tokens)
80 |
81 | sentence_zh = out_lines[i].strip()
82 | tokens = jieba.cut(sentence_zh.strip())
83 | out_data = [sos_id] + encode_text(tgt_char2idx, tokens) + [eos_id]
84 |
85 | if len(in_data) < maxlen_in and len(out_data) < maxlen_out and unk_id not in in_data and unk_id not in out_data:
86 | samples.append({'in': in_data, 'out': out_data})
87 | return samples
88 |
89 |
90 | if __name__ == '__main__':
91 | src_char2idx, src_idx2char = process(train_translation_en_filename, lang='en')
92 | tgt_char2idx, tgt_idx2char = process(train_translation_zh_filename, lang='zh')
93 |
94 | print(len(src_char2idx))
95 | print(len(tgt_char2idx))
96 |
97 | data = {
98 | 'dict': {
99 | 'src_char2idx': src_char2idx,
100 | 'src_idx2char': src_idx2char,
101 | 'tgt_char2idx': tgt_char2idx,
102 | 'tgt_idx2char': tgt_idx2char
103 | }
104 | }
105 | with open(vocab_file, 'wb') as file:
106 | pickle.dump(data, file)
107 |
108 | train = get_data(train_translation_en_filename, train_translation_zh_filename)
109 | valid = get_data(valid_translation_en_filename, valid_translation_zh_filename)
110 |
111 | data = {
112 | 'train': train,
113 | 'valid': valid
114 | }
115 |
116 | print('num_train: ' + str(len(train)))
117 | print('num_valid: ' + str(len(valid)))
118 |
119 | with open(data_file, 'wb') as file:
120 | pickle.dump(data, file)
121 |
--------------------------------------------------------------------------------
/test/test_bleu.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu
2 |
3 | reference = list('她的故事在法国遥远的西部山上')
4 | hypothesis = list('她的故事在法国的遥远山')
5 | score = sentence_bleu([reference], hypothesis)
6 | print(score)
7 |
--------------------------------------------------------------------------------
/test/test_lr.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | from config import d_model
4 |
5 | if __name__ == '__main__':
6 | warmup_steps = 4000
7 | init_lr = d_model ** (-0.5)
8 |
9 | lr_list = []
10 | for step_num in range(1, 500000):
11 | # print(step_num)
12 | lr = init_lr * min(step_num ** (-0.65), step_num * (warmup_steps ** (-1.5)))
13 |
14 | lr_list.append(lr)
15 |
16 | plt.plot(lr_list)
17 | plt.show()
18 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | from config import device, print_freq, sos_id, eos_id, n_src_vocab, n_tgt_vocab, grad_clip, logger
9 | from data_gen import AiChallenger2017Dataset, pad_collate
10 | from transformer.decoder import Decoder
11 | from transformer.encoder import Encoder
12 | from transformer.loss import cal_performance
13 | from transformer.optimizer import TransformerOptimizer
14 | from transformer.transformer import Transformer
15 | from utils import parse_args, save_checkpoint, AverageMeter, clip_gradient
16 |
17 |
18 | # from torch import nn
19 |
20 |
21 | def train_net(args):
22 | torch.manual_seed(7)
23 | np.random.seed(7)
24 | checkpoint = args.checkpoint
25 | start_epoch = 0
26 | best_loss = float('inf')
27 | writer = SummaryWriter()
28 | epochs_since_improvement = 0
29 |
30 | # Initialize / load checkpoint
31 | if checkpoint is None:
32 | # model
33 | encoder = Encoder(n_src_vocab, args.n_layers_enc, args.n_head,
34 | args.d_k, args.d_v, args.d_model, args.d_inner,
35 | dropout=args.dropout, pe_maxlen=args.pe_maxlen)
36 | decoder = Decoder(sos_id, eos_id, n_tgt_vocab,
37 | args.d_word_vec, args.n_layers_dec, args.n_head,
38 | args.d_k, args.d_v, args.d_model, args.d_inner,
39 | dropout=args.dropout,
40 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
41 | pe_maxlen=args.pe_maxlen)
42 | model = Transformer(encoder, decoder)
43 | # print(model)
44 | # model = nn.DataParallel(model)
45 |
46 | # optimizer
47 | optimizer = TransformerOptimizer(
48 | torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09))
49 |
50 | else:
51 | checkpoint = torch.load(checkpoint)
52 | start_epoch = checkpoint['epoch'] + 1
53 | epochs_since_improvement = checkpoint['epochs_since_improvement']
54 | model = checkpoint['model']
55 | optimizer = checkpoint['optimizer']
56 |
57 | # Move to GPU, if available
58 | model = model.to(device)
59 |
60 | # Custom dataloaders
61 | train_dataset = AiChallenger2017Dataset('train')
62 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
63 | shuffle=True, num_workers=args.num_workers)
64 | valid_dataset = AiChallenger2017Dataset('valid')
65 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
66 | shuffle=False, num_workers=args.num_workers)
67 |
68 | # Epochs
69 | for epoch in range(start_epoch, args.epochs):
70 | # One epoch's training
71 | train_loss = train(train_loader=train_loader,
72 | model=model,
73 | optimizer=optimizer,
74 | epoch=epoch,
75 | logger=logger,
76 | writer=writer)
77 |
78 | writer.add_scalar('epoch/train_loss', train_loss, epoch)
79 | writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch)
80 |
81 | print('\nLearning rate: {}'.format(optimizer.lr))
82 | print('Step num: {}\n'.format(optimizer.step_num))
83 |
84 | # One epoch's validation
85 | valid_loss = valid(valid_loader=valid_loader,
86 | model=model,
87 | logger=logger)
88 | writer.add_scalar('epoch/valid_loss', valid_loss, epoch)
89 |
90 | # Check if there was an improvement
91 | is_best = valid_loss < best_loss
92 | best_loss = min(valid_loss, best_loss)
93 | if not is_best:
94 | epochs_since_improvement += 1
95 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
96 | else:
97 | epochs_since_improvement = 0
98 |
99 | # Save checkpoint
100 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
101 |
102 |
103 | def train(train_loader, model, optimizer, epoch, logger, writer):
104 | model.train() # train mode (dropout and batchnorm is used)
105 |
106 | losses = AverageMeter()
107 | times = AverageMeter()
108 |
109 | start = time.time()
110 |
111 | # Batches
112 | for i, (data) in enumerate(train_loader):
113 | # Move to GPU, if available
114 | padded_input, padded_target, input_lengths = data
115 | padded_input = padded_input.to(device)
116 | padded_target = padded_target.to(device)
117 | input_lengths = input_lengths.to(device)
118 |
119 | # Forward prop.
120 | pred, gold = model(padded_input, input_lengths, padded_target)
121 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
122 | try:
123 | assert (not math.isnan(loss.item()))
124 | except AssertionError:
125 | print('n_correct: ' + str(n_correct))
126 | print('data: ' + str(n_correct))
127 | continue
128 |
129 | # Back prop.
130 | optimizer.zero_grad()
131 | loss.backward()
132 |
133 | # Clip gradients
134 | clip_gradient(optimizer.optimizer, grad_clip)
135 |
136 | # Update weights
137 | optimizer.step()
138 |
139 | # Keep track of metrics
140 | elapsed = time.time() - start
141 | start = time.time()
142 |
143 | losses.update(loss.item())
144 | times.update(elapsed)
145 |
146 | # Print status
147 | if i % print_freq == 0:
148 | logger.info('Epoch: [{0}][{1}/{2}]\t'
149 | 'Batch time {time.val:.5f} ({time.avg:.5f})\t'
150 | 'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, len(train_loader), time=times,
151 | loss=losses))
152 | writer.add_scalar('step_num/train_loss', losses.avg, optimizer.step_num)
153 | writer.add_scalar('step_num/learning_rate', optimizer.lr, optimizer.step_num)
154 |
155 | return losses.avg
156 |
157 |
158 | def valid(valid_loader, model, logger):
159 | model.eval()
160 |
161 | losses = AverageMeter()
162 |
163 | # Batches
164 | for data in valid_loader:
165 | # Move to GPU, if available
166 | padded_input, padded_target, input_lengths = data
167 | padded_input = padded_input.to(device)
168 | padded_target = padded_target.to(device)
169 | input_lengths = input_lengths.to(device)
170 |
171 | with torch.no_grad():
172 | # Forward prop.
173 | pred, gold = model(padded_input, input_lengths, padded_target)
174 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
175 | try:
176 | assert (not math.isnan(loss.item()))
177 | except AssertionError:
178 | print('n_correct: ' + str(n_correct))
179 | print('data: ' + str(n_correct))
180 | continue
181 |
182 | # Keep track of metrics
183 | losses.update(loss.item())
184 |
185 | # Print status
186 | logger.info('\nValidation Loss {loss.val:.5f} ({loss.avg:.5f})\n'.format(loss=losses))
187 |
188 | return losses.avg
189 |
190 |
191 | def main():
192 | global args
193 | args = parse_args()
194 | train_net(args)
195 |
196 |
197 | if __name__ == '__main__':
198 | main()
199 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/transformer/__init__.py
--------------------------------------------------------------------------------
/transformer/attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class MultiHeadAttention(nn.Module):
7 | ''' Multi-Head Attention module '''
8 |
9 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
10 | super().__init__()
11 |
12 | self.n_head = n_head
13 | self.d_k = d_k
14 | self.d_v = d_v
15 |
16 | self.w_qs = nn.Linear(d_model, n_head * d_k)
17 | self.w_ks = nn.Linear(d_model, n_head * d_k)
18 | self.w_vs = nn.Linear(d_model, n_head * d_v)
19 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
20 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
21 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
22 |
23 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5),
24 | attn_dropout=dropout)
25 | self.layer_norm = nn.LayerNorm(d_model)
26 |
27 | self.fc = nn.Linear(n_head * d_v, d_model)
28 | nn.init.xavier_normal_(self.fc.weight)
29 |
30 | self.dropout = nn.Dropout(dropout)
31 |
32 | def forward(self, q, k, v, mask=None):
33 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
34 |
35 | sz_b, len_q, _ = q.size()
36 | sz_b, len_k, _ = k.size()
37 | sz_b, len_v, _ = v.size()
38 |
39 | residual = q
40 |
41 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
42 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
43 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
44 |
45 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
46 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
47 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
48 |
49 | if mask is not None:
50 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
51 |
52 | output, attn = self.attention(q, k, v, mask=mask)
53 |
54 | output = output.view(n_head, sz_b, len_q, d_v)
55 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
56 |
57 | output = self.dropout(self.fc(output))
58 | output = self.layer_norm(output + residual)
59 |
60 | return output, attn
61 |
62 |
63 | class ScaledDotProductAttention(nn.Module):
64 | ''' Scaled Dot-Product Attention '''
65 |
66 | def __init__(self, temperature, attn_dropout=0.1):
67 | super().__init__()
68 | self.temperature = temperature
69 | self.dropout = nn.Dropout(attn_dropout)
70 | self.softmax = nn.Softmax(dim=2)
71 |
72 | def forward(self, q, k, v, mask=None):
73 | attn = torch.bmm(q, k.transpose(1, 2))
74 | attn = attn / self.temperature
75 |
76 | if mask is not None:
77 | attn = attn.masked_fill(mask.bool(), -np.inf)
78 |
79 | attn = self.softmax(attn)
80 | attn = self.dropout(attn)
81 | output = torch.bmm(attn, v)
82 |
83 | return output, attn
84 |
--------------------------------------------------------------------------------
/transformer/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from config import IGNORE_ID, sos_id, eos_id, n_tgt_vocab
6 | from .attention import MultiHeadAttention
7 | from .module import PositionalEncoding, PositionwiseFeedForward
8 | from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list
9 |
10 |
11 | class Decoder(nn.Module):
12 | ''' A decoder model with self attention mechanism. '''
13 |
14 | def __init__(
15 | self, sos_id=sos_id, eos_id=eos_id,
16 | n_tgt_vocab=n_tgt_vocab, d_word_vec=512,
17 | n_layers=6, n_head=8, d_k=64, d_v=64,
18 | d_model=512, d_inner=2048, dropout=0.1,
19 | tgt_emb_prj_weight_sharing=True,
20 | pe_maxlen=5000):
21 | super(Decoder, self).__init__()
22 | # parameters
23 | self.sos_id = sos_id # Start of Sentence
24 | self.eos_id = eos_id # End of Sentence
25 | self.n_tgt_vocab = n_tgt_vocab
26 | self.d_word_vec = d_word_vec
27 | self.n_layers = n_layers
28 | self.n_head = n_head
29 | self.d_k = d_k
30 | self.d_v = d_v
31 | self.d_model = d_model
32 | self.d_inner = d_inner
33 | self.dropout = dropout
34 | self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing
35 | self.pe_maxlen = pe_maxlen
36 |
37 | self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec)
38 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
39 | self.dropout = nn.Dropout(dropout)
40 |
41 | self.layer_stack = nn.ModuleList([
42 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
43 | for _ in range(n_layers)])
44 |
45 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
46 | nn.init.xavier_normal_(self.tgt_word_prj.weight)
47 |
48 | if tgt_emb_prj_weight_sharing:
49 | # Share the weight matrix between target word embedding & the final logit dense layer
50 | self.tgt_word_prj.weight = self.tgt_word_emb.weight
51 | self.x_logit_scale = (d_model ** -0.5)
52 | else:
53 | self.x_logit_scale = 1.
54 |
55 | def preprocess(self, padded_input):
56 | """Generate decoder input and output label from padded_input
57 | Add to decoder input, and add to decoder output label
58 | """
59 | ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys
60 | # prepare input and output word sequences with sos/eos IDs
61 | eos = ys[0].new([self.eos_id])
62 | sos = ys[0].new([self.sos_id])
63 | ys_in = [torch.cat([sos, y], dim=0) for y in ys]
64 | ys_out = [torch.cat([y, eos], dim=0) for y in ys]
65 | # padding for ys with -1
66 | # pys: utt x olen
67 | ys_in_pad = pad_list(ys_in, self.eos_id)
68 | ys_out_pad = pad_list(ys_out, IGNORE_ID)
69 | assert ys_in_pad.size() == ys_out_pad.size()
70 | return ys_in_pad, ys_out_pad
71 |
72 | def forward(self, padded_input, encoder_padded_outputs,
73 | encoder_input_lengths, return_attns=False):
74 | """
75 | Args:
76 | padded_input: N x To
77 | encoder_padded_outputs: N x Ti x H
78 | Returns:
79 | """
80 | dec_slf_attn_list, dec_enc_attn_list = [], []
81 |
82 | # Get Deocder Input and Output
83 | ys_in_pad, ys_out_pad = self.preprocess(padded_input)
84 |
85 | # Prepare masks
86 | non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id)
87 |
88 | slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad)
89 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
90 | seq_q=ys_in_pad,
91 | pad_idx=self.eos_id)
92 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
93 |
94 | output_length = ys_in_pad.size(1)
95 | dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
96 | encoder_input_lengths,
97 | output_length)
98 |
99 | # Forward
100 | dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
101 | self.positional_encoding(ys_in_pad))
102 |
103 | for dec_layer in self.layer_stack:
104 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
105 | dec_output, encoder_padded_outputs,
106 | non_pad_mask=non_pad_mask,
107 | slf_attn_mask=slf_attn_mask,
108 | dec_enc_attn_mask=dec_enc_attn_mask)
109 |
110 | if return_attns:
111 | dec_slf_attn_list += [dec_slf_attn]
112 | dec_enc_attn_list += [dec_enc_attn]
113 |
114 | # before softmax
115 | seq_logit = self.tgt_word_prj(dec_output)
116 |
117 | # Return
118 | pred, gold = seq_logit, ys_out_pad
119 |
120 | if return_attns:
121 | return pred, gold, dec_slf_attn_list, dec_enc_attn_list
122 | return pred, gold
123 |
124 | def recognize_beam(self, encoder_outputs, char_list):
125 | """Beam search, decode one utterence now.
126 | Args:
127 | encoder_outputs: T x H
128 | char_list: list of character
129 | args: args.beam
130 | Returns:
131 | nbest_hyps:
132 | """
133 | # search params
134 | beam = 5
135 | nbest = 1
136 | maxlen = 100
137 |
138 | encoder_outputs = encoder_outputs.unsqueeze(0)
139 |
140 | # prepare sos
141 | ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
142 |
143 | # yseq: 1xT
144 | hyp = {'score': 0.0, 'yseq': ys}
145 | hyps = [hyp]
146 | ended_hyps = []
147 |
148 | for i in range(maxlen):
149 | hyps_best_kept = []
150 | for hyp in hyps:
151 | ys = hyp['yseq'] # 1 x i
152 |
153 | # -- Prepare masks
154 | non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
155 | slf_attn_mask = get_subsequent_mask(ys)
156 |
157 | # -- Forward
158 | dec_output = self.dropout(
159 | self.tgt_word_emb(ys) * self.x_logit_scale +
160 | self.positional_encoding(ys))
161 |
162 | for dec_layer in self.layer_stack:
163 | dec_output, _, _ = dec_layer(
164 | dec_output, encoder_outputs,
165 | non_pad_mask=non_pad_mask,
166 | slf_attn_mask=slf_attn_mask,
167 | dec_enc_attn_mask=None)
168 |
169 | seq_logit = self.tgt_word_prj(dec_output[:, -1])
170 |
171 | local_scores = F.log_softmax(seq_logit, dim=1)
172 | # topk scores
173 | local_best_scores, local_best_ids = torch.topk(
174 | local_scores, beam, dim=1)
175 |
176 | for j in range(beam):
177 | new_hyp = {}
178 | new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
179 | new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long()
180 | new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
181 | new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
182 | # will be (2 x beam) hyps at most
183 | hyps_best_kept.append(new_hyp)
184 |
185 | hyps_best_kept = sorted(hyps_best_kept,
186 | key=lambda x: x['score'],
187 | reverse=True)[:beam]
188 | # end for hyp in hyps
189 | hyps = hyps_best_kept
190 |
191 | # add eos in the final loop to avoid that there are no ended hyps
192 | if i == maxlen - 1:
193 | for hyp in hyps:
194 | hyp['yseq'] = torch.cat([hyp['yseq'],
195 | torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()],
196 | dim=1)
197 |
198 | # add ended hypothes to a final list, and removed them from current hypothes
199 | # (this will be a probmlem, number of hyps < beam)
200 | remained_hyps = []
201 | for hyp in hyps:
202 | if hyp['yseq'][0, -1] == self.eos_id:
203 | ended_hyps.append(hyp)
204 | else:
205 | remained_hyps.append(hyp)
206 |
207 | hyps = remained_hyps
208 | # if len(hyps) > 0:
209 | # print('remeined hypothes: ' + str(len(hyps)))
210 | # else:
211 | # print('no hypothesis. Finish decoding.')
212 | # break
213 | #
214 | # for hyp in hyps:
215 | # print('hypo: ' + ''.join([char_list[int(x)]
216 | # for x in hyp['yseq'][0, 1:]]))
217 | # end for i in range(maxlen)
218 | nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
219 | :min(len(ended_hyps), nbest)]
220 | # compitable with LAS implementation
221 | for hyp in nbest_hyps:
222 | hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
223 | return nbest_hyps
224 |
225 |
226 | class DecoderLayer(nn.Module):
227 | ''' Compose with three layers '''
228 |
229 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
230 | super(DecoderLayer, self).__init__()
231 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
232 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
233 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
234 |
235 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
236 | dec_output, dec_slf_attn = self.slf_attn(
237 | dec_input, dec_input, dec_input, mask=slf_attn_mask)
238 | dec_output *= non_pad_mask
239 |
240 | dec_output, dec_enc_attn = self.enc_attn(
241 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
242 | dec_output *= non_pad_mask
243 |
244 | dec_output = self.pos_ffn(dec_output)
245 | dec_output *= non_pad_mask
246 |
247 | return dec_output, dec_slf_attn, dec_enc_attn
248 |
--------------------------------------------------------------------------------
/transformer/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from config import pad_id, n_src_vocab
4 | from .attention import MultiHeadAttention
5 | from .module import PositionalEncoding, PositionwiseFeedForward
6 | from .utils import get_non_pad_mask, get_attn_pad_mask
7 |
8 |
9 | class Encoder(nn.Module):
10 | """Encoder of Transformer including self-attention and feed forward.
11 | """
12 |
13 | def __init__(self, n_src_vocab=n_src_vocab, n_layers=6, n_head=8, d_k=64, d_v=64,
14 | d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000):
15 | super(Encoder, self).__init__()
16 | # parameters
17 | self.n_src_vocab = n_src_vocab
18 | self.n_layers = n_layers
19 | self.n_head = n_head
20 | self.d_k = d_k
21 | self.d_v = d_v
22 | self.d_model = d_model
23 | self.d_inner = d_inner
24 | self.dropout_rate = dropout
25 | self.pe_maxlen = pe_maxlen
26 |
27 | self.src_emb = nn.Embedding(n_src_vocab, d_model, padding_idx=pad_id)
28 | self.pos_emb = PositionalEncoding(d_model, max_len=pe_maxlen)
29 | self.dropout = nn.Dropout(dropout)
30 |
31 | self.layer_stack = nn.ModuleList([
32 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
33 | for _ in range(n_layers)])
34 |
35 | def forward(self, padded_input, input_lengths, return_attns=False):
36 | """
37 | Args:
38 | padded_input: N x T
39 | input_lengths: N
40 | Returns:
41 | enc_output: N x T x H
42 | """
43 | enc_slf_attn_list = []
44 |
45 | # Forward
46 | enc_outputs = self.src_emb(padded_input)
47 | enc_outputs += self.pos_emb(enc_outputs)
48 | enc_output = self.dropout(enc_outputs)
49 |
50 | # Prepare masks
51 | non_pad_mask = get_non_pad_mask(enc_output, input_lengths=input_lengths)
52 | length = padded_input.size(1)
53 | slf_attn_mask = get_attn_pad_mask(enc_output, input_lengths, length)
54 |
55 | for enc_layer in self.layer_stack:
56 | enc_output, enc_slf_attn = enc_layer(
57 | enc_output,
58 | non_pad_mask=non_pad_mask,
59 | slf_attn_mask=slf_attn_mask)
60 | if return_attns:
61 | enc_slf_attn_list += [enc_slf_attn]
62 |
63 | if return_attns:
64 | return enc_output, enc_slf_attn_list
65 | return enc_output,
66 |
67 |
68 | class EncoderLayer(nn.Module):
69 | """Compose with two sub-layers.
70 | 1. A multi-head self-attention mechanism
71 | 2. A simple, position-wise fully connected feed-forward network.
72 | """
73 |
74 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
75 | super(EncoderLayer, self).__init__()
76 | self.slf_attn = MultiHeadAttention(
77 | n_head, d_model, d_k, d_v, dropout=dropout)
78 | self.pos_ffn = PositionwiseFeedForward(
79 | d_model, d_inner, dropout=dropout)
80 |
81 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
82 | enc_output, enc_slf_attn = self.slf_attn(
83 | enc_input, enc_input, enc_input, mask=slf_attn_mask)
84 | enc_output *= non_pad_mask
85 |
86 | enc_output = self.pos_ffn(enc_output)
87 | enc_output *= non_pad_mask
88 |
89 | return enc_output, enc_slf_attn
90 |
--------------------------------------------------------------------------------
/transformer/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from config import IGNORE_ID
5 |
6 |
7 | def cal_performance(pred, gold, smoothing=0.0):
8 | """Calculate cross entropy loss, apply label smoothing if needed.
9 | Args:
10 | pred: N x T x C, score before softmax
11 | gold: N x T
12 | """
13 |
14 | pred = pred.view(-1, pred.size(2))
15 | gold = gold.contiguous().view(-1)
16 |
17 | loss = cal_loss(pred, gold, smoothing)
18 |
19 | pred = pred.max(1)[1]
20 | non_pad_mask = gold.ne(IGNORE_ID)
21 | n_correct = pred.eq(gold)
22 | n_correct = n_correct.masked_select(non_pad_mask).sum().item()
23 |
24 | return loss, n_correct
25 |
26 |
27 | def cal_loss(pred, gold, smoothing=0.0):
28 | """Calculate cross entropy loss, apply label smoothing if needed.
29 | """
30 |
31 | if smoothing > 0.0:
32 | eps = smoothing
33 | n_class = pred.size(1)
34 |
35 | # Generate one-hot matrix: N x C.
36 | # Only label position is 1 and all other positions are 0
37 | # gold include -1 value (IGNORE_ID) and this will lead to assert error
38 | gold_for_scatter = gold.ne(IGNORE_ID).long() * gold
39 | one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
40 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class
41 | log_prb = F.log_softmax(pred, dim=1)
42 |
43 | non_pad_mask = gold.ne(IGNORE_ID)
44 | n_word = non_pad_mask.sum().item()
45 | loss = -(one_hot * log_prb).sum(dim=1)
46 | loss = loss.masked_select(non_pad_mask).sum() / n_word
47 | else:
48 | loss = F.cross_entropy(pred, gold, ignore_index=IGNORE_ID, reduction='mean')
49 |
50 | return loss
51 |
--------------------------------------------------------------------------------
/transformer/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class PositionalEncoding(nn.Module):
9 | """Implement the positional encoding (PE) function.
10 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
11 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
12 | """
13 |
14 | def __init__(self, d_model, max_len=5000):
15 | super(PositionalEncoding, self).__init__()
16 | # Compute the positional encodings once in log space.
17 | pe = torch.zeros(max_len, d_model, requires_grad=False)
18 | position = torch.arange(0, max_len).unsqueeze(1).float()
19 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
20 | -(math.log(10000.0) / d_model))
21 | pe[:, 0::2] = torch.sin(position * div_term)
22 | pe[:, 1::2] = torch.cos(position * div_term)
23 | pe = pe.unsqueeze(0)
24 | self.register_buffer('pe', pe)
25 |
26 | def forward(self, input):
27 | """
28 | Args:
29 | input: N x T x D
30 | """
31 | length = input.size(1)
32 | return self.pe[:, :length]
33 |
34 |
35 | class PositionwiseFeedForward(nn.Module):
36 | """Implements position-wise feedforward sublayer.
37 | FFN(x) = max(0, xW1 + b1)W2 + b2
38 | """
39 |
40 | def __init__(self, d_model, d_ff, dropout=0.1):
41 | super(PositionwiseFeedForward, self).__init__()
42 | self.w_1 = nn.Linear(d_model, d_ff)
43 | self.w_2 = nn.Linear(d_ff, d_model)
44 | self.dropout = nn.Dropout(dropout)
45 | self.layer_norm = nn.LayerNorm(d_model)
46 |
47 | def forward(self, x):
48 | residual = x
49 | output = self.w_2(F.relu(self.w_1(x)))
50 | output = self.dropout(output)
51 | output = self.layer_norm(output + residual)
52 | return output
53 |
54 |
55 | # Another implementation
56 | class PositionwiseFeedForwardUseConv(nn.Module):
57 | """A two-feed-forward-layer module"""
58 |
59 | def __init__(self, d_in, d_hid, dropout=0.1):
60 | super(PositionwiseFeedForwardUseConv, self).__init__()
61 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
62 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
63 | self.layer_norm = nn.LayerNorm(d_in)
64 | self.dropout = nn.Dropout(dropout)
65 |
66 | def forward(self, x):
67 | residual = x
68 | output = x.transpose(1, 2)
69 | output = self.w_2(F.relu(self.w_1(output)))
70 | output = output.transpose(1, 2)
71 | output = self.dropout(output)
72 | output = self.layer_norm(output + residual)
73 | return output
74 |
--------------------------------------------------------------------------------
/transformer/optimizer.py:
--------------------------------------------------------------------------------
1 | from config import d_model
2 |
3 |
4 | class TransformerOptimizer(object):
5 | """A simple wrapper class for learning rate scheduling"""
6 |
7 | def __init__(self, optimizer, warmup_steps=4000):
8 | self.optimizer = optimizer
9 | self.init_lr = d_model ** (-0.5)
10 | self.warmup_steps = warmup_steps
11 | self.lr = self.init_lr
12 | self.step_num = 0
13 |
14 | def zero_grad(self):
15 | self.optimizer.zero_grad()
16 |
17 | def step(self):
18 | self._update_lr()
19 | self.optimizer.step()
20 |
21 | def _update_lr(self):
22 | self.step_num += 1
23 | self.min_lr = 1e-5
24 | self.lr = self.init_lr * min(self.step_num ** (-0.65), self.step_num * (self.warmup_steps ** (-1.5)))
25 | self.lr = max(self.lr, self.min_lr)
26 | for param_group in self.optimizer.param_groups:
27 | param_group['lr'] = self.lr
28 |
--------------------------------------------------------------------------------
/transformer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .decoder import Decoder
4 | from .encoder import Encoder
5 |
6 |
7 | class Transformer(nn.Module):
8 | """An encoder-decoder framework only includes attention.
9 | """
10 |
11 | def __init__(self, encoder=None, decoder=None):
12 | super(Transformer, self).__init__()
13 | self.encoder = encoder
14 | self.decoder = decoder
15 |
16 | if encoder is not None and decoder is not None:
17 | self.encoder = encoder
18 | self.decoder = decoder
19 |
20 | for p in self.parameters():
21 | if p.dim() > 1:
22 | nn.init.xavier_uniform_(p)
23 | else:
24 | self.encoder = Encoder()
25 | self.decoder = Decoder()
26 |
27 | def forward(self, padded_input, input_lengths, padded_target):
28 | """
29 | Args:
30 | padded_input: N x Ti x D
31 | input_lengths: N
32 | padded_targets: N x To
33 | """
34 | encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
35 | # pred is score before softmax
36 | pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs,
37 | input_lengths)
38 | return pred, gold
39 |
40 | def recognize(self, input, input_length, char_list):
41 | """Sequence-to-Sequence beam search, decode one utterence now.
42 | Args:
43 | input: T x D
44 | char_list: list of characters
45 | args: args.beam
46 | Returns:
47 | nbest_hyps:
48 | """
49 | encoder_outputs, enc_slf_attn_list = self.encoder(padded_input=input.unsqueeze(0), input_lengths=input_length,
50 | return_attns=True)
51 | nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0], char_list)
52 | return nbest_hyps
53 |
--------------------------------------------------------------------------------
/transformer/utils.py:
--------------------------------------------------------------------------------
1 | def pad_list(xs, pad_value):
2 | # From: espnet/src/nets/e2e_asr_th.py: pad_list()
3 | n_batch = len(xs)
4 | max_len = max(x.size(0) for x in xs)
5 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
6 | for i in range(n_batch):
7 | pad[i, :xs[i].size(0)] = xs[i]
8 | return pad
9 |
10 |
11 | def process_dict(dict_path):
12 | with open(dict_path, 'rb') as f:
13 | dictionary = f.readlines()
14 | char_list = [entry.decode('utf-8').split(' ')[0]
15 | for entry in dictionary]
16 | sos_id = char_list.index('')
17 | eos_id = char_list.index('')
18 | return char_list, sos_id, eos_id
19 |
20 |
21 | if __name__ == "__main__":
22 | import sys
23 |
24 | path = sys.argv[1]
25 | char_list, sos_id, eos_id = process_dict(path)
26 | print(char_list, sos_id, eos_id)
27 |
28 |
29 | # * ------------------ recognition related ------------------ *
30 |
31 |
32 | def parse_hypothesis(hyp, char_list):
33 | """Function to parse hypothesis
34 | :param list hyp: recognition hypothesis
35 | :param list char_list: list of characters
36 | :return: recognition text strinig
37 | :return: recognition token strinig
38 | :return: recognition tokenid string
39 | """
40 | # remove sos and get results
41 | tokenid_as_list = list(map(int, hyp['yseq'][1:]))
42 | token_as_list = [char_list[idx] for idx in tokenid_as_list]
43 | score = float(hyp['score'])
44 |
45 | # convert to string
46 | tokenid = " ".join([str(idx) for idx in tokenid_as_list])
47 | token = " ".join(token_as_list)
48 | text = "".join(token_as_list).replace('', ' ')
49 |
50 | return text, token, tokenid, score
51 |
52 |
53 | def add_results_to_json(js, nbest_hyps, char_list):
54 | """Function to add N-best results to json
55 | :param dict js: groundtruth utterance dict
56 | :param list nbest_hyps: list of hypothesis
57 | :param list char_list: list of characters
58 | :return: N-best results added utterance dict
59 | """
60 | # copy old json info
61 | new_js = dict()
62 | new_js['utt2spk'] = js['utt2spk']
63 | new_js['output'] = []
64 |
65 | for n, hyp in enumerate(nbest_hyps, 1):
66 | # parse hypothesis
67 | rec_text, rec_token, rec_tokenid, score = parse_hypothesis(
68 | hyp, char_list)
69 |
70 | # copy ground-truth
71 | out_dic = dict(js['output'][0].items())
72 |
73 | # update name
74 | out_dic['name'] += '[%d]' % n
75 |
76 | # add recognition results
77 | out_dic['rec_text'] = rec_text
78 | out_dic['rec_token'] = rec_token
79 | out_dic['rec_tokenid'] = rec_tokenid
80 | out_dic['score'] = score
81 |
82 | # add to list of N-best result dicts
83 | new_js['output'].append(out_dic)
84 |
85 | # show 1-best result
86 | if n == 1:
87 | print('groundtruth: %s' % out_dic['text'])
88 | print('prediction : %s' % out_dic['rec_text'])
89 |
90 | return new_js
91 |
92 |
93 | # -- Transformer Related --
94 | import torch
95 |
96 |
97 | def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None):
98 | """padding position is set to 0, either use input_lengths or pad_idx
99 | """
100 | assert input_lengths is not None or pad_idx is not None
101 | if input_lengths is not None:
102 | # padded_input: N x T x ..
103 | N = padded_input.size(0)
104 | non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # N x T
105 | for i in range(N):
106 | non_pad_mask[i, input_lengths[i]:] = 0
107 | if pad_idx is not None:
108 | # padded_input: N x T
109 | assert padded_input.dim() == 2
110 | non_pad_mask = padded_input.ne(pad_idx).float()
111 | # unsqueeze(-1) for broadcast
112 | return non_pad_mask.unsqueeze(-1)
113 |
114 |
115 | def get_subsequent_mask(seq):
116 | ''' For masking out the subsequent info. '''
117 |
118 | sz_b, len_s = seq.size()
119 | subsequent_mask = torch.triu(
120 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
121 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
122 |
123 | return subsequent_mask
124 |
125 |
126 | def get_attn_key_pad_mask(seq_k, seq_q, pad_idx):
127 | ''' For masking out the padding part of key sequence. '''
128 |
129 | # Expand to fit the shape of key query attention matrix.
130 | len_q = seq_q.size(1)
131 | padding_mask = seq_k.eq(pad_idx)
132 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
133 |
134 | return padding_mask
135 |
136 |
137 | def get_attn_pad_mask(padded_input, input_lengths, expand_length):
138 | """mask position is set to 1"""
139 | # N x Ti x 1
140 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
141 | # N x Ti, lt(1) like not operation
142 | pad_mask = non_pad_mask.squeeze(-1).lt(1)
143 | attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
144 | return attn_mask
145 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import re
3 | import unicodedata
4 |
5 | import torch
6 |
7 |
8 | def clip_gradient(optimizer, grad_clip):
9 | """
10 | Clips gradients computed during backpropagation to avoid explosion of gradients.
11 | :param optimizer: optimizer with the gradients to be clipped
12 | :param grad_clip: clip value
13 | """
14 | for group in optimizer.param_groups:
15 | for param in group['params']:
16 | if param.grad is not None:
17 | param.grad.data.clamp_(-grad_clip, grad_clip)
18 |
19 |
20 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best):
21 | state = {'epoch': epoch,
22 | 'epochs_since_improvement': epochs_since_improvement,
23 | 'loss': loss,
24 | 'model': model,
25 | 'optimizer': optimizer}
26 |
27 | filename = 'checkpoint.tar'
28 | torch.save(state, filename)
29 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
30 | if is_best:
31 | torch.save(state, 'BEST_checkpoint.tar')
32 |
33 |
34 | class AverageMeter(object):
35 | """
36 | Keeps track of most recent, average, sum, and count of a metric.
37 | """
38 |
39 | def __init__(self):
40 | self.reset()
41 |
42 | def reset(self):
43 | self.val = 0
44 | self.avg = 0
45 | self.sum = 0
46 | self.count = 0
47 |
48 | def update(self, val, n=1):
49 | self.val = val
50 | self.sum += val * n
51 | self.count += n
52 | self.avg = self.sum / self.count
53 |
54 |
55 | def adjust_learning_rate(optimizer, shrink_factor):
56 | """
57 | Shrinks learning rate by a specified factor.
58 | :param optimizer: optimizer whose learning rate must be shrunk.
59 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
60 | """
61 |
62 | print("\nDECAYING learning rate.")
63 | for param_group in optimizer.param_groups:
64 | param_group['lr'] = param_group['lr'] * shrink_factor
65 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
66 |
67 |
68 | def accuracy(scores, targets, k=1):
69 | batch_size = targets.size(0)
70 | _, ind = scores.topk(k, 1, True, True)
71 | correct = ind.eq(targets.view(-1, 1).expand_as(ind))
72 | correct_total = correct.view(-1).float().sum() # 0D tensor
73 | return correct_total.item() * (100.0 / batch_size)
74 |
75 |
76 | def parse_args():
77 | parser = argparse.ArgumentParser(description='Transformer')
78 |
79 | # Network architecture
80 | # encoder
81 | # TODO: automatically infer input dim
82 | parser.add_argument('--n_layers_enc', default=6, type=int,
83 | help='Number of encoder stacks')
84 | parser.add_argument('--n_head', default=8, type=int,
85 | help='Number of Multi Head Attention (MHA)')
86 | parser.add_argument('--d_k', default=64, type=int,
87 | help='Dimension of key')
88 | parser.add_argument('--d_v', default=64, type=int,
89 | help='Dimension of value')
90 | parser.add_argument('--d_model', default=512, type=int,
91 | help='Dimension of model')
92 | parser.add_argument('--d_inner', default=2048, type=int,
93 | help='Dimension of inner')
94 | parser.add_argument('--dropout', default=0.1, type=float,
95 | help='Dropout rate')
96 | parser.add_argument('--pe_maxlen', default=5000, type=int,
97 | help='Positional Encoding max len')
98 | # decoder
99 | parser.add_argument('--d_word_vec', default=512, type=int,
100 | help='Dim of decoder embedding')
101 | parser.add_argument('--n_layers_dec', default=6, type=int,
102 | help='Number of decoder stacks')
103 | parser.add_argument('--tgt_emb_prj_weight_sharing', default=1, type=int,
104 | help='share decoder embedding with decoder projection')
105 | # Loss
106 | parser.add_argument('--label_smoothing', default=0.1, type=float,
107 | help='label smoothing')
108 |
109 | # Training config
110 | parser.add_argument('--epochs', default=1000, type=int,
111 | help='Number of maximum epochs')
112 | # minibatch
113 | parser.add_argument('--shuffle', default=1, type=int,
114 | help='reshuffle the data at every epoch')
115 | parser.add_argument('--batch-size', default=128, type=int,
116 | help='Batch size')
117 | parser.add_argument('--batch_frames', default=0, type=int,
118 | help='Batch frames. If this is not 0, batch size will make no sense')
119 | parser.add_argument('--maxlen-in', default=50, type=int, metavar='ML',
120 | help='Batch size is reduced if the input sequence length > ML')
121 | parser.add_argument('--maxlen-out', default=25, type=int, metavar='ML',
122 | help='Batch size is reduced if the output sequence length > ML')
123 | parser.add_argument('--num-workers', default=8, type=int,
124 | help='Number of workers to generate minibatch')
125 | # optimizer
126 | parser.add_argument('--k', default=0.2, type=float,
127 | help='tunable scalar multiply to learning rate')
128 | parser.add_argument('--warmup_steps', default=4000, type=int,
129 | help='warmup steps')
130 |
131 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint')
132 | args = parser.parse_args()
133 | return args
134 |
135 |
136 | def ensure_folder(folder):
137 | import os
138 | if not os.path.isdir(folder):
139 | os.mkdir(folder)
140 |
141 |
142 | def pad_list(xs, pad_value):
143 | # From: espnet/src/nets/e2e_asr_th.py: pad_list()
144 | n_batch = len(xs)
145 | max_len = max(x.size(0) for x in xs)
146 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
147 | for i in range(n_batch):
148 | pad[i, :xs[i].size(0)] = xs[i]
149 | return pad
150 |
151 |
152 | def text_to_sequence(text, char2idx):
153 | result = [char2idx[char] for char in text]
154 | return result
155 |
156 |
157 | def sequence_to_text(seq, idx2char):
158 | result = [idx2char[idx] for idx in seq]
159 | return result
160 |
161 |
162 | # Turn a Unicode string to plain ASCII, thanks to
163 | # http://stackoverflow.com/a/518232/2809427
164 | def unicodeToAscii(s):
165 | return ''.join(
166 | c for c in unicodedata.normalize('NFD', s)
167 | if unicodedata.category(c) != 'Mn'
168 | )
169 |
170 |
171 | def normalizeString(s):
172 | s = unicodeToAscii(s.lower().strip())
173 | s = re.sub(r"([.!?])", r" \1", s)
174 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
175 | return s
176 |
177 |
178 | def encode_text(word_map, c):
179 | return [word_map.get(word, word_map['']) for word in c]
180 |
--------------------------------------------------------------------------------
/vocab.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/vocab.pkl
--------------------------------------------------------------------------------