├── README.md
└── poetry_gen.py
/README.md:
--------------------------------------------------------------------------------
1 | # RNN_poetry_generator
2 |
3 | 基于RNN生成古诗
4 |
5 | ### 环境
6 |
7 | - python3.6
8 | - tensorflow 1.2.0
9 |
10 | ### 使用
11 |
12 | - 训练:
13 |
14 | python poetry_gen.py --mode train
15 |
16 | - 生成:
17 |
18 | python poetry_gen.py 或者 python poetry_gen.py --mode sample
19 |
20 | - 生成藏头诗:
21 |
22 | python poetry_gen.py --mode sample --head 明月别枝惊鹊
23 |
24 | > 生成藏头诗 ---> 明月别枝惊鹊
25 |
26 | > 明年襟宠任,月出画床帘。别有平州伯性悔,枝边折得李桑迷。惊腰每异年三杰,鹊出交钟玉笛频。
27 |
28 | ### 帮助
29 |
30 | python poetry_gen.py --help
31 |
32 |
33 | usage: poetry_gen.py [-h] [--mode MODE] [--head HEAD]
34 |
35 | optional arguments:
36 | -h, --help show this help message and exit
37 | --mode MODE usage: train or sample, sample is default
38 | --head HEAD 生成藏头诗
39 |
40 |
--------------------------------------------------------------------------------
/poetry_gen.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import argparse
4 | import sys
5 | import os
6 | import time
7 | import numpy as np
8 | import collections
9 | import tensorflow as tf
10 | import tensorflow.contrib.rnn as rnn
11 | import tensorflow.contrib.legacy_seq2seq as seq2seq
12 |
13 | BEGIN_CHAR = '^'
14 | END_CHAR = '$'
15 | UNKNOWN_CHAR = '*'
16 | MAX_LENGTH = 100
17 | MIN_LENGTH = 10
18 | max_words = 3000
19 | epochs = 50
20 | poetry_file = 'poetry.txt'
21 | save_dir = 'log'
22 |
23 |
24 | class Data:
25 | def __init__(self):
26 | self.batch_size = 64
27 | self.poetry_file = poetry_file
28 | self.load()
29 | self.create_batches()
30 |
31 | def load(self):
32 | def handle(line):
33 | if len(line) > MAX_LENGTH:
34 | index_end = line.rfind('。', 0, MAX_LENGTH)
35 | index_end = index_end if index_end > 0 else MAX_LENGTH
36 | line = line[:index_end + 1]
37 | return BEGIN_CHAR + line + END_CHAR
38 |
39 | self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in
40 | open(self.poetry_file, encoding='utf-8')]
41 | self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]
42 | # 所有字
43 | words = []
44 | for poetry in self.poetrys:
45 | words += [word for word in poetry]
46 | counter = collections.Counter(words)
47 | count_pairs = sorted(counter.items(), key=lambda x: -x[1])
48 | words, _ = zip(*count_pairs)
49 |
50 | # 取出现频率最高的词的数量组成字典,不在字典中的字用'*'代替
51 | words_size = min(max_words, len(words))
52 | self.words = words[:words_size] + (UNKNOWN_CHAR,)
53 | self.words_size = len(self.words)
54 |
55 | # 字映射成id
56 | self.char2id_dict = {w: i for i, w in enumerate(self.words)}
57 | self.id2char_dict = {i: w for i, w in enumerate(self.words)}
58 | self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
59 | self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
60 | self.id2char = lambda num: self.id2char_dict.get(num)
61 | self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
62 | self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]
63 |
64 | def create_batches(self):
65 | self.n_size = len(self.poetrys_vector) // self.batch_size
66 | self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
67 | self.x_batches = []
68 | self.y_batches = []
69 | for i in range(self.n_size):
70 | batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
71 | length = max(map(len, batches))
72 | for row in range(self.batch_size):
73 | if len(batches[row]) < length:
74 | r = length - len(batches[row])
75 | batches[row][len(batches[row]): length] = [self.unknow_char] * r
76 | xdata = np.array(batches)
77 | ydata = np.copy(xdata)
78 | ydata[:, :-1] = xdata[:, 1:]
79 | self.x_batches.append(xdata)
80 | self.y_batches.append(ydata)
81 |
82 |
83 | class Model:
84 | def __init__(self, data, model='lstm', infer=False):
85 | self.rnn_size = 128
86 | self.n_layers = 2
87 |
88 | if infer:
89 | self.batch_size = 1
90 | else:
91 | self.batch_size = data.batch_size
92 |
93 | if model == 'rnn':
94 | cell_rnn = rnn.BasicRNNCell
95 | elif model == 'gru':
96 | cell_rnn = rnn.GRUCell
97 | elif model == 'lstm':
98 | cell_rnn = rnn.BasicLSTMCell
99 |
100 | cell = cell_rnn(self.rnn_size, state_is_tuple=False)
101 | self.cell = rnn.MultiRNNCell([cell] * self.n_layers, state_is_tuple=False)
102 |
103 | self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
104 | self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])
105 |
106 | self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)
107 |
108 | with tf.variable_scope('rnnlm'):
109 | softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
110 | softmax_b = tf.get_variable("softmax_b", [data.words_size])
111 | with tf.device("/cpu:0"):
112 | embedding = tf.get_variable(
113 | "embedding", [data.words_size, self.rnn_size])
114 | inputs = tf.nn.embedding_lookup(embedding, self.x_tf)
115 |
116 | outputs, final_state = tf.nn.dynamic_rnn(
117 | self.cell, inputs, initial_state=self.initial_state, scope='rnnlm')
118 |
119 | self.output = tf.reshape(outputs, [-1, self.rnn_size])
120 | self.logits = tf.matmul(self.output, softmax_w) + softmax_b
121 | self.probs = tf.nn.softmax(self.logits)
122 | self.final_state = final_state
123 | pred = tf.reshape(self.y_tf, [-1])
124 | # seq2seq
125 | loss = seq2seq.sequence_loss_by_example([self.logits],
126 | [pred],
127 | [tf.ones_like(pred, dtype=tf.float32)],)
128 |
129 | self.cost = tf.reduce_mean(loss)
130 | self.learning_rate = tf.Variable(0.0, trainable=False)
131 | tvars = tf.trainable_variables()
132 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)
133 |
134 | optimizer = tf.train.AdamOptimizer(self.learning_rate)
135 | self.train_op = optimizer.apply_gradients(zip(grads, tvars))
136 |
137 |
138 | def train(data, model):
139 | with tf.Session() as sess:
140 | sess.run(tf.global_variables_initializer())
141 | saver = tf.train.Saver(tf.global_variables())
142 | model_file = tf.train.latest_checkpoint(save_dir)
143 | saver.restore(sess, model_file)
144 | n = 0
145 | for epoch in range(epochs):
146 | sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
147 | pointer = 0
148 | for batche in range(data.n_size):
149 | n += 1
150 | feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
151 | pointer += 1
152 | train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
153 | sys.stdout.write('\r')
154 | info = "{}/{} (epoch {}) | train_loss {:.3f}" \
155 | .format(epoch * data.n_size + batche,
156 | epochs * data.n_size, epoch, train_loss)
157 | sys.stdout.write(info)
158 | sys.stdout.flush()
159 | # save
160 | if (epoch * data.n_size + batche) % 1000 == 0 \
161 | or (epoch == epochs-1 and batche == data.n_size-1):
162 | checkpoint_path = os.path.join(save_dir, 'model.ckpt')
163 | saver.save(sess, checkpoint_path, global_step=n)
164 | sys.stdout.write('\n')
165 | print("model saved to {}".format(checkpoint_path))
166 | sys.stdout.write('\n')
167 |
168 |
169 | def sample(data, model, head=u''):
170 | def to_word(weights):
171 | t = np.cumsum(weights)
172 | s = np.sum(weights)
173 | sa = int(np.searchsorted(t, np.random.rand(1) * s))
174 | return data.id2char(sa)
175 |
176 | for word in head:
177 | if word not in data.words:
178 | return u'{} 不在字典中'.format(word)
179 |
180 | with tf.Session() as sess:
181 | sess.run(tf.global_variables_initializer())
182 |
183 | saver = tf.train.Saver(tf.global_variables())
184 | model_file = tf.train.latest_checkpoint(save_dir)
185 | # print(model_file)
186 | saver.restore(sess, model_file)
187 |
188 | if head:
189 | print('生成藏头诗 ---> ', head)
190 | poem = BEGIN_CHAR
191 | for head_word in head:
192 | poem += head_word
193 | x = np.array([list(map(data.char2id, poem))])
194 | state = sess.run(model.cell.zero_state(1, tf.float32))
195 | feed_dict = {model.x_tf: x, model.initial_state: state}
196 | [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
197 | word = to_word(probs[-1])
198 | while word != u',' and word != u'。':
199 | poem += word
200 | x = np.zeros((1, 1))
201 | x[0, 0] = data.char2id(word)
202 | [probs, state] = sess.run([model.probs, model.final_state],
203 | {model.x_tf: x, model.initial_state: state})
204 | word = to_word(probs[-1])
205 | poem += word
206 | return poem[1:]
207 | else:
208 | poem = ''
209 | head = BEGIN_CHAR
210 | x = np.array([list(map(data.char2id, head))])
211 | state = sess.run(model.cell.zero_state(1, tf.float32))
212 | feed_dict = {model.x_tf: x, model.initial_state: state}
213 | [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
214 | word = to_word(probs[-1])
215 | while word != END_CHAR:
216 | poem += word
217 | x = np.zeros((1, 1))
218 | x[0, 0] = data.char2id(word)
219 | [probs, state] = sess.run([model.probs, model.final_state],
220 | {model.x_tf: x, model.initial_state: state})
221 | word = to_word(probs[-1])
222 | return poem
223 |
224 |
225 | def main():
226 | msg = """
227 | Usage:
228 | Training:
229 | python poetry_gen.py --mode train
230 | Sampling:
231 | python poetry_gen.py --mode sample --head 明月别枝惊鹊
232 | """
233 | parser = argparse.ArgumentParser()
234 | parser.add_argument('--mode', type=str, default='sample',
235 | help=u'usage: train or sample, sample is default')
236 | parser.add_argument('--head', type=str, default='',
237 | help='生成藏头诗')
238 |
239 | args = parser.parse_args()
240 |
241 | if args.mode == 'sample':
242 | infer = True # True
243 | data = Data()
244 | model = Model(data=data, infer=infer)
245 | print(sample(data, model, head=args.head))
246 | elif args.mode == 'train':
247 | infer = False
248 | data = Data()
249 | model = Model(data=data, infer=infer)
250 | print(train(data, model))
251 | else:
252 | print(msg)
253 |
254 |
255 | if __name__ == '__main__':
256 | main()
257 |
--------------------------------------------------------------------------------