├── ssq.xls ├── dlt2.xls ├── LICENSE ├── README.md ├── hyper_parameters.py ├── ssq_test.py ├── poems ├── poems.py ├── model.py └── resnet.py ├── ssq4all_test_v4.py ├── dlt4all_test.py ├── ssq4all_test.py ├── ssq.py ├── ssq4all.py ├── ssq4all_v4.py ├── dlt4all.py └── ssq_data.py /ssq.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-yc/ssq/HEAD/ssq.xls -------------------------------------------------------------------------------- /dlt2.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-yc/ssq/HEAD/dlt2.xls -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yanchao Liang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 利用神经网络和LSTM预测双色球(How To predict the China's Union Lotto with Neural Network and LSTM) 2 | ## 2020/12/4更新 3 | 因为问的人比较多了,所以提供了适配'tensorflow 2.X'版本的[分支](https://github.com/Liang-yc/ssq/tree/tf2). 4 | 5 | ## 2019/6/18更新 6 | 利用最近几次的双色球结果和奖池情况做预测,训练和测试见`ssq4all_v4.py`和`ssq4all_test_v4.py`。 7 | ## 2018/9/16更新 8 | 添加了大乐透的训练测试文件。 9 | ## 2018/5/30更新 10 | 根据反馈,修复了一些错误,删除了两个文件。另外,有问题请直接开issue。 11 | ## 2018/3/29更新 12 | 尝试用CNN和LSTM做预测。CNN用于提取特征,采用的是resnet。目前最好的结果是五等奖(中4个红号)。 13 | ## 背景(Background) 14 | ------ 15 | 这个项目是通过结合神经网络和Long Short-Term Memory(LSTM)完成双色球预测。关于双色球的介绍,在此处不在赘述(参见[网站](https://baike.baidu.com/item/中国福利彩票双色球/8676030?fr=aladdin&fromid=75279&fromtitle=%E5%8F%8C%E8%89%B2%E7%90%83))。该项目以真实的双色球开奖结果作为输入(7个数值,其中前6个表示红号,最后一个表示蓝号),输出预测结果(输出仍为7个数值,其中前6个表示红号,最后一个表示蓝号)。目前该项目还处于开发中。

16 | 其中,项目核心代码主要借鉴了这个[网站](https://github.com/jinfagang/tensorflow_poems) 。 17 | 截至目前,该项目充分训练后的预测结果中,正确预测的红号数目为0-2个(1个的居多,但鉴于这概率本身就偏高,因此不能算是取得很好的结果),正确预测的蓝号数目为0-1个(0个的居多),考虑到双色球的中奖条件(至少3红号正确或者至少蓝号正确),**不建议各位直接根据预测结果去买**。
18 | 预测结果不好有以下几个原因: 19 | 1. 双色球本身是一个随机事件,而我却希望从中找到规律; 20 | 2. 数据集样本数量不够多且无法有效扩大数据集; 21 | 3. 可能我的模型还不够优秀。 22 | 23 | 如果您有什么好的建议或者有什么问题,请直接发我邮箱(yc_liang@qq.com)或者在issue留言也可以;
24 | 如果您觉得这项目似乎还有点意思,记得STAR;
25 | 当然,如果您真的中奖了,要求不高, STAR。
26 | 27 | ## 项目依赖(Requirement) 28 | ----- 29 | 30 | 1.Tensorflow;
31 | 2.pyexcel_xls;
32 | 3.CUDA.(optional)
33 | 该项目的代码可能还需要一些python包但我没列出来,请自行pip一下。

34 | 35 | ## 项目文件(File contents) 36 | ----- 37 | 38 | 1.` poems` : ` poems`文件夹包含3个模型文件,其中`model.py`是本项目所使用的模型文件;其中`resnet.py`定义的是resnet模型;
39 | 2.`ssq.py `: 该文件用于训练双色球模型;
40 | 3.`ssq.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
41 | 4.`ssq_data.py ` : 读取`ssq.xls`文件中的数据并变成所需要的形式.
42 | 5.`ssq_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
43 | 6.`ssq4all.py `: 该文件用于训练双色球模型(和`ssq.py `功能一样,但用的模型不一样);
44 | 7.`ssq4all_test.py`: 以最近一次的双色球开奖结果作为输入,输出模型产生的预测结果;
45 | 6.`dlt.xls` : 该文件储存了历次双色球的开奖和中奖数据,需要利用宏进行数据更新;
46 | 7.`dlt4all.py `: 该文件用于训练大乐透模型;
47 | 8.`dlt4all_test.py`: 以最近一次的大乐透开奖结果作为输入,输出模型产生的预测结果;
48 | 49 | ## 训练及测试(How To Use) 50 | ----- 51 | 设置好参数,运行ssq.py(建议在控制台运行,这样训练到一半可以直接通过按CTRL+C取消训练并保存结果;这样下次训练时可以从上一次的checkpoint继续训练)。训练后,项目目录下会生成一个model文件夹用于保存模型。然后运行ssq_test.py进行预测就可以了。
52 | 该项目在W10, tensorflow 1.4/1.5正常运行。有无GPU均可运行。原则上讲,这项目的模型并不大,且数据量也不大,所以可以直接将所有训练数据作为一个batch进行训练,后续如果需要加深模型,可能要降低batch size。 53 | -------------------------------------------------------------------------------- /hyper_parameters.py: -------------------------------------------------------------------------------- 1 | # Coder: Wenxin Xu 2 | # Github: https://github.com/wenxinxu/resnet_in_tensorflow 3 | # ============================================================================== 4 | import tensorflow as tf 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | 8 | ## The following flags are related to save paths, tensorboard outputs and screen outputs 9 | 10 | tf.app.flags.DEFINE_string('version', 'test_110', '''A version number defining the directory to save 11 | logs and checkpoints''') 12 | tf.app.flags.DEFINE_integer('report_freq', 391, '''Steps takes to output errors on the screen 13 | and write summaries''') 14 | tf.app.flags.DEFINE_float('train_ema_decay', 0.95, '''The decay factor of the train error's 15 | moving average shown on tensorboard''') 16 | 17 | 18 | ## The following flags define hyper-parameters regards training 19 | 20 | tf.app.flags.DEFINE_integer('train_steps', 80000, '''Total steps that you want to train''') 21 | tf.app.flags.DEFINE_boolean('is_full_validation', False, '''Validation w/ full validation set or 22 | a random batch''') 23 | tf.app.flags.DEFINE_integer('train_batch_size', 128, '''Train batch size''') 24 | tf.app.flags.DEFINE_integer('validation_batch_size', 250, '''Validation batch size, better to be 25 | a divisor of 10000 for this task''') 26 | tf.app.flags.DEFINE_integer('test_batch_size', 125, '''Test batch size''') 27 | 28 | tf.app.flags.DEFINE_float('init_lr', 0.1, '''Initial learning rate''') 29 | tf.app.flags.DEFINE_float('lr_decay_factor', 0.1, '''How much to decay the learning rate each 30 | time''') 31 | tf.app.flags.DEFINE_integer('decay_step0', 40000, '''At which step to decay the learning rate''') 32 | tf.app.flags.DEFINE_integer('decay_step1', 60000, '''At which step to decay the learning rate''') 33 | 34 | 35 | ## The following flags define hyper-parameters modifying the training network 36 | 37 | tf.app.flags.DEFINE_integer('num_residual_blocks', 5, '''How many residual blocks do you want''') 38 | tf.app.flags.DEFINE_float('weight_decay', 0.0002, '''scale for l2 regularization''') 39 | 40 | 41 | ## The following flags are related to data-augmentation 42 | 43 | tf.app.flags.DEFINE_integer('padding_size', 2, '''In data augmentation, layers of zero padding on 44 | each side of the image''') 45 | 46 | 47 | ## If you want to load a checkpoint and continue training 48 | 49 | tf.app.flags.DEFINE_string('ckpt_path', 'cache/logs_repeat20/model.ckpt-100000', '''Checkpoint 50 | directory to restore''') 51 | tf.app.flags.DEFINE_boolean('is_use_ckpt', False, '''Whether to load a checkpoint and continue 52 | training''') 53 | 54 | tf.app.flags.DEFINE_string('test_ckpt_path', 'model_110.ckpt-79999', '''Checkpoint 55 | directory to restore''') 56 | 57 | 58 | train_dir = 'logs_' + FLAGS.version + '/' 59 | -------------------------------------------------------------------------------- /ssq_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.poems import process_poems 22 | import numpy as np 23 | from ssq_data import * 24 | start_token = 'B' 25 | end_token = 'E' 26 | model_dir = './model/' 27 | corpus_file = './data/poems.txt' 28 | 29 | 30 | def to_word(predict, vocabs): 31 | t = np.cumsum(predict) 32 | s = np.sum(predict) 33 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 34 | if sample > len(vocabs): 35 | sample = len(vocabs) - 1 36 | return vocabs[sample] 37 | 38 | 39 | def gen_poem(): 40 | batch_size = 1 41 | print('## loading model from %s' % model_dir) 42 | 43 | input_data = tf.placeholder(tf.int32, [batch_size, None]) 44 | 45 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33+16, 46 | rnn_size=128, output_num=7,input_num=7,num_layers=7, batch_size=1, learning_rate=0.01) 47 | 48 | saver = tf.train.Saver(tf.global_variables()) 49 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 50 | with tf.Session() as sess: 51 | sess.run(init_op) 52 | 53 | checkpoint = tf.train.latest_checkpoint(model_dir) 54 | saver.restore(sess, checkpoint) 55 | ssqdata = get_exl_data() 56 | # x = np.array([list(map(word_int_map.get, start_token))]) 57 | x=[ssqdata[len(ssqdata)-1]] 58 | print("input: %s"%(x+np.asarray([1,1,1,1,1,1,-32]))) 59 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 60 | feed_dict={input_data: x}) 61 | poem_=np.argmax(np.array(predict),axis=1) 62 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 63 | print("output:%s"%results) 64 | return poem_ 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | # begin_char = input('## please input the first character:') 70 | poem = gen_poem() 71 | # pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /poems/poems.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: poems.py 3 | # author: JinTian 4 | # time: 08/03/2017 7:39 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import collections 20 | import os 21 | import sys 22 | import numpy as np 23 | 24 | start_token = 'B' 25 | end_token = 'E' 26 | 27 | 28 | def process_poems(file_name): 29 | # 诗集 30 | poems = [] 31 | with open(file_name, "r", encoding='utf-8', ) as f: 32 | for line in f.readlines(): 33 | try: 34 | title, content = line.strip().split(':') 35 | content = content.replace(' ', '') 36 | if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \ 37 | start_token in content or end_token in content: 38 | continue 39 | if len(content) < 5 or len(content) > 79: 40 | continue 41 | content = start_token + content + end_token 42 | poems.append(content) 43 | except ValueError as e: 44 | pass 45 | poems = sorted(poems, key=lambda l: len(line)) 46 | 47 | all_words = [] 48 | for poem in poems: 49 | all_words += [word for word in poem] 50 | counter = collections.Counter(all_words) 51 | count_pairs = sorted(counter.items(), key=lambda x: -x[1]) 52 | words, _ = zip(*count_pairs) 53 | 54 | words = words[:len(words)] + (' ',) 55 | word_int_map = dict(zip(words, range(len(words)))) 56 | poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems] 57 | 58 | return poems_vector, word_int_map, words 59 | 60 | 61 | def generate_batch(batch_size, poems_vec, word_to_int): 62 | n_chunk = len(poems_vec) // batch_size 63 | x_batches = [] 64 | y_batches = [] 65 | for i in range(n_chunk): 66 | start_index = i * batch_size 67 | end_index = start_index + batch_size 68 | 69 | batches = poems_vec[start_index:end_index] 70 | length = max(map(len, batches)) 71 | x_data = np.full((batch_size, length), word_to_int[' '], np.int32) 72 | for row in range(batch_size): 73 | x_data[row, :len(batches[row])] = batches[row] 74 | y_data = np.copy(x_data) 75 | y_data[:, :-1] = x_data[:, 1:] 76 | """ 77 | x_data y_data 78 | [6,2,4,6,9] [2,4,6,9,9] 79 | [1,4,2,8,5] [4,2,8,5,5] 80 | """ 81 | x_batches.append(x_data) 82 | y_batches.append(y_data) 83 | return x_batches, y_batches 84 | -------------------------------------------------------------------------------- /ssq4all_test_v4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all_v2/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 10,7+8,1]) 44 | logits = inference(input_data, 1, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 49 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.001) 50 | 51 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 52 | # 53 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 54 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 55 | 56 | saver = tf.train.Saver(tf.global_variables()) 57 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 58 | with tf.Session() as sess: 59 | sess.run(init_op) 60 | 61 | checkpoint = tf.train.latest_checkpoint('./model4all_v4/') 62 | saver.restore(sess, checkpoint) 63 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368") 64 | # ssqdata = get_exl_data(random_order=True,use_resnet=True) 65 | # ssqdata = get_exl_data_v3(random_order=True, use_resnet=True) 66 | ssqdata = get_exl_data_by_period(random_order=True, use_resnet=True, times=10) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32],[0],[0],[0],[0],[0],[0],[0],[0]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 74 | print(sorted_result) 75 | print("output: %s"%results) 76 | return poem_ 77 | 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | # begin_char = input('## please input the first character:') 83 | # poem=gen_blue()#13 84 | poem = gen_poem() 85 | # pretty_print_poem(poem_=poem) 86 | -------------------------------------------------------------------------------- /dlt4all_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 44 | logits = inference(input_data, 1, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=35+12, 49 | output_num=7, 50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 51 | 52 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 53 | # 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | sess.run(init_op) 61 | 62 | checkpoint = tf.train.latest_checkpoint('./dlt_model/') 63 | saver.restore(sess, checkpoint) 64 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4red/poems-150000") 65 | ssqdata = get_dlt_data(random_order=False,use_resnet=True) 66 | # x = np.array([list(map(word_int_map.get, start_token))]) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[-34],[-34]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,-34,-34]) 74 | # print(sorted_result) 75 | print("output: %s"%results) 76 | return poem_ 77 | 78 | def gen_blue(): 79 | batch_size = 1 80 | print('## loading model from %s' % model_dir) 81 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 82 | logits = inference(input_data, 10, reuse=False,output_num=128) 83 | 84 | # print(tf.shape(input_data)) 85 | output_targets = tf.placeholder(tf.int32, [1, None]) 86 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33, 87 | output_num=1, 88 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01) 89 | 90 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 91 | # 92 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 93 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 94 | 95 | saver = tf.train.Saver(tf.global_variables()) 96 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 97 | with tf.Session() as sess: 98 | sess.run(init_op) 99 | 100 | checkpoint = tf.train.latest_checkpoint('./model4all/') 101 | saver.restore(sess, checkpoint) 102 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201") 103 | ssqdata = get_exl_data(random_order=False,use_resnet=True) 104 | # x = np.array([list(map(word_int_map.get, start_token))]) 105 | x=[ssqdata[len(ssqdata)-1]] 106 | print("input: %s"%(x)) 107 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 108 | feed_dict={input_data: x}) 109 | poem_=np.argmax(np.array(predict),axis=1) 110 | sorted_result=np.argsort(np.array(predict),axis=1) 111 | results=poem_+np.ones(1) 112 | print(results) 113 | print(sorted_result) 114 | return poem_ 115 | 116 | 117 | if __name__ == '__main__': 118 | # begin_char = input('## please input the first character:') 119 | # poem=gen_blue()#13 120 | poem = gen_poem() 121 | # pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /ssq4all_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.resnet import * 22 | from poems.poems import process_poems 23 | import numpy as np 24 | from ssq_data import * 25 | start_token = 'B' 26 | end_token = 'E' 27 | model_dir = './model4all/' 28 | corpus_file = './data/poems.txt' 29 | 30 | 31 | def to_word(predict, vocabs): 32 | t = np.cumsum(predict) 33 | s = np.sum(predict) 34 | sample = int(np.searchsorted(t, np.random.rand(1) * s)) 35 | if sample > len(vocabs): 36 | sample = len(vocabs) - 1 37 | return vocabs[sample] 38 | 39 | 40 | def gen_poem(): 41 | batch_size = 1 42 | print('## loading model from %s' % model_dir) 43 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 44 | logits = inference(input_data, 2, reuse=False,output_num=128) 45 | 46 | # print(tf.shape(input_data)) 47 | output_targets = tf.placeholder(tf.int32, [1, None]) 48 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16, 49 | output_num=7, 50 | rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 51 | 52 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 53 | # 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 55 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | sess.run(init_op) 61 | 62 | checkpoint = tf.train.latest_checkpoint('./model4all/') 63 | saver.restore(sess, checkpoint) 64 | # saver.restore(sess, "F:/tensorflow_poems-master/model4all/poems-401713") 65 | ssqdata = get_exl_data(random_order=True,use_resnet=True) 66 | # x = np.array([list(map(word_int_map.get, start_token))]) 67 | x=[ssqdata[len(ssqdata)-1]] 68 | print("input: %s"%(x+np.asarray([[[[1],[1],[1],[1],[1],[1],[-32]]]]))) 69 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 70 | feed_dict={input_data: x}) 71 | poem_=np.argmax(np.array(predict),axis=1) 72 | sorted_result = np.argsort(np.array(predict), axis=1) 73 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 74 | print(sorted_result) 75 | print("output: %s"%results) 76 | poem_=np.argmin(np.array(predict),axis=1) 77 | results=poem_+np.asarray([1,1,1,1,1,1,-32]) 78 | print("min output:%s"%results) 79 | return poem_ 80 | 81 | def gen_blue(): 82 | batch_size = 1 83 | print('## loading model from %s' % model_dir) 84 | input_data = tf.placeholder(tf.float32, [1, 1,7,1]) 85 | logits = inference(input_data, 10, reuse=False,output_num=128) 86 | 87 | # print(tf.shape(input_data)) 88 | output_targets = tf.placeholder(tf.int32, [1, None]) 89 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33, 90 | output_num=1, 91 | rnn_size=128, num_layers=3, batch_size=1, learning_rate=0.01) 92 | 93 | # input_data = tf.placeholder(tf.int32, [batch_size, None]) 94 | # 95 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=33, 96 | # rnn_size=128, num_layers=7, batch_size=1, learning_rate=0.01) 97 | 98 | saver = tf.train.Saver(tf.global_variables()) 99 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 100 | with tf.Session() as sess: 101 | sess.run(init_op) 102 | 103 | checkpoint = tf.train.latest_checkpoint('./model4all/') 104 | saver.restore(sess, checkpoint) 105 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4blue/poems-84201") 106 | ssqdata = get_exl_data(random_order=True,use_resnet=True) 107 | # x = np.array([list(map(word_int_map.get, start_token))]) 108 | x=[ssqdata[len(ssqdata)-1]] 109 | 110 | print("input: %s"%(x)) 111 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 112 | feed_dict={input_data: x}) 113 | poem_=np.argmax(np.array(predict),axis=1) 114 | sorted_result=np.argsort(np.array(predict),axis=1) 115 | results=poem_+np.ones(1) 116 | print(results) 117 | print(sorted_result) 118 | return poem_ 119 | 120 | 121 | if __name__ == '__main__': 122 | # begin_char = input('## please input the first character:') 123 | # poem=gen_blue()#13 124 | poem = gen_poem() 125 | # pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /ssq.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from poems.model import rnn_model 23 | from poems.poems import process_poems, generate_batch 24 | from ssq_data import * 25 | tf.app.flags.DEFINE_integer('batch_size', -1, 'batch size.') 26 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 27 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.') 28 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 29 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 30 | tf.app.flags.DEFINE_integer('epochs', 150000, 'train how many epochs.') 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | 35 | def run_training(): 36 | # if not os.path.exists(FLAGS.model_dir): 37 | # os.makedirs(FLAGS.model_dir) 38 | # 39 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 40 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 41 | ssqdata=get_exl_data(random_order=True) 42 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 43 | batches_outputs = ssqdata[1:(len(ssqdata))] 44 | FLAGS.batch_size=len(ssqdata)-1 45 | # data=batches_outputs[1:7] 46 | # print(len(data)) 47 | del ssqdata 48 | input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 49 | # print(tf.shape(input_data)) 50 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 51 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=33+16, 52 | output_num=7,input_num=7, 53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 62 | sess.run(init_op) 63 | 64 | start_epoch = 0 65 | # saver.restore(sess, "F:/tensorflow_poems-master/model/poems-100000") 66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 67 | if checkpoint: 68 | saver.restore(sess, checkpoint) 69 | print("## restore from the checkpoint {0}".format(checkpoint)) 70 | start_epoch += int(checkpoint.split('-')[-1]) 71 | print('## start training...') 72 | try: 73 | for epoch in range(start_epoch, FLAGS.epochs): 74 | n = 0 75 | # n_chunk = len(poems_vector) // FLAGS.batch_size 76 | # n_chunk = len(batches_inputs) // FLAGS.batch_size 77 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size) 78 | for batch in range(n_chunk): 79 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs) 80 | if left<0: 81 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 82 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 83 | else: 84 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ] 85 | # temp.extend(batches_inputs[0:left]) 86 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)] 87 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ] 88 | # temp.extend(batches_outputs[0:left]) 89 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)] 90 | # print(len(inputdata)) 91 | loss, _, _ = sess.run([ 92 | end_points['total_loss'], 93 | end_points['last_state'], 94 | end_points['train_op'] 95 | ], feed_dict={input_data: inputdata, output_targets: outputdata}) 96 | # # ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}) 97 | n += 1 98 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 99 | if epoch % 5000 == 0: 100 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 101 | except KeyboardInterrupt: 102 | print('## Interrupt manually, try saving checkpoint for now...') 103 | finally: 104 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 105 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 106 | 107 | 108 | def main(_): 109 | run_training() 110 | 111 | 112 | if __name__ == '__main__': 113 | tf.app.run() 114 | -------------------------------------------------------------------------------- /ssq4all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from poems.model import rnn_model 23 | from poems.resnet import * 24 | from poems.poems import process_poems, generate_batch 25 | from ssq_data import * 26 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48) 27 | # import win_unicode_console 28 | # win_unicode_console.enable() 29 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.') 30 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 31 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all'), 'model save path.') 32 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 33 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 34 | tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.') 35 | 36 | FLAGS = tf.app.flags.FLAGS 37 | 38 | 39 | def run_training(): 40 | if not os.path.exists(FLAGS.model_dir): 41 | os.makedirs(FLAGS.model_dir) 42 | # 43 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 44 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 45 | ssqdata=get_exl_data(random_order=False,use_resnet=True) 46 | # print(ssqdata[len(ssqdata)-1]) 47 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 48 | ssqdata=get_exl_data(random_order=False,use_resnet=False) 49 | batches_outputs = ssqdata[1:(len(ssqdata))] 50 | FLAGS.batch_size=len(batches_inputs) 51 | # print(np.shape(batches_outputs)) 52 | # data=batches_outputs[1:7] 53 | # print(len(data)) 54 | del ssqdata 55 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1]) 56 | logits = inference(input_data, 2, reuse=False,output_num=128) 57 | 58 | # print(tf.shape(input_data)) 59 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 60 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 61 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 62 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 63 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 64 | 65 | saver = tf.train.Saver(tf.global_variables()) 66 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 67 | with tf.Session() as sess: 68 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 69 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 70 | sess.run(init_op) 71 | 72 | start_epoch = 0 73 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 74 | if checkpoint: 75 | saver.restore(sess, checkpoint) 76 | print("## restore from the checkpoint {0}".format(checkpoint)) 77 | start_epoch += int(checkpoint.split('-')[-1]) 78 | print('## start training...') 79 | try: 80 | for epoch in range(start_epoch, FLAGS.epochs): 81 | n = 0 82 | # n_chunk = len(poems_vector) // FLAGS.batch_size 83 | # n_chunk = len(batches_inputs) // FLAGS.batch_size 84 | n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size) 85 | for batch in range(n_chunk): 86 | left=(batch+1)*FLAGS.batch_size-len(batches_inputs) 87 | if left<0: 88 | inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 89 | outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)] 90 | else: 91 | # temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ] 92 | # temp.extend(batches_inputs[0:left]) 93 | inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)] 94 | # temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ] 95 | # temp.extend(batches_outputs[0:left]) 96 | outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)] 97 | # print(len(inputdata)) 98 | loss, _, _ = sess.run([ 99 | end_points['total_loss'], 100 | end_points['last_state'], 101 | end_points['train_op'] 102 | ], feed_dict={input_data: inputdata, output_targets: outputdata}) 103 | # ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs}) 104 | n += 1 105 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 106 | if epoch % 50000 == 0: 107 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 108 | except KeyboardInterrupt: 109 | print('## Interrupt manually, try saving checkpoint for now...') 110 | finally: 111 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 112 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 113 | 114 | 115 | def main(_): 116 | run_training() 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.app.run() 121 | -------------------------------------------------------------------------------- /ssq4all_v4.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(),".."))) 6 | import numpy as np 7 | import tensorflow as tf 8 | from poems.model import rnn_model 9 | from poems.resnet import * 10 | from poems.poems import process_poems, generate_batch 11 | from ssq_data import * 12 | # for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48) 13 | import win_unicode_console 14 | win_unicode_console.enable() 15 | tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.') 16 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.') 17 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all_v4'), 'model save path.') 18 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 19 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 20 | tf.app.flags.DEFINE_integer('epochs', 2000001, 'train how many epochs.') 21 | tf.app.flags.DEFINE_integer('times', 10, 'train how many epochs.') 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | 25 | def run_training(): 26 | if not os.path.exists(FLAGS.model_dir): 27 | os.makedirs(FLAGS.model_dir) 28 | # 29 | # poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 30 | # batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 31 | # ssqdata=get_exl_data(random_order=False,use_resnet=True) 32 | # # print(ssqdata[len(ssqdata)-1]) 33 | # batches_inputs=ssqdata[0:(len(ssqdata)-1)] 34 | ssqdata=get_exl_data(random_order=False,use_resnet=False) 35 | # ssqdata=get_dlt_data(random_order=False,use_resnet=True) 36 | # print(ssqdata[len(ssqdata)-1]) 37 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 38 | # ssqdata=get_dlt_data(random_order=False,use_resnet=False) 39 | batches_outputs = ssqdata[1:(len(ssqdata))] 40 | ssqdata=get_exl_data_by_period(random_order=False,use_resnet=True,times=FLAGS.times) 41 | batches_inputs=ssqdata[0:(len(ssqdata)-1)] 42 | FLAGS.batch_size=len(batches_inputs) 43 | # print(np.shape(batches_outputs)) 44 | # data=batches_outputs[1:7] 45 | # print(len(data)) 46 | del ssqdata 47 | input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.times,7+8,1]) 48 | logits = inference(input_data, 1, reuse=False,output_num=128) 49 | 50 | # print(tf.shape(input_data)) 51 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 52 | end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7, 53 | rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate) 54 | # end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 55 | # vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 56 | 57 | saver = tf.train.Saver(tf.global_variables()) 58 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 59 | with tf.Session() as sess: 60 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 61 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 62 | sess.run(init_op) 63 | 64 | start_epoch = 0 65 | # saver.restore(sess, "E:/workplace/tensorflow_poems-master/model4all/poems-208368") 66 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 67 | if checkpoint: 68 | saver.restore(sess, checkpoint) 69 | print("## restore from the checkpoint {0}".format(checkpoint)) 70 | start_epoch += int(checkpoint.split('-')[-1]) 71 | print('## start training...') 72 | try: 73 | # for epoch in range(start_epoch, FLAGS.epochs): 74 | epoch = start_epoch 75 | FLAGS.epochs=epoch+100000 76 | print('till',FLAGS.epochs) 77 | # for epoch in range(start_epoch, FLAGS.epochs): 78 | while(epoch