├── README.md ├── char_rnn_model.py ├── config_poem.py ├── data └── poem │ ├── poem_ids.txt │ ├── poems_edge_split.txt │ ├── rhyme_words.txt │ └── vectors_poem.bin ├── data_loader.py ├── doc ├── client.png └── train.png ├── poem_server.py ├── rhyme_helper.py ├── train.py ├── word2vec_helper.py └── write_poem.py /README.md: -------------------------------------------------------------------------------- 1 | # poet 2 | 我是“小诗姬”,全唐诗作为训练数据。可以写押韵自由诗、藏头诗、给定若干字作为主题的诗。 3 |

4 | 环境要求: 5 | --- 6 | python:3.x
7 | tensorflow 1.x 8 | 9 |

10 | 运行
11 | --- 12 | 运行训练:python train.py
13 | --- 14 | ![image](https://github.com/norybaby/poet/blob/master/doc/train.png) 15 | 16 |

17 |   18 | 19 | 运行写诗服务: python poem_server
20 | --- 21 | 此步需要先运行训练产生好模型后。 22 | 23 | 客户端用浏览器访问,效果参照:
24 | --- 25 | ![image](https://github.com/norybaby/poet/blob/master/doc/client.png) 26 | 27 |

28 |   29 | 30 | 如有问题欢迎讨论 xiuyunchen@126.com 31 | 32 | -------------------------------------------------------------------------------- /char_rnn_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from enum import Enum 4 | import heapq 5 | import numpy as np 6 | import tensorflow as tf 7 | from rhyme_helper import RhymeWords 8 | 9 | logging.getLogger('tensorflow').setLevel(logging.WARNING) 10 | SampleType = Enum('SampleType',('max_prob', 'weighted_sample', 'rhyme','select_given')) 11 | 12 | class CharRNNLM(object): 13 | def __init__(self, is_training, batch_size, num_unrollings, vocab_size,w2v_model, 14 | hidden_size, max_grad_norm, embedding_size, num_layers, 15 | learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False): 16 | self.batch_size = batch_size 17 | self.num_unrollings = num_unrollings 18 | if infer: 19 | self.batch_size = 1 20 | self.num_unrollings = 1 21 | self.hidden_size = hidden_size 22 | self.vocab_size = vocab_size 23 | self.max_grad_norm = max_grad_norm 24 | self.num_layers = num_layers 25 | self.embedding_size = embedding_size 26 | self.cell_type = cell_type 27 | self.dropout = dropout 28 | self.input_dropout = input_dropout 29 | self.w2v_model = w2v_model 30 | 31 | if embedding_size <= 0: 32 | self.input_size = vocab_size 33 | self.input_dropout = 0.0 34 | else: 35 | self.input_size = embedding_size 36 | 37 | self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='inputs') 38 | self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name='targets') 39 | 40 | if self.cell_type == 'rnn': 41 | cell_fn = tf.nn.rnn_cell.BasicRNNCell 42 | elif self.cell_type == 'lstm': 43 | cell_fn = tf.nn.rnn_cell.BasicLSTMCell 44 | elif self.cell_type == 'gru': 45 | cell_fn = tf.nn.rnn_cell.GRUCell 46 | 47 | params = dict() 48 | #params = {'input_size': self.input_size} 49 | if self.cell_type == 'lstm': 50 | params['forget_bias'] = 1.0 # 1.0 is default value 51 | cell = cell_fn(self.hidden_size, **params) 52 | 53 | cells = [cell] 54 | #params['input_size'] = self.hidden_size 55 | for i in range(self.num_layers-1): 56 | higher_layer_cell = cell_fn(self.hidden_size, **params) 57 | cells.append(higher_layer_cell) 58 | 59 | if is_training and self.dropout > 0: 60 | cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells] 61 | 62 | multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells) 63 | 64 | with tf.name_scope('initial_state'): 65 | self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32) 66 | if self.cell_type == 'rnn' or self.cell_type == 'gru': 67 | self.initial_state = tuple( 68 | [tf.placeholder(tf.float32, 69 | [self.batch_size, multi_cell.state_size[idx]], 70 | 'initial_state_'+str(idx+1)) for idx in range(self.num_layers)]) 71 | elif self.cell_type == 'lstm': 72 | self.initial_state = tuple( 73 | [tf.nn.rnn_cell.LSTMStateTuple( 74 | tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]], 75 | 'initial_lstm_state_'+str(idx+1)), 76 | tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]], 77 | 'initial_lstm_state_'+str(idx+1))) 78 | for idx in range(self.num_layers)]) 79 | 80 | with tf.name_scope('embedding_layer'): 81 | if embedding_size > 0: 82 | # self.embedding = tf.get_variable('embedding', [self.vocab_size, self.embedding_size]) 83 | self.embedding = tf.get_variable("word_embeddings", 84 | initializer=self.w2v_model.vectors.astype(np.float32)) 85 | 86 | else: 87 | self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32) 88 | 89 | inputs = tf.nn.embedding_lookup(self.embedding, self.input_data) 90 | if is_training and self.input_dropout > 0: 91 | inputs = tf.nn.dropout(inputs, 1-self.input_dropout) 92 | 93 | with tf.name_scope('slice_inputs'): 94 | # num_unrollings * (batch_size, embedding_size), the format of rnn inputs. 95 | sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split( 96 | axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)] 97 | 98 | # sliced_inputs: list of shape xx 99 | # inputs: A length T list of inputs, each a Tensor of shape [batch_size, input_size] 100 | # initial_state: An initial state for the RNN. 101 | # If cell.state_size is an integer, this must be a Tensor of appropriate 102 | # type and shape [batch_size, cell.state_size] 103 | # outputs: a length T list of outputs (one for each input), or a nested tuple of such elements. 104 | # state: the final state 105 | outputs, final_state = tf.nn.static_rnn( 106 | cell = multi_cell, 107 | inputs = sliced_inputs, 108 | initial_state=self.initial_state) 109 | self.final_state = final_state 110 | 111 | with tf.name_scope('flatten_outputs'): 112 | flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size]) 113 | 114 | with tf.name_scope('flatten_targets'): 115 | flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1]) 116 | 117 | with tf.variable_scope('softmax') as sm_vs: 118 | softmax_w = tf.get_variable('softmax_w', [hidden_size, vocab_size]) 119 | softmax_b = tf.get_variable('softmax_b', [vocab_size]) 120 | self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b 121 | self.probs = tf.nn.softmax(self.logits) 122 | 123 | with tf.name_scope('loss'): 124 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 125 | logits = self.logits, labels = flat_targets) 126 | self.mean_loss = tf.reduce_mean(loss) 127 | 128 | with tf.name_scope('loss_montor'): 129 | count = tf.Variable(1.0, name='count') 130 | sum_mean_loss = tf.Variable(1.0, name='sum_mean_loss') 131 | 132 | self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0), 133 | count.assign(0.0), name='reset_loss_monitor') 134 | self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss), 135 | count.assign(count+1), name='update_loss_monitor') 136 | 137 | with tf.control_dependencies([self.update_loss_monitor]): 138 | self.average_loss = sum_mean_loss / count 139 | self.ppl = tf.exp(self.average_loss) 140 | 141 | average_loss_summary = tf.summary.scalar( 142 | name = 'average loss', tensor = self.average_loss) 143 | ppl_summary = tf.summary.scalar( 144 | name = 'perplexity', tensor = self.ppl) 145 | 146 | self.summaries = tf.summary.merge( 147 | inputs = [average_loss_summary, ppl_summary], name='loss_monitor') 148 | 149 | self.global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.0)) 150 | 151 | # self.learning_rate = tf.constant(learning_rate) 152 | self.learning_rate = tf.placeholder(tf.float32, [], name='learning_rate') 153 | 154 | if is_training: 155 | tvars = tf.trainable_variables() 156 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm) 157 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 158 | self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step) 159 | 160 | 161 | def run_epoch(self, session, batch_generator, is_training, learning_rate, verbose=0, freq=10): 162 | epoch_size = batch_generator.num_batches 163 | 164 | if verbose > 0: 165 | logging.info('epoch_size: %d', epoch_size) 166 | logging.info('data_size: %d', batch_generator.seq_length) 167 | logging.info('num_unrollings: %d', self.num_unrollings) 168 | logging.info('batch_size: %d', self.batch_size) 169 | 170 | if is_training: 171 | extra_op = self.train_op 172 | else: 173 | extra_op = tf.no_op() 174 | 175 | 176 | if self.cell_type in ['rnn', 'gru']: 177 | state = self.zero_state.eval() 178 | else: 179 | state = tuple([(np.zeros((self.batch_size, self.hidden_size)), 180 | np.zeros((self.batch_size, self.hidden_size))) 181 | for _ in range(self.num_layers)]) 182 | 183 | self.reset_loss_monitor.run() 184 | batch_generator.reset_batch_pointer() 185 | start_time = time.time() 186 | ppl_cumsum = 0 187 | for step in range(epoch_size): 188 | x, y = batch_generator.next_batch() 189 | 190 | ops = [self.average_loss, self.ppl, self.final_state, extra_op, 191 | self.summaries, self.global_step] 192 | 193 | feed_dict = {self.input_data: x, self.targets: y, self.initial_state: state, 194 | self.learning_rate: learning_rate} 195 | 196 | results = session.run(ops, feed_dict) 197 | average_loss, ppl, final_state, _, summary_str, global_step = results 198 | ppl_cumsum += ppl 199 | 200 | # if (verbose > 0) and ((step+1) % freq == 0): 201 | if ((step+1) % freq == 0): 202 | logging.info('%.1f%%, step:%d, perplexity: %.3f, speed: %.0f words', 203 | (step + 1) * 1.0 / epoch_size * 100, step, ppl_cumsum/(step+1), 204 | (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time)) 205 | logging.info("Perplexity: %.3f, speed: %.0f words per sec", 206 | ppl, (step + 1) * self.batch_size * self.num_unrollings / (time.time() - start_time)) 207 | 208 | return ppl, summary_str, global_step 209 | 210 | def sample_seq(self, session, length, start_text, sample_type= SampleType.max_prob,given='',rhyme_ref='',rhyme_idx = 0): 211 | #state = self.zero_state.eval() 212 | if self.cell_type in ['rnn', 'gru']: 213 | state = self.zero_state.eval() 214 | else: 215 | state = tuple([(np.zeros((self.batch_size, self.hidden_size)), 216 | np.zeros((self.batch_size, self.hidden_size))) 217 | for _ in range(self.num_layers)]) 218 | 219 | # use start_text to warm up the RNN. 220 | start_text = self.check_start(start_text) 221 | if start_text is not None and len(start_text) > 0: 222 | seq = list(start_text) 223 | for char in start_text[:-1]: 224 | x = np.array([[self.w2v_model.vocab_hash[char]]]) 225 | state = session.run(self.final_state, {self.input_data: x, self.initial_state: state}) 226 | x = np.array([[self.w2v_model.vocab_hash[start_text[-1]]]]) 227 | else: 228 | x = np.array([[np.random.randint(0, self.vocab_size)]]) 229 | seq = [] 230 | 231 | for i in range(length): 232 | state, logits = session.run([self.final_state, self.logits], 233 | {self.input_data: x, self.initial_state: state}) 234 | unnormalized_probs = np.exp(logits[0] - np.max(logits[0])) 235 | probs = unnormalized_probs / np.sum(unnormalized_probs) 236 | 237 | if rhyme_ref and i == rhyme_idx : 238 | sample = self.select_rhyme(rhyme_ref,probs) 239 | elif sample_type == SampleType.max_prob: 240 | sample = np.argmax(probs) 241 | elif sample_type == SampleType.select_given: 242 | sample,given = self.select_by_given(given,probs) 243 | else: #SampleType.weighted_sample 244 | sample = np.random.choice(self.vocab_size, 1, p=probs)[0] 245 | 246 | seq.append(self.w2v_model.vocab[sample]) 247 | x = np.array([[sample]]) 248 | 249 | return ''.join(seq) 250 | 251 | def select_by_given(self,given,probs,max_prob = False): 252 | if given: 253 | seq_probs = zip(probs,range(0,self.vocab_size)) 254 | topn = heapq.nlargest(100,seq_probs,key=lambda sp :sp[0]) 255 | 256 | for _,seq in topn: 257 | if self.w2v_model.vocab[seq] in given: 258 | given = given.replace(self.w2v_model.vocab[seq],'') 259 | return seq,given 260 | if max_prob: 261 | return np.argmax(probs),given 262 | 263 | return np.random.choice(self.vocab_size, 1, p=probs)[0],given 264 | 265 | 266 | def select_rhyme(self,rhyme_ref,probs): 267 | if rhyme_ref: 268 | rhyme_set = RhymeWords.get_rhyme_words(rhyme_ref) 269 | if rhyme_set: 270 | seq_probs = zip(probs,range(0,self.vocab_size)) 271 | topn = heapq.nlargest(50,seq_probs,key=lambda sp :sp[0]) 272 | 273 | for _,seq in topn: 274 | if self.w2v_model.vocab[seq] in rhyme_set: 275 | return seq 276 | 277 | return np.argmax(probs) 278 | 279 | def check_start(self,text): 280 | idx = text.find('<') 281 | if idx > -1: 282 | text = text[:idx] 283 | 284 | valid_text = [] 285 | for w in text: 286 | if w in self.w2v_model.vocab: 287 | valid_text.append(w) 288 | return ''.join(valid_text) 289 | -------------------------------------------------------------------------------- /config_poem.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | def config_poem_train(args=''): 5 | parser = argparse.ArgumentParser() 6 | 7 | # Data and vocabulary file 8 | # parser.add_argument('--data_file', type=str, 9 | # default='../data/poem/poems_space.txt', 10 | # help='data file') 11 | 12 | parser.add_argument('--data_path', type=str, 13 | default='./data/poem/', 14 | help='data path') 15 | 16 | 17 | parser.add_argument('--encoding', type=str, 18 | default='utf-8', 19 | help='the encoding of the data file.') 20 | 21 | # Parameters for saving models. 22 | parser.add_argument('--output_dir', type=str, default='output_model', 23 | help=('directory to store final and' 24 | ' intermediate results and models.')) 25 | # Parameters for using saved best models. 26 | parser.add_argument('--init_dir', type=str, default='', 27 | help='continue from the outputs in the given directory') 28 | 29 | # Parameters to configure the neural network. 30 | parser.add_argument('--hidden_size', type=int, default=128,#128, 31 | help='size of RNN hidden state vector') 32 | parser.add_argument('--embedding_size', type=int, default=128,#0, 33 | help='size of character embeddings, 0 for one-hot') 34 | parser.add_argument('--num_layers', type=int, default=2, 35 | help='number of layers in the RNN') 36 | parser.add_argument('--num_unrollings', type=int, default=64,#10, 37 | help='number of unrolling steps.') 38 | parser.add_argument('--cell_type', type=str, default='lstm', 39 | help='which model to use (rnn, lstm or gru).') 40 | 41 | # Parameters to control the training. 42 | parser.add_argument('--num_epochs', type=int, default=5, 43 | help='number of epochs') 44 | parser.add_argument('--batch_size', type=int, default=16, 45 | help='minibatch size') 46 | parser.add_argument('--train_frac', type=float, default=0.9, 47 | help='fraction of data used for training.') 48 | parser.add_argument('--valid_frac', type=float, default=0.05, 49 | help='fraction of data used for validation.') 50 | # test_frac is computed as (1 - train_frac - valid_frac). 51 | parser.add_argument('--dropout', type=float, default=0.0, 52 | help='dropout rate, default to 0 (no dropout).') 53 | 54 | parser.add_argument('--input_dropout', type=float, default=0.0, 55 | help=('dropout rate on input layer, default to 0 (no dropout),' 56 | 'and no dropout if using one-hot representation.')) 57 | 58 | # Parameters for gradient descent. 59 | parser.add_argument('--max_grad_norm', type=float, default=5., 60 | help='clip global grad norm') 61 | parser.add_argument('--learning_rate', type=float, default=5e-3, 62 | help='initial learning rate') 63 | 64 | # Parameters for logging. 65 | parser.add_argument('--progress_freq', type=int, default=100, 66 | help=('frequency for progress report in training and evalution.')) 67 | parser.add_argument('--verbose', type=int, default=0, 68 | help=('whether to show progress report in training and evalution.')) 69 | 70 | # Parameters to feed in the initial model and current best model. 71 | parser.add_argument('--init_model', type=str, 72 | default='', help=('initial model')) 73 | parser.add_argument('--best_model', type=str, 74 | default='', help=('current best model')) 75 | parser.add_argument('--best_valid_ppl', type=float, 76 | default=np.Inf, help=('current valid perplexity')) 77 | 78 | # # Parameters for using saved best models. 79 | # parser.add_argument('--model_dir', type=str, default='', 80 | # help='continue from the outputs in the given directory') 81 | 82 | # Parameters for debugging. 83 | parser.add_argument('--debug', dest='debug', action='store_true', 84 | help='show debug information') 85 | parser.set_defaults(debug=False) 86 | 87 | # Parameters for unittesting the implementation. 88 | parser.add_argument('--test', dest='test', action='store_true', 89 | help=('use the first 1000 character to as data to test the implementation')) 90 | parser.set_defaults(test=False) 91 | 92 | # input_args = '--data_path ./data/poem --output_dir output_poem --hidden_size 256 --embedding_size 128 --num_unrollings 128 --debug --encoding utf-8' 93 | args = parser.parse_args(args.split()) 94 | 95 | return args 96 | 97 | 98 | 99 | def config_sample(args=''): 100 | parser = argparse.ArgumentParser() 101 | 102 | # hyper-parameters for using saved best models. 103 | # 学习日志和结果相关的超参数 104 | logging_args = parser.add_argument_group('Logging_Options') 105 | logging_args.add_argument('--model_dir', type=str, 106 | default='demo_model/', 107 | help='continue from the outputs in the given directory') 108 | 109 | logging_args.add_argument('--data_dir', type=str, 110 | default='./data/poem', 111 | help='data file path') 112 | 113 | logging_args.add_argument('--best_model', type=str, 114 | default='', help=('current best model')) 115 | 116 | # hyper-parameters for sampling. 117 | # 设置sampling相关的超参数 118 | testing_args = parser.add_argument_group('Sampling Options') 119 | testing_args.add_argument('--max_prob', dest='max_prob', action='store_true', 120 | help='always pick the most probable next character in sampling') 121 | testing_args.set_defaults(max_prob=False) 122 | 123 | testing_args.add_argument('--start_text', type=str, 124 | default='The meaning of life is ', 125 | help='the text to start with') 126 | 127 | testing_args.add_argument('--length', type=int, 128 | default=100, 129 | help='length of sampled sequence') 130 | 131 | testing_args.add_argument('--seed', type=int, 132 | default=-1, 133 | help=('seed for sampling to replicate results, ' 134 | 'an integer between 0 and 4294967295.')) 135 | 136 | args = parser.parse_args(args.split()) 137 | 138 | return args -------------------------------------------------------------------------------- /data/poem/poem_ids.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/data/poem/poem_ids.txt -------------------------------------------------------------------------------- /data/poem/rhyme_words.txt: -------------------------------------------------------------------------------- 1 | 爸 把 八 罢 坝 拔 霸 扒 靶 叭 吧 擦 差 茶 插 查 叉 察 刹 咤 衩 大 达 发 法 伐 罚 尬 挂 瓜 刮 寡 呱 哈 花 华 化 话 画 滑 划 猾 家 加 价 甲 佳 假 夹 颊 驾 嫁 卡 跨 垮 夸 啦 拉 蜡 辣 妈 吗 嘛 马 砝 码 麻 骂 那 霎 那 拿 怕 爬 恰 撒 洒 沙 纱 啥 傻 厦 她 踏 塌 塔 榻 哇 娃 下 夏 霞 暇 压 雅 哑 崖 涯 杂 砸 眨 炸 吒 爪 抓 佳 玛 瑕 蚂 咖 垃 2 | 3 | 4 | 爱 哀 皑 埃 摆 白 百 拜 败 才 彩 财 采 踩 猜 睬 柴 踹 揣 待 带 代 袋 呆 该 改 怪 乖 海 害 孩 怀 坏 徘 徊 开 慨 快 块 来 赖 赖 莱 买 卖 脉 迈 麦 埋 耐 奶 奈 派 牌 排 湃 赛 腮 塞 晒 帅 衰 甩 摔 台 太 态 泰 外 歪 在 载 再 灾 栽 仔 债 摘 窄 拽 暧 儿 二 耳 饵 而 尔 5 | 6 | 安 暗 岸 边 变 辨 滨 伴 烂 餐 惨 残 惭 璀 璨 缠 颤 馋 掺 忏 潺 传 川 穿 船 串 喘 单 淡 担 胆 诞 耽 蛋 短 段 断 端 缎 帆 凡 反 饭 翻 返 犯 繁 烦 干 感 甘 敢 赶 肝 关 管 冠 灌 罐 惯 官 观 含 汉 寒 喊 汗 函 瀚 旱 憨 酣 欢 唤 还 换 幻 患 缓 焕 看 坎 侃 砍 堪 槛 款 宽 蓝 栏 懒 澜 揽 泛 滥 兰 婪 橄 榄 乱 峦 卵 满 慢 漫 瞒 男 南 难 暖 盼 判 盘 攀 畔 叛 然 渲 染 燃 软 伞 三 散 酸 算 山 善 闪 扇 删 蹒 跚 衫 拴 滩 叹 谈 坛 坦 探 贪 碳 毯 瘫 团 万 玩 完 晚 碗 湾 弯 婉 腕 挽 战 占 站 展 站 盏 湛 绽 转 专 煎 点 电 殿 垫 甸 巅 间 肩 见 件 剑 建 健 减 尖 坚 箭 渐 剪 舰 贱 捡 煎 俭 艰 拣 茧 连 怜 脸 联 链 练 恋 炼 莲 廉 帘 面 眠 免 绵 棉 勉 年 念 粘 篇 片 骗 偏 翩 牵 前 钱 千 浅 迁 签 欠 倩 谦 潜 遣 歉 虔 纤 阡 天 田 甜 填 添 舔 先 现 仙 线 限 鲜 弦 贤 闲 险 献 陷 娴 掀 倦 卷 娟 绢 捐 卷 涓 圈 搀 涟 珊 全 权 泉 犬 圈 拳 劝 选 轩 玄 喧 宣 旋 炫 悬 眩 绚 漩 言 眼 烟 燕 严 盐 严 演 艳 炎 宴 雁 验 焰 掩 嫣 厌 衍 淹 原 园 圆 源 员 远 院 愿 缘 苑 怨 渊 冤 援 尴 忐 蝙 颠 乾 店 7 | 8 | 昂 帮 绑 藏 沧 苍 茫 唱 长 唱 场 畅 怅 创 伤 窗 闯 党 荡 挡 方 放 芳 港 岗 光 广 逛 行 航 彷 徨 黄 荒 晃 皇 慌 谎 凰 江 降 浆 疆 匠 将 康 扛 抗 狂 旷 郎 浪 狼 朗 琅 谅 量 梁 亮 良 凉 忙 茫 氓 囊 娘 旁 胖 强 墙 枪 抢 腔 让 嚷 壤 伤 上 尚 赏 双 爽 霜 糖 棠 唐 汤 塘 趟 烫 躺 淌 网 望 往 汪 忘 旺 妄 想 向 乡 香 翔 巷 祥 响 详 样 阳 样 扬 养 氧 仰 汪 洋 痒 鸳 鸯 荡 漾 脏 葬 章 掌 涨 账 仗 杖 障 胀 装 状 壮 妆 桩 恍 滂 9 | 10 | 熬 骄 傲 奥 凹 报 宝 包 保 堡 爆 抱 暴 薄 饱 豹 曝 标 表 草 操 槽 超 炒 巢 潮 抄 吵 钞 朝 到 道 道 岛 刀 盗 倒 稻 捣 蹈 唠 叨 悼 高 搞 告 稿 糕 好 号 浩 豪 耗 嚎 毫 叫 教 交 脚 靠 烤 拷 老 劳 牢 捞 烙 酪 涝 了 聊 料 疗 毛 猫 茂 帽 冒 貌 卯 矛 髦 锚 苗 秒 妙 庙 描 瞄 淼 缈 脑 闹 挠 恼 瑙 孬 鸟 尿 袅 跑 泡 抛 炮 袍 刨 飘 漂 票 嫖 巧 桥 敲 瞧 侨 翘 俏 窍 撬 悄 峭 跷 锹 绕 饶 扰 饶 扫 嫂 骚 少 烧 稍 绍 勺 梢 哨 捎 艄 套 涛 桃 讨 逃 淘 掏 滔 嚎 啕 跳 条 调 挑 眺 迢 窕 小 笑 晓 肖 萧 校 效 消 孝 销 咆 哮 霄 削 宵 要 腰 耀 妖 遥 窑 谣 摇 咬 药 肴 耀 吆 幺 早 造 遭 枣 灶 燥 皂 糟 藻 凿 噪 澡 跳 蚤 找 照 招 着 兆 朝 罩 憔 狡 陶 11 | 12 | 层 曾 村 寸 存 忖 臣 晨 尘 称 辰 沉 趁 秤 成 城 称 呈 乘 诚 承 橙 澄 撑 骋 惩 春 纯 唇 醇 淳 蠢 等 灯 登 蹬 瞪 凳 吨 顿 盾 炖 敦 墩 蹲 钝 盹 恩 本 奔 笨 嗯 摁 苯 嫩 恁 恁 能 13 | 14 | 林 令 灵 另 零 临 邻 岭 凌 领 龄 铃 陵 淋 拎 琳 顶 定 丁 鼎 订 钉 盯 锭 叮 腚 均 军 君 俊 菌 峻 金 经 近 进 仅 今 精 景 静 尽 京 净 井 紧 锦 敬 镜 巾 劲 境 晶 津 斤 惊 禁 警 径 筋 浸 睛 憬 兢 荆 缤 冰 病 滨 鬓 宾 兵 并 饼 柄 丙 濒 玲 伶 淋 15 | 16 | 17 | 文 问 温 闻 纹 稳 吻 蚊 瘟 紊 门 闷 焖 们 汶 冷 肯 垦 恳 啃 坑 吭 坤 困 捆 冷 愣 棱 论 轮 昆 仑 抡 分 份 粉 纷 奋 粪 焚 愤 氛 跟 根 艮 更 耿 耕 庚 羹 埂 滚 棍 很 恨 痕 狠 恒 横 哼 衡 混 魂 婚 昏 馄 浑 荤 人 仁 任 认 忍 韧 刃 饪 润 森 僧 孙 损 损 孙 神 身 深 伸 渗 参 生 升 声 胜 圣 盛 绳 剩 吞 臀 豚 腾 藤 疼 怎 增 赠 憎 尊 遵 镇 真 振 针 珍 震 阵 帧 贞 枕 诊 疹 斟 正 证 政 整 争 症 征 蒸 挣 铮 怔 筝 睁 诤 准 芬 18 | 19 | 名 明 民 命 敏 鸣 铭 抿 您 宁 凝 拧 咛 喷 盆 品 平 评 凭 瓶 萍 屏 拼 坪 聘 贫 频 群 请 清 青 情 轻 庆 琴 亲 晴 勤 卿 倾 禽 侵 寝 擒 引擎 顷 听 厅 亭 停 挺 庭 廷 艇 蜻 蜓 心 新 型 性 行 信 星 兴 形 欣 姓 刑 馨 幸 醒 薪 腥 惺 寻 讯 训 迅 勋 巡 循 询 熏 旬 汛 驯 音 应 因 英 银 印 影 营 引 阴 硬 饮 隐 迎 鹰 赢 淫 殷 颖 盈 映 吟 荫 茵 婴 瘾 莺 缨 蝇 萦 姻 韵 云 运 晕 芸 允 匀 孕 蕴 熨 耘 酝 20 | 21 | 22 | 23 | 蹦 绷 甭 嘣 从 聪 丛 匆 重 冲 虫 充 崇 宠 忡 东 动 洞 冬 栋 懂 冻 咚 风 峰 封 凤 锋 奉 逢 缝 疯 蜂 讽 工 公 共 供 宫 功 攻 弓 贡 恭 拱 躬 红 洪 宏 鸿 弘 哄 轰 虹 泓 烘 讧 炯 窘 空 孔 控 恐 龙 隆 笼 聋 拢 珑 梦 蒙 猛 盟 萌 朦 农 弄 浓 侬 脓 鹏 碰 捧 棚 朋 蓬 砰 膨 怦 穷 琼 穹 荣 容 融 溶 绒 熔 送 松 颂 诵 耸 讼 通 痛 同 统 童 桶 筒 桐 捅 翁 嗡 熊 雄 兄 胸 凶 汹 用 永 勇 拥 涌 咏 庸 泳 蛹 总 宗 纵 综 踪 中 钟 重 种 众 忠 终 肿 踵 盅 憧 澎 烹 碰 24 | 25 | 不 步 部 布 补 捕 簿 怖 醋 促 粗 出 除 楚 初 处 储 触 厨 橱 雏 础 矗 搐 躇 度 都 读 独 毒 渡 赌 堵 肚 镀 督 嘟 督 妒 睹 渎 副 福 富 夫 复 服 附 付 负 府 赴 伏 父 扶 符 傅 浮 腹 辅 覆 赋 抚 妇 肤 腐 拂 斧 咐 腑 蝠 孵 俘 袱 匍 匐 扶 故 股 古 谷 顾 骨 固 鼓 孤 姑 菇 估 雇 咕 胡 户 湖 虎 乎 护 呼 互 壶 忽 狐 糊 弧 唬 惚 瑚 苦 库 哭 酷 裤 枯 窟 录 路 鲁 陆 炉 鹿 禄 鹭 虏 颅 噜 赂 木 亩 目 母 姆 墓 牧 幕 慕 沐 睦 怒 奴 努 弩 驽 呶 谱 铺 普 朴 扑 瀑 仆 铺 如 入 汝 乳 儒 辱 褥 苏 素 速 宿 诉 塑 俗 酥 溯 数 书 属 树 术 舒 输 熟 叔 鼠 蜀 束 述 淑 竖 梳 殊 恕 暑 孰 墅 漱 赎 图 土 涂 兔 吐 突 徒 途 屠 凸 秃 荼 涂 无 五 吴 物 武 舞 屋 吾 务 乌 勿 雾 伍 悟 误 午 巫 污 唔 呜 恶 捂 忤 侮 圬 组 租 足 族 祖 阻 卒 主 注 著 注 朱 住 猪 诸 竹 驻 祝 柱 助 煮 逐 筑 铸 烛 嘱 拄 蛛 敷 26 | 27 | 28 | 29 | 次 此 词 刺 赐 慈 辞 瓷 雌 疵 吃 持 池 赤 尺 迟 驰 齿 痴 翅 耻 弛 地 低 底 帝 滴 敌 蒂 抵 逐 弟 堤 递 笛 谛 及 即 级 机 集 基 记 吉 计 极 几 继 鸡 急 剂 季 纪 积 极 技 忌 击 姬 寄 籍 济 肌 疾 挤 祭 际 迹 己 激 辑 寂 绩 脊 妓 饥 叽 棘 讥 圾 嫉 笈 据 局 具 居 距 聚 句 举 咀 巨 俱 菊 剧 拒 锯 惧 桔 鞠 矩 拘 炬 驹 理 里 李 力 利 立 离 丽 例 礼 莉 黎 历 厉 梨 栗 沥 励 隶 厘 犁 俐 漓 狸 率 绿 铝 律 驴 旅 履 虑 滤 缕 侣 米 迷 密 秘 蜜 弥 谜 觅 咪 靡 眯 泌 你 昵 拟 泥 逆 腻 溺 女 皮 批 匹 脾 劈 屁 披 坯 辟 疲 啤 癖 僻 其 起 期 七 气 器 齐 奇 企 旗 启 漆 骑 妻 弃 棋 汽 欺 泣 乞 戚 契 歧 杞 去 区 取 曲 趣 娶 屈 驱 渠 躯 蛆 趋 蹊 30 | 31 | 日 四 斯 思 死 司 寺 丝 似 私 撕 肆 嘶 是 时 市 使 式 室 石 事 十 师 实 诗 世 史 氏 食 视 士 施 试 饰 始 失 识 湿 势 示 释 适 仕 狮 尸 拾 侍 逝 誓 蚀 驶 虱 拭 自 子 字 紫 资 仔 姿 滋 咨 渍 制 值 只 至 止 之 指 志 直 知 智 治 支 纸 质 致 置 枝 职 址 执 汁 植 芝 脂 织 滞 旨 掷 肢 趾 峙 稚 秩 秩 挚 帜 蜘 32 | 33 | 体 题 提 替 踢 梯 蹄 啼 锑 剃 涕 剔 惕 屉 嚏 西 系 喜 希 细 溪 洗 戏 吸 息 兮 席 熙 习 悉 稀 惜 夕 袭 析 隙 嘻 析 昔 嬉 熄 晰 需 须 许 虚 需 续 旭 序 恤 叙 绪 蓄 胥 絮 嘘 墟 栩 煦 酗 一 以 已 亿 易 亦 意 依 义 宜 伊 乙 仪 衣 益 矣 艺 异 已 译 医 逸 翼 怡 移 毅 议 疑 忆 倚 椅 遗 溢 奕 役 抑 谊 姨 疫 蚁 裔 弈 漪 奕 与 于 玉 鱼 余 於 雨 欲 语 宇 羽 遇 域 预 愈 育 予 御 豫 郁 浴 喻 愚 誉 娱 狱 屿 淤 愉 吁 谀 彼 比 碧 壁 闭 蔽 驿 逼 笔 币 必 毕 避 鼻 弊 毙 哔 鄙 34 | 35 | 36 | 波 博 薄 播 勃 搏 剥 泊 脖 渤 胳 膊 跛 萝 卜 恶 饿 测 册 侧 策 车 撤 扯 彻 澈 戳 龌 龊 多 朵 夺 躲 堕 剁 舵 惰 得 德 碟 蝶 笛 叠 爹 谍 跌 佛 国 过 果 郭 锅 裹 蝈 帼 个 各 格 歌 哥 阁 戈 隔 革 割 鸽 嗝 或 获 火 活 货 霍 祸 惑 伙 豁 攉 和 合 何 河 喝 盒 贺 核 荷 鹤 赫 禾 褐 呵 吓 接 街 节 杰 解 界 结 借 姐 洁 戒 捷 介 劫 阶 揭 截 竭 诫 觉 绝 决 抉 诀 爵 掘 嚼 倔 撅 崛 可 克 科 客 课 克 科 颗 壳 刻 咳 棵 渴 磕 苛 刻 扩 阔 括 廓 罗 落 洛 络 螺 裸 锣 骡 菠 萝 了 乐 勒 列 烈 裂 劣 猎 咧 冽 洌 略 掠 莫 魔 末 膜 墨 模 摩 磨 默 摸 么 抹 沫 漠 陌 馍 灭 诺 挪 懦 捏 孽 破 坡 颇 婆 迫 泼 魄 珀 陂 泊 撇 且 切 妾 窃 怯 却 缺 确 阙 雀 鹊 阕 瘸 若 弱 热 惹 缩 锁 所 索 梭 琐 娑 啰 哆 嗦 婆 娑 色 塞 瑟 涩 啬 说 硕 烁 设 社 摄 舍 蛇 射 折 舌 涉 赦 奢 慑 赊 涉 托 脱 拖 拓 陀 妥 骆 驼 驮 沱 唾 特 忑 蝌 做 作 座 左 坐 佐 昨 则 泽 责 择 仄 啧 卓 桌 着 捉 啄 浊 拙 灼 酌 这 者 着 折 哲 遮 辄 蔗 辙 37 | 38 | 贴 铁 我 窝 沃 卧 握 涡 些 血 写 谢 鞋 协 邪 斜 携 蟹 歇 泄 泻 卸 屑 械 挟 谐 蝎 偕 懈 胁 学 雪 血 穴 靴 削 也 业 页 叶 夜 液 野 耶 爷 冶 椰 噎 月 约 曰 乐 越 粤 岳 跃 悦 阅 亵 别 憋 鳖 瘪 邂 39 | 40 | 41 | 凑 愁 抽 丑 臭 仇 筹 酬 稠 绸 瞅 踌 畴 都 斗 豆 抖 逗 兜 陡 痘 蚪 丢 狗 购 沟 够 钩 勾 构 垢 后 侯 厚 候 猴 吼 喉 逅 就 旧 九 酒 久 救 纠 揪 鸠 舅 疚 口 扣 寇 抠 叩 蔻 楼 露 漏 喽 搂 篓 陋 六 流 留 柳 瘤 溜 榴 遛 某 谋 眸 牛 扭 纽 钮 妞 拗 忸 剖 球 求 秋 邱 丘 裘 囚 鳅 肉 柔 揉 搜 艘 馊 擞 手 受 收 首 守 售 寿 兽 授 瘦 熟 头 投 透 偷 修 秀 休 绣 袖 羞 嗅 朽 有 游 由 又 友 油 右 优 尤 邮 幽 佑 幼 忧 悠 诱 走 奏 揍 周 州 洲 轴 舟 粥 咒 皱 肘 骤 昼 柚 宙 帚 惆 犹 42 | 43 | 北 备 倍 杯 背 贝 碑 悲 辈 卑 惫 狈 翠 催 脆 萃 摧 粹 悴 对 队 堆 飞 非 费 菲 肥 废 肺 妃 匪 霏 啡 给 归 贵 鬼 桂 规 柜 龟 轨 跪 诡 玫 瑰 黑 嘿 会 灰 回 慧 惠 辉 汇 绘 徽 毁 晖 挥 悔 卉 讳 亏 魁 窥 葵 溃 馈 盔 愧 睽 类 雷 泪 累 蕾 磊 类 垒 擂 梅 每 没 美 枚 煤 妹 眉 酶 媚 媒 霉 昧 内 配 培 陪 佩 赔 沛 呸 瑞 锐 岁 随 虽 遂 碎 穗 隋 髓 退 推 腿 褪 为 位 未 卫 维 伟 威 微 味 魏 尾 喂 谓 围 胃 伪 危 纬 畏 违 贼 最 嘴 罪 醉 追 坠 锥 缀 委 贿 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /data/poem/vectors_poem.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/data/poem/vectors_poem.bin -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | #import collections 3 | from six.moves import cPickle 4 | import numpy as np 5 | from word2vec_helper import Word2Vec 6 | import math 7 | 8 | 9 | 10 | class DataLoader(): 11 | def __init__(self, data_dir, batch_size,seq_max_length,w2v,data_type): 12 | self.data_dir = data_dir 13 | self.batch_size = batch_size 14 | self.seq_max_length = seq_max_length 15 | self.w2v = w2v 16 | self.trainingSamples = [] 17 | self.validationSamples = [] 18 | self.testingSamples = [] 19 | self.train_frac = 0.85 20 | self.valid_frac = 0.05 21 | 22 | self.load_corpus(self.data_dir) 23 | 24 | if data_type == 'train': 25 | self.create_batches(self.trainingSamples) 26 | elif data_type == 'test': 27 | self.create_batches(self.testingSamples) 28 | elif data_type == 'valid': 29 | self.create_batches(self.validationSamples) 30 | 31 | self.reset_batch_pointer() 32 | 33 | def _print_stats(self): 34 | print('Loaded {}: training samples:{} ,validationSamples:{},testingSamples:{}'.format( 35 | self.data_dir, len(self.trainingSamples),len(self.validationSamples),len(self.testingSamples))) 36 | 37 | def load_corpus(self,base_path): 38 | """读/创建 对话数据: 39 | 在训练文件创建的过程中,由两个文件 40 | 1. self.fullSamplePath 41 | 2. self.filteredSamplesPath 42 | """ 43 | tensor_file = os.path.join(base_path,'poem_ids.txt') 44 | print('tensor_file:%s' % tensor_file) 45 | 46 | datasetExist = os.path.isfile(tensor_file) 47 | # 如果处理过的对话数据文件不存在,创建数据文件 48 | if not datasetExist: 49 | print('训练样本不存在。从原始样本数据集创建训练样本...') 50 | 51 | fullSamplesPath = os.path.join(self.data_dir,'poems_edge_split.txt') 52 | # 创建/读取原始对话样本数据集: self.trainingSamples 53 | print('fullSamplesPath:%s' % fullSamplesPath) 54 | self.load_from_text_file(fullSamplesPath) 55 | 56 | else: 57 | self.load_dataset(tensor_file) 58 | 59 | self.padToken = self.w2v.ix('') 60 | self.goToken = self.w2v.ix('[') 61 | self.eosToken = self.w2v.ix(']') 62 | self.unknownToken = self.w2v.ix('') 63 | 64 | self._print_stats() 65 | # assert self.padToken == 0 66 | 67 | def load_from_text_file(self,in_file): 68 | # base_path = 'F:\BaiduYunDownload\chatbot_lecture\lecture2\data\ice_and_fire_zh' 69 | # in_file = os.path.join(base_path,'poems_edge.txt') 70 | fr = open(in_file, "r",encoding='utf-8') 71 | poems = fr.readlines() 72 | fr.close() 73 | 74 | print("唐诗总数: %d"%len(poems)) 75 | # self.seq_max_length = max([len(poem) for poem in poems]) 76 | # print("seq_max_length: %d"% (self.seq_max_length)) 77 | 78 | poem_ids = DataLoader.get_text_idx(poems,self.w2v.vocab_hash,self.seq_max_length) 79 | 80 | # # 后续处理 81 | # # 1. 单词过滤,去掉不常见(<=filterVocab)的单词,保留最常见的vocabSize个单词 82 | # print('Filtering words (vocabSize = {} and wordCount > {})...'.format( 83 | # self.args.vocabularySize, 84 | # self.args.filterVocab 85 | # )) 86 | # self.filterFromFull() 87 | 88 | # 2. 分割数据 89 | print('分割数据为 train, valid, test 数据集...') 90 | n_samples = len(poem_ids) 91 | train_size = int(self.train_frac * n_samples) 92 | valid_size = int(self.valid_frac * n_samples) 93 | test_size = n_samples - train_size - valid_size 94 | 95 | print('n_samples=%d, train-size=%d, valid_size=%d, test_size=%d' % ( 96 | n_samples, train_size, valid_size, test_size)) 97 | self.testingSamples = poem_ids[-test_size:] 98 | self.validationSamples = poem_ids[-valid_size-test_size : -test_size] 99 | self.trainingSamples = poem_ids[:train_size] 100 | 101 | # 保存处理过的训练数据集 102 | print('Saving dataset...') 103 | poem_ids_file = os.path.join(self.data_dir,'poem_ids.txt') 104 | self.save_dataset(poem_ids_file) 105 | 106 | # 2. utility 函数,使用pickle写文件 107 | def save_dataset(self, filename): 108 | """使用pickle保存数据文件。 109 | 110 | 数据文件包含词典和对话样本。 111 | 112 | Args: 113 | filename (str): pickle 文件名 114 | """ 115 | with open(filename, 'wb') as handle: 116 | data = { 117 | 'trainingSamples': self.trainingSamples 118 | } 119 | 120 | if len(self.validationSamples)>0: 121 | data['validationSamples'] = self.validationSamples 122 | data['testingSamples'] = self.testingSamples 123 | data['maxSeqLen'] = self.seq_max_length 124 | 125 | cPickle.dump(data, handle, -1) # Using the highest protocol available 126 | 127 | # 3. utility 函数,使用pickle读文件 128 | def load_dataset(self, filename): 129 | """使用pickle读入数据文件 130 | Args: 131 | filename (str): pickle filename 132 | """ 133 | 134 | print('Loading dataset from {}'.format(filename)) 135 | with open(filename, 'rb') as handle: 136 | data = cPickle.load(handle) 137 | self.trainingSamples = data['trainingSamples'] 138 | 139 | if 'validationSamples' in data: 140 | self.validationSamples = data['validationSamples'] 141 | self.testingSamples = data['testingSamples'] 142 | 143 | print('file maxSeqLen = {}'.format( data['maxSeqLen'])) 144 | 145 | 146 | @classmethod 147 | def get_text_idx(text,vocab,max_document_length): 148 | text_array = [] 149 | for i,x in enumerate(text): 150 | line = [] 151 | for j, w in enumerate(x): 152 | if (w not in vocab): 153 | w = '' 154 | line.append(vocab[w]) 155 | text_array.append(line) 156 | # else : 157 | # print w,'not exist' 158 | 159 | return text_array 160 | 161 | def create_batches(self,samples): 162 | 163 | sample_size = len(samples) 164 | self.num_batches = math.ceil(sample_size /self.batch_size) 165 | new_sample_size = self.num_batches * self.batch_size 166 | 167 | # Create the batch tensor 168 | # x_lengths = [len(sample) for sample in samples] 169 | 170 | x_lengths = [] 171 | x_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32) 172 | y_seqs = np.ndarray((new_sample_size,self.seq_max_length),dtype=np.int32) 173 | self.x_lengths = [] 174 | for i,sample in enumerate(samples): 175 | # fill with padding to align batchSize samples into one 2D list 176 | x_lengths.append(len(sample)) 177 | x_seqs[i] = sample + [self.padToken] * (self.seq_max_length - len(sample)) 178 | 179 | for i in range(sample_size,new_sample_size): 180 | copyi = i - sample_size 181 | x_seqs[i] = x_seqs[copyi] 182 | x_lengths.append(x_lengths[copyi]) 183 | 184 | y_seqs[:,:-1] = x_seqs[:,1:] 185 | y_seqs[:,-1] = x_seqs[:,0] 186 | x_len_array = np.array(x_lengths) 187 | 188 | 189 | 190 | self.x_batches = np.split(x_seqs.reshape(self.batch_size, -1), self.num_batches, 1) 191 | self.x_len_batches = np.split(x_len_array.reshape(self.batch_size, -1), self.num_batches, 1) 192 | self.y_batches = np.split(y_seqs.reshape(self.batch_size, -1), self.num_batches, 1) 193 | 194 | def next_batch_dynamic(self): 195 | x,x_len, y = self.x_batches[self.pointer], self.x_len_batches[self.pointer],self.y_batches[self.pointer] 196 | self.pointer += 1 197 | return x,x_len, y 198 | 199 | def next_batch(self): 200 | x, y = self.x_batches[self.pointer], self.y_batches[self.pointer] 201 | self.pointer += 1 202 | return x,y 203 | 204 | def reset_batch_pointer(self): 205 | self.pointer = 0 206 | 207 | @staticmethod 208 | def get_text_idx(text,vocab,max_document_length): 209 | max_document_length_without_end = max_document_length - 1 210 | text_array = [] 211 | for i,x in enumerate(text): 212 | line = [] 213 | if len(x) > max_document_length: 214 | x_parts = x[:max_document_length_without_end] 215 | idx = x_parts.rfind('。') 216 | if idx > -1 : 217 | x_parts = x_parts[0:idx + 1] + ']' 218 | x = x_parts 219 | 220 | for j, w in enumerate(x): 221 | # if j >= max_document_length: 222 | # break 223 | 224 | if (w not in vocab): 225 | w = '' 226 | line.append(vocab[w]) 227 | text_array.append(line) 228 | # else : 229 | # print w,'not exist' 230 | 231 | return text_array 232 | 233 | if __name__ == '__main__': 234 | base_path = './data/poem' 235 | # poem = '风急云轻鹤背寒,洞天谁道却归难。千山万水瀛洲路,何处烟飞是醮坛。是的' 236 | # idx = poem.rfind('。') 237 | # poem_part = poem[:idx + 1] 238 | w2v_file = os.path.join(base_path, "vectors_poem.bin") 239 | w2v = Word2Vec(w2v_file) 240 | 241 | # vect = w2v_model['['][:10] 242 | # print(vect) 243 | # 244 | # vect = w2v_model['春'][:10] 245 | # print(vect) 246 | 247 | in_file = os.path.join(base_path,'poems_edge.txt') 248 | # fr = open(in_file, "r",encoding='utf-8') 249 | # poems = fr.readlines() 250 | # fr.close() 251 | # 252 | # 253 | # 254 | # print("唐诗总数: %d"%len(poems)) 255 | # 256 | # poem_ids = get_text_idx(poems,w2v.model.vocab_hash,100) 257 | # poem_ids_file = os.path.join(base_path,'poem_ids.txt') 258 | # with open(poem_ids_file, 'wb') as f: 259 | # cPickle.dump(poem_ids, f) 260 | 261 | dataloader = DataLoader(base_path,20,w2v.model,'train') 262 | 263 | -------------------------------------------------------------------------------- /doc/client.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/doc/client.png -------------------------------------------------------------------------------- /doc/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/norybaby/poet/b0f7b5dd7a31109995b921ad5269323e0fe794f8/doc/train.png -------------------------------------------------------------------------------- /poem_server.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import os 3 | from flask import Flask,request 4 | from write_poem import WritePoem,start_model 5 | 6 | app = Flask(__name__) 7 | application = app 8 | 9 | path = os.getcwd() #获取当前工作目录 10 | print(path) 11 | 12 | writer = start_model() 13 | 14 | # @app.route('/') 15 | # def test(title): 16 | # return 'test ok' 17 | 18 | sytle_help = '
para style : 1:自由诗
2:带押韵的自由诗
3:藏头诗
4:给定若干字,以最大概率生成诗' 19 | @app.route('/poem') 20 | def write_poem(): 21 | params = request.args 22 | start_with= '' 23 | poem_style = 0 24 | 25 | # print(params) 26 | if 'start' in params : 27 | start_with = params['start'] 28 | if 'style' in params: 29 | poem_style = int(params['style']) 30 | 31 | # return 'hello' 32 | if start_with: 33 | if poem_style == 3: 34 | return writer.cangtou(start_with) 35 | elif poem_style == 4: 36 | return writer.hide_words(start_with) 37 | 38 | if poem_style == 1: 39 | return writer.free_verse() 40 | elif poem_style == 2: 41 | return writer.rhyme_verse() 42 | 43 | return 'hello,what do you want? {}'.format(sytle_help) 44 | 45 | 46 | if __name__ == "__main__": 47 | app.run() -------------------------------------------------------------------------------- /rhyme_helper.py: -------------------------------------------------------------------------------- 1 | 2 | class RhymeWords(): 3 | rhyme_list = [] 4 | 5 | @staticmethod 6 | def read_rhyme_words(infile): 7 | with open(infile,'r',encoding='utf-8',errors='ignore') as fr: 8 | for line in fr: 9 | words = set(line.split()) 10 | RhymeWords.rhyme_list.append(words) 11 | 12 | @staticmethod 13 | def get_rhyme_words(w): 14 | for words in RhymeWords.rhyme_list: 15 | if w in words: 16 | return words 17 | return None 18 | 19 | @staticmethod 20 | def print_stats(): 21 | count = 0 22 | for words in RhymeWords.rhyme_list: 23 | count += len(words) 24 | print(words) 25 | 26 | for w in words: 27 | if len(w) > 1: 28 | print(w) 29 | 30 | print('count = ',count) 31 | 32 | if __name__ == '__main__': 33 | infile = './data/poem/rhyme_words.txt' 34 | RhymeWords.read_rhyme_words(infile) 35 | RhymeWords.print_stats() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import json 3 | import logging 4 | import os 5 | import shutil 6 | import sys 7 | import time 8 | import numpy as np 9 | import tensorflow as tf 10 | from char_rnn_model import CharRNNLM 11 | from config_poem import config_poem_train 12 | from data_loader import DataLoader 13 | from word2vec_helper import Word2Vec 14 | TF_VERSION = int(tf.__version__.split('.')[1]) 15 | 16 | 17 | def main(args=''): 18 | args = config_poem_train(args) 19 | # Specifying location to store model, best model and tensorboard log. 20 | args.save_model = os.path.join(args.output_dir, 'save_model/model') 21 | args.save_best_model = os.path.join(args.output_dir, 'best_model/model') 22 | # args.tb_log_dir = os.path.join(args.output_dir, 'tensorboard_log/') 23 | timestamp = str(int(time.time())) 24 | args.tb_log_dir = os.path.abspath(os.path.join(args.output_dir, "tensorboard_log", timestamp)) 25 | print("Writing to {}\n".format(args.tb_log_dir)) 26 | 27 | # Create necessary directories. 28 | if len(args.init_dir) != 0: 29 | args.output_dir = args.init_dir 30 | else: 31 | if os.path.exists(args.output_dir): 32 | shutil.rmtree(args.output_dir) 33 | for paths in [args.save_model, args.save_best_model, args.tb_log_dir]: 34 | os.makedirs(os.path.dirname(paths)) 35 | 36 | logging.basicConfig(stream=sys.stdout, 37 | format='%(asctime)s %(levelname)s:%(message)s', 38 | level=logging.INFO, datefmt='%I:%M:%S') 39 | 40 | print('=' * 60) 41 | print('All final and intermediate outputs will be stored in %s/' % args.output_dir) 42 | print('=' * 60 + '\n') 43 | 44 | logging.info('args are:\n%s', args) 45 | 46 | if len(args.init_dir) != 0: 47 | with open(os.path.join(args.init_dir, 'result.json'), 'r') as f: 48 | result = json.load(f) 49 | params = result['params'] 50 | args.init_model = result['latest_model'] 51 | best_model = result['best_model'] 52 | best_valid_ppl = result['best_valid_ppl'] 53 | if 'encoding' in result: 54 | args.encoding = result['encoding'] 55 | else: 56 | args.encoding = 'utf-8' 57 | 58 | else: 59 | params = {'batch_size': args.batch_size, 60 | 'num_unrollings': args.num_unrollings, 61 | 'hidden_size': args.hidden_size, 62 | 'max_grad_norm': args.max_grad_norm, 63 | 'embedding_size': args.embedding_size, 64 | 'num_layers': args.num_layers, 65 | 'learning_rate': args.learning_rate, 66 | 'cell_type': args.cell_type, 67 | 'dropout': args.dropout, 68 | 'input_dropout': args.input_dropout} 69 | best_model = '' 70 | logging.info('Parameters are:\n%s\n', json.dumps(params, sort_keys=True, indent=4)) 71 | 72 | # Create batch generators. 73 | batch_size = params['batch_size'] 74 | num_unrollings = params['num_unrollings'] 75 | 76 | base_path = args.data_path 77 | w2v_file = os.path.join(base_path, "vectors_poem.bin") 78 | w2v = Word2Vec(w2v_file) 79 | 80 | train_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'train') 81 | test1_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'test') 82 | valid_data_loader = DataLoader(base_path,batch_size,num_unrollings,w2v.model,'valid') 83 | 84 | # Create graphs 85 | logging.info('Creating graph') 86 | graph = tf.Graph() 87 | with graph.as_default(): 88 | w2v_vocab_size = len(w2v.model.vocab) 89 | with tf.name_scope('training'): 90 | train_model = CharRNNLM(is_training=True,w2v_model = w2v.model,vocab_size=w2v_vocab_size, infer=False, **params) 91 | tf.get_variable_scope().reuse_variables() 92 | 93 | with tf.name_scope('validation'): 94 | valid_model = CharRNNLM(is_training=False,w2v_model = w2v.model, vocab_size=w2v_vocab_size, infer=False, **params) 95 | 96 | with tf.name_scope('evaluation'): 97 | test_model = CharRNNLM(is_training=False,w2v_model = w2v.model,vocab_size=w2v_vocab_size, infer=False, **params) 98 | saver = tf.train.Saver(name='model_saver') 99 | best_model_saver = tf.train.Saver(name='best_model_saver') 100 | 101 | logging.info('Start training\n') 102 | 103 | result = {} 104 | result['params'] = params 105 | 106 | 107 | try: 108 | with tf.Session(graph=graph) as session: 109 | # Version 8 changed the api of summary writer to use 110 | # graph instead of graph_def. 111 | if TF_VERSION >= 8: 112 | graph_info = session.graph 113 | else: 114 | graph_info = session.graph_def 115 | 116 | train_summary_dir = os.path.join(args.tb_log_dir, "summaries", "train") 117 | train_writer = tf.summary.FileWriter(train_summary_dir, graph_info) 118 | valid_summary_dir = os.path.join(args.tb_log_dir, "summaries", "valid") 119 | valid_writer = tf.summary.FileWriter(valid_summary_dir, graph_info) 120 | 121 | # load a saved model or start from random initialization. 122 | if len(args.init_model) != 0: 123 | saver.restore(session, args.init_model) 124 | else: 125 | tf.global_variables_initializer().run() 126 | 127 | learning_rate = args.learning_rate 128 | for epoch in range(args.num_epochs): 129 | logging.info('=' * 19 + ' Epoch %d ' + '=' * 19 + '\n', epoch) 130 | logging.info('Training on training set') 131 | # training step 132 | ppl, train_summary_str, global_step = train_model.run_epoch(session, train_data_loader, is_training=True, 133 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 134 | # record the summary 135 | train_writer.add_summary(train_summary_str, global_step) 136 | train_writer.flush() 137 | # save model 138 | saved_path = saver.save(session, args.save_model, 139 | global_step=train_model.global_step) 140 | 141 | logging.info('Latest model saved in %s\n', saved_path) 142 | logging.info('Evaluate on validation set') 143 | 144 | valid_ppl, valid_summary_str, _ = valid_model.run_epoch(session, valid_data_loader, is_training=False, 145 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 146 | 147 | # save and update best model 148 | if (len(best_model) == 0) or (valid_ppl < best_valid_ppl): 149 | best_model = best_model_saver.save(session, args.save_best_model, 150 | global_step=train_model.global_step) 151 | best_valid_ppl = valid_ppl 152 | else: 153 | learning_rate /= 2.0 154 | logging.info('Decay the learning rate: ' + str(learning_rate)) 155 | 156 | valid_writer.add_summary(valid_summary_str, global_step) 157 | valid_writer.flush() 158 | 159 | logging.info('Best model is saved in %s', best_model) 160 | logging.info('Best validation ppl is %f\n', best_valid_ppl) 161 | 162 | result['latest_model'] = saved_path 163 | result['best_model'] = best_model 164 | # Convert to float because numpy.float is not json serializable. 165 | result['best_valid_ppl'] = float(best_valid_ppl) 166 | 167 | result_path = os.path.join(args.output_dir, 'result.json') 168 | if os.path.exists(result_path): 169 | os.remove(result_path) 170 | with open(result_path, 'w') as f: 171 | json.dump(result, f, indent=2, sort_keys=True) 172 | 173 | logging.info('Latest model is saved in %s', saved_path) 174 | logging.info('Best model is saved in %s', best_model) 175 | logging.info('Best validation ppl is %f\n', best_valid_ppl) 176 | 177 | logging.info('Evaluate the best model on test set') 178 | saver.restore(session, best_model) 179 | test_ppl, _, _ = test_model.run_epoch(session, test1_data_loader, is_training=False, 180 | learning_rate=learning_rate, verbose=args.verbose, freq=args.progress_freq) 181 | result['test_ppl'] = float(test_ppl) 182 | except Exception as e: 183 | print('err :{}'.format(e)) 184 | finally: 185 | result_path = os.path.join(args.output_dir, 'result.json') 186 | if os.path.exists(result_path): 187 | os.remove(result_path) 188 | with open(result_path, 'w',encoding='utf-8',errors='ignore') as f: 189 | json.dump(result, f, indent=2, sort_keys=True) 190 | 191 | 192 | if __name__ == '__main__': 193 | args = '--output_dir output_poem --data_path ./data/poem/ --hidden_size 128 --embedding_size 128 --cell_type lstm' 194 | main(args) 195 | -------------------------------------------------------------------------------- /word2vec_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import word2vec 3 | 4 | class Word2Vec(): 5 | def __init__(self,file_path): 6 | # w2v_file = os.path.join(base_path, "vectors_poem.bin") 7 | self.model = word2vec.load(file_path) 8 | self.add_word('') 9 | self.add_word('') 10 | # self.vocab_size = len(self.model.vocab) 11 | 12 | def add_word(self,word): 13 | if word not in self.model.vocab_hash: 14 | w_vec = np.random.uniform(-0.1,0.1,size=128) 15 | self.model.vocab_hash[word] = len(self.model.vocab) 16 | self.model.vectors = np.row_stack((self.model.vectors,w_vec)) 17 | self.model.vocab = np.concatenate((self.model.vocab,np.array([word]))) 18 | 19 | # vocab = np.empty(1, dtype='= 0: 39 | np.random.seed(args.seed) 40 | 41 | logging.info('best_model: %s\n', best_model) 42 | 43 | self.sess = tf.Session() 44 | w2v_vocab_size = len(self.w2v.model.vocab) 45 | with tf.name_scope('evaluation'): 46 | self.model = CharRNNLM(is_training=False,w2v_model = self.w2v.model,vocab_size=w2v_vocab_size, infer=True, **params) 47 | saver = tf.train.Saver(name='model_saver') 48 | saver.restore(self.sess, best_model) 49 | 50 | def free_verse(self): 51 | ''' 52 | 自由诗 53 | Returns: 54 | 55 | ''' 56 | sample = self.model.sample_seq(self.sess, 40, '[',sample_type= SampleType.weighted_sample) 57 | if not sample: 58 | return 'err occar!' 59 | 60 | print('free_verse:',sample) 61 | 62 | idx_end = sample.find(']') 63 | parts = sample.split('。') 64 | if len(parts) > 1: 65 | two_sentence_len = len(parts[0]) + len(parts[1]) 66 | if idx_end < 0 or two_sentence_len < idx_end: 67 | return sample[1:two_sentence_len + 2] 68 | 69 | return sample[1:idx_end] 70 | 71 | @staticmethod 72 | def assemble(sample): 73 | if sample: 74 | parts = sample.split('。') 75 | if len(parts) > 1: 76 | return '{}。{}。'.format(parts[0][1:],parts[1][:len(parts[0])]) 77 | 78 | return '' 79 | 80 | 81 | def rhyme_verse(self): 82 | ''' 83 | 押韵诗 84 | Returns: 85 | 86 | ''' 87 | gen_len = 20 88 | sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.weighted_sample) 89 | if not sample: 90 | return 'err occar!' 91 | 92 | print('rhyme_verse:',sample) 93 | 94 | parts = sample.split('。') 95 | if len(parts) > 0: 96 | start = parts[0] + '。' 97 | rhyme_ref_word = start[-2] 98 | rhyme_seq = len(start) - 3 99 | 100 | sample = self.model.sample_seq(self.sess, gen_len , start, 101 | sample_type= SampleType.weighted_sample,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) 102 | print(sample) 103 | return WritePoem.assemble(sample) 104 | 105 | return sample[1:] 106 | 107 | def hide_words(self,given_text): 108 | ''' 109 | 藏字诗 110 | Args: 111 | given_text: 112 | 113 | Returns: 114 | 115 | ''' 116 | if(not given_text): 117 | return self.rhyme_verse() 118 | 119 | givens = ['',''] 120 | split_len = math.ceil(len(given_text)/2) 121 | givens[0] = given_text[:split_len] 122 | givens[1] = given_text[split_len:] 123 | 124 | gen_len = 20 125 | sample = self.model.sample_seq(self.sess, gen_len, start_text='[',sample_type= SampleType.select_given,given=givens[0]) 126 | if not sample: 127 | return 'err occar!' 128 | 129 | print('rhyme_verse:',sample) 130 | 131 | parts = sample.split('。') 132 | if len(parts) > 0: 133 | start = parts[0] + '。' 134 | rhyme_ref_word = start[-2] 135 | rhyme_seq = len(start) - 3 136 | # gen_len = len(start) - 1 137 | 138 | sample = self.model.sample_seq(self.sess, gen_len , start, 139 | sample_type= SampleType.select_given,given=givens[1],rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) 140 | print(sample) 141 | return WritePoem.assemble(sample) 142 | 143 | return sample[1:] 144 | 145 | def cangtou(self,given_text): 146 | ''' 147 | 藏头诗 148 | Returns: 149 | 150 | ''' 151 | if(not given_text): 152 | return self.rhyme_verse() 153 | 154 | start = '' 155 | rhyme_ref_word = '' 156 | rhyme_seq = 0 157 | 158 | # for i,word in enumerate(given_text): 159 | for i in range(4): 160 | word = '' 161 | if i < len(given_text): 162 | word = given_text[i] 163 | 164 | if i == 0: 165 | start = '[' + word 166 | else: 167 | start += word 168 | 169 | before_idx = len(start) 170 | if(i != 3): 171 | sample = self.model.sample_seq(self.sess, self.args.length, start, 172 | sample_type= SampleType.weighted_sample ) 173 | 174 | else: 175 | if not word: 176 | rhyme_seq += 1 177 | 178 | sample = self.model.sample_seq(self.sess, self.args.length, start, 179 | sample_type= SampleType.max_prob,rhyme_ref =rhyme_ref_word,rhyme_idx = rhyme_seq ) 180 | 181 | print('Sampled text is:\n\n%s' % sample) 182 | 183 | sample = sample[before_idx:] 184 | idx1 = sample.find(',') 185 | idx2 = sample.find('。') 186 | min_idx = min(idx1,idx2) 187 | 188 | if min_idx == -1: 189 | if idx1 > -1 : 190 | min_idx = idx1 191 | else: min_idx =idx2 192 | if min_idx > 0: 193 | # last_sample.append(sample[:min_idx + 1]) 194 | start ='{}{}'.format(start, sample[:min_idx + 1]) 195 | 196 | if i == 1: 197 | rhyme_seq = min_idx - 1 198 | rhyme_ref_word = sample[rhyme_seq] 199 | 200 | print('last_sample text is:\n\n%s' % start) 201 | 202 | return WritePoem.assemble(start) 203 | 204 | def start_model(): 205 | now = int(time.time()) 206 | args = config_sample('--model_dir output_poem --length 16 --seed {}'.format(now)) 207 | writer = WritePoem(args) 208 | return writer 209 | 210 | if __name__ == '__main__': 211 | writer = start_model() 212 | --------------------------------------------------------------------------------