├── .gitignore ├── README.md ├── compose_poem.py ├── data ├── poems.txt ├── qijue-all.txt ├── qilv-all.txt ├── wujue-all.txt └── wulv-all.txt ├── model ├── checkpoint ├── poems-0.data-00000-of-00001 ├── poems-0.index ├── poems-0.meta ├── poems-1.data-00000-of-00001 ├── poems-1.index ├── poems-1.meta ├── poems-2.data-00000-of-00001 └── poems-2.index ├── poems ├── __pycache__ │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── poems.cpython-37.pyc │ └── poems.cpython-38.pyc ├── model.py └── poems.py ├── test.py ├── train.py └── utils ├── clean_cn.py └── make_regulated_verse.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | model/ 3 | .idea/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Welcome to LiBai AI Composer 👋

2 |

3 | Version 4 | 5 | License: Apache 6 | 7 |

8 | 9 | > An ai powered automatically generats poems in Chinese. 10 | > 11 | > 很久以来,我们都想让机器自己创作诗歌,当无数作家、编辑还没有抬起笔时,AI已经完成了数千篇文章。现在,这里是第一步.... 12 | 13 | ### 🏠 [Homepage](https://github.com/jinfagang/tensorflow_poems) 14 | 15 | ## 👍 Outcome 结果 16 | 17 | 阅遍了近4万首唐诗,作出: 18 | 19 | ``` 20 | 龙舆迎池里,控列守龙猱。 21 | 几岁芳篁落,来和晚月中。 22 | 殊乘暮心处,麦光属激羁。 23 | 铁门通眼峡,高桂露沙连。 24 | 倘子门中望,何妨嶮锦楼。 25 | 择闻洛臣识,椒苑根觞吼。 26 | 柳翰天河酒,光方入胶明。 27 | ``` 28 | 29 | 这诗做的很有感觉啊,这都是勤奋的结果啊,基本上学习了全唐诗的所有精华才有了这么牛逼的能力,这一般人能做到? 30 | 本博客讲讲解一些里面实现的技术细节,如果有未尽之处,大家可以通过微信找到我,那个头像很神奇的男人。闲话不多说,先把 github 链接放上来,这个作诗机器人我会一直维护,如果大家因为时间太紧没有时间看,可以给这个项目 star 一下或者 fork, 31 | 我一推送更新你就能看到,主要是为了修复一些 api 问题,tensorflow 虽然到了1.0,但是 api 还是会变化。 32 | 把星星加起来,让更多人可以看到我们创造这个作诗机器人,后期会加入更多牛逼掉渣天的功能,比如说押韵等等。 33 | 34 | ## 📥 Install 安装 35 | 36 | ```sh 37 | git clone https://github.com/jinfagang/tensorflow_poems.git 38 | ``` 39 | 40 | ## 🛠 Usage 使用 41 | 42 | ```sh 43 | # train on poems 训练 44 | python3 train.py 45 | # compose poems 作诗 46 | python3 compose_poem.py 47 | ``` 48 | 49 | 训练的时候,你可能会看到如下: 50 | 51 | When you kick it off, you will see something like this: 52 | 53 | ![](https://i.loli.net/2018/03/12/5aa5fd903c041.jpeg) 54 | 55 | ## 📈 Updates 更新 56 | 57 | #### 2018-8-16 58 | 59 | We are now officially announced a new project started: **StrangeAI School** - An artificial intelligence learning school and advanced algorithm exchange platform! What we believed in is: AI should made to change people's life, rather than controlled by Gaint Companies. 60 | Here you can get some previews about our projects: http://ai.loliloli.pro (strangeai.pro availiable soon) 61 | 62 | #### 2018-3-12 63 | 64 | **tensorflow_poems**来诈尸了,许久没有更新这个项目,不知不觉已经有了上千个star,感觉大家对这个还是很感兴趣,在这里我非常荣幸大家关注这个项目,但是我们不能因此而停止不前,这也是我来诈尸的目的。我会向大家展示一下我最新的进展,首先非常希望大家关注一下我倾心做的知乎专栏,人工智能从入门到逆天杀神以及每周一个黑科技,我们不仅仅要关注人工智能,还有区块链等前沿技术: 65 | 66 | - 人工智能从入门到逆天杀神(知乎专栏): https://zhuanlan.zhihu.com/ai-man 67 | - 每周一项目黑科技-TrackTech(知乎专栏): https://zhuanlan.zhihu.com/tracktech 68 | If you want talk about AI, visit our website (for now): http://ai.loliloli.pro (strangeai.pro availiable soon) 69 | , **subscribe** our WeChat channel: 奇异人工智能学院 70 | 71 | #### 2017-11-8 72 | 73 | 貌似距离上一次更新这个repo已经很久了,这段时间很多童鞋通过微信找到了我,甚至包括一些大佬。当时这个项目只是一个练手的东西,那个时候我的手法还不是非常老道。让各位踩坑了。现在**李白**强势归来。在这次的更新中增加了这些改进: 74 | 75 | - 对数据预处理脚本进行了前所未有的简化,现在连小学生都能了解了 76 | - 训练只需要运行train.py,数据和预训练模型都已经备好 77 | - 可以直接compose_poem.py 作诗,这次不会出现死循环的情况了。 78 | 79 | #### 2017-6-1 ~~可能是最后一次更新~~ 80 | 81 | 我决定有时间的时候重构这个项目了,古诗,源自在下骨子里的文艺之风,最近搞得东西有点乱,所以召集大家,对这个项目感兴趣的欢迎加入扣扣群: 82 | 83 | ``` 84 | 292889553 85 | ``` 86 | 87 | 88 | #### 2017-3-22 重磅更新,推出藏头诗功能 89 | 90 | 一波小更新,下面的问题已经解决了: 91 | 92 | * 训练完成作诗时出现一直不出现的情况,实际上是陷入了一直作诗的死循环,已修复 93 | * 新增pretty print功能,打印出的古诗标准,接入第三方APP或者其他平台可以直接获取到标准格式的诗词 94 | * Ternimal disable了tensorflow默认的debug信息 95 | 最后最后最重要的是: **我们的作诗机器人(暂且叫李白)已经可以根据你的指定的字作诗了哦!!** 96 | 欢迎大家继续来踩,没有star的快star!!保持更新!!永远开源!!! 97 | 让我们来看看李白做的藏头诗吧: 98 | 99 | ``` 100 | # 最近一直下雨,就作一首雨字开头的吧 101 | 雨霁开门中,山听淮水流。 102 | 落花遍霜霰,金壶横河湟。 103 | 年年忽息世,径远谁论吟。 104 | 惊舟望秋月,应柳待晨围。 105 | 人处山霜月,萧萧广野虚。 106 | 107 | # 李白人工智能作诗机器人的作者长得比较帅,以帅开头做一首吧 108 | 帅主何幸化,自日兼春连。 109 | 命钱犯夕兴,职馀玄赏圣。 110 | 君有不知益,浮于但神衍。 111 | (浓浓的怀才不遇之风...) 112 | ``` 113 | 114 | ## 👊 它已经不仅仅能够作古诗,还能模仿周杰伦创作歌词!! 115 | 116 | 这是2017-03-9更新的功能,模仿周杰伦歌曲创作歌词,大家先来感受一下它创作的歌词: 117 | 118 | ``` 119 | 我的你的她 120 | 蛾眉脚的泪花 121 | 乱飞从慌乱 122 | 笛卡尔的悲伤 123 | 迟早在是石板上 124 | 荒废了晚上 125 | 夜你的她不是她 126 | .... 127 | ``` 128 | 129 | 怎么说,目前由于缺乏训练文本,导致我们的AI做的歌词有点....额,还好啦,有那么一点忧郁之风,这个周杰伦完全不是一种风格呀。 130 | 然而没有关系,目前它训练的文本还太少,只有112首歌,在这里我来呼吁大家一起来整理 **中国歌手的语料文本!!!** 131 | 如果你喜欢周杰伦的歌,可以把他的歌一首一行,每首歌句子空格分开保存到txt中,大家可以集中发到我的[邮箱](mailto:jinfagang19@163.com): 132 | 相信如果不断的加入训练文本我们的歌词创作机器人会越来越牛逼!当然我会及时把数据集更新到github上,大家可以 star 一下跟进本项目的更新。 133 | 134 | ## 👥 Authors 作者 135 | 136 | 👤 **jinfagang** 137 | 138 | * Website: http://jinfagang.github.io 139 | * GitHub: [@JinTian](https://github.com/JinTian) 140 | 141 | 👤 **William Song** 142 | 143 | - Website: http://williamzjc.gitee.io/morninglake/ 144 | - GitHub: [@Freakwill](https://github.com/Freakwill) 145 | - Twitter: [@WilliamPython](https://twitter.com/WilliamPython) 146 | 147 | 👤 **Harvey Dam** 148 | 149 | - GitHub: [@damtharvey](https://github.com/damtharvey) 150 | 151 | 👤 **KnowsCount** 152 | 153 | - Website: http://docs.knowscount.cc/ 154 | - GitHub: [@KnowsCount](https://github.com/KnowsCount) 155 | 156 | ## 🎉 Show your support 支持 157 | 158 | 如果帮助了你,给颗 🌟 罢! 159 | 160 | Give a 🌟 if this project helped you! 161 | 162 | ## 📝 License 协议 163 | 164 | Copyright 版权 © 2020 165 | -------------------------------------------------------------------------------- /compose_poem.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | from poems.model import rnn_model 21 | from poems.poems import process_poems 22 | import numpy as np 23 | 24 | start_token = 'B' 25 | end_token = 'E' 26 | model_dir = './model/' 27 | corpus_file = './data/poems.txt' 28 | 29 | lr = 0.0002 30 | 31 | 32 | def to_word(predict, vocabs): 33 | predict = predict[0] 34 | predict /= np.sum(predict) 35 | sample = np.random.choice(np.arange(len(predict)), p=predict) 36 | if sample > len(vocabs): 37 | return vocabs[-1] 38 | else: 39 | return vocabs[sample] 40 | 41 | 42 | def gen_poem(begin_word): 43 | batch_size = 1 44 | print('## loading corpus from %s' % model_dir) 45 | poems_vector, word_int_map, vocabularies = process_poems(corpus_file) 46 | 47 | input_data = tf.placeholder(tf.int32, [batch_size, None]) 48 | 49 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len( 50 | vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr) 51 | 52 | saver = tf.train.Saver(tf.global_variables()) 53 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 54 | with tf.Session() as sess: 55 | sess.run(init_op) 56 | 57 | checkpoint = tf.train.latest_checkpoint(model_dir) 58 | saver.restore(sess, checkpoint) 59 | 60 | x = np.array([list(map(word_int_map.get, start_token))]) 61 | 62 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 63 | feed_dict={input_data: x}) 64 | word = begin_word or to_word(predict, vocabularies) 65 | poem_ = '' 66 | 67 | i = 0 68 | while word != end_token: 69 | poem_ += word 70 | i += 1 71 | if i > 24: 72 | break 73 | x = np.array([[word_int_map[word]]]) 74 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']], 75 | feed_dict={input_data: x, end_points['initial_state']: last_state}) 76 | word = to_word(predict, vocabularies) 77 | 78 | return poem_ 79 | 80 | 81 | def pretty_print_poem(poem_): 82 | poem_sentences = poem_.split('。') 83 | for s in poem_sentences: 84 | if s != '' and len(s) > 10: 85 | print(s + '。') 86 | 87 | if __name__ == '__main__': 88 | begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ') 89 | if begin_char == 'quit': 90 | exit() 91 | poem = gen_poem(begin_char) 92 | pretty_print_poem(poem_=poem) -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/Users/admin/Desktop/tensorflow_poems/model/poems-1" 2 | all_model_checkpoint_paths: "/Users/admin/Desktop/tensorflow_poems/model/poems-0" 3 | all_model_checkpoint_paths: "/Users/admin/Desktop/tensorflow_poems/model/poems-1" 4 | -------------------------------------------------------------------------------- /model/poems-0.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.data-00000-of-00001 -------------------------------------------------------------------------------- /model/poems-0.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.index -------------------------------------------------------------------------------- /model/poems-0.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.meta -------------------------------------------------------------------------------- /model/poems-1.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.data-00000-of-00001 -------------------------------------------------------------------------------- /model/poems-1.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.index -------------------------------------------------------------------------------- /model/poems-1.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.meta -------------------------------------------------------------------------------- /model/poems-2.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-2.data-00000-of-00001 -------------------------------------------------------------------------------- /model/poems-2.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-2.index -------------------------------------------------------------------------------- /poems/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /poems/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /poems/__pycache__/poems.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/poems.cpython-37.pyc -------------------------------------------------------------------------------- /poems/__pycache__/poems.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/poems.cpython-38.pyc -------------------------------------------------------------------------------- /poems/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: model.py 3 | # author: JinTian 4 | # time: 07/03/2017 3:07 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import tensorflow as tf 20 | import numpy as np 21 | 22 | 23 | def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64, 24 | learning_rate=0.01): 25 | """ 26 | construct rnn seq2seq model. 27 | :param model: model class 28 | :param input_data: input data placeholder 29 | :param output_data: output data placeholder 30 | :param vocab_size: 31 | :param rnn_size: 32 | :param num_layers: 33 | :param batch_size: 34 | :param learning_rate: 35 | :return: 36 | """ 37 | end_points = {} 38 | 39 | if model == 'rnn': 40 | cell_fun = tf.contrib.rnn.BasicRNNCell 41 | elif model == 'gru': 42 | cell_fun = tf.contrib.rnn.GRUCell 43 | elif model == 'lstm': 44 | cell_fun = tf.contrib.rnn.BasicLSTMCell 45 | 46 | cell = cell_fun(rnn_size, state_is_tuple=True) 47 | cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True) 48 | 49 | if output_data is not None: 50 | initial_state = cell.zero_state(batch_size, tf.float32) 51 | else: 52 | initial_state = cell.zero_state(1, tf.float32) 53 | 54 | with tf.device("/cpu:0"): 55 | embedding = tf.get_variable('embedding', initializer=tf.random_uniform( 56 | [vocab_size + 1, rnn_size], -1.0, 1.0)) 57 | inputs = tf.nn.embedding_lookup(embedding, input_data) 58 | 59 | # [batch_size, ?, rnn_size] = [64, ?, 128] 60 | outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state) 61 | output = tf.reshape(outputs, [-1, rnn_size]) 62 | 63 | weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1])) 64 | bias = tf.Variable(tf.zeros(shape=[vocab_size + 1])) 65 | logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias) 66 | # [?, vocab_size+1] 67 | 68 | if output_data is not None: 69 | # output_data must be one-hot encode 70 | labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1) 71 | # should be [?, vocab_size+1] 72 | 73 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits) 74 | # loss shape should be [?, vocab_size+1] 75 | total_loss = tf.reduce_mean(loss) 76 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) 77 | 78 | end_points['initial_state'] = initial_state 79 | end_points['output'] = output 80 | end_points['train_op'] = train_op 81 | end_points['total_loss'] = total_loss 82 | end_points['loss'] = loss 83 | end_points['last_state'] = last_state 84 | else: 85 | prediction = tf.nn.softmax(logits) 86 | 87 | end_points['initial_state'] = initial_state 88 | end_points['last_state'] = last_state 89 | end_points['prediction'] = prediction 90 | 91 | return end_points 92 | -------------------------------------------------------------------------------- /poems/poems.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: poems.py 3 | # author: JinTian 4 | # time: 08/03/2017 7:39 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import collections 20 | import numpy as np 21 | 22 | start_token = 'B' 23 | end_token = 'E' 24 | 25 | 26 | def process_poems(file_name): 27 | # poems -> list of numbers 28 | poems = [] 29 | with open(file_name, "r", encoding='utf-8', ) as f: 30 | for line in f.readlines(): 31 | try: 32 | title, content = line.strip().split(':') 33 | content = content.replace(' ', '') 34 | if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \ 35 | start_token in content or end_token in content: 36 | continue 37 | if len(content) < 5 or len(content) > 79: 38 | continue 39 | content = start_token + content + end_token 40 | poems.append(content) 41 | except ValueError as e: 42 | pass 43 | # poems = sorted(poems, key=len) 44 | 45 | all_words = [word for poem in poems for word in poem] 46 | counter = collections.Counter(all_words) 47 | words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True) 48 | 49 | words.append(' ') 50 | L = len(words) 51 | word_int_map = dict(zip(words, range(L))) 52 | poems_vector = [list(map(lambda word: word_int_map.get(word, L), poem)) for poem in poems] 53 | 54 | return poems_vector, word_int_map, words 55 | 56 | 57 | def generate_batch(batch_size, poems_vec, word_to_int): 58 | n_chunk = len(poems_vec) // batch_size 59 | x_batches = [] 60 | y_batches = [] 61 | for i in range(n_chunk): 62 | start_index = i * batch_size 63 | end_index = start_index + batch_size 64 | 65 | batches = poems_vec[start_index:end_index] 66 | length = max(map(len, batches)) 67 | x_data = np.full((batch_size, length), word_to_int[' '], np.int32) 68 | for row, batch in enumerate(batches): 69 | x_data[row, :len(batch)] = batch 70 | y_data = np.copy(x_data) 71 | y_data[:, :-1] = x_data[:, 1:] 72 | """ 73 | x_data y_data 74 | [6,2,4,6,9] [2,4,6,9,9] 75 | [1,4,2,8,5] [4,2,8,5,5] 76 | """ 77 | x_batches.append(x_data) 78 | y_batches.append(y_data) 79 | return x_batches, y_batches 80 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | user_input_str = input('input string:') 2 | if user_input_str == 'exit': 3 | exit() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: main.py 3 | # author: JinTian 4 | # time: 11/03/2017 9:53 AM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | import os 20 | import tensorflow as tf 21 | from poems.model import rnn_model 22 | from poems.poems import process_poems, generate_batch 23 | 24 | tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.') 25 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.') 26 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.') 27 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.') 28 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.') 29 | tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.') 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | 34 | def run_training(): 35 | if not os.path.exists(FLAGS.model_dir): 36 | os.makedirs(FLAGS.model_dir) 37 | 38 | poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path) 39 | batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int) 40 | 41 | input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 42 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) 43 | 44 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len( 45 | vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) 46 | 47 | saver = tf.train.Saver(tf.global_variables()) 48 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 49 | with tf.Session() as sess: 50 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess) 51 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 52 | sess.run(init_op) 53 | 54 | start_epoch = 0 55 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) 56 | if checkpoint: 57 | saver.restore(sess, checkpoint) 58 | print("## restore from the checkpoint {0}".format(checkpoint)) 59 | start_epoch += int(checkpoint.split('-')[-1]) 60 | print('## start training...') 61 | try: 62 | n_chunk = len(poems_vector) // FLAGS.batch_size 63 | for epoch in range(start_epoch, FLAGS.epochs): 64 | n = 0 65 | for batch in range(n_chunk): 66 | loss, _, _ = sess.run([ 67 | end_points['total_loss'], 68 | end_points['last_state'], 69 | end_points['train_op'] 70 | ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]}) 71 | n += 1 72 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss)) 73 | if epoch % 6 == 0: 74 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 75 | except KeyboardInterrupt: 76 | print('## Interrupt manually, try saving checkpoint for now...') 77 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch) 78 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch)) 79 | 80 | 81 | def main(_): 82 | run_training() 83 | 84 | 85 | if __name__ == '__main__': 86 | tf.app.run() -------------------------------------------------------------------------------- /utils/clean_cn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # file: clean_cn.py 3 | # author: JinTian 4 | # time: 08/03/2017 8:02 PM 5 | # Copyright 2017 JinTian. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ------------------------------------------------------------------------ 19 | """ 20 | this script using for clean Chinese corpus. 21 | you can set level for clean, i.e.: 22 | level='all', will clean all character that not Chinese, include punctuations 23 | level='normal', this will generate corpus like normal use, reserve alphabets and numbers 24 | level='clean', this will remove all except Chinese and Chinese punctuations 25 | 26 | besides, if you want remove complex Chinese characters, just set this to be true: 27 | simple_only=True 28 | """ 29 | 30 | import os 31 | import string 32 | 33 | 34 | cn_punctuation_set = [',', '。', '!', '?', '"', '"', '、'] 35 | en_punctuation_set = [',', '.', '?', '!', '"', '"'] 36 | 37 | 38 | def clean_cn_corpus(file_name, clean_level='all', simple_only=True, is_save=True): 39 | """ 40 | clean Chinese corpus. 41 | :param file_name: 42 | :param clean_level: 43 | :param simple_only: 44 | :param is_save: 45 | :return: clean corpus in list type. 46 | """ 47 | if os.path.dirname(file_name): 48 | base_dir = os.path.dirname(file_name) 49 | else: 50 | print('not set dir. please check') 51 | 52 | save_file = os.path.join(base_dir, os.path.basename(file_name).split('.')[0] + '_cleaned.txt') 53 | with open(file_name, 'r+') as f: 54 | clean_content = [] 55 | for l in f.readlines(): 56 | l = l.strip() 57 | if l: 58 | l = list(l) 59 | should_remove_words = [w for w in l if not should_reserve(w, clean_level)] 60 | clean_line = ''.join(c for c in l if c not in should_remove_words) 61 | if clean_line: 62 | clean_content.append(clean_line) 63 | if is_save: 64 | with open(save_file, 'w+') as f: 65 | for l in clean_content: 66 | f.write(l + '\n') 67 | print('[INFO] cleaned file have been saved to %s.' % save_file) 68 | return clean_content 69 | 70 | 71 | def should_reserve(w, clean_level): 72 | if w == ' ': 73 | return True 74 | else: 75 | if clean_level == 'all': 76 | # only reserve Chinese characters 77 | if w in cn_punctuation_set or w in string.punctuation or is_alphabet(w): 78 | return False 79 | else: 80 | return is_chinese(w) 81 | elif clean_level == 'normal': 82 | # reserve Chinese characters, English alphabet, number 83 | if is_chinese(w) or is_alphabet(w) or is_number(w): 84 | return True 85 | elif w in cn_punctuation_set or w in en_punctuation_set: 86 | return True 87 | else: 88 | return False 89 | elif clean_level == 'clean': 90 | if is_chinese(w): 91 | return True 92 | elif w in cn_punctuation_set: 93 | return True 94 | else: 95 | return False 96 | else: 97 | raise "clean_level not support %s, please set for all, normal, clean" % clean_level 98 | 99 | 100 | def is_chinese(uchar): 101 | """is chinese""" 102 | return '\u4e00' <= uchar <= '\u9fa5' 103 | 104 | 105 | def is_number(uchar): 106 | """is number""" 107 | return '\u0030' <= uchar <= '\u0039' 108 | 109 | 110 | def is_alphabet(uchar): 111 | """is alphabet""" 112 | return ('\u0041' <= uchar <= '\u005a') or ('\u0061' <= uchar <= '\u007a') 113 | 114 | def semi_angle_to_sbc(uchar): 115 | """半角转全角""" 116 | inside_code = ord(uchar) 117 | if inside_code < 0x0020 or inside_code > 0x7e: 118 | return uchar 119 | if inside_code == 0x0020: 120 | inside_code = 0x3000 121 | else: 122 | inside_code += 0xfee0 123 | return chr(inside_code) 124 | 125 | 126 | def sbc_to_semi_angle(uchar): 127 | """全角转半角""" 128 | inside_code = ord(uchar) 129 | if inside_code == 0x3000: 130 | inside_code = 0x0020 131 | else: 132 | inside_code -= 0xfee0 133 | if inside_code < 0x0020 or inside_code > 0x7e: 134 | return uchar 135 | return chr(inside_code) 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /utils/make_regulated_verse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2018 damtharvey 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | """ 24 | These functions can process data from 25 | github.com/chinese-poetry/chinese-poetry/json to make regulated verse data 26 | ready to use for training. 27 | 28 | If you don't have jsons already, just 29 | git clone https://github.com/chinese-poetry/chinese-poetry.git 30 | to get it. 31 | 32 | Regulated verse forms are expected to be tuples of (number of couplets, 33 | characters per couplet). For reference: 34 | wujue 五言絕句 = (2, 10) 35 | qijue 七言絕句 = (2, 14) 36 | wulv 五言律詩 = (4, 10) 37 | qilv 七言律詩 = (4, 14) 38 | """ 39 | 40 | import glob 41 | import os 42 | import pandas as pd 43 | 44 | 45 | def unregulated(paragraphs): 46 | """ 47 | Return True if the df row describes unregulated verse. 48 | """ 49 | if all(len(couplet) == len(paragraphs[0]) for couplet in paragraphs): 50 | return False 51 | else: 52 | return True 53 | 54 | def get_poems_in_df(df, form): 55 | """ 56 | Return a txt-friendly string of only poems in df of the specified form. 57 | """ 58 | big_string = "" 59 | for row in range(len(df)): 60 | if len(df["strains"][row]) != form[0] or \ 61 | len(df["strains"][row][0]) - 2 != form[1]: 62 | continue 63 | if "○" in str(df["paragraphs"][row]): 64 | continue 65 | if unregulated(df["paragraphs"][row]): 66 | continue 67 | big_string += df["title"][row] + ":" 68 | for couplet in df["paragraphs"][row]: 69 | big_string += couplet 70 | big_string += "\n" 71 | return big_string 72 | 73 | def get_poems_in_dir(dir, form, save_dir): 74 | """ 75 | Save to save_dir poems of form in dir in separate txt files by df. 76 | """ 77 | files = [f for f in os.listdir(dir) if "poet" in f] 78 | for file in files: # Restart partway through if kernel dies. 79 | with open(os.path.join(save_dir, file[:-5] + ".txt"), "w") as data: 80 | print("Now reading " + file) 81 | df = pd.read_json(os.path.join(dir, file)) 82 | poems = get_poems_in_df(df, form) 83 | data.write(poems) 84 | print(str(len(poems)) + " chars written to " 85 | + save_dir + "/" + file[:-5] + ".txt") 86 | return 0 87 | 88 | def combine_txt(txts_dir, save_file): 89 | """ 90 | Combine .txt files in txts_dir and save to save_file. 91 | """ 92 | read_files = glob.glob(os.path.join(txts_dir, "*.txt")) 93 | with open(save_file, "wb") as outfile: 94 | for f in read_files: 95 | with open(f, "rb") as infile: 96 | outfile.write(infile.read()) 97 | return 0 --------------------------------------------------------------------------------