├── log
└── tmp.txt
├── checkpoint
└── tmp.txt
├── corpus
└── tmp.txt
├── pictures
├── p0.png
├── p1.png
├── p2.png
├── p3.png
├── p4.png
├── p5.png
└── logo.jpg
├── codes
├── __pycache__
│ ├── beam.cpython-37.pyc
│ ├── tool.cpython-37.pyc
│ ├── config.cpython-37.pyc
│ ├── decay.cpython-37.pyc
│ ├── filter.cpython-37.pyc
│ ├── graphs.cpython-37.pyc
│ ├── layers.cpython-37.pyc
│ ├── logger.cpython-37.pyc
│ ├── utils.cpython-37.pyc
│ ├── criterion.cpython-37.pyc
│ ├── generator.cpython-37.pyc
│ ├── scheduler.cpython-37.pyc
│ ├── cl_trainer.cpython-37.pyc
│ ├── dae_trainer.cpython-37.pyc
│ ├── dseq_trainer.cpython-37.pyc
│ ├── mix_trainer.cpython-37.pyc
│ ├── rhythm_tool.cpython-37.pyc
│ ├── wm_trainer.cpython-37.pyc
│ ├── visualization.cpython-37.pyc
│ └── spectral_normalization.cpython-37.pyc
├── criterion.py
├── decay.py
├── scheduler.py
├── train.py
├── config.py
├── logger.py
├── utils.py
├── visualization.py
├── dseq_trainer.py
├── generate.py
├── wm_trainer.py
├── filter.py
├── rhythm_tool.py
├── beam.py
├── layers.py
├── generator.py
├── tool.py
└── graphs.py
├── preprocess
├── __pycache__
│ ├── rhythm_tool.cpython-37.pyc
│ └── pattern_extractor.cpython-37.pyc
├── pattern_extractor.py
├── rhythm_tool.py
└── preprocess.py
├── data
├── GenrePatterns.txt
└── fchars.txt
└── README.md
/log/tmp.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/checkpoint/tmp.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/corpus/tmp.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pictures/p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p0.png
--------------------------------------------------------------------------------
/pictures/p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p1.png
--------------------------------------------------------------------------------
/pictures/p2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p2.png
--------------------------------------------------------------------------------
/pictures/p3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p3.png
--------------------------------------------------------------------------------
/pictures/p4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p4.png
--------------------------------------------------------------------------------
/pictures/p5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p5.png
--------------------------------------------------------------------------------
/pictures/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/logo.jpg
--------------------------------------------------------------------------------
/codes/__pycache__/beam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/beam.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/tool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/tool.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/decay.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/decay.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/filter.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/filter.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/graphs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/graphs.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/layers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/layers.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/criterion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/criterion.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/generator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/generator.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/cl_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/cl_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/dae_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/dae_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/dseq_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/dseq_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/mix_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/mix_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/rhythm_tool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/rhythm_tool.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/wm_trainer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/wm_trainer.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/visualization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/visualization.cpython-37.pyc
--------------------------------------------------------------------------------
/preprocess/__pycache__/rhythm_tool.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/preprocess/__pycache__/rhythm_tool.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/__pycache__/spectral_normalization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/spectral_normalization.cpython-37.pyc
--------------------------------------------------------------------------------
/preprocess/__pycache__/pattern_extractor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/preprocess/__pycache__/pattern_extractor.cpython-37.pyc
--------------------------------------------------------------------------------
/data/GenrePatterns.txt:
--------------------------------------------------------------------------------
1 | [00]#七绝一#4#0 32 0 31 31 32 32|0 31 0 32 32 31 33|0 31 0 32 31 31 32|0 32 31 31 32 32 33
2 | [01]#七绝二#4#0 31 0 32 32 31 33|0 32 31 31 32 32 33|0 32 0 31 31 32 32|0 31 0 32 32 31 33
3 | [02]#七绝三#4#0 32 31 31 32 32 33|0 31 0 32 32 31 33|0 31 0 32 31 31 32|0 32 31 31 32 32 33
4 | [03]#七绝四#4#0 31 0 32 31 31 32|0 32 31 31 32 32 33|0 32 0 31 31 32 32|0 31 0 32 32 31 33
5 | [04]#五绝一#4#0 31 31 32 32|0 32 32 31 33|0 32 31 31 32|31 31 32 32 33
6 | [05]#五绝二#4#0 32 32 31 33|31 31 32 32 33|0 31 31 32 32|0 32 32 31 33
7 | [06]#五绝三#4#31 31 32 32 33|0 32 32 31 33|0 32 0 31 32|31 31 32 32 33
8 | [07]#五绝四#4#0 32 31 31 32|31 31 32 32 33|0 31 31 32 32|0 32 32 31 33
9 |
--------------------------------------------------------------------------------
/codes/criterion.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:06:09
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import torch
14 | from torch import nn
15 |
16 | class Criterion(nn.Module):
17 | def __init__(self, pad_idx):
18 | super().__init__()
19 | self._criterion = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=pad_idx)
20 | self._pad_idx = pad_idx
21 |
22 |
23 | def forward(self, outputs, targets):
24 | # outputs: (B, L, V)
25 | # targets: (B, L)
26 |
27 | vocab_size = outputs.size(-1)
28 | outs = outputs.contiguous().view(-1, vocab_size) # outs: (N, V)
29 | tgts = targets.contiguous().view(-1) # tgts: (N)
30 |
31 | non_pad_mask = tgts.ne(self._pad_idx)
32 |
33 | loss = self._criterion(outs, tgts) # [N]
34 | loss = loss.masked_select(non_pad_mask).mean()
35 |
36 | return loss
--------------------------------------------------------------------------------
/codes/decay.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:06:44
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import numpy as np
14 |
15 | #---------------------------------------------------
16 | class RateDecay(object):
17 | '''Basic class for different types of rate decay,
18 | e.g., teach forcing ratio, gumbel temperature,
19 | KL annealing.
20 | '''
21 | def __init__(self, burn_down_steps, decay_steps, limit_v):
22 |
23 | self.step = 0
24 | self.rate = 1.0
25 |
26 | self.burn_down_steps = burn_down_steps
27 | self.decay_steps = decay_steps
28 |
29 | self.limit_v = limit_v
30 |
31 |
32 | def decay_funtion(self):
33 | # to be reconstructed
34 | return self.rate
35 |
36 |
37 | def do_step(self):
38 | # update rate
39 | self.step += 1
40 | if self.step > self.burn_down_steps:
41 | self.rate = self.decay_funtion()
42 |
43 | return self.rate
44 |
45 |
46 | def get_rate(self):
47 | return self.rate
48 |
49 |
50 | class ExponentialDecay(RateDecay):
51 | def __init__(self, burn_down_steps, decay_steps, min_v):
52 | super(ExponentialDecay, self).__init__(
53 | burn_down_steps, decay_steps, min_v)
54 |
55 | self.__alpha = np.log(self.limit_v) / (-decay_steps)
56 |
57 | def decay_funtion(self):
58 | new_rate = max(np.exp(-self.__alpha*self.step), self.limit_v)
59 | return new_rate
--------------------------------------------------------------------------------
/codes/scheduler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:09:56
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | class ISRScheduler(object):
14 | '''Inverse Square Root Schedule
15 | '''
16 | def __init__(self, optimizer, warmup_steps, max_lr=5e-4, min_lr=3e-5, init_lr=1e-5, beta=0.55):
17 | self._optimizer = optimizer
18 |
19 | self._step = 0
20 | self._rate = init_lr
21 |
22 | self._warmup_steps = warmup_steps
23 | self._max_lr = max_lr
24 | self._min_lr = min_lr
25 | self._init_lr = init_lr
26 |
27 | self._alpha = (max_lr-init_lr) / warmup_steps
28 | self._beta = beta
29 | self._gama = max_lr * warmup_steps ** (beta)
30 |
31 |
32 | def step(self):
33 | self._step += 1
34 | rate = self.rate()
35 | for p in self._optimizer.param_groups:
36 | p['lr'] = rate
37 | self._rate = rate
38 | self._optimizer.step()
39 |
40 | def rate(self):
41 | step = self._step
42 | if step < self._warmup_steps:
43 | lr = self._init_lr + self._alpha * step
44 | else:
45 | lr = self._gama * step ** (-self._beta)
46 |
47 | if step > self._warmup_steps:
48 | lr = max(lr, self._min_lr)
49 | return lr
50 |
51 | def zero_grad(self):
52 | self._optimizer.zero_grad()
53 |
54 | def state_dict(self):
55 | return self._optimizer.state_dict()
56 |
57 | def load_state_dict(self, dic):
58 | self._optimizer.load_state_dict(dic)
--------------------------------------------------------------------------------
/codes/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:15:55
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | from dseq_trainer import DSeqTrainer
14 | from wm_trainer import WMTrainer
15 |
16 | from graphs import WorkingMemoryModel
17 | from tool import Tool
18 | from config import device, hparams
19 | import utils
20 |
21 |
22 | def pretrain(wm_model, tool, hps, specified_device):
23 | dseq_trainer = DSeqTrainer(hps, specified_device)
24 |
25 | print ("dseq pretraining...")
26 | dseq_trainer.train(wm_model, tool)
27 | print ("dseq pretraining done!")
28 |
29 |
30 |
31 | def train(wm_model, tool, hps, specified_device):
32 | last_epoch = utils.restore_checkpoint(
33 | hps.model_dir, specified_device, wm_model)
34 |
35 | if last_epoch is not None:
36 | print ("checkpoint exsits! directly recover!")
37 | else:
38 | print ("checkpoint not exsits! train from scratch!")
39 |
40 | wm_trainer = WMTrainer(hps, specified_device)
41 | wm_trainer.train(wm_model, tool)
42 |
43 |
44 | def main():
45 | hps = hparams
46 | tool = Tool(hps.sens_num, hps.sen_len,
47 | hps.key_len, hps.topic_slots, hps.corrupt_ratio)
48 | tool.load_dic(hps.vocab_path, hps.ivocab_path)
49 | vocab_size = tool.get_vocab_size()
50 | PAD_ID = tool.get_PAD_ID()
51 | B_ID = tool.get_B_ID()
52 | assert vocab_size > 0 and PAD_ID >=0 and B_ID >= 0
53 | hps = hps._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID)
54 |
55 | print ("hyper-patameters:")
56 | print (hps)
57 | input("please check the hyper-parameters, and then press any key to continue >")
58 |
59 | wm_model = WorkingMemoryModel(hps, device)
60 | wm_model = wm_model.to(device)
61 |
62 | pretrain(wm_model, tool, hps, device)
63 | train(wm_model, tool, hps, device)
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------
/data/fchars.txt:
--------------------------------------------------------------------------------
1 | 一
2 | 二
3 | 三
4 | 四
5 | 五
6 | 六
7 | 七
8 | 八
9 | 九
10 | 十
11 | 百
12 | 千
13 | 万
14 | 几
15 | 只
16 | 个
17 | 两
18 | 傍
19 | 上
20 | 更
21 | 在
22 | 此
23 | 止
24 | 正
25 | 动
26 | 了
27 | 加
28 | 于
29 | 互
30 | 亟
31 | 交
32 | 亦
33 | 随
34 | 洊
35 | 弥
36 | 纵
37 | 弗
38 | 率
39 | 再
40 | 继
41 | 啻
42 | 促
43 | 愈
44 | 焉
45 | 如
46 | 嘻
47 | 矣
48 | 然
49 | 属
50 | 綦
51 | 俱
52 | 况
53 | 殊
54 | 讫
55 | 间
56 | 讵
57 | 孰
58 | 许
59 | 俄
60 | 设
61 | 披
62 | 等
63 | 犹
64 | 渐
65 | 抵
66 | 把
67 | 顾
68 | 须
69 | 屡
70 | 顷
71 | 呼
72 | 希
73 | 忽
74 | 遽
75 | 是
76 | 汔
77 | 遍
78 | 呜
79 | 遄
80 | 忝
81 | 遂
82 | 必
83 | 常
84 | 有
85 | 鲜
86 | 倏
87 | 最
88 | 沓
89 | 敬
90 | 唯
91 | 虑
92 | 故
93 | 沿
94 | 便
95 | 益
96 | 盍
97 | 抑
98 | 固
99 | 被
100 | 相
101 | 直
102 | 因
103 | 漫
104 | 莫
105 | 却
106 | 即
107 | 凡
108 | 到
109 | 特
110 | 蹔
111 | 能
112 | 廑
113 | 依
114 | 阿
115 | 兼
116 | 举
117 | 猥
118 | 蔑
119 | 否
120 | 循
121 | 徐
122 | 后
123 | 徒
124 | 差
125 | 得
126 | 通
127 | 速
128 | 总
129 | 向
130 | 适
131 | 同
132 | 猝
133 | 逆
134 | 各
135 | 猗
136 | 斯
137 | 临
138 | 奚
139 | 於
140 | 奉
141 | 奈
142 | 方
143 | 已
144 | 要
145 | 奄
146 | 普
147 | 类
148 | 吁
149 | 全
150 | 样
151 | 仅
152 | 略
153 | 从
154 | 仍
155 | 今
156 | 介
157 | 似
158 | 伪
159 | 傥
160 | 令
161 | 每
162 | 会
163 | 洎
164 | 前
165 | 尝
166 | 尚
167 | 苟
168 | 尔
169 | 少
170 | 尽
171 | 偶
172 | 若
173 | 溘
174 | 庸
175 | 约
176 | 庶
177 | 酷
178 | 请
179 | 应
180 | 底
181 | 身
182 | 蚤
183 | 杂
184 | 罕
185 | 暨
186 | 都
187 | 来
188 | 诸
189 | 暂
190 | 暇
191 | 哉
192 | 比
193 | 毕
194 | 兹
195 | 为
196 | 其
197 | 共
198 | 靡
199 | 公
200 | 兮
201 | 毋
202 | 详
203 | 业
204 | 丕
205 | 且
206 | 专
207 | 非
208 | 与
209 | 不
210 | 先
211 | 克
212 | 嗟
213 | 兀
214 | 元
215 | 稍
216 | 潜
217 | 叨
218 | 可
219 | 较
220 | 辄
221 | 叵
222 | 堪
223 | 及
224 | 又
225 | 反
226 | 或
227 | 往
228 | 恒
229 | 彼
230 | 当
231 | 恰
232 | 皆
233 | 的
234 | 索
235 | 擅
236 | 累
237 | 謇
238 | 偏
239 | 假
240 | 使
241 | 由
242 | 以
243 | 佥
244 | 偕
245 | 佯
246 | 何
247 | 刚
248 | 则
249 | 切
250 | 但
251 | 甚
252 | 生
253 | 竟
254 | 竞
255 | 勿
256 | 立
257 | 耶
258 | 而
259 | 者
260 | 端
261 | 预
262 | 颇
263 | 岂
264 | 频
265 | 蒙
266 | 无
267 | 既
268 | 时
269 | 矧
270 | 处
271 | 喏
272 | 复
273 | 虽
274 | 块
275 | 坐
276 | 致
277 | 至
278 | 自
279 | 欻
280 | 也
281 | 乌
282 | 乍
283 | 乎
284 | 之
285 | 欤
286 | 乃
287 | 宁
288 | 它
289 | 脱
290 | 倘
291 | 趋
292 | 趁
293 | 夫
294 | 例
295 | 连
296 | 厪
297 | 荐
298 | 这
299 | 迄
300 | 所
301 | 迭
302 | 迪
303 | 替
304 | 那
305 | 咸
306 | 曷
307 | 曩
308 | 氐
309 | 悉
310 | 咄
311 | 并
312 | 和
313 | 噫
314 |
--------------------------------------------------------------------------------
/codes/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 21:05:31
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 |
14 | from collections import namedtuple
15 | import torch
16 | HParams = namedtuple('HParams',
17 | 'vocab_size, pad_idx, bos_idx,'
18 | 'word_emb_size, ph_emb_size, len_emb_size,'
19 | 'hidden_size, mem_size, global_trace_size, topic_trace_size,'
20 | 'his_mem_slots, topic_slots, sens_num, sen_len, key_len,'
21 |
22 | 'batch_size, drop_ratio, attn_drop_ratio, weight_decay, clip_grad_norm,'
23 | 'max_lr, min_lr, init_lr, warmup_steps,'
24 | 'min_tr, burn_down_tr, decay_tr,'
25 | 'tau_annealing_steps, min_tau,'
26 |
27 | 'log_steps, sample_num, max_epoches,'
28 | 'save_epoches, validate_epoches,'
29 | 'vocab_path, ivocab_path, train_data, valid_data,'
30 | 'model_dir, data_dir, train_log_path, valid_log_path,'
31 |
32 | 'corrupt_ratio, dseq_epoches, dseq_batch_size,'
33 | 'dseq_max_lr, dseq_min_lr, dseq_init_lr dseq_warmup_steps,'
34 | 'dseq_min_tr, dseq_burn_down_tr, dseq_decay_tr,'
35 |
36 | 'dseq_log_steps, dseq_validate_epoches, dseq_save_epoches,'
37 | 'dseq_train_log_path, dseq_valid_log_path,'
38 | )
39 |
40 |
41 | hparams = HParams(
42 | # --------------------
43 | # general settings
44 | vocab_size=-1, pad_idx=-1, bos_idx=-1, # to be replaced by true size after loading dictionary
45 | word_emb_size=256, ph_emb_size=64, len_emb_size=32,
46 | hidden_size=512, mem_size=512, global_trace_size=512, topic_trace_size=20,
47 | his_mem_slots=4, topic_slots=4, sens_num=4, sen_len=10, key_len=2,
48 |
49 |
50 | batch_size=128, drop_ratio=0.25, attn_drop_ratio=0.1,
51 | weight_decay=2.5e-4, clip_grad_norm=2.0,
52 | max_lr=1e-3, min_lr=5e-8, init_lr=1e-4, warmup_steps=6000, # learning rate decay
53 | min_tr=0.80, burn_down_tr=3, decay_tr=10, # epoches for teach forcing ratio decay
54 | tau_annealing_steps=30000, min_tau=0.01,# Gumbel temperature, from 1 to min_tau
55 |
56 | log_steps=100, sample_num=1, max_epoches=14,
57 | save_epoches=2, validate_epoches=1,
58 |
59 | vocab_path="../corpus/vocab.pickle",
60 | ivocab_path="../corpus/ivocab.pickle",
61 | train_data="../corpus/train_data.pickle",
62 | valid_data="../corpus/valid_data.pickle",
63 | model_dir="../checkpoint/",
64 | data_dir="../data/",
65 | train_log_path="../log/wm_train_log.txt",
66 | valid_log_path="../log/wm_valid_log.txt",
67 |
68 | #--------------------------
69 | # for pre-training
70 | corrupt_ratio=0.1, dseq_epoches=10, dseq_batch_size=256,
71 | dseq_max_lr=1e-3, dseq_min_lr=5e-5, dseq_init_lr=1e-4, dseq_warmup_steps=6000,
72 | dseq_min_tr=0.80, dseq_burn_down_tr=3, dseq_decay_tr=7,
73 |
74 | dseq_log_steps=200, dseq_validate_epoches=1, dseq_save_epoches=2,
75 | dseq_train_log_path="../log/dae_train_log.txt",
76 | dseq_valid_log_path="../log/dae_valid_log.txt"
77 |
78 | )
79 |
80 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
81 |
--------------------------------------------------------------------------------
/codes/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:09:32
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import numpy as np
14 | import time
15 |
16 |
17 | class InfoLogger(object):
18 | """docstring for LogInfo"""
19 | def __init__(self, mode):
20 | super(InfoLogger).__init__()
21 | self._mode = mode # string, 'train' or 'valid'
22 | self._total_steps = 0
23 | self._batch_num = 0
24 | self._log_steps = 0
25 | self._cur_step = 0
26 | self._cur_epoch = 1
27 |
28 | self._start_time = 0
29 | self._end_time = 0
30 |
31 | #--------------------------
32 | self._log_path = "" # path to save the log file
33 |
34 | # -------------------------
35 | self._decay_rates = {'learning_rate':1.0,
36 | 'teach_ratio':1.0, 'temperature':1.0}
37 |
38 |
39 | def set_batch_num(self, batch_num):
40 | self._batch_num = batch_num
41 | def set_log_steps(self, log_steps):
42 | self._log_steps = log_steps
43 | def set_log_path(self, log_path):
44 | self._log_path = log_path
45 |
46 | def set_rate(self, name, value):
47 | self._decay_rates[name] = value
48 |
49 |
50 | def set_start_time(self):
51 | self._start_time = time.time()
52 |
53 | def set_end_time(self):
54 | self._end_time = time.time()
55 |
56 | def add_step(self):
57 | self._total_steps += 1
58 | self._cur_step += 1
59 |
60 | def add_epoch(self):
61 | self._cur_step = 0
62 | self._cur_epoch += 1
63 |
64 |
65 | # ------------------------------
66 | @property
67 | def cur_process(self):
68 | ratio = float(self._cur_step) / self._batch_num * 100
69 | process_str = "%d/%d %.1f%%" % (self._cur_step, self._batch_num, ratio)
70 | return process_str
71 |
72 | @property
73 | def time_cost(self):
74 | return (self._end_time-self._start_time) / self._log_steps
75 |
76 | @property
77 | def total_steps(self):
78 | return self._total_steps
79 |
80 | @property
81 | def epoch(self):
82 | return self._cur_epoch
83 |
84 | @property
85 | def mode(self):
86 | return self._mode
87 |
88 | @property
89 | def log_path(self):
90 | return self._log_path
91 |
92 |
93 | @property
94 | def learning_rate(self):
95 | return self._decay_rates['learning_rate']
96 |
97 | @property
98 | def teach_ratio(self):
99 | return self._decay_rates['teach_ratio']
100 |
101 | @property
102 | def temperature(self):
103 | return self._decay_rates['temperature']
104 |
105 |
106 | #------------------------------------
107 | class SimpleLogger(InfoLogger):
108 | def __init__(self, mode):
109 | super(SimpleLogger, self).__init__(mode)
110 | self._gen_loss = 0.0
111 |
112 | def add_losses(self, gen_loss):
113 | self.add_step()
114 | self._gen_loss += gen_loss
115 |
116 |
117 | def print_log(self, epoch=None):
118 |
119 | gen_loss = self._gen_loss / self.total_steps
120 | ppl = np.exp(gen_loss)
121 |
122 | if self.mode == 'train':
123 | process_info = "epoch: %d, %s, %.2fs per iter, lr: %.4f, tr: %.2f, tau: %.3f" % (self.epoch,
124 | self.cur_process, self.time_cost, self.learning_rate, self.teach_ratio, self.temperature)
125 | else:
126 | process_info = "epoch: %d, lr: %.4f, tr: %.2f, tau: %.3f" % (
127 | epoch, self.learning_rate, self.teach_ratio, self.temperature)
128 |
129 | train_info = " gen loss: %.3f ppl:%.2f" % (gen_loss, ppl)
130 |
131 |
132 | print (process_info)
133 | print (train_info)
134 | print ("______________________")
135 |
136 | info = process_info + "\n" + train_info
137 | fout = open(self.log_path, 'a')
138 | fout.write(info + "\n\n")
139 | fout.close()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WMPoetry
2 | The source code of [*Chinese Poetry Generation with a Working Memory Model*](https://www.ijcai.org/Proceedings/2018/0633.pdf) (IJCAI 2018).
3 |
4 | ## 1. Rights
5 | All rights reserved.
6 |
7 | ## 2. Requirements
8 | * python>=3.7.0
9 | * pytorch>=1.3.1
10 | * matplotlib>=2.2.3
11 |
12 | A Tensorflow version of our model is available [here](https://github.com/XiaoyuanYi/WMPoetry).
13 |
14 | ## 3. Data Preparation
15 | To train the model and generate poems, please
16 |
17 | * add the training, validation and testing sets of our [THU-CCPC](https://github.com/THUNLP-AIPoet/Datasets/tree/master/CCPC) data into the *WMPoetry/preprocess/* directory;
18 | * add the pingsheng.txt, zesheng.txt, pingshui.txt and pingshui_amb.pkl files of our [THU-CRRD](https://github.com/THUNLP-AIPoet/Datasets/tree/master/CRRD) set into the *WMPoetry/data/* directory.
19 |
20 | ## 4. Preprocessing
21 | In WMPoetry/preprocess/, just run:
22 | ```
23 | python preprocess.py
24 | ```
25 |
26 | Then, move the produced vocab.pickle, ivocab.pickle, train_data.pickle and valid_data.pickle into *WMPoetry/corpus/*; and move test_inps.txt, test_trgs.txt and training_lines.txt into *WMPoetry/data/*.
27 |
28 | ## 5. Training
29 | In WMPoetry/codes/, run:
30 | ```
31 | python train.py
32 | ```
33 | The encoder and decoder will be pre-trained as a denoising seq2seq model, and then the Working Memory model is trained based on the pre-trained one.
34 |
35 | One can also edit WMPoetry/codes/**config.py** to modify the configuration, such as the hidden size, embedding size, data path, training epoch, learning rate and so on.
36 |
37 | During the training process, some training information is outputed, such as:
38 |
39 |

40 |
41 | The training and validation information is saved in the log directory, e.g., WMPoetry/log/.
42 |
43 | ## 6. Generation
44 | To generate a poem in an interactive interface, in WMPoetry/codes/, run:
45 | ```
46 | python generate.py -v 1
47 | ```
48 | Then one can input some keywords, select the genre pattern and the rhyme category, and then get the generated pome:
49 |
50 |
55 |
56 | We provide a genere pattern file for Chinese classical quatrains, please refer to WMPoetry/data/GenrePatterns.txt for details.
57 |
58 | By running:
59 | ```
60 | python generate.py -v 1 -s 1
61 | ```
62 | , one can manually select each generated line from the beam candidates.
63 |
64 |
65 |

66 |
67 |
68 | By running:
69 | ```
70 | python generate.py -v 1 -d 1
71 | ```
72 | , one can get the visualization of memory reading probabilities and the contents stored in the memory for each generation step, such as:
73 |
74 |
75 |

76 |
77 |
78 | These visualization pictures are saved in the log directory. *NOTE*: We use simhei font in matplotlib for Chinese characters. Therefore, to use our provided visualization script, please make sure you have correctly installed and set your font.
79 |
80 |
81 | To generate poems with an input testing file, which contains a set of keywords and genre patterns, run:
82 | ```
83 | python generate.py -m file -i ../data/test_inps.txt -o test_outs.txt
84 | ```
85 |
86 | ## 7. Cite
87 | If you use our source code, please kindly cite this paper:
88 |
89 | Xiaoyuan Yi, Maosong Sun, Ruoyu Li and Zonghan Yang. Chinese Poetry Generation with a Working Memory Model. *In Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence*, pages 4553–4559, Stockholm, Sweden, 2018.
90 |
91 | The bib format is as follows:
92 | ```
93 | @inproceedings{Yimemory:18,
94 | author = {Xiaoyuan Yi and Maosong Sun and Ruoyu Li and Zonghan Yang},
95 | title = {Chinese Poetry Generation with a Working Memory Mode},
96 | year = "2018",
97 | pages = "4553--4559",
98 | booktitle = {Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence},
99 | address = {Stockholm, Sweden}
100 | }
101 | ```
102 | ## 8. System
103 | This work is a part of the automatic Chinese poetry generation system, [THUAIPoet (Jiuge, 九歌)](https://jiuge.thunlp.cn) developed by Research Center for Natural Language Processing, Computational Humanities and Social Sciences, Tsinghua University (清华大学人工智能研究院, 自然语言处理与社会人文计算研究中心). Please refer to [THUNLP](https://github.com/thunlp) and [THUNLP Lab](http://nlp.csai.tsinghua.edu.cn/site2/) for more information.
104 |
105 | 
106 |
107 | ## 9. Contact
108 | If you have any questions, suggestions or bug reports, please feel free to email yi-xy16@mails.tsinghua.edu.cn or mtmoonyi@gmail.com.
109 |
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:07:53
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import os
14 | import torch
15 | import torch.nn.functional as F
16 |
17 | import random
18 |
19 | def save_checkpoint(model_dir, epoch, model, prefix='', optimizer=None):
20 | # save model state dict
21 | checkpoint_name = "model_ckpt_{}_{}e.tar".format(prefix, epoch)
22 | model_state_path = os.path.join(model_dir, checkpoint_name)
23 |
24 | saved_dic = {
25 | 'epoch': epoch,
26 | 'model_state_dict': model.state_dict()
27 | }
28 |
29 | if optimizer is not None:
30 | saved_dic['optimizer'] = optimizer.state_dict()
31 |
32 |
33 | torch.save(saved_dic, model_state_path)
34 |
35 | # write checkpoint information
36 | log_path = os.path.join(model_dir, "ckpt_list.txt")
37 | fout = open(log_path, 'a')
38 | fout.write(checkpoint_name+"\n")
39 | fout.close()
40 |
41 |
42 | def restore_checkpoint(model_dir, device, model, optimizer=None):
43 | ckpt_list_path = os.path.join(model_dir, "ckpt_list.txt")
44 | if not os.path.exists(ckpt_list_path):
45 | print ("checkpoint list not exists, creat new one!")
46 | return None
47 |
48 | # get latest ckpt name
49 | fin = open(ckpt_list_path, 'r')
50 | latest_ckpt_path = fin.readlines()[-1].strip()
51 | fin.close()
52 |
53 | latest_ckpt_path = os.path.join(model_dir, latest_ckpt_path)
54 | if not os.path.exists(latest_ckpt_path):
55 | print ("latest checkpoint not exists!")
56 | return None
57 |
58 |
59 | print ("restore checkpoint from %s" % (latest_ckpt_path))
60 | print ("loading...")
61 | checkpoint = torch.load(latest_ckpt_path, map_location=device)
62 | #checkpoint = torch.load(latest_ckpt_path)
63 | print ("load state dic, params: %d..." % (len(checkpoint['model_state_dict'])))
64 | model.load_state_dict(checkpoint['model_state_dict'])
65 |
66 |
67 | if optimizer is not None:
68 | print ("load optimizer dic...")
69 | optimizer.load_state_dict(checkpoint['optimizer'])
70 |
71 |
72 | epoch = checkpoint['epoch']
73 |
74 |
75 | return epoch
76 |
77 |
78 | def sample_dseq(inputs, targets, logits, sample_num, tool):
79 | # inps, trgs [batch size, sen len]
80 | # logits [batch size, trg len, vocab size]
81 | batch_size = inputs.size(0)
82 | inp_len = inputs.size(1)
83 | trg_len = targets.size(1)
84 | out_len = logits.size(1)
85 |
86 |
87 | sample_num = min(sample_num, batch_size)
88 |
89 | # randomly select some examples
90 | sample_ids = random.sample(list(range(0, batch_size)), sample_num)
91 |
92 | for sid in sample_ids:
93 | # Build lines
94 | inps = [inputs[sid, t].item() for t in range(0, inp_len)]
95 | sline = tool.idxes2line(inps)
96 |
97 | # -------------------------------------------
98 | trgs = [targets[sid, t].item() for t in range(0, trg_len)]
99 | tline = tool.idxes2line(trgs)
100 |
101 |
102 | # ------------------------------------------
103 | probs = F.softmax(logits, dim=-1)
104 | outs = [probs[sid, t, :].cpu().data.numpy() for t in range(0, out_len)]
105 | oline = tool.greedy_search(outs)
106 |
107 |
108 | print("inp: " + sline)
109 | print("trg: " + tline)
110 | print("out: " + oline)
111 | print ("")
112 |
113 |
114 | #------------------------------
115 | def sample_wm(keys, all_trgs, all_outs, sample_num, tool):
116 | batch_size = all_trgs[0].size(0)
117 | sample_num = min(sample_num, batch_size)
118 |
119 | # random select some examples
120 | sample_ids = random.sample(list(range(0, batch_size)), sample_num)
121 |
122 | for sid in sample_ids:
123 | key_lines = []
124 | for key in keys:
125 | key_idxes = [key[sid, t].item() for t in range(0, key.size(1))]
126 | key_str = tool.idxes2line(key_idxes)
127 | if len(key_str) == 0:
128 | key_str = "PAD"
129 | key_lines.append(key_str)
130 |
131 | trg_lines = []
132 | for trg in all_trgs:
133 | trg_idxes = [trg[sid, t].item() for t in range(0, trg.size(1))]
134 | trg_lines.append(tool.idxes2line(trg_idxes))
135 |
136 | out_lines = []
137 | for out in all_outs:
138 | probs = F.softmax(out, dim=-1)
139 | out_probs = [probs[sid, t, :].cpu().data.numpy() for t in range(0, probs.size(1))]
140 | out_lines.append(tool.greedy_search(out_probs))
141 |
142 | #--------------------------------------------
143 | print("keywords: " + "|".join(key_lines))
144 | print("target: " + "|".join(trg_lines))
145 | print("output: " + "|".join(out_lines))
146 | print ("")
147 |
148 |
149 | def print_parameter_list(model, prefix=None):
150 | params = model.named_parameters()
151 |
152 | param_num = 0
153 | for name, param in params:
154 | if prefix is not None:
155 | seg = name.split(".")[1]
156 | if seg in prefix:
157 | print(name, param.size())
158 | param_num += 1
159 | else:
160 | print(name, param.size())
161 | param_num += 1
162 |
163 | print ("params num: %d" % (param_num))
164 | #------------------------------
--------------------------------------------------------------------------------
/codes/visualization.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 22:04:36
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | from matplotlib import pyplot as plt
14 | plt.rcParams['font.family'] = ['simhei']
15 |
16 | from matplotlib.colors import from_levels_and_colors
17 |
18 | import numpy as np
19 | import copy
20 |
21 | import torch
22 |
23 | class Visualization(object):
24 | """docstring for LogInfo"""
25 | def __init__(self, topic_slots, history_slots, log_path):
26 | super(Visualization).__init__()
27 |
28 | self._topic_slots = topic_slots
29 | self._history_slots = history_slots
30 | self._log_path = log_path
31 |
32 |
33 | def reset(self, keywords):
34 | self._keywords = keywords
35 | self._history_mem = [' ']*self._history_slots
36 | self._gen_lines = []
37 |
38 |
39 | def add_gen_line(self, line):
40 | self._gen_lines.append(line.strip())
41 |
42 | def normalization(self, ori_matrix):
43 | new_matrix = ori_matrix / ori_matrix.sum(axis=1, keepdims=True)
44 | return new_matrix
45 |
46 |
47 | def draw(self, read_log, write_log, step, visual_mode):
48 | assert visual_mode in [0, 1, 2]
49 | # read_log: (1, 1, mem_slots) * L_gen
50 | # write_log: (B, L_gen, mem_slots)
51 | current_gen_chars = [c for c in self._gen_lines[-1]]
52 | gen_len = len(current_gen_chars)
53 |
54 | if len(self._gen_lines) >= 2:
55 | last_gen_chars = [c for c in self._gen_lines[-2]]
56 | last_gen_len = len(last_gen_chars)
57 | else:
58 | last_gen_chars = [''] * gen_len
59 | last_gen_len = gen_len
60 |
61 | # (L_gen, mem_slots)
62 | mem_slots = self._topic_slots+self._history_slots+last_gen_len
63 | read_matrix = torch.cat(read_log, dim=1)[0, 0:gen_len, 0:mem_slots].detach().cpu().numpy()
64 | read_matrix = self.normalization(read_matrix)
65 |
66 | plt.figure(figsize=(11, 5))
67 |
68 | # visualization of reading attention weights
69 | num_levels = 100
70 | vmin, vmax = read_matrix.min(), read_matrix.max()
71 | midpoint = 0
72 | levels = np.linspace(vmin, vmax, num_levels)
73 | midp = np.mean(np.c_[levels[:-1], levels[1:]], axis=1)
74 | vals = np.interp(midp, [vmin, midpoint, vmax], [0, 0.5, 1])
75 | colors = plt.cm.seismic(vals)
76 | cmap, norm = from_levels_and_colors(levels, colors)
77 |
78 |
79 | plt.imshow(read_matrix, cmap=cmap, interpolation='none')
80 |
81 | # print generated chars and chars in the memory
82 | fontsize = 14
83 |
84 | plt.text(0.2, gen_len+0.5, "Topic Memory", fontsize=fontsize)
85 | plt.text(self._topic_slots, gen_len+0.5, "History Memory", fontsize=fontsize)
86 | if last_gen_len == 5:
87 | shift = 5
88 | else:
89 | shift = 6
90 | plt.text(self._topic_slots+shift, gen_len+0.5, "Local Memory", fontsize=fontsize)
91 |
92 | # topic memory
93 | for i in range(0, len(self._keywords)):
94 | key = self._keywords[i]
95 | if len(key) == 1:
96 | key = " " + key + " "
97 | key = key + "|"
98 | plt.text(i-0.4,-0.7, key, fontsize=fontsize)
99 |
100 | start_pos = self._topic_slots
101 | end_pos = self._topic_slots + self._history_slots
102 |
103 | # history memory
104 | for i in range(start_pos, end_pos):
105 | c = self._history_mem[i - start_pos]
106 | if i == end_pos - 1:
107 | c = c + " |"
108 |
109 | plt.text(i-0.2,-0.7, c, fontsize=fontsize)
110 |
111 | start_pos = end_pos
112 | end_pos = start_pos + last_gen_len
113 |
114 | # local memory
115 | for i in range(start_pos, end_pos):
116 | idx = i - start_pos
117 | plt.text(i-0.2,-0.7, last_gen_chars[idx], fontsize=fontsize)
118 |
119 | # generated line
120 | for i in range(0, len(current_gen_chars)):
121 | plt.text(-1.2, i+0.15, current_gen_chars[i], fontsize=fontsize)
122 |
123 | plt.colorbar()
124 | plt.tick_params(labelbottom=False, labelleft=False)
125 |
126 |
127 | x_major_locator = plt.MultipleLocator(1)
128 | y_major_locator = plt.MultipleLocator(1)
129 | ax = plt.gca()
130 | ax.xaxis.set_major_locator(x_major_locator)
131 | ax.yaxis.set_major_locator(y_major_locator)
132 | #plt.tight_layout()
133 |
134 |
135 | if visual_mode == 1:
136 | fig = plt.gcf()
137 | fig.savefig(self._log_path + 'visual_step_{}.png'.format(step), dpi=300, quality=100, bbox_inches="tight")
138 | elif visual_mode == 2:
139 | plt.show()
140 |
141 |
142 | # update history memory
143 | if write_log is not None:
144 | if len(last_gen_chars) == 0:
145 | print ("last generated line is empty!")
146 |
147 | write_log = write_log[0, :, :].detach().cpu().numpy()
148 | history_mem = copy.deepcopy(self._history_mem)
149 | for i, c in enumerate(last_gen_chars):
150 | selected_slot = np.argmax(write_log[i, :])
151 | if selected_slot >= self._history_slots:
152 | continue
153 | history_mem[selected_slot] = c
154 |
155 | self._history_mem = history_mem
156 |
--------------------------------------------------------------------------------
/codes/dseq_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:39:50
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import torch
14 |
15 | from scheduler import ISRScheduler
16 | from criterion import Criterion
17 | from decay import ExponentialDecay
18 | from logger import SimpleLogger
19 | import utils
20 |
21 | class DSeqTrainer(object):
22 |
23 | def __init__(self, hps, device):
24 | self.hps = hps
25 | self.device = device
26 |
27 |
28 | def run_validation(self, epoch, wm_model, criterion, tool, lr):
29 | logger = SimpleLogger('valid')
30 | logger.set_batch_num(tool.valid_batch_num)
31 | logger.set_log_path(self.hps.dseq_valid_log_path)
32 | logger.set_rate('learning_rate', lr)
33 | logger.set_rate('teach_ratio', wm_model.get_teach_ratio())
34 |
35 | for step in range(0, tool.valid_batch_num):
36 |
37 | batch = tool.valid_batches[step]
38 |
39 | inps = batch[0].to(self.device)
40 | trgs = batch[1].to(self.device)
41 | ph_inps = batch[2].to(self.device)
42 | len_inps = batch[3].to(self.device)
43 |
44 | with torch.no_grad():
45 | gen_loss, _ = self.run_step(wm_model, None, criterion,
46 | inps, trgs, ph_inps, len_inps, True)
47 | logger.add_losses(gen_loss)
48 |
49 | logger.print_log(epoch)
50 |
51 |
52 | def run_step(self, wm_model, optimizer, criterion,
53 | inps, trgs, ph_inps, len_inps, valid=False):
54 | if not valid:
55 | optimizer.zero_grad()
56 |
57 | outs = wm_model.dseq_graph(inps, trgs, ph_inps, len_inps)
58 |
59 | loss = criterion(outs, trgs)
60 |
61 | if not valid:
62 | loss.backward()
63 | torch.nn.utils.clip_grad_norm_(wm_model.dseq_parameters(),
64 | self.hps.clip_grad_norm)
65 | optimizer.step()
66 |
67 | return loss.item(), outs
68 |
69 |
70 | def run_train(self, wm_model, tool, optimizer, criterion, logger):
71 | logger.set_start_time()
72 |
73 | for step in range(0, tool.train_batch_num):
74 |
75 | batch = tool.train_batches[step]
76 |
77 | inps = batch[0].to(self.device)
78 | trgs = batch[1].to(self.device)
79 | ph_inps = batch[2].to(self.device)
80 | len_inps = batch[3].to(self.device)
81 |
82 | gen_loss, outs = self.run_step(wm_model, optimizer, criterion,
83 | inps, trgs, ph_inps, len_inps)
84 |
85 | logger.add_losses(gen_loss)
86 | logger.set_rate("learning_rate", optimizer.rate())
87 | if step % self.hps.dseq_log_steps == 0:
88 | logger.set_end_time()
89 | utils.sample_dseq(inps, trgs, outs, self.hps.sample_num, tool)
90 | logger.print_log()
91 | logger.set_start_time()
92 |
93 |
94 |
95 | def train(self, wm_model, tool):
96 | #utils.print_parameter_list(wm_model, wm_model.dseq_parameter_names())
97 |
98 | # load data for pre-training
99 | print ("building data for dseq...")
100 | tool.build_data(self.hps.train_data, self.hps.valid_data,
101 | self.hps.dseq_batch_size, mode='dseq')
102 |
103 | print ("train batch num: %d" % (tool.train_batch_num))
104 | print ("valid batch num: %d" % (tool.valid_batch_num))
105 |
106 | #input("please check the parameters, and then press any key to continue >")
107 |
108 |
109 | # training logger
110 | logger = SimpleLogger('train')
111 | logger.set_batch_num(tool.train_batch_num)
112 | logger.set_log_steps(self.hps.dseq_log_steps)
113 | logger.set_log_path(self.hps.dseq_train_log_path)
114 | logger.set_rate('learning_rate', 0.0)
115 | logger.set_rate('teach_ratio', 1.0)
116 |
117 |
118 | # build optimizer
119 | opt = torch.optim.AdamW(wm_model.dseq_parameters(),
120 | lr=1e-3, betas=(0.9, 0.99), weight_decay=self.hps.weight_decay)
121 | optimizer = ISRScheduler(optimizer=opt, warmup_steps=self.hps.dseq_warmup_steps,
122 | max_lr=self.hps.dseq_max_lr, min_lr=self.hps.dseq_min_lr,
123 | init_lr=self.hps.dseq_init_lr, beta=0.6)
124 |
125 | wm_model.train()
126 |
127 | criterion = Criterion(self.hps.pad_idx)
128 |
129 | # tech forcing ratio decay
130 | tr_decay_tool = ExponentialDecay(self.hps.dseq_burn_down_tr, self.hps.dseq_decay_tr,
131 | self.hps.dseq_min_tr)
132 |
133 | # train
134 | for epoch in range(1, self.hps.dseq_epoches+1):
135 |
136 | self.run_train(wm_model, tool, optimizer, criterion, logger)
137 |
138 | if epoch % self.hps.dseq_validate_epoches == 0:
139 | print("run validation...")
140 | wm_model.eval()
141 | print ("in training mode: %d" % (wm_model.training))
142 | self.run_validation(epoch, wm_model, criterion, tool, optimizer.rate())
143 | wm_model.train()
144 | print ("validation Done: %d" % (wm_model.training))
145 |
146 |
147 | if (self.hps.dseq_save_epoches >= 1) and \
148 | (epoch % self.hps.dseq_save_epoches) == 0:
149 | # save checkpoint
150 | print("saving model...")
151 | utils.save_checkpoint(self.hps.model_dir, epoch, wm_model, prefix="dseq")
152 |
153 |
154 | logger.add_epoch()
155 |
156 | print ("teach forcing ratio decay...")
157 | wm_model.set_teach_ratio(tr_decay_tool.do_step())
158 | logger.set_rate('teach_ratio', tr_decay_tool.get_rate())
159 |
160 | print("shuffle data...")
161 | tool.shuffle_train_data()
--------------------------------------------------------------------------------
/codes/generate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:17:27
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | from generator import Generator
14 | from config import hparams, device
15 | import copy
16 | import argparse
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description="The parametrs for the generator.")
20 | parser.add_argument("-m", "--mode", type=str, choices=['interact', 'file'], default='interact',
21 | help='The mode of generation. interact: generate in a interactive mode.\
22 | file: take an input file and generate poems for each input in the file.')
23 | parser.add_argument("-b", "--bsize", type=int, default=20, help="beam size, 20 by default.")
24 | parser.add_argument("-v", "--verbose", type=int, default=0, choices=[0, 1, 2, 3],
25 | help="Show other information during the generation, False by default.")
26 | parser.add_argument("-d", "--draw", type=int, default=0, choices=[0, 1, 2],
27 | help="Show the visualization of memory reading and writing. It only works in the interact mode.\
28 | 0: not work, 1: save the visualization as pictures, 2: show the visualization at each step.")
29 | parser.add_argument("-s", "--select", type=int, default=0,
30 | help="If manually select each generated line from beam candidates? False by default.\
31 | It works only in the interact mode.")
32 | parser.add_argument("-i", "--inp", type=str,
33 | help="input file path. it works only in the file mode.")
34 | parser.add_argument("-o", "--out", type=str,
35 | help="output file path. it works only in the file mode")
36 | return parser.parse_args()
37 |
38 |
39 | class GenerateTool(object):
40 | """docstring for GenerateTool"""
41 | def __init__(self):
42 | super(GenerateTool, self).__init__()
43 | self.generator = Generator(hparams, device)
44 | self._load_patterns(hparams.data_dir+"/GenrePatterns.txt")
45 |
46 |
47 | def _load_patterns(self, path):
48 | with open(path, 'r') as fin:
49 | lines = fin.readlines()
50 |
51 | self._patterns = []
52 | '''
53 | each line contains:
54 | pattern id, pattern name, the number of lines,
55 | pattern: 0 either, 31 pingze, 32 ze, 33 rhyme position
56 | '''
57 | for line in lines:
58 | line = line.strip()
59 | para = line.split("#")
60 | pas = para[3].split("|")
61 | newpas = []
62 | for pa in pas:
63 | pa = pa.split(" ")
64 | newpas.append([int(p) for p in pa])
65 |
66 | self._patterns.append((para[1], newpas))
67 |
68 | self.p_num = len(self._patterns)
69 | print ("load %d patterns." % (self.p_num))
70 |
71 |
72 | def build_pattern(self, pstr):
73 | pstr_vec = pstr.split("|")
74 | patterns = []
75 | for pstr in pstr_vec:
76 | pas = pstr.split(" ")
77 | pas = [int(pa) for pa in pas]
78 | patterns.append(pas)
79 |
80 | return patterns
81 |
82 |
83 | def generate_file(self, args):
84 | beam_size = args.bsize
85 | verbose = args.verbose
86 | manu = True if args.select ==1 else False
87 |
88 | assert args.inp is not None
89 | assert args.out is not None
90 |
91 | with open(args.inp, 'r') as fin:
92 | inps = fin.readlines()
93 |
94 |
95 | fout = open(args.out, 'w')
96 |
97 | poems = []
98 | N = len(inps)
99 | log_step = max(int(N/100), 2)
100 | for i, inp in enumerate(inps):
101 | para = inp.strip().split("#")
102 | keywords = para[0].split(" ")
103 | pattern = self.build_pattern(para[1])
104 |
105 | lines, info = self.generator.generate_one(keywords, pattern,
106 | beam_size, verbose, manu=manu)
107 |
108 | if len(lines) != 4:
109 | ans = info
110 | else:
111 | ans = "|".join(lines)
112 |
113 | fout.write(ans+"\n")
114 |
115 | if i % log_step == 0:
116 | print ("generating, %d/%d" % (i, N))
117 | fout.flush()
118 |
119 |
120 | fout.close()
121 |
122 |
123 | def _set_rhyme_into_pattern(self, ori_pattern, rhyme):
124 | pattern = copy.deepcopy(ori_pattern)
125 | for i in range(0, len(pattern)):
126 | if pattern[i][-1] == 33:
127 | pattern[i][-1] = rhyme
128 | return pattern
129 |
130 |
131 | def generate_manu(self, args):
132 | beam_size = args.bsize
133 | verbose = args.verbose
134 | manu = True if args.select ==1 else False
135 | visual_mode = args.draw
136 |
137 | while True:
138 | keys = input("please input keywords (with whitespace split), 4 at most > ")
139 | pattern_id = int(input("please select genre pattern 0~{} > ".format(self.p_num-1)))
140 | rhyme = int(input("please input rhyme id, 1~30> "))
141 |
142 | ori_pattern = self._patterns[pattern_id]
143 | name = ori_pattern[0]
144 | pattern = ori_pattern[1]
145 | pattern = self._set_rhyme_into_pattern(pattern, rhyme)
146 | print ("select pattern: %s" % (name))
147 |
148 | keywords = keys.strip().split(" ")
149 | lines, info = self.generator.generate_one(keywords, pattern,
150 | beam_size, verbose, manu=manu, visual=visual_mode)
151 |
152 | if len(lines) != 4:
153 | print("generation failed!")
154 | continue
155 | else:
156 | print("\n".join(lines))
157 |
158 |
159 | def main():
160 | args = parse_args()
161 | generate_tool = GenerateTool()
162 | if args.mode == 'interact':
163 | generate_tool.generate_manu(args)
164 | else:
165 | generate_tool.generate_file(args)
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/preprocess/pattern_extractor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi and Jiannan Liang
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:24:22
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 |
14 | import pickle
15 | import os
16 | import copy
17 |
18 | import numpy as np
19 |
20 | from rhythm_tool import RhythmRecognizer
21 |
22 |
23 | class PatternExtractor(object):
24 |
25 | def __init__(self, data_dir):
26 | '''
27 | rhythm patterns.
28 | for Chinese quatrains, we generalize four main poem-level patterns.
29 |
30 | NOTE: We use pingshuiyun (The Pingshui Rhyme Category)
31 | We only consider level-tone rhyme in terms of the requirements of
32 | Chinese classical quatrains.
33 | 0: either level or oblique tone; 1~30 rhyme categorizes,
34 | 31: level tone, 32: oblique tone
35 |
36 | '''
37 | self._RHYTHM_PATTERNS = {7: [[0, 32, 0, 31, 31, 32, 32], [0, 31, 0, 32, 32, 31, 31],
38 | [0, 32, 31, 31, 32, 32, 31], [0, 31, 0, 32, 31, 31, 32]],
39 | 5: [[0, 31, 31, 32, 32], [0, 32, 32, 31, 31], [31, 31, 32, 32, 31], [0, 32, 0, 31, 32]]}
40 |
41 |
42 | '''
43 | rhythm patterns.
44 | for Chinese quatrains, we generalize four main poem-level patterns
45 | '''
46 | self._RHYTHM_TYPES = [[0, 1, 3, 2], [1, 2, 0, 1], [2, 1, 3, 2], [3, 2, 0, 1]]
47 |
48 |
49 | self._rhythm_tool = RhythmRecognizer(data_dir+"pingsheng.txt", data_dir+"zesheng.txt")
50 |
51 | self._load_rhythm_dic(data_dir+"pingsheng.txt", data_dir+"zesheng.txt")
52 | self._load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl")
53 |
54 |
55 |
56 | def _load_rhythm_dic(self, level_path, oblique_path):
57 | with open(level_path, 'r') as fin:
58 | level_chars = fin.read()
59 |
60 | with open(oblique_path, 'r') as fin:
61 | oblique_chars = fin.read()
62 |
63 | self._level_list = list(level_chars)
64 | self._oblique_list = list(oblique_chars)
65 |
66 |
67 | print (" rhythm dic loaded, level tone chars: %d, oblique tone chars: %d" %\
68 | (len(self._level_list), len(self._oblique_list)))
69 |
70 |
71 | #------------------------------------------
72 | def _load_rhyme_dic(self, rhyme_dic_path, rhyme_disamb_path):
73 |
74 | self._rhyme_dic = {} # char id to rhyme category ids
75 | self._rhyme_idic = {} # rhyme category id to char ids
76 |
77 | with open(rhyme_dic_path, 'r') as fin:
78 | lines = fin.readlines()
79 |
80 | amb_count = 0
81 | for line in lines:
82 | (char, rhyme_id) = line.strip().split(' ')
83 |
84 | rhyme_id = int(rhyme_id)
85 |
86 | if not char in self._rhyme_dic:
87 | self._rhyme_dic.update({char:[rhyme_id]})
88 | elif not rhyme_id in self._rhyme_dic[char]:
89 | self._rhyme_dic[char].append(rhyme_id)
90 | amb_count += 1
91 |
92 | if not rhyme_id in self._rhyme_idic:
93 | self._rhyme_idic.update({rhyme_id:[char]})
94 | else:
95 | self._rhyme_idic[rhyme_id].append(char)
96 |
97 | print (" rhyme dic loaded, ambiguous rhyme chars: %d" % (amb_count))
98 |
99 | # load data for rhyme disambiguation
100 | self._ngram_rhyme_map = {} # rhyme id list of each bigram or trigram
101 | self._char_rhyme_map = {} # the most likely rhyme id for each char
102 | # load the calculated data, if there is any
103 | #print (rhyme_disamb_path)
104 | assert rhyme_disamb_path is not None and os.path.exists(rhyme_disamb_path)
105 |
106 | with open(rhyme_disamb_path, 'rb') as fin:
107 | self._char_rhyme_map = pickle.load(fin)
108 | self._ngram_rhyme_map = pickle.load(fin)
109 |
110 | print (" rhyme disamb data loaded, cached chars: %d, ngrams: %d"
111 | % (len(self._char_rhyme_map), len(self._ngram_rhyme_map)))
112 |
113 |
114 | def get_line_rhyme(self, line):
115 | """ we use statistics of ngram to disambiguate the rhyme category,
116 | but there is still risk of mismatching and ambiguity
117 | """
118 | tail_char = line[-1]
119 |
120 | if tail_char in self._rhyme_dic:
121 | rhyme_candis = self._rhyme_dic[tail_char]
122 | if len(rhyme_candis) == 1:
123 | return rhyme_candis[0]
124 |
125 | if tail_char in self._char_rhyme_map:
126 | bigram = line[-2] + line[-1]
127 | if bigram in self._ngram_rhyme_map:
128 | return int(self._ngram_rhyme_map[bigram][0])
129 |
130 | trigram = line[-3] + line[-2] + line[-1]
131 | if trigram in self._ngram_rhyme_map:
132 | return int(self._ngram_rhyme_map[trigram][0])
133 |
134 | return int(self._char_rhyme_map[tail_char][0])
135 |
136 |
137 | return -1
138 |
139 |
140 | def get_poem_rhyme(self, sens):
141 | assert len(sens) == 4
142 | rhymes = [self.get_line_rhyme(sen) for sen in sens]
143 |
144 | #print (rhymes)
145 | #input(">")
146 |
147 | if rhymes[1] == -1 and rhymes[3] != -1:
148 | rhymes[1] = rhymes[3]
149 | elif rhymes[1] != -1 and rhymes[3] == -1:
150 | rhymes[3] = rhymes[1]
151 | elif rhymes[1] == -1 and rhymes[3] == -1:
152 | return []
153 |
154 |
155 | if (rhymes[0] != -1) and (rhymes[0]!= rhymes[1]):
156 | rhymes[0] = rhymes[1]
157 |
158 | rhymes[2] = -1
159 |
160 | return rhymes
161 |
162 |
163 | def pattern_complete(self, rhythm_ids):
164 | if rhythm_ids.count(-1) == 0:
165 | return rhythm_ids
166 |
167 | if rhythm_ids.count(-1) > 1:
168 | return []
169 |
170 | #print (rhythm_ids)
171 | for poem_pattern in self._RHYTHM_TYPES:
172 | eq = (np.array(poem_pattern) \
173 | == np.array(rhythm_ids)).astype(np.float)
174 |
175 | if np.sum(eq) != 3:
176 | continue
177 |
178 | pos = list(eq).index(0)
179 | rhythm_ids[pos] = poem_pattern[pos]
180 | #print (rhythm_ids)
181 | #input(">")
182 | return rhythm_ids
183 |
184 | return []
185 |
186 |
187 | def get_poem_rhythm(self, sens, length):
188 | assert len(sens) == 4
189 |
190 | rhythm_ids = []
191 | for sen in sens:
192 | #print (sen)
193 | rhythm_id = self._rhythm_tool.get_rhythm(sen)
194 | rhythm_ids.append(rhythm_id)
195 |
196 |
197 | rhythm_ids = self.pattern_complete(rhythm_ids)
198 |
199 | if len(rhythm_ids) == 0:
200 | return []
201 |
202 | rhythm_pattern = [copy.deepcopy(self._RHYTHM_PATTERNS[length][id])
203 | for id in rhythm_ids]
204 |
205 | return rhythm_pattern
--------------------------------------------------------------------------------
/preprocess/rhythm_tool.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Ruoyu Li
3 | # @Description:
4 | '''
5 | Copyright 2019 THUNLP Lab. All Rights Reserved.
6 | This code is part of the online Chinese poetry generation system, Jiuge.
7 | System URL: https://jiuge.thunlp.cn/.
8 | Github: https://github.com/THUNLP-AIPoet.
9 | '''
10 | class RhythmRecognizer(object):
11 | """Get the rhythm id of a input line
12 | This tool can be applied to Chinese classical quatrains only
13 | """
14 |
15 | def __init__(self, ping_file, ze_file):
16 |
17 | # read level tone char list
18 | with open(ping_file, 'r') as fin:
19 | self.__ping = fin.read()
20 | #print (type(self.__ping))
21 |
22 | with open(ze_file, 'r') as fin:
23 | self.__ze = fin.read()
24 | #print (type(self.__ze))
25 |
26 | def get_rhythm(self, sentence):
27 | # print "#" + sentence + "#"
28 | if(len(sentence) == 5):
29 | #1
30 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze):
31 | return 0
32 | #2
33 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze):
34 | return 0
35 | #3
36 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze):
37 | return 0
38 | #4
39 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
40 | return 0
41 | #5
42 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
43 | return 0
44 | #6
45 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping):
46 | return 1
47 | #7
48 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping):
49 | return 1
50 | #8
51 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze):
52 | return 3
53 | #9
54 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
55 | return 3
56 | #10
57 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze):
58 | return 3
59 | #11
60 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
61 | return 3
62 | #12
63 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping):
64 | return 2
65 | #13
66 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping):
67 | return 2
68 | #14
69 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping):
70 | return 2
71 |
72 |
73 | elif (len(sentence) == 7):
74 | #1
75 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze):
76 | return 0
77 | #2
78 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ze):
79 | return 0
80 | #3
81 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze):
82 | return 0
83 | #4
84 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
85 | return 0
86 | #5
87 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
88 | return 0
89 | #6
90 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping):
91 | return 1
92 | #7
93 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping):
94 | return 1
95 | #8
96 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze):
97 | return 3
98 | #9
99 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
100 | return 3
101 | #10
102 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze):
103 | return 3
104 | #11
105 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
106 | return 3
107 | #12
108 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ping):
109 | return 2
110 | #13
111 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping):
112 | return 2
113 | #14
114 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping):
115 | return 2
116 | else:
117 | return -2
118 | return -1
--------------------------------------------------------------------------------
/codes/wm_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:39:43
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import torch
14 | from scheduler import ISRScheduler
15 | from criterion import Criterion
16 | from decay import ExponentialDecay
17 | from logger import SimpleLogger
18 | import utils
19 |
20 |
21 | class WMTrainer(object):
22 |
23 | def __init__(self, hps, device):
24 | self.hps = hps
25 | self.device = device
26 |
27 | def run_validation(self, epoch, wm_model, criterion, tool, lr):
28 | logger = SimpleLogger('valid')
29 | logger.set_batch_num(tool.valid_batch_num)
30 | logger.set_log_path(self.hps.valid_log_path)
31 | logger.set_rate('learning_rate', lr)
32 | logger.set_rate('teach_ratio', wm_model.get_teach_ratio())
33 | logger.set_rate('temperature', wm_model.get_tau())
34 |
35 | for step in range(0, tool.valid_batch_num):
36 |
37 | batch = tool.valid_batches[step]
38 |
39 | all_inps = [inps.to(self.device) for inps in batch[0]]
40 | all_trgs = [trgs.to(self.device) for trgs in batch[1]]
41 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in batch[2]]
42 | all_len_inps = [len_inps.to(self.device) for len_inps in batch[3]]
43 | keys = [key.to(self.device) for key in batch[4]]
44 |
45 | with torch.no_grad():
46 | gen_loss, _ = self.run_step(wm_model, None, criterion,
47 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys, True)
48 |
49 | logger.add_losses(gen_loss)
50 |
51 | logger.print_log(epoch)
52 |
53 |
54 | def run_step(self, wm_model, optimizer, criterion,
55 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys, valid=False):
56 |
57 | if not valid:
58 | optimizer.zero_grad()
59 |
60 | all_outs = wm_model(all_inps, all_trgs, all_ph_inps, all_len_inps, keys)
61 |
62 | loss_vec = []
63 | assert len(all_outs) == len(all_trgs)
64 | for out, trg in zip(all_outs, all_trgs):
65 | loss = criterion(out, trg)
66 | loss_vec.append(loss.unsqueeze(0))
67 |
68 | loss = torch.cat(loss_vec, dim=0).mean()
69 |
70 | if not valid:
71 | loss.backward()
72 | torch.nn.utils.clip_grad_norm_(wm_model.parameters(),
73 | self.hps.clip_grad_norm)
74 | optimizer.step()
75 |
76 | return loss.item(), all_outs
77 |
78 | # -------------------------------------------------------------------------
79 | def run_train(self, wm_model, tool, optimizer, criterion, logger):
80 |
81 | logger.set_start_time()
82 |
83 | for step in range(0, tool.train_batch_num):
84 |
85 | batch = tool.train_batches[step]
86 | all_inps = [inps.to(self.device) for inps in batch[0]]
87 | all_trgs = [trgs.to(self.device) for trgs in batch[1]]
88 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in batch[2]]
89 | all_len_inps = [len_inps.to(self.device) for len_inps in batch[3]]
90 | keys = [key.to(self.device) for key in batch[4]]
91 |
92 | # train the classifier, recognition network and decoder
93 | gen_loss, all_outs = self.run_step(wm_model, optimizer, criterion,
94 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys)
95 |
96 | logger.add_losses(gen_loss)
97 | logger.set_rate("learning_rate", optimizer.rate())
98 |
99 | # temperature annealing
100 | wm_model.set_tau(self.tau_decay_tool.do_step())
101 | logger.set_rate('temperature', self.tau_decay_tool.get_rate())
102 |
103 | if step % self.hps.log_steps == 0:
104 | logger.set_end_time()
105 | utils.sample_wm(keys, all_trgs, all_outs, self.hps.sample_num, tool)
106 | logger.print_log()
107 | logger.set_start_time()
108 |
109 |
110 |
111 | def train(self, wm_model, tool):
112 | #utils.print_parameter_list(wm_model)
113 | # load data for pre-training
114 | print ("building data for wm...")
115 | tool.build_data(self.hps.train_data, self.hps.valid_data,
116 | self.hps.batch_size, mode='wm')
117 |
118 | print ("train batch num: %d" % (tool.train_batch_num))
119 | print ("valid batch num: %d" % (tool.valid_batch_num))
120 |
121 |
122 | #input("please check the parameters, and then press any key to continue >")
123 |
124 | # training logger
125 | logger = SimpleLogger('train')
126 | logger.set_batch_num(tool.train_batch_num)
127 | logger.set_log_steps(self.hps.log_steps)
128 | logger.set_log_path(self.hps.train_log_path)
129 | logger.set_rate('learning_rate', 0.0)
130 | logger.set_rate('teach_ratio', 1.0)
131 | logger.set_rate('temperature', 1.0)
132 |
133 |
134 | # build optimizer
135 | opt = torch.optim.AdamW(wm_model.parameters(),
136 | lr=1e-3, betas=(0.9, 0.99), weight_decay=self.hps.weight_decay)
137 | optimizer = ISRScheduler(optimizer=opt, warmup_steps=self.hps.warmup_steps,
138 | max_lr=self.hps.max_lr, min_lr=self.hps.min_lr,
139 | init_lr=self.hps.init_lr, beta=0.6)
140 |
141 | wm_model.train()
142 |
143 | null_idxes = tool.load_function_tokens(self.hps.data_dir + "fchars.txt").to(self.device)
144 | wm_model.set_null_idxes(null_idxes)
145 |
146 | criterion = Criterion(self.hps.pad_idx)
147 |
148 | # change each epoch
149 | tr_decay_tool = ExponentialDecay(self.hps.burn_down_tr, self.hps.decay_tr, self.hps.min_tr)
150 | # change each iteration
151 | self.tau_decay_tool = ExponentialDecay(0, self.hps.tau_annealing_steps, self.hps.min_tau)
152 |
153 |
154 | # -----------------------------------------------------------
155 | # train with all data
156 | for epoch in range(1, self.hps.max_epoches+1):
157 |
158 | self.run_train(wm_model, tool, optimizer, criterion, logger)
159 |
160 | if epoch % self.hps.validate_epoches == 0:
161 | print("run validation...")
162 | wm_model.eval()
163 | print ("in training mode: %d" % (wm_model.training))
164 | self.run_validation(epoch, wm_model, criterion, tool, optimizer.rate())
165 | wm_model.train()
166 | print ("validation Done: %d" % (wm_model.training))
167 |
168 |
169 | if (self.hps.save_epoches >= 1) and \
170 | (epoch % self.hps.save_epoches) == 0:
171 | # save checkpoint
172 | print("saving model...")
173 | utils.save_checkpoint(self.hps.model_dir, epoch, wm_model, prefix="wm")
174 |
175 |
176 | logger.add_epoch()
177 |
178 | print ("teach forcing ratio decay...")
179 | wm_model.set_teach_ratio(tr_decay_tool.do_step())
180 | logger.set_rate('teach_ratio', tr_decay_tool.get_rate())
181 |
182 | print("shuffle data...")
183 | tool.shuffle_train_data()
--------------------------------------------------------------------------------
/codes/filter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi and Jiannan Liang
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:22:16
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import pickle
14 | import copy
15 | import os
16 |
17 | class PoetryFilter(object):
18 |
19 |
20 | def __init__(self, vocab, ivocab, data_dir):
21 | self._vocab = vocab
22 | self._ivocab = ivocab
23 |
24 | self._load_rhythm_dic(data_dir+"pingsheng.txt", data_dir+"zesheng.txt")
25 | self._load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl")
26 | self._load_line_lib(data_dir+"training_lines.txt")
27 |
28 |
29 | def _load_line_lib(self, data_path):
30 | self._line_lib = {}
31 |
32 | with open(data_path, 'r') as fin:
33 | lines = fin.readlines()
34 |
35 | for line in lines:
36 | line = line.strip()
37 | self._line_lib[line] = 1
38 |
39 | print (" line lib loaded, %d lines: " % (len(self._line_lib)))
40 |
41 |
42 |
43 | def _load_rhythm_dic(self, level_path, oblique_path):
44 | with open(level_path, 'r') as fin:
45 | level_chars = fin.read()
46 |
47 | with open(oblique_path, 'r') as fin:
48 | oblique_chars = fin.read()
49 |
50 | self._level_list = []
51 | self._oblique_list = []
52 | # convert char to id
53 | for char, idx in self._vocab.items():
54 | if char in level_chars:
55 | self._level_list.append(idx)
56 |
57 | if char in oblique_chars:
58 | self._oblique_list.append(idx)
59 |
60 | print (" rhythm dic loaded, level tone chars: %d, oblique tone chars: %d" %\
61 | (len(self._level_list), len(self._oblique_list)))
62 |
63 |
64 | #------------------------------------------
65 | def _load_rhyme_dic(self, rhyme_dic_path, rhyme_disamb_path):
66 |
67 | self._rhyme_dic = {} # char id to rhyme category ids
68 | self._rhyme_idic = {} # rhyme category id to char ids
69 |
70 | with open(rhyme_dic_path, 'r') as fin:
71 | lines = fin.readlines()
72 |
73 |
74 | amb_count = 0
75 | for line in lines:
76 | (char, rhyme_id) = line.strip().split(' ')
77 | if char not in self._vocab:
78 | continue
79 | char_id = self._vocab[char]
80 | rhyme_id = int(rhyme_id)
81 |
82 | if not char_id in self._rhyme_dic:
83 | self._rhyme_dic.update({char_id:[rhyme_id]})
84 | elif not rhyme_id in self._rhyme_dic[char_id]:
85 | self._rhyme_dic[char_id].append(rhyme_id)
86 |
87 |
88 | if not rhyme_id in self._rhyme_idic:
89 | self._rhyme_idic.update({rhyme_id:[char_id]})
90 | else:
91 | self._rhyme_idic[rhyme_id].append(char_id)
92 |
93 |
94 | # load data for rhyme disambiguation
95 | self._ngram_rhyme_map = {} # rhyme id list of each bigram or trigram
96 | self._char_rhyme_map = {} # the most likely rhyme id for each char
97 | # load the calculated data, if there is any
98 | #print (rhyme_disamb_path)
99 | assert rhyme_disamb_path is not None and os.path.exists(rhyme_disamb_path)
100 |
101 | with open(rhyme_disamb_path, 'rb') as fin:
102 | self._char_rhyme_map = pickle.load(fin)
103 | self._ngram_rhyme_map = pickle.load(fin)
104 |
105 | print (" rhyme disamb data loaded, cached chars: %d, ngrams: %d"
106 | % (len(self._char_rhyme_map), len(self._ngram_rhyme_map)))
107 |
108 |
109 |
110 | def get_line_rhyme(self, line):
111 | """ we use statistics of ngram to disambiguate the rhyme category,
112 | but there is still a risk of mismatching and ambiguity
113 | """
114 | tail_char = line[-1]
115 |
116 | if tail_char in self._char_rhyme_map:
117 | rhyme_candis = self._char_rhyme_map[tail_char]
118 | if len(rhyme_candis) == 1:
119 | return rhyme_candis[0]
120 |
121 | bigram = line[-2] + line[-1]
122 | if bigram in self._ngram_rhyme_map:
123 | return self._ngram_rhyme_map[bigram][0]
124 |
125 | trigram = line[-3] + line[-2] + line[-1]
126 | if trigram in self._ngram_rhyme_map:
127 | return self._ngram_rhyme_map[trigram][0]
128 |
129 | return self._char_rhyme_map[tail_char][0]
130 |
131 |
132 | if not tail_char in self._vocab:
133 | return -1
134 | else:
135 | tail_id = self._vocab[tail_char]
136 |
137 |
138 | if tail_id in self._rhyme_dic:
139 | return self._rhyme_dic[tail_id][0]
140 |
141 | return -1
142 |
143 | # ------------------------------
144 | def reset(self, length, verbose):
145 | assert length == 5 or length == 7
146 | self._length = length
147 | self._repetitive_ids = []
148 | self._verbose = verbose
149 |
150 |
151 | def add_repetitive(self, ids):
152 | self._repetitive_ids = list(set(ids+self._repetitive_ids))
153 |
154 |
155 | # -------------------------------
156 | def get_level_cids(self):
157 | return copy.deepcopy(self._level_list)
158 |
159 | def get_oblique_cids(self):
160 | return copy.deepcopy(self._oblique_list)
161 |
162 | def get_rhyme_cids(self, rhyme_id):
163 | if rhyme_id not in self._rhyme_idic:
164 | return []
165 | else:
166 | return copy.deepcopy(self._rhyme_idic[rhyme_id])
167 |
168 | def get_repetitive_ids(self):
169 | return copy.deepcopy(self._repetitive_ids)
170 |
171 |
172 | def filter_illformed(self, lines, costs, states, aligns, rhyme_id):
173 | if len(lines) == 0:
174 | return [], [], [], []
175 |
176 | new_lines, new_costs = [], []
177 | new_states, new_aligns = [], []
178 |
179 | len_error, lib_error, rhyme_error = 0, 0, 0
180 |
181 | for i in range(len(lines)):
182 | #print (lines[i])
183 | if len(lines[i]) < self._length:
184 | len_error += 1
185 | continue
186 |
187 | line = lines[i][0:self._length]
188 |
189 | # we filter out the lines that already exist in the
190 | # training set, to guarantee the novelty of generated poems
191 | if line in self._line_lib:
192 | lib_error += 1
193 | continue
194 |
195 | if 1 <= rhyme_id <= 30:
196 | if self.get_line_rhyme(line) != rhyme_id:
197 | rhyme_error != 1
198 | continue
199 |
200 | new_lines.append(line)
201 | new_costs.append(costs[i])
202 | new_states.append(states[i])
203 | new_aligns.append(aligns[i])
204 |
205 |
206 | if self._verbose >= 3:
207 | print ("input lines: %d, ilter out %d illformed lines, %d remain"
208 | % (len(lines), len(lines)-len(new_lines), len(new_lines)))
209 | print ("%d len error, %d exist in lib, %d rhyme error"
210 | % (len_error, lib_error, rhyme_error))
211 |
212 | return new_lines, new_costs, new_states, new_aligns
--------------------------------------------------------------------------------
/codes/rhythm_tool.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Ruoyu Li
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:21:37
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 |
14 | class RhythmRecognizer(object):
15 | """Get the rhythm id of a input line
16 | This tool can be applied to Chinese classical quatrains only
17 | """
18 |
19 | def __init__(self, ping_file, ze_file):
20 |
21 | # read level tone char list
22 | with open(ping_file, 'r') as fin:
23 | self.__ping = fin.read()
24 | #print (type(self.__ping))
25 |
26 | with open(ze_file, 'r') as fin:
27 | self.__ze = fin.read()
28 | #print (type(self.__ze))
29 |
30 | def get_rhythm(self, sentence):
31 | # print "#" + sentence + "#"
32 | if(len(sentence) == 5):
33 | #1
34 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze):
35 | return 0
36 | #2
37 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze):
38 | return 0
39 | #3
40 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze):
41 | return 0
42 | #4
43 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
44 | return 0
45 | #5
46 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
47 | return 0
48 | #6
49 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping):
50 | return 1
51 | #7
52 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping):
53 | return 1
54 | #8
55 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze):
56 | return 3
57 | #9
58 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
59 | return 3
60 | #10
61 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze):
62 | return 3
63 | #11
64 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze):
65 | return 3
66 | #12
67 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping):
68 | return 2
69 | #13
70 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping):
71 | return 2
72 | #14
73 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping):
74 | return 2
75 |
76 |
77 | elif (len(sentence) == 7):
78 | #1
79 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze):
80 | return 0
81 | #2
82 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ze):
83 | return 0
84 | #3
85 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze):
86 | return 0
87 | #4
88 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
89 | return 0
90 | #5
91 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
92 | return 0
93 | #6
94 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping):
95 | return 1
96 | #7
97 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping):
98 | return 1
99 | #8
100 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze):
101 | return 3
102 | #9
103 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
104 | return 3
105 | #10
106 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze):
107 | return 3
108 | #11
109 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze):
110 | return 3
111 | #12
112 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ping):
113 | return 2
114 | #13
115 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping):
116 | return 2
117 | #14
118 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping):
119 | return 2
120 | else:
121 | return -2
122 | return -1
123 |
--------------------------------------------------------------------------------
/codes/beam.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 18:04:40
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import numpy as np
14 | import torch
15 | import copy
16 |
17 |
18 | class Hypothesis(object):
19 | '''
20 | a hypothesis which holds the generated tokens,
21 | current state, beam score and memory reading weights
22 | '''
23 | def __init__(self, tokens, states, score, read_aligns):
24 | self.score = score
25 | self.states = states
26 | self.candidate = copy.deepcopy(tokens)
27 | self.read_aligns = read_aligns # (1, L_i, mem_slots)
28 |
29 |
30 | class PoetryBeam(object):
31 | def __init__(self, device, beam_size, length, B_ID, E_ID, UNK_ID,
32 | level_char_ids, oblique_char_ids):
33 | """Initialize params."""
34 | self.device = device
35 |
36 | self._length = length
37 | self._beam_size = beam_size
38 |
39 | self._B_ID = B_ID
40 | self._E_ID = E_ID
41 | self._UNK_ID = UNK_ID
42 |
43 | self._level_cids = level_char_ids
44 | self._oblique_cids = oblique_char_ids
45 |
46 |
47 | def reset(self, init_state, rhythms, rhyme_char_ids, repetitive_ids):
48 | # reset before generating each line
49 | self._hypotheses \
50 | = [Hypothesis([self._B_ID], [init_state.clone().detach()], 0.0, [])
51 | for _ in range(0, self._beam_size)]
52 |
53 | self._completed_hypotheses = []
54 |
55 | self._rhythms = rhythms # rhythm pattern of each chars in a line
56 | self._rhyme_cids = rhyme_char_ids # char ids in the required rhyme category
57 | self._repetitive_ids = repetitive_ids
58 |
59 |
60 | def get_candidates(self, completed=False, with_states=False):
61 | if completed:
62 | hypotheses = self._completed_hypotheses
63 | else:
64 | hypotheses = self._hypotheses
65 |
66 | candidates = [hypo.candidate for hypo in hypotheses]
67 | scores = [hypo.score for hypo in hypotheses]
68 |
69 | read_aligns = [hypo.read_aligns for hypo in hypotheses]
70 |
71 | if with_states:
72 | # (L, H) * B
73 | all_states = [hypo.states for hypo in hypotheses]
74 | return candidates, scores, read_aligns, all_states
75 |
76 | else:
77 | return candidates, scores, read_aligns
78 |
79 |
80 | def get_search_results(self, only_completed=True, sort=True):
81 | candidates, scores, aligns, states = self.get_candidates(True, True)
82 |
83 | if not only_completed:
84 | add_candis, add_scores, add_aligns, add_states = self.get_candidates(False, True)
85 | candidates = candidates + add_candis
86 | scores = scores + add_scores
87 | states = states + add_states
88 | aligns = aligns + add_aligns
89 |
90 |
91 | scores = [score/(len(candi)-1) for score, candi in zip(scores, candidates)]
92 |
93 | # sort with costs
94 | if sort:
95 | sort_indices = list(np.argsort(scores))
96 | candidates = [candidates[i] for i in sort_indices]
97 | scores = [scores[i] for i in sort_indices]
98 | states = [states[i] for i in sort_indices]
99 | aligns = [aligns[i] for i in sort_indices]
100 |
101 | # ignore the bos symbol and initial state
102 | candidates = [candi[1: ] for candi in candidates]
103 | states = [ state[1:] for state in states ]
104 |
105 | return candidates, scores, states, aligns
106 |
107 |
108 | def get_beam_tails(self):
109 | # get the last token and state of each hypothesis
110 | tokens = [hypo.candidate[-1] for hypo in self._hypotheses]
111 | # (B)
112 | tail_tokens = torch.tensor(tokens, dtype=torch.long, device=self.device)
113 |
114 | tail_states = [hypo.states[-1] for hypo in self._hypotheses]
115 | # [1, H] * B -> [B, H]
116 | tail_states = torch.cat(tail_states, dim=0)
117 |
118 | return tail_tokens, tail_states
119 |
120 |
121 | def uncompleted_num(self):
122 | return len(self._hypotheses)
123 |
124 |
125 | def advance(self, logit, state, read_align, position):
126 | # logit: (B, V)
127 | # state: (B, H)
128 | # read_align: (B, 1, mem_slots)
129 | log_prob = torch.nn.functional.log_softmax(logit, dim=-1).cpu().data.numpy()
130 |
131 | beam_ids, word_ids, scores = self._beam_select(log_prob, position)
132 |
133 | # update beams
134 | updated_hypotheses = []
135 | for beam_id, word_id, score in zip(beam_ids, word_ids, scores):
136 | # (1, H)
137 | new_states = self._hypotheses[beam_id].states + [state[beam_id, :].unsqueeze(0)]
138 |
139 | new_candidate = self._hypotheses[beam_id].candidate + [word_id]
140 |
141 | new_aligns = self._hypotheses[beam_id].read_aligns + \
142 | [read_align[beam_id, :, :].unsqueeze(0)]
143 |
144 | hypo = Hypothesis(new_candidate, new_states, score, new_aligns)
145 |
146 | if word_id == self._E_ID:
147 | self._completed_hypotheses.append(hypo)
148 | else:
149 | updated_hypotheses.append(hypo)
150 |
151 | self._hypotheses = updated_hypotheses
152 |
153 |
154 | def _beam_select(self, log_probs, position):
155 | # log_probs: (B, V)
156 | B, V = log_probs.shape[0], log_probs.shape[1]
157 |
158 |
159 | if position == 0:
160 | costs = - log_probs[0, :].reshape(1, V) # (1, V)
161 | else:
162 | current_scores = [hypo.score for hypo in self._hypotheses]
163 | costs = np.reshape(current_scores, (B, 1)) - log_probs # (B, V)
164 |
165 | # filter with rhythm, rhyme and length
166 | # candidates that don't meet requirements are assigned a large cost
167 | filter_v = 1e5
168 |
169 | costs[:, self._UNK_ID] = filter_v
170 |
171 | # filter eos symbol
172 | if position < self._length:
173 | costs[:, self._E_ID] = filter_v
174 |
175 | # restrain the model from generating chars
176 | # that already generated in previous lines
177 | costs[:, self._repetitive_ids] = filter_v
178 |
179 | # restrain in-line repetitive chars
180 | inline_filter_ids = self.inline_filter(position)
181 | for i in range(0, costs.shape[0]):
182 | costs[i, inline_filter_ids[i]] = filter_v
183 |
184 |
185 | # for the tail char, filter out non-rhyme chars
186 | if (position == self._length-1) and (1 <= self._rhythms[-1] <= 30):
187 | filter_ids = list(set(range(0, V)) - set(self._rhyme_cids))
188 | costs[:, filter_ids] = filter_v
189 |
190 | '''
191 | filter out chars of undesired tones
192 | NOTE: since some Chinese characters may belong to both tones,
193 | here we only consider the non-overlap ones
194 | TODO: disambiguation
195 | '''
196 | pos_rhythm = self._rhythms[position]
197 | if position < self._length and pos_rhythm != 0:
198 | if pos_rhythm == 31: # level tone
199 | costs[:, self._oblique_cids] = filter_v
200 | elif pos_rhythm == 32: # oblique
201 | costs[:, self._level_cids] = filter_v
202 |
203 | flat_costs = costs.flatten() # (B*V)
204 |
205 | # idx of the smallest B elements
206 | best_indices = np.argpartition(
207 | flat_costs, B)[0:B].copy()
208 |
209 | scores = flat_costs[best_indices]
210 |
211 | # get beam id and word id
212 | beam_ids = [int(idx // V) for idx in best_indices]
213 | word_ids = [int(idx % V) for idx in best_indices]
214 |
215 | if position == 0:
216 | beam_ids = list(range(0, B))
217 |
218 | return beam_ids, word_ids, scores
219 |
220 |
221 | def inline_filter(self, pos):
222 | candidates, _, _ = self.get_candidates()
223 | # candidates: (L_i) * B
224 | B = len(candidates)
225 | forbidden_list = [[] for _ in range(0, B)]
226 |
227 | limit_pos = pos - 1 if pos % 2 != 0 else pos
228 | preidx = range(0, limit_pos)
229 |
230 | for i in range(0, B): # iter ever batch
231 | forbidden_list[i] = [candidates[i][c] for c in preidx]
232 |
233 | return forbidden_list
--------------------------------------------------------------------------------
/preprocess/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:25:32
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import pickle
14 | import json
15 | import random
16 |
17 | from pattern_extractor import PatternExtractor
18 |
19 | def outFile(data, file_name):
20 | print ("output data to %s, num: %d" % (file_name, len(data)))
21 | with open(file_name, 'w') as fout:
22 | for d in data:
23 | fout.write(d+"\n")
24 |
25 |
26 | class PreProcess(object):
27 | """A Tool for data preprocess.
28 | Please note that this tool is only for Chinese quatrains.
29 | """
30 | def __init__(self):
31 | super(PreProcess, self).__init__()
32 | self.min_freq = 1
33 | self.sens_num = 4 # sens_num must be 4
34 | self.key_num = 4 # max number of keywords
35 |
36 | self.extractor = PatternExtractor("../data/")
37 |
38 |
39 | def create_dic(self, poems):
40 | print ("creating the word dictionary...")
41 | print ("input poems: %d" % (len(poems)))
42 | count_dic = {}
43 | for p in poems:
44 | poem = p.strip().replace("|", "")
45 |
46 | for c in poem:
47 | if c in count_dic:
48 | count_dic[c] += 1
49 | else:
50 | count_dic[c] = 1
51 |
52 | vec = sorted(count_dic.items(), key=lambda d:d[1], reverse=True)
53 | print ("original word num:%d" % (len(vec)))
54 |
55 | # add special symbols
56 | # --------------------------------------
57 | dic = {}
58 | idic = {}
59 | dic['PAD'] = 0
60 | idic[0] = 'PAD'
61 |
62 | dic['UNK'] = 1
63 | idic[1] = 'UNK'
64 |
65 | dic[''] = 2
66 | idic[2] = ''
67 |
68 | dic[''] = 3
69 | idic[3] = ''
70 |
71 |
72 | idx = 4
73 | print ("min freq:%d" % (self.min_freq))
74 |
75 | for c, v in vec:
76 | if v < self.min_freq:
77 | continue
78 | if not c in dic:
79 | dic[c] = idx
80 | idic[idx] = c
81 | idx += 1
82 |
83 | print ("total word num: %s" % (len(dic)))
84 |
85 | return dic, idic
86 |
87 |
88 | def build_dic(self, infile):
89 | with open(infile, 'r') as fin:
90 | lines = fin.readlines()
91 |
92 | poems = []
93 | training_lines = []
94 | for line in lines:
95 | dic = json.loads(line.strip())
96 | poem = dic['content']
97 | poems.append(poem)
98 | training_lines.extend(poem.split("|"))
99 |
100 | dic, idic = self.create_dic(poems)
101 | self.dic = dic
102 | self.idic = idic
103 |
104 | # output dic file
105 | # read
106 | dic_file = "vocab.pickle"
107 | idic_file = "ivocab.pickle"
108 |
109 | print ("saving dictionary to %s" % (dic_file))
110 | with open(dic_file, 'wb') as fout:
111 | pickle.dump(dic, fout, -1)
112 |
113 |
114 | print ("saving inverting dictionary to %s" % (idic_file))
115 | with open(idic_file, 'wb') as fout:
116 | pickle.dump(idic, fout, -1)
117 |
118 |
119 | # building training lines
120 | outFile(training_lines, "training_lines.txt")
121 |
122 |
123 |
124 | def line2idxes(self, line):
125 | chars = [c for c in line]
126 | idxes = []
127 | for c in chars:
128 | if c in self.dic:
129 | idx = self.dic[c]
130 | else:
131 | idx = self.dic['UNK']
132 | idxes.append(idx)
133 |
134 | return idxes
135 |
136 |
137 |
138 | def read_corpus(self, infile):
139 | with open(infile, 'r') as fin:
140 | lines = fin.readlines()
141 |
142 | corpus = []
143 | for line in lines:
144 | dic = json.loads(line.strip())
145 | poem = dic['content'].strip()
146 | keywords = dic['keywords'].strip().split(" ")
147 |
148 | corpus.append((keywords, poem))
149 |
150 | return corpus
151 |
152 |
153 | def build_pattern(self, sens):
154 | length = len(sens[0])
155 | assert length == 5 or length == 7
156 |
157 | rhymes = self.extractor.get_poem_rhyme(sens)
158 | if len(rhymes) == 0:
159 | return ""
160 |
161 |
162 | rhythm_pattern = self.extractor.get_poem_rhythm(sens, length)
163 | if len(rhythm_pattern) == 0:
164 | return ""
165 |
166 | for i in range(0, len(sens)):
167 | if rhymes[i] >= 1:
168 | rhythm_pattern[i][-1] = rhymes[i]
169 |
170 | return rhythm_pattern
171 |
172 |
173 | def build_data(self, corpus, convert_to_indices=True):
174 | skip_count = 0
175 |
176 | data = []
177 | for keywords, poem in corpus:
178 | lines = poem.strip().split("|")
179 |
180 | if len(keywords) == 0:
181 | skip_count += 1
182 | continue
183 |
184 | keywords = keywords[0:self.key_num]
185 |
186 | if len(lines) != 4:
187 | skip_count += 1
188 | continue
189 |
190 |
191 | lens = [len(line) for line in lines]
192 |
193 | if not (lens[0] == lens[1] == lens[2] == lens[3]):
194 | skip_count += 1
195 | continue
196 |
197 |
198 | length = lens[0]
199 | # only for Chinese quatrains
200 | if length != 5 and length != 7:
201 | skip_count += 1
202 | continue
203 |
204 |
205 | pattern = self.build_pattern(lines)
206 | if len(pattern) == 0:
207 | skip_count += 1
208 | continue
209 |
210 | if not convert_to_indices:
211 | for keynum in range(1, len(keywords)+1):
212 | tup = (random.sample(keywords, keynum),
213 | lines, keynum, pattern)
214 | data.append(tup)
215 | continue
216 |
217 | # poem to indices
218 | line_idxes_vec = []
219 | for line in lines:
220 | idxes = self.line2idxes(line)
221 | assert len(idxes) == 5 or len(idxes) == 7
222 | line_idxes_vec.append(idxes)
223 |
224 | assert len(line_idxes_vec) == 4
225 |
226 | # keywords to indices
227 | key_idxes_vec = [self.line2idxes(keyword) for keyword in keywords]
228 |
229 |
230 | for keynum in range(1, len(key_idxes_vec)+1):
231 | tup = (random.sample(key_idxes_vec, keynum),
232 | line_idxes_vec, keynum, pattern)
233 | data.append(tup)
234 |
235 | print ("data num: %d, skip_count: %d" %\
236 | (len(data), skip_count))
237 |
238 | return data
239 |
240 |
241 | def build_test_data(self, infile, out_inp_file, out_trg_file):
242 | with open(infile, 'r') as fin:
243 | lines = fin.readlines()
244 |
245 | test = self.read_corpus(infile)
246 | test_data = self.build_data(test, False)
247 |
248 | inps, trgs = [], []
249 | for tup in test_data:
250 | keywords = " ".join(tup[0])
251 | poem = "|".join(tup[1])
252 | pattern = "|".join([" ".join(map(str, p)) for p in tup[3]])
253 |
254 | inps.append(keywords+"#"+pattern)
255 | trgs.append(poem)
256 |
257 |
258 | outFile(inps, out_inp_file)
259 | outFile(trgs, out_trg_file)
260 |
261 |
262 |
263 | def process(self):
264 | # build the word dictionary
265 | self.build_dic("ccpc_train_v1.0.json")
266 |
267 | # build training and validation datasets
268 | train = self.read_corpus("ccpc_train_v1.0.json")
269 | valid = self.read_corpus("ccpc_valid_v1.0.json")
270 |
271 | train_data = self.build_data(train)
272 | valid_data = self.build_data(valid)
273 |
274 | random.shuffle(train_data)
275 | random.shuffle(valid_data)
276 |
277 | print ("training data: %d" % (len(train_data)))
278 | print ("validation data: %d" % (len(valid_data)))
279 |
280 | train_file = "train_data.pickle"
281 | print ("saving training data to %s" % (train_file))
282 | with open(train_file, 'wb') as fout:
283 | pickle.dump(train_data, fout, -1)
284 |
285 |
286 | valid_file = "valid_data.pickle"
287 | print ("saving validation data to %s" % (valid_file))
288 | with open(valid_file, 'wb') as fout:
289 | pickle.dump(valid_data, fout, -1)
290 |
291 | # build testing inputs and trgs
292 | self.build_test_data("ccpc_test_v1.0.json", "test_inps.txt", "test_trgs.txt")
293 |
294 |
295 |
296 | def main():
297 | processor = PreProcess()
298 | processor.process()
299 |
300 |
301 |
302 | if __name__ == "__main__":
303 | main()
--------------------------------------------------------------------------------
/codes/layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:08:55
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import math
14 | from collections import OrderedDict
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 |
20 |
21 | class BidirEncoder(nn.Module):
22 | def __init__(self, input_size, hidden_size, cell='GRU', n_layers=1, drop_ratio=0.1):
23 | super(BidirEncoder, self).__init__()
24 |
25 | if cell == 'GRU':
26 | self.rnn = nn.GRU(input_size, hidden_size, n_layers,
27 | bidirectional=True, batch_first=True)
28 | elif cell == 'Elman':
29 | self.rnn = nn.RNN(input_size, hidden_size, n_layers,
30 | bidirectional=True, batch_first=True)
31 | elif cell == 'LSTM':
32 | self.rnn = nn.LSTM(input_size, hidden_size, n_layers,
33 | bidirectional=True, batch_first=True)
34 |
35 | self.dropout_layer = nn.Dropout(drop_ratio)
36 |
37 |
38 | def forward(self, embed_seq, input_lens=None):
39 | # embed_seq: (B, L, emb_dim)
40 | # input_lens: (B)
41 | embed_inps = self.dropout_layer(embed_seq)
42 |
43 | if input_lens is None:
44 | outputs, state = self.rnn(embed_inps, None)
45 | else:
46 | # Dynamic RNN
47 | total_len = embed_inps.size(1)
48 | packed = torch.nn.utils.rnn.pack_padded_sequence(embed_inps,
49 | input_lens, batch_first=True, enforce_sorted=False)
50 | outputs, state = self.rnn(packed, None)
51 | # outputs: (B, L, num_directions*H)
52 | # state: (num_layers*num_directions, B, H)
53 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs,
54 | batch_first=True, total_length=total_len)
55 |
56 | return outputs, state
57 |
58 |
59 | class Decoder(nn.Module):
60 | def __init__(self, input_size, hidden_size, cell='GRU', n_layers=1, drop_ratio=0.1):
61 | super(Decoder, self).__init__()
62 |
63 | self.dropout_layer = nn.Dropout(drop_ratio)
64 |
65 | if cell == 'GRU':
66 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True)
67 | elif cell == 'Elman':
68 | self.rnn = nn.RNN(input_size, hidden_size, n_layers, batch_first=True)
69 | elif cell == 'LSTM':
70 | self.rnn = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)
71 |
72 |
73 | def forward(self, embed_seq, last_state):
74 | # embed_seq: (B, L, H)
75 | # last_state: (B, H)
76 | embed_inps = self.dropout_layer(embed_seq)
77 | output, state = self.rnn(embed_inps, last_state.unsqueeze(0))
78 | output = output.squeeze(1) # (B, 1, N) -> (B,N)
79 | return output, state.squeeze(0) # (B, H)
80 |
81 |
82 | class AttentionReader(nn.Module):
83 | def __init__(self, d_q, d_v, drop_ratio=0.0):
84 | super(AttentionReader, self).__init__()
85 | self.attn = nn.Linear(d_q+d_v, d_v)
86 | self.v = nn.Parameter(torch.rand(d_v))
87 | stdv = 1. / math.sqrt(self.v.size(0))
88 | self.v.data.normal_(mean=0, std=stdv)
89 | self.dropout = nn.Dropout(drop_ratio)
90 |
91 |
92 | def forward(self, Q, K, V, attn_mask):
93 | # Q: (B, 1, d_q)
94 | # K: (B, L, d_v)
95 | # V: (B, L, d_v)
96 | # attn_mask: (B, L), True means mask
97 | k_len = K.size(1)
98 | q_state = Q.repeat(1, k_len, 1) # (B, L, d_q)
99 |
100 | attn_energies = self.score(q_state, K) # (B, L)
101 |
102 | attn_energies.masked_fill_(attn_mask, -1e12)
103 |
104 | attn_weights = F.softmax(attn_energies, dim=1).unsqueeze(1)
105 | attn_weights = self.dropout(attn_weights)
106 |
107 | # (B, 1, L) * (B, L, d_v) -> (B, 1, d_v)
108 | context = attn_weights.bmm(V)
109 |
110 | return context.squeeze(1), attn_weights
111 |
112 |
113 | def score(self, query, memory):
114 | # query (B, L, d_q)
115 | # memory (B, L, d_v)
116 |
117 | # (B, L, d_q+d_v)->(B, L, d_v)
118 | energy = torch.tanh(self.attn(torch.cat([query, memory], 2)))
119 | energy = energy.transpose(1, 2) # (B, d_v, L)
120 |
121 | v = self.v.repeat(memory.size(0), 1).unsqueeze(1) # (B, 1, d_v)
122 | energy = torch.bmm(v, energy) # (B, 1, d_v) * (B, d_v, L) -> [B, 1, L]
123 | return energy.squeeze(1) # (B, L)
124 |
125 |
126 |
127 | class AttentionWriter(nn.Module):
128 | def __init__(self, d_q, mem_size):
129 | super(AttentionWriter, self).__init__()
130 | self.attn = nn.Linear(d_q+mem_size, mem_size)
131 | self.v = nn.Parameter(torch.rand(mem_size))
132 | stdv = 1. / math.sqrt(self.v.size(0))
133 | self.v.data.normal_(mean=0, std=stdv)
134 |
135 | self._tau = 1.0 # Gumbel temperature
136 |
137 |
138 | def set_tau(self, tau):
139 | self._tau = tau
140 |
141 | def get_tau(self):
142 | return self._tau
143 |
144 |
145 | def forward(self, his_mem, states, states_mask, global_trace, null_mem):
146 | # mem: (B, mem_slots, mem_size)
147 | # states: (B, L, mem_size)
148 | # states_mask: (B, L), 0 means pad_idx and not to be written
149 | # global_trace: (B, D)
150 | # null_mem: (B, 1, mem_size)
151 | n = states.size(1)
152 | mem_slots = his_mem.size(1) + 1 # including the null slot
153 |
154 | write_log = []
155 |
156 | for i in range(0, n):
157 | mem = torch.cat([his_mem, null_mem], dim=1)
158 | state = states[:, i, :] # (B, mem_size)
159 |
160 |
161 | query = torch.cat([state, global_trace], dim=1).unsqueeze(1).repeat(1, mem_slots, 1)
162 | attn_energies = self.score(query, mem)
163 |
164 | attn_weights = F.softmax(attn_energies, dim=-1) # (B, mem_slots)
165 |
166 | # manually give the empty slots higher weights
167 | empty_mask = mem.abs().sum(-1).eq(0).float() # (B, mem_slots)
168 | attn_weights = attn_weights + empty_mask * 10.0
169 |
170 | # one-hot (B, mem_slots)
171 | slot_select = F.gumbel_softmax(attn_weights, tau=self._tau, hard=True)
172 |
173 | write_mask = slot_select[:, 0:mem_slots-1] * \
174 | (states_mask[:, i].unsqueeze(1).repeat(1, mem_slots-1))
175 | write_mask = write_mask.unsqueeze(2) # (B, mem_slots-1, 1)
176 |
177 |
178 | write_state = state.unsqueeze(1).repeat(1, mem_slots-1, 1)
179 |
180 | his_mem = (1.0 - write_mask) * his_mem + write_mask * write_state
181 |
182 | write_log.append(slot_select.unsqueeze(1))
183 |
184 | write_log = torch.cat(write_log, dim=1)
185 | return his_mem, write_log
186 |
187 |
188 | def score(self, query, memory):
189 | energy = torch.tanh(self.attn(torch.cat([query, memory], 2)))
190 | energy = energy.transpose(1, 2)
191 |
192 | v = self.v.repeat(memory.size(0), 1).unsqueeze(1)
193 |
194 | energy = torch.bmm(v, energy)
195 | return energy.squeeze(1)
196 |
197 |
198 | class MLP(nn.Module):
199 | def __init__(self, ori_input_size, layer_sizes, activs=None,
200 | drop_ratio=0.0, no_drop=False):
201 | super(MLP, self).__init__()
202 |
203 | layer_num = len(layer_sizes)
204 |
205 | orderedDic = OrderedDict()
206 | input_size = ori_input_size
207 | for i, (layer_size, activ) in enumerate(zip(layer_sizes, activs)):
208 | linear_name = 'linear_' + str(i)
209 | orderedDic[linear_name] = nn.Linear(input_size, layer_size)
210 | input_size = layer_size
211 |
212 | if activ is not None:
213 | assert activ in ['tanh', 'relu', 'leakyrelu']
214 |
215 | active_name = 'activ_' + str(i)
216 | if activ == 'tanh':
217 | orderedDic[active_name] = nn.Tanh()
218 | elif activ == 'relu':
219 | orderedDic[active_name] = nn.ReLU()
220 | elif activ == 'leakyrelu':
221 | orderedDic[active_name] = nn.LeakyReLU(0.2)
222 |
223 |
224 | if (drop_ratio > 0) and (i < layer_num-1) and (not no_drop):
225 | orderedDic["drop_" + str(i)] = nn.Dropout(drop_ratio)
226 |
227 | self.mlp = nn.Sequential(orderedDic)
228 |
229 |
230 | def forward(self, inps):
231 | return self.mlp(inps)
232 |
233 |
234 | class ContextLayer(nn.Module):
235 | def __init__(self, inp_size, out_size, kernel_size=3):
236 | super(ContextLayer, self).__init__()
237 | # (B, L, H)
238 | self.conv = nn.Conv1d(inp_size, out_size, kernel_size)
239 | self.linear = nn.Linear(out_size+inp_size, out_size)
240 |
241 | def forward(self, last_context, dec_states):
242 | # last_context: (B, context_size)
243 | # dec_states: (B, H, L)
244 | hidden_feature = self.conv(dec_states) # (B, out_size, L_out)
245 | feature = torch.tanh(hidden_feature).mean(dim=2) # (B, out_size)
246 | new_context = torch.tanh(self.linear(torch.cat([last_context, feature], dim=1)))
247 | return new_context
--------------------------------------------------------------------------------
/codes/generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:19:25
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from graphs import WorkingMemoryModel
17 |
18 | from tool import Tool
19 | from beam import PoetryBeam
20 | from filter import PoetryFilter
21 | from visualization import Visualization
22 | import utils
23 |
24 | class Generator(object):
25 | '''
26 | generator for testing
27 | '''
28 |
29 | def __init__(self, hps, device):
30 | self.tool = Tool(hps.sens_num, hps.sen_len,
31 | hps.key_len, hps.topic_slots, 0.0)
32 | self.tool.load_dic(hps.vocab_path, hps.ivocab_path)
33 | vocab_size = self.tool.get_vocab_size()
34 | print ("vocabulary size: %d" % (vocab_size))
35 | PAD_ID = self.tool.get_PAD_ID()
36 | B_ID = self.tool.get_B_ID()
37 | assert vocab_size > 0 and PAD_ID >=0 and B_ID >= 0
38 | self.hps = hps._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID)
39 | self.device = device
40 |
41 | # load model
42 | model = WorkingMemoryModel(self.hps, device)
43 |
44 | # load trained model
45 | utils.restore_checkpoint(self.hps.model_dir, device, model)
46 | self.model = model.to(device)
47 | self.model.eval()
48 |
49 | null_idxes = self.tool.load_function_tokens(self.hps.data_dir + "fchars.txt").to(self.device)
50 | self.model.set_null_idxes(null_idxes)
51 |
52 | self.model.set_tau(hps.min_tau)
53 |
54 | # load poetry filter
55 | print ("loading poetry filter...")
56 | self.filter = PoetryFilter(self.tool.get_vocab(),
57 | self.tool.get_ivocab(), self.hps.data_dir)
58 |
59 | self.visual_tool = Visualization(hps.topic_slots, hps.his_mem_slots,
60 | "../log/")
61 | print("--------------------------")
62 |
63 |
64 |
65 | def generate_one(self, keywords, pattern, beam_size=20, verbose=1, manu=False, visual=0):
66 | '''
67 | generate one poem according to the inputs:
68 | keyword: a list of topic words, at most key_slots
69 | pattern: genre pattern of the poem to be generated
70 | verbose: 0, 1, 2, 3
71 | visual: 0, 1, 2
72 | '''
73 | key_inps = self.tool.keywords2tensor([keywords]*beam_size)
74 | key_inps = [key.to(self.device) for key in key_inps]
75 |
76 | if visual > 0:
77 | self.visual_tool.reset(keywords)
78 |
79 |
80 | # inps is a pseudo tensor with pad symbols for generating the first line
81 | inps, all_ph_inps, all_len_inps, all_lengths = self.tool.patterns2tensor([pattern]*beam_size)
82 |
83 | inps = inps.to(self.device)
84 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in all_ph_inps]
85 | all_len_inps = [len_inps.to(self.device) for len_inps in all_len_inps]
86 |
87 | # for quatrains, all lines in a poem share the same length
88 | length = all_lengths[0]
89 |
90 | # initialize beam pool
91 | beam_pool = PoetryBeam(self.device, beam_size, length,
92 | self.tool.get_B_ID(), self.tool.get_E_ID(), self.tool.get_UNK_ID(),
93 | self.filter.get_level_cids(), self.filter.get_oblique_cids())
94 |
95 | self.filter.reset(length, verbose)
96 |
97 | with torch.no_grad():
98 | topic_mem, topic_mask, history_mem, history_mask,\
99 | global_trace, topic_trace, key_init_state = self.model.initialize_mems(key_inps)
100 |
101 | # beam search
102 | poem = []
103 | for step in range(0, self.hps.sens_num):
104 | # generate each line
105 | if verbose >= 1:
106 | print ("\ngenerating step: %d" % (step))
107 |
108 | if step > 0:
109 | key_init_state = None
110 |
111 | candidates, costs, states, read_aligns, local_mem, local_mask = self.beam_search(beam_pool,
112 | length, inps,
113 | all_ph_inps[step], pattern[step], all_len_inps[step],
114 | key_init_state, history_mem, history_mask,
115 | topic_mem, topic_mask, global_trace, topic_trace)
116 |
117 | lines = [self.tool.idxes2line(idxes) for idxes in candidates]
118 |
119 | lines, costs, states, read_aligns = self.filter.filter_illformed(lines, costs,
120 | states, read_aligns, pattern[step][-1])
121 |
122 | if len(lines) == 0:
123 | return [], "line {} generation failed!".format(step)
124 |
125 | which = 0
126 | if manu:
127 | for i, (line, cost) in enumerate(zip(lines, costs)):
128 | print ("%d, %s, %.2f" % (i, line, cost))
129 | which = int(input("select sentence>"))
130 |
131 | line = lines[which]
132 | poem.append(line)
133 |
134 | # set repetitive chars
135 | self.filter.add_repetitive(self.tool.line2idxes(line))
136 |
137 | # ---------------------------------------
138 | # write into history memory
139 | write_log = None
140 | if step >= 1:
141 | with torch.no_grad():
142 | history_mem, write_log = self.update_history_mem(history_mem,
143 | local_mem, local_mask, global_trace)
144 |
145 | history_mask = history_mem.abs().sum(-1).eq(0) # (B, mem_slots)
146 |
147 |
148 | with torch.no_grad():
149 | # update global trace
150 | global_trace = self.update_glocal_trace(global_trace, states[which], length)
151 |
152 | # update topic trace
153 | topic_trace = self.update_topic_trace(topic_trace, topic_mem, read_aligns[which])
154 |
155 |
156 | # build inps
157 | inps = self.tool.line2tensor(line, beam_size).to(self.device)
158 |
159 |
160 | if visual > 0:
161 | # show visualization of memory reading
162 | self.visual_tool.add_gen_line(line)
163 | self.visual_tool.draw(read_aligns[which], write_log, step, visual)
164 |
165 |
166 | return poem, "ok"
167 |
168 |
169 | # ------------------------------------
170 | def beam_search(self, beam_pool, trg_len, inputs, phs, ph_labels, lens, key_init_state,
171 | history_mem, history_mask, topic_mem, topic_mask, global_trace, topic_trace):
172 |
173 | local_mem, local_mask, init_state = \
174 | self.model.computer_local_memory(inputs, key_init_state is None)
175 |
176 |
177 | if key_init_state is not None:
178 | init_state = key_init_state
179 |
180 | # reset beam pool
181 | if 1 <= ph_labels[-1] <= 30:
182 | rhyme = ph_labels[-1]
183 | else:
184 | rhyme = -1
185 |
186 | beam_pool.reset(init_state[0, :].unsqueeze(0), ph_labels+[0]*10,
187 | self.filter.get_rhyme_cids(rhyme), self.filter.get_repetitive_ids())
188 |
189 | # current size of beam candidates in the beam pool
190 | n_samples = beam_pool.uncompleted_num()
191 |
192 | total_mask = torch.cat([topic_mask, history_mask, local_mask], dim=1)
193 | total_mem = torch.cat([topic_mem, history_mem, local_mem], dim=1)
194 |
195 |
196 | for k in range(0, trg_len+5):
197 | inp, state = beam_pool.get_beam_tails()
198 |
199 | if k <= trg_len:
200 | ph_inp = phs[0:n_samples, k]
201 | len_inp = lens[0:n_samples, k]
202 | else:
203 | ph_inp = torch.zeros(n_samples, dtype=torch.long, device=self.device)
204 | len_inp = torch.zeros(n_samples, dtype=torch.long, device=self.device)
205 |
206 |
207 | with torch.no_grad():
208 | logit, new_state, read_align = self.model.dec_step(inp, state,
209 | ph_inp, len_inp,
210 | total_mem[0:n_samples, :, :], total_mask[0:n_samples, :],
211 | global_trace[0:n_samples, :], topic_trace[0:n_samples, :])
212 |
213 |
214 | beam_pool.advance(logit, new_state, read_align, k)
215 |
216 | n_samples = beam_pool.uncompleted_num()
217 |
218 | if n_samples == 0:
219 | break
220 |
221 | candidates, costs, dec_states, read_aligns = beam_pool.get_search_results()
222 | return candidates, costs, dec_states, read_aligns, local_mem, local_mask
223 |
224 |
225 | # ---------------
226 | def update_history_mem(self, history_mem, local_mem, local_mask, global_trace):
227 | new_history_mem, write_log = self.model.layers['memory_write'](history_mem, local_mem,
228 | 1.0-local_mask.float(), global_trace, self.model.null_mem)
229 |
230 |
231 | return new_history_mem, write_log
232 |
233 |
234 | def update_glocal_trace(self, global_trace, dec_states, length):
235 | # dec_states: (1, H) * L_gen
236 |
237 | batch_size = global_trace.size(0)
238 |
239 | # (1, H) -> (B, H, 1)
240 | states = [state.unsqueeze(2).repeat(batch_size, 1, 1) for state in dec_states]
241 |
242 | l = len(states)
243 |
244 | mask = torch.zeros(batch_size, l, dtype=torch.float, device=self.device)
245 | mask[:, 0:length+1] = 1.0
246 |
247 | # update global trace vector
248 | new_global_trace = self.model.update_global_trace(global_trace, states, mask)
249 |
250 |
251 | return new_global_trace
252 |
253 |
254 | def update_topic_trace(self, topic_trace, topic_mem, read_align):
255 | # read_align: (1, 1, mem_slots) * L_gen
256 |
257 | batch_size = topic_trace.size(0)
258 | # concat_aligns: (B, L_gen, mem_slots)
259 | concat_aligns = torch.cat(read_align, dim=1).repeat(batch_size, 1, 1)
260 | new_topic_trace = self.model.update_topic_trace(topic_trace, topic_mem, concat_aligns)
261 |
262 | return new_topic_trace
263 |
264 |
265 | # --------------------------------------
--------------------------------------------------------------------------------
/codes/tool.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:38:33
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import pickle
14 | import numpy as np
15 | import random
16 | import copy
17 | import torch
18 |
19 |
20 | def readPickle(data_path):
21 | corpus_file = open(data_path, 'rb')
22 | corpus = pickle.load(corpus_file)
23 | corpus_file.close()
24 |
25 | return corpus
26 |
27 | #-------------------------------------------------------------------------
28 | #-------------------------------------------------------------------------
29 | class Tool(object):
30 | '''
31 | a tool to hold training data and the vocabulary
32 | '''
33 | def __init__(self, sens_num, sen_len, key_len, key_slots,
34 | corrupt_ratio=0):
35 | # corrupt ratio for dae
36 | self._sens_num = sens_num
37 | self._sen_len = sen_len
38 | self._key_len = key_len
39 | self._key_slots = key_slots
40 |
41 | self._corrupt_ratio = corrupt_ratio
42 |
43 | self._vocab = None
44 | self._ivocab = None
45 |
46 | self._PAD_ID = None
47 | self._B_ID = None
48 | self._E_ID = None
49 | self._UNK_ID = None
50 |
51 | # -----------------------------------
52 | # map functions
53 | def idxes2line(self, idxes, truncate=True):
54 | if truncate and self._E_ID in idxes:
55 | idxes = idxes[:idxes.index(self._E_ID)]
56 |
57 | tokens = self.idxes2tokens(idxes, truncate)
58 | line = self.tokens2line(tokens)
59 | return line
60 |
61 | def line2idxes(self, line):
62 | tokens = self.line2tokens(line)
63 | return self.tokens2idxes(tokens)
64 |
65 | def line2tokens(self, line):
66 | '''
67 | in this work, we treat each Chinese character as a token.
68 | '''
69 | line = line.strip()
70 | tokens = [c for c in line]
71 | return tokens
72 |
73 |
74 | def tokens2line(self, tokens):
75 | return "".join(tokens)
76 |
77 |
78 | def tokens2idxes(self, tokens):
79 | ''' Characters to idx list '''
80 | idxes = []
81 | for w in tokens:
82 | if w in self._vocab:
83 | idxes.append(self._vocab[w])
84 | else:
85 | idxes.append(self._UNK_ID)
86 | return idxes
87 |
88 |
89 | def idxes2tokens(self, idxes, omit_special=True):
90 | tokens = []
91 | for idx in idxes:
92 | if (idx == self._PAD_ID or idx == self._B_ID
93 | or idx == self._E_ID) and omit_special:
94 | continue
95 | tokens.append(self._ivocab[idx])
96 |
97 | return tokens
98 |
99 | # -------------------------------------------------
100 | def greedy_search(self, probs):
101 | # probs: (V)
102 | out_idxes = [int(np.argmax(prob, axis=-1)) for prob in probs]
103 |
104 | return self.idxes2line(out_idxes)
105 |
106 | # ----------------------------
107 | def get_vocab(self):
108 | return copy.deepcopy(self._vocab)
109 |
110 | def get_ivocab(self):
111 | return copy.deepcopy(self._ivocab)
112 |
113 | def get_vocab_size(self):
114 | if self._vocab is not None:
115 | return len(self._vocab)
116 | else:
117 | return -1
118 |
119 | def get_PAD_ID(self):
120 | assert self._PAD_ID is not None
121 | return self._PAD_ID
122 |
123 | def get_B_ID(self):
124 | assert self._B_ID is not None
125 | return self._B_ID
126 |
127 | def get_E_ID(self):
128 | assert self._E_ID is not None
129 | return self._E_ID
130 |
131 | def get_UNK_ID(self):
132 | assert self._UNK_ID is not None
133 | return self._UNK_ID
134 |
135 |
136 | # ----------------------------------------------------------------
137 | def load_dic(self, vocab_path, ivocab_path):
138 | dic = readPickle(vocab_path)
139 | idic = readPickle(ivocab_path)
140 |
141 | assert len(dic) == len(idic)
142 |
143 |
144 | self._vocab = dic
145 | self._ivocab = idic
146 |
147 | self._PAD_ID = dic['PAD']
148 | self._UNK_ID = dic['UNK']
149 | self._E_ID = dic['']
150 | self._B_ID = dic['']
151 |
152 |
153 | def load_function_tokens(self, file_dir):
154 | # please run load_dict before using loading function tokens !
155 | with open(file_dir, 'r') as fin:
156 | lines = fin.readlines()
157 |
158 | tokens = [line.strip() for line in lines]
159 |
160 | f_idxes = []
161 | for token in tokens:
162 | if token in self._vocab:
163 | f_idxes.append(self._vocab[token])
164 |
165 | f_idxes = torch.tensor(f_idxes, dtype=torch.long)
166 | return f_idxes
167 |
168 |
169 | def build_data(self, train_data_path, valid_data_path, batch_size, mode):
170 | '''
171 | Build data as batches.
172 | NOTE: please run load_dic() at first.
173 | mode:
174 | dae: pre-train the encoder and decoder as a denoising Seq2Seq model
175 | wm: train the working memory model
176 | '''
177 | assert mode in ['dseq', 'wm']
178 | train_data = readPickle(train_data_path)
179 | valid_data = readPickle(valid_data_path)
180 |
181 |
182 | # data limit for debug
183 | self.train_batches = self._build_data_core(train_data, batch_size, mode, None)
184 | self.valid_batches = self._build_data_core(valid_data, batch_size, mode, None)
185 |
186 | self.train_batch_num = len(self.train_batches)
187 | self.valid_batch_num = len(self.valid_batches)
188 |
189 |
190 | def _build_data_core(self, data, batch_size, mode, data_limit=None):
191 | # data: [keywords, sens, key_num, pattern] * data_num
192 | if data_limit is not None:
193 | data = data[0:data_limit]
194 |
195 | if mode == 'dseq':
196 | return self.build_dseq_batches(data, batch_size)
197 | elif mode == 'wm':
198 | return self.build_wm_batches(data, batch_size)
199 |
200 |
201 | def build_dseq_batches(self, data, batch_size):
202 | # data: [keywords, sens, key_num, pattern] * data_num
203 | batched_data = []
204 | batch_num = int(np.ceil(len(data) / float(batch_size)))
205 | for bi in range(0, batch_num):
206 | instances = data[bi*batch_size : (bi+1)*batch_size]
207 | if len(instances) < batch_size:
208 | instances = instances + random.sample(data, batch_size-len(instances))
209 |
210 | # build poetry batch
211 | poems = [instance[1] for instance in instances] # all poems
212 | genre_patterns = [instance[3] for instance in instances]
213 | for i in range(0, self._sens_num-1):
214 | line0 = [poem[i] for poem in poems]
215 | line1 = [poem[i+1] for poem in poems]
216 | phs = [pattern[i+1] for pattern in genre_patterns]
217 |
218 | inps, trgs, ph_inps, len_inps = \
219 | self._build_batch_seqs(line0, line1, phs, corrupt=True)
220 |
221 | batched_data.append((inps, trgs, ph_inps, len_inps))
222 |
223 | random.shuffle(batched_data)
224 | return batched_data
225 |
226 |
227 | def build_wm_batches(self, data, batch_size):
228 | # data: [keywords, sens, key_num, pattern] * data_num
229 | batched_data = []
230 | batch_num = int(np.ceil(len(data) / float(batch_size)))
231 | for bi in range(0, batch_num):
232 | instances = data[bi*batch_size : (bi+1)*batch_size]
233 | if len(instances) < batch_size:
234 | instances = instances + random.sample(data, batch_size-len(instances))
235 |
236 | # build poetry batch
237 | poems = [instance[1] for instance in instances] # all poems
238 | genre_patterns = [instance[3] for instance in instances]
239 |
240 |
241 | all_inps, all_trgs = [], []
242 | all_ph_inps, all_len_inps = [], []
243 |
244 | for i in range(-1, self._sens_num-1):
245 |
246 | if i < 0:
247 | line0 = [[] for poem in poems]
248 | else:
249 | line0 = [poem[i] for poem in poems]
250 |
251 | line1 = [poem[i+1] for poem in poems]
252 | phs = [pattern[i+1] for pattern in genre_patterns]
253 |
254 |
255 | inps, trgs, ph_inps, len_inps = \
256 | self._build_batch_seqs(line0, line1, phs, corrupt=False)
257 |
258 |
259 | all_inps.append(inps)
260 | all_trgs.append(trgs)
261 | all_ph_inps.append(ph_inps)
262 | all_len_inps.append(len_inps)
263 |
264 |
265 | # build keys
266 | keywords = [instance[0] for instance in instances]
267 | keys = self._build_batch_keys(keywords)
268 |
269 | batched_data.append((all_inps, all_trgs, all_ph_inps, all_len_inps, keys))
270 |
271 |
272 | random.shuffle(batched_data)
273 | return batched_data
274 |
275 |
276 | def _build_batch_keys(self, keywords):
277 | # build key batch
278 | batch_size = len(keywords)
279 | key_inps = [[] for _ in range(self._key_slots)]
280 |
281 | for i in range(0, batch_size):
282 | keys = keywords[i] # batch_size * at most 4
283 | for step in range(0, len(keys)):
284 | key = keys[step]
285 | assert len(key) <= self._key_len
286 | key_inps[step].append(key + [self._PAD_ID] * (self._key_len-len(key)))
287 |
288 | for step in range(0, self._key_slots-len(keys)):
289 | key_inps[len(keys)+step].append([self._PAD_ID] * self._key_len)
290 |
291 |
292 | key_tensor = [self._sens2tensor(key) for key in key_inps]
293 | return key_tensor
294 |
295 |
296 |
297 | def _build_batch_seqs(self, inputs, targets, pattern, corrupt=False):
298 | # pack sequences as a tensor
299 | inps, _, _ = self._get_batch_seq(inputs, pattern, False, corrupt=corrupt)
300 | trgs, phs, lens = self._get_batch_seq(targets, pattern, True, corrupt=False)
301 |
302 | inps_tensor = self._sens2tensor(inps)
303 | trgs_tensor = self._sens2tensor(trgs)
304 |
305 | phs_tensor = self._sens2tensor(phs)
306 | lens_tensor = self._sens2tensor(lens)
307 |
308 |
309 | return inps_tensor, trgs_tensor, phs_tensor, lens_tensor
310 |
311 |
312 | def _get_batch_seq(self, seqs, phs, with_E, corrupt):
313 | batch_size = len(seqs)
314 | max_len = max([len(seq) for seq in seqs])
315 | max_len = max_len + int(with_E)
316 |
317 | if max_len == 0:
318 | max_len = self._sen_len
319 |
320 | batched_seqs = []
321 | batched_lens, batched_phs = [], []
322 | for i in range(0, batch_size):
323 | # max length for each sequence
324 | ori_seq = copy.deepcopy(seqs[i])
325 |
326 | if corrupt:
327 | seq = self._do_corruption(ori_seq)
328 | else:
329 | seq = ori_seq
330 | # ----------------------------------
331 |
332 | pad_size = max_len - len(seq) - int(with_E)
333 | pads = [self._PAD_ID] * pad_size
334 |
335 | new_seq = seq + [self._E_ID] * int(with_E) + pads
336 |
337 | #---------------------------------
338 | # 0 means either
339 | ph = phs[i]
340 | ph_inp = ph + [0] * (max_len - len(ph))
341 |
342 | assert len(ph_inp) == len(new_seq)
343 | batched_phs.append(ph_inp)
344 |
345 | len_inp = list(range(1, len(seq)+1+int(with_E)))
346 | len_inp.reverse()
347 | len_inp = len_inp + [0] * pad_size
348 |
349 | assert len(len_inp) == len(new_seq)
350 |
351 | batched_lens.append(len_inp)
352 | #---------------------------------
353 | batched_seqs.append(new_seq)
354 |
355 |
356 | return batched_seqs, batched_phs, batched_lens
357 |
358 |
359 |
360 | def _sens2tensor(self, sens):
361 | batch_size = len(sens)
362 | sen_len = max([len(sen) for sen in sens])
363 | tensor = torch.zeros(batch_size, sen_len, dtype=torch.long)
364 | for i, sen in enumerate(sens):
365 | for j, token in enumerate(sen):
366 | tensor[i][j] = token
367 | return tensor
368 |
369 |
370 | def _do_corruption(self, inp):
371 | # corrupt the sequence by setting some tokens as UNK
372 | m = int(np.ceil(len(inp) * self._corrupt_ratio))
373 | m = min(m, len(inp))
374 | m = max(1, m)
375 |
376 | unk_id = self.get_UNK_ID()
377 |
378 | corrupted_inp = copy.deepcopy(inp)
379 | pos = random.sample(list(range(0, len(inp))), m)
380 | for p in pos:
381 | corrupted_inp[p] = unk_id
382 |
383 | return corrupted_inp
384 |
385 |
386 |
387 | def shuffle_train_data(self):
388 | random.shuffle(self.train_batches)
389 |
390 |
391 |
392 | # -----------------------------------------------------------
393 | # -----------------------------------------------------------
394 | # Tools for beam search
395 | def keywords2tensor(self, keywords):
396 | # input: keywords: list of string
397 |
398 | # string to idxes
399 | key_idxes = []
400 | for keyword in keywords:
401 | keys = [self.line2idxes(key_str) for key_str in keyword]
402 | if len(keys) < self._key_slots:
403 | add_num = self._key_slots - len(keys)
404 | add_keys = [[self._PAD_ID]*self._key_len]*add_num
405 | keys = keys + add_keys
406 |
407 | key_idxes.append(keys)
408 |
409 | keys_tensor = self._build_batch_keys(key_idxes)
410 | return keys_tensor
411 |
412 |
413 |
414 | def patterns2tensor(self, patterns):
415 | batch_size = len(patterns)
416 | # assume all poems share the same sens_num
417 | all_seqs = []
418 | all_lengths = []
419 | all_ph_inps, all_len_inps = [], []
420 | for step in range(0, self._sens_num):
421 | # each line
422 | phs = [pattern[step] for pattern in patterns]
423 | pseudo_seqs = [ [0] * len(ph) for ph in phs ]
424 | all_lengths.append(max([len(seq) for seq in pseudo_seqs]))
425 |
426 | batched_seqs, batched_phs, batched_lens = \
427 | self._get_batch_seq(pseudo_seqs, phs, True, False)
428 |
429 | phs_tensor = self._sens2tensor(batched_phs)
430 | lens_tensor = self._sens2tensor(batched_lens)
431 | inps_tensor = self._sens2tensor(batched_seqs)
432 |
433 | all_ph_inps.append(phs_tensor)
434 | all_len_inps.append(lens_tensor)
435 | all_seqs.append(inps_tensor)
436 |
437 |
438 | return all_seqs[0], all_ph_inps, all_len_inps, all_lengths
439 |
440 |
441 | def line2tensor(self, line, beam_size):
442 | idxes = self.line2idxes(line.strip())
443 | return self._sens2tensor([idxes]*beam_size)
444 |
--------------------------------------------------------------------------------
/codes/graphs.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author: Xiaoyuan Yi
3 | # @Last Modified by: Xiaoyuan Yi
4 | # @Last Modified time: 2020-06-11 20:16:17
5 | # @Email: yi-xy16@mails.tsinghua.edu.cn
6 | # @Description:
7 | '''
8 | Copyright 2020 THUNLP Lab. All Rights Reserved.
9 | This code is part of the online Chinese poetry generation system, Jiuge.
10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/.
11 | Github: https://github.com/THUNLP-AIPoet.
12 | '''
13 | import random
14 | from itertools import chain
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 |
19 | from layers import BidirEncoder, Decoder, MLP, ContextLayer, AttentionReader, AttentionWriter
20 |
21 | def get_non_pad_mask(seq, pad_idx, device):
22 | # seq: [B, L]
23 | assert seq.dim() == 2
24 | # [B, L]
25 | mask = seq.ne(pad_idx).type(torch.float)
26 | return mask.to(device)
27 |
28 |
29 | def get_seq_length(seq, pad_idx, device):
30 | mask = get_non_pad_mask(seq, pad_idx, device)
31 | # mask: [B, T]
32 | lengths = mask.sum(dim=-1).long()
33 | return lengths
34 |
35 |
36 | class WorkingMemoryModel(nn.Module):
37 | def __init__(self, hps, device):
38 | super(WorkingMemoryModel, self).__init__()
39 | self.hps = hps
40 | self.device = device
41 |
42 | self.global_trace_size = hps.global_trace_size
43 | self.topic_trace_size = hps.topic_trace_size
44 | self.topic_slots = hps.topic_slots
45 | self.his_mem_slots = hps.his_mem_slots
46 |
47 | self.vocab_size = hps.vocab_size
48 | self.mem_size = hps.mem_size
49 |
50 | self.sens_num = hps.sens_num
51 |
52 | self.pad_idx = hps.pad_idx
53 | self.bos_tensor = torch.tensor(hps.bos_idx, dtype=torch.long, device=device)
54 |
55 | # ----------------------------
56 | # build componets
57 | self.layers = nn.ModuleDict()
58 | self.layers['word_embed'] = nn.Embedding(hps.vocab_size,
59 | hps.word_emb_size, padding_idx=hps.pad_idx)
60 |
61 | # NOTE: We set fixed 33 phonology categories: 0~32
62 | # please refer to preprocess.py for more details
63 | self.layers['ph_embed'] = nn.Embedding(33, hps.ph_emb_size)
64 |
65 | self.layers['len_embed'] = nn.Embedding(hps.sen_len, hps.len_emb_size)
66 |
67 |
68 | self.layers['encoder'] = BidirEncoder(hps.word_emb_size, hps.hidden_size, drop_ratio=hps.drop_ratio)
69 | self.layers['decoder'] = Decoder(hps.hidden_size, hps.hidden_size, drop_ratio=hps.drop_ratio)
70 |
71 | # project the decoder hidden state to a vocanbulary-size output logit
72 | self.layers['out_proj'] = nn.Linear(hps.hidden_size, hps.vocab_size)
73 |
74 | # update the context vector
75 | self.layers['global_trace_updater'] = ContextLayer(hps.hidden_size, hps.global_trace_size)
76 | self.layers['topic_trace_updater'] = MLP(self.mem_size+self.topic_trace_size,
77 | layer_sizes=[self.topic_trace_size], activs=['tanh'], drop_ratio=hps.drop_ratio)
78 |
79 |
80 | # MLP for calculate initial decoder state
81 | self.layers['dec_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
82 | activs=['tanh'], drop_ratio=hps.drop_ratio)
83 | self.layers['key_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size],
84 | activs=['tanh'], drop_ratio=hps.drop_ratio)
85 |
86 | # history memory reading and writing layers
87 | # query: concatenation of hidden state, global_trace and topic_trace
88 | self.layers['memory_read'] = AttentionReader(
89 | d_q=hps.hidden_size+self.global_trace_size+self.topic_trace_size+self.topic_slots,
90 | d_v=hps.mem_size, drop_ratio=hps.attn_drop_ratio)
91 |
92 | self.layers['memory_write'] = AttentionWriter(hps.mem_size+self.global_trace_size, hps.mem_size)
93 |
94 | # NOTE: a layer to compress the encoder hidden states to a smaller size for larger number of slots
95 | self.layers['mem_compress'] = MLP(hps.hidden_size*2, layer_sizes=[hps.mem_size],
96 | activs=['tanh'], drop_ratio=hps.drop_ratio)
97 |
98 | # [inp, attns, ph_inp, len_inp, global_trace]
99 | self.layers['merge_x'] = MLP(
100 | hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size+hps.global_trace_size+hps.mem_size,
101 | layer_sizes=[hps.hidden_size],
102 | activs=['tanh'], drop_ratio=hps.drop_ratio)
103 |
104 |
105 | # two annealing parameters
106 | self._tau = 1.0
107 | self._teach_ratio = 0.8
108 |
109 |
110 | # ---------------------------------------------------------
111 | # only used for for pre-training
112 | self.layers['dec_init_pre'] = MLP(hps.hidden_size*2,
113 | layer_sizes=[hps.hidden_size],
114 | activs=['tanh'], drop_ratio=hps.drop_ratio)
115 |
116 | self.layers['merge_x_pre'] = MLP(
117 | hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size,
118 | layer_sizes=[hps.hidden_size],
119 | activs=['tanh'], drop_ratio=hps.drop_ratio)
120 |
121 |
122 |
123 | #---------------------------------
124 | def set_tau(self, tau):
125 | if 0.0 < tau <= 1.0:
126 | self.layers['memory_write'].set_tau(tau)
127 |
128 | def get_tau(self):
129 | return self.layers['memory_write'].get_tau()
130 |
131 | def set_teach_ratio(self, teach_ratio):
132 | if 0.0 < teach_ratio <= 1.0:
133 | self._teach_ratio = teach_ratio
134 |
135 | def get_teach_ratio(self):
136 | return self._teach_ratio
137 |
138 |
139 | def set_null_idxes(self, null_idxes):
140 | self.null_idxes = null_idxes.to(self.device).unsqueeze(0)
141 |
142 |
143 | #---------------------------------
144 | def compute_null_mem(self, batch_size):
145 | # we initialize the null memory slot with an average of stop words
146 | # by supposing that the model could learn to ignore these words
147 | emb_null = self.layers['word_embed'](self.null_idxes)
148 |
149 | # (1, L, 2*H)
150 | enc_outs, _ = self.layers['encoder'](emb_null)
151 |
152 | # (1, L, 2 * H) -> (1, L, D)
153 | null_mem = self.layers['mem_compress'](enc_outs)
154 |
155 | # (1, L, D)->(1, 1, D)->(B, 1, D)
156 | self.null_mem = null_mem.mean(dim=1, keepdim=True).repeat(batch_size, 1, 1)
157 |
158 |
159 | def computer_topic_memory(self, keys):
160 | # (B, key_len)
161 | emb_keys = [self.layers['word_embed'](key) for key in keys]
162 | key_lens = [get_seq_length(key, self.pad_idx, self.device) for key in keys]
163 |
164 | batch_size = emb_keys[0].size(0)
165 |
166 | # length == 0 means this is am empty topic slot
167 | topic_mask = torch.zeros(batch_size, self.topic_slots,
168 | dtype=torch.float, device=self.device).bool() # (B, topic_slots)
169 | for step in range(0, self.topic_slots):
170 | topic_mask[:, step] = torch.eq(key_lens[step], 0)
171 |
172 |
173 | key_states_vec, topic_slots = [], []
174 | for step, (emb_key, length) in enumerate(zip(emb_keys, key_lens)):
175 |
176 | # we set the length of empty keys to 1 for parallel processing,
177 | # which will be masked then for memory reading
178 | length.masked_fill_(length.eq(0), 1)
179 |
180 | _, state = self.layers['encoder'](emb_key, length)
181 | # (2, B, H) -> (B, 2, H) -> (B, 2*H)
182 | key_state = state.transpose(0, 1).contiguous().view(batch_size, -1)
183 | mask = (1 - topic_mask[:, step].float()).unsqueeze(1) # (B, 1)
184 |
185 | key_states_vec.append((key_state*mask).unsqueeze(1))
186 |
187 | topic = self.layers['mem_compress'](key_state)
188 | topic_slots.append((topic*mask).unsqueeze(1))
189 |
190 | # (B, topic_slots, mem_size)
191 | topic_mem = torch.cat(topic_slots, dim=1)
192 |
193 | # (B, H)
194 | key_init_state = self.layers['key_init'](
195 | torch.cat(key_states_vec, dim=1).sum(1))
196 |
197 | return topic_mem, topic_mask, key_init_state
198 |
199 |
200 | def computer_local_memory(self, inps, with_length):
201 | batch_size = inps.size(0)
202 | if with_length:
203 | length = get_seq_length(inps, self.pad_idx, self.device)
204 | else:
205 | length = None
206 |
207 | emb_inps = self.layers['word_embed'](inps)
208 |
209 | # outs: (B, L, 2 * H)
210 | # states: (2, B, H)
211 | enc_outs, enc_states = self.layers['encoder'](emb_inps, length)
212 |
213 | init_state = self.layers['dec_init'](enc_states.transpose(0, 1).
214 | contiguous().view(batch_size, -1))
215 |
216 | # (B, L, 2 * H) -> (B, L, D)
217 | local_mem = self.layers['mem_compress'](enc_outs)
218 |
219 | local_mask = torch.eq(inps, self.pad_idx)
220 |
221 | return local_mem, local_mask, init_state
222 |
223 |
224 | def update_global_trace(self, old_global_trace, dec_states, dec_mask):
225 | states = torch.cat(dec_states, dim=2) # (B, H, L)
226 | global_trace = self.layers['global_trace_updater'](
227 | old_global_trace, states*(dec_mask.unsqueeze(1)))
228 | return global_trace
229 |
230 |
231 | def update_topic_trace(self, topic_trace, topic_mem, concat_aligns):
232 | # topic_trace: (B, topic_trace_size+topic_slots)
233 | # concat_aligns: (B, L_gen, mem_slots)
234 |
235 | # 1: topic memory, 2: history memory 3: local memory
236 | topic_align = concat_aligns[:, :, 0:self.topic_slots].mean(dim=1) # (B, topic_slots)
237 |
238 | # (B, topic_slots, mem_size) * (B, topic_slots, 1) -> (B, topic_slots, mem_size)
239 | # -> (B, mem_size)
240 | topic_used = torch.mul(topic_mem, topic_align.unsqueeze(2)).mean(dim=1)
241 |
242 |
243 | new_topic_trace = self.layers['topic_trace_updater'](
244 | torch.cat([topic_trace[:, 0:self.topic_trace_size], topic_used], dim=1))
245 |
246 | read_log = topic_trace[:, self.topic_trace_size:] + topic_align
247 |
248 | fin_topic_trace = torch.cat([new_topic_trace, read_log], dim=1)
249 |
250 | return fin_topic_trace
251 |
252 |
253 | def dec_step(self, inp, state, ph, length, total_mem, total_mask,
254 | global_trace, topic_trace):
255 |
256 | emb_inp = self.layers['word_embed'](inp)
257 | emb_ph = self.layers['ph_embed'](ph)
258 | emb_len = self.layers['len_embed'](length)
259 |
260 | # query for reading read memory
261 | # (B, 1, H]
262 | query = torch.cat([state, global_trace, topic_trace], dim=1).unsqueeze(1)
263 |
264 | # attns: (B, 1, mem_size), align: (B, 1, L)
265 | attns, align = self.layers['memory_read'](query, total_mem, total_mem, total_mask)
266 |
267 |
268 | x = torch.cat([emb_inp, emb_ph, emb_len, attns, global_trace], dim=1).unsqueeze(1)
269 | x = self.layers['merge_x'](x)
270 |
271 | cell_out, new_state = self.layers['decoder'](x, state)
272 | out = self.layers['out_proj'](cell_out)
273 | return out, new_state, align
274 |
275 |
276 | def run_decoder(self, inps, trgs, phs, lens, key_init_state,
277 | history_mem, history_mask, topic_mem, topic_mask, global_trace, topic_trace,
278 | specified_teach_ratio):
279 |
280 | local_mem, local_mask, init_state = \
281 | self.computer_local_memory(inps, key_init_state is None)
282 |
283 | if key_init_state is not None:
284 | init_state = key_init_state
285 |
286 | if specified_teach_ratio is None:
287 | teach_ratio = self._teach_ratio
288 | else:
289 | teach_ratio = specified_teach_ratio
290 |
291 |
292 | # Note this order: 1: topic memory, 2: history memory 3: local memory
293 | total_mask = torch.cat([topic_mask, history_mask, local_mask], dim=1)
294 | total_mem = torch.cat([topic_mem, history_mem, local_mem], dim=1)
295 |
296 | batch_size = inps.size(0)
297 | trg_len = trgs.size(1)
298 |
299 | outs = torch.zeros(batch_size, trg_len, self.vocab_size,
300 | dtype=torch.float, device=self.device)
301 |
302 | state = init_state
303 | inp = self.bos_tensor.repeat(batch_size)
304 | dec_states, attn_weights = [], []
305 |
306 | # generate each line
307 | for t in range(0, trg_len):
308 | out, state, align = self.dec_step(inp, state, phs[:, t],
309 | lens[:, t], total_mem, total_mask, global_trace, topic_trace)
310 | outs[:, t, :] = out
311 |
312 | attn_weights.append(align)
313 |
314 | # teach force with a probability
315 | is_teach = random.random() < teach_ratio
316 | if is_teach or (not self.training):
317 | inp = trgs[:, t]
318 | else:
319 | normed_out = F.softmax(out, dim=-1)
320 | inp = normed_out.data.max(1)[1]
321 |
322 | dec_states.append(state.unsqueeze(2)) # (B, H, 1)
323 | attn_weights.append(align)
324 |
325 |
326 |
327 | # write the history memory
328 | if key_init_state is None:
329 | new_history_mem, _ = self.layers['memory_write'](history_mem, local_mem,
330 | 1.0-local_mask.float(), global_trace, self.null_mem)
331 | else:
332 | new_history_mem = history_mem
333 |
334 | # (B, L)
335 | dec_mask = get_non_pad_mask(trgs, self.pad_idx, self.device)
336 |
337 | # update global trace vector
338 | new_global_trace = self.update_global_trace(global_trace, dec_states, dec_mask)
339 |
340 |
341 | # update topic trace vector
342 | # attn_weights: (B, 1, all_mem_slots) * L_gen
343 | concat_aligns = torch.cat(attn_weights, dim=1)
344 | new_topic_trace = self.update_topic_trace(topic_trace, topic_mem, concat_aligns)
345 |
346 |
347 | return outs, new_history_mem, new_global_trace, new_topic_trace
348 |
349 |
350 |
351 | def initialize_mems(self, keys):
352 | batch_size = keys[0].size(0)
353 | topic_mem, topic_mask, key_init_state = self.computer_topic_memory(keys)
354 |
355 | history_mem = torch.zeros(batch_size, self.his_mem_slots, self.mem_size,
356 | dtype=torch.float, device=self.device)
357 |
358 | # default: True, masked
359 | history_mask = torch.ones(batch_size, self.his_mem_slots,
360 | dtype=torch.float, device=self.device).bool()
361 |
362 | global_trace = torch.zeros(batch_size, self.global_trace_size,
363 | dtype=torch.float, device=self.device)
364 | topic_trace = torch.zeros(batch_size, self.topic_trace_size+self.topic_slots,
365 | dtype=torch.float, device=self.device)
366 |
367 | self.compute_null_mem(batch_size)
368 |
369 | return topic_mem, topic_mask, history_mem, history_mask,\
370 | global_trace, topic_trace, key_init_state
371 |
372 |
373 | def rebuild_inps(self, ori_inps, last_outs, teach_ratio):
374 | # ori_inps: (B, L)
375 | # last_outs: (B, L, V)
376 | inp_len = ori_inps.size(1)
377 | new_inps = torch.ones_like(ori_inps) * self.pad_idx
378 |
379 | mask = get_non_pad_mask(ori_inps, self.pad_idx, self.device).long()
380 |
381 | if teach_ratio is None:
382 | teach_ratio = self._teach_ratio
383 |
384 | for t in range(0, inp_len):
385 | is_teach = random.random() < teach_ratio
386 | if is_teach or (not self.training):
387 | new_inps[:, t] = ori_inps[:, t]
388 | else:
389 | normed_out = F.softmax(last_outs[:, t], dim=-1)
390 | new_inps[:, t] = normed_out.data.max(1)[1]
391 |
392 | new_inps = new_inps * mask
393 |
394 | return new_inps
395 |
396 |
397 | def forward(self, all_inps, all_trgs, all_ph_inps, all_len_inps, keys, teach_ratio=None,
398 | flexible_inps=False):
399 | '''
400 | all_inps: (B, L) * sens_num
401 | all_trgs: (B, L) * sens_num
402 | all_ph_inps: (B, L) * sens_num
403 | all_len_inps: (B, L) * sens_num
404 | keys: (B, L) * topic_slots
405 | flexible_inps: if apply partial teaching force to local memory.
406 | False: the ground-truth src line is stored into the local memory
407 | True: for local memory, ground-truth characters will be replaced with generated characters with
408 | the probability of 1- teach_ratio.
409 | NOTE: this trick is *not* adopted in our original paper, which could lead to
410 | better BLEU and topic relevance, but worse diversity of generated poems.
411 | '''
412 | all_outs = []
413 |
414 | topic_mem, topic_mask, history_mem, history_mask,\
415 | global_trace, topic_trace, key_init_state = self.initialize_mems(keys)
416 |
417 | for step in range(0, self.sens_num):
418 | if step > 0:
419 | key_init_state = None
420 |
421 | if step >= 1 and flexible_inps:
422 | inps = self.rebuild_inps(all_inps[step], all_outs[-1], teach_ratio)
423 | else:
424 | inps = all_inps[step]
425 |
426 | outs, history_mem, global_trace, topic_trace \
427 | = self.run_decoder(inps, all_trgs[step],
428 | all_ph_inps[step], all_len_inps[step], key_init_state,
429 | history_mem, history_mask, topic_mem, topic_mask,
430 | global_trace, topic_trace, teach_ratio)
431 |
432 | if step >= 1:
433 | history_mask = history_mem.abs().sum(-1).eq(0) # (B, mem_slots)
434 |
435 |
436 | all_outs.append(outs)
437 |
438 |
439 | return all_outs
440 |
441 |
442 |
443 | # --------------------------
444 | # graphs for pre-training
445 | def dseq_graph(self, inps, trgs, ph_inps, len_inps, teach_ratio=None):
446 | # pre-train the encoder and decoder as a denoising Seq2Seq model
447 | batch_size, trg_len = trgs.size(0), trgs.size(1)
448 | length = get_seq_length(inps, self.pad_idx, self.device)
449 |
450 |
451 | emb_inps = self.layers['word_embed'](inps)
452 | emb_phs = self.layers['ph_embed'](ph_inps)
453 | emb_lens = self.layers['len_embed'](len_inps)
454 |
455 |
456 | # outs: (B, L, 2 * H)
457 | # states: (2, B, H)
458 | _, enc_states = self.layers['encoder'](emb_inps, length)
459 |
460 |
461 | init_state = self.layers['dec_init_pre'](enc_states.transpose(0, 1).
462 | contiguous().view(batch_size, -1))
463 |
464 |
465 | outs = torch.zeros(batch_size, trg_len, self.vocab_size,
466 | dtype=torch.float, device=self.device)
467 |
468 | if teach_ratio is None:
469 | teach_ratio = self._teach_ratio
470 |
471 | state = init_state
472 | inp = self.bos_tensor.repeat(batch_size, 1)
473 |
474 | # generate each line
475 | for t in range(0, trg_len):
476 | emb_inp = self.layers['word_embed'](inp)
477 | x = self.layers['merge_x_pre'](torch.cat(
478 | [emb_inp, emb_phs[:, t].unsqueeze(1), emb_lens[:, t].unsqueeze(1)],
479 | dim=-1))
480 |
481 | cell_out, state, = self.layers['decoder'](x, state)
482 | out = self.layers['out_proj'](cell_out)
483 |
484 | outs[:, t, :] = out
485 |
486 | # teach force with a probability
487 | is_teach = random.random() < teach_ratio
488 | if is_teach or (not self.training):
489 | inp = trgs[:, t].unsqueeze(1)
490 | else:
491 | normed_out = F.softmax(out, dim=-1)
492 | top1 = normed_out.data.max(1)[1]
493 | inp = top1.unsqueeze(1)
494 |
495 |
496 | return outs
497 |
498 |
499 | # ----------------------------------------------
500 | def dseq_parameter_names(self):
501 | required_names = ['word_embed', 'ph_embed', 'len_embed',
502 | 'encoder', 'decoder', 'out_proj',
503 | 'dec_init_pre', 'merge_x_pre']
504 | return required_names
505 |
506 | def dseq_parameters(self):
507 | names = self.dseq_parameter_names()
508 |
509 | required_params = [self.layers[name].parameters() for name in names]
510 |
511 | return chain.from_iterable(required_params)
512 |
513 | # ------------------------------------------------
--------------------------------------------------------------------------------