├── README.md
├── char_rnn_model.py
├── config_poem.py
├── data
└── poem
│ ├── poem_ids.txt
│ ├── poems_edge_split.txt
│ ├── rhyme_words.txt
│ └── vectors_poem.bin
├── data_loader.py
├── doc
├── client.png
└── train.png
├── poem_server.py
├── rhyme_helper.py
├── train.py
├── word2vec_helper.py
└── write_poem.py
/README.md:
--------------------------------------------------------------------------------
1 | # poet
2 | 我是“小诗姬”,全唐诗作为训练数据。可以写押韵自由诗、藏头诗、给定若干字作为主题的诗。
3 |
4 | 环境要求:
5 | ---
6 | python:3.x
7 | tensorflow 1.x
8 |
9 |
10 | 运行
11 | ---
12 | 运行训练:python train.py
13 | ---
14 | 
15 |
16 |
17 |
18 |
19 | 运行写诗服务: python poem_server
20 | ---
21 | 此步需要先运行训练产生好模型后。
22 |
23 | 客户端用浏览器访问,效果参照:
24 | ---
25 | 
26 |
27 |
28 |
29 |
30 | 如有问题欢迎讨论 xiuyunchen@126.com
31 |
32 |
--------------------------------------------------------------------------------
/char_rnn_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from enum import Enum
4 | import heapq
5 | import numpy as np
6 | import tensorflow as tf
7 | from rhyme_helper import RhymeWords
8 |
9 | logging.getLogger('tensorflow').setLevel(logging.WARNING)
10 | SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given'))
11 |
12 | class CharRNNLM(object):
13 | def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model,
14 | hidden_size, max_grad_norm, embedding_size, num_layers,
15 | learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False):
16 | self.batch_size = batch_size
17 | self.num_unrollings = num_unrollings
18 | if infer:
19 | self.batch_size = 1
20 | self.num_unrollings = 1
21 | self.hidden_size = hidden_size
22 | self.vocab_size = vocab_size
23 | self.max_grad_norm = max_grad_norm
24 | self.num_layers = num_layers
25 | self.embedding_size = embedding_size
26 | self.cell_type = cell_type
27 | self.dropout = dropout
28 | self.input_dropout = input_dropout
29 | self.w2v_model = w2v_model
30 |
31 | if embedding_size <= 0:
32 | self.input_size = vocab_size
33 | self.input_dropout = 0.0
34 | else:
35 | self.input_size = embedding_size
36 |
37 | self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs')
38 | self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets')
39 |
40 | if self.cell_type == 'rnn':
41 | cell_fn = tf.nn.rnn_cell.BasicRNNCell
42 | elif self.cell_type == 'lstm':
43 | cell_fn = tf.nn.rnn_cell.BasicLSTMCell
44 | elif self.cell_type == 'gru':
45 | cell_fn = tf.nn.rnn_cell.GRUCell
46 |
47 | params = dict()
48 | #params = {'input_size': self.input_size}
49 | if self.cell_type == 'lstm':
50 | params['forget_bias'] = 1.0 # 1.0 is default value
51 | cell = cell_fn(self.hidden_size, **params)
52 |
53 | cells = [cell]
54 | #params['input_size'] = self.hidden_size
55 | for i in range(self.num_layers-1):
56 | higher_layer_cell = cell_fn(self.hidden_size, **params)
57 | cells.append(higher_layer_cell)
58 |
59 | if is_training and self.dropout > 0:
60 | cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells]
61 |
62 | multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
63 |
64 | with tf.name_scope('initial_state'):
65 | self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32)
66 | if self.cell_type == 'rnn' or self.cell_type == 'gru':
67 | self.initial_state = tuple(
68 | [tf.placeholder(tf.float32,
69 | [self.batch_size, multi_cell.state_size[idx]],
70 | 'initial_state_'+str(idx+1)) for idx in range(self.num_layers)])
71 | elif self.cell_type == 'lstm':
72 | self.initial_state = tuple(
73 | [tf.nn.rnn_cell.LSTMStateTuple(
74 | tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]],
75 | 'initial_lstm_state_'+str(idx+1)),
76 | tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]],
77 | 'initial_lstm_state_'+str(idx+1)))
78 | for idx in range(self.num_layers)])
79 |
80 | with tf.name_scope('embedding_layer'):
81 | if embedding_size > 0:
82 | # self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size])
83 | self.embedding = tf.get_variable("word_embeddings",
84 | initializer=self.w2v_model.vectors.astype(np.float32))
85 |
86 | else:
87 | self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32)
88 |
89 | inputs = tf.nn.embedding_lookup(self.embedding, self.input_data)
90 | if is_training and self.input_dropout > 0:
91 | inputs = tf.nn.dropout(inputs, 1-self.input_dropout)
92 |
93 | with tf.name_scope('slice_inputs'):
94 | # num_unrollings * (batch_size, embedding_size), the format of rnn inputs.
95 | sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split(
96 | axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)]
97 |
98 | # sliced_inputs: list of shape xx
99 | # inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size]
100 | # initial_state: An initial state for the RNN.
101 | # If cell.state_size is an integer, this must be a Tensor of appropriate
102 | # type and shape [batch_size, cell.state_size]
103 | # outputs: a length T list of outputs (one for each input), or a nested tuple of such elements.
104 | # state: the final state
105 | outputs, final_state = tf.nn.static_rnn(
106 | cell = multi_cell,
107 | inputs = sliced_inputs,
108 | initial_state=self.initial_state)
109 | self.final_state = final_state
110 |
111 | with tf.name_scope('flatten_outputs'):
112 | flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size])
113 |
114 | with tf.name_scope('flatten_targets'):
115 | flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1])
116 |
117 | with tf.variable_scope('softmax') as sm_vs:
118 | softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size])
119 | softmax_b = tf.get_variable('softmax_b', [vocab_size])
120 | self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
121 | self.probs = tf.nn.softmax(self.logits)
122 |
123 | with tf.name_scope('loss'):
124 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
125 | logits = self.logits, labels = flat_targets)
126 | self.mean_loss = tf.reduce_mean(loss)
127 |
128 | with tf.name_scope('loss_montor'):
129 | count = tf.Variable(1.0, name='count')
130 | sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss')
131 |
132 | self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0),
133 | count.assign(0.0), name='reset_loss_monitor')
134 | self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss),
135 | count.assign(count+1), name='update_loss_monitor')
136 |
137 | with tf.control_dependencies([self.update_loss_monitor]):
138 | self.average_loss = sum_mean_loss / count
139 | self.ppl = tf.exp(self.average_loss)
140 |
141 | average_loss_summary = tf.summary.scalar(
142 | name = 'average loss', tensor = self.average_loss)
143 | ppl_summary = tf.summary.scalar(
144 | name = 'perplexity', tensor = self.ppl)
145 |
146 | self.summaries = tf.summary.merge(
147 | inputs = [average_loss_summary, ppl_summary], name='loss_monitor')
148 |
149 | self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0))
150 |
151 | # self.learning_rate = tf.constant(learning_rate)
152 | self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate')
153 |
154 | if is_training:
155 | tvars = tf.trainable_variables()
156 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm)
157 | optimizer = tf.train.AdamOptimizer(self.learning_rate)
158 | self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
159 |
160 |
161 | def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10):
162 | epoch_size = batch_generator.num_batches
163 |
164 | if verbose > 0:
165 | logging.info('epoch_size: %d', epoch_size)
166 | logging.info('data_size: %d', batch_generator.seq_length)
167 | logging.info('num_unrollings: %d', self.num_unrollings)
168 | logging.info('batch_size: %d', self.batch_size)
169 |
170 | if is_training:
171 | extra_op = self.train_op
172 | else:
173 | extra_op = tf.no_op()
174 |
175 |
176 | if self.cell_type in ['rnn', 'gru']:
177 | state = self.zero_state.eval()
178 | else:
179 | state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
180 | np.zeros((self.batch_size, self.hidden_size)))
181 | for _ in range(self.num_layers)])
182 |
183 | self.reset_loss_monitor.run()
184 | batch_generator.reset_batch_pointer()
185 | start_time = time.time()
186 | ppl_cumsum = 0
187 | for step in range(epoch_size):
188 | x, y = batch_generator.next_batch()
189 |
190 | ops = [self.average_loss, self.ppl, self.final_state, extra_op,
191 | self.summaries, self.global_step]
192 |
193 | feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state,
194 | self.learning_rate: learning_rate}
195 |
196 | results = session.run(ops, feed_dict)
197 | average_loss, ppl, final_state, _, summary_str, global_step = results
198 | ppl_cumsum += ppl
199 |
200 | # if (verbose > 0) and ((step+1) % freq == 0):
201 | if ((step+1) % freq == 0):
202 | logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words',
203 | (step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1),
204 | (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
205 | logging.info("Perplexity: %.3f, speed: %.0f words per sec",
206 | ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time))
207 |
208 | return ppl, summary_str, global_step
209 |
210 | def sample_seq(self, session, length, start_text, sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0):
211 | #state = self.zero_state.eval()
212 | if self.cell_type in ['rnn', 'gru']:
213 | state = self.zero_state.eval()
214 | else:
215 | state = tuple([(np.zeros((self.batch_size, self.hidden_size)),
216 | np.zeros((self.batch_size, self.hidden_size)))
217 | for _ in range(self.num_layers)])
218 |
219 | # use start_text to warm up the RNN.
220 | start_text = self.check_start(start_text)
221 | if start_text is not None and len(start_text) > 0:
222 | seq = list(start_text)
223 | for char in start_text[:-1]:
224 | x = np.array([[self.w2v_model.vocab_hash[char]]])
225 | state = session.run(self.final_state, {self.input_data: x, self.initial_state: state})
226 | x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]])
227 | else:
228 | x = np.array([[np.random.randint(0, self.vocab_size)]])
229 | seq = []
230 |
231 | for i in range(length):
232 | state, logits = session.run([self.final_state, self.logits],
233 | {self.input_data: x, self.initial_state: state})
234 | unnormalized_probs = np.exp(logits[0] - np.max(logits[0]))
235 | probs = unnormalized_probs / np.sum(unnormalized_probs)
236 |
237 | if rhyme_ref and i == rhyme_idx :
238 | sample = self.select_rhyme(rhyme_ref,probs)
239 | elif sample_type == SampleType.max_prob:
240 | sample = np.argmax(probs)
241 | elif sample_type == SampleType.select_given:
242 | sample,given = self.select_by_given(given,probs)
243 | else: #SampleType.weighted_sample
244 | sample = np.random.choice(self.vocab_size, 1, p=probs)[0]
245 |
246 | seq.append(self.w2v_model.vocab[sample])
247 | x = np.array([[sample]])
248 |
249 | return ''.join(seq)
250 |
251 | def select_by_given(self,given,probs,max_prob = False):
252 | if given:
253 | seq_probs = zip(probs,range(0,self.vocab_size))
254 | topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0])
255 |
256 | for _,seq in topn:
257 | if self.w2v_model.vocab[seq] in given:
258 | given = given.replace(self.w2v_model.vocab[seq],'')
259 | return seq,given
260 | if max_prob:
261 | return np.argmax(probs),given
262 |
263 | return np.random.choice(self.vocab_size, 1, p=probs)[0],given
264 |
265 |
266 | def select_rhyme(self,rhyme_ref,probs):
267 | if rhyme_ref:
268 | rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref)
269 | if rhyme_set:
270 | seq_probs = zip(probs,range(0,self.vocab_size))
271 | topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0])
272 |
273 | for _,seq in topn:
274 | if self.w2v_model.vocab[seq] in rhyme_set:
275 | return seq
276 |
277 | return np.argmax(probs)
278 |
279 | def check_start(self,text):
280 | idx = text.find('<')
281 | if idx > -1:
282 | text = text[:idx]
283 |
284 | valid_text = []
285 | for w in text:
286 | if w in self.w2v_model.vocab:
287 | valid_text.append(w)
288 | return ''.join(valid_text)
289 |
--------------------------------------------------------------------------------
/config_poem.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 | def config_poem_train(args=''):
5 | parser = argparse.ArgumentParser()
6 |
7 | # Data and vocabulary file
8 | # parser.add_argument('--data_file', type=str,
9 | # default='../data/poem/poems_space.txt',
10 | # help='data file')
11 |
12 | parser.add_argument('--data_path', type=str,
13 | default='./data/poem/',
14 | help='data path')
15 |
16 |
17 | parser.add_argument('--encoding', type=str,
18 | default='utf-8',
19 | help='the encoding of the data file.')
20 |
21 | # Parameters for saving models.
22 | parser.add_argument('--output_dir', type=str, default='output_model',
23 | help=('directory to store final and'
24 | ' intermediate results and models.'))
25 | # Parameters for using saved best models.
26 | parser.add_argument('--init_dir', type=str, default='',
27 | help='continue from the outputs in the given directory')
28 |
29 | # Parameters to configure the neural network.
30 | parser.add_argument('--hidden_size', type=int, default=128,#128,
31 | help='size of RNN hidden state vector')
32 | parser.add_argument('--embedding_size', type=int, default=128,#0,
33 | help='size of character embeddings, 0 for one-hot')
34 | parser.add_argument('--num_layers', type=int, default=2,
35 | help='number of layers in the RNN')
36 | parser.add_argument('--num_unrollings', type=int, default=64,#10,
37 | help='number of unrolling steps.')
38 | parser.add_argument('--cell_type', type=str, default='lstm',
39 | help='which model to use (rnn, lstm or gru).')
40 |
41 | # Parameters to control the training.
42 | parser.add_argument('--num_epochs', type=int, default=5,
43 | help='number of epochs')
44 | parser.add_argument('--batch_size', type=int, default=16,
45 | help='minibatch size')
46 | parser.add_argument('--train_frac', type=float, default=0.9,
47 | help='fraction of data used for training.')
48 | parser.add_argument('--valid_frac', type=float, default=0.05,
49 | help='fraction of data used for validation.')
50 | # test_frac is computed as (1 - train_frac - valid_frac).
51 | parser.add_argument('--dropout', type=float, default=0.0,
52 | help='dropout rate, default to 0 (no dropout).')
53 |
54 | parser.add_argument('--input_dropout', type=float, default=0.0,
55 | help=('dropout rate on input layer, default to 0 (no dropout),'
56 | 'and no dropout if using one-hot representation.'))
57 |
58 | # Parameters for gradient descent.
59 | parser.add_argument('--max_grad_norm', type=float, default=5.,
60 | help='clip global grad norm')
61 | parser.add_argument('--learning_rate', type=float, default=5e-3,
62 | help='initial learning rate')
63 |
64 | # Parameters for logging.
65 | parser.add_argument('--progress_freq', type=int, default=100,
66 | help=('frequency for progress report in training and evalution.'))
67 | parser.add_argument('--verbose', type=int, default=0,
68 | help=('whether to show progress report in training and evalution.'))
69 |
70 | # Parameters to feed in the initial model and current best model.
71 | parser.add_argument('--init_model', type=str,
72 | default='', help=('initial model'))
73 | parser.add_argument('--best_model', type=str,
74 | default='', help=('current best model'))
75 | parser.add_argument('--best_valid_ppl', type=float,
76 | default=np.Inf, help=('current valid perplexity'))
77 |
78 | # # Parameters for using saved best models.
79 | # parser.add_argument('--model_dir', type=str, default='',
80 | # help='continue from the outputs in the given directory')
81 |
82 | # Parameters for debugging.
83 | parser.add_argument('--debug', dest='debug', action='store_true',
84 | help='show debug information')
85 | parser.set_defaults(debug=False)
86 |
87 | # Parameters for unittesting the implementation.
88 | parser.add_argument('--test', dest='test', action='store_true',
89 | help=('use the first 1000 character to as data to test the implementation'))
90 | parser.set_defaults(test=False)
91 |
92 | # input_args = '--data_path ./data/poem --output_dir output_poem --hidden_size 256 --embedding_size 128 --num_unrollings 128 --debug --encoding utf-8'
93 | args = parser.parse_args(args.split())
94 |
95 | return args
96 |
97 |
98 |
99 | def config_sample(args=''):
100 | parser = argparse.ArgumentParser()
101 |
102 | # hyper-parameters for using saved best models.
103 | # 学习日志和结果相关的超参数
104 | logging_args = parser.add_argument_group('Logging_Options')
105 | logging_args.add_argument('--model_dir', type=str,
106 | default='demo_model/',
107 | help='continue from the outputs in the given directory')
108 |
109 | logging_args.add_argument('--data_dir', type=str,
110 | default='./data/poem',
111 | help='data file path')
112 |
113 | logging_args.add_argument('--best_model', type=str,
114 | default='', help=('current best model'))
115 |
116 | # hyper-parameters for sampling.
117 | # 设置sampling相关的超参数
118 | testing_args = parser.add_argument_group('Sampling Options')
119 | testing_args.add_argument('--max_prob', dest='max_prob', action='store_true',
120 | help='always pick the most probable next character in sampling')
121 | testing_args.set_defaults(max_prob=False)
122 |
123 | testing_args.add_argument('--start_text', type=str,
124 | default='The meaning of life is ',
125 | help='the text to start with')
126 |
127 | testing_args.add_argument('--length', type=int,
128 | default=100,
129 | help='length of sampled sequence')
130 |
131 | testing_args.add_argument('--seed', type=int,
132 | default=-1,
133 | help=('seed for sampling to replicate results, '
134 | 'an integer between 0 and 4294967295.'))
135 |
136 | args = parser.parse_args(args.split())
137 |
138 | return args
--------------------------------------------------------------------------------
/data/poem/poem_ids.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/data/poem/poem_ids.txt
--------------------------------------------------------------------------------
/data/poem/rhyme_words.txt:
--------------------------------------------------------------------------------
1 | 爸 把 八 罢 坝 拔 霸 扒 靶 叭 吧 擦 差 茶 插 查 叉 察 刹 咤 衩 大 达 发 法 伐 罚 尬 挂 瓜 刮 寡 呱 哈 花 华 化 话 画 滑 划 猾 家 加 价 甲 佳 假 夹 颊 驾 嫁 卡 跨 垮 夸 啦 拉 蜡 辣 妈 吗 嘛 马 砝 码 麻 骂 那 霎 那 拿 怕 爬 恰 撒 洒 沙 纱 啥 傻 厦 她 踏 塌 塔 榻 哇 娃 下 夏 霞 暇 压 雅 哑 崖 涯 杂 砸 眨 炸 吒 爪 抓 佳 玛 瑕 蚂 咖 垃
2 |
3 |
4 | 爱 哀 皑 埃 摆 白 百 拜 败 才 彩 财 采 踩 猜 睬 柴 踹 揣 待 带 代 袋 呆 该 改 怪 乖 海 害 孩 怀 坏 徘 徊 开 慨 快 块 来 赖 赖 莱 买 卖 脉 迈 麦 埋 耐 奶 奈 派 牌 排 湃 赛 腮 塞 晒 帅 衰 甩 摔 台 太 态 泰 外 歪 在 载 再 灾 栽 仔 债 摘 窄 拽 暧 儿 二 耳 饵 而 尔
5 |
6 | 安 暗 岸 边 变 辨 滨 伴 烂 餐 惨 残 惭 璀 璨 缠 颤 馋 掺 忏 潺 传 川 穿 船 串 喘 单 淡 担 胆 诞 耽 蛋 短 段 断 端 缎 帆 凡 反 饭 翻 返 犯 繁 烦 干 感 甘 敢 赶 肝 关 管 冠 灌 罐 惯 官 观 含 汉 寒 喊 汗 函 瀚 旱 憨 酣 欢 唤 还 换 幻 患 缓 焕 看 坎 侃 砍 堪 槛 款 宽 蓝 栏 懒 澜 揽 泛 滥 兰 婪 橄 榄 乱 峦 卵 满 慢 漫 瞒 男 南 难 暖 盼 判 盘 攀 畔 叛 然 渲 染 燃 软 伞 三 散 酸 算 山 善 闪 扇 删 蹒 跚 衫 拴 滩 叹 谈 坛 坦 探 贪 碳 毯 瘫 团 万 玩 完 晚 碗 湾 弯 婉 腕 挽 战 占 站 展 站 盏 湛 绽 转 专 煎 点 电 殿 垫 甸 巅 间 肩 见 件 剑 建 健 减 尖 坚 箭 渐 剪 舰 贱 捡 煎 俭 艰 拣 茧 连 怜 脸 联 链 练 恋 炼 莲 廉 帘 面 眠 免 绵 棉 勉 年 念 粘 篇 片 骗 偏 翩 牵 前 钱 千 浅 迁 签 欠 倩 谦 潜 遣 歉 虔 纤 阡 天 田 甜 填 添 舔 先 现 仙 线 限 鲜 弦 贤 闲 险 献 陷 娴 掀 倦 卷 娟 绢 捐 卷 涓 圈 搀 涟 珊 全 权 泉 犬 圈 拳 劝 选 轩 玄 喧 宣 旋 炫 悬 眩 绚 漩 言 眼 烟 燕 严 盐 严 演 艳 炎 宴 雁 验 焰 掩 嫣 厌 衍 淹 原 园 圆 源 员 远 院 愿 缘 苑 怨 渊 冤 援 尴 忐 蝙 颠 乾 店
7 |
8 | 昂 帮 绑 藏 沧 苍 茫 唱 长 唱 场 畅 怅 创 伤 窗 闯 党 荡 挡 方 放 芳 港 岗 光 广 逛 行 航 彷 徨 黄 荒 晃 皇 慌 谎 凰 江 降 浆 疆 匠 将 康 扛 抗 狂 旷 郎 浪 狼 朗 琅 谅 量 梁 亮 良 凉 忙 茫 氓 囊 娘 旁 胖 强 墙 枪 抢 腔 让 嚷 壤 伤 上 尚 赏 双 爽 霜 糖 棠 唐 汤 塘 趟 烫 躺 淌 网 望 往 汪 忘 旺 妄 想 向 乡 香 翔 巷 祥 响 详 样 阳 样 扬 养 氧 仰 汪 洋 痒 鸳 鸯 荡 漾 脏 葬 章 掌 涨 账 仗 杖 障 胀 装 状 壮 妆 桩 恍 滂
9 |
10 | 熬 骄 傲 奥 凹 报 宝 包 保 堡 爆 抱 暴 薄 饱 豹 曝 标 表 草 操 槽 超 炒 巢 潮 抄 吵 钞 朝 到 道 道 岛 刀 盗 倒 稻 捣 蹈 唠 叨 悼 高 搞 告 稿 糕 好 号 浩 豪 耗 嚎 毫 叫 教 交 脚 靠 烤 拷 老 劳 牢 捞 烙 酪 涝 了 聊 料 疗 毛 猫 茂 帽 冒 貌 卯 矛 髦 锚 苗 秒 妙 庙 描 瞄 淼 缈 脑 闹 挠 恼 瑙 孬 鸟 尿 袅 跑 泡 抛 炮 袍 刨 飘 漂 票 嫖 巧 桥 敲 瞧 侨 翘 俏 窍 撬 悄 峭 跷 锹 绕 饶 扰 饶 扫 嫂 骚 少 烧 稍 绍 勺 梢 哨 捎 艄 套 涛 桃 讨 逃 淘 掏 滔 嚎 啕 跳 条 调 挑 眺 迢 窕 小 笑 晓 肖 萧 校 效 消 孝 销 咆 哮 霄 削 宵 要 腰 耀 妖 遥 窑 谣 摇 咬 药 肴 耀 吆 幺 早 造 遭 枣 灶 燥 皂 糟 藻 凿 噪 澡 跳 蚤 找 照 招 着 兆 朝 罩 憔 狡 陶
11 |
12 | 层 曾 村 寸 存 忖 臣 晨 尘 称 辰 沉 趁 秤 成 城 称 呈 乘 诚 承 橙 澄 撑 骋 惩 春 纯 唇 醇 淳 蠢 等 灯 登 蹬 瞪 凳 吨 顿 盾 炖 敦 墩 蹲 钝 盹 恩 本 奔 笨 嗯 摁 苯 嫩 恁 恁 能
13 |
14 | 林 令 灵 另 零 临 邻 岭 凌 领 龄 铃 陵 淋 拎 琳 顶 定 丁 鼎 订 钉 盯 锭 叮 腚 均 军 君 俊 菌 峻 金 经 近 进 仅 今 精 景 静 尽 京 净 井 紧 锦 敬 镜 巾 劲 境 晶 津 斤 惊 禁 警 径 筋 浸 睛 憬 兢 荆 缤 冰 病 滨 鬓 宾 兵 并 饼 柄 丙 濒 玲 伶 淋
15 |
16 |
17 | 文 问 温 闻 纹 稳 吻 蚊 瘟 紊 门 闷 焖 们 汶 冷 肯 垦 恳 啃 坑 吭 坤 困 捆 冷 愣 棱 论 轮 昆 仑 抡 分 份 粉 纷 奋 粪 焚 愤 氛 跟 根 艮 更 耿 耕 庚 羹 埂 滚 棍 很 恨 痕 狠 恒 横 哼 衡 混 魂 婚 昏 馄 浑 荤 人 仁 任 认 忍 韧 刃 饪 润 森 僧 孙 损 损 孙 神 身 深 伸 渗 参 生 升 声 胜 圣 盛 绳 剩 吞 臀 豚 腾 藤 疼 怎 增 赠 憎 尊 遵 镇 真 振 针 珍 震 阵 帧 贞 枕 诊 疹 斟 正 证 政 整 争 症 征 蒸 挣 铮 怔 筝 睁 诤 准 芬
18 |
19 | 名 明 民 命 敏 鸣 铭 抿 您 宁 凝 拧 咛 喷 盆 品 平 评 凭 瓶 萍 屏 拼 坪 聘 贫 频 群 请 清 青 情 轻 庆 琴 亲 晴 勤 卿 倾 禽 侵 寝 擒 引擎 顷 听 厅 亭 停 挺 庭 廷 艇 蜻 蜓 心 新 型 性 行 信 星 兴 形 欣 姓 刑 馨 幸 醒 薪 腥 惺 寻 讯 训 迅 勋 巡 循 询 熏 旬 汛 驯 音 应 因 英 银 印 影 营 引 阴 硬 饮 隐 迎 鹰 赢 淫 殷 颖 盈 映 吟 荫 茵 婴 瘾 莺 缨 蝇 萦 姻 韵 云 运 晕 芸 允 匀 孕 蕴 熨 耘 酝
20 |
21 |
22 |
23 | 蹦 绷 甭 嘣 从 聪 丛 匆 重 冲 虫 充 崇 宠 忡 东 动 洞 冬 栋 懂 冻 咚 风 峰 封 凤 锋 奉 逢 缝 疯 蜂 讽 工 公 共 供 宫 功 攻 弓 贡 恭 拱 躬 红 洪 宏 鸿 弘 哄 轰 虹 泓 烘 讧 炯 窘 空 孔 控 恐 龙 隆 笼 聋 拢 珑 梦 蒙 猛 盟 萌 朦 农 弄 浓 侬 脓 鹏 碰 捧 棚 朋 蓬 砰 膨 怦 穷 琼 穹 荣 容 融 溶 绒 熔 送 松 颂 诵 耸 讼 通 痛 同 统 童 桶 筒 桐 捅 翁 嗡 熊 雄 兄 胸 凶 汹 用 永 勇 拥 涌 咏 庸 泳 蛹 总 宗 纵 综 踪 中 钟 重 种 众 忠 终 肿 踵 盅 憧 澎 烹 碰
24 |
25 | 不 步 部 布 补 捕 簿 怖 醋 促 粗 出 除 楚 初 处 储 触 厨 橱 雏 础 矗 搐 躇 度 都 读 独 毒 渡 赌 堵 肚 镀 督 嘟 督 妒 睹 渎 副 福 富 夫 复 服 附 付 负 府 赴 伏 父 扶 符 傅 浮 腹 辅 覆 赋 抚 妇 肤 腐 拂 斧 咐 腑 蝠 孵 俘 袱 匍 匐 扶 故 股 古 谷 顾 骨 固 鼓 孤 姑 菇 估 雇 咕 胡 户 湖 虎 乎 护 呼 互 壶 忽 狐 糊 弧 唬 惚 瑚 苦 库 哭 酷 裤 枯 窟 录 路 鲁 陆 炉 鹿 禄 鹭 虏 颅 噜 赂 木 亩 目 母 姆 墓 牧 幕 慕 沐 睦 怒 奴 努 弩 驽 呶 谱 铺 普 朴 扑 瀑 仆 铺 如 入 汝 乳 儒 辱 褥 苏 素 速 宿 诉 塑 俗 酥 溯 数 书 属 树 术 舒 输 熟 叔 鼠 蜀 束 述 淑 竖 梳 殊 恕 暑 孰 墅 漱 赎 图 土 涂 兔 吐 突 徒 途 屠 凸 秃 荼 涂 无 五 吴 物 武 舞 屋 吾 务 乌 勿 雾 伍 悟 误 午 巫 污 唔 呜 恶 捂 忤 侮 圬 组 租 足 族 祖 阻 卒 主 注 著 注 朱 住 猪 诸 竹 驻 祝 柱 助 煮 逐 筑 铸 烛 嘱 拄 蛛 敷
26 |
27 |
28 |
29 | 次 此 词 刺 赐 慈 辞 瓷 雌 疵 吃 持 池 赤 尺 迟 驰 齿 痴 翅 耻 弛 地 低 底 帝 滴 敌 蒂 抵 逐 弟 堤 递 笛 谛 及 即 级 机 集 基 记 吉 计 极 几 继 鸡 急 剂 季 纪 积 极 技 忌 击 姬 寄 籍 济 肌 疾 挤 祭 际 迹 己 激 辑 寂 绩 脊 妓 饥 叽 棘 讥 圾 嫉 笈 据 局 具 居 距 聚 句 举 咀 巨 俱 菊 剧 拒 锯 惧 桔 鞠 矩 拘 炬 驹 理 里 李 力 利 立 离 丽 例 礼 莉 黎 历 厉 梨 栗 沥 励 隶 厘 犁 俐 漓 狸 率 绿 铝 律 驴 旅 履 虑 滤 缕 侣 米 迷 密 秘 蜜 弥 谜 觅 咪 靡 眯 泌 你 昵 拟 泥 逆 腻 溺 女 皮 批 匹 脾 劈 屁 披 坯 辟 疲 啤 癖 僻 其 起 期 七 气 器 齐 奇 企 旗 启 漆 骑 妻 弃 棋 汽 欺 泣 乞 戚 契 歧 杞 去 区 取 曲 趣 娶 屈 驱 渠 躯 蛆 趋 蹊
30 |
31 | 日 四 斯 思 死 司 寺 丝 似 私 撕 肆 嘶 是 时 市 使 式 室 石 事 十 师 实 诗 世 史 氏 食 视 士 施 试 饰 始 失 识 湿 势 示 释 适 仕 狮 尸 拾 侍 逝 誓 蚀 驶 虱 拭 自 子 字 紫 资 仔 姿 滋 咨 渍 制 值 只 至 止 之 指 志 直 知 智 治 支 纸 质 致 置 枝 职 址 执 汁 植 芝 脂 织 滞 旨 掷 肢 趾 峙 稚 秩 秩 挚 帜 蜘
32 |
33 | 体 题 提 替 踢 梯 蹄 啼 锑 剃 涕 剔 惕 屉 嚏 西 系 喜 希 细 溪 洗 戏 吸 息 兮 席 熙 习 悉 稀 惜 夕 袭 析 隙 嘻 析 昔 嬉 熄 晰 需 须 许 虚 需 续 旭 序 恤 叙 绪 蓄 胥 絮 嘘 墟 栩 煦 酗 一 以 已 亿 易 亦 意 依 义 宜 伊 乙 仪 衣 益 矣 艺 异 已 译 医 逸 翼 怡 移 毅 议 疑 忆 倚 椅 遗 溢 奕 役 抑 谊 姨 疫 蚁 裔 弈 漪 奕 与 于 玉 鱼 余 於 雨 欲 语 宇 羽 遇 域 预 愈 育 予 御 豫 郁 浴 喻 愚 誉 娱 狱 屿 淤 愉 吁 谀 彼 比 碧 壁 闭 蔽 驿 逼 笔 币 必 毕 避 鼻 弊 毙 哔 鄙
34 |
35 |
36 | 波 博 薄 播 勃 搏 剥 泊 脖 渤 胳 膊 跛 萝 卜 恶 饿 测 册 侧 策 车 撤 扯 彻 澈 戳 龌 龊 多 朵 夺 躲 堕 剁 舵 惰 得 德 碟 蝶 笛 叠 爹 谍 跌 佛 国 过 果 郭 锅 裹 蝈 帼 个 各 格 歌 哥 阁 戈 隔 革 割 鸽 嗝 或 获 火 活 货 霍 祸 惑 伙 豁 攉 和 合 何 河 喝 盒 贺 核 荷 鹤 赫 禾 褐 呵 吓 接 街 节 杰 解 界 结 借 姐 洁 戒 捷 介 劫 阶 揭 截 竭 诫 觉 绝 决 抉 诀 爵 掘 嚼 倔 撅 崛 可 克 科 客 课 克 科 颗 壳 刻 咳 棵 渴 磕 苛 刻 扩 阔 括 廓 罗 落 洛 络 螺 裸 锣 骡 菠 萝 了 乐 勒 列 烈 裂 劣 猎 咧 冽 洌 略 掠 莫 魔 末 膜 墨 模 摩 磨 默 摸 么 抹 沫 漠 陌 馍 灭 诺 挪 懦 捏 孽 破 坡 颇 婆 迫 泼 魄 珀 陂 泊 撇 且 切 妾 窃 怯 却 缺 确 阙 雀 鹊 阕 瘸 若 弱 热 惹 缩 锁 所 索 梭 琐 娑 啰 哆 嗦 婆 娑 色 塞 瑟 涩 啬 说 硕 烁 设 社 摄 舍 蛇 射 折 舌 涉 赦 奢 慑 赊 涉 托 脱 拖 拓 陀 妥 骆 驼 驮 沱 唾 特 忑 蝌 做 作 座 左 坐 佐 昨 则 泽 责 择 仄 啧 卓 桌 着 捉 啄 浊 拙 灼 酌 这 者 着 折 哲 遮 辄 蔗 辙
37 |
38 | 贴 铁 我 窝 沃 卧 握 涡 些 血 写 谢 鞋 协 邪 斜 携 蟹 歇 泄 泻 卸 屑 械 挟 谐 蝎 偕 懈 胁 学 雪 血 穴 靴 削 也 业 页 叶 夜 液 野 耶 爷 冶 椰 噎 月 约 曰 乐 越 粤 岳 跃 悦 阅 亵 别 憋 鳖 瘪 邂
39 |
40 |
41 | 凑 愁 抽 丑 臭 仇 筹 酬 稠 绸 瞅 踌 畴 都 斗 豆 抖 逗 兜 陡 痘 蚪 丢 狗 购 沟 够 钩 勾 构 垢 后 侯 厚 候 猴 吼 喉 逅 就 旧 九 酒 久 救 纠 揪 鸠 舅 疚 口 扣 寇 抠 叩 蔻 楼 露 漏 喽 搂 篓 陋 六 流 留 柳 瘤 溜 榴 遛 某 谋 眸 牛 扭 纽 钮 妞 拗 忸 剖 球 求 秋 邱 丘 裘 囚 鳅 肉 柔 揉 搜 艘 馊 擞 手 受 收 首 守 售 寿 兽 授 瘦 熟 头 投 透 偷 修 秀 休 绣 袖 羞 嗅 朽 有 游 由 又 友 油 右 优 尤 邮 幽 佑 幼 忧 悠 诱 走 奏 揍 周 州 洲 轴 舟 粥 咒 皱 肘 骤 昼 柚 宙 帚 惆 犹
42 |
43 | 北 备 倍 杯 背 贝 碑 悲 辈 卑 惫 狈 翠 催 脆 萃 摧 粹 悴 对 队 堆 飞 非 费 菲 肥 废 肺 妃 匪 霏 啡 给 归 贵 鬼 桂 规 柜 龟 轨 跪 诡 玫 瑰 黑 嘿 会 灰 回 慧 惠 辉 汇 绘 徽 毁 晖 挥 悔 卉 讳 亏 魁 窥 葵 溃 馈 盔 愧 睽 类 雷 泪 累 蕾 磊 类 垒 擂 梅 每 没 美 枚 煤 妹 眉 酶 媚 媒 霉 昧 内 配 培 陪 佩 赔 沛 呸 瑞 锐 岁 随 虽 遂 碎 穗 隋 髓 退 推 腿 褪 为 位 未 卫 维 伟 威 微 味 魏 尾 喂 谓 围 胃 伪 危 纬 畏 违 贼 最 嘴 罪 醉 追 坠 锥 缀 委 贿
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/data/poem/vectors_poem.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/data/poem/vectors_poem.bin
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | #import collections
3 | from six.moves import cPickle
4 | import numpy as np
5 | from word2vec_helper import Word2Vec
6 | import math
7 |
8 |
9 |
10 | class DataLoader():
11 | def __init__(self, data_dir, batch_size,seq_max_length,w2v,data_type):
12 | self.data_dir = data_dir
13 | self.batch_size = batch_size
14 | self.seq_max_length = seq_max_length
15 | self.w2v = w2v
16 | self.trainingSamples = []
17 | self.validationSamples = []
18 | self.testingSamples = []
19 | self.train_frac = 0.85
20 | self.valid_frac = 0.05
21 |
22 | self.load_corpus(self.data_dir)
23 |
24 | if data_type == 'train':
25 | self.create_batches(self.trainingSamples)
26 | elif data_type == 'test':
27 | self.create_batches(self.testingSamples)
28 | elif data_type == 'valid':
29 | self.create_batches(self.validationSamples)
30 |
31 | self.reset_batch_pointer()
32 |
33 | def _print_stats(self):
34 | print('Loaded {}: training samples:{} ,validationSamples:{},testingSamples:{}'.format(
35 | self.data_dir, len(self.trainingSamples),len(self.validationSamples),len(self.testingSamples)))
36 |
37 | def load_corpus(self,base_path):
38 | """读/创建 对话数据:
39 | 在训练文件创建的过程中,由两个文件
40 | 1. self.fullSamplePath
41 | 2. self.filteredSamplesPath
42 | """
43 | tensor_file = os.path.join(base_path,'poem_ids.txt')
44 | print('tensor_file:%s' % tensor_file)
45 |
46 | datasetExist = os.path.isfile(tensor_file)
47 | # 如果处理过的对话数据文件不存在,创建数据文件
48 | if not datasetExist:
49 | print('训练样本不存在。从原始样本数据集创建训练样本...')
50 |
51 | fullSamplesPath = os.path.join(self.data_dir,'poems_edge_split.txt')
52 | # 创建/读取原始对话样本数据集: self.trainingSamples
53 | print('fullSamplesPath:%s' % fullSamplesPath)
54 | self.load_from_text_file(fullSamplesPath)
55 |
56 | else:
57 | self.load_dataset(tensor_file)
58 |
59 | self.padToken = self.w2v.ix('')
60 | self.goToken = self.w2v.ix('[')
61 | self.eosToken = self.w2v.ix(']')
62 | self.unknownToken = self.w2v.ix('')
63 |
64 | self._print_stats()
65 | # assert self.padToken == 0
66 |
67 | def load_from_text_file(self,in_file):
68 | # base_path = 'F:\BaiduYunDownload\chatbot_lecture\lecture2\data\ice_and_fire_zh'
69 | # in_file = os.path.join(base_path,'poems_edge.txt')
70 | fr = open(in_file, "r",encoding='utf-8')
71 | poems = fr.readlines()
72 | fr.close()
73 |
74 | print("唐诗总数: %d"%len(poems))
75 | # self.seq_max_length = max([len(poem) for poem in poems])
76 | # print("seq_max_length: %d"% (self.seq_max_length))
77 |
78 | poem_ids = DataLoader.get_text_idx(poems,self.w2v.vocab_hash,self.seq_max_length)
79 |
80 | # # 后续处理
81 | # # 1. 单词过滤,去掉不常见(<=filterVocab)的单词,保留最常见的vocabSize个单词
82 | # print('Filtering words (vocabSize = {} and wordCount > {})...'.format(
83 | # self.args.vocabularySize,
84 | # self.args.filterVocab
85 | # ))
86 | # self.filterFromFull()
87 |
88 | # 2. 分割数据
89 | print('分割数据为 train, valid, test 数据集...')
90 | n_samples = len(poem_ids)
91 | train_size = int(self.train_frac * n_samples)
92 | valid_size = int(self.valid_frac * n_samples)
93 | test_size = n_samples - train_size - valid_size
94 |
95 | print('n_samples=%d, train-size=%d, valid_size=%d, test_size=%d' % (
96 | n_samples, train_size, valid_size, test_size))
97 | self.testingSamples = poem_ids[-test_size:]
98 | self.validationSamples = poem_ids[-valid_size-test_size : -test_size]
99 | self.trainingSamples = poem_ids[:train_size]
100 |
101 | # 保存处理过的训练数据集
102 | print('Saving dataset...')
103 | poem_ids_file = os.path.join(self.data_dir,'poem_ids.txt')
104 | self.save_dataset(poem_ids_file)
105 |
106 | # 2. utility 函数,使用pickle写文件
107 | def save_dataset(self, filename):
108 | """使用pickle保存数据文件。
109 |
110 | 数据文件包含词典和对话样本。
111 |
112 | Args:
113 | filename (str): pickle 文件名
114 | """
115 | with open(filename, 'wb') as handle:
116 | data = {
117 | 'trainingSamples': self.trainingSamples
118 | }
119 |
120 | if len(self.validationSamples)>0:
121 | data['validationSamples'] = self.validationSamples
122 | data['testingSamples'] = self.testingSamples
123 | data['maxSeqLen'] = self.seq_max_length
124 |
125 | cPickle.dump(data, handle, -1) # Using the highest protocol available
126 |
127 | # 3. utility 函数,使用pickle读文件
128 | def load_dataset(self, filename):
129 | """使用pickle读入数据文件
130 | Args:
131 | filename (str): pickle filename
132 | """
133 |
134 | print('Loading dataset from {}'.format(filename))
135 | with open(filename, 'rb') as handle:
136 | data = cPickle.load(handle)
137 | self.trainingSamples = data['trainingSamples']
138 |
139 | if 'validationSamples' in data:
140 | self.validationSamples = data['validationSamples']
141 | self.testingSamples = data['testingSamples']
142 |
143 | print('file maxSeqLen = {}'.format( data['maxSeqLen']))
144 |
145 |
146 | @classmethod
147 | def get_text_idx(text,vocab,max_document_length):
148 | text_array = []
149 | for i,x in enumerate(text):
150 | line = []
151 | for j, w in enumerate(x):
152 | if (w not in vocab):
153 | w = ''
154 | line.append(vocab[w])
155 | text_array.append(line)
156 | # else :
157 | # print w,'not exist'
158 |
159 | return text_array
160 |
161 | def create_batches(self,samples):
162 |
163 | sample_size = len(samples)
164 | self.num_batches = math.ceil(sample_size /self.batch_size)
165 | new_sample_size = self.num_batches * self.batch_size
166 |
167 | # Create the batch tensor
168 | # x_lengths = [len(sample) for sample in samples]
169 |
170 | x_lengths = []
171 | x_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32)
172 | y_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32)
173 | self.x_lengths = []
174 | for i,sample in enumerate(samples):
175 | # fill with padding to align batchSize samples into one 2D list
176 | x_lengths.append(len(sample))
177 | x_seqs[i] = sample + [self.padToken] * (self.seq_max_length - len(sample))
178 |
179 | for i in range(sample_size,new_sample_size):
180 | copyi = i - sample_size
181 | x_seqs[i] = x_seqs[copyi]
182 | x_lengths.append(x_lengths[copyi])
183 |
184 | y_seqs[:,:-1] = x_seqs[:,1:]
185 | y_seqs[:,-1] = x_seqs[:,0]
186 | x_len_array = np.array(x_lengths)
187 |
188 |
189 |
190 | self.x_batches = np.split(x_seqs.reshape(self.batch_size, -1), self.num_batches, 1)
191 | self.x_len_batches = np.split(x_len_array.reshape(self.batch_size, -1), self.num_batches, 1)
192 | self.y_batches = np.split(y_seqs.reshape(self.batch_size, -1), self.num_batches, 1)
193 |
194 | def next_batch_dynamic(self):
195 | x,x_len, y = self.x_batches[self.pointer], self.x_len_batches[self.pointer],self.y_batches[self.pointer]
196 | self.pointer += 1
197 | return x,x_len, y
198 |
199 | def next_batch(self):
200 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]
201 | self.pointer += 1
202 | return x,y
203 |
204 | def reset_batch_pointer(self):
205 | self.pointer = 0
206 |
207 | @staticmethod
208 | def get_text_idx(text,vocab,max_document_length):
209 | max_document_length_without_end = max_document_length - 1
210 | text_array = []
211 | for i,x in enumerate(text):
212 | line = []
213 | if len(x) > max_document_length:
214 | x_parts = x[:max_document_length_without_end]
215 | idx = x_parts.rfind('。')
216 | if idx > -1 :
217 | x_parts = x_parts[0:idx + 1] + ']'
218 | x = x_parts
219 |
220 | for j, w in enumerate(x):
221 | # if j >= max_document_length:
222 | # break
223 |
224 | if (w not in vocab):
225 | w = ''
226 | line.append(vocab[w])
227 | text_array.append(line)
228 | # else :
229 | # print w,'not exist'
230 |
231 | return text_array
232 |
233 | if __name__ == '__main__':
234 | base_path = './data/poem'
235 | # poem = '风急云轻鹤背寒,洞天谁道却归难。千山万水瀛洲路,何处烟飞是醮坛。是的'
236 | # idx = poem.rfind('。')
237 | # poem_part = poem[:idx + 1]
238 | w2v_file = os.path.join(base_path, "vectors_poem.bin")
239 | w2v = Word2Vec(w2v_file)
240 |
241 | # vect = w2v_model['['][:10]
242 | # print(vect)
243 | #
244 | # vect = w2v_model['春'][:10]
245 | # print(vect)
246 |
247 | in_file = os.path.join(base_path,'poems_edge.txt')
248 | # fr = open(in_file, "r",encoding='utf-8')
249 | # poems = fr.readlines()
250 | # fr.close()
251 | #
252 | #
253 | #
254 | # print("唐诗总数: %d"%len(poems))
255 | #
256 | # poem_ids = get_text_idx(poems,w2v.model.vocab_hash,100)
257 | # poem_ids_file = os.path.join(base_path,'poem_ids.txt')
258 | # with open(poem_ids_file, 'wb') as f:
259 | # cPickle.dump(poem_ids, f)
260 |
261 | dataloader = DataLoader(base_path,20,w2v.model,'train')
262 |
263 |
--------------------------------------------------------------------------------
/doc/client.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/doc/client.png
--------------------------------------------------------------------------------
/doc/train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/doc/train.png
--------------------------------------------------------------------------------
/poem_server.py:
--------------------------------------------------------------------------------
1 | #coding=utf8
2 | import os
3 | from flask import Flask,request
4 | from write_poem import WritePoem,start_model
5 |
6 | app = Flask(__name__)
7 | application = app
8 |
9 | path = os.getcwd() #获取当前工作目录
10 | print(path)
11 |
12 | writer = start_model()
13 |
14 | # @app.route('/')
15 | # def test(title):
16 | # return 'test ok'
17 |
18 | sytle_help = '
para style : 1:自由诗
2:带押韵的自由诗
3:藏头诗
4:给定若干字,以最大概率生成诗'
19 | @app.route('/poem')
20 | def write_poem():
21 | params = request.args
22 | start_with= ''
23 | poem_style = 0
24 |
25 | # print(params)
26 | if 'start' in params :
27 | start_with = params['start']
28 | if 'style' in params:
29 | poem_style = int(params['style'])
30 |
31 | # return 'hello'
32 | if start_with:
33 | if poem_style == 3:
34 | return writer.cangtou(start_with)
35 | elif poem_style == 4:
36 | return writer.hide_words(start_with)
37 |
38 | if poem_style == 1:
39 | return writer.free_verse()
40 | elif poem_style == 2:
41 | return writer.rhyme_verse()
42 |
43 | return 'hello,what do you want? {}'.format(sytle_help)
44 |
45 |
46 | if __name__ == "__main__":
47 | app.run()
--------------------------------------------------------------------------------
/rhyme_helper.py:
--------------------------------------------------------------------------------
1 |
2 | class RhymeWords():
3 | rhyme_list = []
4 |
5 | @staticmethod
6 | def read_rhyme_words(infile):
7 | with open(infile,'r',encoding='utf-8',errors='ignore') as fr:
8 | for line in fr:
9 | words = set(line.split())
10 | RhymeWords.rhyme_list.append(words)
11 |
12 | @staticmethod
13 | def get_rhyme_words(w):
14 | for words in RhymeWords.rhyme_list:
15 | if w in words:
16 | return words
17 | return None
18 |
19 | @staticmethod
20 | def print_stats():
21 | count = 0
22 | for words in RhymeWords.rhyme_list:
23 | count += len(words)
24 | print(words)
25 |
26 | for w in words:
27 | if len(w) > 1:
28 | print(w)
29 |
30 | print('count = ',count)
31 |
32 | if __name__ == '__main__':
33 | infile = './data/poem/rhyme_words.txt'
34 | RhymeWords.read_rhyme_words(infile)
35 | RhymeWords.print_stats()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import json
3 | import logging
4 | import os
5 | import shutil
6 | import sys
7 | import time
8 | import numpy as np
9 | import tensorflow as tf
10 | from char_rnn_model import CharRNNLM
11 | from config_poem import config_poem_train
12 | from data_loader import DataLoader
13 | from word2vec_helper import Word2Vec
14 | TF_VERSION = int(tf.__version__.split('.')[1])
15 |
16 |
17 | def main(args=''):
18 | args = config_poem_train(args)
19 | # Specifying location to store model, best model and tensorboard log.
20 | args.save_model = os.path.join(args.output_dir, 'save_model/model')
21 | args.save_best_model = os.path.join(args.output_dir, 'best_model/model')
22 | # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/')
23 | timestamp = str(int(time.time()))
24 | args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp))
25 | print("Writing to {}\n".format(args.tb_log_dir))
26 |
27 | # Create necessary directories.
28 | if len(args.init_dir) != 0:
29 | args.output_dir = args.init_dir
30 | else:
31 | if os.path.exists(args.output_dir):
32 | shutil.rmtree(args.output_dir)
33 | for paths in [args.save_model, args.save_best_model, args.tb_log_dir]:
34 | os.makedirs(os.path.dirname(paths))
35 |
36 | logging.basicConfig(stream=sys.stdout,
37 | format='%(asctime)s %(levelname)s:%(message)s',
38 | level=logging.INFO, datefmt='%I:%M:%S')
39 |
40 | print('=' * 60)
41 | print('All final and intermediate outputs will be stored in %s/' % args.output_dir)
42 | print('=' * 60 + '\n')
43 |
44 | logging.info('args are:\n%s', args)
45 |
46 | if len(args.init_dir) != 0:
47 | with open(os.path.join(args.init_dir, 'result.json'), 'r') as f:
48 | result = json.load(f)
49 | params = result['params']
50 | args.init_model = result['latest_model']
51 | best_model = result['best_model']
52 | best_valid_ppl = result['best_valid_ppl']
53 | if 'encoding' in result:
54 | args.encoding = result['encoding']
55 | else:
56 | args.encoding = 'utf-8'
57 |
58 | else:
59 | params = {'batch_size': args.batch_size,
60 | 'num_unrollings': args.num_unrollings,
61 | 'hidden_size': args.hidden_size,
62 | 'max_grad_norm': args.max_grad_norm,
63 | 'embedding_size': args.embedding_size,
64 | 'num_layers': args.num_layers,
65 | 'learning_rate': args.learning_rate,
66 | 'cell_type': args.cell_type,
67 | 'dropout': args.dropout,
68 | 'input_dropout': args.input_dropout}
69 | best_model = ''
70 | logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4))
71 |
72 | # Create batch generators.
73 | batch_size = params['batch_size']
74 | num_unrollings = params['num_unrollings']
75 |
76 | base_path = args.data_path
77 | w2v_file = os.path.join(base_path, "vectors_poem.bin")
78 | w2v = Word2Vec(w2v_file)
79 |
80 | train_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'train')
81 | test1_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'test')
82 | valid_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'valid')
83 |
84 | # Create graphs
85 | logging.info('Creating graph')
86 | graph = tf.Graph()
87 | with graph.as_default():
88 | w2v_vocab_size = len(w2v.model.vocab)
89 | with tf.name_scope('training'):
90 | train_model = CharRNNLM(is_training=True,w2v_model = w2v.model,vocab_size=w2v_vocab_size, infer=False, **params)
91 | tf.get_variable_scope().reuse_variables()
92 |
93 | with tf.name_scope('validation'):
94 | valid_model = CharRNNLM(is_training=False,w2v_model = w2v.model, vocab_size=w2v_vocab_size, infer=False, **params)
95 |
96 | with tf.name_scope('evaluation'):
97 | test_model = CharRNNLM(is_training=False,w2v_model = w2v.model,vocab_size=w2v_vocab_size, infer=False, **params)
98 | saver = tf.train.Saver(name='model_saver')
99 | best_model_saver = tf.train.Saver(name='best_model_saver')
100 |
101 | logging.info('Start training\n')
102 |
103 | result = {}
104 | result['params'] = params
105 |
106 |
107 | try:
108 | with tf.Session(graph=graph) as session:
109 | # Version 8 changed the api of summary writer to use
110 | # graph instead of graph_def.
111 | if TF_VERSION >= 8:
112 | graph_info = session.graph
113 | else:
114 | graph_info = session.graph_def
115 |
116 | train_summary_dir = os.path.join(args.tb_log_dir, "summaries", "train")
117 | train_writer = tf.summary.FileWriter(train_summary_dir, graph_info)
118 | valid_summary_dir = os.path.join(args.tb_log_dir, "summaries", "valid")
119 | valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info)
120 |
121 | # load a saved model or start from random initialization.
122 | if len(args.init_model) != 0:
123 | saver.restore(session, args.init_model)
124 | else:
125 | tf.global_variables_initializer().run()
126 |
127 | learning_rate = args.learning_rate
128 | for epoch in range(args.num_epochs):
129 | logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch)
130 | logging.info('Training on training set')
131 | # training step
132 | ppl, train_summary_str, global_step = train_model.run_epoch(session, train_data_loader, is_training=True,
133 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
134 | # record the summary
135 | train_writer.add_summary(train_summary_str, global_step)
136 | train_writer.flush()
137 | # save model
138 | saved_path = saver.save(session, args.save_model,
139 | global_step=train_model.global_step)
140 |
141 | logging.info('Latest model saved in %s\n', saved_path)
142 | logging.info('Evaluate on validation set')
143 |
144 | valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_data_loader, is_training=False,
145 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
146 |
147 | # save and update best model
148 | if (len(best_model) == 0) or (valid_ppl < best_valid_ppl):
149 | best_model = best_model_saver.save(session, args.save_best_model,
150 | global_step=train_model.global_step)
151 | best_valid_ppl = valid_ppl
152 | else:
153 | learning_rate /= 2.0
154 | logging.info('Decay the learning rate: ' + str(learning_rate))
155 |
156 | valid_writer.add_summary(valid_summary_str, global_step)
157 | valid_writer.flush()
158 |
159 | logging.info('Best model is saved in %s', best_model)
160 | logging.info('Best validation ppl is %f\n', best_valid_ppl)
161 |
162 | result['latest_model'] = saved_path
163 | result['best_model'] = best_model
164 | # Convert to float because numpy.float is not json serializable.
165 | result['best_valid_ppl'] = float(best_valid_ppl)
166 |
167 | result_path = os.path.join(args.output_dir, 'result.json')
168 | if os.path.exists(result_path):
169 | os.remove(result_path)
170 | with open(result_path, 'w') as f:
171 | json.dump(result, f, indent=2, sort_keys=True)
172 |
173 | logging.info('Latest model is saved in %s', saved_path)
174 | logging.info('Best model is saved in %s', best_model)
175 | logging.info('Best validation ppl is %f\n', best_valid_ppl)
176 |
177 | logging.info('Evaluate the best model on test set')
178 | saver.restore(session, best_model)
179 | test_ppl, _, _ = test_model.run_epoch(session, test1_data_loader, is_training=False,
180 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq)
181 | result['test_ppl'] = float(test_ppl)
182 | except Exception as e:
183 | print('err :{}'.format(e))
184 | finally:
185 | result_path = os.path.join(args.output_dir, 'result.json')
186 | if os.path.exists(result_path):
187 | os.remove(result_path)
188 | with open(result_path, 'w',encoding='utf-8',errors='ignore') as f:
189 | json.dump(result, f, indent=2, sort_keys=True)
190 |
191 |
192 | if __name__ == '__main__':
193 | args = '--output_dir output_poem --data_path ./data/poem/ --hidden_size 128 --embedding_size 128 --cell_type lstm'
194 | main(args)
195 |
--------------------------------------------------------------------------------
/word2vec_helper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import word2vec
3 |
4 | class Word2Vec():
5 | def __init__(self,file_path):
6 | # w2v_file = os.path.join(base_path, "vectors_poem.bin")
7 | self.model = word2vec.load(file_path)
8 | self.add_word('')
9 | self.add_word('')
10 | # self.vocab_size = len(self.model.vocab)
11 |
12 | def add_word(self,word):
13 | if word not in self.model.vocab_hash:
14 | w_vec = np.random.uniform(-0.1,0.1,size=128)
15 | self.model.vocab_hash[word] = len(self.model.vocab)
16 | self.model.vectors = np.row_stack((self.model.vectors,w_vec))
17 | self.model.vocab = np.concatenate((self.model.vocab,np.array([word])))
18 |
19 | # vocab = np.empty(1, dtype='= 0:
39 | np.random.seed(args.seed)
40 |
41 | logging.info('best_model: %s\n', best_model)
42 |
43 | self.sess = tf.Session()
44 | w2v_vocab_size = len(self.w2v.model.vocab)
45 | with tf.name_scope('evaluation'):
46 | self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params)
47 | saver = tf.train.Saver(name='model_saver')
48 | saver.restore(self.sess, best_model)
49 |
50 | def free_verse(self):
51 | '''
52 | 自由诗
53 | Returns:
54 |
55 | '''
56 | sample = self.model.sample_seq(self.sess, 40, '[',sample_type= SampleType.weighted_sample)
57 | if not sample:
58 | return 'err occar!'
59 |
60 | print('free_verse:',sample)
61 |
62 | idx_end = sample.find(']')
63 | parts = sample.split('。')
64 | if len(parts) > 1:
65 | two_sentence_len = len(parts[0]) + len(parts[1])
66 | if idx_end < 0 or two_sentence_len < idx_end:
67 | return sample[1:two_sentence_len + 2]
68 |
69 | return sample[1:idx_end]
70 |
71 | @staticmethod
72 | def assemble(sample):
73 | if sample:
74 | parts = sample.split('。')
75 | if len(parts) > 1:
76 | return '{}。{}。'.format(parts[0][1:],parts[1][:len(parts[0])])
77 |
78 | return ''
79 |
80 |
81 | def rhyme_verse(self):
82 | '''
83 | 押韵诗
84 | Returns:
85 |
86 | '''
87 | gen_len = 20
88 | sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.weighted_sample)
89 | if not sample:
90 | return 'err occar!'
91 |
92 | print('rhyme_verse:',sample)
93 |
94 | parts = sample.split('。')
95 | if len(parts) > 0:
96 | start = parts[0] + '。'
97 | rhyme_ref_word = start[-2]
98 | rhyme_seq = len(start) - 3
99 |
100 | sample = self.model.sample_seq(self.sess, gen_len , start,
101 | sample_type= SampleType.weighted_sample,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
102 | print(sample)
103 | return WritePoem.assemble(sample)
104 |
105 | return sample[1:]
106 |
107 | def hide_words(self,given_text):
108 | '''
109 | 藏字诗
110 | Args:
111 | given_text:
112 |
113 | Returns:
114 |
115 | '''
116 | if(not given_text):
117 | return self.rhyme_verse()
118 |
119 | givens = ['','']
120 | split_len = math.ceil(len(given_text)/2)
121 | givens[0] = given_text[:split_len]
122 | givens[1] = given_text[split_len:]
123 |
124 | gen_len = 20
125 | sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.select_given,given=givens[0])
126 | if not sample:
127 | return 'err occar!'
128 |
129 | print('rhyme_verse:',sample)
130 |
131 | parts = sample.split('。')
132 | if len(parts) > 0:
133 | start = parts[0] + '。'
134 | rhyme_ref_word = start[-2]
135 | rhyme_seq = len(start) - 3
136 | # gen_len = len(start) - 1
137 |
138 | sample = self.model.sample_seq(self.sess, gen_len , start,
139 | sample_type= SampleType.select_given,given=givens[1],rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
140 | print(sample)
141 | return WritePoem.assemble(sample)
142 |
143 | return sample[1:]
144 |
145 | def cangtou(self,given_text):
146 | '''
147 | 藏头诗
148 | Returns:
149 |
150 | '''
151 | if(not given_text):
152 | return self.rhyme_verse()
153 |
154 | start = ''
155 | rhyme_ref_word = ''
156 | rhyme_seq = 0
157 |
158 | # for i,word in enumerate(given_text):
159 | for i in range(4):
160 | word = ''
161 | if i < len(given_text):
162 | word = given_text[i]
163 |
164 | if i == 0:
165 | start = '[' + word
166 | else:
167 | start += word
168 |
169 | before_idx = len(start)
170 | if(i != 3):
171 | sample = self.model.sample_seq(self.sess, self.args.length, start,
172 | sample_type= SampleType.weighted_sample )
173 |
174 | else:
175 | if not word:
176 | rhyme_seq += 1
177 |
178 | sample = self.model.sample_seq(self.sess, self.args.length, start,
179 | sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq )
180 |
181 | print('Sampled text is:\n\n%s' % sample)
182 |
183 | sample = sample[before_idx:]
184 | idx1 = sample.find(',')
185 | idx2 = sample.find('。')
186 | min_idx = min(idx1,idx2)
187 |
188 | if min_idx == -1:
189 | if idx1 > -1 :
190 | min_idx = idx1
191 | else: min_idx =idx2
192 | if min_idx > 0:
193 | # last_sample.append(sample[:min_idx + 1])
194 | start ='{}{}'.format(start, sample[:min_idx + 1])
195 |
196 | if i == 1:
197 | rhyme_seq = min_idx - 1
198 | rhyme_ref_word = sample[rhyme_seq]
199 |
200 | print('last_sample text is:\n\n%s' % start)
201 |
202 | return WritePoem.assemble(start)
203 |
204 | def start_model():
205 | now = int(time.time())
206 | args = config_sample('--model_dir output_poem --length 16 --seed {}'.format(now))
207 | writer = WritePoem(args)
208 | return writer
209 |
210 | if __name__ == '__main__':
211 | writer = start_model()
212 |
--------------------------------------------------------------------------------