├── .gitignore ├── LICENSE ├── README.md ├── bleu_score.py ├── config.py ├── convert_valid.py ├── data_gen.py ├── demo.py ├── export.py ├── extract.py ├── images └── dataset.png ├── pre_process.py ├── test ├── test_bleu.py └── test_lr.py ├── train.py ├── transformer ├── __init__.py ├── attention.py ├── decoder.py ├── encoder.py ├── loss.py ├── module.py ├── optimizer.py ├── transformer.py └── utils.py ├── utils.py └── vocab.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 英中机器文本翻译 2 | 3 | 评测英中文本机器翻译的能力。机器翻译语言方向为英文到中文。 4 | 5 | 6 | ## 依赖 7 | 8 | - Python 3.6.8 9 | - PyTorch 1.3.0 10 | 11 | ## 数据集 12 | 13 | 我们使用AI Challenger 2017中的英中机器文本翻译数据集,超过1000万的英中对照的句子对作为数据集合。其中,训练集合占据绝大部分,为12904955对,验证集合8000对,测试集A 8000条,测试集B 8000条。 14 | 15 | 可以从这里下载:[英中翻译数据集](https://challenger.ai/datasets/translation) 16 | 17 | ![image](https://github.com/foamliu/Transformer/raw/master/images/dataset.png) 18 | 19 | ## 用法 20 | 21 | ### 数据预处理 22 | 提取训练和验证样本: 23 | ```bash 24 | $ python pre_process.py 25 | ``` 26 | 27 | ### 训练 28 | ```bash 29 | $ python train.py 30 | ``` 31 | 32 | 要想可视化训练过程,在终端中运行: 33 | ```bash 34 | $ tensorboard --logdir path_to_current_dir/logs 35 | ``` 36 | 37 | ### Demo 38 | 下载 [预训练模型](https://github.com/foamliu/Scene-Classification/releases/download/v1.0/model.85-0.7657.hdf5) 然后执行: 39 | 40 | ```bash 41 | $ python demo.py 42 | ``` 43 | 44 | 下面第一行是英文例句(数据集),第二行是人翻中文例句(数据集),之后一行是机翻(本模型)中文句子(实时生成)。 45 | 46 |
47 | 
48 | < there are places in the world where girls do n t get educated   simply because they are girls  .
49 | = 在这个世界上某些地方的女孩,不能得到教育仅仅因为她是女孩。
50 | > 世界上有一些女孩没有受过教育,仅仅是因为她们是女孩。
51 | < i  ve noticed this black van parked outside the house every day  .
52 | = 我注意到有辆黑色的车每天都停在房子外面。
53 | > 我每天都注意到一辆黑色货车停在房子外面。
54 | < i mean   i feel like i  m okay  cause i passed on the crazy  .
55 | = 我是说,我觉得不错因为已经从疯狂里解脱了。
56 | > 我是说,我感觉很好因为我错过了疯狂。
57 | < i do n t know  . i was having this dream  . it  s all right  . i  m here now  .
58 | = 我不知道。我刚才做了个梦。没事了。我在这里。
59 | > 我不知道。我做了个梦。没关系。我在这儿。
60 | < and if the girls see you up here   they  re gon na take it
61 | = 如果她们看到你坐这她们会觉得
62 | > 如果女孩们看到你在这里,他们会接受的
63 | < getting filled out like an application constitutes being with a guy  .
64 | = 愿意跟小男孩在一起。
65 | > 就像和男人在一起一样。
66 | < every single brother in my fraternity has worn this suit  .
67 | = 兄弟会里每个人都穿成这样。
68 | > 我兄弟会的每个兄弟都穿这件衣服。
69 | < but i  m here in town   checking out some real estate   and
70 | = 但我进城来看房子,以及-
71 | > 但我在城里,查一些房地产,
72 | < hey   i was gon na come and see you today and say hi  .
73 | = 嘿,我今天打算过去看看你打个招呼。
74 | > 嘿,我今天想去看你然后打个招呼。
75 | < it was a good show  . stop saying it was a good show  . shh  !
76 | = 今晚的节目很好。不要再说它好了!
77 | > 是个好节目。别说是好节目了。嘘!
78 | 
79 | 
80 | 
-------------------------------------------------------------------------------- /bleu_score.py: -------------------------------------------------------------------------------- 1 | # import the necessary packages 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from nltk.translate.bleu_score import sentence_bleu 8 | from tqdm import tqdm 9 | 10 | from config import device, logger, data_file, vocab_file 11 | from transformer.transformer import Transformer 12 | 13 | if __name__ == '__main__': 14 | filename = 'transformer.pt' 15 | print('loading {}...'.format(filename)) 16 | start = time.time() 17 | model = Transformer() 18 | model.load_state_dict(torch.load(filename)) 19 | print('elapsed {} sec'.format(time.time() - start)) 20 | model = model.to(device) 21 | model.eval() 22 | 23 | logger.info('loading samples...') 24 | start = time.time() 25 | with open(data_file, 'rb') as file: 26 | data = pickle.load(file) 27 | samples = data['valid'] 28 | elapsed = time.time() - start 29 | logger.info('elapsed: {:.4f} seconds'.format(elapsed)) 30 | 31 | logger.info('loading vocab...') 32 | start = time.time() 33 | with open(vocab_file, 'rb') as file: 34 | data = pickle.load(file) 35 | src_idx2char = data['dict']['src_idx2char'] 36 | tgt_idx2char = data['dict']['tgt_idx2char'] 37 | elapsed = time.time() - start 38 | logger.info('elapsed: {:.4f} seconds'.format(elapsed)) 39 | 40 | # samples = random.sample(samples, 10) 41 | bleu_scores = [] 42 | 43 | for sample in tqdm(samples): 44 | sentence_in = sample['in'] 45 | sentence_out = sample['out'] 46 | 47 | input = torch.from_numpy(np.array(sentence_in, dtype=np.long)).to(device) 48 | input_length = torch.LongTensor([len(sentence_in)]).to(device) 49 | 50 | sentence_in = ' '.join([src_idx2char[idx] for idx in sentence_in]) 51 | sentence_out = ''.join([tgt_idx2char[idx] for idx in sentence_out]) 52 | sentence_out = sentence_out.replace('', '').replace('', '') 53 | # print('< ' + sentence_in) 54 | # print('= ' + sentence_out) 55 | 56 | try: 57 | with torch.no_grad(): 58 | nbest_hyps = model.recognize(input=input, input_length=input_length, char_list=tgt_idx2char) 59 | # print(nbest_hyps) 60 | except RuntimeError: 61 | print('sentence_in: ' + sentence_in) 62 | continue 63 | 64 | score_list = [] 65 | for hyp in nbest_hyps: 66 | out = hyp['yseq'] 67 | out = [tgt_idx2char[idx] for idx in out] 68 | out = ''.join(out) 69 | out = out.replace('', '').replace('', '') 70 | reference = list(sentence_out) 71 | hypothesis = list(out) 72 | score = sentence_bleu([reference], hypothesis) 73 | score_list.append(score) 74 | 75 | bleu_scores.append(max(score_list)) 76 | 77 | print('len(bleu_scores): ' + str(len(bleu_scores))) 78 | print('np.max(bleu_scores): ' + str(np.max(bleu_scores))) 79 | print('np.min(bleu_scores): ' + str(np.min(bleu_scores))) 80 | print('np.mean(bleu_scores): ' + str(np.mean(bleu_scores))) 81 | 82 | import numpy as np 83 | import matplotlib.pyplot as plt 84 | 85 | # Fixing random state for reproducibility 86 | np.random.seed(19680801) 87 | 88 | # the histogram of the data 89 | n, bins, patches = plt.hist(bleu_scores, 50, density=True, facecolor='g', alpha=0.75) 90 | 91 | plt.xlabel('Scores') 92 | plt.ylabel('Probability') 93 | plt.title('BLEU Scores') 94 | plt.grid(True) 95 | plt.show() 96 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 6 | 7 | # Model parameters 8 | d_model = 512 9 | epochs = 10000 10 | embedding_size = 300 11 | hidden_size = 1024 12 | data_file = 'data.pkl' 13 | vocab_file = 'vocab.pkl' 14 | n_src_vocab = 15000 15 | n_tgt_vocab = 15000 # target 16 | maxlen_in = 100 17 | maxlen_out = 50 18 | # Training parameters 19 | grad_clip = 1.0 # clip gradients at an absolute value of 20 | print_freq = 50 # print training/validation stats every __ batches 21 | checkpoint = None # path to checkpoint, None if none 22 | 23 | # Data parameters 24 | IGNORE_ID = -1 25 | pad_id = 0 26 | sos_id = 1 27 | eos_id = 2 28 | unk_id = 3 29 | num_train = 4669414 30 | num_valid = 3870 31 | 32 | train_translation_en_filename = 'data/ai_challenger_translation_train_20170904/translation_train_data_20170904/train.en' 33 | train_translation_zh_filename = 'data/ai_challenger_translation_train_20170904/translation_train_data_20170904/train.zh' 34 | valid_translation_en_filename = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en' 35 | valid_translation_zh_filename = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.zh' 36 | 37 | 38 | def get_logger(): 39 | logger = logging.getLogger() 40 | handler = logging.StreamHandler() 41 | formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(threadName)s] %(name)s: %(message)s") 42 | handler.setFormatter(formatter) 43 | logger.addHandler(handler) 44 | logger.setLevel(logging.INFO) 45 | return logger 46 | 47 | 48 | logger = get_logger() 49 | -------------------------------------------------------------------------------- /convert_valid.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree 2 | 3 | valid_en_old = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en-zh.en.sgm' 4 | valid_en_new = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en' 5 | valid_zh_old = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.en-zh.zh.sgm' 6 | valid_zh_new = 'data/ai_challenger_translation_validation_20170912/translation_validation_20170912/valid.zh' 7 | 8 | 9 | def convert(old, new): 10 | print('old: ' + old) 11 | print('new: ' + new) 12 | with open(old, 'r', encoding='utf-8') as f: 13 | data = f.readlines() 14 | data = [line.replace(' & ', ' & ') for line in data] 15 | with open(new, 'w', encoding='utf-8') as f: 16 | f.writelines(data) 17 | 18 | root = xml.etree.ElementTree.parse(new).getroot() 19 | data = [elem.text.strip() + '\n' for elem in root.iter() if elem.tag == 'seg'] 20 | with open(new, 'w', encoding='utf-8') as file: 21 | file.writelines(data) 22 | 23 | 24 | if __name__ == '__main__': 25 | convert(valid_en_old, valid_en_new) 26 | convert(valid_zh_old, valid_zh_new) 27 | -------------------------------------------------------------------------------- /data_gen.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from config import data_file, vocab_file, IGNORE_ID, pad_id, logger 9 | 10 | logger.info('loading samples...') 11 | start = time.time() 12 | with open(data_file, 'rb') as file: 13 | data = pickle.load(file) 14 | elapsed = time.time() - start 15 | logger.info('elapsed: {:.4f}'.format(elapsed)) 16 | 17 | 18 | def get_data(filename): 19 | with open(filename, 'r') as file: 20 | data = file.readlines() 21 | data = [line.strip() for line in data] 22 | return data 23 | 24 | 25 | def pad_collate(batch): 26 | max_input_len = float('-inf') 27 | max_target_len = float('-inf') 28 | 29 | for elem in batch: 30 | src, tgt = elem 31 | max_input_len = max_input_len if max_input_len > len(src) else len(src) 32 | max_target_len = max_target_len if max_target_len > len(tgt) else len(tgt) 33 | 34 | for i, elem in enumerate(batch): 35 | src, tgt = elem 36 | input_length = len(src) 37 | padded_input = np.pad(src, (0, max_input_len - len(src)), 'constant', constant_values=pad_id) 38 | padded_target = np.pad(tgt, (0, max_target_len - len(tgt)), 'constant', constant_values=IGNORE_ID) 39 | batch[i] = (padded_input, padded_target, input_length) 40 | 41 | # sort it by input lengths (long to short) 42 | batch.sort(key=lambda x: x[2], reverse=True) 43 | 44 | return default_collate(batch) 45 | 46 | 47 | class AiChallenger2017Dataset(Dataset): 48 | def __init__(self, split): 49 | self.samples = data[split] 50 | 51 | def __getitem__(self, i): 52 | sample = self.samples[i] 53 | src_text = sample['in'] 54 | tgt_text = sample['out'] 55 | 56 | return np.array(src_text, dtype=np.long), np.array(tgt_text, np.long) 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | 61 | 62 | def main(): 63 | from utils import sequence_to_text 64 | 65 | valid_dataset = AiChallenger2017Dataset('valid') 66 | print(valid_dataset[0]) 67 | 68 | with open(vocab_file, 'rb') as file: 69 | data = pickle.load(file) 70 | 71 | src_idx2char = data['dict']['src_idx2char'] 72 | tgt_idx2char = data['dict']['tgt_idx2char'] 73 | 74 | src_text, tgt_text = valid_dataset[0] 75 | src_text = sequence_to_text(src_text, src_idx2char) 76 | src_text = ' '.join(src_text) 77 | print('src_text: ' + src_text) 78 | 79 | tgt_text = sequence_to_text(tgt_text, tgt_idx2char) 80 | tgt_text = ''.join(tgt_text) 81 | print('tgt_text: ' + tgt_text) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # import the necessary packages 2 | import pickle 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from config import device, logger, data_file, vocab_file 10 | from transformer.transformer import Transformer 11 | 12 | if __name__ == '__main__': 13 | filename = 'transformer.pt' 14 | print('loading {}...'.format(filename)) 15 | start = time.time() 16 | model = Transformer() 17 | model.load_state_dict(torch.load(filename)) 18 | print('elapsed {} sec'.format(time.time() - start)) 19 | model = model.to(device) 20 | model.eval() 21 | 22 | logger.info('loading samples...') 23 | start = time.time() 24 | with open(data_file, 'rb') as file: 25 | data = pickle.load(file) 26 | samples = data['valid'] 27 | elapsed = time.time() - start 28 | logger.info('elapsed: {:.4f} seconds'.format(elapsed)) 29 | 30 | logger.info('loading vocab...') 31 | start = time.time() 32 | with open(vocab_file, 'rb') as file: 33 | data = pickle.load(file) 34 | src_idx2char = data['dict']['src_idx2char'] 35 | tgt_idx2char = data['dict']['tgt_idx2char'] 36 | elapsed = time.time() - start 37 | logger.info('elapsed: {:.4f} seconds'.format(elapsed)) 38 | 39 | samples = random.sample(samples, 10) 40 | 41 | for sample in samples: 42 | sentence_in = sample['in'] 43 | sentence_out = sample['out'] 44 | 45 | input = torch.from_numpy(np.array(sentence_in, dtype=np.long)).to(device) 46 | input_length = torch.LongTensor([len(sentence_in)]).to(device) 47 | 48 | sentence_in = ' '.join([src_idx2char[idx] for idx in sentence_in]) 49 | sentence_out = ''.join([tgt_idx2char[idx] for idx in sentence_out]) 50 | sentence_out = sentence_out.replace('', '').replace('', '') 51 | print('< ' + sentence_in) 52 | print('= ' + sentence_out) 53 | 54 | with torch.no_grad(): 55 | nbest_hyps = model.recognize(input=input, input_length=input_length, char_list=tgt_idx2char) 56 | # print(nbest_hyps) 57 | 58 | for hyp in nbest_hyps: 59 | out = hyp['yseq'] 60 | out = [tgt_idx2char[idx] for idx in out] 61 | out = ''.join(out) 62 | out = out.replace('', '').replace('', '') 63 | 64 | print('> {}'.format(out)) 65 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | from transformer.transformer import Transformer 6 | 7 | if __name__ == '__main__': 8 | checkpoint = 'BEST_checkpoint.tar' 9 | print('loading {}...'.format(checkpoint)) 10 | start = time.time() 11 | checkpoint = torch.load(checkpoint) 12 | print('elapsed {} sec'.format(time.time() - start)) 13 | model = checkpoint['model'] 14 | print(type(model)) 15 | 16 | filename = 'transformer.pt' 17 | print('saving {}...'.format(filename)) 18 | start = time.time() 19 | torch.save(model.state_dict(), filename) 20 | print('elapsed {} sec'.format(time.time() - start)) 21 | 22 | print('loading {}...'.format(filename)) 23 | start = time.time() 24 | model = Transformer() 25 | model.load_state_dict(torch.load(filename)) 26 | print('elapsed {} sec'.format(time.time() - start)) 27 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | 4 | def extract(filename): 5 | print('Extracting {}...'.format(filename)) 6 | with zipfile.ZipFile(filename, 'r') as zip_ref: 7 | zip_ref.extractall('data') 8 | 9 | 10 | if __name__ == '__main__': 11 | extract('data/ai_challenger_translation_train_20170904.zip') 12 | extract('data/ai_challenger_translation_validation_20170912.zip') 13 | -------------------------------------------------------------------------------- /images/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/images/dataset.png -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import Counter 3 | 4 | import jieba 5 | import nltk 6 | from tqdm import tqdm 7 | 8 | from config import train_translation_en_filename, train_translation_zh_filename, valid_translation_en_filename, \ 9 | valid_translation_zh_filename, vocab_file, maxlen_in, maxlen_out, data_file, sos_id, eos_id, n_src_vocab, \ 10 | n_tgt_vocab, unk_id 11 | from utils import normalizeString, encode_text 12 | 13 | 14 | def build_vocab(token, word2idx, idx2char): 15 | if token not in word2idx: 16 | next_index = len(word2idx) 17 | word2idx[token] = next_index 18 | idx2char[next_index] = token 19 | 20 | 21 | def process(file, lang='zh'): 22 | print('processing {}...'.format(file)) 23 | with open(file, 'r', encoding='utf-8') as f: 24 | data = f.readlines() 25 | 26 | word_freq = Counter() 27 | lengths = [] 28 | 29 | for line in tqdm(data): 30 | sentence = line.strip() 31 | if lang == 'en': 32 | sentence_en = sentence.lower() 33 | tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en)] 34 | word_freq.update(list(tokens)) 35 | vocab_size = n_src_vocab 36 | else: 37 | seg_list = jieba.cut(sentence.strip()) 38 | tokens = list(seg_list) 39 | word_freq.update(list(tokens)) 40 | vocab_size = n_tgt_vocab 41 | 42 | lengths.append(len(tokens)) 43 | 44 | words = word_freq.most_common(vocab_size - 4) 45 | word_map = {k[0]: v + 4 for v, k in enumerate(words)} 46 | word_map[''] = 0 47 | word_map[''] = 1 48 | word_map[''] = 2 49 | word_map[''] = 3 50 | print(len(word_map)) 51 | print(words[:100]) 52 | # 53 | # n, bins, patches = plt.hist(lengths, 50, density=True, facecolor='g', alpha=0.75) 54 | # 55 | # plt.xlabel('Lengths') 56 | # plt.ylabel('Probability') 57 | # plt.title('Histogram of Lengths') 58 | # plt.grid(True) 59 | # plt.show() 60 | 61 | word2idx = word_map 62 | idx2char = {v: k for k, v in word2idx.items()} 63 | 64 | return word2idx, idx2char 65 | 66 | 67 | def get_data(in_file, out_file): 68 | print('getting data {}->{}...'.format(in_file, out_file)) 69 | with open(in_file, 'r', encoding='utf-8') as file: 70 | in_lines = file.readlines() 71 | with open(out_file, 'r', encoding='utf-8') as file: 72 | out_lines = file.readlines() 73 | 74 | samples = [] 75 | 76 | for i in tqdm(range(len(in_lines))): 77 | sentence_en = in_lines[i].strip().lower() 78 | tokens = [normalizeString(s.strip()) for s in nltk.word_tokenize(sentence_en)] 79 | in_data = encode_text(src_char2idx, tokens) 80 | 81 | sentence_zh = out_lines[i].strip() 82 | tokens = jieba.cut(sentence_zh.strip()) 83 | out_data = [sos_id] + encode_text(tgt_char2idx, tokens) + [eos_id] 84 | 85 | if len(in_data) < maxlen_in and len(out_data) < maxlen_out and unk_id not in in_data and unk_id not in out_data: 86 | samples.append({'in': in_data, 'out': out_data}) 87 | return samples 88 | 89 | 90 | if __name__ == '__main__': 91 | src_char2idx, src_idx2char = process(train_translation_en_filename, lang='en') 92 | tgt_char2idx, tgt_idx2char = process(train_translation_zh_filename, lang='zh') 93 | 94 | print(len(src_char2idx)) 95 | print(len(tgt_char2idx)) 96 | 97 | data = { 98 | 'dict': { 99 | 'src_char2idx': src_char2idx, 100 | 'src_idx2char': src_idx2char, 101 | 'tgt_char2idx': tgt_char2idx, 102 | 'tgt_idx2char': tgt_idx2char 103 | } 104 | } 105 | with open(vocab_file, 'wb') as file: 106 | pickle.dump(data, file) 107 | 108 | train = get_data(train_translation_en_filename, train_translation_zh_filename) 109 | valid = get_data(valid_translation_en_filename, valid_translation_zh_filename) 110 | 111 | data = { 112 | 'train': train, 113 | 'valid': valid 114 | } 115 | 116 | print('num_train: ' + str(len(train))) 117 | print('num_valid: ' + str(len(valid))) 118 | 119 | with open(data_file, 'wb') as file: 120 | pickle.dump(data, file) 121 | -------------------------------------------------------------------------------- /test/test_bleu.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import sentence_bleu 2 | 3 | reference = list('她的故事在法国遥远的西部山上') 4 | hypothesis = list('她的故事在法国的遥远山') 5 | score = sentence_bleu([reference], hypothesis) 6 | print(score) 7 | -------------------------------------------------------------------------------- /test/test_lr.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from config import d_model 4 | 5 | if __name__ == '__main__': 6 | warmup_steps = 4000 7 | init_lr = d_model ** (-0.5) 8 | 9 | lr_list = [] 10 | for step_num in range(1, 500000): 11 | # print(step_num) 12 | lr = init_lr * min(step_num ** (-0.65), step_num * (warmup_steps ** (-1.5))) 13 | 14 | lr_list.append(lr) 15 | 16 | plt.plot(lr_list) 17 | plt.show() 18 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from config import device, print_freq, sos_id, eos_id, n_src_vocab, n_tgt_vocab, grad_clip, logger 9 | from data_gen import AiChallenger2017Dataset, pad_collate 10 | from transformer.decoder import Decoder 11 | from transformer.encoder import Encoder 12 | from transformer.loss import cal_performance 13 | from transformer.optimizer import TransformerOptimizer 14 | from transformer.transformer import Transformer 15 | from utils import parse_args, save_checkpoint, AverageMeter, clip_gradient 16 | 17 | 18 | # from torch import nn 19 | 20 | 21 | def train_net(args): 22 | torch.manual_seed(7) 23 | np.random.seed(7) 24 | checkpoint = args.checkpoint 25 | start_epoch = 0 26 | best_loss = float('inf') 27 | writer = SummaryWriter() 28 | epochs_since_improvement = 0 29 | 30 | # Initialize / load checkpoint 31 | if checkpoint is None: 32 | # model 33 | encoder = Encoder(n_src_vocab, args.n_layers_enc, args.n_head, 34 | args.d_k, args.d_v, args.d_model, args.d_inner, 35 | dropout=args.dropout, pe_maxlen=args.pe_maxlen) 36 | decoder = Decoder(sos_id, eos_id, n_tgt_vocab, 37 | args.d_word_vec, args.n_layers_dec, args.n_head, 38 | args.d_k, args.d_v, args.d_model, args.d_inner, 39 | dropout=args.dropout, 40 | tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing, 41 | pe_maxlen=args.pe_maxlen) 42 | model = Transformer(encoder, decoder) 43 | # print(model) 44 | # model = nn.DataParallel(model) 45 | 46 | # optimizer 47 | optimizer = TransformerOptimizer( 48 | torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09)) 49 | 50 | else: 51 | checkpoint = torch.load(checkpoint) 52 | start_epoch = checkpoint['epoch'] + 1 53 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 54 | model = checkpoint['model'] 55 | optimizer = checkpoint['optimizer'] 56 | 57 | # Move to GPU, if available 58 | model = model.to(device) 59 | 60 | # Custom dataloaders 61 | train_dataset = AiChallenger2017Dataset('train') 62 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate, 63 | shuffle=True, num_workers=args.num_workers) 64 | valid_dataset = AiChallenger2017Dataset('valid') 65 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=pad_collate, 66 | shuffle=False, num_workers=args.num_workers) 67 | 68 | # Epochs 69 | for epoch in range(start_epoch, args.epochs): 70 | # One epoch's training 71 | train_loss = train(train_loader=train_loader, 72 | model=model, 73 | optimizer=optimizer, 74 | epoch=epoch, 75 | logger=logger, 76 | writer=writer) 77 | 78 | writer.add_scalar('epoch/train_loss', train_loss, epoch) 79 | writer.add_scalar('epoch/learning_rate', optimizer.lr, epoch) 80 | 81 | print('\nLearning rate: {}'.format(optimizer.lr)) 82 | print('Step num: {}\n'.format(optimizer.step_num)) 83 | 84 | # One epoch's validation 85 | valid_loss = valid(valid_loader=valid_loader, 86 | model=model, 87 | logger=logger) 88 | writer.add_scalar('epoch/valid_loss', valid_loss, epoch) 89 | 90 | # Check if there was an improvement 91 | is_best = valid_loss < best_loss 92 | best_loss = min(valid_loss, best_loss) 93 | if not is_best: 94 | epochs_since_improvement += 1 95 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 96 | else: 97 | epochs_since_improvement = 0 98 | 99 | # Save checkpoint 100 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best) 101 | 102 | 103 | def train(train_loader, model, optimizer, epoch, logger, writer): 104 | model.train() # train mode (dropout and batchnorm is used) 105 | 106 | losses = AverageMeter() 107 | times = AverageMeter() 108 | 109 | start = time.time() 110 | 111 | # Batches 112 | for i, (data) in enumerate(train_loader): 113 | # Move to GPU, if available 114 | padded_input, padded_target, input_lengths = data 115 | padded_input = padded_input.to(device) 116 | padded_target = padded_target.to(device) 117 | input_lengths = input_lengths.to(device) 118 | 119 | # Forward prop. 120 | pred, gold = model(padded_input, input_lengths, padded_target) 121 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing) 122 | try: 123 | assert (not math.isnan(loss.item())) 124 | except AssertionError: 125 | print('n_correct: ' + str(n_correct)) 126 | print('data: ' + str(n_correct)) 127 | continue 128 | 129 | # Back prop. 130 | optimizer.zero_grad() 131 | loss.backward() 132 | 133 | # Clip gradients 134 | clip_gradient(optimizer.optimizer, grad_clip) 135 | 136 | # Update weights 137 | optimizer.step() 138 | 139 | # Keep track of metrics 140 | elapsed = time.time() - start 141 | start = time.time() 142 | 143 | losses.update(loss.item()) 144 | times.update(elapsed) 145 | 146 | # Print status 147 | if i % print_freq == 0: 148 | logger.info('Epoch: [{0}][{1}/{2}]\t' 149 | 'Batch time {time.val:.5f} ({time.avg:.5f})\t' 150 | 'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, len(train_loader), time=times, 151 | loss=losses)) 152 | writer.add_scalar('step_num/train_loss', losses.avg, optimizer.step_num) 153 | writer.add_scalar('step_num/learning_rate', optimizer.lr, optimizer.step_num) 154 | 155 | return losses.avg 156 | 157 | 158 | def valid(valid_loader, model, logger): 159 | model.eval() 160 | 161 | losses = AverageMeter() 162 | 163 | # Batches 164 | for data in valid_loader: 165 | # Move to GPU, if available 166 | padded_input, padded_target, input_lengths = data 167 | padded_input = padded_input.to(device) 168 | padded_target = padded_target.to(device) 169 | input_lengths = input_lengths.to(device) 170 | 171 | with torch.no_grad(): 172 | # Forward prop. 173 | pred, gold = model(padded_input, input_lengths, padded_target) 174 | loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing) 175 | try: 176 | assert (not math.isnan(loss.item())) 177 | except AssertionError: 178 | print('n_correct: ' + str(n_correct)) 179 | print('data: ' + str(n_correct)) 180 | continue 181 | 182 | # Keep track of metrics 183 | losses.update(loss.item()) 184 | 185 | # Print status 186 | logger.info('\nValidation Loss {loss.val:.5f} ({loss.avg:.5f})\n'.format(loss=losses)) 187 | 188 | return losses.avg 189 | 190 | 191 | def main(): 192 | global args 193 | args = parse_args() 194 | train_net(args) 195 | 196 | 197 | if __name__ == '__main__': 198 | main() 199 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/transformer/__init__.py -------------------------------------------------------------------------------- /transformer/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MultiHeadAttention(nn.Module): 7 | ''' Multi-Head Attention module ''' 8 | 9 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 10 | super().__init__() 11 | 12 | self.n_head = n_head 13 | self.d_k = d_k 14 | self.d_v = d_v 15 | 16 | self.w_qs = nn.Linear(d_model, n_head * d_k) 17 | self.w_ks = nn.Linear(d_model, n_head * d_k) 18 | self.w_vs = nn.Linear(d_model, n_head * d_v) 19 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 20 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 21 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 22 | 23 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), 24 | attn_dropout=dropout) 25 | self.layer_norm = nn.LayerNorm(d_model) 26 | 27 | self.fc = nn.Linear(n_head * d_v, d_model) 28 | nn.init.xavier_normal_(self.fc.weight) 29 | 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | def forward(self, q, k, v, mask=None): 33 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 34 | 35 | sz_b, len_q, _ = q.size() 36 | sz_b, len_k, _ = k.size() 37 | sz_b, len_v, _ = v.size() 38 | 39 | residual = q 40 | 41 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 42 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 43 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 44 | 45 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 46 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 47 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 48 | 49 | if mask is not None: 50 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 51 | 52 | output, attn = self.attention(q, k, v, mask=mask) 53 | 54 | output = output.view(n_head, sz_b, len_q, d_v) 55 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 56 | 57 | output = self.dropout(self.fc(output)) 58 | output = self.layer_norm(output + residual) 59 | 60 | return output, attn 61 | 62 | 63 | class ScaledDotProductAttention(nn.Module): 64 | ''' Scaled Dot-Product Attention ''' 65 | 66 | def __init__(self, temperature, attn_dropout=0.1): 67 | super().__init__() 68 | self.temperature = temperature 69 | self.dropout = nn.Dropout(attn_dropout) 70 | self.softmax = nn.Softmax(dim=2) 71 | 72 | def forward(self, q, k, v, mask=None): 73 | attn = torch.bmm(q, k.transpose(1, 2)) 74 | attn = attn / self.temperature 75 | 76 | if mask is not None: 77 | attn = attn.masked_fill(mask.bool(), -np.inf) 78 | 79 | attn = self.softmax(attn) 80 | attn = self.dropout(attn) 81 | output = torch.bmm(attn, v) 82 | 83 | return output, attn 84 | -------------------------------------------------------------------------------- /transformer/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from config import IGNORE_ID, sos_id, eos_id, n_tgt_vocab 6 | from .attention import MultiHeadAttention 7 | from .module import PositionalEncoding, PositionwiseFeedForward 8 | from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list 9 | 10 | 11 | class Decoder(nn.Module): 12 | ''' A decoder model with self attention mechanism. ''' 13 | 14 | def __init__( 15 | self, sos_id=sos_id, eos_id=eos_id, 16 | n_tgt_vocab=n_tgt_vocab, d_word_vec=512, 17 | n_layers=6, n_head=8, d_k=64, d_v=64, 18 | d_model=512, d_inner=2048, dropout=0.1, 19 | tgt_emb_prj_weight_sharing=True, 20 | pe_maxlen=5000): 21 | super(Decoder, self).__init__() 22 | # parameters 23 | self.sos_id = sos_id # Start of Sentence 24 | self.eos_id = eos_id # End of Sentence 25 | self.n_tgt_vocab = n_tgt_vocab 26 | self.d_word_vec = d_word_vec 27 | self.n_layers = n_layers 28 | self.n_head = n_head 29 | self.d_k = d_k 30 | self.d_v = d_v 31 | self.d_model = d_model 32 | self.d_inner = d_inner 33 | self.dropout = dropout 34 | self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing 35 | self.pe_maxlen = pe_maxlen 36 | 37 | self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec) 38 | self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | self.layer_stack = nn.ModuleList([ 42 | DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 43 | for _ in range(n_layers)]) 44 | 45 | self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) 46 | nn.init.xavier_normal_(self.tgt_word_prj.weight) 47 | 48 | if tgt_emb_prj_weight_sharing: 49 | # Share the weight matrix between target word embedding & the final logit dense layer 50 | self.tgt_word_prj.weight = self.tgt_word_emb.weight 51 | self.x_logit_scale = (d_model ** -0.5) 52 | else: 53 | self.x_logit_scale = 1. 54 | 55 | def preprocess(self, padded_input): 56 | """Generate decoder input and output label from padded_input 57 | Add to decoder input, and add to decoder output label 58 | """ 59 | ys = [y[y != IGNORE_ID] for y in padded_input] # parse padded ys 60 | # prepare input and output word sequences with sos/eos IDs 61 | eos = ys[0].new([self.eos_id]) 62 | sos = ys[0].new([self.sos_id]) 63 | ys_in = [torch.cat([sos, y], dim=0) for y in ys] 64 | ys_out = [torch.cat([y, eos], dim=0) for y in ys] 65 | # padding for ys with -1 66 | # pys: utt x olen 67 | ys_in_pad = pad_list(ys_in, self.eos_id) 68 | ys_out_pad = pad_list(ys_out, IGNORE_ID) 69 | assert ys_in_pad.size() == ys_out_pad.size() 70 | return ys_in_pad, ys_out_pad 71 | 72 | def forward(self, padded_input, encoder_padded_outputs, 73 | encoder_input_lengths, return_attns=False): 74 | """ 75 | Args: 76 | padded_input: N x To 77 | encoder_padded_outputs: N x Ti x H 78 | Returns: 79 | """ 80 | dec_slf_attn_list, dec_enc_attn_list = [], [] 81 | 82 | # Get Deocder Input and Output 83 | ys_in_pad, ys_out_pad = self.preprocess(padded_input) 84 | 85 | # Prepare masks 86 | non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) 87 | 88 | slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) 89 | slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad, 90 | seq_q=ys_in_pad, 91 | pad_idx=self.eos_id) 92 | slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) 93 | 94 | output_length = ys_in_pad.size(1) 95 | dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs, 96 | encoder_input_lengths, 97 | output_length) 98 | 99 | # Forward 100 | dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale + 101 | self.positional_encoding(ys_in_pad)) 102 | 103 | for dec_layer in self.layer_stack: 104 | dec_output, dec_slf_attn, dec_enc_attn = dec_layer( 105 | dec_output, encoder_padded_outputs, 106 | non_pad_mask=non_pad_mask, 107 | slf_attn_mask=slf_attn_mask, 108 | dec_enc_attn_mask=dec_enc_attn_mask) 109 | 110 | if return_attns: 111 | dec_slf_attn_list += [dec_slf_attn] 112 | dec_enc_attn_list += [dec_enc_attn] 113 | 114 | # before softmax 115 | seq_logit = self.tgt_word_prj(dec_output) 116 | 117 | # Return 118 | pred, gold = seq_logit, ys_out_pad 119 | 120 | if return_attns: 121 | return pred, gold, dec_slf_attn_list, dec_enc_attn_list 122 | return pred, gold 123 | 124 | def recognize_beam(self, encoder_outputs, char_list): 125 | """Beam search, decode one utterence now. 126 | Args: 127 | encoder_outputs: T x H 128 | char_list: list of character 129 | args: args.beam 130 | Returns: 131 | nbest_hyps: 132 | """ 133 | # search params 134 | beam = 5 135 | nbest = 1 136 | maxlen = 100 137 | 138 | encoder_outputs = encoder_outputs.unsqueeze(0) 139 | 140 | # prepare sos 141 | ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() 142 | 143 | # yseq: 1xT 144 | hyp = {'score': 0.0, 'yseq': ys} 145 | hyps = [hyp] 146 | ended_hyps = [] 147 | 148 | for i in range(maxlen): 149 | hyps_best_kept = [] 150 | for hyp in hyps: 151 | ys = hyp['yseq'] # 1 x i 152 | 153 | # -- Prepare masks 154 | non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1 155 | slf_attn_mask = get_subsequent_mask(ys) 156 | 157 | # -- Forward 158 | dec_output = self.dropout( 159 | self.tgt_word_emb(ys) * self.x_logit_scale + 160 | self.positional_encoding(ys)) 161 | 162 | for dec_layer in self.layer_stack: 163 | dec_output, _, _ = dec_layer( 164 | dec_output, encoder_outputs, 165 | non_pad_mask=non_pad_mask, 166 | slf_attn_mask=slf_attn_mask, 167 | dec_enc_attn_mask=None) 168 | 169 | seq_logit = self.tgt_word_prj(dec_output[:, -1]) 170 | 171 | local_scores = F.log_softmax(seq_logit, dim=1) 172 | # topk scores 173 | local_best_scores, local_best_ids = torch.topk( 174 | local_scores, beam, dim=1) 175 | 176 | for j in range(beam): 177 | new_hyp = {} 178 | new_hyp['score'] = hyp['score'] + local_best_scores[0, j] 179 | new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long() 180 | new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq'] 181 | new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j]) 182 | # will be (2 x beam) hyps at most 183 | hyps_best_kept.append(new_hyp) 184 | 185 | hyps_best_kept = sorted(hyps_best_kept, 186 | key=lambda x: x['score'], 187 | reverse=True)[:beam] 188 | # end for hyp in hyps 189 | hyps = hyps_best_kept 190 | 191 | # add eos in the final loop to avoid that there are no ended hyps 192 | if i == maxlen - 1: 193 | for hyp in hyps: 194 | hyp['yseq'] = torch.cat([hyp['yseq'], 195 | torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()], 196 | dim=1) 197 | 198 | # add ended hypothes to a final list, and removed them from current hypothes 199 | # (this will be a probmlem, number of hyps < beam) 200 | remained_hyps = [] 201 | for hyp in hyps: 202 | if hyp['yseq'][0, -1] == self.eos_id: 203 | ended_hyps.append(hyp) 204 | else: 205 | remained_hyps.append(hyp) 206 | 207 | hyps = remained_hyps 208 | # if len(hyps) > 0: 209 | # print('remeined hypothes: ' + str(len(hyps))) 210 | # else: 211 | # print('no hypothesis. Finish decoding.') 212 | # break 213 | # 214 | # for hyp in hyps: 215 | # print('hypo: ' + ''.join([char_list[int(x)] 216 | # for x in hyp['yseq'][0, 1:]])) 217 | # end for i in range(maxlen) 218 | nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[ 219 | :min(len(ended_hyps), nbest)] 220 | # compitable with LAS implementation 221 | for hyp in nbest_hyps: 222 | hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist() 223 | return nbest_hyps 224 | 225 | 226 | class DecoderLayer(nn.Module): 227 | ''' Compose with three layers ''' 228 | 229 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 230 | super(DecoderLayer, self).__init__() 231 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 232 | self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 233 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 234 | 235 | def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None): 236 | dec_output, dec_slf_attn = self.slf_attn( 237 | dec_input, dec_input, dec_input, mask=slf_attn_mask) 238 | dec_output *= non_pad_mask 239 | 240 | dec_output, dec_enc_attn = self.enc_attn( 241 | dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) 242 | dec_output *= non_pad_mask 243 | 244 | dec_output = self.pos_ffn(dec_output) 245 | dec_output *= non_pad_mask 246 | 247 | return dec_output, dec_slf_attn, dec_enc_attn 248 | -------------------------------------------------------------------------------- /transformer/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from config import pad_id, n_src_vocab 4 | from .attention import MultiHeadAttention 5 | from .module import PositionalEncoding, PositionwiseFeedForward 6 | from .utils import get_non_pad_mask, get_attn_pad_mask 7 | 8 | 9 | class Encoder(nn.Module): 10 | """Encoder of Transformer including self-attention and feed forward. 11 | """ 12 | 13 | def __init__(self, n_src_vocab=n_src_vocab, n_layers=6, n_head=8, d_k=64, d_v=64, 14 | d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000): 15 | super(Encoder, self).__init__() 16 | # parameters 17 | self.n_src_vocab = n_src_vocab 18 | self.n_layers = n_layers 19 | self.n_head = n_head 20 | self.d_k = d_k 21 | self.d_v = d_v 22 | self.d_model = d_model 23 | self.d_inner = d_inner 24 | self.dropout_rate = dropout 25 | self.pe_maxlen = pe_maxlen 26 | 27 | self.src_emb = nn.Embedding(n_src_vocab, d_model, padding_idx=pad_id) 28 | self.pos_emb = PositionalEncoding(d_model, max_len=pe_maxlen) 29 | self.dropout = nn.Dropout(dropout) 30 | 31 | self.layer_stack = nn.ModuleList([ 32 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 33 | for _ in range(n_layers)]) 34 | 35 | def forward(self, padded_input, input_lengths, return_attns=False): 36 | """ 37 | Args: 38 | padded_input: N x T 39 | input_lengths: N 40 | Returns: 41 | enc_output: N x T x H 42 | """ 43 | enc_slf_attn_list = [] 44 | 45 | # Forward 46 | enc_outputs = self.src_emb(padded_input) 47 | enc_outputs += self.pos_emb(enc_outputs) 48 | enc_output = self.dropout(enc_outputs) 49 | 50 | # Prepare masks 51 | non_pad_mask = get_non_pad_mask(enc_output, input_lengths=input_lengths) 52 | length = padded_input.size(1) 53 | slf_attn_mask = get_attn_pad_mask(enc_output, input_lengths, length) 54 | 55 | for enc_layer in self.layer_stack: 56 | enc_output, enc_slf_attn = enc_layer( 57 | enc_output, 58 | non_pad_mask=non_pad_mask, 59 | slf_attn_mask=slf_attn_mask) 60 | if return_attns: 61 | enc_slf_attn_list += [enc_slf_attn] 62 | 63 | if return_attns: 64 | return enc_output, enc_slf_attn_list 65 | return enc_output, 66 | 67 | 68 | class EncoderLayer(nn.Module): 69 | """Compose with two sub-layers. 70 | 1. A multi-head self-attention mechanism 71 | 2. A simple, position-wise fully connected feed-forward network. 72 | """ 73 | 74 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 75 | super(EncoderLayer, self).__init__() 76 | self.slf_attn = MultiHeadAttention( 77 | n_head, d_model, d_k, d_v, dropout=dropout) 78 | self.pos_ffn = PositionwiseFeedForward( 79 | d_model, d_inner, dropout=dropout) 80 | 81 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 82 | enc_output, enc_slf_attn = self.slf_attn( 83 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 84 | enc_output *= non_pad_mask 85 | 86 | enc_output = self.pos_ffn(enc_output) 87 | enc_output *= non_pad_mask 88 | 89 | return enc_output, enc_slf_attn 90 | -------------------------------------------------------------------------------- /transformer/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from config import IGNORE_ID 5 | 6 | 7 | def cal_performance(pred, gold, smoothing=0.0): 8 | """Calculate cross entropy loss, apply label smoothing if needed. 9 | Args: 10 | pred: N x T x C, score before softmax 11 | gold: N x T 12 | """ 13 | 14 | pred = pred.view(-1, pred.size(2)) 15 | gold = gold.contiguous().view(-1) 16 | 17 | loss = cal_loss(pred, gold, smoothing) 18 | 19 | pred = pred.max(1)[1] 20 | non_pad_mask = gold.ne(IGNORE_ID) 21 | n_correct = pred.eq(gold) 22 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 23 | 24 | return loss, n_correct 25 | 26 | 27 | def cal_loss(pred, gold, smoothing=0.0): 28 | """Calculate cross entropy loss, apply label smoothing if needed. 29 | """ 30 | 31 | if smoothing > 0.0: 32 | eps = smoothing 33 | n_class = pred.size(1) 34 | 35 | # Generate one-hot matrix: N x C. 36 | # Only label position is 1 and all other positions are 0 37 | # gold include -1 value (IGNORE_ID) and this will lead to assert error 38 | gold_for_scatter = gold.ne(IGNORE_ID).long() * gold 39 | one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1) 40 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class 41 | log_prb = F.log_softmax(pred, dim=1) 42 | 43 | non_pad_mask = gold.ne(IGNORE_ID) 44 | n_word = non_pad_mask.sum().item() 45 | loss = -(one_hot * log_prb).sum(dim=1) 46 | loss = loss.masked_select(non_pad_mask).sum() / n_word 47 | else: 48 | loss = F.cross_entropy(pred, gold, ignore_index=IGNORE_ID, reduction='mean') 49 | 50 | return loss 51 | -------------------------------------------------------------------------------- /transformer/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PositionalEncoding(nn.Module): 9 | """Implement the positional encoding (PE) function. 10 | PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) 11 | PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) 12 | """ 13 | 14 | def __init__(self, d_model, max_len=5000): 15 | super(PositionalEncoding, self).__init__() 16 | # Compute the positional encodings once in log space. 17 | pe = torch.zeros(max_len, d_model, requires_grad=False) 18 | position = torch.arange(0, max_len).unsqueeze(1).float() 19 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 20 | -(math.log(10000.0) / d_model)) 21 | pe[:, 0::2] = torch.sin(position * div_term) 22 | pe[:, 1::2] = torch.cos(position * div_term) 23 | pe = pe.unsqueeze(0) 24 | self.register_buffer('pe', pe) 25 | 26 | def forward(self, input): 27 | """ 28 | Args: 29 | input: N x T x D 30 | """ 31 | length = input.size(1) 32 | return self.pe[:, :length] 33 | 34 | 35 | class PositionwiseFeedForward(nn.Module): 36 | """Implements position-wise feedforward sublayer. 37 | FFN(x) = max(0, xW1 + b1)W2 + b2 38 | """ 39 | 40 | def __init__(self, d_model, d_ff, dropout=0.1): 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = nn.Linear(d_model, d_ff) 43 | self.w_2 = nn.Linear(d_ff, d_model) 44 | self.dropout = nn.Dropout(dropout) 45 | self.layer_norm = nn.LayerNorm(d_model) 46 | 47 | def forward(self, x): 48 | residual = x 49 | output = self.w_2(F.relu(self.w_1(x))) 50 | output = self.dropout(output) 51 | output = self.layer_norm(output + residual) 52 | return output 53 | 54 | 55 | # Another implementation 56 | class PositionwiseFeedForwardUseConv(nn.Module): 57 | """A two-feed-forward-layer module""" 58 | 59 | def __init__(self, d_in, d_hid, dropout=0.1): 60 | super(PositionwiseFeedForwardUseConv, self).__init__() 61 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 62 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 63 | self.layer_norm = nn.LayerNorm(d_in) 64 | self.dropout = nn.Dropout(dropout) 65 | 66 | def forward(self, x): 67 | residual = x 68 | output = x.transpose(1, 2) 69 | output = self.w_2(F.relu(self.w_1(output))) 70 | output = output.transpose(1, 2) 71 | output = self.dropout(output) 72 | output = self.layer_norm(output + residual) 73 | return output 74 | -------------------------------------------------------------------------------- /transformer/optimizer.py: -------------------------------------------------------------------------------- 1 | from config import d_model 2 | 3 | 4 | class TransformerOptimizer(object): 5 | """A simple wrapper class for learning rate scheduling""" 6 | 7 | def __init__(self, optimizer, warmup_steps=4000): 8 | self.optimizer = optimizer 9 | self.init_lr = d_model ** (-0.5) 10 | self.warmup_steps = warmup_steps 11 | self.lr = self.init_lr 12 | self.step_num = 0 13 | 14 | def zero_grad(self): 15 | self.optimizer.zero_grad() 16 | 17 | def step(self): 18 | self._update_lr() 19 | self.optimizer.step() 20 | 21 | def _update_lr(self): 22 | self.step_num += 1 23 | self.min_lr = 1e-5 24 | self.lr = self.init_lr * min(self.step_num ** (-0.65), self.step_num * (self.warmup_steps ** (-1.5))) 25 | self.lr = max(self.lr, self.min_lr) 26 | for param_group in self.optimizer.param_groups: 27 | param_group['lr'] = self.lr 28 | -------------------------------------------------------------------------------- /transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .decoder import Decoder 4 | from .encoder import Encoder 5 | 6 | 7 | class Transformer(nn.Module): 8 | """An encoder-decoder framework only includes attention. 9 | """ 10 | 11 | def __init__(self, encoder=None, decoder=None): 12 | super(Transformer, self).__init__() 13 | self.encoder = encoder 14 | self.decoder = decoder 15 | 16 | if encoder is not None and decoder is not None: 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | 20 | for p in self.parameters(): 21 | if p.dim() > 1: 22 | nn.init.xavier_uniform_(p) 23 | else: 24 | self.encoder = Encoder() 25 | self.decoder = Decoder() 26 | 27 | def forward(self, padded_input, input_lengths, padded_target): 28 | """ 29 | Args: 30 | padded_input: N x Ti x D 31 | input_lengths: N 32 | padded_targets: N x To 33 | """ 34 | encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths) 35 | # pred is score before softmax 36 | pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, 37 | input_lengths) 38 | return pred, gold 39 | 40 | def recognize(self, input, input_length, char_list): 41 | """Sequence-to-Sequence beam search, decode one utterence now. 42 | Args: 43 | input: T x D 44 | char_list: list of characters 45 | args: args.beam 46 | Returns: 47 | nbest_hyps: 48 | """ 49 | encoder_outputs, enc_slf_attn_list = self.encoder(padded_input=input.unsqueeze(0), input_lengths=input_length, 50 | return_attns=True) 51 | nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0], char_list) 52 | return nbest_hyps 53 | -------------------------------------------------------------------------------- /transformer/utils.py: -------------------------------------------------------------------------------- 1 | def pad_list(xs, pad_value): 2 | # From: espnet/src/nets/e2e_asr_th.py: pad_list() 3 | n_batch = len(xs) 4 | max_len = max(x.size(0) for x in xs) 5 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 6 | for i in range(n_batch): 7 | pad[i, :xs[i].size(0)] = xs[i] 8 | return pad 9 | 10 | 11 | def process_dict(dict_path): 12 | with open(dict_path, 'rb') as f: 13 | dictionary = f.readlines() 14 | char_list = [entry.decode('utf-8').split(' ')[0] 15 | for entry in dictionary] 16 | sos_id = char_list.index('') 17 | eos_id = char_list.index('') 18 | return char_list, sos_id, eos_id 19 | 20 | 21 | if __name__ == "__main__": 22 | import sys 23 | 24 | path = sys.argv[1] 25 | char_list, sos_id, eos_id = process_dict(path) 26 | print(char_list, sos_id, eos_id) 27 | 28 | 29 | # * ------------------ recognition related ------------------ * 30 | 31 | 32 | def parse_hypothesis(hyp, char_list): 33 | """Function to parse hypothesis 34 | :param list hyp: recognition hypothesis 35 | :param list char_list: list of characters 36 | :return: recognition text strinig 37 | :return: recognition token strinig 38 | :return: recognition tokenid string 39 | """ 40 | # remove sos and get results 41 | tokenid_as_list = list(map(int, hyp['yseq'][1:])) 42 | token_as_list = [char_list[idx] for idx in tokenid_as_list] 43 | score = float(hyp['score']) 44 | 45 | # convert to string 46 | tokenid = " ".join([str(idx) for idx in tokenid_as_list]) 47 | token = " ".join(token_as_list) 48 | text = "".join(token_as_list).replace('', ' ') 49 | 50 | return text, token, tokenid, score 51 | 52 | 53 | def add_results_to_json(js, nbest_hyps, char_list): 54 | """Function to add N-best results to json 55 | :param dict js: groundtruth utterance dict 56 | :param list nbest_hyps: list of hypothesis 57 | :param list char_list: list of characters 58 | :return: N-best results added utterance dict 59 | """ 60 | # copy old json info 61 | new_js = dict() 62 | new_js['utt2spk'] = js['utt2spk'] 63 | new_js['output'] = [] 64 | 65 | for n, hyp in enumerate(nbest_hyps, 1): 66 | # parse hypothesis 67 | rec_text, rec_token, rec_tokenid, score = parse_hypothesis( 68 | hyp, char_list) 69 | 70 | # copy ground-truth 71 | out_dic = dict(js['output'][0].items()) 72 | 73 | # update name 74 | out_dic['name'] += '[%d]' % n 75 | 76 | # add recognition results 77 | out_dic['rec_text'] = rec_text 78 | out_dic['rec_token'] = rec_token 79 | out_dic['rec_tokenid'] = rec_tokenid 80 | out_dic['score'] = score 81 | 82 | # add to list of N-best result dicts 83 | new_js['output'].append(out_dic) 84 | 85 | # show 1-best result 86 | if n == 1: 87 | print('groundtruth: %s' % out_dic['text']) 88 | print('prediction : %s' % out_dic['rec_text']) 89 | 90 | return new_js 91 | 92 | 93 | # -- Transformer Related -- 94 | import torch 95 | 96 | 97 | def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None): 98 | """padding position is set to 0, either use input_lengths or pad_idx 99 | """ 100 | assert input_lengths is not None or pad_idx is not None 101 | if input_lengths is not None: 102 | # padded_input: N x T x .. 103 | N = padded_input.size(0) 104 | non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # N x T 105 | for i in range(N): 106 | non_pad_mask[i, input_lengths[i]:] = 0 107 | if pad_idx is not None: 108 | # padded_input: N x T 109 | assert padded_input.dim() == 2 110 | non_pad_mask = padded_input.ne(pad_idx).float() 111 | # unsqueeze(-1) for broadcast 112 | return non_pad_mask.unsqueeze(-1) 113 | 114 | 115 | def get_subsequent_mask(seq): 116 | ''' For masking out the subsequent info. ''' 117 | 118 | sz_b, len_s = seq.size() 119 | subsequent_mask = torch.triu( 120 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) 121 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 122 | 123 | return subsequent_mask 124 | 125 | 126 | def get_attn_key_pad_mask(seq_k, seq_q, pad_idx): 127 | ''' For masking out the padding part of key sequence. ''' 128 | 129 | # Expand to fit the shape of key query attention matrix. 130 | len_q = seq_q.size(1) 131 | padding_mask = seq_k.eq(pad_idx) 132 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 133 | 134 | return padding_mask 135 | 136 | 137 | def get_attn_pad_mask(padded_input, input_lengths, expand_length): 138 | """mask position is set to 1""" 139 | # N x Ti x 1 140 | non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) 141 | # N x Ti, lt(1) like not operation 142 | pad_mask = non_pad_mask.squeeze(-1).lt(1) 143 | attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) 144 | return attn_mask 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import unicodedata 4 | 5 | import torch 6 | 7 | 8 | def clip_gradient(optimizer, grad_clip): 9 | """ 10 | Clips gradients computed during backpropagation to avoid explosion of gradients. 11 | :param optimizer: optimizer with the gradients to be clipped 12 | :param grad_clip: clip value 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best): 21 | state = {'epoch': epoch, 22 | 'epochs_since_improvement': epochs_since_improvement, 23 | 'loss': loss, 24 | 'model': model, 25 | 'optimizer': optimizer} 26 | 27 | filename = 'checkpoint.tar' 28 | torch.save(state, filename) 29 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 30 | if is_best: 31 | torch.save(state, 'BEST_checkpoint.tar') 32 | 33 | 34 | class AverageMeter(object): 35 | """ 36 | Keeps track of most recent, average, sum, and count of a metric. 37 | """ 38 | 39 | def __init__(self): 40 | self.reset() 41 | 42 | def reset(self): 43 | self.val = 0 44 | self.avg = 0 45 | self.sum = 0 46 | self.count = 0 47 | 48 | def update(self, val, n=1): 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | 55 | def adjust_learning_rate(optimizer, shrink_factor): 56 | """ 57 | Shrinks learning rate by a specified factor. 58 | :param optimizer: optimizer whose learning rate must be shrunk. 59 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 60 | """ 61 | 62 | print("\nDECAYING learning rate.") 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = param_group['lr'] * shrink_factor 65 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 66 | 67 | 68 | def accuracy(scores, targets, k=1): 69 | batch_size = targets.size(0) 70 | _, ind = scores.topk(k, 1, True, True) 71 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 72 | correct_total = correct.view(-1).float().sum() # 0D tensor 73 | return correct_total.item() * (100.0 / batch_size) 74 | 75 | 76 | def parse_args(): 77 | parser = argparse.ArgumentParser(description='Transformer') 78 | 79 | # Network architecture 80 | # encoder 81 | # TODO: automatically infer input dim 82 | parser.add_argument('--n_layers_enc', default=6, type=int, 83 | help='Number of encoder stacks') 84 | parser.add_argument('--n_head', default=8, type=int, 85 | help='Number of Multi Head Attention (MHA)') 86 | parser.add_argument('--d_k', default=64, type=int, 87 | help='Dimension of key') 88 | parser.add_argument('--d_v', default=64, type=int, 89 | help='Dimension of value') 90 | parser.add_argument('--d_model', default=512, type=int, 91 | help='Dimension of model') 92 | parser.add_argument('--d_inner', default=2048, type=int, 93 | help='Dimension of inner') 94 | parser.add_argument('--dropout', default=0.1, type=float, 95 | help='Dropout rate') 96 | parser.add_argument('--pe_maxlen', default=5000, type=int, 97 | help='Positional Encoding max len') 98 | # decoder 99 | parser.add_argument('--d_word_vec', default=512, type=int, 100 | help='Dim of decoder embedding') 101 | parser.add_argument('--n_layers_dec', default=6, type=int, 102 | help='Number of decoder stacks') 103 | parser.add_argument('--tgt_emb_prj_weight_sharing', default=1, type=int, 104 | help='share decoder embedding with decoder projection') 105 | # Loss 106 | parser.add_argument('--label_smoothing', default=0.1, type=float, 107 | help='label smoothing') 108 | 109 | # Training config 110 | parser.add_argument('--epochs', default=1000, type=int, 111 | help='Number of maximum epochs') 112 | # minibatch 113 | parser.add_argument('--shuffle', default=1, type=int, 114 | help='reshuffle the data at every epoch') 115 | parser.add_argument('--batch-size', default=128, type=int, 116 | help='Batch size') 117 | parser.add_argument('--batch_frames', default=0, type=int, 118 | help='Batch frames. If this is not 0, batch size will make no sense') 119 | parser.add_argument('--maxlen-in', default=50, type=int, metavar='ML', 120 | help='Batch size is reduced if the input sequence length > ML') 121 | parser.add_argument('--maxlen-out', default=25, type=int, metavar='ML', 122 | help='Batch size is reduced if the output sequence length > ML') 123 | parser.add_argument('--num-workers', default=8, type=int, 124 | help='Number of workers to generate minibatch') 125 | # optimizer 126 | parser.add_argument('--k', default=0.2, type=float, 127 | help='tunable scalar multiply to learning rate') 128 | parser.add_argument('--warmup_steps', default=4000, type=int, 129 | help='warmup steps') 130 | 131 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 132 | args = parser.parse_args() 133 | return args 134 | 135 | 136 | def ensure_folder(folder): 137 | import os 138 | if not os.path.isdir(folder): 139 | os.mkdir(folder) 140 | 141 | 142 | def pad_list(xs, pad_value): 143 | # From: espnet/src/nets/e2e_asr_th.py: pad_list() 144 | n_batch = len(xs) 145 | max_len = max(x.size(0) for x in xs) 146 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 147 | for i in range(n_batch): 148 | pad[i, :xs[i].size(0)] = xs[i] 149 | return pad 150 | 151 | 152 | def text_to_sequence(text, char2idx): 153 | result = [char2idx[char] for char in text] 154 | return result 155 | 156 | 157 | def sequence_to_text(seq, idx2char): 158 | result = [idx2char[idx] for idx in seq] 159 | return result 160 | 161 | 162 | # Turn a Unicode string to plain ASCII, thanks to 163 | # http://stackoverflow.com/a/518232/2809427 164 | def unicodeToAscii(s): 165 | return ''.join( 166 | c for c in unicodedata.normalize('NFD', s) 167 | if unicodedata.category(c) != 'Mn' 168 | ) 169 | 170 | 171 | def normalizeString(s): 172 | s = unicodeToAscii(s.lower().strip()) 173 | s = re.sub(r"([.!?])", r" \1", s) 174 | s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 175 | return s 176 | 177 | 178 | def encode_text(word_map, c): 179 | return [word_map.get(word, word_map['']) for word in c] 180 | -------------------------------------------------------------------------------- /vocab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Transformer/3246970b825a87f2c035edac545000716c0d96f7/vocab.pkl --------------------------------------------------------------------------------