├── log └── tmp.txt ├── checkpoint └── tmp.txt ├── corpus └── tmp.txt ├── pictures ├── p0.png ├── p1.png ├── p2.png ├── p3.png ├── p4.png ├── p5.png └── logo.jpg ├── codes ├── __pycache__ │ ├── beam.cpython-37.pyc │ ├── tool.cpython-37.pyc │ ├── config.cpython-37.pyc │ ├── decay.cpython-37.pyc │ ├── filter.cpython-37.pyc │ ├── graphs.cpython-37.pyc │ ├── layers.cpython-37.pyc │ ├── logger.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── criterion.cpython-37.pyc │ ├── generator.cpython-37.pyc │ ├── scheduler.cpython-37.pyc │ ├── cl_trainer.cpython-37.pyc │ ├── dae_trainer.cpython-37.pyc │ ├── dseq_trainer.cpython-37.pyc │ ├── mix_trainer.cpython-37.pyc │ ├── rhythm_tool.cpython-37.pyc │ ├── wm_trainer.cpython-37.pyc │ ├── visualization.cpython-37.pyc │ └── spectral_normalization.cpython-37.pyc ├── criterion.py ├── decay.py ├── scheduler.py ├── train.py ├── config.py ├── logger.py ├── utils.py ├── visualization.py ├── dseq_trainer.py ├── generate.py ├── wm_trainer.py ├── filter.py ├── rhythm_tool.py ├── beam.py ├── layers.py ├── generator.py ├── tool.py └── graphs.py ├── preprocess ├── __pycache__ │ ├── rhythm_tool.cpython-37.pyc │ └── pattern_extractor.cpython-37.pyc ├── pattern_extractor.py ├── rhythm_tool.py └── preprocess.py ├── data ├── GenrePatterns.txt └── fchars.txt └── README.md /log/tmp.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoint/tmp.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /corpus/tmp.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pictures/p0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p0.png -------------------------------------------------------------------------------- /pictures/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p1.png -------------------------------------------------------------------------------- /pictures/p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p2.png -------------------------------------------------------------------------------- /pictures/p3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p3.png -------------------------------------------------------------------------------- /pictures/p4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p4.png -------------------------------------------------------------------------------- /pictures/p5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/p5.png -------------------------------------------------------------------------------- /pictures/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/pictures/logo.jpg -------------------------------------------------------------------------------- /codes/__pycache__/beam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/beam.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/tool.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/decay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/decay.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/filter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/filter.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/graphs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/graphs.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/criterion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/criterion.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/generator.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/cl_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/cl_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/dae_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/dae_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/dseq_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/dseq_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/mix_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/mix_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/rhythm_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/rhythm_tool.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/wm_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/wm_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /preprocess/__pycache__/rhythm_tool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/preprocess/__pycache__/rhythm_tool.cpython-37.pyc -------------------------------------------------------------------------------- /codes/__pycache__/spectral_normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/codes/__pycache__/spectral_normalization.cpython-37.pyc -------------------------------------------------------------------------------- /preprocess/__pycache__/pattern_extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUNLP-AIPoet/WMPoetry/HEAD/preprocess/__pycache__/pattern_extractor.cpython-37.pyc -------------------------------------------------------------------------------- /data/GenrePatterns.txt: -------------------------------------------------------------------------------- 1 | [00]#七绝一#4#0 32 0 31 31 32 32|0 31 0 32 32 31 33|0 31 0 32 31 31 32|0 32 31 31 32 32 33 2 | [01]#七绝二#4#0 31 0 32 32 31 33|0 32 31 31 32 32 33|0 32 0 31 31 32 32|0 31 0 32 32 31 33 3 | [02]#七绝三#4#0 32 31 31 32 32 33|0 31 0 32 32 31 33|0 31 0 32 31 31 32|0 32 31 31 32 32 33 4 | [03]#七绝四#4#0 31 0 32 31 31 32|0 32 31 31 32 32 33|0 32 0 31 31 32 32|0 31 0 32 32 31 33 5 | [04]#五绝一#4#0 31 31 32 32|0 32 32 31 33|0 32 31 31 32|31 31 32 32 33 6 | [05]#五绝二#4#0 32 32 31 33|31 31 32 32 33|0 31 31 32 32|0 32 32 31 33 7 | [06]#五绝三#4#31 31 32 32 33|0 32 32 31 33|0 32 0 31 32|31 31 32 32 33 8 | [07]#五绝四#4#0 32 31 31 32|31 31 32 32 33|0 31 31 32 32|0 32 32 31 33 9 | -------------------------------------------------------------------------------- /codes/criterion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:06:09 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import torch 14 | from torch import nn 15 | 16 | class Criterion(nn.Module): 17 | def __init__(self, pad_idx): 18 | super().__init__() 19 | self._criterion = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=pad_idx) 20 | self._pad_idx = pad_idx 21 | 22 | 23 | def forward(self, outputs, targets): 24 | # outputs: (B, L, V) 25 | # targets: (B, L) 26 | 27 | vocab_size = outputs.size(-1) 28 | outs = outputs.contiguous().view(-1, vocab_size) # outs: (N, V) 29 | tgts = targets.contiguous().view(-1) # tgts: (N) 30 | 31 | non_pad_mask = tgts.ne(self._pad_idx) 32 | 33 | loss = self._criterion(outs, tgts) # [N] 34 | loss = loss.masked_select(non_pad_mask).mean() 35 | 36 | return loss -------------------------------------------------------------------------------- /codes/decay.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:06:44 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import numpy as np 14 | 15 | #--------------------------------------------------- 16 | class RateDecay(object): 17 | '''Basic class for different types of rate decay, 18 | e.g., teach forcing ratio, gumbel temperature, 19 | KL annealing. 20 | ''' 21 | def __init__(self, burn_down_steps, decay_steps, limit_v): 22 | 23 | self.step = 0 24 | self.rate = 1.0 25 | 26 | self.burn_down_steps = burn_down_steps 27 | self.decay_steps = decay_steps 28 | 29 | self.limit_v = limit_v 30 | 31 | 32 | def decay_funtion(self): 33 | # to be reconstructed 34 | return self.rate 35 | 36 | 37 | def do_step(self): 38 | # update rate 39 | self.step += 1 40 | if self.step > self.burn_down_steps: 41 | self.rate = self.decay_funtion() 42 | 43 | return self.rate 44 | 45 | 46 | def get_rate(self): 47 | return self.rate 48 | 49 | 50 | class ExponentialDecay(RateDecay): 51 | def __init__(self, burn_down_steps, decay_steps, min_v): 52 | super(ExponentialDecay, self).__init__( 53 | burn_down_steps, decay_steps, min_v) 54 | 55 | self.__alpha = np.log(self.limit_v) / (-decay_steps) 56 | 57 | def decay_funtion(self): 58 | new_rate = max(np.exp(-self.__alpha*self.step), self.limit_v) 59 | return new_rate -------------------------------------------------------------------------------- /codes/scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:09:56 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | class ISRScheduler(object): 14 | '''Inverse Square Root Schedule 15 | ''' 16 | def __init__(self, optimizer, warmup_steps, max_lr=5e-4, min_lr=3e-5, init_lr=1e-5, beta=0.55): 17 | self._optimizer = optimizer 18 | 19 | self._step = 0 20 | self._rate = init_lr 21 | 22 | self._warmup_steps = warmup_steps 23 | self._max_lr = max_lr 24 | self._min_lr = min_lr 25 | self._init_lr = init_lr 26 | 27 | self._alpha = (max_lr-init_lr) / warmup_steps 28 | self._beta = beta 29 | self._gama = max_lr * warmup_steps ** (beta) 30 | 31 | 32 | def step(self): 33 | self._step += 1 34 | rate = self.rate() 35 | for p in self._optimizer.param_groups: 36 | p['lr'] = rate 37 | self._rate = rate 38 | self._optimizer.step() 39 | 40 | def rate(self): 41 | step = self._step 42 | if step < self._warmup_steps: 43 | lr = self._init_lr + self._alpha * step 44 | else: 45 | lr = self._gama * step ** (-self._beta) 46 | 47 | if step > self._warmup_steps: 48 | lr = max(lr, self._min_lr) 49 | return lr 50 | 51 | def zero_grad(self): 52 | self._optimizer.zero_grad() 53 | 54 | def state_dict(self): 55 | return self._optimizer.state_dict() 56 | 57 | def load_state_dict(self, dic): 58 | self._optimizer.load_state_dict(dic) -------------------------------------------------------------------------------- /codes/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:15:55 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | from dseq_trainer import DSeqTrainer 14 | from wm_trainer import WMTrainer 15 | 16 | from graphs import WorkingMemoryModel 17 | from tool import Tool 18 | from config import device, hparams 19 | import utils 20 | 21 | 22 | def pretrain(wm_model, tool, hps, specified_device): 23 | dseq_trainer = DSeqTrainer(hps, specified_device) 24 | 25 | print ("dseq pretraining...") 26 | dseq_trainer.train(wm_model, tool) 27 | print ("dseq pretraining done!") 28 | 29 | 30 | 31 | def train(wm_model, tool, hps, specified_device): 32 | last_epoch = utils.restore_checkpoint( 33 | hps.model_dir, specified_device, wm_model) 34 | 35 | if last_epoch is not None: 36 | print ("checkpoint exsits! directly recover!") 37 | else: 38 | print ("checkpoint not exsits! train from scratch!") 39 | 40 | wm_trainer = WMTrainer(hps, specified_device) 41 | wm_trainer.train(wm_model, tool) 42 | 43 | 44 | def main(): 45 | hps = hparams 46 | tool = Tool(hps.sens_num, hps.sen_len, 47 | hps.key_len, hps.topic_slots, hps.corrupt_ratio) 48 | tool.load_dic(hps.vocab_path, hps.ivocab_path) 49 | vocab_size = tool.get_vocab_size() 50 | PAD_ID = tool.get_PAD_ID() 51 | B_ID = tool.get_B_ID() 52 | assert vocab_size > 0 and PAD_ID >=0 and B_ID >= 0 53 | hps = hps._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID) 54 | 55 | print ("hyper-patameters:") 56 | print (hps) 57 | input("please check the hyper-parameters, and then press any key to continue >") 58 | 59 | wm_model = WorkingMemoryModel(hps, device) 60 | wm_model = wm_model.to(device) 61 | 62 | pretrain(wm_model, tool, hps, device) 63 | train(wm_model, tool, hps, device) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /data/fchars.txt: -------------------------------------------------------------------------------- 1 | 一 2 | 二 3 | 三 4 | 四 5 | 五 6 | 六 7 | 七 8 | 八 9 | 九 10 | 十 11 | 百 12 | 千 13 | 万 14 | 几 15 | 只 16 | 个 17 | 两 18 | 傍 19 | 上 20 | 更 21 | 在 22 | 此 23 | 止 24 | 正 25 | 动 26 | 了 27 | 加 28 | 于 29 | 互 30 | 亟 31 | 交 32 | 亦 33 | 随 34 | 洊 35 | 弥 36 | 纵 37 | 弗 38 | 率 39 | 再 40 | 继 41 | 啻 42 | 促 43 | 愈 44 | 焉 45 | 如 46 | 嘻 47 | 矣 48 | 然 49 | 属 50 | 綦 51 | 俱 52 | 况 53 | 殊 54 | 讫 55 | 间 56 | 讵 57 | 孰 58 | 许 59 | 俄 60 | 设 61 | 披 62 | 等 63 | 犹 64 | 渐 65 | 抵 66 | 把 67 | 顾 68 | 须 69 | 屡 70 | 顷 71 | 呼 72 | 希 73 | 忽 74 | 遽 75 | 是 76 | 汔 77 | 遍 78 | 呜 79 | 遄 80 | 忝 81 | 遂 82 | 必 83 | 常 84 | 有 85 | 鲜 86 | 倏 87 | 最 88 | 沓 89 | 敬 90 | 唯 91 | 虑 92 | 故 93 | 沿 94 | 便 95 | 益 96 | 盍 97 | 抑 98 | 固 99 | 被 100 | 相 101 | 直 102 | 因 103 | 漫 104 | 莫 105 | 却 106 | 即 107 | 凡 108 | 到 109 | 特 110 | 蹔 111 | 能 112 | 廑 113 | 依 114 | 阿 115 | 兼 116 | 举 117 | 猥 118 | 蔑 119 | 否 120 | 循 121 | 徐 122 | 后 123 | 徒 124 | 差 125 | 得 126 | 通 127 | 速 128 | 总 129 | 向 130 | 适 131 | 同 132 | 猝 133 | 逆 134 | 各 135 | 猗 136 | 斯 137 | 临 138 | 奚 139 | 於 140 | 奉 141 | 奈 142 | 方 143 | 已 144 | 要 145 | 奄 146 | 普 147 | 类 148 | 吁 149 | 全 150 | 样 151 | 仅 152 | 略 153 | 从 154 | 仍 155 | 今 156 | 介 157 | 似 158 | 伪 159 | 傥 160 | 令 161 | 每 162 | 会 163 | 洎 164 | 前 165 | 尝 166 | 尚 167 | 苟 168 | 尔 169 | 少 170 | 尽 171 | 偶 172 | 若 173 | 溘 174 | 庸 175 | 约 176 | 庶 177 | 酷 178 | 请 179 | 应 180 | 底 181 | 身 182 | 蚤 183 | 杂 184 | 罕 185 | 暨 186 | 都 187 | 来 188 | 诸 189 | 暂 190 | 暇 191 | 哉 192 | 比 193 | 毕 194 | 兹 195 | 为 196 | 其 197 | 共 198 | 靡 199 | 公 200 | 兮 201 | 毋 202 | 详 203 | 业 204 | 丕 205 | 且 206 | 专 207 | 非 208 | 与 209 | 不 210 | 先 211 | 克 212 | 嗟 213 | 兀 214 | 元 215 | 稍 216 | 潜 217 | 叨 218 | 可 219 | 较 220 | 辄 221 | 叵 222 | 堪 223 | 及 224 | 又 225 | 反 226 | 或 227 | 往 228 | 恒 229 | 彼 230 | 当 231 | 恰 232 | 皆 233 | 的 234 | 索 235 | 擅 236 | 累 237 | 謇 238 | 偏 239 | 假 240 | 使 241 | 由 242 | 以 243 | 佥 244 | 偕 245 | 佯 246 | 何 247 | 刚 248 | 则 249 | 切 250 | 但 251 | 甚 252 | 生 253 | 竟 254 | 竞 255 | 勿 256 | 立 257 | 耶 258 | 而 259 | 者 260 | 端 261 | 预 262 | 颇 263 | 岂 264 | 频 265 | 蒙 266 | 无 267 | 既 268 | 时 269 | 矧 270 | 处 271 | 喏 272 | 复 273 | 虽 274 | 块 275 | 坐 276 | 致 277 | 至 278 | 自 279 | 欻 280 | 也 281 | 乌 282 | 乍 283 | 乎 284 | 之 285 | 欤 286 | 乃 287 | 宁 288 | 它 289 | 脱 290 | 倘 291 | 趋 292 | 趁 293 | 夫 294 | 例 295 | 连 296 | 厪 297 | 荐 298 | 这 299 | 迄 300 | 所 301 | 迭 302 | 迪 303 | 替 304 | 那 305 | 咸 306 | 曷 307 | 曩 308 | 氐 309 | 悉 310 | 咄 311 | 并 312 | 和 313 | 噫 314 | -------------------------------------------------------------------------------- /codes/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 21:05:31 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | 14 | from collections import namedtuple 15 | import torch 16 | HParams = namedtuple('HParams', 17 | 'vocab_size, pad_idx, bos_idx,' 18 | 'word_emb_size, ph_emb_size, len_emb_size,' 19 | 'hidden_size, mem_size, global_trace_size, topic_trace_size,' 20 | 'his_mem_slots, topic_slots, sens_num, sen_len, key_len,' 21 | 22 | 'batch_size, drop_ratio, attn_drop_ratio, weight_decay, clip_grad_norm,' 23 | 'max_lr, min_lr, init_lr, warmup_steps,' 24 | 'min_tr, burn_down_tr, decay_tr,' 25 | 'tau_annealing_steps, min_tau,' 26 | 27 | 'log_steps, sample_num, max_epoches,' 28 | 'save_epoches, validate_epoches,' 29 | 'vocab_path, ivocab_path, train_data, valid_data,' 30 | 'model_dir, data_dir, train_log_path, valid_log_path,' 31 | 32 | 'corrupt_ratio, dseq_epoches, dseq_batch_size,' 33 | 'dseq_max_lr, dseq_min_lr, dseq_init_lr dseq_warmup_steps,' 34 | 'dseq_min_tr, dseq_burn_down_tr, dseq_decay_tr,' 35 | 36 | 'dseq_log_steps, dseq_validate_epoches, dseq_save_epoches,' 37 | 'dseq_train_log_path, dseq_valid_log_path,' 38 | ) 39 | 40 | 41 | hparams = HParams( 42 | # -------------------- 43 | # general settings 44 | vocab_size=-1, pad_idx=-1, bos_idx=-1, # to be replaced by true size after loading dictionary 45 | word_emb_size=256, ph_emb_size=64, len_emb_size=32, 46 | hidden_size=512, mem_size=512, global_trace_size=512, topic_trace_size=20, 47 | his_mem_slots=4, topic_slots=4, sens_num=4, sen_len=10, key_len=2, 48 | 49 | 50 | batch_size=128, drop_ratio=0.25, attn_drop_ratio=0.1, 51 | weight_decay=2.5e-4, clip_grad_norm=2.0, 52 | max_lr=1e-3, min_lr=5e-8, init_lr=1e-4, warmup_steps=6000, # learning rate decay 53 | min_tr=0.80, burn_down_tr=3, decay_tr=10, # epoches for teach forcing ratio decay 54 | tau_annealing_steps=30000, min_tau=0.01,# Gumbel temperature, from 1 to min_tau 55 | 56 | log_steps=100, sample_num=1, max_epoches=14, 57 | save_epoches=2, validate_epoches=1, 58 | 59 | vocab_path="../corpus/vocab.pickle", 60 | ivocab_path="../corpus/ivocab.pickle", 61 | train_data="../corpus/train_data.pickle", 62 | valid_data="../corpus/valid_data.pickle", 63 | model_dir="../checkpoint/", 64 | data_dir="../data/", 65 | train_log_path="../log/wm_train_log.txt", 66 | valid_log_path="../log/wm_valid_log.txt", 67 | 68 | #-------------------------- 69 | # for pre-training 70 | corrupt_ratio=0.1, dseq_epoches=10, dseq_batch_size=256, 71 | dseq_max_lr=1e-3, dseq_min_lr=5e-5, dseq_init_lr=1e-4, dseq_warmup_steps=6000, 72 | dseq_min_tr=0.80, dseq_burn_down_tr=3, dseq_decay_tr=7, 73 | 74 | dseq_log_steps=200, dseq_validate_epoches=1, dseq_save_epoches=2, 75 | dseq_train_log_path="../log/dae_train_log.txt", 76 | dseq_valid_log_path="../log/dae_valid_log.txt" 77 | 78 | ) 79 | 80 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 81 | -------------------------------------------------------------------------------- /codes/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:09:32 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import numpy as np 14 | import time 15 | 16 | 17 | class InfoLogger(object): 18 | """docstring for LogInfo""" 19 | def __init__(self, mode): 20 | super(InfoLogger).__init__() 21 | self._mode = mode # string, 'train' or 'valid' 22 | self._total_steps = 0 23 | self._batch_num = 0 24 | self._log_steps = 0 25 | self._cur_step = 0 26 | self._cur_epoch = 1 27 | 28 | self._start_time = 0 29 | self._end_time = 0 30 | 31 | #-------------------------- 32 | self._log_path = "" # path to save the log file 33 | 34 | # ------------------------- 35 | self._decay_rates = {'learning_rate':1.0, 36 | 'teach_ratio':1.0, 'temperature':1.0} 37 | 38 | 39 | def set_batch_num(self, batch_num): 40 | self._batch_num = batch_num 41 | def set_log_steps(self, log_steps): 42 | self._log_steps = log_steps 43 | def set_log_path(self, log_path): 44 | self._log_path = log_path 45 | 46 | def set_rate(self, name, value): 47 | self._decay_rates[name] = value 48 | 49 | 50 | def set_start_time(self): 51 | self._start_time = time.time() 52 | 53 | def set_end_time(self): 54 | self._end_time = time.time() 55 | 56 | def add_step(self): 57 | self._total_steps += 1 58 | self._cur_step += 1 59 | 60 | def add_epoch(self): 61 | self._cur_step = 0 62 | self._cur_epoch += 1 63 | 64 | 65 | # ------------------------------ 66 | @property 67 | def cur_process(self): 68 | ratio = float(self._cur_step) / self._batch_num * 100 69 | process_str = "%d/%d %.1f%%" % (self._cur_step, self._batch_num, ratio) 70 | return process_str 71 | 72 | @property 73 | def time_cost(self): 74 | return (self._end_time-self._start_time) / self._log_steps 75 | 76 | @property 77 | def total_steps(self): 78 | return self._total_steps 79 | 80 | @property 81 | def epoch(self): 82 | return self._cur_epoch 83 | 84 | @property 85 | def mode(self): 86 | return self._mode 87 | 88 | @property 89 | def log_path(self): 90 | return self._log_path 91 | 92 | 93 | @property 94 | def learning_rate(self): 95 | return self._decay_rates['learning_rate'] 96 | 97 | @property 98 | def teach_ratio(self): 99 | return self._decay_rates['teach_ratio'] 100 | 101 | @property 102 | def temperature(self): 103 | return self._decay_rates['temperature'] 104 | 105 | 106 | #------------------------------------ 107 | class SimpleLogger(InfoLogger): 108 | def __init__(self, mode): 109 | super(SimpleLogger, self).__init__(mode) 110 | self._gen_loss = 0.0 111 | 112 | def add_losses(self, gen_loss): 113 | self.add_step() 114 | self._gen_loss += gen_loss 115 | 116 | 117 | def print_log(self, epoch=None): 118 | 119 | gen_loss = self._gen_loss / self.total_steps 120 | ppl = np.exp(gen_loss) 121 | 122 | if self.mode == 'train': 123 | process_info = "epoch: %d, %s, %.2fs per iter, lr: %.4f, tr: %.2f, tau: %.3f" % (self.epoch, 124 | self.cur_process, self.time_cost, self.learning_rate, self.teach_ratio, self.temperature) 125 | else: 126 | process_info = "epoch: %d, lr: %.4f, tr: %.2f, tau: %.3f" % ( 127 | epoch, self.learning_rate, self.teach_ratio, self.temperature) 128 | 129 | train_info = " gen loss: %.3f ppl:%.2f" % (gen_loss, ppl) 130 | 131 | 132 | print (process_info) 133 | print (train_info) 134 | print ("______________________") 135 | 136 | info = process_info + "\n" + train_info 137 | fout = open(self.log_path, 'a') 138 | fout.write(info + "\n\n") 139 | fout.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WMPoetry 2 | The source code of [*Chinese Poetry Generation with a Working Memory Model*](https://www.ijcai.org/Proceedings/2018/0633.pdf) (IJCAI 2018). 3 | 4 | ## 1. Rights 5 | All rights reserved. 6 | 7 | ## 2. Requirements 8 | * python>=3.7.0 9 | * pytorch>=1.3.1 10 | * matplotlib>=2.2.3 11 | 12 | A Tensorflow version of our model is available [here](https://github.com/XiaoyuanYi/WMPoetry). 13 | 14 | ## 3. Data Preparation 15 | To train the model and generate poems, please 16 | 17 | * add the training, validation and testing sets of our [THU-CCPC](https://github.com/THUNLP-AIPoet/Datasets/tree/master/CCPC) data into the *WMPoetry/preprocess/* directory; 18 | * add the pingsheng.txt, zesheng.txt, pingshui.txt and pingshui_amb.pkl files of our [THU-CRRD](https://github.com/THUNLP-AIPoet/Datasets/tree/master/CRRD) set into the *WMPoetry/data/* directory. 19 | 20 | ## 4. Preprocessing 21 | In WMPoetry/preprocess/, just run: 22 | ``` 23 | python preprocess.py 24 | ``` 25 | 26 | Then, move the produced vocab.pickle, ivocab.pickle, train_data.pickle and valid_data.pickle into *WMPoetry/corpus/*; and move test_inps.txt, test_trgs.txt and training_lines.txt into *WMPoetry/data/*. 27 | 28 | ## 5. Training 29 | In WMPoetry/codes/, run: 30 | ``` 31 | python train.py 32 | ``` 33 | The encoder and decoder will be pre-trained as a denoising seq2seq model, and then the Working Memory model is trained based on the pre-trained one. 34 | 35 | One can also edit WMPoetry/codes/**config.py** to modify the configuration, such as the hidden size, embedding size, data path, training epoch, learning rate and so on. 36 | 37 | During the training process, some training information is outputed, such as: 38 | 39 |
40 | 41 | The training and validation information is saved in the log directory, e.g., WMPoetry/log/. 42 | 43 | ## 6. Generation 44 | To generate a poem in an interactive interface, in WMPoetry/codes/, run: 45 | ``` 46 | python generate.py -v 1 47 | ``` 48 | Then one can input some keywords, select the genre pattern and the rhyme category, and then get the generated pome: 49 | 50 |
51 | 52 | 53 | 54 |
55 | 56 | We provide a genere pattern file for Chinese classical quatrains, please refer to WMPoetry/data/GenrePatterns.txt for details. 57 | 58 | By running: 59 | ``` 60 | python generate.py -v 1 -s 1 61 | ``` 62 | , one can manually select each generated line from the beam candidates. 63 | 64 |
65 | 66 |
67 | 68 | By running: 69 | ``` 70 | python generate.py -v 1 -d 1 71 | ``` 72 | , one can get the visualization of memory reading probabilities and the contents stored in the memory for each generation step, such as: 73 | 74 |
75 | 76 |
77 | 78 | These visualization pictures are saved in the log directory. *NOTE*: We use simhei font in matplotlib for Chinese characters. Therefore, to use our provided visualization script, please make sure you have correctly installed and set your font. 79 | 80 | 81 | To generate poems with an input testing file, which contains a set of keywords and genre patterns, run: 82 | ``` 83 | python generate.py -m file -i ../data/test_inps.txt -o test_outs.txt 84 | ``` 85 | 86 | ## 7. Cite 87 | If you use our source code, please kindly cite this paper: 88 | 89 | Xiaoyuan Yi, Maosong Sun, Ruoyu Li and Zonghan Yang. Chinese Poetry Generation with a Working Memory Model. *In Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence*, pages 4553–4559, Stockholm, Sweden, 2018. 90 | 91 | The bib format is as follows: 92 | ``` 93 | @inproceedings{Yimemory:18, 94 | author = {Xiaoyuan Yi and Maosong Sun and Ruoyu Li and Zonghan Yang}, 95 | title = {Chinese Poetry Generation with a Working Memory Mode}, 96 | year = "2018", 97 | pages = "4553--4559", 98 | booktitle = {Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence}, 99 | address = {Stockholm, Sweden} 100 | } 101 | ``` 102 | ## 8. System 103 | This work is a part of the automatic Chinese poetry generation system, [THUAIPoet (Jiuge, 九歌)](https://jiuge.thunlp.cn) developed by Research Center for Natural Language Processing, Computational Humanities and Social Sciences, Tsinghua University (清华大学人工智能研究院, 自然语言处理与社会人文计算研究中心). Please refer to [THUNLP](https://github.com/thunlp) and [THUNLP Lab](http://nlp.csai.tsinghua.edu.cn/site2/) for more information. 104 | 105 |
106 | 107 | ## 9. Contact 108 | If you have any questions, suggestions or bug reports, please feel free to email yi-xy16@mails.tsinghua.edu.cn or mtmoonyi@gmail.com. 109 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:07:53 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import os 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | import random 18 | 19 | def save_checkpoint(model_dir, epoch, model, prefix='', optimizer=None): 20 | # save model state dict 21 | checkpoint_name = "model_ckpt_{}_{}e.tar".format(prefix, epoch) 22 | model_state_path = os.path.join(model_dir, checkpoint_name) 23 | 24 | saved_dic = { 25 | 'epoch': epoch, 26 | 'model_state_dict': model.state_dict() 27 | } 28 | 29 | if optimizer is not None: 30 | saved_dic['optimizer'] = optimizer.state_dict() 31 | 32 | 33 | torch.save(saved_dic, model_state_path) 34 | 35 | # write checkpoint information 36 | log_path = os.path.join(model_dir, "ckpt_list.txt") 37 | fout = open(log_path, 'a') 38 | fout.write(checkpoint_name+"\n") 39 | fout.close() 40 | 41 | 42 | def restore_checkpoint(model_dir, device, model, optimizer=None): 43 | ckpt_list_path = os.path.join(model_dir, "ckpt_list.txt") 44 | if not os.path.exists(ckpt_list_path): 45 | print ("checkpoint list not exists, creat new one!") 46 | return None 47 | 48 | # get latest ckpt name 49 | fin = open(ckpt_list_path, 'r') 50 | latest_ckpt_path = fin.readlines()[-1].strip() 51 | fin.close() 52 | 53 | latest_ckpt_path = os.path.join(model_dir, latest_ckpt_path) 54 | if not os.path.exists(latest_ckpt_path): 55 | print ("latest checkpoint not exists!") 56 | return None 57 | 58 | 59 | print ("restore checkpoint from %s" % (latest_ckpt_path)) 60 | print ("loading...") 61 | checkpoint = torch.load(latest_ckpt_path, map_location=device) 62 | #checkpoint = torch.load(latest_ckpt_path) 63 | print ("load state dic, params: %d..." % (len(checkpoint['model_state_dict']))) 64 | model.load_state_dict(checkpoint['model_state_dict']) 65 | 66 | 67 | if optimizer is not None: 68 | print ("load optimizer dic...") 69 | optimizer.load_state_dict(checkpoint['optimizer']) 70 | 71 | 72 | epoch = checkpoint['epoch'] 73 | 74 | 75 | return epoch 76 | 77 | 78 | def sample_dseq(inputs, targets, logits, sample_num, tool): 79 | # inps, trgs [batch size, sen len] 80 | # logits [batch size, trg len, vocab size] 81 | batch_size = inputs.size(0) 82 | inp_len = inputs.size(1) 83 | trg_len = targets.size(1) 84 | out_len = logits.size(1) 85 | 86 | 87 | sample_num = min(sample_num, batch_size) 88 | 89 | # randomly select some examples 90 | sample_ids = random.sample(list(range(0, batch_size)), sample_num) 91 | 92 | for sid in sample_ids: 93 | # Build lines 94 | inps = [inputs[sid, t].item() for t in range(0, inp_len)] 95 | sline = tool.idxes2line(inps) 96 | 97 | # ------------------------------------------- 98 | trgs = [targets[sid, t].item() for t in range(0, trg_len)] 99 | tline = tool.idxes2line(trgs) 100 | 101 | 102 | # ------------------------------------------ 103 | probs = F.softmax(logits, dim=-1) 104 | outs = [probs[sid, t, :].cpu().data.numpy() for t in range(0, out_len)] 105 | oline = tool.greedy_search(outs) 106 | 107 | 108 | print("inp: " + sline) 109 | print("trg: " + tline) 110 | print("out: " + oline) 111 | print ("") 112 | 113 | 114 | #------------------------------ 115 | def sample_wm(keys, all_trgs, all_outs, sample_num, tool): 116 | batch_size = all_trgs[0].size(0) 117 | sample_num = min(sample_num, batch_size) 118 | 119 | # random select some examples 120 | sample_ids = random.sample(list(range(0, batch_size)), sample_num) 121 | 122 | for sid in sample_ids: 123 | key_lines = [] 124 | for key in keys: 125 | key_idxes = [key[sid, t].item() for t in range(0, key.size(1))] 126 | key_str = tool.idxes2line(key_idxes) 127 | if len(key_str) == 0: 128 | key_str = "PAD" 129 | key_lines.append(key_str) 130 | 131 | trg_lines = [] 132 | for trg in all_trgs: 133 | trg_idxes = [trg[sid, t].item() for t in range(0, trg.size(1))] 134 | trg_lines.append(tool.idxes2line(trg_idxes)) 135 | 136 | out_lines = [] 137 | for out in all_outs: 138 | probs = F.softmax(out, dim=-1) 139 | out_probs = [probs[sid, t, :].cpu().data.numpy() for t in range(0, probs.size(1))] 140 | out_lines.append(tool.greedy_search(out_probs)) 141 | 142 | #-------------------------------------------- 143 | print("keywords: " + "|".join(key_lines)) 144 | print("target: " + "|".join(trg_lines)) 145 | print("output: " + "|".join(out_lines)) 146 | print ("") 147 | 148 | 149 | def print_parameter_list(model, prefix=None): 150 | params = model.named_parameters() 151 | 152 | param_num = 0 153 | for name, param in params: 154 | if prefix is not None: 155 | seg = name.split(".")[1] 156 | if seg in prefix: 157 | print(name, param.size()) 158 | param_num += 1 159 | else: 160 | print(name, param.size()) 161 | param_num += 1 162 | 163 | print ("params num: %d" % (param_num)) 164 | #------------------------------ -------------------------------------------------------------------------------- /codes/visualization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 22:04:36 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | from matplotlib import pyplot as plt 14 | plt.rcParams['font.family'] = ['simhei'] 15 | 16 | from matplotlib.colors import from_levels_and_colors 17 | 18 | import numpy as np 19 | import copy 20 | 21 | import torch 22 | 23 | class Visualization(object): 24 | """docstring for LogInfo""" 25 | def __init__(self, topic_slots, history_slots, log_path): 26 | super(Visualization).__init__() 27 | 28 | self._topic_slots = topic_slots 29 | self._history_slots = history_slots 30 | self._log_path = log_path 31 | 32 | 33 | def reset(self, keywords): 34 | self._keywords = keywords 35 | self._history_mem = [' ']*self._history_slots 36 | self._gen_lines = [] 37 | 38 | 39 | def add_gen_line(self, line): 40 | self._gen_lines.append(line.strip()) 41 | 42 | def normalization(self, ori_matrix): 43 | new_matrix = ori_matrix / ori_matrix.sum(axis=1, keepdims=True) 44 | return new_matrix 45 | 46 | 47 | def draw(self, read_log, write_log, step, visual_mode): 48 | assert visual_mode in [0, 1, 2] 49 | # read_log: (1, 1, mem_slots) * L_gen 50 | # write_log: (B, L_gen, mem_slots) 51 | current_gen_chars = [c for c in self._gen_lines[-1]] 52 | gen_len = len(current_gen_chars) 53 | 54 | if len(self._gen_lines) >= 2: 55 | last_gen_chars = [c for c in self._gen_lines[-2]] 56 | last_gen_len = len(last_gen_chars) 57 | else: 58 | last_gen_chars = [''] * gen_len 59 | last_gen_len = gen_len 60 | 61 | # (L_gen, mem_slots) 62 | mem_slots = self._topic_slots+self._history_slots+last_gen_len 63 | read_matrix = torch.cat(read_log, dim=1)[0, 0:gen_len, 0:mem_slots].detach().cpu().numpy() 64 | read_matrix = self.normalization(read_matrix) 65 | 66 | plt.figure(figsize=(11, 5)) 67 | 68 | # visualization of reading attention weights 69 | num_levels = 100 70 | vmin, vmax = read_matrix.min(), read_matrix.max() 71 | midpoint = 0 72 | levels = np.linspace(vmin, vmax, num_levels) 73 | midp = np.mean(np.c_[levels[:-1], levels[1:]], axis=1) 74 | vals = np.interp(midp, [vmin, midpoint, vmax], [0, 0.5, 1]) 75 | colors = plt.cm.seismic(vals) 76 | cmap, norm = from_levels_and_colors(levels, colors) 77 | 78 | 79 | plt.imshow(read_matrix, cmap=cmap, interpolation='none') 80 | 81 | # print generated chars and chars in the memory 82 | fontsize = 14 83 | 84 | plt.text(0.2, gen_len+0.5, "Topic Memory", fontsize=fontsize) 85 | plt.text(self._topic_slots, gen_len+0.5, "History Memory", fontsize=fontsize) 86 | if last_gen_len == 5: 87 | shift = 5 88 | else: 89 | shift = 6 90 | plt.text(self._topic_slots+shift, gen_len+0.5, "Local Memory", fontsize=fontsize) 91 | 92 | # topic memory 93 | for i in range(0, len(self._keywords)): 94 | key = self._keywords[i] 95 | if len(key) == 1: 96 | key = " " + key + " " 97 | key = key + "|" 98 | plt.text(i-0.4,-0.7, key, fontsize=fontsize) 99 | 100 | start_pos = self._topic_slots 101 | end_pos = self._topic_slots + self._history_slots 102 | 103 | # history memory 104 | for i in range(start_pos, end_pos): 105 | c = self._history_mem[i - start_pos] 106 | if i == end_pos - 1: 107 | c = c + " |" 108 | 109 | plt.text(i-0.2,-0.7, c, fontsize=fontsize) 110 | 111 | start_pos = end_pos 112 | end_pos = start_pos + last_gen_len 113 | 114 | # local memory 115 | for i in range(start_pos, end_pos): 116 | idx = i - start_pos 117 | plt.text(i-0.2,-0.7, last_gen_chars[idx], fontsize=fontsize) 118 | 119 | # generated line 120 | for i in range(0, len(current_gen_chars)): 121 | plt.text(-1.2, i+0.15, current_gen_chars[i], fontsize=fontsize) 122 | 123 | plt.colorbar() 124 | plt.tick_params(labelbottom=False, labelleft=False) 125 | 126 | 127 | x_major_locator = plt.MultipleLocator(1) 128 | y_major_locator = plt.MultipleLocator(1) 129 | ax = plt.gca() 130 | ax.xaxis.set_major_locator(x_major_locator) 131 | ax.yaxis.set_major_locator(y_major_locator) 132 | #plt.tight_layout() 133 | 134 | 135 | if visual_mode == 1: 136 | fig = plt.gcf() 137 | fig.savefig(self._log_path + 'visual_step_{}.png'.format(step), dpi=300, quality=100, bbox_inches="tight") 138 | elif visual_mode == 2: 139 | plt.show() 140 | 141 | 142 | # update history memory 143 | if write_log is not None: 144 | if len(last_gen_chars) == 0: 145 | print ("last generated line is empty!") 146 | 147 | write_log = write_log[0, :, :].detach().cpu().numpy() 148 | history_mem = copy.deepcopy(self._history_mem) 149 | for i, c in enumerate(last_gen_chars): 150 | selected_slot = np.argmax(write_log[i, :]) 151 | if selected_slot >= self._history_slots: 152 | continue 153 | history_mem[selected_slot] = c 154 | 155 | self._history_mem = history_mem 156 | -------------------------------------------------------------------------------- /codes/dseq_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:39:50 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import torch 14 | 15 | from scheduler import ISRScheduler 16 | from criterion import Criterion 17 | from decay import ExponentialDecay 18 | from logger import SimpleLogger 19 | import utils 20 | 21 | class DSeqTrainer(object): 22 | 23 | def __init__(self, hps, device): 24 | self.hps = hps 25 | self.device = device 26 | 27 | 28 | def run_validation(self, epoch, wm_model, criterion, tool, lr): 29 | logger = SimpleLogger('valid') 30 | logger.set_batch_num(tool.valid_batch_num) 31 | logger.set_log_path(self.hps.dseq_valid_log_path) 32 | logger.set_rate('learning_rate', lr) 33 | logger.set_rate('teach_ratio', wm_model.get_teach_ratio()) 34 | 35 | for step in range(0, tool.valid_batch_num): 36 | 37 | batch = tool.valid_batches[step] 38 | 39 | inps = batch[0].to(self.device) 40 | trgs = batch[1].to(self.device) 41 | ph_inps = batch[2].to(self.device) 42 | len_inps = batch[3].to(self.device) 43 | 44 | with torch.no_grad(): 45 | gen_loss, _ = self.run_step(wm_model, None, criterion, 46 | inps, trgs, ph_inps, len_inps, True) 47 | logger.add_losses(gen_loss) 48 | 49 | logger.print_log(epoch) 50 | 51 | 52 | def run_step(self, wm_model, optimizer, criterion, 53 | inps, trgs, ph_inps, len_inps, valid=False): 54 | if not valid: 55 | optimizer.zero_grad() 56 | 57 | outs = wm_model.dseq_graph(inps, trgs, ph_inps, len_inps) 58 | 59 | loss = criterion(outs, trgs) 60 | 61 | if not valid: 62 | loss.backward() 63 | torch.nn.utils.clip_grad_norm_(wm_model.dseq_parameters(), 64 | self.hps.clip_grad_norm) 65 | optimizer.step() 66 | 67 | return loss.item(), outs 68 | 69 | 70 | def run_train(self, wm_model, tool, optimizer, criterion, logger): 71 | logger.set_start_time() 72 | 73 | for step in range(0, tool.train_batch_num): 74 | 75 | batch = tool.train_batches[step] 76 | 77 | inps = batch[0].to(self.device) 78 | trgs = batch[1].to(self.device) 79 | ph_inps = batch[2].to(self.device) 80 | len_inps = batch[3].to(self.device) 81 | 82 | gen_loss, outs = self.run_step(wm_model, optimizer, criterion, 83 | inps, trgs, ph_inps, len_inps) 84 | 85 | logger.add_losses(gen_loss) 86 | logger.set_rate("learning_rate", optimizer.rate()) 87 | if step % self.hps.dseq_log_steps == 0: 88 | logger.set_end_time() 89 | utils.sample_dseq(inps, trgs, outs, self.hps.sample_num, tool) 90 | logger.print_log() 91 | logger.set_start_time() 92 | 93 | 94 | 95 | def train(self, wm_model, tool): 96 | #utils.print_parameter_list(wm_model, wm_model.dseq_parameter_names()) 97 | 98 | # load data for pre-training 99 | print ("building data for dseq...") 100 | tool.build_data(self.hps.train_data, self.hps.valid_data, 101 | self.hps.dseq_batch_size, mode='dseq') 102 | 103 | print ("train batch num: %d" % (tool.train_batch_num)) 104 | print ("valid batch num: %d" % (tool.valid_batch_num)) 105 | 106 | #input("please check the parameters, and then press any key to continue >") 107 | 108 | 109 | # training logger 110 | logger = SimpleLogger('train') 111 | logger.set_batch_num(tool.train_batch_num) 112 | logger.set_log_steps(self.hps.dseq_log_steps) 113 | logger.set_log_path(self.hps.dseq_train_log_path) 114 | logger.set_rate('learning_rate', 0.0) 115 | logger.set_rate('teach_ratio', 1.0) 116 | 117 | 118 | # build optimizer 119 | opt = torch.optim.AdamW(wm_model.dseq_parameters(), 120 | lr=1e-3, betas=(0.9, 0.99), weight_decay=self.hps.weight_decay) 121 | optimizer = ISRScheduler(optimizer=opt, warmup_steps=self.hps.dseq_warmup_steps, 122 | max_lr=self.hps.dseq_max_lr, min_lr=self.hps.dseq_min_lr, 123 | init_lr=self.hps.dseq_init_lr, beta=0.6) 124 | 125 | wm_model.train() 126 | 127 | criterion = Criterion(self.hps.pad_idx) 128 | 129 | # tech forcing ratio decay 130 | tr_decay_tool = ExponentialDecay(self.hps.dseq_burn_down_tr, self.hps.dseq_decay_tr, 131 | self.hps.dseq_min_tr) 132 | 133 | # train 134 | for epoch in range(1, self.hps.dseq_epoches+1): 135 | 136 | self.run_train(wm_model, tool, optimizer, criterion, logger) 137 | 138 | if epoch % self.hps.dseq_validate_epoches == 0: 139 | print("run validation...") 140 | wm_model.eval() 141 | print ("in training mode: %d" % (wm_model.training)) 142 | self.run_validation(epoch, wm_model, criterion, tool, optimizer.rate()) 143 | wm_model.train() 144 | print ("validation Done: %d" % (wm_model.training)) 145 | 146 | 147 | if (self.hps.dseq_save_epoches >= 1) and \ 148 | (epoch % self.hps.dseq_save_epoches) == 0: 149 | # save checkpoint 150 | print("saving model...") 151 | utils.save_checkpoint(self.hps.model_dir, epoch, wm_model, prefix="dseq") 152 | 153 | 154 | logger.add_epoch() 155 | 156 | print ("teach forcing ratio decay...") 157 | wm_model.set_teach_ratio(tr_decay_tool.do_step()) 158 | logger.set_rate('teach_ratio', tr_decay_tool.get_rate()) 159 | 160 | print("shuffle data...") 161 | tool.shuffle_train_data() -------------------------------------------------------------------------------- /codes/generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:17:27 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | from generator import Generator 14 | from config import hparams, device 15 | import copy 16 | import argparse 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="The parametrs for the generator.") 20 | parser.add_argument("-m", "--mode", type=str, choices=['interact', 'file'], default='interact', 21 | help='The mode of generation. interact: generate in a interactive mode.\ 22 | file: take an input file and generate poems for each input in the file.') 23 | parser.add_argument("-b", "--bsize", type=int, default=20, help="beam size, 20 by default.") 24 | parser.add_argument("-v", "--verbose", type=int, default=0, choices=[0, 1, 2, 3], 25 | help="Show other information during the generation, False by default.") 26 | parser.add_argument("-d", "--draw", type=int, default=0, choices=[0, 1, 2], 27 | help="Show the visualization of memory reading and writing. It only works in the interact mode.\ 28 | 0: not work, 1: save the visualization as pictures, 2: show the visualization at each step.") 29 | parser.add_argument("-s", "--select", type=int, default=0, 30 | help="If manually select each generated line from beam candidates? False by default.\ 31 | It works only in the interact mode.") 32 | parser.add_argument("-i", "--inp", type=str, 33 | help="input file path. it works only in the file mode.") 34 | parser.add_argument("-o", "--out", type=str, 35 | help="output file path. it works only in the file mode") 36 | return parser.parse_args() 37 | 38 | 39 | class GenerateTool(object): 40 | """docstring for GenerateTool""" 41 | def __init__(self): 42 | super(GenerateTool, self).__init__() 43 | self.generator = Generator(hparams, device) 44 | self._load_patterns(hparams.data_dir+"/GenrePatterns.txt") 45 | 46 | 47 | def _load_patterns(self, path): 48 | with open(path, 'r') as fin: 49 | lines = fin.readlines() 50 | 51 | self._patterns = [] 52 | ''' 53 | each line contains: 54 | pattern id, pattern name, the number of lines, 55 | pattern: 0 either, 31 pingze, 32 ze, 33 rhyme position 56 | ''' 57 | for line in lines: 58 | line = line.strip() 59 | para = line.split("#") 60 | pas = para[3].split("|") 61 | newpas = [] 62 | for pa in pas: 63 | pa = pa.split(" ") 64 | newpas.append([int(p) for p in pa]) 65 | 66 | self._patterns.append((para[1], newpas)) 67 | 68 | self.p_num = len(self._patterns) 69 | print ("load %d patterns." % (self.p_num)) 70 | 71 | 72 | def build_pattern(self, pstr): 73 | pstr_vec = pstr.split("|") 74 | patterns = [] 75 | for pstr in pstr_vec: 76 | pas = pstr.split(" ") 77 | pas = [int(pa) for pa in pas] 78 | patterns.append(pas) 79 | 80 | return patterns 81 | 82 | 83 | def generate_file(self, args): 84 | beam_size = args.bsize 85 | verbose = args.verbose 86 | manu = True if args.select ==1 else False 87 | 88 | assert args.inp is not None 89 | assert args.out is not None 90 | 91 | with open(args.inp, 'r') as fin: 92 | inps = fin.readlines() 93 | 94 | 95 | fout = open(args.out, 'w') 96 | 97 | poems = [] 98 | N = len(inps) 99 | log_step = max(int(N/100), 2) 100 | for i, inp in enumerate(inps): 101 | para = inp.strip().split("#") 102 | keywords = para[0].split(" ") 103 | pattern = self.build_pattern(para[1]) 104 | 105 | lines, info = self.generator.generate_one(keywords, pattern, 106 | beam_size, verbose, manu=manu) 107 | 108 | if len(lines) != 4: 109 | ans = info 110 | else: 111 | ans = "|".join(lines) 112 | 113 | fout.write(ans+"\n") 114 | 115 | if i % log_step == 0: 116 | print ("generating, %d/%d" % (i, N)) 117 | fout.flush() 118 | 119 | 120 | fout.close() 121 | 122 | 123 | def _set_rhyme_into_pattern(self, ori_pattern, rhyme): 124 | pattern = copy.deepcopy(ori_pattern) 125 | for i in range(0, len(pattern)): 126 | if pattern[i][-1] == 33: 127 | pattern[i][-1] = rhyme 128 | return pattern 129 | 130 | 131 | def generate_manu(self, args): 132 | beam_size = args.bsize 133 | verbose = args.verbose 134 | manu = True if args.select ==1 else False 135 | visual_mode = args.draw 136 | 137 | while True: 138 | keys = input("please input keywords (with whitespace split), 4 at most > ") 139 | pattern_id = int(input("please select genre pattern 0~{} > ".format(self.p_num-1))) 140 | rhyme = int(input("please input rhyme id, 1~30> ")) 141 | 142 | ori_pattern = self._patterns[pattern_id] 143 | name = ori_pattern[0] 144 | pattern = ori_pattern[1] 145 | pattern = self._set_rhyme_into_pattern(pattern, rhyme) 146 | print ("select pattern: %s" % (name)) 147 | 148 | keywords = keys.strip().split(" ") 149 | lines, info = self.generator.generate_one(keywords, pattern, 150 | beam_size, verbose, manu=manu, visual=visual_mode) 151 | 152 | if len(lines) != 4: 153 | print("generation failed!") 154 | continue 155 | else: 156 | print("\n".join(lines)) 157 | 158 | 159 | def main(): 160 | args = parse_args() 161 | generate_tool = GenerateTool() 162 | if args.mode == 'interact': 163 | generate_tool.generate_manu(args) 164 | else: 165 | generate_tool.generate_file(args) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /preprocess/pattern_extractor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi and Jiannan Liang 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:24:22 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | 14 | import pickle 15 | import os 16 | import copy 17 | 18 | import numpy as np 19 | 20 | from rhythm_tool import RhythmRecognizer 21 | 22 | 23 | class PatternExtractor(object): 24 | 25 | def __init__(self, data_dir): 26 | ''' 27 | rhythm patterns. 28 | for Chinese quatrains, we generalize four main poem-level patterns. 29 | 30 | NOTE: We use pingshuiyun (The Pingshui Rhyme Category) 31 | We only consider level-tone rhyme in terms of the requirements of 32 | Chinese classical quatrains. 33 | 0: either level or oblique tone; 1~30 rhyme categorizes, 34 | 31: level tone, 32: oblique tone 35 | 36 | ''' 37 | self._RHYTHM_PATTERNS = {7: [[0, 32, 0, 31, 31, 32, 32], [0, 31, 0, 32, 32, 31, 31], 38 | [0, 32, 31, 31, 32, 32, 31], [0, 31, 0, 32, 31, 31, 32]], 39 | 5: [[0, 31, 31, 32, 32], [0, 32, 32, 31, 31], [31, 31, 32, 32, 31], [0, 32, 0, 31, 32]]} 40 | 41 | 42 | ''' 43 | rhythm patterns. 44 | for Chinese quatrains, we generalize four main poem-level patterns 45 | ''' 46 | self._RHYTHM_TYPES = [[0, 1, 3, 2], [1, 2, 0, 1], [2, 1, 3, 2], [3, 2, 0, 1]] 47 | 48 | 49 | self._rhythm_tool = RhythmRecognizer(data_dir+"pingsheng.txt", data_dir+"zesheng.txt") 50 | 51 | self._load_rhythm_dic(data_dir+"pingsheng.txt", data_dir+"zesheng.txt") 52 | self._load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl") 53 | 54 | 55 | 56 | def _load_rhythm_dic(self, level_path, oblique_path): 57 | with open(level_path, 'r') as fin: 58 | level_chars = fin.read() 59 | 60 | with open(oblique_path, 'r') as fin: 61 | oblique_chars = fin.read() 62 | 63 | self._level_list = list(level_chars) 64 | self._oblique_list = list(oblique_chars) 65 | 66 | 67 | print (" rhythm dic loaded, level tone chars: %d, oblique tone chars: %d" %\ 68 | (len(self._level_list), len(self._oblique_list))) 69 | 70 | 71 | #------------------------------------------ 72 | def _load_rhyme_dic(self, rhyme_dic_path, rhyme_disamb_path): 73 | 74 | self._rhyme_dic = {} # char id to rhyme category ids 75 | self._rhyme_idic = {} # rhyme category id to char ids 76 | 77 | with open(rhyme_dic_path, 'r') as fin: 78 | lines = fin.readlines() 79 | 80 | amb_count = 0 81 | for line in lines: 82 | (char, rhyme_id) = line.strip().split(' ') 83 | 84 | rhyme_id = int(rhyme_id) 85 | 86 | if not char in self._rhyme_dic: 87 | self._rhyme_dic.update({char:[rhyme_id]}) 88 | elif not rhyme_id in self._rhyme_dic[char]: 89 | self._rhyme_dic[char].append(rhyme_id) 90 | amb_count += 1 91 | 92 | if not rhyme_id in self._rhyme_idic: 93 | self._rhyme_idic.update({rhyme_id:[char]}) 94 | else: 95 | self._rhyme_idic[rhyme_id].append(char) 96 | 97 | print (" rhyme dic loaded, ambiguous rhyme chars: %d" % (amb_count)) 98 | 99 | # load data for rhyme disambiguation 100 | self._ngram_rhyme_map = {} # rhyme id list of each bigram or trigram 101 | self._char_rhyme_map = {} # the most likely rhyme id for each char 102 | # load the calculated data, if there is any 103 | #print (rhyme_disamb_path) 104 | assert rhyme_disamb_path is not None and os.path.exists(rhyme_disamb_path) 105 | 106 | with open(rhyme_disamb_path, 'rb') as fin: 107 | self._char_rhyme_map = pickle.load(fin) 108 | self._ngram_rhyme_map = pickle.load(fin) 109 | 110 | print (" rhyme disamb data loaded, cached chars: %d, ngrams: %d" 111 | % (len(self._char_rhyme_map), len(self._ngram_rhyme_map))) 112 | 113 | 114 | def get_line_rhyme(self, line): 115 | """ we use statistics of ngram to disambiguate the rhyme category, 116 | but there is still risk of mismatching and ambiguity 117 | """ 118 | tail_char = line[-1] 119 | 120 | if tail_char in self._rhyme_dic: 121 | rhyme_candis = self._rhyme_dic[tail_char] 122 | if len(rhyme_candis) == 1: 123 | return rhyme_candis[0] 124 | 125 | if tail_char in self._char_rhyme_map: 126 | bigram = line[-2] + line[-1] 127 | if bigram in self._ngram_rhyme_map: 128 | return int(self._ngram_rhyme_map[bigram][0]) 129 | 130 | trigram = line[-3] + line[-2] + line[-1] 131 | if trigram in self._ngram_rhyme_map: 132 | return int(self._ngram_rhyme_map[trigram][0]) 133 | 134 | return int(self._char_rhyme_map[tail_char][0]) 135 | 136 | 137 | return -1 138 | 139 | 140 | def get_poem_rhyme(self, sens): 141 | assert len(sens) == 4 142 | rhymes = [self.get_line_rhyme(sen) for sen in sens] 143 | 144 | #print (rhymes) 145 | #input(">") 146 | 147 | if rhymes[1] == -1 and rhymes[3] != -1: 148 | rhymes[1] = rhymes[3] 149 | elif rhymes[1] != -1 and rhymes[3] == -1: 150 | rhymes[3] = rhymes[1] 151 | elif rhymes[1] == -1 and rhymes[3] == -1: 152 | return [] 153 | 154 | 155 | if (rhymes[0] != -1) and (rhymes[0]!= rhymes[1]): 156 | rhymes[0] = rhymes[1] 157 | 158 | rhymes[2] = -1 159 | 160 | return rhymes 161 | 162 | 163 | def pattern_complete(self, rhythm_ids): 164 | if rhythm_ids.count(-1) == 0: 165 | return rhythm_ids 166 | 167 | if rhythm_ids.count(-1) > 1: 168 | return [] 169 | 170 | #print (rhythm_ids) 171 | for poem_pattern in self._RHYTHM_TYPES: 172 | eq = (np.array(poem_pattern) \ 173 | == np.array(rhythm_ids)).astype(np.float) 174 | 175 | if np.sum(eq) != 3: 176 | continue 177 | 178 | pos = list(eq).index(0) 179 | rhythm_ids[pos] = poem_pattern[pos] 180 | #print (rhythm_ids) 181 | #input(">") 182 | return rhythm_ids 183 | 184 | return [] 185 | 186 | 187 | def get_poem_rhythm(self, sens, length): 188 | assert len(sens) == 4 189 | 190 | rhythm_ids = [] 191 | for sen in sens: 192 | #print (sen) 193 | rhythm_id = self._rhythm_tool.get_rhythm(sen) 194 | rhythm_ids.append(rhythm_id) 195 | 196 | 197 | rhythm_ids = self.pattern_complete(rhythm_ids) 198 | 199 | if len(rhythm_ids) == 0: 200 | return [] 201 | 202 | rhythm_pattern = [copy.deepcopy(self._RHYTHM_PATTERNS[length][id]) 203 | for id in rhythm_ids] 204 | 205 | return rhythm_pattern -------------------------------------------------------------------------------- /preprocess/rhythm_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Ruoyu Li 3 | # @Description: 4 | ''' 5 | Copyright 2019 THUNLP Lab. All Rights Reserved. 6 | This code is part of the online Chinese poetry generation system, Jiuge. 7 | System URL: https://jiuge.thunlp.cn/. 8 | Github: https://github.com/THUNLP-AIPoet. 9 | ''' 10 | class RhythmRecognizer(object): 11 | """Get the rhythm id of a input line 12 | This tool can be applied to Chinese classical quatrains only 13 | """ 14 | 15 | def __init__(self, ping_file, ze_file): 16 | 17 | # read level tone char list 18 | with open(ping_file, 'r') as fin: 19 | self.__ping = fin.read() 20 | #print (type(self.__ping)) 21 | 22 | with open(ze_file, 'r') as fin: 23 | self.__ze = fin.read() 24 | #print (type(self.__ze)) 25 | 26 | def get_rhythm(self, sentence): 27 | # print "#" + sentence + "#" 28 | if(len(sentence) == 5): 29 | #1 30 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze): 31 | return 0 32 | #2 33 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze): 34 | return 0 35 | #3 36 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze): 37 | return 0 38 | #4 39 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 40 | return 0 41 | #5 42 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 43 | return 0 44 | #6 45 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping): 46 | return 1 47 | #7 48 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping): 49 | return 1 50 | #8 51 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze): 52 | return 3 53 | #9 54 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 55 | return 3 56 | #10 57 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze): 58 | return 3 59 | #11 60 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 61 | return 3 62 | #12 63 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping): 64 | return 2 65 | #13 66 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping): 67 | return 2 68 | #14 69 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping): 70 | return 2 71 | 72 | 73 | elif (len(sentence) == 7): 74 | #1 75 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze): 76 | return 0 77 | #2 78 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ze): 79 | return 0 80 | #3 81 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze): 82 | return 0 83 | #4 84 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 85 | return 0 86 | #5 87 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 88 | return 0 89 | #6 90 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping): 91 | return 1 92 | #7 93 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping): 94 | return 1 95 | #8 96 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze): 97 | return 3 98 | #9 99 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 100 | return 3 101 | #10 102 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze): 103 | return 3 104 | #11 105 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 106 | return 3 107 | #12 108 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ping): 109 | return 2 110 | #13 111 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping): 112 | return 2 113 | #14 114 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping): 115 | return 2 116 | else: 117 | return -2 118 | return -1 -------------------------------------------------------------------------------- /codes/wm_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:39:43 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import torch 14 | from scheduler import ISRScheduler 15 | from criterion import Criterion 16 | from decay import ExponentialDecay 17 | from logger import SimpleLogger 18 | import utils 19 | 20 | 21 | class WMTrainer(object): 22 | 23 | def __init__(self, hps, device): 24 | self.hps = hps 25 | self.device = device 26 | 27 | def run_validation(self, epoch, wm_model, criterion, tool, lr): 28 | logger = SimpleLogger('valid') 29 | logger.set_batch_num(tool.valid_batch_num) 30 | logger.set_log_path(self.hps.valid_log_path) 31 | logger.set_rate('learning_rate', lr) 32 | logger.set_rate('teach_ratio', wm_model.get_teach_ratio()) 33 | logger.set_rate('temperature', wm_model.get_tau()) 34 | 35 | for step in range(0, tool.valid_batch_num): 36 | 37 | batch = tool.valid_batches[step] 38 | 39 | all_inps = [inps.to(self.device) for inps in batch[0]] 40 | all_trgs = [trgs.to(self.device) for trgs in batch[1]] 41 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in batch[2]] 42 | all_len_inps = [len_inps.to(self.device) for len_inps in batch[3]] 43 | keys = [key.to(self.device) for key in batch[4]] 44 | 45 | with torch.no_grad(): 46 | gen_loss, _ = self.run_step(wm_model, None, criterion, 47 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys, True) 48 | 49 | logger.add_losses(gen_loss) 50 | 51 | logger.print_log(epoch) 52 | 53 | 54 | def run_step(self, wm_model, optimizer, criterion, 55 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys, valid=False): 56 | 57 | if not valid: 58 | optimizer.zero_grad() 59 | 60 | all_outs = wm_model(all_inps, all_trgs, all_ph_inps, all_len_inps, keys) 61 | 62 | loss_vec = [] 63 | assert len(all_outs) == len(all_trgs) 64 | for out, trg in zip(all_outs, all_trgs): 65 | loss = criterion(out, trg) 66 | loss_vec.append(loss.unsqueeze(0)) 67 | 68 | loss = torch.cat(loss_vec, dim=0).mean() 69 | 70 | if not valid: 71 | loss.backward() 72 | torch.nn.utils.clip_grad_norm_(wm_model.parameters(), 73 | self.hps.clip_grad_norm) 74 | optimizer.step() 75 | 76 | return loss.item(), all_outs 77 | 78 | # ------------------------------------------------------------------------- 79 | def run_train(self, wm_model, tool, optimizer, criterion, logger): 80 | 81 | logger.set_start_time() 82 | 83 | for step in range(0, tool.train_batch_num): 84 | 85 | batch = tool.train_batches[step] 86 | all_inps = [inps.to(self.device) for inps in batch[0]] 87 | all_trgs = [trgs.to(self.device) for trgs in batch[1]] 88 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in batch[2]] 89 | all_len_inps = [len_inps.to(self.device) for len_inps in batch[3]] 90 | keys = [key.to(self.device) for key in batch[4]] 91 | 92 | # train the classifier, recognition network and decoder 93 | gen_loss, all_outs = self.run_step(wm_model, optimizer, criterion, 94 | all_inps, all_trgs, all_ph_inps, all_len_inps, keys) 95 | 96 | logger.add_losses(gen_loss) 97 | logger.set_rate("learning_rate", optimizer.rate()) 98 | 99 | # temperature annealing 100 | wm_model.set_tau(self.tau_decay_tool.do_step()) 101 | logger.set_rate('temperature', self.tau_decay_tool.get_rate()) 102 | 103 | if step % self.hps.log_steps == 0: 104 | logger.set_end_time() 105 | utils.sample_wm(keys, all_trgs, all_outs, self.hps.sample_num, tool) 106 | logger.print_log() 107 | logger.set_start_time() 108 | 109 | 110 | 111 | def train(self, wm_model, tool): 112 | #utils.print_parameter_list(wm_model) 113 | # load data for pre-training 114 | print ("building data for wm...") 115 | tool.build_data(self.hps.train_data, self.hps.valid_data, 116 | self.hps.batch_size, mode='wm') 117 | 118 | print ("train batch num: %d" % (tool.train_batch_num)) 119 | print ("valid batch num: %d" % (tool.valid_batch_num)) 120 | 121 | 122 | #input("please check the parameters, and then press any key to continue >") 123 | 124 | # training logger 125 | logger = SimpleLogger('train') 126 | logger.set_batch_num(tool.train_batch_num) 127 | logger.set_log_steps(self.hps.log_steps) 128 | logger.set_log_path(self.hps.train_log_path) 129 | logger.set_rate('learning_rate', 0.0) 130 | logger.set_rate('teach_ratio', 1.0) 131 | logger.set_rate('temperature', 1.0) 132 | 133 | 134 | # build optimizer 135 | opt = torch.optim.AdamW(wm_model.parameters(), 136 | lr=1e-3, betas=(0.9, 0.99), weight_decay=self.hps.weight_decay) 137 | optimizer = ISRScheduler(optimizer=opt, warmup_steps=self.hps.warmup_steps, 138 | max_lr=self.hps.max_lr, min_lr=self.hps.min_lr, 139 | init_lr=self.hps.init_lr, beta=0.6) 140 | 141 | wm_model.train() 142 | 143 | null_idxes = tool.load_function_tokens(self.hps.data_dir + "fchars.txt").to(self.device) 144 | wm_model.set_null_idxes(null_idxes) 145 | 146 | criterion = Criterion(self.hps.pad_idx) 147 | 148 | # change each epoch 149 | tr_decay_tool = ExponentialDecay(self.hps.burn_down_tr, self.hps.decay_tr, self.hps.min_tr) 150 | # change each iteration 151 | self.tau_decay_tool = ExponentialDecay(0, self.hps.tau_annealing_steps, self.hps.min_tau) 152 | 153 | 154 | # ----------------------------------------------------------- 155 | # train with all data 156 | for epoch in range(1, self.hps.max_epoches+1): 157 | 158 | self.run_train(wm_model, tool, optimizer, criterion, logger) 159 | 160 | if epoch % self.hps.validate_epoches == 0: 161 | print("run validation...") 162 | wm_model.eval() 163 | print ("in training mode: %d" % (wm_model.training)) 164 | self.run_validation(epoch, wm_model, criterion, tool, optimizer.rate()) 165 | wm_model.train() 166 | print ("validation Done: %d" % (wm_model.training)) 167 | 168 | 169 | if (self.hps.save_epoches >= 1) and \ 170 | (epoch % self.hps.save_epoches) == 0: 171 | # save checkpoint 172 | print("saving model...") 173 | utils.save_checkpoint(self.hps.model_dir, epoch, wm_model, prefix="wm") 174 | 175 | 176 | logger.add_epoch() 177 | 178 | print ("teach forcing ratio decay...") 179 | wm_model.set_teach_ratio(tr_decay_tool.do_step()) 180 | logger.set_rate('teach_ratio', tr_decay_tool.get_rate()) 181 | 182 | print("shuffle data...") 183 | tool.shuffle_train_data() -------------------------------------------------------------------------------- /codes/filter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi and Jiannan Liang 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:22:16 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import pickle 14 | import copy 15 | import os 16 | 17 | class PoetryFilter(object): 18 | 19 | 20 | def __init__(self, vocab, ivocab, data_dir): 21 | self._vocab = vocab 22 | self._ivocab = ivocab 23 | 24 | self._load_rhythm_dic(data_dir+"pingsheng.txt", data_dir+"zesheng.txt") 25 | self._load_rhyme_dic(data_dir+"pingshui.txt", data_dir+"pingshui_amb.pkl") 26 | self._load_line_lib(data_dir+"training_lines.txt") 27 | 28 | 29 | def _load_line_lib(self, data_path): 30 | self._line_lib = {} 31 | 32 | with open(data_path, 'r') as fin: 33 | lines = fin.readlines() 34 | 35 | for line in lines: 36 | line = line.strip() 37 | self._line_lib[line] = 1 38 | 39 | print (" line lib loaded, %d lines: " % (len(self._line_lib))) 40 | 41 | 42 | 43 | def _load_rhythm_dic(self, level_path, oblique_path): 44 | with open(level_path, 'r') as fin: 45 | level_chars = fin.read() 46 | 47 | with open(oblique_path, 'r') as fin: 48 | oblique_chars = fin.read() 49 | 50 | self._level_list = [] 51 | self._oblique_list = [] 52 | # convert char to id 53 | for char, idx in self._vocab.items(): 54 | if char in level_chars: 55 | self._level_list.append(idx) 56 | 57 | if char in oblique_chars: 58 | self._oblique_list.append(idx) 59 | 60 | print (" rhythm dic loaded, level tone chars: %d, oblique tone chars: %d" %\ 61 | (len(self._level_list), len(self._oblique_list))) 62 | 63 | 64 | #------------------------------------------ 65 | def _load_rhyme_dic(self, rhyme_dic_path, rhyme_disamb_path): 66 | 67 | self._rhyme_dic = {} # char id to rhyme category ids 68 | self._rhyme_idic = {} # rhyme category id to char ids 69 | 70 | with open(rhyme_dic_path, 'r') as fin: 71 | lines = fin.readlines() 72 | 73 | 74 | amb_count = 0 75 | for line in lines: 76 | (char, rhyme_id) = line.strip().split(' ') 77 | if char not in self._vocab: 78 | continue 79 | char_id = self._vocab[char] 80 | rhyme_id = int(rhyme_id) 81 | 82 | if not char_id in self._rhyme_dic: 83 | self._rhyme_dic.update({char_id:[rhyme_id]}) 84 | elif not rhyme_id in self._rhyme_dic[char_id]: 85 | self._rhyme_dic[char_id].append(rhyme_id) 86 | 87 | 88 | if not rhyme_id in self._rhyme_idic: 89 | self._rhyme_idic.update({rhyme_id:[char_id]}) 90 | else: 91 | self._rhyme_idic[rhyme_id].append(char_id) 92 | 93 | 94 | # load data for rhyme disambiguation 95 | self._ngram_rhyme_map = {} # rhyme id list of each bigram or trigram 96 | self._char_rhyme_map = {} # the most likely rhyme id for each char 97 | # load the calculated data, if there is any 98 | #print (rhyme_disamb_path) 99 | assert rhyme_disamb_path is not None and os.path.exists(rhyme_disamb_path) 100 | 101 | with open(rhyme_disamb_path, 'rb') as fin: 102 | self._char_rhyme_map = pickle.load(fin) 103 | self._ngram_rhyme_map = pickle.load(fin) 104 | 105 | print (" rhyme disamb data loaded, cached chars: %d, ngrams: %d" 106 | % (len(self._char_rhyme_map), len(self._ngram_rhyme_map))) 107 | 108 | 109 | 110 | def get_line_rhyme(self, line): 111 | """ we use statistics of ngram to disambiguate the rhyme category, 112 | but there is still a risk of mismatching and ambiguity 113 | """ 114 | tail_char = line[-1] 115 | 116 | if tail_char in self._char_rhyme_map: 117 | rhyme_candis = self._char_rhyme_map[tail_char] 118 | if len(rhyme_candis) == 1: 119 | return rhyme_candis[0] 120 | 121 | bigram = line[-2] + line[-1] 122 | if bigram in self._ngram_rhyme_map: 123 | return self._ngram_rhyme_map[bigram][0] 124 | 125 | trigram = line[-3] + line[-2] + line[-1] 126 | if trigram in self._ngram_rhyme_map: 127 | return self._ngram_rhyme_map[trigram][0] 128 | 129 | return self._char_rhyme_map[tail_char][0] 130 | 131 | 132 | if not tail_char in self._vocab: 133 | return -1 134 | else: 135 | tail_id = self._vocab[tail_char] 136 | 137 | 138 | if tail_id in self._rhyme_dic: 139 | return self._rhyme_dic[tail_id][0] 140 | 141 | return -1 142 | 143 | # ------------------------------ 144 | def reset(self, length, verbose): 145 | assert length == 5 or length == 7 146 | self._length = length 147 | self._repetitive_ids = [] 148 | self._verbose = verbose 149 | 150 | 151 | def add_repetitive(self, ids): 152 | self._repetitive_ids = list(set(ids+self._repetitive_ids)) 153 | 154 | 155 | # ------------------------------- 156 | def get_level_cids(self): 157 | return copy.deepcopy(self._level_list) 158 | 159 | def get_oblique_cids(self): 160 | return copy.deepcopy(self._oblique_list) 161 | 162 | def get_rhyme_cids(self, rhyme_id): 163 | if rhyme_id not in self._rhyme_idic: 164 | return [] 165 | else: 166 | return copy.deepcopy(self._rhyme_idic[rhyme_id]) 167 | 168 | def get_repetitive_ids(self): 169 | return copy.deepcopy(self._repetitive_ids) 170 | 171 | 172 | def filter_illformed(self, lines, costs, states, aligns, rhyme_id): 173 | if len(lines) == 0: 174 | return [], [], [], [] 175 | 176 | new_lines, new_costs = [], [] 177 | new_states, new_aligns = [], [] 178 | 179 | len_error, lib_error, rhyme_error = 0, 0, 0 180 | 181 | for i in range(len(lines)): 182 | #print (lines[i]) 183 | if len(lines[i]) < self._length: 184 | len_error += 1 185 | continue 186 | 187 | line = lines[i][0:self._length] 188 | 189 | # we filter out the lines that already exist in the 190 | # training set, to guarantee the novelty of generated poems 191 | if line in self._line_lib: 192 | lib_error += 1 193 | continue 194 | 195 | if 1 <= rhyme_id <= 30: 196 | if self.get_line_rhyme(line) != rhyme_id: 197 | rhyme_error != 1 198 | continue 199 | 200 | new_lines.append(line) 201 | new_costs.append(costs[i]) 202 | new_states.append(states[i]) 203 | new_aligns.append(aligns[i]) 204 | 205 | 206 | if self._verbose >= 3: 207 | print ("input lines: %d, ilter out %d illformed lines, %d remain" 208 | % (len(lines), len(lines)-len(new_lines), len(new_lines))) 209 | print ("%d len error, %d exist in lib, %d rhyme error" 210 | % (len_error, lib_error, rhyme_error)) 211 | 212 | return new_lines, new_costs, new_states, new_aligns -------------------------------------------------------------------------------- /codes/rhythm_tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Ruoyu Li 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:21:37 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | 14 | class RhythmRecognizer(object): 15 | """Get the rhythm id of a input line 16 | This tool can be applied to Chinese classical quatrains only 17 | """ 18 | 19 | def __init__(self, ping_file, ze_file): 20 | 21 | # read level tone char list 22 | with open(ping_file, 'r') as fin: 23 | self.__ping = fin.read() 24 | #print (type(self.__ping)) 25 | 26 | with open(ze_file, 'r') as fin: 27 | self.__ze = fin.read() 28 | #print (type(self.__ze)) 29 | 30 | def get_rhythm(self, sentence): 31 | # print "#" + sentence + "#" 32 | if(len(sentence) == 5): 33 | #1 34 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze): 35 | return 0 36 | #2 37 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze): 38 | return 0 39 | #3 40 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze): 41 | return 0 42 | #4 43 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 44 | return 0 45 | #5 46 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 47 | return 0 48 | #6 49 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping): 50 | return 1 51 | #7 52 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping): 53 | return 1 54 | #8 55 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze): 56 | return 3 57 | #9 58 | if(sentence[0] in self.__ping and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 59 | return 3 60 | #10 61 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze): 62 | return 3 63 | #11 64 | if(sentence[0] in self.__ze and sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze): 65 | return 3 66 | #12 67 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping): 68 | return 2 69 | #13 70 | if(sentence[0] in self.__ze and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping): 71 | return 2 72 | #14 73 | if(sentence[0] in self.__ping and sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping): 74 | return 2 75 | 76 | 77 | elif (len(sentence) == 7): 78 | #1 79 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze): 80 | return 0 81 | #2 82 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ze): 83 | return 0 84 | #3 85 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ze): 86 | return 0 87 | #4 88 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 89 | return 0 90 | #5 91 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 92 | return 0 93 | #6 94 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping): 95 | return 1 96 | #7 97 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ping): 98 | return 1 99 | #8 100 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze): 101 | return 3 102 | #9 103 | if(sentence[1] in self.__ping and sentence[2] in self.__ping and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 104 | return 3 105 | #10 106 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ping and sentence[5] in self.__ping and sentence[6] in self.__ze): 107 | return 3 108 | #11 109 | if(sentence[1] in self.__ping and sentence[2] in self.__ze and sentence[3] in self.__ze and sentence[4] in self.__ze and sentence[5] in self.__ping and sentence[6] in self.__ze): 110 | return 3 111 | #12 112 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ze and sentence[5] in self.__ze and sentence[6] in self.__ping): 113 | return 2 114 | #13 115 | if(sentence[1] in self.__ze and sentence[2] in self.__ze and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping): 116 | return 2 117 | #14 118 | if(sentence[1] in self.__ze and sentence[2] in self.__ping and sentence[3] in self.__ping and sentence[4] in self.__ping and sentence[5] in self.__ze and sentence[6] in self.__ping): 119 | return 2 120 | else: 121 | return -2 122 | return -1 123 | -------------------------------------------------------------------------------- /codes/beam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 18:04:40 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import numpy as np 14 | import torch 15 | import copy 16 | 17 | 18 | class Hypothesis(object): 19 | ''' 20 | a hypothesis which holds the generated tokens, 21 | current state, beam score and memory reading weights 22 | ''' 23 | def __init__(self, tokens, states, score, read_aligns): 24 | self.score = score 25 | self.states = states 26 | self.candidate = copy.deepcopy(tokens) 27 | self.read_aligns = read_aligns # (1, L_i, mem_slots) 28 | 29 | 30 | class PoetryBeam(object): 31 | def __init__(self, device, beam_size, length, B_ID, E_ID, UNK_ID, 32 | level_char_ids, oblique_char_ids): 33 | """Initialize params.""" 34 | self.device = device 35 | 36 | self._length = length 37 | self._beam_size = beam_size 38 | 39 | self._B_ID = B_ID 40 | self._E_ID = E_ID 41 | self._UNK_ID = UNK_ID 42 | 43 | self._level_cids = level_char_ids 44 | self._oblique_cids = oblique_char_ids 45 | 46 | 47 | def reset(self, init_state, rhythms, rhyme_char_ids, repetitive_ids): 48 | # reset before generating each line 49 | self._hypotheses \ 50 | = [Hypothesis([self._B_ID], [init_state.clone().detach()], 0.0, []) 51 | for _ in range(0, self._beam_size)] 52 | 53 | self._completed_hypotheses = [] 54 | 55 | self._rhythms = rhythms # rhythm pattern of each chars in a line 56 | self._rhyme_cids = rhyme_char_ids # char ids in the required rhyme category 57 | self._repetitive_ids = repetitive_ids 58 | 59 | 60 | def get_candidates(self, completed=False, with_states=False): 61 | if completed: 62 | hypotheses = self._completed_hypotheses 63 | else: 64 | hypotheses = self._hypotheses 65 | 66 | candidates = [hypo.candidate for hypo in hypotheses] 67 | scores = [hypo.score for hypo in hypotheses] 68 | 69 | read_aligns = [hypo.read_aligns for hypo in hypotheses] 70 | 71 | if with_states: 72 | # (L, H) * B 73 | all_states = [hypo.states for hypo in hypotheses] 74 | return candidates, scores, read_aligns, all_states 75 | 76 | else: 77 | return candidates, scores, read_aligns 78 | 79 | 80 | def get_search_results(self, only_completed=True, sort=True): 81 | candidates, scores, aligns, states = self.get_candidates(True, True) 82 | 83 | if not only_completed: 84 | add_candis, add_scores, add_aligns, add_states = self.get_candidates(False, True) 85 | candidates = candidates + add_candis 86 | scores = scores + add_scores 87 | states = states + add_states 88 | aligns = aligns + add_aligns 89 | 90 | 91 | scores = [score/(len(candi)-1) for score, candi in zip(scores, candidates)] 92 | 93 | # sort with costs 94 | if sort: 95 | sort_indices = list(np.argsort(scores)) 96 | candidates = [candidates[i] for i in sort_indices] 97 | scores = [scores[i] for i in sort_indices] 98 | states = [states[i] for i in sort_indices] 99 | aligns = [aligns[i] for i in sort_indices] 100 | 101 | # ignore the bos symbol and initial state 102 | candidates = [candi[1: ] for candi in candidates] 103 | states = [ state[1:] for state in states ] 104 | 105 | return candidates, scores, states, aligns 106 | 107 | 108 | def get_beam_tails(self): 109 | # get the last token and state of each hypothesis 110 | tokens = [hypo.candidate[-1] for hypo in self._hypotheses] 111 | # (B) 112 | tail_tokens = torch.tensor(tokens, dtype=torch.long, device=self.device) 113 | 114 | tail_states = [hypo.states[-1] for hypo in self._hypotheses] 115 | # [1, H] * B -> [B, H] 116 | tail_states = torch.cat(tail_states, dim=0) 117 | 118 | return tail_tokens, tail_states 119 | 120 | 121 | def uncompleted_num(self): 122 | return len(self._hypotheses) 123 | 124 | 125 | def advance(self, logit, state, read_align, position): 126 | # logit: (B, V) 127 | # state: (B, H) 128 | # read_align: (B, 1, mem_slots) 129 | log_prob = torch.nn.functional.log_softmax(logit, dim=-1).cpu().data.numpy() 130 | 131 | beam_ids, word_ids, scores = self._beam_select(log_prob, position) 132 | 133 | # update beams 134 | updated_hypotheses = [] 135 | for beam_id, word_id, score in zip(beam_ids, word_ids, scores): 136 | # (1, H) 137 | new_states = self._hypotheses[beam_id].states + [state[beam_id, :].unsqueeze(0)] 138 | 139 | new_candidate = self._hypotheses[beam_id].candidate + [word_id] 140 | 141 | new_aligns = self._hypotheses[beam_id].read_aligns + \ 142 | [read_align[beam_id, :, :].unsqueeze(0)] 143 | 144 | hypo = Hypothesis(new_candidate, new_states, score, new_aligns) 145 | 146 | if word_id == self._E_ID: 147 | self._completed_hypotheses.append(hypo) 148 | else: 149 | updated_hypotheses.append(hypo) 150 | 151 | self._hypotheses = updated_hypotheses 152 | 153 | 154 | def _beam_select(self, log_probs, position): 155 | # log_probs: (B, V) 156 | B, V = log_probs.shape[0], log_probs.shape[1] 157 | 158 | 159 | if position == 0: 160 | costs = - log_probs[0, :].reshape(1, V) # (1, V) 161 | else: 162 | current_scores = [hypo.score for hypo in self._hypotheses] 163 | costs = np.reshape(current_scores, (B, 1)) - log_probs # (B, V) 164 | 165 | # filter with rhythm, rhyme and length 166 | # candidates that don't meet requirements are assigned a large cost 167 | filter_v = 1e5 168 | 169 | costs[:, self._UNK_ID] = filter_v 170 | 171 | # filter eos symbol 172 | if position < self._length: 173 | costs[:, self._E_ID] = filter_v 174 | 175 | # restrain the model from generating chars 176 | # that already generated in previous lines 177 | costs[:, self._repetitive_ids] = filter_v 178 | 179 | # restrain in-line repetitive chars 180 | inline_filter_ids = self.inline_filter(position) 181 | for i in range(0, costs.shape[0]): 182 | costs[i, inline_filter_ids[i]] = filter_v 183 | 184 | 185 | # for the tail char, filter out non-rhyme chars 186 | if (position == self._length-1) and (1 <= self._rhythms[-1] <= 30): 187 | filter_ids = list(set(range(0, V)) - set(self._rhyme_cids)) 188 | costs[:, filter_ids] = filter_v 189 | 190 | ''' 191 | filter out chars of undesired tones 192 | NOTE: since some Chinese characters may belong to both tones, 193 | here we only consider the non-overlap ones 194 | TODO: disambiguation 195 | ''' 196 | pos_rhythm = self._rhythms[position] 197 | if position < self._length and pos_rhythm != 0: 198 | if pos_rhythm == 31: # level tone 199 | costs[:, self._oblique_cids] = filter_v 200 | elif pos_rhythm == 32: # oblique 201 | costs[:, self._level_cids] = filter_v 202 | 203 | flat_costs = costs.flatten() # (B*V) 204 | 205 | # idx of the smallest B elements 206 | best_indices = np.argpartition( 207 | flat_costs, B)[0:B].copy() 208 | 209 | scores = flat_costs[best_indices] 210 | 211 | # get beam id and word id 212 | beam_ids = [int(idx // V) for idx in best_indices] 213 | word_ids = [int(idx % V) for idx in best_indices] 214 | 215 | if position == 0: 216 | beam_ids = list(range(0, B)) 217 | 218 | return beam_ids, word_ids, scores 219 | 220 | 221 | def inline_filter(self, pos): 222 | candidates, _, _ = self.get_candidates() 223 | # candidates: (L_i) * B 224 | B = len(candidates) 225 | forbidden_list = [[] for _ in range(0, B)] 226 | 227 | limit_pos = pos - 1 if pos % 2 != 0 else pos 228 | preidx = range(0, limit_pos) 229 | 230 | for i in range(0, B): # iter ever batch 231 | forbidden_list[i] = [candidates[i][c] for c in preidx] 232 | 233 | return forbidden_list -------------------------------------------------------------------------------- /preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:25:32 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import pickle 14 | import json 15 | import random 16 | 17 | from pattern_extractor import PatternExtractor 18 | 19 | def outFile(data, file_name): 20 | print ("output data to %s, num: %d" % (file_name, len(data))) 21 | with open(file_name, 'w') as fout: 22 | for d in data: 23 | fout.write(d+"\n") 24 | 25 | 26 | class PreProcess(object): 27 | """A Tool for data preprocess. 28 | Please note that this tool is only for Chinese quatrains. 29 | """ 30 | def __init__(self): 31 | super(PreProcess, self).__init__() 32 | self.min_freq = 1 33 | self.sens_num = 4 # sens_num must be 4 34 | self.key_num = 4 # max number of keywords 35 | 36 | self.extractor = PatternExtractor("../data/") 37 | 38 | 39 | def create_dic(self, poems): 40 | print ("creating the word dictionary...") 41 | print ("input poems: %d" % (len(poems))) 42 | count_dic = {} 43 | for p in poems: 44 | poem = p.strip().replace("|", "") 45 | 46 | for c in poem: 47 | if c in count_dic: 48 | count_dic[c] += 1 49 | else: 50 | count_dic[c] = 1 51 | 52 | vec = sorted(count_dic.items(), key=lambda d:d[1], reverse=True) 53 | print ("original word num:%d" % (len(vec))) 54 | 55 | # add special symbols 56 | # -------------------------------------- 57 | dic = {} 58 | idic = {} 59 | dic['PAD'] = 0 60 | idic[0] = 'PAD' 61 | 62 | dic['UNK'] = 1 63 | idic[1] = 'UNK' 64 | 65 | dic[''] = 2 66 | idic[2] = '' 67 | 68 | dic[''] = 3 69 | idic[3] = '' 70 | 71 | 72 | idx = 4 73 | print ("min freq:%d" % (self.min_freq)) 74 | 75 | for c, v in vec: 76 | if v < self.min_freq: 77 | continue 78 | if not c in dic: 79 | dic[c] = idx 80 | idic[idx] = c 81 | idx += 1 82 | 83 | print ("total word num: %s" % (len(dic))) 84 | 85 | return dic, idic 86 | 87 | 88 | def build_dic(self, infile): 89 | with open(infile, 'r') as fin: 90 | lines = fin.readlines() 91 | 92 | poems = [] 93 | training_lines = [] 94 | for line in lines: 95 | dic = json.loads(line.strip()) 96 | poem = dic['content'] 97 | poems.append(poem) 98 | training_lines.extend(poem.split("|")) 99 | 100 | dic, idic = self.create_dic(poems) 101 | self.dic = dic 102 | self.idic = idic 103 | 104 | # output dic file 105 | # read 106 | dic_file = "vocab.pickle" 107 | idic_file = "ivocab.pickle" 108 | 109 | print ("saving dictionary to %s" % (dic_file)) 110 | with open(dic_file, 'wb') as fout: 111 | pickle.dump(dic, fout, -1) 112 | 113 | 114 | print ("saving inverting dictionary to %s" % (idic_file)) 115 | with open(idic_file, 'wb') as fout: 116 | pickle.dump(idic, fout, -1) 117 | 118 | 119 | # building training lines 120 | outFile(training_lines, "training_lines.txt") 121 | 122 | 123 | 124 | def line2idxes(self, line): 125 | chars = [c for c in line] 126 | idxes = [] 127 | for c in chars: 128 | if c in self.dic: 129 | idx = self.dic[c] 130 | else: 131 | idx = self.dic['UNK'] 132 | idxes.append(idx) 133 | 134 | return idxes 135 | 136 | 137 | 138 | def read_corpus(self, infile): 139 | with open(infile, 'r') as fin: 140 | lines = fin.readlines() 141 | 142 | corpus = [] 143 | for line in lines: 144 | dic = json.loads(line.strip()) 145 | poem = dic['content'].strip() 146 | keywords = dic['keywords'].strip().split(" ") 147 | 148 | corpus.append((keywords, poem)) 149 | 150 | return corpus 151 | 152 | 153 | def build_pattern(self, sens): 154 | length = len(sens[0]) 155 | assert length == 5 or length == 7 156 | 157 | rhymes = self.extractor.get_poem_rhyme(sens) 158 | if len(rhymes) == 0: 159 | return "" 160 | 161 | 162 | rhythm_pattern = self.extractor.get_poem_rhythm(sens, length) 163 | if len(rhythm_pattern) == 0: 164 | return "" 165 | 166 | for i in range(0, len(sens)): 167 | if rhymes[i] >= 1: 168 | rhythm_pattern[i][-1] = rhymes[i] 169 | 170 | return rhythm_pattern 171 | 172 | 173 | def build_data(self, corpus, convert_to_indices=True): 174 | skip_count = 0 175 | 176 | data = [] 177 | for keywords, poem in corpus: 178 | lines = poem.strip().split("|") 179 | 180 | if len(keywords) == 0: 181 | skip_count += 1 182 | continue 183 | 184 | keywords = keywords[0:self.key_num] 185 | 186 | if len(lines) != 4: 187 | skip_count += 1 188 | continue 189 | 190 | 191 | lens = [len(line) for line in lines] 192 | 193 | if not (lens[0] == lens[1] == lens[2] == lens[3]): 194 | skip_count += 1 195 | continue 196 | 197 | 198 | length = lens[0] 199 | # only for Chinese quatrains 200 | if length != 5 and length != 7: 201 | skip_count += 1 202 | continue 203 | 204 | 205 | pattern = self.build_pattern(lines) 206 | if len(pattern) == 0: 207 | skip_count += 1 208 | continue 209 | 210 | if not convert_to_indices: 211 | for keynum in range(1, len(keywords)+1): 212 | tup = (random.sample(keywords, keynum), 213 | lines, keynum, pattern) 214 | data.append(tup) 215 | continue 216 | 217 | # poem to indices 218 | line_idxes_vec = [] 219 | for line in lines: 220 | idxes = self.line2idxes(line) 221 | assert len(idxes) == 5 or len(idxes) == 7 222 | line_idxes_vec.append(idxes) 223 | 224 | assert len(line_idxes_vec) == 4 225 | 226 | # keywords to indices 227 | key_idxes_vec = [self.line2idxes(keyword) for keyword in keywords] 228 | 229 | 230 | for keynum in range(1, len(key_idxes_vec)+1): 231 | tup = (random.sample(key_idxes_vec, keynum), 232 | line_idxes_vec, keynum, pattern) 233 | data.append(tup) 234 | 235 | print ("data num: %d, skip_count: %d" %\ 236 | (len(data), skip_count)) 237 | 238 | return data 239 | 240 | 241 | def build_test_data(self, infile, out_inp_file, out_trg_file): 242 | with open(infile, 'r') as fin: 243 | lines = fin.readlines() 244 | 245 | test = self.read_corpus(infile) 246 | test_data = self.build_data(test, False) 247 | 248 | inps, trgs = [], [] 249 | for tup in test_data: 250 | keywords = " ".join(tup[0]) 251 | poem = "|".join(tup[1]) 252 | pattern = "|".join([" ".join(map(str, p)) for p in tup[3]]) 253 | 254 | inps.append(keywords+"#"+pattern) 255 | trgs.append(poem) 256 | 257 | 258 | outFile(inps, out_inp_file) 259 | outFile(trgs, out_trg_file) 260 | 261 | 262 | 263 | def process(self): 264 | # build the word dictionary 265 | self.build_dic("ccpc_train_v1.0.json") 266 | 267 | # build training and validation datasets 268 | train = self.read_corpus("ccpc_train_v1.0.json") 269 | valid = self.read_corpus("ccpc_valid_v1.0.json") 270 | 271 | train_data = self.build_data(train) 272 | valid_data = self.build_data(valid) 273 | 274 | random.shuffle(train_data) 275 | random.shuffle(valid_data) 276 | 277 | print ("training data: %d" % (len(train_data))) 278 | print ("validation data: %d" % (len(valid_data))) 279 | 280 | train_file = "train_data.pickle" 281 | print ("saving training data to %s" % (train_file)) 282 | with open(train_file, 'wb') as fout: 283 | pickle.dump(train_data, fout, -1) 284 | 285 | 286 | valid_file = "valid_data.pickle" 287 | print ("saving validation data to %s" % (valid_file)) 288 | with open(valid_file, 'wb') as fout: 289 | pickle.dump(valid_data, fout, -1) 290 | 291 | # build testing inputs and trgs 292 | self.build_test_data("ccpc_test_v1.0.json", "test_inps.txt", "test_trgs.txt") 293 | 294 | 295 | 296 | def main(): 297 | processor = PreProcess() 298 | processor.process() 299 | 300 | 301 | 302 | if __name__ == "__main__": 303 | main() -------------------------------------------------------------------------------- /codes/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:08:55 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import math 14 | from collections import OrderedDict 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class BidirEncoder(nn.Module): 22 | def __init__(self, input_size, hidden_size, cell='GRU', n_layers=1, drop_ratio=0.1): 23 | super(BidirEncoder, self).__init__() 24 | 25 | if cell == 'GRU': 26 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, 27 | bidirectional=True, batch_first=True) 28 | elif cell == 'Elman': 29 | self.rnn = nn.RNN(input_size, hidden_size, n_layers, 30 | bidirectional=True, batch_first=True) 31 | elif cell == 'LSTM': 32 | self.rnn = nn.LSTM(input_size, hidden_size, n_layers, 33 | bidirectional=True, batch_first=True) 34 | 35 | self.dropout_layer = nn.Dropout(drop_ratio) 36 | 37 | 38 | def forward(self, embed_seq, input_lens=None): 39 | # embed_seq: (B, L, emb_dim) 40 | # input_lens: (B) 41 | embed_inps = self.dropout_layer(embed_seq) 42 | 43 | if input_lens is None: 44 | outputs, state = self.rnn(embed_inps, None) 45 | else: 46 | # Dynamic RNN 47 | total_len = embed_inps.size(1) 48 | packed = torch.nn.utils.rnn.pack_padded_sequence(embed_inps, 49 | input_lens, batch_first=True, enforce_sorted=False) 50 | outputs, state = self.rnn(packed, None) 51 | # outputs: (B, L, num_directions*H) 52 | # state: (num_layers*num_directions, B, H) 53 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, 54 | batch_first=True, total_length=total_len) 55 | 56 | return outputs, state 57 | 58 | 59 | class Decoder(nn.Module): 60 | def __init__(self, input_size, hidden_size, cell='GRU', n_layers=1, drop_ratio=0.1): 61 | super(Decoder, self).__init__() 62 | 63 | self.dropout_layer = nn.Dropout(drop_ratio) 64 | 65 | if cell == 'GRU': 66 | self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True) 67 | elif cell == 'Elman': 68 | self.rnn = nn.RNN(input_size, hidden_size, n_layers, batch_first=True) 69 | elif cell == 'LSTM': 70 | self.rnn = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True) 71 | 72 | 73 | def forward(self, embed_seq, last_state): 74 | # embed_seq: (B, L, H) 75 | # last_state: (B, H) 76 | embed_inps = self.dropout_layer(embed_seq) 77 | output, state = self.rnn(embed_inps, last_state.unsqueeze(0)) 78 | output = output.squeeze(1) # (B, 1, N) -> (B,N) 79 | return output, state.squeeze(0) # (B, H) 80 | 81 | 82 | class AttentionReader(nn.Module): 83 | def __init__(self, d_q, d_v, drop_ratio=0.0): 84 | super(AttentionReader, self).__init__() 85 | self.attn = nn.Linear(d_q+d_v, d_v) 86 | self.v = nn.Parameter(torch.rand(d_v)) 87 | stdv = 1. / math.sqrt(self.v.size(0)) 88 | self.v.data.normal_(mean=0, std=stdv) 89 | self.dropout = nn.Dropout(drop_ratio) 90 | 91 | 92 | def forward(self, Q, K, V, attn_mask): 93 | # Q: (B, 1, d_q) 94 | # K: (B, L, d_v) 95 | # V: (B, L, d_v) 96 | # attn_mask: (B, L), True means mask 97 | k_len = K.size(1) 98 | q_state = Q.repeat(1, k_len, 1) # (B, L, d_q) 99 | 100 | attn_energies = self.score(q_state, K) # (B, L) 101 | 102 | attn_energies.masked_fill_(attn_mask, -1e12) 103 | 104 | attn_weights = F.softmax(attn_energies, dim=1).unsqueeze(1) 105 | attn_weights = self.dropout(attn_weights) 106 | 107 | # (B, 1, L) * (B, L, d_v) -> (B, 1, d_v) 108 | context = attn_weights.bmm(V) 109 | 110 | return context.squeeze(1), attn_weights 111 | 112 | 113 | def score(self, query, memory): 114 | # query (B, L, d_q) 115 | # memory (B, L, d_v) 116 | 117 | # (B, L, d_q+d_v)->(B, L, d_v) 118 | energy = torch.tanh(self.attn(torch.cat([query, memory], 2))) 119 | energy = energy.transpose(1, 2) # (B, d_v, L) 120 | 121 | v = self.v.repeat(memory.size(0), 1).unsqueeze(1) # (B, 1, d_v) 122 | energy = torch.bmm(v, energy) # (B, 1, d_v) * (B, d_v, L) -> [B, 1, L] 123 | return energy.squeeze(1) # (B, L) 124 | 125 | 126 | 127 | class AttentionWriter(nn.Module): 128 | def __init__(self, d_q, mem_size): 129 | super(AttentionWriter, self).__init__() 130 | self.attn = nn.Linear(d_q+mem_size, mem_size) 131 | self.v = nn.Parameter(torch.rand(mem_size)) 132 | stdv = 1. / math.sqrt(self.v.size(0)) 133 | self.v.data.normal_(mean=0, std=stdv) 134 | 135 | self._tau = 1.0 # Gumbel temperature 136 | 137 | 138 | def set_tau(self, tau): 139 | self._tau = tau 140 | 141 | def get_tau(self): 142 | return self._tau 143 | 144 | 145 | def forward(self, his_mem, states, states_mask, global_trace, null_mem): 146 | # mem: (B, mem_slots, mem_size) 147 | # states: (B, L, mem_size) 148 | # states_mask: (B, L), 0 means pad_idx and not to be written 149 | # global_trace: (B, D) 150 | # null_mem: (B, 1, mem_size) 151 | n = states.size(1) 152 | mem_slots = his_mem.size(1) + 1 # including the null slot 153 | 154 | write_log = [] 155 | 156 | for i in range(0, n): 157 | mem = torch.cat([his_mem, null_mem], dim=1) 158 | state = states[:, i, :] # (B, mem_size) 159 | 160 | 161 | query = torch.cat([state, global_trace], dim=1).unsqueeze(1).repeat(1, mem_slots, 1) 162 | attn_energies = self.score(query, mem) 163 | 164 | attn_weights = F.softmax(attn_energies, dim=-1) # (B, mem_slots) 165 | 166 | # manually give the empty slots higher weights 167 | empty_mask = mem.abs().sum(-1).eq(0).float() # (B, mem_slots) 168 | attn_weights = attn_weights + empty_mask * 10.0 169 | 170 | # one-hot (B, mem_slots) 171 | slot_select = F.gumbel_softmax(attn_weights, tau=self._tau, hard=True) 172 | 173 | write_mask = slot_select[:, 0:mem_slots-1] * \ 174 | (states_mask[:, i].unsqueeze(1).repeat(1, mem_slots-1)) 175 | write_mask = write_mask.unsqueeze(2) # (B, mem_slots-1, 1) 176 | 177 | 178 | write_state = state.unsqueeze(1).repeat(1, mem_slots-1, 1) 179 | 180 | his_mem = (1.0 - write_mask) * his_mem + write_mask * write_state 181 | 182 | write_log.append(slot_select.unsqueeze(1)) 183 | 184 | write_log = torch.cat(write_log, dim=1) 185 | return his_mem, write_log 186 | 187 | 188 | def score(self, query, memory): 189 | energy = torch.tanh(self.attn(torch.cat([query, memory], 2))) 190 | energy = energy.transpose(1, 2) 191 | 192 | v = self.v.repeat(memory.size(0), 1).unsqueeze(1) 193 | 194 | energy = torch.bmm(v, energy) 195 | return energy.squeeze(1) 196 | 197 | 198 | class MLP(nn.Module): 199 | def __init__(self, ori_input_size, layer_sizes, activs=None, 200 | drop_ratio=0.0, no_drop=False): 201 | super(MLP, self).__init__() 202 | 203 | layer_num = len(layer_sizes) 204 | 205 | orderedDic = OrderedDict() 206 | input_size = ori_input_size 207 | for i, (layer_size, activ) in enumerate(zip(layer_sizes, activs)): 208 | linear_name = 'linear_' + str(i) 209 | orderedDic[linear_name] = nn.Linear(input_size, layer_size) 210 | input_size = layer_size 211 | 212 | if activ is not None: 213 | assert activ in ['tanh', 'relu', 'leakyrelu'] 214 | 215 | active_name = 'activ_' + str(i) 216 | if activ == 'tanh': 217 | orderedDic[active_name] = nn.Tanh() 218 | elif activ == 'relu': 219 | orderedDic[active_name] = nn.ReLU() 220 | elif activ == 'leakyrelu': 221 | orderedDic[active_name] = nn.LeakyReLU(0.2) 222 | 223 | 224 | if (drop_ratio > 0) and (i < layer_num-1) and (not no_drop): 225 | orderedDic["drop_" + str(i)] = nn.Dropout(drop_ratio) 226 | 227 | self.mlp = nn.Sequential(orderedDic) 228 | 229 | 230 | def forward(self, inps): 231 | return self.mlp(inps) 232 | 233 | 234 | class ContextLayer(nn.Module): 235 | def __init__(self, inp_size, out_size, kernel_size=3): 236 | super(ContextLayer, self).__init__() 237 | # (B, L, H) 238 | self.conv = nn.Conv1d(inp_size, out_size, kernel_size) 239 | self.linear = nn.Linear(out_size+inp_size, out_size) 240 | 241 | def forward(self, last_context, dec_states): 242 | # last_context: (B, context_size) 243 | # dec_states: (B, H, L) 244 | hidden_feature = self.conv(dec_states) # (B, out_size, L_out) 245 | feature = torch.tanh(hidden_feature).mean(dim=2) # (B, out_size) 246 | new_context = torch.tanh(self.linear(torch.cat([last_context, feature], dim=1))) 247 | return new_context -------------------------------------------------------------------------------- /codes/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:19:25 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from graphs import WorkingMemoryModel 17 | 18 | from tool import Tool 19 | from beam import PoetryBeam 20 | from filter import PoetryFilter 21 | from visualization import Visualization 22 | import utils 23 | 24 | class Generator(object): 25 | ''' 26 | generator for testing 27 | ''' 28 | 29 | def __init__(self, hps, device): 30 | self.tool = Tool(hps.sens_num, hps.sen_len, 31 | hps.key_len, hps.topic_slots, 0.0) 32 | self.tool.load_dic(hps.vocab_path, hps.ivocab_path) 33 | vocab_size = self.tool.get_vocab_size() 34 | print ("vocabulary size: %d" % (vocab_size)) 35 | PAD_ID = self.tool.get_PAD_ID() 36 | B_ID = self.tool.get_B_ID() 37 | assert vocab_size > 0 and PAD_ID >=0 and B_ID >= 0 38 | self.hps = hps._replace(vocab_size=vocab_size, pad_idx=PAD_ID, bos_idx=B_ID) 39 | self.device = device 40 | 41 | # load model 42 | model = WorkingMemoryModel(self.hps, device) 43 | 44 | # load trained model 45 | utils.restore_checkpoint(self.hps.model_dir, device, model) 46 | self.model = model.to(device) 47 | self.model.eval() 48 | 49 | null_idxes = self.tool.load_function_tokens(self.hps.data_dir + "fchars.txt").to(self.device) 50 | self.model.set_null_idxes(null_idxes) 51 | 52 | self.model.set_tau(hps.min_tau) 53 | 54 | # load poetry filter 55 | print ("loading poetry filter...") 56 | self.filter = PoetryFilter(self.tool.get_vocab(), 57 | self.tool.get_ivocab(), self.hps.data_dir) 58 | 59 | self.visual_tool = Visualization(hps.topic_slots, hps.his_mem_slots, 60 | "../log/") 61 | print("--------------------------") 62 | 63 | 64 | 65 | def generate_one(self, keywords, pattern, beam_size=20, verbose=1, manu=False, visual=0): 66 | ''' 67 | generate one poem according to the inputs: 68 | keyword: a list of topic words, at most key_slots 69 | pattern: genre pattern of the poem to be generated 70 | verbose: 0, 1, 2, 3 71 | visual: 0, 1, 2 72 | ''' 73 | key_inps = self.tool.keywords2tensor([keywords]*beam_size) 74 | key_inps = [key.to(self.device) for key in key_inps] 75 | 76 | if visual > 0: 77 | self.visual_tool.reset(keywords) 78 | 79 | 80 | # inps is a pseudo tensor with pad symbols for generating the first line 81 | inps, all_ph_inps, all_len_inps, all_lengths = self.tool.patterns2tensor([pattern]*beam_size) 82 | 83 | inps = inps.to(self.device) 84 | all_ph_inps = [ph_inps.to(self.device) for ph_inps in all_ph_inps] 85 | all_len_inps = [len_inps.to(self.device) for len_inps in all_len_inps] 86 | 87 | # for quatrains, all lines in a poem share the same length 88 | length = all_lengths[0] 89 | 90 | # initialize beam pool 91 | beam_pool = PoetryBeam(self.device, beam_size, length, 92 | self.tool.get_B_ID(), self.tool.get_E_ID(), self.tool.get_UNK_ID(), 93 | self.filter.get_level_cids(), self.filter.get_oblique_cids()) 94 | 95 | self.filter.reset(length, verbose) 96 | 97 | with torch.no_grad(): 98 | topic_mem, topic_mask, history_mem, history_mask,\ 99 | global_trace, topic_trace, key_init_state = self.model.initialize_mems(key_inps) 100 | 101 | # beam search 102 | poem = [] 103 | for step in range(0, self.hps.sens_num): 104 | # generate each line 105 | if verbose >= 1: 106 | print ("\ngenerating step: %d" % (step)) 107 | 108 | if step > 0: 109 | key_init_state = None 110 | 111 | candidates, costs, states, read_aligns, local_mem, local_mask = self.beam_search(beam_pool, 112 | length, inps, 113 | all_ph_inps[step], pattern[step], all_len_inps[step], 114 | key_init_state, history_mem, history_mask, 115 | topic_mem, topic_mask, global_trace, topic_trace) 116 | 117 | lines = [self.tool.idxes2line(idxes) for idxes in candidates] 118 | 119 | lines, costs, states, read_aligns = self.filter.filter_illformed(lines, costs, 120 | states, read_aligns, pattern[step][-1]) 121 | 122 | if len(lines) == 0: 123 | return [], "line {} generation failed!".format(step) 124 | 125 | which = 0 126 | if manu: 127 | for i, (line, cost) in enumerate(zip(lines, costs)): 128 | print ("%d, %s, %.2f" % (i, line, cost)) 129 | which = int(input("select sentence>")) 130 | 131 | line = lines[which] 132 | poem.append(line) 133 | 134 | # set repetitive chars 135 | self.filter.add_repetitive(self.tool.line2idxes(line)) 136 | 137 | # --------------------------------------- 138 | # write into history memory 139 | write_log = None 140 | if step >= 1: 141 | with torch.no_grad(): 142 | history_mem, write_log = self.update_history_mem(history_mem, 143 | local_mem, local_mask, global_trace) 144 | 145 | history_mask = history_mem.abs().sum(-1).eq(0) # (B, mem_slots) 146 | 147 | 148 | with torch.no_grad(): 149 | # update global trace 150 | global_trace = self.update_glocal_trace(global_trace, states[which], length) 151 | 152 | # update topic trace 153 | topic_trace = self.update_topic_trace(topic_trace, topic_mem, read_aligns[which]) 154 | 155 | 156 | # build inps 157 | inps = self.tool.line2tensor(line, beam_size).to(self.device) 158 | 159 | 160 | if visual > 0: 161 | # show visualization of memory reading 162 | self.visual_tool.add_gen_line(line) 163 | self.visual_tool.draw(read_aligns[which], write_log, step, visual) 164 | 165 | 166 | return poem, "ok" 167 | 168 | 169 | # ------------------------------------ 170 | def beam_search(self, beam_pool, trg_len, inputs, phs, ph_labels, lens, key_init_state, 171 | history_mem, history_mask, topic_mem, topic_mask, global_trace, topic_trace): 172 | 173 | local_mem, local_mask, init_state = \ 174 | self.model.computer_local_memory(inputs, key_init_state is None) 175 | 176 | 177 | if key_init_state is not None: 178 | init_state = key_init_state 179 | 180 | # reset beam pool 181 | if 1 <= ph_labels[-1] <= 30: 182 | rhyme = ph_labels[-1] 183 | else: 184 | rhyme = -1 185 | 186 | beam_pool.reset(init_state[0, :].unsqueeze(0), ph_labels+[0]*10, 187 | self.filter.get_rhyme_cids(rhyme), self.filter.get_repetitive_ids()) 188 | 189 | # current size of beam candidates in the beam pool 190 | n_samples = beam_pool.uncompleted_num() 191 | 192 | total_mask = torch.cat([topic_mask, history_mask, local_mask], dim=1) 193 | total_mem = torch.cat([topic_mem, history_mem, local_mem], dim=1) 194 | 195 | 196 | for k in range(0, trg_len+5): 197 | inp, state = beam_pool.get_beam_tails() 198 | 199 | if k <= trg_len: 200 | ph_inp = phs[0:n_samples, k] 201 | len_inp = lens[0:n_samples, k] 202 | else: 203 | ph_inp = torch.zeros(n_samples, dtype=torch.long, device=self.device) 204 | len_inp = torch.zeros(n_samples, dtype=torch.long, device=self.device) 205 | 206 | 207 | with torch.no_grad(): 208 | logit, new_state, read_align = self.model.dec_step(inp, state, 209 | ph_inp, len_inp, 210 | total_mem[0:n_samples, :, :], total_mask[0:n_samples, :], 211 | global_trace[0:n_samples, :], topic_trace[0:n_samples, :]) 212 | 213 | 214 | beam_pool.advance(logit, new_state, read_align, k) 215 | 216 | n_samples = beam_pool.uncompleted_num() 217 | 218 | if n_samples == 0: 219 | break 220 | 221 | candidates, costs, dec_states, read_aligns = beam_pool.get_search_results() 222 | return candidates, costs, dec_states, read_aligns, local_mem, local_mask 223 | 224 | 225 | # --------------- 226 | def update_history_mem(self, history_mem, local_mem, local_mask, global_trace): 227 | new_history_mem, write_log = self.model.layers['memory_write'](history_mem, local_mem, 228 | 1.0-local_mask.float(), global_trace, self.model.null_mem) 229 | 230 | 231 | return new_history_mem, write_log 232 | 233 | 234 | def update_glocal_trace(self, global_trace, dec_states, length): 235 | # dec_states: (1, H) * L_gen 236 | 237 | batch_size = global_trace.size(0) 238 | 239 | # (1, H) -> (B, H, 1) 240 | states = [state.unsqueeze(2).repeat(batch_size, 1, 1) for state in dec_states] 241 | 242 | l = len(states) 243 | 244 | mask = torch.zeros(batch_size, l, dtype=torch.float, device=self.device) 245 | mask[:, 0:length+1] = 1.0 246 | 247 | # update global trace vector 248 | new_global_trace = self.model.update_global_trace(global_trace, states, mask) 249 | 250 | 251 | return new_global_trace 252 | 253 | 254 | def update_topic_trace(self, topic_trace, topic_mem, read_align): 255 | # read_align: (1, 1, mem_slots) * L_gen 256 | 257 | batch_size = topic_trace.size(0) 258 | # concat_aligns: (B, L_gen, mem_slots) 259 | concat_aligns = torch.cat(read_align, dim=1).repeat(batch_size, 1, 1) 260 | new_topic_trace = self.model.update_topic_trace(topic_trace, topic_mem, concat_aligns) 261 | 262 | return new_topic_trace 263 | 264 | 265 | # -------------------------------------- -------------------------------------------------------------------------------- /codes/tool.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:38:33 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import pickle 14 | import numpy as np 15 | import random 16 | import copy 17 | import torch 18 | 19 | 20 | def readPickle(data_path): 21 | corpus_file = open(data_path, 'rb') 22 | corpus = pickle.load(corpus_file) 23 | corpus_file.close() 24 | 25 | return corpus 26 | 27 | #------------------------------------------------------------------------- 28 | #------------------------------------------------------------------------- 29 | class Tool(object): 30 | ''' 31 | a tool to hold training data and the vocabulary 32 | ''' 33 | def __init__(self, sens_num, sen_len, key_len, key_slots, 34 | corrupt_ratio=0): 35 | # corrupt ratio for dae 36 | self._sens_num = sens_num 37 | self._sen_len = sen_len 38 | self._key_len = key_len 39 | self._key_slots = key_slots 40 | 41 | self._corrupt_ratio = corrupt_ratio 42 | 43 | self._vocab = None 44 | self._ivocab = None 45 | 46 | self._PAD_ID = None 47 | self._B_ID = None 48 | self._E_ID = None 49 | self._UNK_ID = None 50 | 51 | # ----------------------------------- 52 | # map functions 53 | def idxes2line(self, idxes, truncate=True): 54 | if truncate and self._E_ID in idxes: 55 | idxes = idxes[:idxes.index(self._E_ID)] 56 | 57 | tokens = self.idxes2tokens(idxes, truncate) 58 | line = self.tokens2line(tokens) 59 | return line 60 | 61 | def line2idxes(self, line): 62 | tokens = self.line2tokens(line) 63 | return self.tokens2idxes(tokens) 64 | 65 | def line2tokens(self, line): 66 | ''' 67 | in this work, we treat each Chinese character as a token. 68 | ''' 69 | line = line.strip() 70 | tokens = [c for c in line] 71 | return tokens 72 | 73 | 74 | def tokens2line(self, tokens): 75 | return "".join(tokens) 76 | 77 | 78 | def tokens2idxes(self, tokens): 79 | ''' Characters to idx list ''' 80 | idxes = [] 81 | for w in tokens: 82 | if w in self._vocab: 83 | idxes.append(self._vocab[w]) 84 | else: 85 | idxes.append(self._UNK_ID) 86 | return idxes 87 | 88 | 89 | def idxes2tokens(self, idxes, omit_special=True): 90 | tokens = [] 91 | for idx in idxes: 92 | if (idx == self._PAD_ID or idx == self._B_ID 93 | or idx == self._E_ID) and omit_special: 94 | continue 95 | tokens.append(self._ivocab[idx]) 96 | 97 | return tokens 98 | 99 | # ------------------------------------------------- 100 | def greedy_search(self, probs): 101 | # probs: (V) 102 | out_idxes = [int(np.argmax(prob, axis=-1)) for prob in probs] 103 | 104 | return self.idxes2line(out_idxes) 105 | 106 | # ---------------------------- 107 | def get_vocab(self): 108 | return copy.deepcopy(self._vocab) 109 | 110 | def get_ivocab(self): 111 | return copy.deepcopy(self._ivocab) 112 | 113 | def get_vocab_size(self): 114 | if self._vocab is not None: 115 | return len(self._vocab) 116 | else: 117 | return -1 118 | 119 | def get_PAD_ID(self): 120 | assert self._PAD_ID is not None 121 | return self._PAD_ID 122 | 123 | def get_B_ID(self): 124 | assert self._B_ID is not None 125 | return self._B_ID 126 | 127 | def get_E_ID(self): 128 | assert self._E_ID is not None 129 | return self._E_ID 130 | 131 | def get_UNK_ID(self): 132 | assert self._UNK_ID is not None 133 | return self._UNK_ID 134 | 135 | 136 | # ---------------------------------------------------------------- 137 | def load_dic(self, vocab_path, ivocab_path): 138 | dic = readPickle(vocab_path) 139 | idic = readPickle(ivocab_path) 140 | 141 | assert len(dic) == len(idic) 142 | 143 | 144 | self._vocab = dic 145 | self._ivocab = idic 146 | 147 | self._PAD_ID = dic['PAD'] 148 | self._UNK_ID = dic['UNK'] 149 | self._E_ID = dic[''] 150 | self._B_ID = dic[''] 151 | 152 | 153 | def load_function_tokens(self, file_dir): 154 | # please run load_dict before using loading function tokens ! 155 | with open(file_dir, 'r') as fin: 156 | lines = fin.readlines() 157 | 158 | tokens = [line.strip() for line in lines] 159 | 160 | f_idxes = [] 161 | for token in tokens: 162 | if token in self._vocab: 163 | f_idxes.append(self._vocab[token]) 164 | 165 | f_idxes = torch.tensor(f_idxes, dtype=torch.long) 166 | return f_idxes 167 | 168 | 169 | def build_data(self, train_data_path, valid_data_path, batch_size, mode): 170 | ''' 171 | Build data as batches. 172 | NOTE: please run load_dic() at first. 173 | mode: 174 | dae: pre-train the encoder and decoder as a denoising Seq2Seq model 175 | wm: train the working memory model 176 | ''' 177 | assert mode in ['dseq', 'wm'] 178 | train_data = readPickle(train_data_path) 179 | valid_data = readPickle(valid_data_path) 180 | 181 | 182 | # data limit for debug 183 | self.train_batches = self._build_data_core(train_data, batch_size, mode, None) 184 | self.valid_batches = self._build_data_core(valid_data, batch_size, mode, None) 185 | 186 | self.train_batch_num = len(self.train_batches) 187 | self.valid_batch_num = len(self.valid_batches) 188 | 189 | 190 | def _build_data_core(self, data, batch_size, mode, data_limit=None): 191 | # data: [keywords, sens, key_num, pattern] * data_num 192 | if data_limit is not None: 193 | data = data[0:data_limit] 194 | 195 | if mode == 'dseq': 196 | return self.build_dseq_batches(data, batch_size) 197 | elif mode == 'wm': 198 | return self.build_wm_batches(data, batch_size) 199 | 200 | 201 | def build_dseq_batches(self, data, batch_size): 202 | # data: [keywords, sens, key_num, pattern] * data_num 203 | batched_data = [] 204 | batch_num = int(np.ceil(len(data) / float(batch_size))) 205 | for bi in range(0, batch_num): 206 | instances = data[bi*batch_size : (bi+1)*batch_size] 207 | if len(instances) < batch_size: 208 | instances = instances + random.sample(data, batch_size-len(instances)) 209 | 210 | # build poetry batch 211 | poems = [instance[1] for instance in instances] # all poems 212 | genre_patterns = [instance[3] for instance in instances] 213 | for i in range(0, self._sens_num-1): 214 | line0 = [poem[i] for poem in poems] 215 | line1 = [poem[i+1] for poem in poems] 216 | phs = [pattern[i+1] for pattern in genre_patterns] 217 | 218 | inps, trgs, ph_inps, len_inps = \ 219 | self._build_batch_seqs(line0, line1, phs, corrupt=True) 220 | 221 | batched_data.append((inps, trgs, ph_inps, len_inps)) 222 | 223 | random.shuffle(batched_data) 224 | return batched_data 225 | 226 | 227 | def build_wm_batches(self, data, batch_size): 228 | # data: [keywords, sens, key_num, pattern] * data_num 229 | batched_data = [] 230 | batch_num = int(np.ceil(len(data) / float(batch_size))) 231 | for bi in range(0, batch_num): 232 | instances = data[bi*batch_size : (bi+1)*batch_size] 233 | if len(instances) < batch_size: 234 | instances = instances + random.sample(data, batch_size-len(instances)) 235 | 236 | # build poetry batch 237 | poems = [instance[1] for instance in instances] # all poems 238 | genre_patterns = [instance[3] for instance in instances] 239 | 240 | 241 | all_inps, all_trgs = [], [] 242 | all_ph_inps, all_len_inps = [], [] 243 | 244 | for i in range(-1, self._sens_num-1): 245 | 246 | if i < 0: 247 | line0 = [[] for poem in poems] 248 | else: 249 | line0 = [poem[i] for poem in poems] 250 | 251 | line1 = [poem[i+1] for poem in poems] 252 | phs = [pattern[i+1] for pattern in genre_patterns] 253 | 254 | 255 | inps, trgs, ph_inps, len_inps = \ 256 | self._build_batch_seqs(line0, line1, phs, corrupt=False) 257 | 258 | 259 | all_inps.append(inps) 260 | all_trgs.append(trgs) 261 | all_ph_inps.append(ph_inps) 262 | all_len_inps.append(len_inps) 263 | 264 | 265 | # build keys 266 | keywords = [instance[0] for instance in instances] 267 | keys = self._build_batch_keys(keywords) 268 | 269 | batched_data.append((all_inps, all_trgs, all_ph_inps, all_len_inps, keys)) 270 | 271 | 272 | random.shuffle(batched_data) 273 | return batched_data 274 | 275 | 276 | def _build_batch_keys(self, keywords): 277 | # build key batch 278 | batch_size = len(keywords) 279 | key_inps = [[] for _ in range(self._key_slots)] 280 | 281 | for i in range(0, batch_size): 282 | keys = keywords[i] # batch_size * at most 4 283 | for step in range(0, len(keys)): 284 | key = keys[step] 285 | assert len(key) <= self._key_len 286 | key_inps[step].append(key + [self._PAD_ID] * (self._key_len-len(key))) 287 | 288 | for step in range(0, self._key_slots-len(keys)): 289 | key_inps[len(keys)+step].append([self._PAD_ID] * self._key_len) 290 | 291 | 292 | key_tensor = [self._sens2tensor(key) for key in key_inps] 293 | return key_tensor 294 | 295 | 296 | 297 | def _build_batch_seqs(self, inputs, targets, pattern, corrupt=False): 298 | # pack sequences as a tensor 299 | inps, _, _ = self._get_batch_seq(inputs, pattern, False, corrupt=corrupt) 300 | trgs, phs, lens = self._get_batch_seq(targets, pattern, True, corrupt=False) 301 | 302 | inps_tensor = self._sens2tensor(inps) 303 | trgs_tensor = self._sens2tensor(trgs) 304 | 305 | phs_tensor = self._sens2tensor(phs) 306 | lens_tensor = self._sens2tensor(lens) 307 | 308 | 309 | return inps_tensor, trgs_tensor, phs_tensor, lens_tensor 310 | 311 | 312 | def _get_batch_seq(self, seqs, phs, with_E, corrupt): 313 | batch_size = len(seqs) 314 | max_len = max([len(seq) for seq in seqs]) 315 | max_len = max_len + int(with_E) 316 | 317 | if max_len == 0: 318 | max_len = self._sen_len 319 | 320 | batched_seqs = [] 321 | batched_lens, batched_phs = [], [] 322 | for i in range(0, batch_size): 323 | # max length for each sequence 324 | ori_seq = copy.deepcopy(seqs[i]) 325 | 326 | if corrupt: 327 | seq = self._do_corruption(ori_seq) 328 | else: 329 | seq = ori_seq 330 | # ---------------------------------- 331 | 332 | pad_size = max_len - len(seq) - int(with_E) 333 | pads = [self._PAD_ID] * pad_size 334 | 335 | new_seq = seq + [self._E_ID] * int(with_E) + pads 336 | 337 | #--------------------------------- 338 | # 0 means either 339 | ph = phs[i] 340 | ph_inp = ph + [0] * (max_len - len(ph)) 341 | 342 | assert len(ph_inp) == len(new_seq) 343 | batched_phs.append(ph_inp) 344 | 345 | len_inp = list(range(1, len(seq)+1+int(with_E))) 346 | len_inp.reverse() 347 | len_inp = len_inp + [0] * pad_size 348 | 349 | assert len(len_inp) == len(new_seq) 350 | 351 | batched_lens.append(len_inp) 352 | #--------------------------------- 353 | batched_seqs.append(new_seq) 354 | 355 | 356 | return batched_seqs, batched_phs, batched_lens 357 | 358 | 359 | 360 | def _sens2tensor(self, sens): 361 | batch_size = len(sens) 362 | sen_len = max([len(sen) for sen in sens]) 363 | tensor = torch.zeros(batch_size, sen_len, dtype=torch.long) 364 | for i, sen in enumerate(sens): 365 | for j, token in enumerate(sen): 366 | tensor[i][j] = token 367 | return tensor 368 | 369 | 370 | def _do_corruption(self, inp): 371 | # corrupt the sequence by setting some tokens as UNK 372 | m = int(np.ceil(len(inp) * self._corrupt_ratio)) 373 | m = min(m, len(inp)) 374 | m = max(1, m) 375 | 376 | unk_id = self.get_UNK_ID() 377 | 378 | corrupted_inp = copy.deepcopy(inp) 379 | pos = random.sample(list(range(0, len(inp))), m) 380 | for p in pos: 381 | corrupted_inp[p] = unk_id 382 | 383 | return corrupted_inp 384 | 385 | 386 | 387 | def shuffle_train_data(self): 388 | random.shuffle(self.train_batches) 389 | 390 | 391 | 392 | # ----------------------------------------------------------- 393 | # ----------------------------------------------------------- 394 | # Tools for beam search 395 | def keywords2tensor(self, keywords): 396 | # input: keywords: list of string 397 | 398 | # string to idxes 399 | key_idxes = [] 400 | for keyword in keywords: 401 | keys = [self.line2idxes(key_str) for key_str in keyword] 402 | if len(keys) < self._key_slots: 403 | add_num = self._key_slots - len(keys) 404 | add_keys = [[self._PAD_ID]*self._key_len]*add_num 405 | keys = keys + add_keys 406 | 407 | key_idxes.append(keys) 408 | 409 | keys_tensor = self._build_batch_keys(key_idxes) 410 | return keys_tensor 411 | 412 | 413 | 414 | def patterns2tensor(self, patterns): 415 | batch_size = len(patterns) 416 | # assume all poems share the same sens_num 417 | all_seqs = [] 418 | all_lengths = [] 419 | all_ph_inps, all_len_inps = [], [] 420 | for step in range(0, self._sens_num): 421 | # each line 422 | phs = [pattern[step] for pattern in patterns] 423 | pseudo_seqs = [ [0] * len(ph) for ph in phs ] 424 | all_lengths.append(max([len(seq) for seq in pseudo_seqs])) 425 | 426 | batched_seqs, batched_phs, batched_lens = \ 427 | self._get_batch_seq(pseudo_seqs, phs, True, False) 428 | 429 | phs_tensor = self._sens2tensor(batched_phs) 430 | lens_tensor = self._sens2tensor(batched_lens) 431 | inps_tensor = self._sens2tensor(batched_seqs) 432 | 433 | all_ph_inps.append(phs_tensor) 434 | all_len_inps.append(lens_tensor) 435 | all_seqs.append(inps_tensor) 436 | 437 | 438 | return all_seqs[0], all_ph_inps, all_len_inps, all_lengths 439 | 440 | 441 | def line2tensor(self, line, beam_size): 442 | idxes = self.line2idxes(line.strip()) 443 | return self._sens2tensor([idxes]*beam_size) 444 | -------------------------------------------------------------------------------- /codes/graphs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Xiaoyuan Yi 3 | # @Last Modified by: Xiaoyuan Yi 4 | # @Last Modified time: 2020-06-11 20:16:17 5 | # @Email: yi-xy16@mails.tsinghua.edu.cn 6 | # @Description: 7 | ''' 8 | Copyright 2020 THUNLP Lab. All Rights Reserved. 9 | This code is part of the online Chinese poetry generation system, Jiuge. 10 | System URL: https://jiuge.thunlp.cn/ and https://jiuge.thunlp.org/. 11 | Github: https://github.com/THUNLP-AIPoet. 12 | ''' 13 | import random 14 | from itertools import chain 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | from layers import BidirEncoder, Decoder, MLP, ContextLayer, AttentionReader, AttentionWriter 20 | 21 | def get_non_pad_mask(seq, pad_idx, device): 22 | # seq: [B, L] 23 | assert seq.dim() == 2 24 | # [B, L] 25 | mask = seq.ne(pad_idx).type(torch.float) 26 | return mask.to(device) 27 | 28 | 29 | def get_seq_length(seq, pad_idx, device): 30 | mask = get_non_pad_mask(seq, pad_idx, device) 31 | # mask: [B, T] 32 | lengths = mask.sum(dim=-1).long() 33 | return lengths 34 | 35 | 36 | class WorkingMemoryModel(nn.Module): 37 | def __init__(self, hps, device): 38 | super(WorkingMemoryModel, self).__init__() 39 | self.hps = hps 40 | self.device = device 41 | 42 | self.global_trace_size = hps.global_trace_size 43 | self.topic_trace_size = hps.topic_trace_size 44 | self.topic_slots = hps.topic_slots 45 | self.his_mem_slots = hps.his_mem_slots 46 | 47 | self.vocab_size = hps.vocab_size 48 | self.mem_size = hps.mem_size 49 | 50 | self.sens_num = hps.sens_num 51 | 52 | self.pad_idx = hps.pad_idx 53 | self.bos_tensor = torch.tensor(hps.bos_idx, dtype=torch.long, device=device) 54 | 55 | # ---------------------------- 56 | # build componets 57 | self.layers = nn.ModuleDict() 58 | self.layers['word_embed'] = nn.Embedding(hps.vocab_size, 59 | hps.word_emb_size, padding_idx=hps.pad_idx) 60 | 61 | # NOTE: We set fixed 33 phonology categories: 0~32 62 | # please refer to preprocess.py for more details 63 | self.layers['ph_embed'] = nn.Embedding(33, hps.ph_emb_size) 64 | 65 | self.layers['len_embed'] = nn.Embedding(hps.sen_len, hps.len_emb_size) 66 | 67 | 68 | self.layers['encoder'] = BidirEncoder(hps.word_emb_size, hps.hidden_size, drop_ratio=hps.drop_ratio) 69 | self.layers['decoder'] = Decoder(hps.hidden_size, hps.hidden_size, drop_ratio=hps.drop_ratio) 70 | 71 | # project the decoder hidden state to a vocanbulary-size output logit 72 | self.layers['out_proj'] = nn.Linear(hps.hidden_size, hps.vocab_size) 73 | 74 | # update the context vector 75 | self.layers['global_trace_updater'] = ContextLayer(hps.hidden_size, hps.global_trace_size) 76 | self.layers['topic_trace_updater'] = MLP(self.mem_size+self.topic_trace_size, 77 | layer_sizes=[self.topic_trace_size], activs=['tanh'], drop_ratio=hps.drop_ratio) 78 | 79 | 80 | # MLP for calculate initial decoder state 81 | self.layers['dec_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size], 82 | activs=['tanh'], drop_ratio=hps.drop_ratio) 83 | self.layers['key_init'] = MLP(hps.hidden_size*2, layer_sizes=[hps.hidden_size], 84 | activs=['tanh'], drop_ratio=hps.drop_ratio) 85 | 86 | # history memory reading and writing layers 87 | # query: concatenation of hidden state, global_trace and topic_trace 88 | self.layers['memory_read'] = AttentionReader( 89 | d_q=hps.hidden_size+self.global_trace_size+self.topic_trace_size+self.topic_slots, 90 | d_v=hps.mem_size, drop_ratio=hps.attn_drop_ratio) 91 | 92 | self.layers['memory_write'] = AttentionWriter(hps.mem_size+self.global_trace_size, hps.mem_size) 93 | 94 | # NOTE: a layer to compress the encoder hidden states to a smaller size for larger number of slots 95 | self.layers['mem_compress'] = MLP(hps.hidden_size*2, layer_sizes=[hps.mem_size], 96 | activs=['tanh'], drop_ratio=hps.drop_ratio) 97 | 98 | # [inp, attns, ph_inp, len_inp, global_trace] 99 | self.layers['merge_x'] = MLP( 100 | hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size+hps.global_trace_size+hps.mem_size, 101 | layer_sizes=[hps.hidden_size], 102 | activs=['tanh'], drop_ratio=hps.drop_ratio) 103 | 104 | 105 | # two annealing parameters 106 | self._tau = 1.0 107 | self._teach_ratio = 0.8 108 | 109 | 110 | # --------------------------------------------------------- 111 | # only used for for pre-training 112 | self.layers['dec_init_pre'] = MLP(hps.hidden_size*2, 113 | layer_sizes=[hps.hidden_size], 114 | activs=['tanh'], drop_ratio=hps.drop_ratio) 115 | 116 | self.layers['merge_x_pre'] = MLP( 117 | hps.word_emb_size+hps.ph_emb_size+hps.len_emb_size, 118 | layer_sizes=[hps.hidden_size], 119 | activs=['tanh'], drop_ratio=hps.drop_ratio) 120 | 121 | 122 | 123 | #--------------------------------- 124 | def set_tau(self, tau): 125 | if 0.0 < tau <= 1.0: 126 | self.layers['memory_write'].set_tau(tau) 127 | 128 | def get_tau(self): 129 | return self.layers['memory_write'].get_tau() 130 | 131 | def set_teach_ratio(self, teach_ratio): 132 | if 0.0 < teach_ratio <= 1.0: 133 | self._teach_ratio = teach_ratio 134 | 135 | def get_teach_ratio(self): 136 | return self._teach_ratio 137 | 138 | 139 | def set_null_idxes(self, null_idxes): 140 | self.null_idxes = null_idxes.to(self.device).unsqueeze(0) 141 | 142 | 143 | #--------------------------------- 144 | def compute_null_mem(self, batch_size): 145 | # we initialize the null memory slot with an average of stop words 146 | # by supposing that the model could learn to ignore these words 147 | emb_null = self.layers['word_embed'](self.null_idxes) 148 | 149 | # (1, L, 2*H) 150 | enc_outs, _ = self.layers['encoder'](emb_null) 151 | 152 | # (1, L, 2 * H) -> (1, L, D) 153 | null_mem = self.layers['mem_compress'](enc_outs) 154 | 155 | # (1, L, D)->(1, 1, D)->(B, 1, D) 156 | self.null_mem = null_mem.mean(dim=1, keepdim=True).repeat(batch_size, 1, 1) 157 | 158 | 159 | def computer_topic_memory(self, keys): 160 | # (B, key_len) 161 | emb_keys = [self.layers['word_embed'](key) for key in keys] 162 | key_lens = [get_seq_length(key, self.pad_idx, self.device) for key in keys] 163 | 164 | batch_size = emb_keys[0].size(0) 165 | 166 | # length == 0 means this is am empty topic slot 167 | topic_mask = torch.zeros(batch_size, self.topic_slots, 168 | dtype=torch.float, device=self.device).bool() # (B, topic_slots) 169 | for step in range(0, self.topic_slots): 170 | topic_mask[:, step] = torch.eq(key_lens[step], 0) 171 | 172 | 173 | key_states_vec, topic_slots = [], [] 174 | for step, (emb_key, length) in enumerate(zip(emb_keys, key_lens)): 175 | 176 | # we set the length of empty keys to 1 for parallel processing, 177 | # which will be masked then for memory reading 178 | length.masked_fill_(length.eq(0), 1) 179 | 180 | _, state = self.layers['encoder'](emb_key, length) 181 | # (2, B, H) -> (B, 2, H) -> (B, 2*H) 182 | key_state = state.transpose(0, 1).contiguous().view(batch_size, -1) 183 | mask = (1 - topic_mask[:, step].float()).unsqueeze(1) # (B, 1) 184 | 185 | key_states_vec.append((key_state*mask).unsqueeze(1)) 186 | 187 | topic = self.layers['mem_compress'](key_state) 188 | topic_slots.append((topic*mask).unsqueeze(1)) 189 | 190 | # (B, topic_slots, mem_size) 191 | topic_mem = torch.cat(topic_slots, dim=1) 192 | 193 | # (B, H) 194 | key_init_state = self.layers['key_init']( 195 | torch.cat(key_states_vec, dim=1).sum(1)) 196 | 197 | return topic_mem, topic_mask, key_init_state 198 | 199 | 200 | def computer_local_memory(self, inps, with_length): 201 | batch_size = inps.size(0) 202 | if with_length: 203 | length = get_seq_length(inps, self.pad_idx, self.device) 204 | else: 205 | length = None 206 | 207 | emb_inps = self.layers['word_embed'](inps) 208 | 209 | # outs: (B, L, 2 * H) 210 | # states: (2, B, H) 211 | enc_outs, enc_states = self.layers['encoder'](emb_inps, length) 212 | 213 | init_state = self.layers['dec_init'](enc_states.transpose(0, 1). 214 | contiguous().view(batch_size, -1)) 215 | 216 | # (B, L, 2 * H) -> (B, L, D) 217 | local_mem = self.layers['mem_compress'](enc_outs) 218 | 219 | local_mask = torch.eq(inps, self.pad_idx) 220 | 221 | return local_mem, local_mask, init_state 222 | 223 | 224 | def update_global_trace(self, old_global_trace, dec_states, dec_mask): 225 | states = torch.cat(dec_states, dim=2) # (B, H, L) 226 | global_trace = self.layers['global_trace_updater']( 227 | old_global_trace, states*(dec_mask.unsqueeze(1))) 228 | return global_trace 229 | 230 | 231 | def update_topic_trace(self, topic_trace, topic_mem, concat_aligns): 232 | # topic_trace: (B, topic_trace_size+topic_slots) 233 | # concat_aligns: (B, L_gen, mem_slots) 234 | 235 | # 1: topic memory, 2: history memory 3: local memory 236 | topic_align = concat_aligns[:, :, 0:self.topic_slots].mean(dim=1) # (B, topic_slots) 237 | 238 | # (B, topic_slots, mem_size) * (B, topic_slots, 1) -> (B, topic_slots, mem_size) 239 | # -> (B, mem_size) 240 | topic_used = torch.mul(topic_mem, topic_align.unsqueeze(2)).mean(dim=1) 241 | 242 | 243 | new_topic_trace = self.layers['topic_trace_updater']( 244 | torch.cat([topic_trace[:, 0:self.topic_trace_size], topic_used], dim=1)) 245 | 246 | read_log = topic_trace[:, self.topic_trace_size:] + topic_align 247 | 248 | fin_topic_trace = torch.cat([new_topic_trace, read_log], dim=1) 249 | 250 | return fin_topic_trace 251 | 252 | 253 | def dec_step(self, inp, state, ph, length, total_mem, total_mask, 254 | global_trace, topic_trace): 255 | 256 | emb_inp = self.layers['word_embed'](inp) 257 | emb_ph = self.layers['ph_embed'](ph) 258 | emb_len = self.layers['len_embed'](length) 259 | 260 | # query for reading read memory 261 | # (B, 1, H] 262 | query = torch.cat([state, global_trace, topic_trace], dim=1).unsqueeze(1) 263 | 264 | # attns: (B, 1, mem_size), align: (B, 1, L) 265 | attns, align = self.layers['memory_read'](query, total_mem, total_mem, total_mask) 266 | 267 | 268 | x = torch.cat([emb_inp, emb_ph, emb_len, attns, global_trace], dim=1).unsqueeze(1) 269 | x = self.layers['merge_x'](x) 270 | 271 | cell_out, new_state = self.layers['decoder'](x, state) 272 | out = self.layers['out_proj'](cell_out) 273 | return out, new_state, align 274 | 275 | 276 | def run_decoder(self, inps, trgs, phs, lens, key_init_state, 277 | history_mem, history_mask, topic_mem, topic_mask, global_trace, topic_trace, 278 | specified_teach_ratio): 279 | 280 | local_mem, local_mask, init_state = \ 281 | self.computer_local_memory(inps, key_init_state is None) 282 | 283 | if key_init_state is not None: 284 | init_state = key_init_state 285 | 286 | if specified_teach_ratio is None: 287 | teach_ratio = self._teach_ratio 288 | else: 289 | teach_ratio = specified_teach_ratio 290 | 291 | 292 | # Note this order: 1: topic memory, 2: history memory 3: local memory 293 | total_mask = torch.cat([topic_mask, history_mask, local_mask], dim=1) 294 | total_mem = torch.cat([topic_mem, history_mem, local_mem], dim=1) 295 | 296 | batch_size = inps.size(0) 297 | trg_len = trgs.size(1) 298 | 299 | outs = torch.zeros(batch_size, trg_len, self.vocab_size, 300 | dtype=torch.float, device=self.device) 301 | 302 | state = init_state 303 | inp = self.bos_tensor.repeat(batch_size) 304 | dec_states, attn_weights = [], [] 305 | 306 | # generate each line 307 | for t in range(0, trg_len): 308 | out, state, align = self.dec_step(inp, state, phs[:, t], 309 | lens[:, t], total_mem, total_mask, global_trace, topic_trace) 310 | outs[:, t, :] = out 311 | 312 | attn_weights.append(align) 313 | 314 | # teach force with a probability 315 | is_teach = random.random() < teach_ratio 316 | if is_teach or (not self.training): 317 | inp = trgs[:, t] 318 | else: 319 | normed_out = F.softmax(out, dim=-1) 320 | inp = normed_out.data.max(1)[1] 321 | 322 | dec_states.append(state.unsqueeze(2)) # (B, H, 1) 323 | attn_weights.append(align) 324 | 325 | 326 | 327 | # write the history memory 328 | if key_init_state is None: 329 | new_history_mem, _ = self.layers['memory_write'](history_mem, local_mem, 330 | 1.0-local_mask.float(), global_trace, self.null_mem) 331 | else: 332 | new_history_mem = history_mem 333 | 334 | # (B, L) 335 | dec_mask = get_non_pad_mask(trgs, self.pad_idx, self.device) 336 | 337 | # update global trace vector 338 | new_global_trace = self.update_global_trace(global_trace, dec_states, dec_mask) 339 | 340 | 341 | # update topic trace vector 342 | # attn_weights: (B, 1, all_mem_slots) * L_gen 343 | concat_aligns = torch.cat(attn_weights, dim=1) 344 | new_topic_trace = self.update_topic_trace(topic_trace, topic_mem, concat_aligns) 345 | 346 | 347 | return outs, new_history_mem, new_global_trace, new_topic_trace 348 | 349 | 350 | 351 | def initialize_mems(self, keys): 352 | batch_size = keys[0].size(0) 353 | topic_mem, topic_mask, key_init_state = self.computer_topic_memory(keys) 354 | 355 | history_mem = torch.zeros(batch_size, self.his_mem_slots, self.mem_size, 356 | dtype=torch.float, device=self.device) 357 | 358 | # default: True, masked 359 | history_mask = torch.ones(batch_size, self.his_mem_slots, 360 | dtype=torch.float, device=self.device).bool() 361 | 362 | global_trace = torch.zeros(batch_size, self.global_trace_size, 363 | dtype=torch.float, device=self.device) 364 | topic_trace = torch.zeros(batch_size, self.topic_trace_size+self.topic_slots, 365 | dtype=torch.float, device=self.device) 366 | 367 | self.compute_null_mem(batch_size) 368 | 369 | return topic_mem, topic_mask, history_mem, history_mask,\ 370 | global_trace, topic_trace, key_init_state 371 | 372 | 373 | def rebuild_inps(self, ori_inps, last_outs, teach_ratio): 374 | # ori_inps: (B, L) 375 | # last_outs: (B, L, V) 376 | inp_len = ori_inps.size(1) 377 | new_inps = torch.ones_like(ori_inps) * self.pad_idx 378 | 379 | mask = get_non_pad_mask(ori_inps, self.pad_idx, self.device).long() 380 | 381 | if teach_ratio is None: 382 | teach_ratio = self._teach_ratio 383 | 384 | for t in range(0, inp_len): 385 | is_teach = random.random() < teach_ratio 386 | if is_teach or (not self.training): 387 | new_inps[:, t] = ori_inps[:, t] 388 | else: 389 | normed_out = F.softmax(last_outs[:, t], dim=-1) 390 | new_inps[:, t] = normed_out.data.max(1)[1] 391 | 392 | new_inps = new_inps * mask 393 | 394 | return new_inps 395 | 396 | 397 | def forward(self, all_inps, all_trgs, all_ph_inps, all_len_inps, keys, teach_ratio=None, 398 | flexible_inps=False): 399 | ''' 400 | all_inps: (B, L) * sens_num 401 | all_trgs: (B, L) * sens_num 402 | all_ph_inps: (B, L) * sens_num 403 | all_len_inps: (B, L) * sens_num 404 | keys: (B, L) * topic_slots 405 | flexible_inps: if apply partial teaching force to local memory. 406 | False: the ground-truth src line is stored into the local memory 407 | True: for local memory, ground-truth characters will be replaced with generated characters with 408 | the probability of 1- teach_ratio. 409 | NOTE: this trick is *not* adopted in our original paper, which could lead to 410 | better BLEU and topic relevance, but worse diversity of generated poems. 411 | ''' 412 | all_outs = [] 413 | 414 | topic_mem, topic_mask, history_mem, history_mask,\ 415 | global_trace, topic_trace, key_init_state = self.initialize_mems(keys) 416 | 417 | for step in range(0, self.sens_num): 418 | if step > 0: 419 | key_init_state = None 420 | 421 | if step >= 1 and flexible_inps: 422 | inps = self.rebuild_inps(all_inps[step], all_outs[-1], teach_ratio) 423 | else: 424 | inps = all_inps[step] 425 | 426 | outs, history_mem, global_trace, topic_trace \ 427 | = self.run_decoder(inps, all_trgs[step], 428 | all_ph_inps[step], all_len_inps[step], key_init_state, 429 | history_mem, history_mask, topic_mem, topic_mask, 430 | global_trace, topic_trace, teach_ratio) 431 | 432 | if step >= 1: 433 | history_mask = history_mem.abs().sum(-1).eq(0) # (B, mem_slots) 434 | 435 | 436 | all_outs.append(outs) 437 | 438 | 439 | return all_outs 440 | 441 | 442 | 443 | # -------------------------- 444 | # graphs for pre-training 445 | def dseq_graph(self, inps, trgs, ph_inps, len_inps, teach_ratio=None): 446 | # pre-train the encoder and decoder as a denoising Seq2Seq model 447 | batch_size, trg_len = trgs.size(0), trgs.size(1) 448 | length = get_seq_length(inps, self.pad_idx, self.device) 449 | 450 | 451 | emb_inps = self.layers['word_embed'](inps) 452 | emb_phs = self.layers['ph_embed'](ph_inps) 453 | emb_lens = self.layers['len_embed'](len_inps) 454 | 455 | 456 | # outs: (B, L, 2 * H) 457 | # states: (2, B, H) 458 | _, enc_states = self.layers['encoder'](emb_inps, length) 459 | 460 | 461 | init_state = self.layers['dec_init_pre'](enc_states.transpose(0, 1). 462 | contiguous().view(batch_size, -1)) 463 | 464 | 465 | outs = torch.zeros(batch_size, trg_len, self.vocab_size, 466 | dtype=torch.float, device=self.device) 467 | 468 | if teach_ratio is None: 469 | teach_ratio = self._teach_ratio 470 | 471 | state = init_state 472 | inp = self.bos_tensor.repeat(batch_size, 1) 473 | 474 | # generate each line 475 | for t in range(0, trg_len): 476 | emb_inp = self.layers['word_embed'](inp) 477 | x = self.layers['merge_x_pre'](torch.cat( 478 | [emb_inp, emb_phs[:, t].unsqueeze(1), emb_lens[:, t].unsqueeze(1)], 479 | dim=-1)) 480 | 481 | cell_out, state, = self.layers['decoder'](x, state) 482 | out = self.layers['out_proj'](cell_out) 483 | 484 | outs[:, t, :] = out 485 | 486 | # teach force with a probability 487 | is_teach = random.random() < teach_ratio 488 | if is_teach or (not self.training): 489 | inp = trgs[:, t].unsqueeze(1) 490 | else: 491 | normed_out = F.softmax(out, dim=-1) 492 | top1 = normed_out.data.max(1)[1] 493 | inp = top1.unsqueeze(1) 494 | 495 | 496 | return outs 497 | 498 | 499 | # ---------------------------------------------- 500 | def dseq_parameter_names(self): 501 | required_names = ['word_embed', 'ph_embed', 'len_embed', 502 | 'encoder', 'decoder', 'out_proj', 503 | 'dec_init_pre', 'merge_x_pre'] 504 | return required_names 505 | 506 | def dseq_parameters(self): 507 | names = self.dseq_parameter_names() 508 | 509 | required_params = [self.layers[name].parameters() for name in names] 510 | 511 | return chain.from_iterable(required_params) 512 | 513 | # ------------------------------------------------ --------------------------------------------------------------------------------