├── LICENSE
├── README.md
├── dlt2.xls
├── dlt4all.py
├── dlt4all_test.py
├── hyper_parameters.py
├── poems
├── model.py
├── poems.py
└── resnet.py
├── ssq.py
├── ssq.xls
├── ssq4all.py
├── ssq4all_test.py
├── ssq4all_test_v4.py
├── ssq4all_v4.py
├── ssq_data.py
└── ssq_test.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yanchao Liang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 利用神经网络和LSTM预测双色球(How To predict the China's Union Lotto with Neural Network and LSTM)
2 | ## 2020/12/4更新
3 | 因为问的人比较多了,所以提供了适配'tensorflow 2.X'版本的[分支](https://github.com/Liang-yc/ssq/tree/tf2).
4 |
5 | ## 2019/6/18更新
6 | 利用最近几次的双色球结果和奖池情况做预测,训练和测试见`ssq4all_v4.py`和`ssq4all_test_v4.py`。
7 | ## 2018/9/16更新
8 | 添加了大乐透的训练测试文件。
9 | ## 2018/5/30更新
10 | 根据反馈,修复了一些错误,删除了两个文件。另外,有问题请直接开issue。
11 | ## 2018/3/29更新
12 | 尝试用CNN和LSTM做预测。CNN用于提取特征,采用的是resnet。目前最好的结果是五等奖(中4个红号)。
13 | ## 背景(Background)
14 | ------
15 | 这个项目是通过结合神经网络和Long Short-Term Memory(LSTM)完成双色球预测。关于双色球的介绍,在此处不在赘述(参见[网站](https://baike.baidu.com/item/中国福利彩票双色球/8676030?fr=aladdin&fromid=75279&fromtitle=%E5%8F%8C%E8%89%B2%E7%90%83))。该项目以真实的双色球开奖结果作为输入(7个数值,其中前6个表示红号,最后一个表示蓝号),输出预测结果(输出仍为7个数值,其中前6个表示红号,最后一个表示蓝号)。目前该项目还处于开发中。
16 | 其中,项目核心代码主要借鉴了这个[网站](https://github.com/jinfagang/tensorflow_poems) 。
17 | 截至目前,该项目充分训练后的预测结果中,正确预测的红号数目为0-2个(1个的居多,但鉴于这概率本身就偏高,因此不能算是取得很好的结果),正确预测的蓝号数目为0-1个(0个的居多),考虑到双色球的中奖条件(至少3红号正确或者至少蓝号正确),**不建议各位直接根据预测结果去买**。
18 | 预测结果不好有以下几个原因:
19 | 1. 双色球本身是一个随机事件,而我却希望从中找到规律;
20 | 2. 数据集样本数量不够多且无法有效扩大数据集;
21 | 3. 可能我的模型还不够优秀。
22 |
23 | 如果您有什么好的建议或者有什么问题,请直接发我邮箱(yc_liang@qq.com)或者在issue留言也可以;
24 | 如果您觉得这项目似乎还有点意思,记得STAR;
25 | 当然,如果您真的中奖了,要求不高, STAR。
26 |
27 | ## 项目依赖(Requirement)
28 | -----
29 |
30 | 1.Tensorflow;
31 | 2.pyexcel_xls;
32 | 3.CUDA.(optional)
33 | 该项目的代码可能还需要一些python包但我没列出来,请自行pip一下。
34 |
35 | ## 项目文件(File contents)
36 | -----
37 |
38 | 1.` poems` : ` poems`文件夹包含3个模型文件,其中`model.py`是本项目所使用的模型文件;其中`resnet.py`定义的是resnet模型;
39 | 2.`ssq.py `: 该文件用于训练双色球模型;
40 | 3.`ssq.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
41 | 4.`ssq_data.py ` : 读取`ssq.xls`文件中的数据并变成所需要的形式.
42 | 5.`ssq_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
43 | 6.`ssq4all.py `: 该文件用于训练双色球模型(和`ssq.py `功能一样,但用的模型不一样);
44 | 7.`ssq4all_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
45 | 6.`dlt.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
46 | 7.`dlt4all.py `: 该文件用于训练大乐透模型;
47 | 8.`dlt4all_test.py`: 以最近一次的大乐透开奖结果作为输入,输出模型产生的预测结果;
48 |
49 | ## 训练及测试(How To Use)
50 | -----
51 | 设置好参数,运行ssq.py(建议在控制台运行,这样训练到一半可以直接通过按CTRL+C取消训练并保存结果;这样下次训练时可以从上一次的checkpoint继续训练)。训练后,项目目录下会生成一个model文件夹用于保存模型。然后运行ssq_test.py进行预测就可以了。
52 | 该项目在W10, tensorflow 1.4/1.5正常运行。有无GPU均可运行。原则上讲,这项目的模型并不大,且数据量也不大,所以可以直接将所有训练数据作为一个batch进行训练,后续如果需要加深模型,可能要降低batch size。
53 |
--------------------------------------------------------------------------------
/dlt2.xls:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liang-yc/ssq/4c69745d1fb731b4a39cc571af58a31368439789/dlt2.xls
--------------------------------------------------------------------------------
/dlt4all.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import os
20 | import numpy as np
21 | import tensorflow as tf
22 | from poems.model import rnn_model
23 | from poems.resnet import *
24 | from poems.poems import process_poems, generate_batch
25 | from ssq_data import *
26 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48)
27 | import win_unicode_console
28 | win_unicode_console.enable()
29 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.')
30 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.')
31 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./dlt_model'), 'model save path.')
32 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
33 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
34 | tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.')
35 |
36 | FLAGS = tf.app.flags.FLAGS
37 |
38 |
39 | def run_training():
40 | if not os.path.exists(FLAGS.model_dir):
41 | os.makedirs(FLAGS.model_dir)
42 | #
43 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
44 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
45 | # ssqdata=get_exl_data(random_order=False,use_resnet=True)
46 | # # print(ssqdata[len(ssqdata)-1])
47 | # batches_inputs=ssqdata[0:(len(ssqdata)-1)]
48 | # ssqdata=get_exl_data(random_order=False,use_resnet=False)
49 | ssqdata=get_dlt_data(random_order=False,use_resnet=True)
50 | # print(ssqdata[len(ssqdata)-1])
51 | batches_inputs=ssqdata[0:(len(ssqdata)-1)]
52 | ssqdata=get_dlt_data(random_order=False,use_resnet=False)
53 | batches_outputs = ssqdata[1:(len(ssqdata))]
54 | FLAGS.batch_size=len(batches_inputs)
55 | # print(np.shape(batches_outputs))
56 | # data=batches_outputs[1:7]
57 | # print(len(data))
58 | del ssqdata
59 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1])
60 | logits = inference(input_data, 1, reuse=False,output_num=128)
61 |
62 | # print(tf.shape(input_data))
63 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
64 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=35+12,output_num=7,
65 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate)
66 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
67 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
68 |
69 | saver = tf.train.Saver(tf.global_variables())
70 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
71 | with tf.Session() as sess:
72 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
73 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
74 | sess.run(init_op)
75 |
76 | start_epoch = 0
77 | # saver.restore(sess, "D:/tensorflow_poems-master/model4all/poems-207261")
78 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
79 | if checkpoint:
80 | saver.restore(sess, checkpoint)
81 | print("## restore from the checkpoint {0}".format(checkpoint))
82 | start_epoch += int(checkpoint.split('-')[-1])
83 | print('## start training...')
84 | try:
85 | for epoch in range(start_epoch, FLAGS.epochs):
86 | n = 0
87 | # n_chunk = len(poems_vector) // FLAGS.batch_size
88 | # n_chunk = len(batches_inputs) // FLAGS.batch_size
89 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size)
90 | for batch in range(n_chunk):
91 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs)
92 | if left<0:
93 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
94 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
95 | else:
96 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ]
97 | # temp.extend(batches_inputs[0:left])
98 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)]
99 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ]
100 | # temp.extend(batches_outputs[0:left])
101 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)]
102 | # print(len(inputdata))
103 | loss, _, _ = sess.run([
104 | end_points['total_loss'],
105 | end_points['last_state'],
106 | end_points['train_op']
107 | ], feed_dict={input_data: inputdata, output_targets: outputdata})
108 | # ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs})
109 | n += 1
110 | if epoch % 1000 == 0:
111 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
112 | if epoch % 50000 == 0:
113 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
114 | except KeyboardInterrupt:
115 | print('## Interrupt manually, try saving checkpoint for now...')
116 | finally:
117 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
118 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
119 |
120 |
121 | def main(_):
122 | run_training()
123 |
124 |
125 | if __name__ == '__main__':
126 | tf.app.run()
127 |
--------------------------------------------------------------------------------
/dlt4all_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | from poems.model import rnn_model
21 | from poems.resnet import *
22 | from poems.poems import process_poems
23 | import numpy as np
24 | from ssq_data import *
25 | start_token = 'B'
26 | end_token = 'E'
27 | model_dir = './model4all/'
28 | corpus_file = './data/poems.txt'
29 |
30 |
31 | def to_word(predict, vocabs):
32 | t = np.cumsum(predict)
33 | s = np.sum(predict)
34 | sample = int(np.searchsorted(t, np.random.rand(1) * s))
35 | if sample > len(vocabs):
36 | sample = len(vocabs) - 1
37 | return vocabs[sample]
38 |
39 |
40 | def gen_poem():
41 | batch_size = 1
42 | print('## loading model from %s' % model_dir)
43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1])
44 | logits = inference(input_data, 1, reuse=False,output_num=128)
45 |
46 | # print(tf.shape(input_data))
47 | output_targets = tf.placeholder(tf.int32, [1, None])
48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=35+12,
49 | output_num=7,
50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
51 |
52 | # input_data = tf.placeholder(tf.int32, [batch_size, None])
53 | #
54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33,
55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
56 |
57 | saver = tf.train.Saver(tf.global_variables())
58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
59 | with tf.Session() as sess:
60 | sess.run(init_op)
61 |
62 | checkpoint = tf.train.latest_checkpoint('./dlt_model/')
63 | saver.restore(sess, checkpoint)
64 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4red/poems-150000")
65 | ssqdata = get_dlt_data(random_order=False,use_resnet=True)
66 | # x = np.array([list(map(word_int_map.get, start_token))])
67 | x=[ssqdata[len(ssqdata)-1]]
68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[-34],[-34]]]])))
69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
70 | feed_dict={input_data: x})
71 | poem_=np.argmax(np.array(predict),axis=1)
72 | sorted_result = np.argsort(np.array(predict), axis=1)
73 | results=poem_+np.asarray([1,1,1,1,1,-34,-34])
74 | # print(sorted_result)
75 | print("output: %s"%results)
76 | return poem_
77 |
78 | def gen_blue():
79 | batch_size = 1
80 | print('## loading model from %s' % model_dir)
81 | input_data = tf.placeholder(tf.float32, [1, 1,7,1])
82 | logits = inference(input_data, 10, reuse=False,output_num=128)
83 |
84 | # print(tf.shape(input_data))
85 | output_targets = tf.placeholder(tf.int32, [1, None])
86 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33,
87 | output_num=1,
88 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01)
89 |
90 | # input_data = tf.placeholder(tf.int32, [batch_size, None])
91 | #
92 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33,
93 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
94 |
95 | saver = tf.train.Saver(tf.global_variables())
96 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
97 | with tf.Session() as sess:
98 | sess.run(init_op)
99 |
100 | checkpoint = tf.train.latest_checkpoint('./model4all/')
101 | saver.restore(sess, checkpoint)
102 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201")
103 | ssqdata = get_exl_data(random_order=False,use_resnet=True)
104 | # x = np.array([list(map(word_int_map.get, start_token))])
105 | x=[ssqdata[len(ssqdata)-1]]
106 | print("input: %s"%(x))
107 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
108 | feed_dict={input_data: x})
109 | poem_=np.argmax(np.array(predict),axis=1)
110 | sorted_result=np.argsort(np.array(predict),axis=1)
111 | results=poem_+np.ones(1)
112 | print(results)
113 | print(sorted_result)
114 | return poem_
115 |
116 |
117 | if __name__ == '__main__':
118 | # begin_char = input('## please input the first character:')
119 | # poem=gen_blue()#13
120 | poem = gen_poem()
121 | # pretty_print_poem(poem_=poem)
--------------------------------------------------------------------------------
/hyper_parameters.py:
--------------------------------------------------------------------------------
1 | # Coder: Wenxin Xu
2 | # Github: https://github.com/wenxinxu/resnet_in_tensorflow
3 | # ==============================================================================
4 | import tensorflow as tf
5 |
6 | FLAGS = tf.app.flags.FLAGS
7 |
8 | ## The following flags are related to save paths, tensorboard outputs and screen outputs
9 |
10 | tf.app.flags.DEFINE_string('version', 'test_110', '''A version number defining the directory to save
11 | logs and checkpoints''')
12 | tf.app.flags.DEFINE_integer('report_freq', 391, '''Steps takes to output errors on the screen
13 | and write summaries''')
14 | tf.app.flags.DEFINE_float('train_ema_decay', 0.95, '''The decay factor of the train error's
15 | moving average shown on tensorboard''')
16 |
17 |
18 | ## The following flags define hyper-parameters regards training
19 |
20 | tf.app.flags.DEFINE_integer('train_steps', 80000, '''Total steps that you want to train''')
21 | tf.app.flags.DEFINE_boolean('is_full_validation', False, '''Validation w/ full validation set or
22 | a random batch''')
23 | tf.app.flags.DEFINE_integer('train_batch_size', 128, '''Train batch size''')
24 | tf.app.flags.DEFINE_integer('validation_batch_size', 250, '''Validation batch size, better to be
25 | a divisor of 10000 for this task''')
26 | tf.app.flags.DEFINE_integer('test_batch_size', 125, '''Test batch size''')
27 |
28 | tf.app.flags.DEFINE_float('init_lr', 0.1, '''Initial learning rate''')
29 | tf.app.flags.DEFINE_float('lr_decay_factor', 0.1, '''How much to decay the learning rate each
30 | time''')
31 | tf.app.flags.DEFINE_integer('decay_step0', 40000, '''At which step to decay the learning rate''')
32 | tf.app.flags.DEFINE_integer('decay_step1', 60000, '''At which step to decay the learning rate''')
33 |
34 |
35 | ## The following flags define hyper-parameters modifying the training network
36 |
37 | tf.app.flags.DEFINE_integer('num_residual_blocks', 5, '''How many residual blocks do you want''')
38 | tf.app.flags.DEFINE_float('weight_decay', 0.0002, '''scale for l2 regularization''')
39 |
40 |
41 | ## The following flags are related to data-augmentation
42 |
43 | tf.app.flags.DEFINE_integer('padding_size', 2, '''In data augmentation, layers of zero padding on
44 | each side of the image''')
45 |
46 |
47 | ## If you want to load a checkpoint and continue training
48 |
49 | tf.app.flags.DEFINE_string('ckpt_path', 'cache/logs_repeat20/model.ckpt-100000', '''Checkpoint
50 | directory to restore''')
51 | tf.app.flags.DEFINE_boolean('is_use_ckpt', False, '''Whether to load a checkpoint and continue
52 | training''')
53 |
54 | tf.app.flags.DEFINE_string('test_ckpt_path', 'model_110.ckpt-79999', '''Checkpoint
55 | directory to restore''')
56 |
57 |
58 | train_dir = 'logs_' + FLAGS.version + '/'
59 |
--------------------------------------------------------------------------------
/poems/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: model.py
3 | # author: JinTian
4 | # time: 07/03/2017 3:07 PM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | import numpy as np
21 |
22 |
23 | def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
24 | input_num=128,output_num=1,learning_rate=0.01,use_cnn=True):
25 | """
26 | construct rnn seq2seq model.
27 | :param model: model class
28 | :param input_data: input data placeholder
29 | :param output_data: output data placeholder
30 | :param vocab_size:
31 | :param rnn_size:
32 | :param num_layers:
33 | :param batch_size:
34 | :param learning_rate:
35 | :return:
36 | """
37 | end_points = {}
38 |
39 |
40 | if model == 'rnn':
41 | cell_fun = tf.contrib.rnn.BasicRNNCell
42 | cell = cell_fun(rnn_size)
43 | elif model == 'gru':
44 | cell_fun = tf.contrib.rnn.GRUCell
45 | cell = cell_fun(rnn_size)
46 | elif model == 'lstm':
47 | cell_fun = tf.contrib.rnn.BasicLSTMCell
48 | cell = cell_fun(rnn_size, state_is_tuple=True)
49 |
50 | cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
51 |
52 | if output_data is not None:
53 | initial_state = cell.zero_state(batch_size, tf.float32)
54 | else:
55 | initial_state = cell.zero_state(1, tf.float32)
56 | if use_cnn:
57 | with tf.name_scope('fc1'):
58 | fc1_weights = tf.Variable( # fully connected, depth 512.
59 | tf.truncated_normal([input_num, 128],
60 | # mean=1.0,
61 | stddev=0.1,
62 | dtype=tf.float32))
63 | fc1_biases = tf.Variable(tf.constant(1.0, shape=[128], dtype=tf.float32))
64 | fc1 = tf.nn.relu(tf.matmul(tf.to_float(input_data), fc1_weights) + fc1_biases)
65 | with tf.name_scope('fc2'):
66 | fc2_weights = tf.Variable( # fully connected, depth 512.
67 | tf.truncated_normal([128, 256],
68 | # mean=1.0,
69 | stddev=0.1,
70 | dtype=tf.float32))
71 | fc2_biases = tf.Variable(tf.constant(1.0, shape=[256], dtype=tf.float32))
72 | fc2 = tf.nn.relu(tf.matmul(tf.to_float(fc1), fc2_weights) + fc2_biases)
73 | with tf.name_scope('fc3'):
74 | fc3_weights = tf.Variable( # fully connected, depth 512.
75 | tf.truncated_normal([256, 512],
76 | # mean=1.0,
77 | stddev=0.1,
78 | dtype=tf.float32))
79 | fc3_biases = tf.Variable(tf.constant(1.0, shape=[512], dtype=tf.float32))
80 | fc3 = tf.nn.relu(tf.matmul(tf.to_float(fc2), fc3_weights) + fc3_biases)
81 | with tf.name_scope('fc4'):
82 | fc4_weights = tf.Variable( # fully connected, depth 512.
83 | tf.truncated_normal([512, 128*output_num],
84 | # mean=1.0,
85 | stddev=0.1,
86 | dtype=tf.float32))
87 | fc4_biases = tf.Variable(tf.constant(1.0, shape=[128*output_num], dtype=tf.float32))
88 | fc4 = tf.nn.relu(tf.matmul(fc3, fc4_weights) + fc4_biases)
89 |
90 |
91 | # [batch_size, ?, rnn_size] = [64, ?, 128]
92 | embedded=tf.reshape(fc4,[batch_size,-1,128])
93 | outputs, last_state = tf.nn.dynamic_rnn(cell, embedded, initial_state=initial_state)
94 | else:
95 | with tf.device("/cpu:0"):
96 | embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
97 | [vocab_size + 1, rnn_size], -1.0, 1.0))
98 | inputs = tf.nn.embedding_lookup(embedding, input_data)
99 | # embedded = tf.reshape(input_data, [batch_size, -1, 128])
100 | # outputs, last_state = tf.nn.dynamic_rnn(cell, embedded, initial_state=initial_state)
101 | # outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
102 | output = tf.reshape(outputs, [-1, rnn_size])
103 |
104 | weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
105 | bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
106 | logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
107 | # [?, vocab_size+1]
108 |
109 | if output_data is not None:
110 | # output_data must be one-hot encode
111 | labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
112 | # should be [?, vocab_size+1]
113 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
114 |
115 | # loss shape should be [?, vocab_size+1]
116 | total_loss = tf.reduce_mean(loss)
117 | regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
118 | tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases)+
119 | tf.nn.l2_loss(fc3_weights) + tf.nn.l2_loss(fc3_biases)+
120 | tf.nn.l2_loss(fc4_weights) + tf.nn.l2_loss(fc4_biases))
121 | total_loss+=5e-4 * regularizers
122 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
123 |
124 | end_points['initial_state'] = initial_state
125 | end_points['output'] = output
126 | end_points['train_op'] = train_op
127 | end_points['total_loss'] = total_loss
128 | end_points['loss'] = loss
129 | end_points['last_state'] = last_state
130 | prediction = tf.nn.softmax(logits)
131 | end_points['prediction'] = prediction
132 | else:
133 | prediction = tf.nn.softmax(logits)
134 |
135 | end_points['initial_state'] = initial_state
136 | end_points['last_state'] = last_state
137 | end_points['prediction'] = prediction
138 |
139 | return end_points
140 |
141 |
--------------------------------------------------------------------------------
/poems/poems.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: poems.py
3 | # author: JinTian
4 | # time: 08/03/2017 7:39 PM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import collections
20 | import os
21 | import sys
22 | import numpy as np
23 |
24 | start_token = 'B'
25 | end_token = 'E'
26 |
27 |
28 | def process_poems(file_name):
29 | # 诗集
30 | poems = []
31 | with open(file_name, "r", encoding='utf-8', ) as f:
32 | for line in f.readlines():
33 | try:
34 | title, content = line.strip().split(':')
35 | content = content.replace(' ', '')
36 | if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
37 | start_token in content or end_token in content:
38 | continue
39 | if len(content) < 5 or len(content) > 79:
40 | continue
41 | content = start_token + content + end_token
42 | poems.append(content)
43 | except ValueError as e:
44 | pass
45 | poems = sorted(poems, key=lambda l: len(line))
46 |
47 | all_words = []
48 | for poem in poems:
49 | all_words += [word for word in poem]
50 | counter = collections.Counter(all_words)
51 | count_pairs = sorted(counter.items(), key=lambda x: -x[1])
52 | words, _ = zip(*count_pairs)
53 |
54 | words = words[:len(words)] + (' ',)
55 | word_int_map = dict(zip(words, range(len(words))))
56 | poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems]
57 |
58 | return poems_vector, word_int_map, words
59 |
60 |
61 | def generate_batch(batch_size, poems_vec, word_to_int):
62 | n_chunk = len(poems_vec) // batch_size
63 | x_batches = []
64 | y_batches = []
65 | for i in range(n_chunk):
66 | start_index = i * batch_size
67 | end_index = start_index + batch_size
68 |
69 | batches = poems_vec[start_index:end_index]
70 | length = max(map(len, batches))
71 | x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
72 | for row in range(batch_size):
73 | x_data[row, :len(batches[row])] = batches[row]
74 | y_data = np.copy(x_data)
75 | y_data[:, :-1] = x_data[:, 1:]
76 | """
77 | x_data y_data
78 | [6,2,4,6,9] [2,4,6,9,9]
79 | [1,4,2,8,5] [4,2,8,5,5]
80 | """
81 | x_batches.append(x_data)
82 | y_batches.append(y_data)
83 | return x_batches, y_batches
84 |
--------------------------------------------------------------------------------
/poems/resnet.py:
--------------------------------------------------------------------------------
1 | # Coder: Wenxin Xu
2 | # Github: https://github.com/wenxinxu/resnet_in_tensorflow
3 | # ==============================================================================
4 | '''
5 | This is the resnet structure
6 | '''
7 | import numpy as np
8 |
9 | from hyper_parameters import *
10 |
11 | BN_EPSILON = 0.001
12 |
13 |
14 | def activation_summary(x):
15 | '''
16 | :param x: A Tensor
17 | :return: Add histogram summary and scalar summary of the sparsity of the tensor
18 | '''
19 | tensor_name = x.op.name
20 | tf.summary.histogram(tensor_name + '/activations', x)
21 | tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
22 |
23 |
24 | def create_variables(name, shape, initializer=tf.contrib.layers.xavier_initializer(), is_fc_layer=False):
25 | '''
26 | :param name: A string. The name of the new variable
27 | :param shape: A list of dimensions
28 | :param initializer: User Xavier as default.
29 | :param is_fc_layer: Want to create fc layer variable? May use different weight_decay for fc
30 | layers.
31 | :return: The created variable
32 | '''
33 |
34 | ## TODO: to allow different weight decay to fully connected layer and conv layer
35 | if is_fc_layer is True:
36 | regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.weight_decay)
37 | else:
38 | regularizer = tf.contrib.layers.l2_regularizer(scale=FLAGS.weight_decay)
39 |
40 | new_variables = tf.get_variable(name, shape=shape, initializer=initializer,
41 | regularizer=regularizer)
42 | return new_variables
43 |
44 |
45 | def output_layer(input_layer, num_labels):
46 | '''
47 | :param input_layer: 2D tensor
48 | :param num_labels: int. How many output labels in total? (10 for cifar10 and 100 for cifar100)
49 | :return: output layer Y = WX + B
50 | '''
51 | input_dim = input_layer.get_shape().as_list()[-1]
52 | fc_w = create_variables(name='fc_weights', shape=[input_dim, num_labels], is_fc_layer=True,
53 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
54 | fc_b = create_variables(name='fc_bias', shape=[num_labels], initializer=tf.zeros_initializer())
55 |
56 | fc_h = tf.matmul(input_layer, fc_w) + fc_b
57 | return fc_h
58 |
59 |
60 | def batch_normalization_layer(input_layer, dimension):
61 | '''
62 | Helper function to do batch normalziation
63 | :param input_layer: 4D tensor
64 | :param dimension: input_layer.get_shape().as_list()[-1]. The depth of the 4D tensor
65 | :return: the 4D tensor after being normalized
66 | '''
67 | mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2])
68 | beta = tf.get_variable('beta', dimension, tf.float32,
69 | initializer=tf.constant_initializer(0.0, tf.float32))
70 | gamma = tf.get_variable('gamma', dimension, tf.float32,
71 | initializer=tf.constant_initializer(1.0, tf.float32))
72 | bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, BN_EPSILON)
73 |
74 | return bn_layer
75 |
76 |
77 | def conv_bn_relu_layer(input_layer, filter_shape, stride):
78 | '''
79 | A helper function to conv, batch normalize and relu the input tensor sequentially
80 | :param input_layer: 4D tensor
81 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
82 | :param stride: stride size for conv
83 | :return: 4D tensor. Y = Relu(batch_normalize(conv(X)))
84 | '''
85 |
86 | out_channel = filter_shape[-1]
87 | filter = create_variables(name='conv', shape=filter_shape)
88 |
89 | conv_layer = tf.nn.conv2d(input_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
90 | bn_layer = batch_normalization_layer(conv_layer, out_channel)
91 |
92 | output = tf.nn.relu(bn_layer)
93 | return output
94 |
95 |
96 | def bn_relu_conv_layer(input_layer, filter_shape, stride):
97 | '''
98 | A helper function to batch normalize, relu and conv the input layer sequentially
99 | :param input_layer: 4D tensor
100 | :param filter_shape: list. [filter_height, filter_width, filter_depth, filter_number]
101 | :param stride: stride size for conv
102 | :return: 4D tensor. Y = conv(Relu(batch_normalize(X)))
103 | '''
104 |
105 | in_channel = input_layer.get_shape().as_list()[-1]
106 |
107 | bn_layer = batch_normalization_layer(input_layer, in_channel)
108 | relu_layer = tf.nn.relu(bn_layer)
109 |
110 | filter = create_variables(name='conv', shape=filter_shape)
111 | conv_layer = tf.nn.conv2d(relu_layer, filter, strides=[1, stride, stride, 1], padding='SAME')
112 | return conv_layer
113 |
114 |
115 | def residual_block(input_layer, output_channel, first_block=False):
116 | '''
117 | Defines a residual block in ResNet
118 | :param input_layer: 4D tensor
119 | :param output_channel: int. return_tensor.get_shape().as_list()[-1] = output_channel
120 | :param first_block: if this is the first residual block of the whole network
121 | :return: 4D tensor.
122 | '''
123 | input_channel = input_layer.get_shape().as_list()[-1]
124 |
125 | # When it's time to "shrink" the image size, we use stride = 2
126 | if input_channel * 2 == output_channel:
127 | increase_dim = True
128 | stride = 2
129 | elif input_channel == output_channel:
130 | increase_dim = False
131 | stride = 1
132 | else:
133 | raise ValueError('Output and input channel does not match in residual blocks!!!')
134 |
135 | # The first conv layer of the first residual block does not need to be normalized and relu-ed.
136 | with tf.variable_scope('conv1_in_block'):
137 | if first_block:
138 | filter = create_variables(name='conv', shape=[3, 3, input_channel, output_channel])
139 | conv1 = tf.nn.conv2d(input_layer, filter=filter, strides=[1, 1, 1, 1], padding='SAME')
140 | else:
141 | conv1 = bn_relu_conv_layer(input_layer, [3, 3, input_channel, output_channel], stride)
142 |
143 | with tf.variable_scope('conv2_in_block'):
144 | conv2 = bn_relu_conv_layer(conv1, [3, 3, output_channel, output_channel], 1)
145 |
146 | # When the channels of input layer and conv2 does not match, we add zero pads to increase the
147 | # depth of input layers
148 | if increase_dim is True:
149 | pooled_input = tf.nn.avg_pool(input_layer, ksize=[1, 1, 1, 1],
150 | strides=[1, 2, 2, 1], padding='VALID')
151 | padded_input = tf.pad(pooled_input, [[0, 0], [0, 0], [0, 0], [input_channel // 2,
152 | input_channel // 2]])
153 | else:
154 | padded_input = input_layer
155 |
156 | output = conv2 + padded_input
157 | return output
158 |
159 |
160 | def inference(input_tensor_batch, n, reuse,output_num=128):
161 | '''
162 | The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n +1 = 6n + 2
163 | :param input_tensor_batch: 4D tensor
164 | :param n: num_residual_blocks
165 | :param reuse: To build train graph, reuse=False. To build validation graph and share weights
166 | with train graph, resue=True
167 | :return: last layer in the network. Not softmax-ed
168 | '''
169 |
170 | layers = []
171 | with tf.variable_scope('conv0', reuse=reuse):
172 | conv0 = conv_bn_relu_layer(input_tensor_batch, [1, 1, 1, 16], 1)
173 | activation_summary(conv0)
174 | layers.append(conv0)
175 |
176 | for i in range(n):
177 | with tf.variable_scope('conv1_%d' % i, reuse=reuse):
178 | if i == 0:
179 | conv1 = residual_block(layers[-1], 16, first_block=True)
180 | else:
181 | conv1 = residual_block(layers[-1], 16)
182 | activation_summary(conv1)
183 | layers.append(conv1)
184 |
185 | for i in range(n):
186 | with tf.variable_scope('conv2_%d' % i, reuse=reuse):
187 | conv2 = residual_block(layers[-1], 32)
188 | activation_summary(conv2)
189 | layers.append(conv2)
190 |
191 | for i in range(n):
192 | with tf.variable_scope('conv3_%d' % i, reuse=reuse):
193 | conv3 = residual_block(layers[-1], 64)
194 | layers.append(conv3)
195 | # assert conv3.get_shape().as_list()[1:] == [8, 8, 64]
196 |
197 | with tf.variable_scope('fc', reuse=reuse):
198 | in_channel = layers[-1].get_shape().as_list()[-1]
199 | bn_layer = batch_normalization_layer(layers[-1], in_channel)
200 | relu_layer = tf.nn.relu(bn_layer)
201 | global_pool = tf.reduce_mean(relu_layer, [1, 2])
202 |
203 | assert global_pool.get_shape().as_list()[-1:] == [64]
204 | output = output_layer(global_pool, output_num)
205 | layers.append(output)
206 |
207 | return layers[-1]
208 |
209 |
210 | def test_graph(train_dir='logs'):
211 | '''
212 | Run this function to look at the graph structure on tensorboard. A fast way!
213 | :param train_dir:
214 | '''
215 | input_tensor = tf.constant(np.ones([128, 32, 32, 3]), dtype=tf.float32)
216 | result = inference(input_tensor, 2, reuse=False)
217 | init = tf.initialize_all_variables()
218 | sess = tf.Session()
219 | sess.run(init)
220 | summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
--------------------------------------------------------------------------------
/ssq.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import os
20 | import numpy as np
21 | import tensorflow as tf
22 | from poems.model import rnn_model
23 | from poems.poems import process_poems, generate_batch
24 | from ssq_data import *
25 | tf.app.flags.DEFINE_integer('batch_size', -1, 'batch size.')
26 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.')
27 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
28 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
29 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
30 | tf.app.flags.DEFINE_integer('epochs', 150000, 'train how many epochs.')
31 |
32 | FLAGS = tf.app.flags.FLAGS
33 |
34 |
35 | def run_training():
36 | # if not os.path.exists(FLAGS.model_dir):
37 | # os.makedirs(FLAGS.model_dir)
38 | #
39 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
40 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
41 | ssqdata=get_exl_data(random_order=True)
42 | batches_inputs=ssqdata[0:(len(ssqdata)-1)]
43 | batches_outputs = ssqdata[1:(len(ssqdata))]
44 | FLAGS.batch_size=len(ssqdata)-1
45 | # data=batches_outputs[1:7]
46 | # print(len(data))
47 | del ssqdata
48 | input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
49 | # print(tf.shape(input_data))
50 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
51 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=33+16,
52 | output_num=7,input_num=7,
53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate)
54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
56 |
57 | saver = tf.train.Saver(tf.global_variables())
58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
59 | with tf.Session() as sess:
60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
62 | sess.run(init_op)
63 |
64 | start_epoch = 0
65 | # saver.restore(sess, "F:/tensorflow_poems-master/model/poems-100000")
66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
67 | if checkpoint:
68 | saver.restore(sess, checkpoint)
69 | print("## restore from the checkpoint {0}".format(checkpoint))
70 | start_epoch += int(checkpoint.split('-')[-1])
71 | print('## start training...')
72 | try:
73 | for epoch in range(start_epoch, FLAGS.epochs):
74 | n = 0
75 | # n_chunk = len(poems_vector) // FLAGS.batch_size
76 | # n_chunk = len(batches_inputs) // FLAGS.batch_size
77 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size)
78 | for batch in range(n_chunk):
79 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs)
80 | if left<0:
81 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
82 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
83 | else:
84 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ]
85 | # temp.extend(batches_inputs[0:left])
86 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)]
87 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ]
88 | # temp.extend(batches_outputs[0:left])
89 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)]
90 | # print(len(inputdata))
91 | loss, _, _ = sess.run([
92 | end_points['total_loss'],
93 | end_points['last_state'],
94 | end_points['train_op']
95 | ], feed_dict={input_data: inputdata, output_targets: outputdata})
96 | # # ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
97 | n += 1
98 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
99 | if epoch % 5000 == 0:
100 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
101 | except KeyboardInterrupt:
102 | print('## Interrupt manually, try saving checkpoint for now...')
103 | finally:
104 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
105 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
106 |
107 |
108 | def main(_):
109 | run_training()
110 |
111 |
112 | if __name__ == '__main__':
113 | tf.app.run()
114 |
--------------------------------------------------------------------------------
/ssq.xls:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liang-yc/ssq/4c69745d1fb731b4a39cc571af58a31368439789/ssq.xls
--------------------------------------------------------------------------------
/ssq4all.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import os
20 | import numpy as np
21 | import tensorflow as tf
22 | from poems.model import rnn_model
23 | from poems.resnet import *
24 | from poems.poems import process_poems, generate_batch
25 | from ssq_data import *
26 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48)
27 | # import win_unicode_console
28 | # win_unicode_console.enable()
29 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.')
30 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.')
31 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all'), 'model save path.')
32 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
33 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
34 | tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.')
35 |
36 | FLAGS = tf.app.flags.FLAGS
37 |
38 |
39 | def run_training():
40 | if not os.path.exists(FLAGS.model_dir):
41 | os.makedirs(FLAGS.model_dir)
42 | #
43 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
44 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
45 | ssqdata=get_exl_data(random_order=False,use_resnet=True)
46 | # print(ssqdata[len(ssqdata)-1])
47 | batches_inputs=ssqdata[0:(len(ssqdata)-1)]
48 | ssqdata=get_exl_data(random_order=False,use_resnet=False)
49 | batches_outputs = ssqdata[1:(len(ssqdata))]
50 | FLAGS.batch_size=len(batches_inputs)
51 | # print(np.shape(batches_outputs))
52 | # data=batches_outputs[1:7]
53 | # print(len(data))
54 | del ssqdata
55 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1])
56 | logits = inference(input_data, 2, reuse=False,output_num=128)
57 |
58 | # print(tf.shape(input_data))
59 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
60 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7,
61 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate)
62 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
63 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
64 |
65 | saver = tf.train.Saver(tf.global_variables())
66 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
67 | with tf.Session() as sess:
68 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
69 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
70 | sess.run(init_op)
71 |
72 | start_epoch = 0
73 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
74 | if checkpoint:
75 | saver.restore(sess, checkpoint)
76 | print("## restore from the checkpoint {0}".format(checkpoint))
77 | start_epoch += int(checkpoint.split('-')[-1])
78 | print('## start training...')
79 | try:
80 | for epoch in range(start_epoch, FLAGS.epochs):
81 | n = 0
82 | # n_chunk = len(poems_vector) // FLAGS.batch_size
83 | # n_chunk = len(batches_inputs) // FLAGS.batch_size
84 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size)
85 | for batch in range(n_chunk):
86 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs)
87 | if left<0:
88 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
89 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
90 | else:
91 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ]
92 | # temp.extend(batches_inputs[0:left])
93 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)]
94 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ]
95 | # temp.extend(batches_outputs[0:left])
96 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)]
97 | # print(len(inputdata))
98 | loss, _, _ = sess.run([
99 | end_points['total_loss'],
100 | end_points['last_state'],
101 | end_points['train_op']
102 | ], feed_dict={input_data: inputdata, output_targets: outputdata})
103 | # ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs})
104 | n += 1
105 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
106 | if epoch % 50000 == 0:
107 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
108 | except KeyboardInterrupt:
109 | print('## Interrupt manually, try saving checkpoint for now...')
110 | finally:
111 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
112 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
113 |
114 |
115 | def main(_):
116 | run_training()
117 |
118 |
119 | if __name__ == '__main__':
120 | tf.app.run()
121 |
--------------------------------------------------------------------------------
/ssq4all_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | from poems.model import rnn_model
21 | from poems.resnet import *
22 | from poems.poems import process_poems
23 | import numpy as np
24 | from ssq_data import *
25 | start_token = 'B'
26 | end_token = 'E'
27 | model_dir = './model4all/'
28 | corpus_file = './data/poems.txt'
29 |
30 |
31 | def to_word(predict, vocabs):
32 | t = np.cumsum(predict)
33 | s = np.sum(predict)
34 | sample = int(np.searchsorted(t, np.random.rand(1) * s))
35 | if sample > len(vocabs):
36 | sample = len(vocabs) - 1
37 | return vocabs[sample]
38 |
39 |
40 | def gen_poem():
41 | batch_size = 1
42 | print('## loading model from %s' % model_dir)
43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1])
44 | logits = inference(input_data, 2, reuse=False,output_num=128)
45 |
46 | # print(tf.shape(input_data))
47 | output_targets = tf.placeholder(tf.int32, [1, None])
48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,
49 | output_num=7,
50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
51 |
52 | # input_data = tf.placeholder(tf.int32, [batch_size, None])
53 | #
54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33,
55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
56 |
57 | saver = tf.train.Saver(tf.global_variables())
58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
59 | with tf.Session() as sess:
60 | sess.run(init_op)
61 |
62 | checkpoint = tf.train.latest_checkpoint('./model4all/')
63 | saver.restore(sess, checkpoint)
64 | # saver.restore(sess, "F:/tensorflow_poems-master/model4all/poems-401713")
65 | ssqdata = get_exl_data(random_order=True,use_resnet=True)
66 | # x = np.array([list(map(word_int_map.get, start_token))])
67 | x=[ssqdata[len(ssqdata)-1]]
68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32]]]])))
69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
70 | feed_dict={input_data: x})
71 | poem_=np.argmax(np.array(predict),axis=1)
72 | sorted_result = np.argsort(np.array(predict), axis=1)
73 | results=poem_+np.asarray([1,1,1,1,1,1,-32])
74 | print(sorted_result)
75 | print("output: %s"%results)
76 | poem_=np.argmin(np.array(predict),axis=1)
77 | results=poem_+np.asarray([1,1,1,1,1,1,-32])
78 | print("min output:%s"%results)
79 | return poem_
80 |
81 | def gen_blue():
82 | batch_size = 1
83 | print('## loading model from %s' % model_dir)
84 | input_data = tf.placeholder(tf.float32, [1, 1,7,1])
85 | logits = inference(input_data, 10, reuse=False,output_num=128)
86 |
87 | # print(tf.shape(input_data))
88 | output_targets = tf.placeholder(tf.int32, [1, None])
89 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33,
90 | output_num=1,
91 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01)
92 |
93 | # input_data = tf.placeholder(tf.int32, [batch_size, None])
94 | #
95 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33,
96 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
97 |
98 | saver = tf.train.Saver(tf.global_variables())
99 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
100 | with tf.Session() as sess:
101 | sess.run(init_op)
102 |
103 | checkpoint = tf.train.latest_checkpoint('./model4all/')
104 | saver.restore(sess, checkpoint)
105 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201")
106 | ssqdata = get_exl_data(random_order=True,use_resnet=True)
107 | # x = np.array([list(map(word_int_map.get, start_token))])
108 | x=[ssqdata[len(ssqdata)-1]]
109 |
110 | print("input: %s"%(x))
111 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
112 | feed_dict={input_data: x})
113 | poem_=np.argmax(np.array(predict),axis=1)
114 | sorted_result=np.argsort(np.array(predict),axis=1)
115 | results=poem_+np.ones(1)
116 | print(results)
117 | print(sorted_result)
118 | return poem_
119 |
120 |
121 | if __name__ == '__main__':
122 | # begin_char = input('## please input the first character:')
123 | # poem=gen_blue()#13
124 | poem = gen_poem()
125 | # pretty_print_poem(poem_=poem)
--------------------------------------------------------------------------------
/ssq4all_test_v4.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | from poems.model import rnn_model
21 | from poems.resnet import *
22 | from poems.poems import process_poems
23 | import numpy as np
24 | from ssq_data import *
25 | start_token = 'B'
26 | end_token = 'E'
27 | model_dir = './model4all_v2/'
28 | corpus_file = './data/poems.txt'
29 |
30 |
31 | def to_word(predict, vocabs):
32 | t = np.cumsum(predict)
33 | s = np.sum(predict)
34 | sample = int(np.searchsorted(t, np.random.rand(1) * s))
35 | if sample > len(vocabs):
36 | sample = len(vocabs) - 1
37 | return vocabs[sample]
38 |
39 |
40 | def gen_poem():
41 | batch_size = 1
42 | print('## loading model from %s' % model_dir)
43 | input_data = tf.placeholder(tf.float32, [1, 10,7+8,1])
44 | logits = inference(input_data, 1, reuse=False,output_num=128)
45 |
46 | # print(tf.shape(input_data))
47 | output_targets = tf.placeholder(tf.int32, [1, None])
48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7,
49 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.001)
50 |
51 | # input_data = tf.placeholder(tf.int32, [batch_size, None])
52 | #
53 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33,
54 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01)
55 |
56 | saver = tf.train.Saver(tf.global_variables())
57 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
58 | with tf.Session() as sess:
59 | sess.run(init_op)
60 |
61 | checkpoint = tf.train.latest_checkpoint('./model4all_v4/')
62 | saver.restore(sess, checkpoint)
63 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368")
64 | # ssqdata = get_exl_data(random_order=True,use_resnet=True)
65 | # ssqdata = get_exl_data_v3(random_order=True, use_resnet=True)
66 | ssqdata = get_exl_data_by_period(random_order=True, use_resnet=True, times=10)
67 | x=[ssqdata[len(ssqdata)-1]]
68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32],[0],[0],[0],[0],[0],[0],[0],[0]]]])))
69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
70 | feed_dict={input_data: x})
71 | poem_=np.argmax(np.array(predict),axis=1)
72 | sorted_result = np.argsort(np.array(predict), axis=1)
73 | results=poem_+np.asarray([1,1,1,1,1,1,-32])
74 | print(sorted_result)
75 | print("output: %s"%results)
76 | return poem_
77 |
78 |
79 |
80 |
81 | if __name__ == '__main__':
82 | # begin_char = input('## please input the first character:')
83 | # poem=gen_blue()#13
84 | poem = gen_poem()
85 | # pretty_print_poem(poem_=poem)
86 |
--------------------------------------------------------------------------------
/ssq4all_v4.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 |
3 | import os
4 | import sys
5 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..")))
6 | import numpy as np
7 | import tensorflow as tf
8 | from poems.model import rnn_model
9 | from poems.resnet import *
10 | from poems.poems import process_poems, generate_batch
11 | from ssq_data import *
12 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48)
13 | import win_unicode_console
14 | win_unicode_console.enable()
15 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.')
16 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.')
17 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all_v4'), 'model save path.')
18 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
19 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
20 | tf.app.flags.DEFINE_integer('epochs', 2000001, 'train how many epochs.')
21 | tf.app.flags.DEFINE_integer('times', 10, 'train how many epochs.')
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 |
25 | def run_training():
26 | if not os.path.exists(FLAGS.model_dir):
27 | os.makedirs(FLAGS.model_dir)
28 | #
29 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
30 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
31 | # ssqdata=get_exl_data(random_order=False,use_resnet=True)
32 | # # print(ssqdata[len(ssqdata)-1])
33 | # batches_inputs=ssqdata[0:(len(ssqdata)-1)]
34 | ssqdata=get_exl_data(random_order=False,use_resnet=False)
35 | # ssqdata=get_dlt_data(random_order=False,use_resnet=True)
36 | # print(ssqdata[len(ssqdata)-1])
37 | batches_inputs=ssqdata[0:(len(ssqdata)-1)]
38 | # ssqdata=get_dlt_data(random_order=False,use_resnet=False)
39 | batches_outputs = ssqdata[1:(len(ssqdata))]
40 | ssqdata=get_exl_data_by_period(random_order=False,use_resnet=True,times=FLAGS.times)
41 | batches_inputs=ssqdata[0:(len(ssqdata)-1)]
42 | FLAGS.batch_size=len(batches_inputs)
43 | # print(np.shape(batches_outputs))
44 | # data=batches_outputs[1:7]
45 | # print(len(data))
46 | del ssqdata
47 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.times,7+8,1])
48 | logits = inference(input_data, 1, reuse=False,output_num=128)
49 |
50 | # print(tf.shape(input_data))
51 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
52 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7,
53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate)
54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
56 |
57 | saver = tf.train.Saver(tf.global_variables())
58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
59 | with tf.Session() as sess:
60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
62 | sess.run(init_op)
63 |
64 | start_epoch = 0
65 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368")
66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
67 | if checkpoint:
68 | saver.restore(sess, checkpoint)
69 | print("## restore from the checkpoint {0}".format(checkpoint))
70 | start_epoch += int(checkpoint.split('-')[-1])
71 | print('## start training...')
72 | try:
73 | # for epoch in range(start_epoch, FLAGS.epochs):
74 | epoch = start_epoch
75 | FLAGS.epochs=epoch+100000
76 | print('till',FLAGS.epochs)
77 | # for epoch in range(start_epoch, FLAGS.epochs):
78 | while(epoch len(vocabs):
35 | sample = len(vocabs) - 1
36 | return vocabs[sample]
37 |
38 |
39 | def gen_poem():
40 | batch_size = 1
41 | print('## loading model from %s' % model_dir)
42 |
43 | input_data = tf.placeholder(tf.int32, [batch_size, None])
44 |
45 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33+16,
46 | rnn_size=128, output_num=7,input_num=7,num_layers=7, batch_size=1, learning_rate=0.01)
47 |
48 | saver = tf.train.Saver(tf.global_variables())
49 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
50 | with tf.Session() as sess:
51 | sess.run(init_op)
52 |
53 | checkpoint = tf.train.latest_checkpoint(model_dir)
54 | saver.restore(sess, checkpoint)
55 | ssqdata = get_exl_data()
56 | # x = np.array([list(map(word_int_map.get, start_token))])
57 | x=[ssqdata[len(ssqdata)-1]]
58 | print("input: %s"%(x+np.asarray([1,1,1,1,1,1,-32])))
59 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
60 | feed_dict={input_data: x})
61 | poem_=np.argmax(np.array(predict),axis=1)
62 | results=poem_+np.asarray([1,1,1,1,1,1,-32])
63 | print("output:%s"%results)
64 | return poem_
65 |
66 |
67 |
68 | if __name__ == '__main__':
69 | # begin_char = input('## please input the first character:')
70 | poem = gen_poem()
71 | # pretty_print_poem(poem_=poem)
--------------------------------------------------------------------------------