├── README.md └── poetry_gen.py /README.md: -------------------------------------------------------------------------------- 1 | # RNN_poetry_generator 2 | 3 | 基于RNN生成古诗 4 | 5 | ### 环境 6 | 7 | - python3.6 8 | - tensorflow 1.2.0 9 | 10 | ### 使用 11 | 12 | - 训练: 13 | 14 | python poetry_gen.py --mode train 15 | 16 | - 生成: 17 | 18 | python poetry_gen.py 或者 python poetry_gen.py --mode sample 19 | 20 | - 生成藏头诗: 21 | 22 | python poetry_gen.py --mode sample --head 明月别枝惊鹊 23 | 24 | > 生成藏头诗 ---> 明月别枝惊鹊 25 | 26 | > 明年襟宠任,月出画床帘。别有平州伯性悔,枝边折得李桑迷。惊腰每异年三杰,鹊出交钟玉笛频。 27 | 28 | ### 帮助 29 | 30 | python poetry_gen.py --help 31 | 32 | 33 | usage: poetry_gen.py [-h] [--mode MODE] [--head HEAD] 34 | 35 | optional arguments: 36 | -h, --help show this help message and exit 37 | --mode MODE usage: train or sample, sample is default 38 | --head HEAD 生成藏头诗 39 | 40 | -------------------------------------------------------------------------------- /poetry_gen.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import argparse 4 | import sys 5 | import os 6 | import time 7 | import numpy as np 8 | import collections 9 | import tensorflow as tf 10 | import tensorflow.contrib.rnn as rnn 11 | import tensorflow.contrib.legacy_seq2seq as seq2seq 12 | 13 | BEGIN_CHAR = '^' 14 | END_CHAR = '$' 15 | UNKNOWN_CHAR = '*' 16 | MAX_LENGTH = 100 17 | MIN_LENGTH = 10 18 | max_words = 3000 19 | epochs = 50 20 | poetry_file = 'poetry.txt' 21 | save_dir = 'log' 22 | 23 | 24 | class Data: 25 | def __init__(self): 26 | self.batch_size = 64 27 | self.poetry_file = poetry_file 28 | self.load() 29 | self.create_batches() 30 | 31 | def load(self): 32 | def handle(line): 33 | if len(line) > MAX_LENGTH: 34 | index_end = line.rfind('。', 0, MAX_LENGTH) 35 | index_end = index_end if index_end > 0 else MAX_LENGTH 36 | line = line[:index_end + 1] 37 | return BEGIN_CHAR + line + END_CHAR 38 | 39 | self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in 40 | open(self.poetry_file, encoding='utf-8')] 41 | self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH] 42 | # 所有字 43 | words = [] 44 | for poetry in self.poetrys: 45 | words += [word for word in poetry] 46 | counter = collections.Counter(words) 47 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 48 | words, _ = zip(*count_pairs) 49 | 50 | # 取出现频率最高的词的数量组成字典,不在字典中的字用'*'代替 51 | words_size = min(max_words, len(words)) 52 | self.words = words[:words_size] + (UNKNOWN_CHAR,) 53 | self.words_size = len(self.words) 54 | 55 | # 字映射成id 56 | self.char2id_dict = {w: i for i, w in enumerate(self.words)} 57 | self.id2char_dict = {i: w for i, w in enumerate(self.words)} 58 | self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR) 59 | self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char) 60 | self.id2char = lambda num: self.id2char_dict.get(num) 61 | self.poetrys = sorted(self.poetrys, key=lambda line: len(line)) 62 | self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys] 63 | 64 | def create_batches(self): 65 | self.n_size = len(self.poetrys_vector) // self.batch_size 66 | self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size] 67 | self.x_batches = [] 68 | self.y_batches = [] 69 | for i in range(self.n_size): 70 | batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size] 71 | length = max(map(len, batches)) 72 | for row in range(self.batch_size): 73 | if len(batches[row]) < length: 74 | r = length - len(batches[row]) 75 | batches[row][len(batches[row]): length] = [self.unknow_char] * r 76 | xdata = np.array(batches) 77 | ydata = np.copy(xdata) 78 | ydata[:, :-1] = xdata[:, 1:] 79 | self.x_batches.append(xdata) 80 | self.y_batches.append(ydata) 81 | 82 | 83 | class Model: 84 | def __init__(self, data, model='lstm', infer=False): 85 | self.rnn_size = 128 86 | self.n_layers = 2 87 | 88 | if infer: 89 | self.batch_size = 1 90 | else: 91 | self.batch_size = data.batch_size 92 | 93 | if model == 'rnn': 94 | cell_rnn = rnn.BasicRNNCell 95 | elif model == 'gru': 96 | cell_rnn = rnn.GRUCell 97 | elif model == 'lstm': 98 | cell_rnn = rnn.BasicLSTMCell 99 | 100 | cell = cell_rnn(self.rnn_size, state_is_tuple=False) 101 | self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False) 102 | 103 | self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None]) 104 | self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None]) 105 | 106 | self.initial_state = self.cell.zero_state(self.batch_size, tf.float32) 107 | 108 | with tf.variable_scope('rnnlm'): 109 | softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size]) 110 | softmax_b = tf.get_variable("softmax_b", [data.words_size]) 111 | with tf.device("/cpu:0"): 112 | embedding = tf.get_variable( 113 | "embedding", [data.words_size, self.rnn_size]) 114 | inputs = tf.nn.embedding_lookup(embedding, self.x_tf) 115 | 116 | outputs, final_state = tf.nn.dynamic_rnn( 117 | self.cell, inputs, initial_state=self.initial_state, scope='rnnlm') 118 | 119 | self.output = tf.reshape(outputs, [-1, self.rnn_size]) 120 | self.logits = tf.matmul(self.output, softmax_w) + softmax_b 121 | self.probs = tf.nn.softmax(self.logits) 122 | self.final_state = final_state 123 | pred = tf.reshape(self.y_tf, [-1]) 124 | # seq2seq 125 | loss = seq2seq.sequence_loss_by_example([self.logits], 126 | [pred], 127 | [tf.ones_like(pred, dtype=tf.float32)],) 128 | 129 | self.cost = tf.reduce_mean(loss) 130 | self.learning_rate = tf.Variable(0.0, trainable=False) 131 | tvars = tf.trainable_variables() 132 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5) 133 | 134 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 135 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 136 | 137 | 138 | def train(data, model): 139 | with tf.Session() as sess: 140 | sess.run(tf.global_variables_initializer()) 141 | saver = tf.train.Saver(tf.global_variables()) 142 | model_file = tf.train.latest_checkpoint(save_dir) 143 | saver.restore(sess, model_file) 144 | n = 0 145 | for epoch in range(epochs): 146 | sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch))) 147 | pointer = 0 148 | for batche in range(data.n_size): 149 | n += 1 150 | feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]} 151 | pointer += 1 152 | train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict) 153 | sys.stdout.write('\r') 154 | info = "{}/{} (epoch {}) | train_loss {:.3f}" \ 155 | .format(epoch * data.n_size + batche, 156 | epochs * data.n_size, epoch, train_loss) 157 | sys.stdout.write(info) 158 | sys.stdout.flush() 159 | # save 160 | if (epoch * data.n_size + batche) % 1000 == 0 \ 161 | or (epoch == epochs-1 and batche == data.n_size-1): 162 | checkpoint_path = os.path.join(save_dir, 'model.ckpt') 163 | saver.save(sess, checkpoint_path, global_step=n) 164 | sys.stdout.write('\n') 165 | print("model saved to {}".format(checkpoint_path)) 166 | sys.stdout.write('\n') 167 | 168 | 169 | def sample(data, model, head=u''): 170 | def to_word(weights): 171 | t = np.cumsum(weights) 172 | s = np.sum(weights) 173 | sa = int(np.searchsorted(t, np.random.rand(1) * s)) 174 | return data.id2char(sa) 175 | 176 | for word in head: 177 | if word not in data.words: 178 | return u'{} 不在字典中'.format(word) 179 | 180 | with tf.Session() as sess: 181 | sess.run(tf.global_variables_initializer()) 182 | 183 | saver = tf.train.Saver(tf.global_variables()) 184 | model_file = tf.train.latest_checkpoint(save_dir) 185 | # print(model_file) 186 | saver.restore(sess, model_file) 187 | 188 | if head: 189 | print('生成藏头诗 ---> ', head) 190 | poem = BEGIN_CHAR 191 | for head_word in head: 192 | poem += head_word 193 | x = np.array([list(map(data.char2id, poem))]) 194 | state = sess.run(model.cell.zero_state(1, tf.float32)) 195 | feed_dict = {model.x_tf: x, model.initial_state: state} 196 | [probs, state] = sess.run([model.probs, model.final_state], feed_dict) 197 | word = to_word(probs[-1]) 198 | while word != u',' and word != u'。': 199 | poem += word 200 | x = np.zeros((1, 1)) 201 | x[0, 0] = data.char2id(word) 202 | [probs, state] = sess.run([model.probs, model.final_state], 203 | {model.x_tf: x, model.initial_state: state}) 204 | word = to_word(probs[-1]) 205 | poem += word 206 | return poem[1:] 207 | else: 208 | poem = '' 209 | head = BEGIN_CHAR 210 | x = np.array([list(map(data.char2id, head))]) 211 | state = sess.run(model.cell.zero_state(1, tf.float32)) 212 | feed_dict = {model.x_tf: x, model.initial_state: state} 213 | [probs, state] = sess.run([model.probs, model.final_state], feed_dict) 214 | word = to_word(probs[-1]) 215 | while word != END_CHAR: 216 | poem += word 217 | x = np.zeros((1, 1)) 218 | x[0, 0] = data.char2id(word) 219 | [probs, state] = sess.run([model.probs, model.final_state], 220 | {model.x_tf: x, model.initial_state: state}) 221 | word = to_word(probs[-1]) 222 | return poem 223 | 224 | 225 | def main(): 226 | msg = """ 227 | Usage: 228 | Training: 229 | python poetry_gen.py --mode train 230 | Sampling: 231 | python poetry_gen.py --mode sample --head 明月别枝惊鹊 232 | """ 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument('--mode', type=str, default='sample', 235 | help=u'usage: train or sample, sample is default') 236 | parser.add_argument('--head', type=str, default='', 237 | help='生成藏头诗') 238 | 239 | args = parser.parse_args() 240 | 241 | if args.mode == 'sample': 242 | infer = True # True 243 | data = Data() 244 | model = Model(data=data, infer=infer) 245 | print(sample(data, model, head=args.head)) 246 | elif args.mode == 'train': 247 | infer = False 248 | data = Data() 249 | model = Model(data=data, infer=infer) 250 | print(train(data, model)) 251 | else: 252 | print(msg) 253 | 254 | 255 | if __name__ == '__main__': 256 | main() 257 | --------------------------------------------------------------------------------