├── .gitignore
├── README.md
├── compose_poem.py
├── data
├── poems.txt
├── qijue-all.txt
├── qilv-all.txt
├── wujue-all.txt
└── wulv-all.txt
├── model
├── checkpoint
├── poems-0.data-00000-of-00001
├── poems-0.index
├── poems-0.meta
├── poems-1.data-00000-of-00001
├── poems-1.index
├── poems-1.meta
├── poems-2.data-00000-of-00001
└── poems-2.index
├── poems
├── __pycache__
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── poems.cpython-37.pyc
│ └── poems.cpython-38.pyc
├── model.py
└── poems.py
├── test.py
├── train.py
└── utils
├── clean_cn.py
└── make_regulated_verse.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/
2 | model/
3 | .idea/
4 | __pycache__/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Welcome to LiBai AI Composer 👋
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | > An ai powered automatically generats poems in Chinese.
10 | >
11 | > 很久以来,我们都想让机器自己创作诗歌,当无数作家、编辑还没有抬起笔时,AI已经完成了数千篇文章。现在,这里是第一步....
12 |
13 | ### 🏠 [Homepage](https://github.com/jinfagang/tensorflow_poems)
14 |
15 | ## 👍 Outcome 结果
16 |
17 | 阅遍了近4万首唐诗,作出:
18 |
19 | ```
20 | 龙舆迎池里,控列守龙猱。
21 | 几岁芳篁落,来和晚月中。
22 | 殊乘暮心处,麦光属激羁。
23 | 铁门通眼峡,高桂露沙连。
24 | 倘子门中望,何妨嶮锦楼。
25 | 择闻洛臣识,椒苑根觞吼。
26 | 柳翰天河酒,光方入胶明。
27 | ```
28 |
29 | 这诗做的很有感觉啊,这都是勤奋的结果啊,基本上学习了全唐诗的所有精华才有了这么牛逼的能力,这一般人能做到?
30 | 本博客讲讲解一些里面实现的技术细节,如果有未尽之处,大家可以通过微信找到我,那个头像很神奇的男人。闲话不多说,先把 github 链接放上来,这个作诗机器人我会一直维护,如果大家因为时间太紧没有时间看,可以给这个项目 star 一下或者 fork,
31 | 我一推送更新你就能看到,主要是为了修复一些 api 问题,tensorflow 虽然到了1.0,但是 api 还是会变化。
32 | 把星星加起来,让更多人可以看到我们创造这个作诗机器人,后期会加入更多牛逼掉渣天的功能,比如说押韵等等。
33 |
34 | ## 📥 Install 安装
35 |
36 | ```sh
37 | git clone https://github.com/jinfagang/tensorflow_poems.git
38 | ```
39 |
40 | ## 🛠 Usage 使用
41 |
42 | ```sh
43 | # train on poems 训练
44 | python3 train.py
45 | # compose poems 作诗
46 | python3 compose_poem.py
47 | ```
48 |
49 | 训练的时候,你可能会看到如下:
50 |
51 | When you kick it off, you will see something like this:
52 |
53 | 
54 |
55 | ## 📈 Updates 更新
56 |
57 | #### 2018-8-16
58 |
59 | We are now officially announced a new project started: **StrangeAI School** - An artificial intelligence learning school and advanced algorithm exchange platform! What we believed in is: AI should made to change people's life, rather than controlled by Gaint Companies.
60 | Here you can get some previews about our projects: http://ai.loliloli.pro (strangeai.pro availiable soon)
61 |
62 | #### 2018-3-12
63 |
64 | **tensorflow_poems**来诈尸了,许久没有更新这个项目,不知不觉已经有了上千个star,感觉大家对这个还是很感兴趣,在这里我非常荣幸大家关注这个项目,但是我们不能因此而停止不前,这也是我来诈尸的目的。我会向大家展示一下我最新的进展,首先非常希望大家关注一下我倾心做的知乎专栏,人工智能从入门到逆天杀神以及每周一个黑科技,我们不仅仅要关注人工智能,还有区块链等前沿技术:
65 |
66 | - 人工智能从入门到逆天杀神(知乎专栏): https://zhuanlan.zhihu.com/ai-man
67 | - 每周一项目黑科技-TrackTech(知乎专栏): https://zhuanlan.zhihu.com/tracktech
68 | If you want talk about AI, visit our website (for now): http://ai.loliloli.pro (strangeai.pro availiable soon)
69 | , **subscribe** our WeChat channel: 奇异人工智能学院
70 |
71 | #### 2017-11-8
72 |
73 | 貌似距离上一次更新这个repo已经很久了,这段时间很多童鞋通过微信找到了我,甚至包括一些大佬。当时这个项目只是一个练手的东西,那个时候我的手法还不是非常老道。让各位踩坑了。现在**李白**强势归来。在这次的更新中增加了这些改进:
74 |
75 | - 对数据预处理脚本进行了前所未有的简化,现在连小学生都能了解了
76 | - 训练只需要运行train.py,数据和预训练模型都已经备好
77 | - 可以直接compose_poem.py 作诗,这次不会出现死循环的情况了。
78 |
79 | #### 2017-6-1 ~~可能是最后一次更新~~
80 |
81 | 我决定有时间的时候重构这个项目了,古诗,源自在下骨子里的文艺之风,最近搞得东西有点乱,所以召集大家,对这个项目感兴趣的欢迎加入扣扣群:
82 |
83 | ```
84 | 292889553
85 | ```
86 |
87 |
88 | #### 2017-3-22 重磅更新,推出藏头诗功能
89 |
90 | 一波小更新,下面的问题已经解决了:
91 |
92 | * 训练完成作诗时出现一直不出现的情况,实际上是陷入了一直作诗的死循环,已修复
93 | * 新增pretty print功能,打印出的古诗标准,接入第三方APP或者其他平台可以直接获取到标准格式的诗词
94 | * Ternimal disable了tensorflow默认的debug信息
95 | 最后最后最重要的是: **我们的作诗机器人(暂且叫李白)已经可以根据你的指定的字作诗了哦!!**
96 | 欢迎大家继续来踩,没有star的快star!!保持更新!!永远开源!!!
97 | 让我们来看看李白做的藏头诗吧:
98 |
99 | ```
100 | # 最近一直下雨,就作一首雨字开头的吧
101 | 雨霁开门中,山听淮水流。
102 | 落花遍霜霰,金壶横河湟。
103 | 年年忽息世,径远谁论吟。
104 | 惊舟望秋月,应柳待晨围。
105 | 人处山霜月,萧萧广野虚。
106 |
107 | # 李白人工智能作诗机器人的作者长得比较帅,以帅开头做一首吧
108 | 帅主何幸化,自日兼春连。
109 | 命钱犯夕兴,职馀玄赏圣。
110 | 君有不知益,浮于但神衍。
111 | (浓浓的怀才不遇之风...)
112 | ```
113 |
114 | ## 👊 它已经不仅仅能够作古诗,还能模仿周杰伦创作歌词!!
115 |
116 | 这是2017-03-9更新的功能,模仿周杰伦歌曲创作歌词,大家先来感受一下它创作的歌词:
117 |
118 | ```
119 | 我的你的她
120 | 蛾眉脚的泪花
121 | 乱飞从慌乱
122 | 笛卡尔的悲伤
123 | 迟早在是石板上
124 | 荒废了晚上
125 | 夜你的她不是她
126 | ....
127 | ```
128 |
129 | 怎么说,目前由于缺乏训练文本,导致我们的AI做的歌词有点....额,还好啦,有那么一点忧郁之风,这个周杰伦完全不是一种风格呀。
130 | 然而没有关系,目前它训练的文本还太少,只有112首歌,在这里我来呼吁大家一起来整理 **中国歌手的语料文本!!!**
131 | 如果你喜欢周杰伦的歌,可以把他的歌一首一行,每首歌句子空格分开保存到txt中,大家可以集中发到我的[邮箱](mailto:jinfagang19@163.com):
132 | 相信如果不断的加入训练文本我们的歌词创作机器人会越来越牛逼!当然我会及时把数据集更新到github上,大家可以 star 一下跟进本项目的更新。
133 |
134 | ## 👥 Authors 作者
135 |
136 | 👤 **jinfagang**
137 |
138 | * Website: http://jinfagang.github.io
139 | * GitHub: [@JinTian](https://github.com/JinTian)
140 |
141 | 👤 **William Song**
142 |
143 | - Website: http://williamzjc.gitee.io/morninglake/
144 | - GitHub: [@Freakwill](https://github.com/Freakwill)
145 | - Twitter: [@WilliamPython](https://twitter.com/WilliamPython)
146 |
147 | 👤 **Harvey Dam**
148 |
149 | - GitHub: [@damtharvey](https://github.com/damtharvey)
150 |
151 | 👤 **KnowsCount**
152 |
153 | - Website: http://docs.knowscount.cc/
154 | - GitHub: [@KnowsCount](https://github.com/KnowsCount)
155 |
156 | ## 🎉 Show your support 支持
157 |
158 | 如果帮助了你,给颗 🌟 罢!
159 |
160 | Give a 🌟 if this project helped you!
161 |
162 | ## 📝 License 协议
163 |
164 | Copyright 版权 © 2020
165 |
--------------------------------------------------------------------------------
/compose_poem.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | from poems.model import rnn_model
21 | from poems.poems import process_poems
22 | import numpy as np
23 |
24 | start_token = 'B'
25 | end_token = 'E'
26 | model_dir = './model/'
27 | corpus_file = './data/poems.txt'
28 |
29 | lr = 0.0002
30 |
31 |
32 | def to_word(predict, vocabs):
33 | predict = predict[0]
34 | predict /= np.sum(predict)
35 | sample = np.random.choice(np.arange(len(predict)), p=predict)
36 | if sample > len(vocabs):
37 | return vocabs[-1]
38 | else:
39 | return vocabs[sample]
40 |
41 |
42 | def gen_poem(begin_word):
43 | batch_size = 1
44 | print('## loading corpus from %s' % model_dir)
45 | poems_vector, word_int_map, vocabularies = process_poems(corpus_file)
46 |
47 | input_data = tf.placeholder(tf.int32, [batch_size, None])
48 |
49 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
50 | vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
51 |
52 | saver = tf.train.Saver(tf.global_variables())
53 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
54 | with tf.Session() as sess:
55 | sess.run(init_op)
56 |
57 | checkpoint = tf.train.latest_checkpoint(model_dir)
58 | saver.restore(sess, checkpoint)
59 |
60 | x = np.array([list(map(word_int_map.get, start_token))])
61 |
62 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
63 | feed_dict={input_data: x})
64 | word = begin_word or to_word(predict, vocabularies)
65 | poem_ = ''
66 |
67 | i = 0
68 | while word != end_token:
69 | poem_ += word
70 | i += 1
71 | if i > 24:
72 | break
73 | x = np.array([[word_int_map[word]]])
74 | [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
75 | feed_dict={input_data: x, end_points['initial_state']: last_state})
76 | word = to_word(predict, vocabularies)
77 |
78 | return poem_
79 |
80 |
81 | def pretty_print_poem(poem_):
82 | poem_sentences = poem_.split('。')
83 | for s in poem_sentences:
84 | if s != '' and len(s) > 10:
85 | print(s + '。')
86 |
87 | if __name__ == '__main__':
88 | begin_char = input('## (输入 quit 退出)请输入第一个字 please input the first character: ')
89 | if begin_char == 'quit':
90 | exit()
91 | poem = gen_poem(begin_char)
92 | pretty_print_poem(poem_=poem)
--------------------------------------------------------------------------------
/model/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "/Users/admin/Desktop/tensorflow_poems/model/poems-1"
2 | all_model_checkpoint_paths: "/Users/admin/Desktop/tensorflow_poems/model/poems-0"
3 | all_model_checkpoint_paths: "/Users/admin/Desktop/tensorflow_poems/model/poems-1"
4 |
--------------------------------------------------------------------------------
/model/poems-0.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.data-00000-of-00001
--------------------------------------------------------------------------------
/model/poems-0.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.index
--------------------------------------------------------------------------------
/model/poems-0.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-0.meta
--------------------------------------------------------------------------------
/model/poems-1.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.data-00000-of-00001
--------------------------------------------------------------------------------
/model/poems-1.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.index
--------------------------------------------------------------------------------
/model/poems-1.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-1.meta
--------------------------------------------------------------------------------
/model/poems-2.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-2.data-00000-of-00001
--------------------------------------------------------------------------------
/model/poems-2.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/model/poems-2.index
--------------------------------------------------------------------------------
/poems/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/poems/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/poems/__pycache__/poems.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/poems.cpython-37.pyc
--------------------------------------------------------------------------------
/poems/__pycache__/poems.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucasjinreal/tensorflow_poems/d92867b59c2b8ffc876b891755fbd9f63fda0ebf/poems/__pycache__/poems.cpython-38.pyc
--------------------------------------------------------------------------------
/poems/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: model.py
3 | # author: JinTian
4 | # time: 07/03/2017 3:07 PM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import tensorflow as tf
20 | import numpy as np
21 |
22 |
23 | def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
24 | learning_rate=0.01):
25 | """
26 | construct rnn seq2seq model.
27 | :param model: model class
28 | :param input_data: input data placeholder
29 | :param output_data: output data placeholder
30 | :param vocab_size:
31 | :param rnn_size:
32 | :param num_layers:
33 | :param batch_size:
34 | :param learning_rate:
35 | :return:
36 | """
37 | end_points = {}
38 |
39 | if model == 'rnn':
40 | cell_fun = tf.contrib.rnn.BasicRNNCell
41 | elif model == 'gru':
42 | cell_fun = tf.contrib.rnn.GRUCell
43 | elif model == 'lstm':
44 | cell_fun = tf.contrib.rnn.BasicLSTMCell
45 |
46 | cell = cell_fun(rnn_size, state_is_tuple=True)
47 | cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
48 |
49 | if output_data is not None:
50 | initial_state = cell.zero_state(batch_size, tf.float32)
51 | else:
52 | initial_state = cell.zero_state(1, tf.float32)
53 |
54 | with tf.device("/cpu:0"):
55 | embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
56 | [vocab_size + 1, rnn_size], -1.0, 1.0))
57 | inputs = tf.nn.embedding_lookup(embedding, input_data)
58 |
59 | # [batch_size, ?, rnn_size] = [64, ?, 128]
60 | outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
61 | output = tf.reshape(outputs, [-1, rnn_size])
62 |
63 | weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
64 | bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
65 | logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
66 | # [?, vocab_size+1]
67 |
68 | if output_data is not None:
69 | # output_data must be one-hot encode
70 | labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
71 | # should be [?, vocab_size+1]
72 |
73 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
74 | # loss shape should be [?, vocab_size+1]
75 | total_loss = tf.reduce_mean(loss)
76 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
77 |
78 | end_points['initial_state'] = initial_state
79 | end_points['output'] = output
80 | end_points['train_op'] = train_op
81 | end_points['total_loss'] = total_loss
82 | end_points['loss'] = loss
83 | end_points['last_state'] = last_state
84 | else:
85 | prediction = tf.nn.softmax(logits)
86 |
87 | end_points['initial_state'] = initial_state
88 | end_points['last_state'] = last_state
89 | end_points['prediction'] = prediction
90 |
91 | return end_points
92 |
--------------------------------------------------------------------------------
/poems/poems.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: poems.py
3 | # author: JinTian
4 | # time: 08/03/2017 7:39 PM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import collections
20 | import numpy as np
21 |
22 | start_token = 'B'
23 | end_token = 'E'
24 |
25 |
26 | def process_poems(file_name):
27 | # poems -> list of numbers
28 | poems = []
29 | with open(file_name, "r", encoding='utf-8', ) as f:
30 | for line in f.readlines():
31 | try:
32 | title, content = line.strip().split(':')
33 | content = content.replace(' ', '')
34 | if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
35 | start_token in content or end_token in content:
36 | continue
37 | if len(content) < 5 or len(content) > 79:
38 | continue
39 | content = start_token + content + end_token
40 | poems.append(content)
41 | except ValueError as e:
42 | pass
43 | # poems = sorted(poems, key=len)
44 |
45 | all_words = [word for poem in poems for word in poem]
46 | counter = collections.Counter(all_words)
47 | words = sorted(counter.keys(), key=lambda x: counter[x], reverse=True)
48 |
49 | words.append(' ')
50 | L = len(words)
51 | word_int_map = dict(zip(words, range(L)))
52 | poems_vector = [list(map(lambda word: word_int_map.get(word, L), poem)) for poem in poems]
53 |
54 | return poems_vector, word_int_map, words
55 |
56 |
57 | def generate_batch(batch_size, poems_vec, word_to_int):
58 | n_chunk = len(poems_vec) // batch_size
59 | x_batches = []
60 | y_batches = []
61 | for i in range(n_chunk):
62 | start_index = i * batch_size
63 | end_index = start_index + batch_size
64 |
65 | batches = poems_vec[start_index:end_index]
66 | length = max(map(len, batches))
67 | x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
68 | for row, batch in enumerate(batches):
69 | x_data[row, :len(batch)] = batch
70 | y_data = np.copy(x_data)
71 | y_data[:, :-1] = x_data[:, 1:]
72 | """
73 | x_data y_data
74 | [6,2,4,6,9] [2,4,6,9,9]
75 | [1,4,2,8,5] [4,2,8,5,5]
76 | """
77 | x_batches.append(x_data)
78 | y_batches.append(y_data)
79 | return x_batches, y_batches
80 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | user_input_str = input('input string:')
2 | if user_input_str == 'exit':
3 | exit()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: main.py
3 | # author: JinTian
4 | # time: 11/03/2017 9:53 AM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | import os
20 | import tensorflow as tf
21 | from poems.model import rnn_model
22 | from poems.poems import process_poems, generate_batch
23 |
24 | tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
25 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
26 | tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
27 | tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
28 | tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
29 | tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
30 |
31 | FLAGS = tf.app.flags.FLAGS
32 |
33 |
34 | def run_training():
35 | if not os.path.exists(FLAGS.model_dir):
36 | os.makedirs(FLAGS.model_dir)
37 |
38 | poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
39 | batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
40 |
41 | input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
42 | output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
43 |
44 | end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
45 | vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
46 |
47 | saver = tf.train.Saver(tf.global_variables())
48 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
49 | with tf.Session() as sess:
50 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
51 | # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
52 | sess.run(init_op)
53 |
54 | start_epoch = 0
55 | checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
56 | if checkpoint:
57 | saver.restore(sess, checkpoint)
58 | print("## restore from the checkpoint {0}".format(checkpoint))
59 | start_epoch += int(checkpoint.split('-')[-1])
60 | print('## start training...')
61 | try:
62 | n_chunk = len(poems_vector) // FLAGS.batch_size
63 | for epoch in range(start_epoch, FLAGS.epochs):
64 | n = 0
65 | for batch in range(n_chunk):
66 | loss, _, _ = sess.run([
67 | end_points['total_loss'],
68 | end_points['last_state'],
69 | end_points['train_op']
70 | ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
71 | n += 1
72 | print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
73 | if epoch % 6 == 0:
74 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
75 | except KeyboardInterrupt:
76 | print('## Interrupt manually, try saving checkpoint for now...')
77 | saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
78 | print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
79 |
80 |
81 | def main(_):
82 | run_training()
83 |
84 |
85 | if __name__ == '__main__':
86 | tf.app.run()
--------------------------------------------------------------------------------
/utils/clean_cn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # file: clean_cn.py
3 | # author: JinTian
4 | # time: 08/03/2017 8:02 PM
5 | # Copyright 2017 JinTian. All Rights Reserved.
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | # ------------------------------------------------------------------------
19 | """
20 | this script using for clean Chinese corpus.
21 | you can set level for clean, i.e.:
22 | level='all', will clean all character that not Chinese, include punctuations
23 | level='normal', this will generate corpus like normal use, reserve alphabets and numbers
24 | level='clean', this will remove all except Chinese and Chinese punctuations
25 |
26 | besides, if you want remove complex Chinese characters, just set this to be true:
27 | simple_only=True
28 | """
29 |
30 | import os
31 | import string
32 |
33 |
34 | cn_punctuation_set = [',', '。', '!', '?', '"', '"', '、']
35 | en_punctuation_set = [',', '.', '?', '!', '"', '"']
36 |
37 |
38 | def clean_cn_corpus(file_name, clean_level='all', simple_only=True, is_save=True):
39 | """
40 | clean Chinese corpus.
41 | :param file_name:
42 | :param clean_level:
43 | :param simple_only:
44 | :param is_save:
45 | :return: clean corpus in list type.
46 | """
47 | if os.path.dirname(file_name):
48 | base_dir = os.path.dirname(file_name)
49 | else:
50 | print('not set dir. please check')
51 |
52 | save_file = os.path.join(base_dir, os.path.basename(file_name).split('.')[0] + '_cleaned.txt')
53 | with open(file_name, 'r+') as f:
54 | clean_content = []
55 | for l in f.readlines():
56 | l = l.strip()
57 | if l:
58 | l = list(l)
59 | should_remove_words = [w for w in l if not should_reserve(w, clean_level)]
60 | clean_line = ''.join(c for c in l if c not in should_remove_words)
61 | if clean_line:
62 | clean_content.append(clean_line)
63 | if is_save:
64 | with open(save_file, 'w+') as f:
65 | for l in clean_content:
66 | f.write(l + '\n')
67 | print('[INFO] cleaned file have been saved to %s.' % save_file)
68 | return clean_content
69 |
70 |
71 | def should_reserve(w, clean_level):
72 | if w == ' ':
73 | return True
74 | else:
75 | if clean_level == 'all':
76 | # only reserve Chinese characters
77 | if w in cn_punctuation_set or w in string.punctuation or is_alphabet(w):
78 | return False
79 | else:
80 | return is_chinese(w)
81 | elif clean_level == 'normal':
82 | # reserve Chinese characters, English alphabet, number
83 | if is_chinese(w) or is_alphabet(w) or is_number(w):
84 | return True
85 | elif w in cn_punctuation_set or w in en_punctuation_set:
86 | return True
87 | else:
88 | return False
89 | elif clean_level == 'clean':
90 | if is_chinese(w):
91 | return True
92 | elif w in cn_punctuation_set:
93 | return True
94 | else:
95 | return False
96 | else:
97 | raise "clean_level not support %s, please set for all, normal, clean" % clean_level
98 |
99 |
100 | def is_chinese(uchar):
101 | """is chinese"""
102 | return '\u4e00' <= uchar <= '\u9fa5'
103 |
104 |
105 | def is_number(uchar):
106 | """is number"""
107 | return '\u0030' <= uchar <= '\u0039'
108 |
109 |
110 | def is_alphabet(uchar):
111 | """is alphabet"""
112 | return ('\u0041' <= uchar <= '\u005a') or ('\u0061' <= uchar <= '\u007a')
113 |
114 | def semi_angle_to_sbc(uchar):
115 | """半角转全角"""
116 | inside_code = ord(uchar)
117 | if inside_code < 0x0020 or inside_code > 0x7e:
118 | return uchar
119 | if inside_code == 0x0020:
120 | inside_code = 0x3000
121 | else:
122 | inside_code += 0xfee0
123 | return chr(inside_code)
124 |
125 |
126 | def sbc_to_semi_angle(uchar):
127 | """全角转半角"""
128 | inside_code = ord(uchar)
129 | if inside_code == 0x3000:
130 | inside_code = 0x0020
131 | else:
132 | inside_code -= 0xfee0
133 | if inside_code < 0x0020 or inside_code > 0x7e:
134 | return uchar
135 | return chr(inside_code)
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/utils/make_regulated_verse.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # MIT License
4 | #
5 | # Copyright (c) 2018 damtharvey
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23 | """
24 | These functions can process data from
25 | github.com/chinese-poetry/chinese-poetry/json to make regulated verse data
26 | ready to use for training.
27 |
28 | If you don't have jsons already, just
29 | git clone https://github.com/chinese-poetry/chinese-poetry.git
30 | to get it.
31 |
32 | Regulated verse forms are expected to be tuples of (number of couplets,
33 | characters per couplet). For reference:
34 | wujue 五言絕句 = (2, 10)
35 | qijue 七言絕句 = (2, 14)
36 | wulv 五言律詩 = (4, 10)
37 | qilv 七言律詩 = (4, 14)
38 | """
39 |
40 | import glob
41 | import os
42 | import pandas as pd
43 |
44 |
45 | def unregulated(paragraphs):
46 | """
47 | Return True if the df row describes unregulated verse.
48 | """
49 | if all(len(couplet) == len(paragraphs[0]) for couplet in paragraphs):
50 | return False
51 | else:
52 | return True
53 |
54 | def get_poems_in_df(df, form):
55 | """
56 | Return a txt-friendly string of only poems in df of the specified form.
57 | """
58 | big_string = ""
59 | for row in range(len(df)):
60 | if len(df["strains"][row]) != form[0] or \
61 | len(df["strains"][row][0]) - 2 != form[1]:
62 | continue
63 | if "○" in str(df["paragraphs"][row]):
64 | continue
65 | if unregulated(df["paragraphs"][row]):
66 | continue
67 | big_string += df["title"][row] + ":"
68 | for couplet in df["paragraphs"][row]:
69 | big_string += couplet
70 | big_string += "\n"
71 | return big_string
72 |
73 | def get_poems_in_dir(dir, form, save_dir):
74 | """
75 | Save to save_dir poems of form in dir in separate txt files by df.
76 | """
77 | files = [f for f in os.listdir(dir) if "poet" in f]
78 | for file in files: # Restart partway through if kernel dies.
79 | with open(os.path.join(save_dir, file[:-5] + ".txt"), "w") as data:
80 | print("Now reading " + file)
81 | df = pd.read_json(os.path.join(dir, file))
82 | poems = get_poems_in_df(df, form)
83 | data.write(poems)
84 | print(str(len(poems)) + " chars written to "
85 | + save_dir + "/" + file[:-5] + ".txt")
86 | return 0
87 |
88 | def combine_txt(txts_dir, save_file):
89 | """
90 | Combine .txt files in txts_dir and save to save_file.
91 | """
92 | read_files = glob.glob(os.path.join(txts_dir, "*.txt"))
93 | with open(save_file, "wb") as outfile:
94 | for f in read_files:
95 | with open(f, "rb") as infile:
96 | outfile.write(infile.read())
97 | return 0
--------------------------------------------------------------------------------