├── LICENSE ├── README.md ├── dlt2.xls ├── dlt4all.py ├── dlt4all_test.py ├── hyper_parameters.py ├── poems ├── model.py ├── poems.py └── resnet.py ├── ssq.py ├── ssq.xls ├── ssq4all.py ├── ssq4all_test.py ├── ssq4all_test_v4.py ├── ssq4all_v4.py ├── ssq_data.py └── ssq_test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yanchao Liang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 利用神经网络和LSTM预测双色球(How To predict the China's Union Lotto with Neural Network and LSTM) 2 | ## 2020/12/4更新 3 | 因为问的人比较多了,所以提供了适配'tensorflow 2.X'版本的[分支](https://github.com/Liang-yc/ssq/tree/tf2). 4 | 5 | ## 2019/6/18更新 6 | 利用最近几次的双色球结果和奖池情况做预测,训练和测试见`ssq4all_v4.py`和`ssq4all_test_v4.py`。 7 | ## 2018/9/16更新 8 | 添加了大乐透的训练测试文件。 9 | ## 2018/5/30更新 10 | 根据反馈,修复了一些错误,删除了两个文件。另外,有问题请直接开issue。 11 | ## 2018/3/29更新 12 | 尝试用CNN和LSTM做预测。CNN用于提取特征,采用的是resnet。目前最好的结果是五等奖(中4个红号)。 13 | ## 背景(Background) 14 | ------ 15 | 这个项目是通过结合神经网络和Long Short-Term Memory(LSTM)完成双色球预测。关于双色球的介绍,在此处不在赘述(参见[网站](https://baike.baidu.com/item/中国福利彩票双色球/8676030?fr=aladdin&fromid=75279&fromtitle=%E5%8F%8C%E8%89%B2%E7%90%83))。该项目以真实的双色球开奖结果作为输入(7个数值,其中前6个表示红号,最后一个表示蓝号),输出预测结果(输出仍为7个数值,其中前6个表示红号,最后一个表示蓝号)。目前该项目还处于开发中。

16 | 其中,项目核心代码主要借鉴了这个[网站](https://github.com/jinfagang/tensorflow_poems) 。 17 | 截至目前,该项目充分训练后的预测结果中,正确预测的红号数目为0-2个(1个的居多,但鉴于这概率本身就偏高,因此不能算是取得很好的结果),正确预测的蓝号数目为0-1个(0个的居多),考虑到双色球的中奖条件(至少3红号正确或者至少蓝号正确),**不建议各位直接根据预测结果去买**。
18 | 预测结果不好有以下几个原因: 19 | 1. 双色球本身是一个随机事件,而我却希望从中找到规律; 20 | 2. 数据集样本数量不够多且无法有效扩大数据集; 21 | 3. 可能我的模型还不够优秀。 22 | 23 | 如果您有什么好的建议或者有什么问题,请直接发我邮箱(yc_liang@qq.com)或者在issue留言也可以;
24 | 如果您觉得这项目似乎还有点意思,记得STAR;
25 | 当然,如果您真的中奖了,要求不高, STAR。
26 | 27 | ## 项目依赖(Requirement) 28 | ----- 29 | 30 | 1.Tensorflow;
31 | 2.pyexcel_xls;
32 | 3.CUDA.(optional)
33 | 该项目的代码可能还需要一些python包但我没列出来,请自行pip一下。

34 | 35 | ## 项目文件(File contents) 36 | ----- 37 | 38 | 1.` poems` : ` poems`文件夹包含3个模型文件,其中`model.py`是本项目所使用的模型文件;其中`resnet.py`定义的是resnet模型;
39 | 2.`ssq.py `: 该文件用于训练双色球模型;
40 | 3.`ssq.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
41 | 4.`ssq_data.py ` : 读取`ssq.xls`文件中的数据并变成所需要的形式.
42 | 5.`ssq_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
43 | 6.`ssq4all.py `: 该文件用于训练双色球模型(和`ssq.py `功能一样,但用的模型不一样);
44 | 7.`ssq4all_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
45 | 6.`dlt.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
46 | 7.`dlt4all.py `: 该文件用于训练大乐透模型;
47 | 8.`dlt4all_test.py`: 以最近一次的大乐透开奖结果作为输入,输出模型产生的预测结果;
48 | 49 | ## 训练及测试(How To Use) 50 | ----- 51 | 设置好参数,运行ssq.py(建议在控制台运行,这样训练到一半可以直接通过按CTRL+C取消训练并保存结果;这样下次训练时可以从上一次的checkpoint继续训练)。训练后,项目目录下会生成一个model文件夹用于保存模型。然后运行ssq_test.py进行预测就可以了。
52 | 该项目在W10, tensorflow 1.4/1.5正常运行。有无GPU均可运行。原则上讲,这项目的模型并不大,且数据量也不大,所以可以直接将所有训练数据作为一个batch进行训练,后续如果需要加深模型,可能要降低batch size。 53 | -------------------------------------------------------------------------------- /dlt2.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-yc/ssq/4c69745d1fb731b4a39cc571af58a31368439789/dlt2.xls -------------------------------------------------------------------------------- /dlt4all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from poems.model import rnn_model 23 | from poems.resnet import * 24 | from poems.poems import process_poems, generate_batch 25 | from ssq_data import * 26 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48) 27 | import win_unicode_console 28 | win_unicode_console.enable() 29 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.') 30 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 31 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./dlt_model'), 'model save path.') 32 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 33 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 34 | tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.') 35 | 36 | FLAGS = tf.app.flags.FLAGS 37 | 38 | 39 | def run_training(): 40 | if not os.path.exists(FLAGS.model_dir): 41 | os.makedirs(FLAGS.model_dir) 42 | # 43 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 44 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 45 | # ssqdata=get_exl_data(random_order=False,use_resnet=True) 46 | # # print(ssqdata[len(ssqdata)-1]) 47 | # batches_inputs=ssqdata[0:(len(ssqdata)-1)] 48 | # ssqdata=get_exl_data(random_order=False,use_resnet=False) 49 | ssqdata=get_dlt_data(random_order=False,use_resnet=True) 50 | # print(ssqdata[len(ssqdata)-1]) 51 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 52 | ssqdata=get_dlt_data(random_order=False,use_resnet=False) 53 | batches_outputs = ssqdata[1:(len(ssqdata))] 54 | FLAGS.batch_size=len(batches_inputs) 55 | # print(np.shape(batches_outputs)) 56 | # data=batches_outputs[1:7] 57 | # print(len(data)) 58 | del ssqdata 59 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1]) 60 | logits = inference(input_data, 1, reuse=False,output_num=128) 61 | 62 | # print(tf.shape(input_data)) 63 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 64 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=35+12,output_num=7, 65 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 66 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 67 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 68 | 69 | saver = tf.train.Saver(tf.global_variables()) 70 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 71 | with tf.Session() as sess: 72 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 73 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 74 | sess.run(init_op) 75 | 76 | start_epoch = 0 77 | # saver.restore(sess, "D:/tensorflow_poems-master/model4all/poems-207261") 78 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 79 | if checkpoint: 80 | saver.restore(sess, checkpoint) 81 | print("## restore from the checkpoint {0}".format(checkpoint)) 82 | start_epoch += int(checkpoint.split('-')[-1]) 83 | print('## start training...') 84 | try: 85 | for epoch in range(start_epoch, FLAGS.epochs): 86 | n = 0 87 | # n_chunk = len(poems_vector) // FLAGS.batch_size 88 | # n_chunk = len(batches_inputs) // FLAGS.batch_size 89 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size) 90 | for batch in range(n_chunk): 91 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs) 92 | if left<0: 93 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 94 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 95 | else: 96 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ] 97 | # temp.extend(batches_inputs[0:left]) 98 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)] 99 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ] 100 | # temp.extend(batches_outputs[0:left]) 101 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)] 102 | # print(len(inputdata)) 103 | loss, _, _ = sess.run([ 104 | end_points['total_loss'], 105 | end_points['last_state'], 106 | end_points['train_op'] 107 | ], feed_dict={input_data: inputdata, output_targets: outputdata}) 108 | # ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs}) 109 | n += 1 110 | if epoch % 1000 == 0: 111 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 112 | if epoch % 50000 == 0: 113 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 114 | except KeyboardInterrupt: 115 | print('## Interrupt manually, try saving checkpoint for now...') 116 | finally: 117 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 118 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 119 | 120 | 121 | def main(_): 122 | run_training() 123 | 124 | 125 | if __name__ == '__main__': 126 | tf.app.run() 127 | -------------------------------------------------------------------------------- /dlt4all_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 44 | logits = inference(input_data, 1, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=35+12, 49 | output_num=7, 50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 51 | 52 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 53 | # 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | sess.run(init_op) 61 | 62 | checkpoint = tf.train.latest_checkpoint('./dlt_model/') 63 | saver.restore(sess, checkpoint) 64 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4red/poems-150000") 65 | ssqdata = get_dlt_data(random_order=False,use_resnet=True) 66 | # x = np.array([list(map(word_int_map.get, start_token))]) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[-34],[-34]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,-34,-34]) 74 | # print(sorted_result) 75 | print("output: %s"%results) 76 | return poem_ 77 | 78 | def gen_blue(): 79 | batch_size = 1 80 | print('## loading model from %s' % model_dir) 81 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 82 | logits = inference(input_data, 10, reuse=False,output_num=128) 83 | 84 | # print(tf.shape(input_data)) 85 | output_targets = tf.placeholder(tf.int32, [1, None]) 86 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33, 87 | output_num=1, 88 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01) 89 | 90 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 91 | # 92 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 93 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 94 | 95 | saver = tf.train.Saver(tf.global_variables()) 96 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 97 | with tf.Session() as sess: 98 | sess.run(init_op) 99 | 100 | checkpoint = tf.train.latest_checkpoint('./model4all/') 101 | saver.restore(sess, checkpoint) 102 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201") 103 | ssqdata = get_exl_data(random_order=False,use_resnet=True) 104 | # x = np.array([list(map(word_int_map.get, start_token))]) 105 | x=[ssqdata[len(ssqdata)-1]] 106 | print("input: %s"%(x)) 107 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 108 | feed_dict={input_data: x}) 109 | poem_=np.argmax(np.array(predict),axis=1) 110 | sorted_result=np.argsort(np.array(predict),axis=1) 111 | results=poem_+np.ones(1) 112 | print(results) 113 | print(sorted_result) 114 | return poem_ 115 | 116 | 117 | if __name__ == '__main__': 118 | # begin_char = input('## please input the first character:') 119 | # poem=gen_blue()#13 120 | poem = gen_poem() 121 | # pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /hyper_parameters.py: -------------------------------------------------------------------------------- 1 | # Coder: Wenxin Xu 2 | # Github: https://github.com/wenxinxu/resnet_in_tensorflow 3 | # ============================================================================== 4 | import tensorflow as tf 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | 8 | ## The following flags are related to save paths, tensorboard outputs and screen outputs 9 | 10 | tf.app.flags.DEFINE_string('version', 'test_110', '''A version number defining the directory to save 11 | logs and checkpoints''') 12 | tf.app.flags.DEFINE_integer('report_freq', 391, '''Steps takes to output errors on the screen 13 | and write summaries''') 14 | tf.app.flags.DEFINE_float('train_ema_decay', 0.95, '''The decay factor of the train error's 15 | moving average shown on tensorboard''') 16 | 17 | 18 | ## The following flags define hyper-parameters regards training 19 | 20 | tf.app.flags.DEFINE_integer('train_steps', 80000, '''Total steps that you want to train''') 21 | tf.app.flags.DEFINE_boolean('is_full_validation', False, '''Validation w/ full validation set or 22 | a random batch''') 23 | tf.app.flags.DEFINE_integer('train_batch_size', 128, '''Train batch size''') 24 | tf.app.flags.DEFINE_integer('validation_batch_size', 250, '''Validation batch size, better to be 25 | a divisor of 10000 for this task''') 26 | tf.app.flags.DEFINE_integer('test_batch_size', 125, '''Test batch size''') 27 | 28 | tf.app.flags.DEFINE_float('init_lr', 0.1, '''Initial learning rate''') 29 | tf.app.flags.DEFINE_float('lr_decay_factor', 0.1, '''How much to decay the learning rate each 30 | time''') 31 | tf.app.flags.DEFINE_integer('decay_step0', 40000, '''At which step to decay the learning rate''') 32 | tf.app.flags.DEFINE_integer('decay_step1', 60000, '''At which step to decay the learning rate''') 33 | 34 | 35 | ## The following flags define hyper-parameters modifying the training network 36 | 37 | tf.app.flags.DEFINE_integer('num_residual_blocks', 5, '''How many residual blocks do you want''') 38 | tf.app.flags.DEFINE_float('weight_decay', 0.0002, '''scale for l2 regularization''') 39 | 40 | 41 | ## The following flags are related to data-augmentation 42 | 43 | tf.app.flags.DEFINE_integer('padding_size', 2, '''In data augmentation, layers of zero padding on 44 | each side of the image''') 45 | 46 | 47 | ## If you want to load a checkpoint and continue training 48 | 49 | tf.app.flags.DEFINE_string('ckpt_path', 'cache/logs_repeat20/model.ckpt-100000', '''Checkpoint 50 | directory to restore''') 51 | tf.app.flags.DEFINE_boolean('is_use_ckpt', False, '''Whether to load a checkpoint and continue 52 | training''') 53 | 54 | tf.app.flags.DEFINE_string('test_ckpt_path', 'model_110.ckpt-79999', '''Checkpoint 55 | directory to restore''') 56 | 57 | 58 | train_dir = 'logs_' + FLAGS.version + '/' 59 | -------------------------------------------------------------------------------- /poems/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: model.py 3 | # author: JinTian 4 | # time: 07/03/2017 3:07 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | import numpy as np 21 | 22 | 23 | def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64, 24 | input_num=128,output_num=1,learning_rate=0.01,use_cnn=True): 25 | """ 26 | construct rnn seq2seq model. 27 | :param model: model class 28 | :param input_data: input data placeholder 29 | :param output_data: output data placeholder 30 | :param vocab_size: 31 | :param rnn_size: 32 | :param num_layers: 33 | :param batch_size: 34 | :param learning_rate: 35 | :return: 36 | """ 37 | end_points = {} 38 | 39 | 40 | if model == 'rnn': 41 | cell_fun = tf.contrib.rnn.BasicRNNCell 42 | cell = cell_fun(rnn_size) 43 | elif model == 'gru': 44 | cell_fun = tf.contrib.rnn.GRUCell 45 | cell = cell_fun(rnn_size) 46 | elif model == 'lstm': 47 | cell_fun = tf.contrib.rnn.BasicLSTMCell 48 | cell = cell_fun(rnn_size, state_is_tuple=True) 49 | 50 | cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True) 51 | 52 | if output_data is not None: 53 | initial_state = cell.zero_state(batch_size, tf.float32) 54 | else: 55 | initial_state = cell.zero_state(1, tf.float32) 56 | if use_cnn: 57 | with tf.name_scope('fc1'): 58 | fc1_weights = tf.Variable( # fully connected, depth 512. 59 | tf.truncated_normal([input_num, 128], 60 | # mean=1.0, 61 | stddev=0.1, 62 | dtype=tf.float32)) 63 | fc1_biases = tf.Variable(tf.constant(1.0, shape=[128], dtype=tf.float32)) 64 | fc1 = tf.nn.relu(tf.matmul(tf.to_float(input_data), fc1_weights) + fc1_biases) 65 | with tf.name_scope('fc2'): 66 | fc2_weights = tf.Variable( # fully connected, depth 512. 67 | tf.truncated_normal([128, 256], 68 | # mean=1.0, 69 | stddev=0.1, 70 | dtype=tf.float32)) 71 | fc2_biases = tf.Variable(tf.constant(1.0, shape=[256], dtype=tf.float32)) 72 | fc2 = tf.nn.relu(tf.matmul(tf.to_float(fc1), fc2_weights) + fc2_biases) 73 | with tf.name_scope('fc3'): 74 | fc3_weights = tf.Variable( # fully connected, depth 512. 75 | tf.truncated_normal([256, 512], 76 | # mean=1.0, 77 | stddev=0.1, 78 | dtype=tf.float32)) 79 | fc3_biases = tf.Variable(tf.constant(1.0, shape=[512], dtype=tf.float32)) 80 | fc3 = tf.nn.relu(tf.matmul(tf.to_float(fc2), fc3_weights) + fc3_biases) 81 | with tf.name_scope('fc4'): 82 | fc4_weights = tf.Variable( # fully connected, depth 512. 83 | tf.truncated_normal([512, 128*output_num], 84 | # mean=1.0, 85 | stddev=0.1, 86 | dtype=tf.float32)) 87 | fc4_biases = tf.Variable(tf.constant(1.0, shape=[128*output_num], dtype=tf.float32)) 88 | fc4 = tf.nn.relu(tf.matmul(fc3, fc4_weights) + fc4_biases) 89 | 90 | 91 | # [batch_size, ?, rnn_size] = [64, ?, 128] 92 | embedded=tf.reshape(fc4,[batch_size,-1,128]) 93 | outputs, last_state = tf.nn.dynamic_rnn(cell, embedded, initial_state=initial_state) 94 | else: 95 | with tf.device("/cpu:0"): 96 | embedding = tf.get_variable('embedding', initializer=tf.random_uniform( 97 | [vocab_size + 1, rnn_size], -1.0, 1.0)) 98 | inputs = tf.nn.embedding_lookup(embedding, input_data) 99 | # embedded = tf.reshape(input_data, [batch_size, -1, 128]) 100 | # outputs, last_state = tf.nn.dynamic_rnn(cell, embedded, initial_state=initial_state) 101 | # outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state) 102 | output = tf.reshape(outputs, [-1, rnn_size]) 103 | 104 | weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1])) 105 | bias = tf.Variable(tf.zeros(shape=[vocab_size + 1])) 106 | logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias) 107 | # [?, vocab_size+1] 108 | 109 | if output_data is not None: 110 | # output_data must be one-hot encode 111 | labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1) 112 | # should be [?, vocab_size+1] 113 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) 114 | 115 | # loss shape should be [?, vocab_size+1] 116 | total_loss = tf.reduce_mean(loss) 117 | regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) + 118 | tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases)+ 119 | tf.nn.l2_loss(fc3_weights) + tf.nn.l2_loss(fc3_biases)+ 120 | tf.nn.l2_loss(fc4_weights) + tf.nn.l2_loss(fc4_biases)) 121 | total_loss+=5e-4 * regularizers 122 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) 123 | 124 | end_points['initial_state'] = initial_state 125 | end_points['output'] = output 126 | end_points['train_op'] = train_op 127 | end_points['total_loss'] = total_loss 128 | end_points['loss'] = loss 129 | end_points['last_state'] = last_state 130 | prediction = tf.nn.softmax(logits) 131 | end_points['prediction'] = prediction 132 | else: 133 | prediction = tf.nn.softmax(logits) 134 | 135 | end_points['initial_state'] = initial_state 136 | end_points['last_state'] = last_state 137 | end_points['prediction'] = prediction 138 | 139 | return end_points 140 | 141 | -------------------------------------------------------------------------------- /poems/poems.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: poems.py 3 | # author: JinTian 4 | # time: 08/03/2017 7:39 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import collections 20 | import os 21 | import sys 22 | import numpy as np 23 | 24 | start_token = 'B' 25 | end_token = 'E' 26 | 27 | 28 | def process_poems(file_name): 29 | # 诗集 30 | poems = [] 31 | with open(file_name, "r", encoding='utf-8', ) as f: 32 | for line in f.readlines(): 33 | try: 34 | title, content = line.strip().split(':') 35 | content = content.replace(' ', '') 36 | if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \ 37 | start_token in content or end_token in content: 38 | continue 39 | if len(content) < 5 or len(content) > 79: 40 | continue 41 | content = start_token + content + end_token 42 | poems.append(content) 43 | except ValueError as e: 44 | pass 45 | poems = sorted(poems, key=lambda l: len(line)) 46 | 47 | all_words = [] 48 | for poem in poems: 49 | all_words += [word for word in poem] 50 | counter = collections.Counter(all_words) 51 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 52 | words, _ = zip(*count_pairs) 53 | 54 | words = words[:len(words)] + (' ',) 55 | word_int_map = dict(zip(words, range(len(words)))) 56 | poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems] 57 | 58 | return poems_vector, word_int_map, words 59 | 60 | 61 | def generate_batch(batch_size, poems_vec, word_to_int): 62 | n_chunk = len(poems_vec) // batch_size 63 | x_batches = [] 64 | y_batches = [] 65 | for i in range(n_chunk): 66 | start_index = i * batch_size 67 | end_index = start_index + batch_size 68 | 69 | batches = poems_vec[start_index:end_index] 70 | length = max(map(len, batches)) 71 | x_data = np.full((batch_size, length), word_to_int[' '], np.int32) 72 | for row in range(batch_size): 73 | x_data[row, :len(batches[row])] = batches[row] 74 | y_data = np.copy(x_data) 75 | y_data[:, :-1] = x_data[:, 1:] 76 | """ 77 | x_data y_data 78 | [6,2,4,6,9] [2,4,6,9,9] 79 | [1,4,2,8,5] [4,2,8,5,5] 80 | """ 81 | x_batches.append(x_data) 82 | y_batches.append(y_data) 83 | return x_batches, y_batches 84 | -------------------------------------------------------------------------------- /poems/resnet.py: -------------------------------------------------------------------------------- 1 | # Coder: Wenxin Xu 2 | # Github: https://github.com/wenxinxu/resnet_in_tensorflow 3 | # ============================================================================== 4 | ''' 5 | This is the resnet structure 6 | ''' 7 | import numpy as np 8 | 9 | from hyper_parameters import * 10 | 11 | BN_EPSILON = 0.001 12 | 13 | 14 | def activation_summary(x): 15 | ''' 16 | :param x: A Tensor 17 | :return: Add histogram summary and scalar summary of the sparsity of the tensor 18 | ''' 19 | tensor_name = x.op.name 20 | tf.summary.histogram(tensor_name + '/activations', x) 21 | tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) 22 | 23 | 24 | def create_variables(name, shape, initializer=tf.contrib.layers.xavier_initializer(), is_fc_layer=False): 25 | ''' 26 | :param name: A string. The name of the new variable 27 | :param shape: A list of dimensions 28 | :param initializer: User Xavier as default. 29 | :param is_fc_layer: Want to create fc layer variable? May use different weight_decay for fc 30 | layers. 31 | :return: The created variable 32 | ''' 33 | 34 | ## TODO: to allow different weight decay to fully connected layer and conv layer 35 | if is_fc_layer is True: 36 | regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.weight_decay) 37 | else: 38 | regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.weight_decay) 39 | 40 | new_variables = tf.get_variable(name, shape=shape, initializer=initializer, 41 | regularizer=regularizer) 42 | return new_variables 43 | 44 | 45 | def output_layer(input_layer, num_labels): 46 | ''' 47 | :param input_layer: 2D tensor 48 | :param num_labels: int. How many output labels in total? (10 for cifar10 and 100 for cifar100) 49 | :return: output layer Y = WX + B 50 | ''' 51 | input_dim = input_layer.get_shape().as_list()[-1] 52 | fc_w = create_variables(name='fc_weights', shape=[input_dim, num_labels], is_fc_layer=True, 53 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 54 | fc_b = create_variables(name='fc_bias', shape=[num_labels], initializer=tf.zeros_initializer()) 55 | 56 | fc_h = tf.matmul(input_layer, fc_w) + fc_b 57 | return fc_h 58 | 59 | 60 | def batch_normalization_layer(input_layer, dimension): 61 | ''' 62 | Helper function to do batch normalziation 63 | :param input_layer: 4D tensor 64 | :param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor 65 | :return: the 4D tensor after being normalized 66 | ''' 67 | mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2]) 68 | beta = tf.get_variable('beta', dimension, tf.float32, 69 | initializer=tf.constant_initializer(0.0, tf.float32)) 70 | gamma = tf.get_variable('gamma', dimension, tf.float32, 71 | initializer=tf.constant_initializer(1.0, tf.float32)) 72 | bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, BN_EPSILON) 73 | 74 | return bn_layer 75 | 76 | 77 | def conv_bn_relu_layer(input_layer, filter_shape, stride): 78 | ''' 79 | A helper function to conv, batch normalize and relu the input tensor sequentially 80 | :param input_layer: 4D tensor 81 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number] 82 | :param stride: stride size for conv 83 | :return: 4D tensor. Y = Relu(batch_normalize(conv(X))) 84 | ''' 85 | 86 | out_channel = filter_shape[-1] 87 | filter = create_variables(name='conv', shape=filter_shape) 88 | 89 | conv_layer = tf.nn.conv2d(input_layer, filter, strides=[1, stride, stride, 1], padding='SAME') 90 | bn_layer = batch_normalization_layer(conv_layer, out_channel) 91 | 92 | output = tf.nn.relu(bn_layer) 93 | return output 94 | 95 | 96 | def bn_relu_conv_layer(input_layer, filter_shape, stride): 97 | ''' 98 | A helper function to batch normalize, relu and conv the input layer sequentially 99 | :param input_layer: 4D tensor 100 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number] 101 | :param stride: stride size for conv 102 | :return: 4D tensor. Y = conv(Relu(batch_normalize(X))) 103 | ''' 104 | 105 | in_channel = input_layer.get_shape().as_list()[-1] 106 | 107 | bn_layer = batch_normalization_layer(input_layer, in_channel) 108 | relu_layer = tf.nn.relu(bn_layer) 109 | 110 | filter = create_variables(name='conv', shape=filter_shape) 111 | conv_layer = tf.nn.conv2d(relu_layer, filter, strides=[1, stride, stride, 1], padding='SAME') 112 | return conv_layer 113 | 114 | 115 | def residual_block(input_layer, output_channel, first_block=False): 116 | ''' 117 | Defines a residual block in ResNet 118 | :param input_layer: 4D tensor 119 | :param output_channel: int. return_tensor.get_shape().as_list()[-1] = output_channel 120 | :param first_block: if this is the first residual block of the whole network 121 | :return: 4D tensor. 122 | ''' 123 | input_channel = input_layer.get_shape().as_list()[-1] 124 | 125 | # When it's time to "shrink" the image size, we use stride = 2 126 | if input_channel * 2 == output_channel: 127 | increase_dim = True 128 | stride = 2 129 | elif input_channel == output_channel: 130 | increase_dim = False 131 | stride = 1 132 | else: 133 | raise ValueError('Output and input channel does not match in residual blocks!!!') 134 | 135 | # The first conv layer of the first residual block does not need to be normalized and relu-ed. 136 | with tf.variable_scope('conv1_in_block'): 137 | if first_block: 138 | filter = create_variables(name='conv', shape=[3, 3, input_channel, output_channel]) 139 | conv1 = tf.nn.conv2d(input_layer, filter=filter, strides=[1, 1, 1, 1], padding='SAME') 140 | else: 141 | conv1 = bn_relu_conv_layer(input_layer, [3, 3, input_channel, output_channel], stride) 142 | 143 | with tf.variable_scope('conv2_in_block'): 144 | conv2 = bn_relu_conv_layer(conv1, [3, 3, output_channel, output_channel], 1) 145 | 146 | # When the channels of input layer and conv2 does not match, we add zero pads to increase the 147 | # depth of input layers 148 | if increase_dim is True: 149 | pooled_input = tf.nn.avg_pool(input_layer, ksize=[1, 1, 1, 1], 150 | strides=[1, 2, 2, 1], padding='VALID') 151 | padded_input = tf.pad(pooled_input, [[0, 0], [0, 0], [0, 0], [input_channel // 2, 152 | input_channel // 2]]) 153 | else: 154 | padded_input = input_layer 155 | 156 | output = conv2 + padded_input 157 | return output 158 | 159 | 160 | def inference(input_tensor_batch, n, reuse,output_num=128): 161 | ''' 162 | The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n +1 = 6n + 2 163 | :param input_tensor_batch: 4D tensor 164 | :param n: num_residual_blocks 165 | :param reuse: To build train graph, reuse=False. To build validation graph and share weights 166 | with train graph, resue=True 167 | :return: last layer in the network. Not softmax-ed 168 | ''' 169 | 170 | layers = [] 171 | with tf.variable_scope('conv0', reuse=reuse): 172 | conv0 = conv_bn_relu_layer(input_tensor_batch, [1, 1, 1, 16], 1) 173 | activation_summary(conv0) 174 | layers.append(conv0) 175 | 176 | for i in range(n): 177 | with tf.variable_scope('conv1_%d' % i, reuse=reuse): 178 | if i == 0: 179 | conv1 = residual_block(layers[-1], 16, first_block=True) 180 | else: 181 | conv1 = residual_block(layers[-1], 16) 182 | activation_summary(conv1) 183 | layers.append(conv1) 184 | 185 | for i in range(n): 186 | with tf.variable_scope('conv2_%d' % i, reuse=reuse): 187 | conv2 = residual_block(layers[-1], 32) 188 | activation_summary(conv2) 189 | layers.append(conv2) 190 | 191 | for i in range(n): 192 | with tf.variable_scope('conv3_%d' % i, reuse=reuse): 193 | conv3 = residual_block(layers[-1], 64) 194 | layers.append(conv3) 195 | # assert conv3.get_shape().as_list()[1:] == [8, 8, 64] 196 | 197 | with tf.variable_scope('fc', reuse=reuse): 198 | in_channel = layers[-1].get_shape().as_list()[-1] 199 | bn_layer = batch_normalization_layer(layers[-1], in_channel) 200 | relu_layer = tf.nn.relu(bn_layer) 201 | global_pool = tf.reduce_mean(relu_layer, [1, 2]) 202 | 203 | assert global_pool.get_shape().as_list()[-1:] == [64] 204 | output = output_layer(global_pool, output_num) 205 | layers.append(output) 206 | 207 | return layers[-1] 208 | 209 | 210 | def test_graph(train_dir='logs'): 211 | ''' 212 | Run this function to look at the graph structure on tensorboard. A fast way! 213 | :param train_dir: 214 | ''' 215 | input_tensor = tf.constant(np.ones([128, 32, 32, 3]), dtype=tf.float32) 216 | result = inference(input_tensor, 2, reuse=False) 217 | init = tf.initialize_all_variables() 218 | sess = tf.Session() 219 | sess.run(init) 220 | summary_writer = tf.summary.FileWriter(train_dir, sess.graph) -------------------------------------------------------------------------------- /ssq.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from poems.model import rnn_model 23 | from poems.poems import process_poems, generate_batch 24 | from ssq_data import * 25 | tf.app.flags.DEFINE_integer('batch_size', -1, 'batch size.') 26 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 27 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.') 28 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 29 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 30 | tf.app.flags.DEFINE_integer('epochs', 150000, 'train how many epochs.') 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | 35 | def run_training(): 36 | # if not os.path.exists(FLAGS.model_dir): 37 | # os.makedirs(FLAGS.model_dir) 38 | # 39 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 40 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 41 | ssqdata=get_exl_data(random_order=True) 42 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 43 | batches_outputs = ssqdata[1:(len(ssqdata))] 44 | FLAGS.batch_size=len(ssqdata)-1 45 | # data=batches_outputs[1:7] 46 | # print(len(data)) 47 | del ssqdata 48 | input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 49 | # print(tf.shape(input_data)) 50 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 51 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=33+16, 52 | output_num=7,input_num=7, 53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 62 | sess.run(init_op) 63 | 64 | start_epoch = 0 65 | # saver.restore(sess, "F:/tensorflow_poems-master/model/poems-100000") 66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 67 | if checkpoint: 68 | saver.restore(sess, checkpoint) 69 | print("## restore from the checkpoint {0}".format(checkpoint)) 70 | start_epoch += int(checkpoint.split('-')[-1]) 71 | print('## start training...') 72 | try: 73 | for epoch in range(start_epoch, FLAGS.epochs): 74 | n = 0 75 | # n_chunk = len(poems_vector) // FLAGS.batch_size 76 | # n_chunk = len(batches_inputs) // FLAGS.batch_size 77 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size) 78 | for batch in range(n_chunk): 79 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs) 80 | if left<0: 81 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 82 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 83 | else: 84 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ] 85 | # temp.extend(batches_inputs[0:left]) 86 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)] 87 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ] 88 | # temp.extend(batches_outputs[0:left]) 89 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)] 90 | # print(len(inputdata)) 91 | loss, _, _ = sess.run([ 92 | end_points['total_loss'], 93 | end_points['last_state'], 94 | end_points['train_op'] 95 | ], feed_dict={input_data: inputdata, output_targets: outputdata}) 96 | # # ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}) 97 | n += 1 98 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 99 | if epoch % 5000 == 0: 100 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 101 | except KeyboardInterrupt: 102 | print('## Interrupt manually, try saving checkpoint for now...') 103 | finally: 104 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 105 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 106 | 107 | 108 | def main(_): 109 | run_training() 110 | 111 | 112 | if __name__ == '__main__': 113 | tf.app.run() 114 | -------------------------------------------------------------------------------- /ssq.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-yc/ssq/4c69745d1fb731b4a39cc571af58a31368439789/ssq.xls -------------------------------------------------------------------------------- /ssq4all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from poems.model import rnn_model 23 | from poems.resnet import * 24 | from poems.poems import process_poems, generate_batch 25 | from ssq_data import * 26 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48) 27 | # import win_unicode_console 28 | # win_unicode_console.enable() 29 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.') 30 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 31 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all'), 'model save path.') 32 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 33 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 34 | tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.') 35 | 36 | FLAGS = tf.app.flags.FLAGS 37 | 38 | 39 | def run_training(): 40 | if not os.path.exists(FLAGS.model_dir): 41 | os.makedirs(FLAGS.model_dir) 42 | # 43 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 44 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 45 | ssqdata=get_exl_data(random_order=False,use_resnet=True) 46 | # print(ssqdata[len(ssqdata)-1]) 47 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 48 | ssqdata=get_exl_data(random_order=False,use_resnet=False) 49 | batches_outputs = ssqdata[1:(len(ssqdata))] 50 | FLAGS.batch_size=len(batches_inputs) 51 | # print(np.shape(batches_outputs)) 52 | # data=batches_outputs[1:7] 53 | # print(len(data)) 54 | del ssqdata 55 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1]) 56 | logits = inference(input_data, 2, reuse=False,output_num=128) 57 | 58 | # print(tf.shape(input_data)) 59 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 60 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 61 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 62 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 63 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 64 | 65 | saver = tf.train.Saver(tf.global_variables()) 66 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 67 | with tf.Session() as sess: 68 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 69 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 70 | sess.run(init_op) 71 | 72 | start_epoch = 0 73 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 74 | if checkpoint: 75 | saver.restore(sess, checkpoint) 76 | print("## restore from the checkpoint {0}".format(checkpoint)) 77 | start_epoch += int(checkpoint.split('-')[-1]) 78 | print('## start training...') 79 | try: 80 | for epoch in range(start_epoch, FLAGS.epochs): 81 | n = 0 82 | # n_chunk = len(poems_vector) // FLAGS.batch_size 83 | # n_chunk = len(batches_inputs) // FLAGS.batch_size 84 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size) 85 | for batch in range(n_chunk): 86 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs) 87 | if left<0: 88 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 89 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 90 | else: 91 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ] 92 | # temp.extend(batches_inputs[0:left]) 93 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)] 94 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ] 95 | # temp.extend(batches_outputs[0:left]) 96 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)] 97 | # print(len(inputdata)) 98 | loss, _, _ = sess.run([ 99 | end_points['total_loss'], 100 | end_points['last_state'], 101 | end_points['train_op'] 102 | ], feed_dict={input_data: inputdata, output_targets: outputdata}) 103 | # ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs}) 104 | n += 1 105 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 106 | if epoch % 50000 == 0: 107 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 108 | except KeyboardInterrupt: 109 | print('## Interrupt manually, try saving checkpoint for now...') 110 | finally: 111 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 112 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 113 | 114 | 115 | def main(_): 116 | run_training() 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.app.run() 121 | -------------------------------------------------------------------------------- /ssq4all_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 44 | logits = inference(input_data, 2, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16, 49 | output_num=7, 50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 51 | 52 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 53 | # 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | sess.run(init_op) 61 | 62 | checkpoint = tf.train.latest_checkpoint('./model4all/') 63 | saver.restore(sess, checkpoint) 64 | # saver.restore(sess, "F:/tensorflow_poems-master/model4all/poems-401713") 65 | ssqdata = get_exl_data(random_order=True,use_resnet=True) 66 | # x = np.array([list(map(word_int_map.get, start_token))]) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 74 | print(sorted_result) 75 | print("output: %s"%results) 76 | poem_=np.argmin(np.array(predict),axis=1) 77 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 78 | print("min output:%s"%results) 79 | return poem_ 80 | 81 | def gen_blue(): 82 | batch_size = 1 83 | print('## loading model from %s' % model_dir) 84 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 85 | logits = inference(input_data, 10, reuse=False,output_num=128) 86 | 87 | # print(tf.shape(input_data)) 88 | output_targets = tf.placeholder(tf.int32, [1, None]) 89 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33, 90 | output_num=1, 91 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01) 92 | 93 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 94 | # 95 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 96 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 97 | 98 | saver = tf.train.Saver(tf.global_variables()) 99 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 100 | with tf.Session() as sess: 101 | sess.run(init_op) 102 | 103 | checkpoint = tf.train.latest_checkpoint('./model4all/') 104 | saver.restore(sess, checkpoint) 105 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201") 106 | ssqdata = get_exl_data(random_order=True,use_resnet=True) 107 | # x = np.array([list(map(word_int_map.get, start_token))]) 108 | x=[ssqdata[len(ssqdata)-1]] 109 | 110 | print("input: %s"%(x)) 111 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 112 | feed_dict={input_data: x}) 113 | poem_=np.argmax(np.array(predict),axis=1) 114 | sorted_result=np.argsort(np.array(predict),axis=1) 115 | results=poem_+np.ones(1) 116 | print(results) 117 | print(sorted_result) 118 | return poem_ 119 | 120 | 121 | if __name__ == '__main__': 122 | # begin_char = input('## please input the first character:') 123 | # poem=gen_blue()#13 124 | poem = gen_poem() 125 | # pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /ssq4all_test_v4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all_v2/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 10,7+8,1]) 44 | logits = inference(input_data, 1, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 49 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.001) 50 | 51 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 52 | # 53 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 54 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 55 | 56 | saver = tf.train.Saver(tf.global_variables()) 57 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 58 | with tf.Session() as sess: 59 | sess.run(init_op) 60 | 61 | checkpoint = tf.train.latest_checkpoint('./model4all_v4/') 62 | saver.restore(sess, checkpoint) 63 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368") 64 | # ssqdata = get_exl_data(random_order=True,use_resnet=True) 65 | # ssqdata = get_exl_data_v3(random_order=True, use_resnet=True) 66 | ssqdata = get_exl_data_by_period(random_order=True, use_resnet=True, times=10) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32],[0],[0],[0],[0],[0],[0],[0],[0]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 74 | print(sorted_result) 75 | print("output: %s"%results) 76 | return poem_ 77 | 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | # begin_char = input('## please input the first character:') 83 | # poem=gen_blue()#13 84 | poem = gen_poem() 85 | # pretty_print_poem(poem_=poem) 86 | -------------------------------------------------------------------------------- /ssq4all_v4.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(),".."))) 6 | import numpy as np 7 | import tensorflow as tf 8 | from poems.model import rnn_model 9 | from poems.resnet import * 10 | from poems.poems import process_poems, generate_batch 11 | from ssq_data import * 12 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48) 13 | import win_unicode_console 14 | win_unicode_console.enable() 15 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.') 16 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 17 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all_v4'), 'model save path.') 18 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 19 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 20 | tf.app.flags.DEFINE_integer('epochs', 2000001, 'train how many epochs.') 21 | tf.app.flags.DEFINE_integer('times', 10, 'train how many epochs.') 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | 25 | def run_training(): 26 | if not os.path.exists(FLAGS.model_dir): 27 | os.makedirs(FLAGS.model_dir) 28 | # 29 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 30 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 31 | # ssqdata=get_exl_data(random_order=False,use_resnet=True) 32 | # # print(ssqdata[len(ssqdata)-1]) 33 | # batches_inputs=ssqdata[0:(len(ssqdata)-1)] 34 | ssqdata=get_exl_data(random_order=False,use_resnet=False) 35 | # ssqdata=get_dlt_data(random_order=False,use_resnet=True) 36 | # print(ssqdata[len(ssqdata)-1]) 37 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 38 | # ssqdata=get_dlt_data(random_order=False,use_resnet=False) 39 | batches_outputs = ssqdata[1:(len(ssqdata))] 40 | ssqdata=get_exl_data_by_period(random_order=False,use_resnet=True,times=FLAGS.times) 41 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 42 | FLAGS.batch_size=len(batches_inputs) 43 | # print(np.shape(batches_outputs)) 44 | # data=batches_outputs[1:7] 45 | # print(len(data)) 46 | del ssqdata 47 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.times,7+8,1]) 48 | logits = inference(input_data, 1, reuse=False,output_num=128) 49 | 50 | # print(tf.shape(input_data)) 51 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 52 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 62 | sess.run(init_op) 63 | 64 | start_epoch = 0 65 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368") 66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 67 | if checkpoint: 68 | saver.restore(sess, checkpoint) 69 | print("## restore from the checkpoint {0}".format(checkpoint)) 70 | start_epoch += int(checkpoint.split('-')[-1]) 71 | print('## start training...') 72 | try: 73 | # for epoch in range(start_epoch, FLAGS.epochs): 74 | epoch = start_epoch 75 | FLAGS.epochs=epoch+100000 76 | print('till',FLAGS.epochs) 77 | # for epoch in range(start_epoch, FLAGS.epochs): 78 | while(epoch len(vocabs): 35 | sample = len(vocabs) - 1 36 | return vocabs[sample] 37 | 38 | 39 | def gen_poem(): 40 | batch_size = 1 41 | print('## loading model from %s' % model_dir) 42 | 43 | input_data = tf.placeholder(tf.int32, [batch_size, None]) 44 | 45 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33+16, 46 | rnn_size=128, output_num=7,input_num=7,num_layers=7, batch_size=1, learning_rate=0.01) 47 | 48 | saver = tf.train.Saver(tf.global_variables()) 49 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 50 | with tf.Session() as sess: 51 | sess.run(init_op) 52 | 53 | checkpoint = tf.train.latest_checkpoint(model_dir) 54 | saver.restore(sess, checkpoint) 55 | ssqdata = get_exl_data() 56 | # x = np.array([list(map(word_int_map.get, start_token))]) 57 | x=[ssqdata[len(ssqdata)-1]] 58 | print("input: %s"%(x+np.asarray([1,1,1,1,1,1,-32]))) 59 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 60 | feed_dict={input_data: x}) 61 | poem_=np.argmax(np.array(predict),axis=1) 62 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 63 | print("output:%s"%results) 64 | return poem_ 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | # begin_char = input('## please input the first character:') 70 | poem = gen_poem() 71 | # pretty_print_poem(poem_=poem) --------------------------------------------------------------------------------