├── .gitignore ├── LICENSE ├── README.md ├── data ├── __MACOSX │ └── 《刘慈欣作品全集》(v1.0) │ │ └── ._.DS_Store ├── liucixin.txt ├── sanguoyanyi.txt ├── santi.txt └── 《刘慈欣作品全集》(v1.0).zip ├── data_zh.py ├── images ├── result-1.PNG ├── result-2.PNG ├── result-3.PNG ├── result-4.PNG ├── result-5.PNG ├── result-6.PNG ├── result-7.PNG ├── result-8.PNG └── result-9.PNG ├── main.py ├── model.py ├── preproc.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | checkpoints/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | # Language-Model 8 | 9 | 基于 PyTorch [范例](https://github.com/pytorch/examples/tree/master/word_language_model) 实现中文语言模型。 10 | 11 | ## 数据集 12 | 13 | 《三国演义》 + 《三体》 14 | 15 | ## 模型 16 | 17 |
18 | RNNModel(
19 |   (drop): Dropout(p=0.5)
20 |   (encoder): Embedding(3719, 200)
21 |   (rnn): GRU(200, 200, num_layers=2, dropout=0.5)
22 |   (decoder): Linear(in_features=200, out_features=3719, bias=True)
23 | )
24 | 
25 | 26 | ## 生成效果 27 | 28 | ### 《三国演义》 29 | 30 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-1.PNG) 31 | 32 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-2.PNG) 33 | 34 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-3.PNG) 35 | 36 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-4.PNG) 37 | 38 | ### 《三体》 39 | 40 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-5.PNG) 41 | 42 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-6.PNG) 43 | 44 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-7.PNG) 45 | 46 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-8.PNG) 47 | 48 | ![image](https://github.com/foamliu/Language-Model/raw/master/images/result-9.PNG) 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /data/__MACOSX/《刘慈欣作品全集》(v1.0)/._.DS_Store: -------------------------------------------------------------------------------- 1 | Mac OS X  2Fx ATTRxx -------------------------------------------------------------------------------- /data/《刘慈欣作品全集》(v1.0).zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/data/《刘慈欣作品全集》(v1.0).zip -------------------------------------------------------------------------------- /data_zh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | 7 | 8 | class Dictionary(object): 9 | """ 10 | 词汇表,将文本中的词转换为数字id表示。 11 | """ 12 | 13 | def __init__(self): 14 | self.word2idx = {} 15 | self.idx2word = [] 16 | 17 | def add_word(self, word): 18 | if word not in self.word2idx: 19 | self.idx2word.append(word) 20 | self.word2idx[word] = len(self.idx2word) - 1 21 | 22 | def __len__(self): 23 | return len(self.idx2word) 24 | 25 | 26 | class Corpus(object): 27 | """ 28 | 文本预处理,获取词汇表,并将字符串文本转换为数字序列。 29 | """ 30 | 31 | def __init__(self, path): 32 | self.dictionary = Dictionary() 33 | self.train = self.tokenize(path) 34 | 35 | def tokenize(self, path): 36 | """文本符号化,转换为数字id表示。""" 37 | assert os.path.exists(path) 38 | 39 | # 将新词加入到词汇表中 40 | with open(path, 'r', encoding='utf-8') as f: 41 | tokens = 0 42 | for line in f: 43 | if len(line.strip()) == 0: # 过滤空的行 44 | continue 45 | words = list(line.strip()) + [''] # 此处与原文档不同,基于字符级 46 | tokens += len(words) 47 | for word in words: 48 | self.dictionary.add_word(word) 49 | 50 | # 将字符转换为数字 51 | with open(path, 'r', encoding='utf-8') as f: 52 | ids = torch.LongTensor(tokens) 53 | token = 0 54 | for line in f: 55 | if len(line.strip()) == 0: # 过滤空的行 56 | continue 57 | words = list(line.strip()) + [''] # 此处与原文档不同,基于字符级 58 | for word in words: 59 | ids[token] = self.dictionary.word2idx[word] 60 | token += 1 61 | 62 | return ids 63 | 64 | def __repr__(self): 65 | return "Corpus length: %d, Vocabulary size: %d" % (self.train.size(0), len(self.dictionary)) 66 | -------------------------------------------------------------------------------- /images/result-1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-1.PNG -------------------------------------------------------------------------------- /images/result-2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-2.PNG -------------------------------------------------------------------------------- /images/result-3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-3.PNG -------------------------------------------------------------------------------- /images/result-4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-4.PNG -------------------------------------------------------------------------------- /images/result-5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-5.PNG -------------------------------------------------------------------------------- /images/result-6.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-6.PNG -------------------------------------------------------------------------------- /images/result-7.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-7.PNG -------------------------------------------------------------------------------- /images/result-8.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-8.PNG -------------------------------------------------------------------------------- /images/result-9.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/Language-Model/86e06189554c2e2dd232ae2cb9b9013b80bfc892/images/result-9.PNG -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | import os 6 | import time 7 | from datetime import timedelta 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | 13 | from data_zh import Corpus 14 | from model import RNNModel 15 | 16 | 17 | class Config(object): 18 | """RNNLM模型配置项""" 19 | embedding_dim = 200 # 词向量维度 20 | 21 | rnn_type = 'GRU' # 支持RNN/LSTM/GRU 22 | hidden_dim = 200 # 隐藏层维度 23 | num_layers = 2 # RNN 层数 24 | 25 | dropout = 0.5 # 丢弃概率 26 | tie_weights = True # 是否绑定参数 27 | 28 | batch_size = 10 # 每一批数据量 29 | seq_len = 30 # 序列长度 30 | 31 | clip = 0.25 # 用于梯度规范化 32 | learning_rate = 1 # 初始学习率 33 | 34 | num_epochs = 50 # 迭代轮次 35 | log_interval = 500 # 每隔多少个批次输出一次状态 36 | save_interval = 3 # 每个多少个轮次保存一次参数 37 | 38 | 39 | def batchify(data, bsz): 40 | """返回数据维度为(nbatch, batch_size)""" 41 | nbatch = data.size(0) // bsz 42 | data = data.narrow(0, 0, nbatch * bsz) # 去除多余部分 43 | data = data.view(bsz, -1).t().contiguous() # 将数据按照bsz切分 44 | return data 45 | 46 | 47 | def get_batch(source, i, seq_len, evaluation=False): 48 | """ 49 | 获取一个batch 50 | data: (seq_len, batch_size) 51 | target: (seq_len * batch_size) 52 | """ 53 | seq_len = min(seq_len, len(source) - 1 - i) 54 | data = Variable(source[i:(i + seq_len)], volatile=evaluation) 55 | target = Variable(source[(i + 1):(i + 1 + seq_len)].view(-1)) # 为训练方便,展平 56 | if use_cuda: 57 | data, target = data.cuda(), target.cuda() 58 | return data, target 59 | 60 | 61 | def repackage_hidden(h): 62 | """用新的变量重新包装隐藏层,将它们从历史中分离。""" 63 | return Variable(h.data) 64 | 65 | 66 | def get_time_dif(start_time): 67 | """获取已使用时间""" 68 | end_time = time.time() 69 | time_dif = end_time - start_time 70 | return timedelta(seconds=int(round(time_dif))) 71 | 72 | 73 | def generate(model, idx2word, word_len=200, temperature=1.0): 74 | """生成一定数量的文本,temperature结合多项式分布可增添抽样的多样性。""" 75 | model.eval() 76 | hidden = model.init_hidden(1) # batch_size为1 77 | inputs = Variable(torch.rand(1, 1).mul(len(idx2word)).long()) # 随机选取一个字作为开始 78 | if use_cuda: 79 | inputs = inputs.cuda() 80 | 81 | word_list = [] 82 | for i in range(word_len): # 逐字生成 83 | output, hidden = model(inputs, hidden) 84 | word_weights = output.squeeze().data.div(temperature).exp().cpu() 85 | 86 | # 基于词的权重,对其再进行一次抽样,增添其多样性,如果不使用此法,会导致常用字的无限循环 87 | word_idx = torch.multinomial(word_weights, 1)[0] 88 | inputs.data.fill_(word_idx) # 将新生成的字赋给inputs 89 | word = idx2word[word_idx] 90 | word_list.append(word) 91 | return word_list 92 | 93 | 94 | def train(): 95 | model.train() # 在训练模式下dropout才可用。 96 | total_loss = 0.0 97 | start_time = time.time() 98 | hidden = model.init_hidden(config.batch_size) # 初始化隐藏层参数 99 | # print('hidden: ' + str(hidden)) 100 | 101 | for ibatch, i in enumerate(range(0, train_len - 1, seq_len)): 102 | data, targets = get_batch(train_data, i, seq_len) # 取一个批次的数据 103 | # 在每批开始之前,将隐藏的状态与之前产生的结果分离。 104 | # 如果不这样做,模型会尝试反向传播到数据集的起点。 105 | hidden = repackage_hidden(hidden) 106 | model.zero_grad() 107 | optimizer.zero_grad() 108 | 109 | output, hidden = model(data, hidden) 110 | loss = criterion(output.view(-1, config.vocab_size), targets) 111 | loss.backward() # 反向传播 112 | optimizer.step() 113 | 114 | # `clip_grad_norm` 有助于防止RNNs/LSTMs中的梯度爆炸问题。 115 | # torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip) 116 | # for p in model.parameters(): # 梯度更新 117 | # p.data.add_(-lr, p.grad.data) 118 | 119 | total_loss += loss.item() # loss累计 120 | 121 | if ibatch % config.log_interval == 0 and ibatch > 0: # 每隔多少个批次输出一次状态 122 | cur_loss = total_loss / config.log_interval 123 | elapsed = get_time_dif(start_time) 124 | print("Epoch {:3d}, {:5d}/{:5d} batches, lr {:2.3f}, loss {:5.2f}, ppl {:8.2f}, time {}".format( 125 | epoch, ibatch, train_len // seq_len, lr, cur_loss, math.exp(cur_loss), elapsed)) 126 | total_loss = 0.0 127 | start_time = time.time() 128 | 129 | return loss.item() 130 | 131 | 132 | def generate_flow(epoch=3): 133 | """读取存储的模型,生成新词""" 134 | corpus = Corpus(train_dir) 135 | config = Config() 136 | config.vocab_size = len(corpus.dictionary) 137 | 138 | model = RNNModel(config) 139 | model_file = os.path.join(save_dir, model_name.format(epoch)) 140 | assert os.path.exists(model_file), 'File %s does not exist.' % model_file 141 | model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage)) 142 | 143 | word_list = generate(model, corpus.dictionary.idx2word, word_len=50) 144 | print(''.join(word_list)) 145 | 146 | 147 | if __name__ == '__main__': 148 | train_dir = 'data/santi.txt' 149 | filename = str(os.path.basename(train_dir).split('.')[0]) 150 | 151 | # 用于保存模型参数 152 | save_dir = 'checkpoints/' + filename 153 | if not os.path.exists(save_dir): 154 | os.makedirs(save_dir) 155 | model_name = filename + '_{}.pt' 156 | 157 | use_cuda = torch.cuda.is_available() 158 | 159 | # 载入数据与配置模型 160 | print("Loading data...") 161 | corpus = Corpus(train_dir) 162 | print(corpus) 163 | 164 | config = Config() 165 | config.vocab_size = len(corpus.dictionary) 166 | train_data = batchify(corpus.train, config.batch_size) 167 | train_len = train_data.size(0) 168 | seq_len = config.seq_len 169 | 170 | print("Configuring model...") 171 | model = RNNModel(config) 172 | if use_cuda: 173 | model.cuda() 174 | print(model) 175 | 176 | criterion = nn.CrossEntropyLoss() 177 | lr = config.learning_rate # 初始学习率 178 | best_train_loss = None 179 | optimizer = torch.optim.Adam(model.parameters()) 180 | 181 | print("Training and generating...") 182 | try: 183 | for epoch in range(1, config.num_epochs + 1): # 多轮次训练 184 | epoch_start_time = time.time() 185 | train_loss = train() 186 | 187 | print('-' * 89) 188 | print('| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} | ' 189 | 'train ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 190 | train_loss, math.exp(train_loss))) 191 | print('-' * 89) 192 | # Save the model if the validation loss is the best we've seen so far. 193 | if not best_train_loss or train_loss < best_train_loss: 194 | best_train_loss = train_loss 195 | else: 196 | # 如果验证数据集中未见任何改进,则缩小学习率。 197 | lr /= 2.0 198 | 199 | # 每隔多少轮次保存一次模型参数 200 | if epoch % config.save_interval == 0: 201 | torch.save(model.state_dict(), os.path.join(save_dir, model_name.format(epoch))) 202 | 203 | with torch.no_grad(): 204 | print() 205 | print(''.join(generate(model, corpus.dictionary.idx2word))) 206 | print() 207 | except KeyboardInterrupt: 208 | print('-' * 89) 209 | print('Exiting from training early') 210 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class RNNModel(nn.Module): 9 | """基于RNN的语言模型,包含一个encoder,一个rnn模块,一个decoder。""" 10 | 11 | def __init__(self, config): 12 | super(RNNModel, self).__init__() 13 | 14 | v_size = config.vocab_size 15 | em_dim = config.embedding_dim 16 | 17 | rnn_type = config.rnn_type 18 | hi_dim = config.hidden_dim 19 | n_layers = config.num_layers 20 | 21 | dropout = config.dropout 22 | tie_weights = config.tie_weights 23 | 24 | self.drop = nn.Dropout(dropout) # dropout层 25 | self.encoder = nn.Embedding(v_size, em_dim) # encoder是一个embedding层 26 | 27 | print('rnn_type: ' + str(rnn_type)) 28 | if rnn_type in ['RNN', 'LSTM', 'GRU']: 29 | self.rnn = getattr(nn, rnn_type)(em_dim, hi_dim, n_layers, dropout=dropout) 30 | else: 31 | raise ValueError("""'rnn_type' error, options are ['RNN', 'LSTM', 'GRU']""") 32 | 33 | self.decoder = nn.Linear(hi_dim, v_size) # decoder将向量映射到字 34 | 35 | # tie_weights将encoder和decoder的参数绑定为同一参数,在以下两篇论文中得到了证明: 36 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 37 | # https://arxiv.org/abs/1608.05859 38 | # 以及 39 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 40 | # https://arxiv.org/abs/1611.01462 41 | if tie_weights: 42 | if hi_dim != em_dim: # 这两个维度必须相同 43 | raise ValueError('When using the tied flag, hi_dim must be equal to em_dim') 44 | self.decoder.weight = self.encoder.weight 45 | 46 | self.init_weights() # 初始化权重 47 | 48 | self.rnn_type = rnn_type 49 | self.hi_dim = hi_dim 50 | self.n_layers = n_layers 51 | 52 | def forward(self, inputs, hidden): 53 | emb = self.drop(self.encoder(inputs)) # encoder + dropout 54 | output, hidden = self.rnn(emb, hidden) # output维度:(seq_len, batch_size, hidden_dim) 55 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) # 展平,映射 56 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden # 复原 57 | 58 | def init_weights(self): 59 | """权重初始化,如果tie_weights,则encoder和decoder权重是相同的""" 60 | init_range = 0.1 61 | self.encoder.weight.data.uniform_(-init_range, init_range) 62 | self.decoder.weight.data.uniform_(-init_range, init_range) 63 | self.decoder.bias.data.fill_(0) 64 | 65 | def init_hidden(self, bsz): 66 | """初始化隐藏层,与batch_size相关""" 67 | weight = next(self.parameters()).data 68 | if self.rnn_type == 'LSTM': # lstm:(h0, c0) 69 | return (Variable(weight.new(self.n_layers, bsz, self.hi_dim).zero_()), 70 | Variable(weight.new(self.n_layers, bsz, self.hi_dim).zero_())) 71 | else: # gru 和 rnn:h0 72 | return Variable(weight.new(self.n_layers, bsz, self.hi_dim).zero_()) 73 | -------------------------------------------------------------------------------- /preproc.py: -------------------------------------------------------------------------------- 1 | from os import walk 2 | 3 | from os.path import join 4 | 5 | 6 | def cleanse(content): 7 | content = content.replace('\u3000', '') 8 | return content 9 | 10 | 11 | def load_file(folder): 12 | paths = [join(dirpath, name) 13 | for dirpath, dirs, files in walk(folder) 14 | for name in files 15 | if not name.startswith('.')] 16 | concat = '' 17 | for path in paths: 18 | with open(path, 'r', encoding='utf-8') as myfile: 19 | # print(path) 20 | content = myfile.read() 21 | concat += cleanse(content) 22 | 23 | return concat 24 | 25 | 26 | if __name__ == '__main__': 27 | folder = 'data/《刘慈欣作品全集》(v1.0)' 28 | concat = load_file(folder) 29 | print(concat[20000:20000 + 100]) 30 | print("Full text length %d" % len(concat)) 31 | 32 | with open('data/liucixin.txt', 'w', encoding='utf-8') as file: 33 | file.write(concat) 34 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from data_zh import * 5 | 6 | train_dir = 'data/liucixin.txt' 7 | corpus = Corpus(train_dir) 8 | print("刘慈欣作品全集:", corpus) 9 | --------------------------------------------------------------------------------