├── RunLog └── readme.md ├── DataSet └── readme.md ├── ModelStorage └── readme.md ├── README.md ├── preprocess.py ├── langconv.py ├── MultiTaskXLIR-DRMC.py ├── GRUIRMoS.py ├── MultiTaskXLIR-DuReader.py └── MultiTaskXLIR-Final.py /RunLog/readme.md: -------------------------------------------------------------------------------- 1 | 训练日志 2 | -------------------------------------------------------------------------------- /DataSet/readme.md: -------------------------------------------------------------------------------- 1 | 请查看首页readme 2 | -------------------------------------------------------------------------------- /ModelStorage/readme.md: -------------------------------------------------------------------------------- 1 | placeholder,文件夹用于保存训练过程中产生的模型文件 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tianchi2020ChineseMedicineQuestionGeneration 2 | 2020 阿里云天池大数据竞赛-中医药文献问题生成挑战赛 3 | 4 | 官网链接: https://tianchi.aliyun.com/competition/entrance/531826/introduction 5 | 6 | `初赛成绩`: 0.6133(11/868) `复赛成绩`: 0.6215(8/868=>复赛代码审核后为第6) 7 | 8 | **均为single model** 9 | 10 | 包含数据集的完整项目文件百度盘链接: `https://pan.baidu.com/s/1crAYwtDLrGnkls9xdfQdQg` 提取码:`qagl` 11 | (备注:网盘链接不稳定, 有可能会被百度误封, 如需完整数据文件, 可私信anlin781205936@126.com) 12 | 13 | 模型整体思路: 预训练语言模型(RoBERTa_wwm_ext_large)作为编码器, Transformer-XL作为解码器(train from scratch),使用其他阅读理解数据集进行预学习,再在比赛数据集上进行微调 14 | 15 | 整体流程: 16 | > 1. 数据预处理:python preprocess.py生成multi-task.pkl 17 | > 2. 在DuReader数据集上粗粒度的预学习nohup python -u MultiTaskXLIR-DuReader train gpu-0 & (自行设置batch-size和gpu数量) 18 | > 3. 在DRCD和CMRC2018数据集上细粒度的预学习nohup python -u MultiTaskXLIR-DRMC train gpu-0 & 19 | > 4. 在比赛数据集上进行学习nohup python -u MultiTaskXLIR-Final train gpu-0 final & 20 | > 5. 使用beam_search生成测试集结果python MultiTaskXLIR-Final test gpu-0 21 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Time: 2020/9/14 5 | @Author: menghuanlater 6 | @Software: Pycharm 2019.2 7 | @Usage: data preprocess 8 | ----------------------------- 9 | Description: 10 | ----------------------------- 11 | """ 12 | import json 13 | import os 14 | import pickle 15 | from random import shuffle 16 | import numpy as np 17 | 18 | train_file = open("DataSet/round1_train_0907.json", "r", encoding="UTF-8") 19 | test_file = open("DataSet/juesai_1011.json", "r", encoding="UTF-8") 20 | train_data = json.load(train_file) 21 | 22 | x = [] 23 | 24 | for item in train_data: 25 | for jtem in item["annotations"]: 26 | x.append({ 27 | "context": item["text"], 28 | "query": jtem["Q"], 29 | "answer": jtem["A"] 30 | }) 31 | 32 | for i in range(10): 33 | shuffle(x) 34 | 35 | output = { 36 | "train_items": x[3000:], 37 | "test_items": list(json.load(test_file)), 38 | "valid_items": x[:3000], 39 | "dureader_train_items": [], 40 | "cmrc_train_items": [], 41 | "drcd_train_items": [], 42 | "multi_task_epoch": 6 43 | } 44 | print("===完成比赛数据处理===") 45 | 46 | 47 | # 首先处理crmc数据 ==> 相对标准 48 | def cmrc_json(data): 49 | for dtem in data: 50 | paragraphs = dtem["paragraphs"] 51 | for ptem in paragraphs: 52 | context = ptem["context"][:600] 53 | qas = ptem["qas"] 54 | for qtem in qas: 55 | query = qtem["question"] 56 | answer = qtem["answers"][0]["text"] 57 | output["cmrc_train_items"].append({ 58 | "context": context, "query": query, "answer": answer 59 | }) 60 | 61 | 62 | for file in os.listdir("DataSet/MultiTask/CMRC"): 63 | with open("DataSet/MultiTask/CMRC/" + file, "r", encoding="UTF-8") as f: 64 | cmrc_json(json.load(f)["data"]) 65 | 66 | 67 | print("===完成CMRC数据处理===") 68 | 69 | from langconv import * 70 | obj = Converter('zh-hans') 71 | 72 | 73 | # 其次处理DRCD数据 74 | def drcd_json(data): 75 | for dtem in data: 76 | paragraphs = dtem["paragraphs"] 77 | for ptem in paragraphs: 78 | context = obj.convert(ptem["context"][:600]) 79 | qas = ptem["qas"] 80 | for qtem in qas: 81 | query = obj.convert(qtem["question"]) 82 | answer = obj.convert(qtem["answers"][0]["text"]) 83 | output["drcd_train_items"].append({ 84 | "context": context, "query": query, "answer": answer 85 | }) 86 | 87 | 88 | for file in os.listdir("DataSet/MultiTask/DRCD"): 89 | with open("DataSet/MultiTask/DRCD/" + file, "r", encoding="UTF-8") as f: 90 | drcd_json(json.load(f)["data"]) 91 | 92 | print("===完成DRCD数据处理===") 93 | 94 | 95 | # 最后处理DuReader(完全是用来粗调的==> 粒度太碎) 96 | def dureader_json(data): 97 | for item in data: 98 | if item["question_type"] == "YES_NO": 99 | continue 100 | context = "" 101 | for doc_item in item["documents"]: 102 | if doc_item["is_selected"]: 103 | context += " ".join(doc_item["paragraphs"]) 104 | if len(context) >= 600: 105 | break 106 | context = context[:600] 107 | answers = item["answers"] 108 | for atem in answers: 109 | output["dureader_train_items"].append({ 110 | "context": context, 111 | "query": item["question"], 112 | "answer": atem 113 | }) 114 | 115 | 116 | for file in os.listdir("DataSet/MultiTask/DuReader/devset"): 117 | with open("DataSet/MultiTask/DuReader/devset/" + file, "r", encoding="UTF-8") as f: 118 | dureader_json([json.loads(s) for s in f.readlines()]) 119 | print("===完成DuReader Dev数据处理===") 120 | 121 | for file in os.listdir("DataSet/MultiTask/DuReader/trainset"): 122 | with open("DataSet/MultiTask/DuReader/trainset/" + file, "r", encoding="UTF-8") as f: 123 | dureader_json([json.loads(s) for s in f.readlines()]) 124 | print("===完成DuReader Train数据处理===") 125 | 126 | for i in range(3): 127 | shuffle(output["dureader_train_items"]) 128 | shuffle(output["drcd_train_items"]) 129 | shuffle(output["cmrc_train_items"]) 130 | print("CMRC2018用于训练的数据一共有%d条" % len(output["cmrc_train_items"])) 131 | print("DRCD用于训练的数据一共有%d条" % len(output["drcd_train_items"])) 132 | print("DuReader用于训练的数据一共有%d条" % len(output["dureader_train_items"])) 133 | with open("DataSet/multi_task.pkl", "wb") as f: 134 | pickle.dump(output, f) 135 | 136 | -------------------------------------------------------------------------------- /langconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from copy import deepcopy 5 | import re 6 | 7 | try: 8 | import psyco 9 | psyco.full() 10 | except: 11 | pass 12 | 13 | try: 14 | from zh_wiki import zh2Hant, zh2Hans 15 | except ImportError: 16 | from zhtools.zh_wiki import zh2Hant, zh2Hans 17 | 18 | import sys 19 | py3k = sys.version_info >= (3, 0, 0) 20 | 21 | if py3k: 22 | UEMPTY = '' 23 | else: 24 | _zh2Hant, _zh2Hans = {}, {} 25 | for old, new in ((zh2Hant, _zh2Hant), (zh2Hans, _zh2Hans)): 26 | for k, v in old.items(): 27 | new[k.decode('utf8')] = v.decode('utf8') 28 | zh2Hant = _zh2Hant 29 | zh2Hans = _zh2Hans 30 | UEMPTY = ''.decode('utf8') 31 | 32 | # states 33 | (START, END, FAIL, WAIT_TAIL) = list(range(4)) 34 | # conditions 35 | (TAIL, ERROR, MATCHED_SWITCH, UNMATCHED_SWITCH, CONNECTOR) = list(range(5)) 36 | 37 | MAPS = {} 38 | 39 | class Node(object): 40 | def __init__(self, from_word, to_word=None, is_tail=True, 41 | have_child=False): 42 | self.from_word = from_word 43 | if to_word is None: 44 | self.to_word = from_word 45 | self.data = (is_tail, have_child, from_word) 46 | self.is_original = True 47 | else: 48 | self.to_word = to_word or from_word 49 | self.data = (is_tail, have_child, to_word) 50 | self.is_original = False 51 | self.is_tail = is_tail 52 | self.have_child = have_child 53 | 54 | def is_original_long_word(self): 55 | return self.is_original and len(self.from_word)>1 56 | 57 | def is_follow(self, chars): 58 | return chars != self.from_word[:-1] 59 | 60 | def __str__(self): 61 | return '' % (repr(self.from_word), 62 | repr(self.to_word), self.is_tail, self.have_child) 63 | 64 | __repr__ = __str__ 65 | 66 | class ConvertMap(object): 67 | def __init__(self, name, mapping=None): 68 | self.name = name 69 | self._map = {} 70 | if mapping: 71 | self.set_convert_map(mapping) 72 | 73 | def set_convert_map(self, mapping): 74 | convert_map = {} 75 | have_child = {} 76 | max_key_length = 0 77 | for key in sorted(mapping.keys()): 78 | if len(key)>1: 79 | for i in range(1, len(key)): 80 | parent_key = key[:i] 81 | have_child[parent_key] = True 82 | have_child[key] = False 83 | max_key_length = max(max_key_length, len(key)) 84 | for key in sorted(have_child.keys()): 85 | convert_map[key] = (key in mapping, have_child[key], 86 | mapping.get(key, UEMPTY)) 87 | self._map = convert_map 88 | self.max_key_length = max_key_length 89 | 90 | def __getitem__(self, k): 91 | try: 92 | is_tail, have_child, to_word = self._map[k] 93 | return Node(k, to_word, is_tail, have_child) 94 | except: 95 | return Node(k) 96 | 97 | def __contains__(self, k): 98 | return k in self._map 99 | 100 | def __len__(self): 101 | return len(self._map) 102 | 103 | class StatesMachineException(Exception): pass 104 | 105 | class StatesMachine(object): 106 | def __init__(self): 107 | self.state = START 108 | self.final = UEMPTY 109 | self.len = 0 110 | self.pool = UEMPTY 111 | 112 | def clone(self, pool): 113 | new = deepcopy(self) 114 | new.state = WAIT_TAIL 115 | new.pool = pool 116 | return new 117 | 118 | def feed(self, char, map): 119 | node = map[self.pool+char] 120 | 121 | if node.have_child: 122 | if node.is_tail: 123 | if node.is_original: 124 | cond = UNMATCHED_SWITCH 125 | else: 126 | cond = MATCHED_SWITCH 127 | else: 128 | cond = CONNECTOR 129 | else: 130 | if node.is_tail: 131 | cond = TAIL 132 | else: 133 | cond = ERROR 134 | 135 | new = None 136 | if cond == ERROR: 137 | self.state = FAIL 138 | elif cond == TAIL: 139 | if self.state == WAIT_TAIL and node.is_original_long_word(): 140 | self.state = FAIL 141 | else: 142 | self.final += node.to_word 143 | self.len += 1 144 | self.pool = UEMPTY 145 | self.state = END 146 | elif self.state == START or self.state == WAIT_TAIL: 147 | if cond == MATCHED_SWITCH: 148 | new = self.clone(node.from_word) 149 | self.final += node.to_word 150 | self.len += 1 151 | self.state = END 152 | self.pool = UEMPTY 153 | elif cond == UNMATCHED_SWITCH or cond == CONNECTOR: 154 | if self.state == START: 155 | new = self.clone(node.from_word) 156 | self.final += node.to_word 157 | self.len += 1 158 | self.state = END 159 | else: 160 | if node.is_follow(self.pool): 161 | self.state = FAIL 162 | else: 163 | self.pool = node.from_word 164 | elif self.state == END: 165 | # END is a new START 166 | self.state = START 167 | new = self.feed(char, map) 168 | elif self.state == FAIL: 169 | raise StatesMachineException('Translate States Machine ' 170 | 'have error with input data %s' % node) 171 | return new 172 | 173 | def __len__(self): 174 | return self.len + 1 175 | 176 | def __str__(self): 177 | return '' % ( 178 | id(self), self.pool, self.state, self.final) 179 | __repr__ = __str__ 180 | 181 | class Converter(object): 182 | def __init__(self, to_encoding): 183 | self.to_encoding = to_encoding 184 | self.map = MAPS[to_encoding] 185 | self.start() 186 | 187 | def feed(self, char): 188 | branches = [] 189 | for fsm in self.machines: 190 | new = fsm.feed(char, self.map) 191 | if new: 192 | branches.append(new) 193 | if branches: 194 | self.machines.extend(branches) 195 | self.machines = [fsm for fsm in self.machines if fsm.state != FAIL] 196 | all_ok = True 197 | for fsm in self.machines: 198 | if fsm.state != END: 199 | all_ok = False 200 | if all_ok: 201 | self._clean() 202 | return self.get_result() 203 | 204 | def _clean(self): 205 | if len(self.machines): 206 | self.machines.sort(key=lambda x: len(x)) 207 | # self.machines.sort(cmp=lambda x,y: cmp(len(x), len(y))) 208 | self.final += self.machines[0].final 209 | self.machines = [StatesMachine()] 210 | 211 | def start(self): 212 | self.machines = [StatesMachine()] 213 | self.final = UEMPTY 214 | 215 | def end(self): 216 | self.machines = [fsm for fsm in self.machines 217 | if fsm.state == FAIL or fsm.state == END] 218 | self._clean() 219 | 220 | def convert(self, string): 221 | self.start() 222 | for char in string: 223 | self.feed(char) 224 | self.end() 225 | return self.get_result() 226 | 227 | def get_result(self): 228 | return self.final 229 | 230 | 231 | def registery(name, mapping): 232 | global MAPS 233 | MAPS[name] = ConvertMap(name, mapping) 234 | 235 | registery('zh-hant', zh2Hant) 236 | registery('zh-hans', zh2Hans) 237 | del zh2Hant, zh2Hans 238 | 239 | 240 | def run(): 241 | import sys 242 | from optparse import OptionParser 243 | parser = OptionParser() 244 | parser.add_option('-e', type='string', dest='encoding', 245 | help='encoding') 246 | parser.add_option('-f', type='string', dest='file_in', 247 | help='input file (- for stdin)') 248 | parser.add_option('-t', type='string', dest='file_out', 249 | help='output file') 250 | (options, args) = parser.parse_args() 251 | if not options.encoding: 252 | parser.error('encoding must be set') 253 | if options.file_in: 254 | if options.file_in == '-': 255 | file_in = sys.stdin 256 | else: 257 | file_in = open(options.file_in) 258 | else: 259 | file_in = sys.stdin 260 | if options.file_out: 261 | if options.file_out == '-': 262 | file_out = sys.stdout 263 | else: 264 | file_out = open(options.file_out, 'wb') 265 | else: 266 | file_out = sys.stdout 267 | 268 | c = Converter(options.encoding) 269 | for line in file_in: 270 | # print >> file_out, c.convert(line.rstrip('\n').decode( 271 | file_out.write(c.convert(line.rstrip('\n').decode( 272 | 'utf8')).encode('utf8')) 273 | 274 | 275 | if __name__ == '__main__': 276 | run() -------------------------------------------------------------------------------- /MultiTaskXLIR-DRMC.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Time: 2020/9/14 5 | @Author: menghuanlater 6 | @Software: Pycharm 2019.2 7 | @Usage: data preprocess 8 | ----------------------------- 9 | Description: Base on RoBERTa and Transformer-XL Decoder and Copy Mechanism 10 | Transformer Decoder采用Transformer-XL 11 | ----------------------------- 12 | """ 13 | import os 14 | import sys 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2][4:] 17 | from typing import Any 18 | 19 | from transformers import BertTokenizer, BertModel, BertConfig 20 | import torch 21 | from torch import nn 22 | import pickle 23 | from torch.utils.data import DataLoader, Dataset 24 | from torch import optim 25 | import numpy as np 26 | import json 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | class MyDataset(Dataset): 31 | def __init__(self, data, max_enc_len, max_dec_len): 32 | self.data = data 33 | self.max_encode_len = max_enc_len 34 | self.max_decode_len = max_dec_len 35 | self.SEG_A = 0 36 | self.SEG_P = 1 37 | self.ID_PAD = 0 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, index): 43 | item = self.data[index] 44 | context, query, answer = item["context"], item["query"], item["answer"] 45 | context_tokens = tokenizer.tokenize(context) 46 | query_tokens = tokenizer.tokenize(query) 47 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 48 | 49 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 50 | if len(c) > self.max_encode_len - 1: 51 | c = c[:self.max_encode_len - 1] 52 | c += ["[SEP]"] 53 | input_ids = tokenizer.convert_tokens_to_ids(c) 54 | input_mask = [1.0] * len(input_ids) 55 | input_seg = [self.SEG_A] * (len(answer_tokens) + 2) + [self.SEG_P] * (len(input_ids) - 2 - len(answer_tokens)) 56 | extra = self.max_encode_len - len(input_ids) 57 | if extra > 0: 58 | input_ids += [self.ID_PAD] * extra 59 | input_mask += [0.0] * extra 60 | input_seg += [self.SEG_P] * extra 61 | if len(query_tokens) > self.max_decode_len - 1: 62 | query_tokens = query_tokens[: self.max_decode_len - 1] 63 | c = tokenizer.convert_tokens_to_ids(query_tokens) 64 | dec_input = [args["start_token_id"]] + c 65 | dec_target = c + [args["end_token_id"]] 66 | extra = self.max_decode_len - len(dec_input) 67 | if extra > 0: 68 | dec_input += [self.ID_PAD] * extra 69 | dec_target += [self.ID_PAD] * extra 70 | return { 71 | "input_ids": torch.tensor(input_ids).long(), "input_mask": torch.tensor(input_mask).float(), 72 | "input_seg": torch.tensor(input_seg).long(), "decode_input": torch.tensor(dec_input).long(), 73 | "decode_target": torch.tensor(dec_target).long(), "label": query 74 | } 75 | 76 | 77 | class XLRelPosEmb(nn.Module): 78 | def _forward_unimplemented(self, *input: Any) -> None: 79 | pass 80 | 81 | def __init__(self, d_embed): 82 | super().__init__() 83 | 84 | self.d_embed = d_embed 85 | inv_freq = 1 / (10000 ** (torch.arange(0.0, self.d_embed, 2.0) / self.d_embed)) 86 | self.register_buffer("inv_freq", inv_freq) 87 | 88 | def forward(self, pos_seq): 89 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 90 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 91 | return pos_emb 92 | 93 | 94 | class PositionwiseFFN(nn.Module): 95 | def _forward_unimplemented(self, *input: Any) -> None: 96 | pass 97 | 98 | def __init__(self, d_model, d_inner, layer_norm_epsilon=1e-5): 99 | super().__init__() 100 | self.d_model = d_model 101 | self.d_inner = d_inner 102 | self.CoreNet = nn.Sequential( 103 | nn.Linear(d_model, d_inner), 104 | nn.GELU(), 105 | nn.Dropout(p=args["dropout"]), 106 | nn.Linear(d_inner, d_model), 107 | nn.Dropout(p=args["dropout"]) 108 | ) 109 | self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 110 | 111 | def forward(self, inp): 112 | core_out = self.CoreNet(inp) 113 | output = self.layer_norm(inp + core_out) 114 | return output 115 | 116 | 117 | class RelPartialLearnableMultiHeadAttn(torch.nn.Module): 118 | 119 | def _forward_unimplemented(self, *input: Any) -> None: 120 | pass 121 | 122 | def __init__(self, n_heads, d_model, layer_norm_epsilon=1e-5): 123 | super().__init__() 124 | 125 | self.n_heads = n_heads 126 | self.d_model = d_model 127 | self.d_head = d_model // n_heads 128 | 129 | self.mask_attn_qkv_net = nn.Linear(d_model, 3 * d_model, bias=False) 130 | self.mask_attn_o_net = nn.Linear(d_model, d_model, bias=False) 131 | 132 | self.interaction_kv_net = nn.Linear(d_model, 2 * d_model, bias=False) 133 | self.interaction_q_net = nn.Linear(d_model, d_model, bias=False) 134 | self.interaction_o_net = nn.Linear(d_model, d_model, bias=False) 135 | 136 | self.layer_norm_mask_attn = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 137 | self.layer_norm_interaction = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 138 | self.scale = 1 / (self.d_head ** 0.5) 139 | 140 | self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 141 | self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 142 | 143 | self.r_net = nn.Linear(d_model, d_model, bias=False) 144 | 145 | self.drop = nn.Dropout(p=args["dropout"]) 146 | 147 | @staticmethod 148 | def _rel_shift(x): 149 | zero_pad_shape = (x.size(0), 1) + x.size()[2:] 150 | zero_pad = torch.zeros(zero_pad_shape, device=device, dtype=x.dtype) 151 | x_padded = torch.cat([zero_pad, x], dim=1) 152 | 153 | x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] 154 | x_padded = x_padded.view(*x_padded_shape) 155 | 156 | x = x_padded[1:].view_as(x) 157 | 158 | return x 159 | 160 | def forward(self, w, r, enc_context, attn_mask, padding_mask): 161 | # attn_mask用于Masked-Attn Mechanism(decode自身部分) 162 | # padding_mask用于Norm Multi-Attn, Decode与Encode Contextual Rep交互 163 | dec_len, bsz, enc_len = w.size(0), w.size(1), enc_context.size(0) 164 | w_heads = self.mask_attn_qkv_net(w) 165 | r_head_k = self.r_net(r) 166 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 167 | 168 | w_head_q = w_head_q.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 169 | w_head_k = w_head_k.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 170 | w_head_v = w_head_v.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 171 | 172 | r_head_k = r_head_k.view(dec_len, self.n_heads, self.d_head) # dec_len x n_head x d_head 173 | rw_head_q = w_head_q + self.r_w_bias # dec_len x bsz x n_head x d_head 174 | AC = torch.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # dec_len x dec_len x bsz x n_head 175 | rr_head_q = w_head_q + self.r_r_bias 176 | BD = torch.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # dec_len x dec_len x bsz x n_head 177 | BD = self._rel_shift(BD) 178 | 179 | attn_score = AC + BD 180 | attn_score.mul_(self.scale) 181 | 182 | # causal masking mechanism 183 | attn_mask = attn_mask == 0 # switch to bool 184 | attn_score = attn_score.float().masked_fill(attn_mask, -1e30).type_as(attn_score) 185 | attn_prob = torch.softmax(attn_score, dim=1) 186 | attn_prob = self.drop(attn_prob) 187 | 188 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) 189 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 190 | 191 | attn_out = self.mask_attn_o_net(attn_vec) 192 | attn_out = self.drop(attn_out) 193 | 194 | mask_attn_output = self.layer_norm_mask_attn(w + attn_out) 195 | 196 | # 与编码器交互部分 197 | inter_k, inter_v = torch.chunk(self.interaction_kv_net(enc_context), 2, dim=-1) 198 | inter_q = self.interaction_q_net(mask_attn_output) 199 | inter_q = inter_q.view(dec_len, bsz, self.n_heads, self.d_head) 200 | inter_k = inter_k.view(enc_len, bsz, self.n_heads, self.d_head) 201 | inter_v = inter_v.view(enc_len, bsz, self.n_heads, self.d_head) 202 | 203 | attn_score = torch.einsum("qbnd,kbnd->qkbn", inter_q, inter_k) 204 | attn_score.mul_(self.scale) 205 | 206 | # use padding_mask to mask input [PAD] token 207 | padding_mask = padding_mask[None, :, :, None].repeat(dec_len, 1, 1, 1) 208 | attn_score = attn_score + (1 - padding_mask) * (-1e30) 209 | attn_prob = torch.softmax(attn_score, dim=1) 210 | attn_prob = self.drop(attn_prob) 211 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, inter_v) 212 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 213 | 214 | attn_out = self.interaction_o_net(attn_vec) 215 | attn_out = self.drop(attn_out) 216 | 217 | interaction_output = self.layer_norm_interaction(attn_out + mask_attn_output) 218 | return interaction_output 219 | 220 | 221 | class RelPartialLearnableDecoderLayer(torch.nn.Module): 222 | 223 | def _forward_unimplemented(self, *input: Any) -> None: 224 | pass 225 | 226 | def __init__(self, n_heads, d_model, d_inner): 227 | super().__init__() 228 | 229 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_heads=n_heads, d_model=d_model) 230 | self.ffn_layer = PositionwiseFFN(d_model=d_model, d_inner=d_inner) 231 | 232 | def forward(self, dec_inp, r, enc_inp, dec_mask, enc_mask): 233 | attn_output = self.dec_attn(w=dec_inp, r=r, enc_context=enc_inp, attn_mask=dec_mask, padding_mask=enc_mask) 234 | ffn_out = self.ffn_layer(attn_output) 235 | return ffn_out 236 | 237 | 238 | class XLDecoder(torch.nn.Module): 239 | 240 | def _forward_unimplemented(self, *input: Any) -> None: 241 | pass 242 | 243 | def __init__(self, dim, embedding_matrix: nn.Embedding, seq_len): 244 | super().__init__() 245 | self.d_model = dim 246 | self.word_emb = embedding_matrix 247 | self.seq_len = seq_len 248 | self.n_layers = args["decoder_layers"] 249 | self.n_heads = 16 250 | self.ffn = 4 * dim 251 | self.epsilon = 1e-6 252 | 253 | self.drop = nn.Dropout(p=args["dropout"]) 254 | self.pos_emb = XLRelPosEmb(d_embed=dim) 255 | self.layers = nn.ModuleList() 256 | 257 | self.layers = nn.ModuleList() 258 | for i in range(self.n_layers): 259 | self.layers.append( 260 | RelPartialLearnableDecoderLayer( 261 | n_heads=self.n_heads, d_model=self.d_model, d_inner=self.ffn 262 | ) 263 | ) 264 | self.output = nn.Linear(in_features=dim, out_features=dim) 265 | self.copy_output = nn.Linear(in_features=dim, out_features=dim) 266 | # 自适应的解码概率结合 267 | self.mode_select = nn.Sequential( 268 | nn.Linear(in_features=dim, out_features=1), 269 | nn.Sigmoid() 270 | ) 271 | 272 | def forward(self, input_ids, encoder_rep, input_mask, decode_input, decode_target, use_beam_search, beam_width): 273 | bsz = input_ids.size(0) 274 | if decode_input is not None: # 代表训练模式 275 | input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1) 276 | decode_embed = self.drop(self.word_emb(decode_input)) 277 | all_ones = decode_embed.new_ones((self.seq_len, self.seq_len), dtype=torch.uint8) 278 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 279 | pos_seq = torch.arange(self.seq_len - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 280 | pos_embed = self.drop(self.pos_emb(pos_seq)) 281 | core_out = decode_embed.transpose(0, 1).contiguous() 282 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 283 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 284 | for layer in self.layers: 285 | core_out = layer( 286 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 287 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 288 | ) 289 | core_out = self.drop(core_out.transpose(0, 1).contiguous()) # (bsz, dec_len, dim) 290 | output = self.output(core_out) 291 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 292 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 293 | input_logits = torch.einsum("bid,bjd->bij", self.copy_output(core_out), encoder_rep) # (bsz, dec_len, enc_len) 294 | input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat(1, self.seq_len, 1)) * (-1e30) 295 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, dec_len, enc_len) 296 | mode_sig = self.mode_select(core_out) # (bsz, dec_len, 1) 297 | vocab_prob = vocab_prob * mode_sig 298 | vocab_prob = torch.scatter_add(vocab_prob, dim=2, index=input_ids, src=(1 - mode_sig) * input_prob) 299 | vocab_prob = vocab_prob.view(-1, args["vocab_size"]) 300 | decode_target = decode_target.view(-1) 301 | predict = torch.gather(vocab_prob, dim=1, index=decode_target[:, None]).squeeze(dim=-1) 302 | init_loss = -torch.log(predict + self.epsilon) 303 | init_loss *= (decode_target != 0).float() 304 | loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) 305 | # 为了并行化设计, 将loss变成(bsz,) 306 | return loss[None].repeat(bsz) 307 | else: # 代表验证或者测试解码模式 ==> 比较耗时 308 | dec_list = [] 309 | decode_ids = torch.full(size=(bsz, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) 310 | for i in range(1, self.seq_len + 1): 311 | if i > 1: 312 | decode_ids = torch.cat([decode_ids, dec_list[i - 2]], dim=-1) 313 | decode_embed = self.word_emb(decode_ids) 314 | all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) 315 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 316 | pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 317 | pos_embed = self.pos_emb(pos_seq) 318 | core_out = decode_embed.transpose(0, 1).contiguous() 319 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 320 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 321 | for layer in self.layers: 322 | core_out = layer( 323 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 324 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 325 | ) 326 | core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] 327 | output = self.output(core_out) 328 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 329 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 330 | input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz, enc_len) 331 | input_logits = input_logits + (1.0 - input_mask) * (-1e30) 332 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_len) 333 | mode_sig = self.mode_select(core_out) # (bsz, 1) 334 | vocab_prob = vocab_prob * mode_sig 335 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) 336 | dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) 337 | return torch.cat(dec_list, dim=-1) 338 | 339 | 340 | class MyModel(torch.nn.Module): 341 | def _forward_unimplemented(self, *input: Any) -> None: 342 | pass 343 | 344 | def __init__(self, pre_train_dir: str): 345 | super().__init__() 346 | self.roberta_encoder = BertModel(config=BertConfig.from_json_file(pre_train_dir + "config.json")) 347 | self.decoder_layer = XLDecoder( 348 | dim=args["dimension"], embedding_matrix=self.roberta_encoder.get_input_embeddings(), 349 | seq_len=args["max_dec_len"]) 350 | 351 | def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None): 352 | encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0] 353 | return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target, 354 | args["use_beam_search"], 355 | args["beam_width"]) 356 | 357 | 358 | class WarmUp_LinearDecay: 359 | def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_steps, decay_steps, min_lr_rate): 360 | self.optimizer = optimizer 361 | self.init_rate = init_rate 362 | self.warm_up_steps = warm_up_steps 363 | self.decay_steps = decay_steps 364 | self.min_lr_rate = min_lr_rate 365 | self.optimizer_step = 0 366 | 367 | def step(self): 368 | self.optimizer_step += 1 369 | if self.optimizer_step <= self.warm_up_steps: 370 | rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate 371 | elif self.warm_up_steps < self.optimizer_step <= (self.warm_up_steps + self.decay_steps): 372 | rate = (1.0 - ((self.optimizer_step - self.warm_up_steps) / self.decay_steps)) * self.init_rate 373 | else: 374 | rate = self.min_lr_rate 375 | for p in self.optimizer.param_groups: 376 | p["lr"] = rate 377 | self.optimizer.step() 378 | 379 | 380 | class Main(object): 381 | def __init__(self, train_loader): 382 | self.train_loader = train_loader 383 | self.model = MyModel(pre_train_dir=args["pre_train_dir"]) 384 | 385 | self.model.load_state_dict(torch.load(args["load_path"], map_location=device), strict=False) 386 | param_optimizer = list(self.model.named_parameters()) 387 | no_decay = ['bias', 'gamma', 'beta'] 388 | optimizer_grouped_parameters = [ 389 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 390 | 'weight_decay_rate': args["weight_decay"]}, 391 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 392 | 'weight_decay_rate': 0.0} 393 | ] 394 | 395 | self.optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args["init_lr"]) 396 | self.schedule = WarmUp_LinearDecay(optimizer=self.optimizer, init_rate=args["init_lr"], 397 | warm_up_steps=args["warm_up_steps"], 398 | decay_steps=args["lr_decay_steps"], min_lr_rate=args["min_lr_rate"]) 399 | self.model.to(device=device) 400 | self.model = nn.parallel.DistributedDataParallel(module=self.model, dim=0, find_unused_parameters=True) 401 | 402 | def train(self): 403 | self.model.train() 404 | steps = 0 405 | while True: 406 | for item in self.train_loader: 407 | input_ids, input_mask, input_seg, decode_input, decode_target = \ 408 | item["input_ids"], item["input_mask"], item["input_seg"], item["decode_input"], item[ 409 | "decode_target"] 410 | self.optimizer.zero_grad() 411 | loss = self.model( 412 | input_ids=input_ids.to(device), 413 | input_mask=input_mask.to(device), 414 | input_seg=input_seg.to(device), 415 | decode_input=decode_input.to(device), 416 | decode_target=decode_target.to(device) 417 | ) 418 | loss = loss.float().mean().type_as(loss) 419 | loss.backward() 420 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=args["clip_norm"]) 421 | self.schedule.step() 422 | steps += 1 423 | writer.add_scalar("loss", loss.item(), global_step=steps) 424 | if steps % args["save_interval"] == 0: 425 | torch.save(self.model.module.state_dict(), f=args["save_path"]) 426 | if steps >= args["max_steps"]: 427 | break 428 | if steps >= args["max_steps"]: 429 | break 430 | writer.flush() 431 | writer.close() 432 | 433 | 434 | if __name__ == "__main__": 435 | device = "cuda" 436 | args = { 437 | "init_lr": 2e-5, 438 | "batch_size": 24, 439 | "mos": 2, 440 | "weight_decay": 0.01, 441 | "warm_up_steps": 1000, 442 | "lr_decay_steps": 9000, 443 | "max_steps": 10000, 444 | "min_lr_rate": 1e-9, 445 | "save_interval": 1000, 446 | "save_path": "ModelStorage/xl_dureader_drmc.pth", 447 | "load_path": "ModelStorage/xl_dureader.pth", 448 | "pre_train_dir": "/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/", 449 | "clip_norm": 0.25, 450 | "start_token": "[unused1]", 451 | "end_token": "[unused2]", 452 | "start_token_id": 1, 453 | "end_token_id": 2, 454 | "dimension": 1024, 455 | "max_enc_len": 512, 456 | "max_dec_len": 50, 457 | "max_answer_len": 100, 458 | "use_beam_search": False, 459 | "beam_width": 5, 460 | "decoder_layers": 3, 461 | "dropout": 0.1, 462 | "vocab_size": 21128, 463 | "init_range": 0.02, 464 | "init_std": 0.02 465 | } 466 | 467 | with open("DataSet/multi_task.pkl", "rb") as f: 468 | x = pickle.load(f) 469 | 470 | tokenizer = BertTokenizer(vocab_file="/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/vocab.txt") 471 | 472 | if sys.argv[1] == "train": 473 | torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1, init_method='tcp://localhost:7001') 474 | writer = SummaryWriter(logdir="RunLog/Multi-DRMC") 475 | train_dataset = MyDataset(data=x["drcd_train_items"] + x["cmrc_train_items"], max_enc_len=args["max_enc_len"], 476 | max_dec_len=args["max_dec_len"]) 477 | train_loader = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=4) 478 | 479 | m = Main(train_loader) 480 | m.train() 481 | else: 482 | print("Invalid args.") 483 | -------------------------------------------------------------------------------- /GRUIRMoS.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Time: 2020/9/14 5 | @Author: menghuanlater 6 | @Software: Pycharm 2019.2 7 | @Usage: data preprocess 8 | ----------------------------- 9 | Description: Base on RoBERTa and GRU 10 | -- 增加输入token增强机制(输入的token在解码时具有更高的接受概率) copy mechanism 11 | ----------------------------- 12 | """ 13 | import os 14 | import sys 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2][4:] 17 | from typing import Any 18 | 19 | from transformers import BertTokenizer, BertModel 20 | import torch 21 | from torch import nn 22 | import pickle 23 | from torch.utils.data import DataLoader, Dataset 24 | from torch import optim 25 | import numpy as np 26 | import json 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | class MyDataset(Dataset): 31 | def __init__(self, data, max_enc_len, max_dec_len): 32 | self.data = data 33 | self.max_encode_len = max_enc_len 34 | self.max_decode_len = max_dec_len 35 | self.SEG_A = 0 36 | self.SEG_P = 1 37 | self.ID_PAD = 0 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, index): 43 | item = self.data[index] 44 | context, query, answer = item["context"], item["query"], item["answer"] 45 | context_tokens = tokenizer.tokenize(context) 46 | query_tokens = tokenizer.tokenize(query) 47 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 48 | 49 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 50 | if len(c) > self.max_encode_len - 1: 51 | c = c[:self.max_encode_len - 1] 52 | c += ["[SEP]"] 53 | input_ids = tokenizer.convert_tokens_to_ids(c) 54 | input_mask = [1.0] * len(input_ids) 55 | input_seg = [self.SEG_A] * (len(answer_tokens) + 2) + [self.SEG_P] * (len(input_ids) - 2 - len(answer_tokens)) 56 | extra = self.max_encode_len - len(input_ids) 57 | if extra > 0: 58 | input_ids += [self.ID_PAD] * extra 59 | input_mask += [0.0] * extra 60 | input_seg += [self.SEG_P] * extra 61 | if len(query_tokens) > self.max_decode_len - 1: 62 | query_tokens = query_tokens[: self.max_decode_len - 1] 63 | c = tokenizer.convert_tokens_to_ids(query_tokens) 64 | dec_input = [args["start_token_id"]] + c 65 | dec_target = c + [args["end_token_id"]] 66 | extra = self.max_decode_len - len(dec_input) 67 | if extra > 0: 68 | dec_input += [self.ID_PAD] * extra 69 | dec_target += [self.ID_PAD] * extra 70 | return { 71 | "input_ids": torch.tensor(input_ids).long(), "input_mask": torch.tensor(input_mask).float(), 72 | "input_seg": torch.tensor(input_seg).long(), "decode_input": torch.tensor(dec_input).long(), 73 | "decode_target": torch.tensor(dec_target).long(), "label": query 74 | } 75 | 76 | 77 | class GRUAttnDecoder(torch.nn.Module): 78 | 79 | def _forward_unimplemented(self, *input: Any) -> None: 80 | pass 81 | 82 | def __init__(self, dim, embedding_matrix: nn.Embedding, seq_len): 83 | # 为了保持一致性, context_vector input_vector 以及 hidden_vector保持相同维度 84 | # 同时为了减少参数, 注意力机制采取点积缩放形式 85 | super().__init__() 86 | self.embedding_matrix = embedding_matrix 87 | self.seq_len = seq_len # 解码长度 88 | self.scale = 1 / np.sqrt(dim) 89 | self.reset_gate = nn.Sequential( 90 | nn.Linear(in_features=3 * dim, out_features=dim), 91 | nn.Sigmoid() 92 | ) 93 | self.update_gate = nn.Sequential( 94 | nn.Linear(in_features=3 * dim, out_features=dim), 95 | nn.Sigmoid() 96 | ) 97 | self.update = nn.Sequential( 98 | nn.Linear(in_features=3 * dim, out_features=dim), 99 | nn.Tanh() 100 | ) 101 | # self.pi_mos = nn.Sequential( 102 | # nn.Linear(in_features=dim, out_features=args["mos"]), 103 | # nn.Softmax() 104 | # ) 105 | # self.output = nn.ModuleList() 106 | # for i in range(args["mos"]): 107 | # self.output.append(nn.Linear(in_features=dim, out_features=dim)) 108 | self.output = nn.Linear(in_features=dim, out_features=dim) 109 | self.copy_output = nn.Linear(in_features=dim, out_features=dim) 110 | self.init_hidden_unit = nn.Parameter(torch.FloatTensor(1, dim)) # 状态初始值 111 | 112 | # 自适应的概率结合 113 | self.mode_select = nn.Sequential( 114 | nn.Linear(in_features=dim, out_features=1), 115 | nn.Sigmoid() 116 | ) 117 | self.epsilon = 1e-6 118 | 119 | def forward(self, input_ids, input_context, context_mask, decode_input, decode_target, use_beam_search, beam_width): 120 | """ 121 | :param input_ids: 用于解码增强的输入ids序列 122 | :param input_context: 编码的context (bsz, enc_seq, dim) 123 | :param context_mask: 沿用encoder部分的input_mask, 将pad的输入忽略 124 | :param decode_input: 解码输入 ==> 训练时才有 125 | :param decode_target: 解码目标 ==> 训练时才有, 测试时为空 126 | :param use_beam_search: 是否启动beam search解码 127 | :param beam_width: beam宽度 128 | :return: 训练时返回损失, 测试时返回解码序列 129 | """ 130 | bsz = input_context.size(0) 131 | net_state = self.init_hidden_unit.repeat(bsz, 1) 132 | if decode_target is not None: 133 | dec_list = [] 134 | decode_emb = self.embedding_matrix(decode_input) # 作为输入的一部分(bsz, dec_seq, dim) 135 | for i in range(self.seq_len): 136 | # step1: 通过注意力机制获取当前的context_rep 137 | attn_score = torch.einsum("bsd,bd->bs", input_context, net_state) 138 | attn_score.mul_(self.scale) 139 | attn_score += (1.0 - context_mask) * (-1e30) 140 | attn_prob = torch.softmax(attn_score, dim=-1) 141 | attn_vec = torch.einsum("bs,bsd->bd", attn_prob, input_context) 142 | # step2: 更新状态 143 | x = torch.cat([attn_vec, decode_emb[:, i, :], net_state], dim=-1) 144 | reset_sig = self.reset_gate(x) 145 | update_sig = self.update_gate(x) 146 | update_value = self.update(torch.cat([attn_vec, decode_emb[:, i, :], reset_sig * net_state], dim=-1)) 147 | net_state = (1 - update_sig) * net_state + update_sig * update_value 148 | # step3: 计算分布概率--> mos 149 | vocab_prob_list = [] 150 | # pi_k = self.pi_mos(net_state) 151 | # for k in range(args["mos"]): 152 | # output = self.output[k](net_state) 153 | # vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) 154 | # vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None]) 155 | # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1)) 156 | output = self.output(net_state) 157 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) 158 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 159 | input_logits = torch.einsum("bd,bsd->bs", self.copy_output(net_state), input_context) 160 | input_logits += (1.0 - context_mask) * (-1e30) 161 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_seq) 162 | # step4: 根据mode_sig混合两个概率 163 | mode_sig = self.mode_select(net_state) 164 | vocab_prob = vocab_prob * mode_sig 165 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=input_prob * (1 - mode_sig)) 166 | dec_list.append(vocab_prob[:, None, :]) 167 | # 计算损失 168 | predict = torch.cat(dec_list, dim=1) # (bsz, dec_seq, vocab) 169 | predict = predict.view(size=(-1, predict.size(-1))) 170 | decode_target = decode_target.view(size=(-1,)) 171 | predict = torch.gather(predict, dim=1, index=decode_target[:, None]).squeeze(dim=-1) 172 | init_loss = -torch.log(predict + self.epsilon) 173 | init_loss *= (decode_target != 0).float() 174 | loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) 175 | return loss[None].repeat(bsz) 176 | else: 177 | if use_beam_search: 178 | pass 179 | else: # 贪婪式解码 180 | dec_list = [] 181 | for i in range(self.seq_len): 182 | # step1: 通过注意力机制获取当前的context_rep 183 | attn_score = torch.einsum("bsd,bd->bs", input_context, net_state) 184 | attn_score.mul_(self.scale) 185 | attn_score += (1.0 - context_mask) * (-1e30) 186 | attn_prob = torch.softmax(attn_score, dim=-1) 187 | attn_vec = torch.einsum("bs,bsd->bd", attn_prob, input_context) 188 | # step2: 更新状态 189 | if i == 0: 190 | emb = self.embedding_matrix( 191 | torch.full(size=(bsz,), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device)) 192 | else: 193 | emb = self.embedding_matrix(dec_list[i - 1].squeeze(dim=-1)) 194 | x = torch.cat([attn_vec, emb, net_state], dim=-1) 195 | reset_sig = self.reset_gate(x) 196 | update_sig = self.update_gate(x) 197 | update_value = self.update(torch.cat([attn_vec, emb, reset_sig * net_state], dim=-1)) 198 | net_state = (1 - update_sig) * net_state + update_sig * update_value 199 | # step3: 计算分布得分 200 | vocab_prob_list = [] 201 | # pi_k = self.pi_mos(net_state) 202 | # for k in range(args["mos"]): 203 | # output = self.output[k](net_state) 204 | # vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) 205 | # vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None]) 206 | # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1)) 207 | output = self.output(net_state) 208 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) 209 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 210 | input_logits = torch.einsum("bd,bsd->bs", self.copy_output(net_state), input_context) 211 | input_logits += (1.0 - context_mask) * (-1e30) 212 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_seq) 213 | # step4: 根据mode_sig混合两个概率 214 | mode_sig = self.mode_select(net_state) 215 | vocab_prob = vocab_prob * mode_sig 216 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=input_prob * (1 - mode_sig)) 217 | dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) 218 | return torch.cat(dec_list, dim=-1) 219 | 220 | 221 | class MyModel(torch.nn.Module): 222 | def _forward_unimplemented(self, *input: Any) -> None: 223 | pass 224 | 225 | def __init__(self, pre_train_dir: str): 226 | super().__init__() 227 | self.roberta_encoder = BertModel.from_pretrained(pre_train_dir) 228 | self.decoder_cell = GRUAttnDecoder(dim=args["dimension"], 229 | embedding_matrix=self.roberta_encoder.get_input_embeddings(), 230 | seq_len=args["max_dec_len"]) 231 | if args["freeze_roberta"]: 232 | for p in self.roberta_encoder.parameters(): 233 | p.requires_grad = False 234 | 235 | def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None): 236 | encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0] 237 | return self.decoder_cell(input_ids, encoder_rep, input_mask, decode_input, decode_target, args["use_beam_search"], 238 | args["beam_width"]) 239 | 240 | 241 | class WarmUp_LinearDecay: 242 | def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_steps, decay_steps, min_lr_rate): 243 | self.optimizer = optimizer 244 | self.init_rate = init_rate 245 | self.warm_up_steps = warm_up_steps 246 | self.decay_steps = decay_steps 247 | self.min_lr_rate = min_lr_rate 248 | self.optimizer_step = 0 249 | 250 | def step(self): 251 | self.optimizer_step += 1 252 | if self.optimizer_step <= self.warm_up_steps: 253 | rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate 254 | elif self.warm_up_steps < self.optimizer_step <= (self.warm_up_steps + self.decay_steps): 255 | rate = (1.0 - ((self.optimizer_step - self.warm_up_steps) / self.decay_steps)) * self.init_rate 256 | else: 257 | rate = self.min_lr_rate 258 | for p in self.optimizer.param_groups: 259 | p["lr"] = rate 260 | self.optimizer.step() 261 | 262 | 263 | class Main(object): 264 | def __init__(self, train_loader, valid_loader, test_flag=False, test_items=None): 265 | self.train_loader = train_loader 266 | self.valid_loader = valid_loader 267 | self.test_items = test_items 268 | self.model = MyModel(pre_train_dir=args["pre_train_dir"]) 269 | 270 | if test_flag: 271 | self.model.load_state_dict(torch.load(args["save_path"], map_location=device), strict=False) 272 | else: 273 | if args["warm_start"]: 274 | self.model.load_state_dict(torch.load(args["save_path"], map_location=device), strict=False) 275 | param_optimizer = list(self.model.named_parameters()) 276 | no_decay = ['bias', 'gamma', 'beta'] 277 | optimizer_grouped_parameters = [ 278 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 279 | 'weight_decay_rate': args["weight_decay"]}, 280 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 281 | 'weight_decay_rate': 0.0} 282 | ] 283 | 284 | self.optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args["init_lr"]) 285 | self.schedule = WarmUp_LinearDecay(optimizer=self.optimizer, init_rate=args["init_lr"], 286 | warm_up_steps=args["warm_up_steps"], 287 | decay_steps=args["lr_decay_steps"], min_lr_rate=args["min_lr_rate"]) 288 | self.model.to(device=device) 289 | if args["is_train"]: 290 | self.model = nn.parallel.DistributedDataParallel(module=self.model, dim=0, find_unused_parameters=True) 291 | 292 | def train(self): 293 | best_rl = 0.0 294 | self.model.train() 295 | steps = 0 296 | while True: 297 | for item in self.train_loader: 298 | input_ids, input_mask, input_seg, decode_input, decode_target = \ 299 | item["input_ids"], item["input_mask"], item["input_seg"], item["decode_input"], item["decode_target"] 300 | self.optimizer.zero_grad() 301 | loss = self.model( 302 | input_ids=input_ids.to(device), 303 | input_mask=input_mask.to(device), 304 | input_seg=input_seg.to(device), 305 | decode_input=decode_input.to(device), 306 | decode_target=decode_target.to(device) 307 | ) 308 | loss = loss.float().mean().type_as(loss) 309 | loss.backward() 310 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=args["clip_norm"]) 311 | self.schedule.step() 312 | steps += 1 313 | writer.add_scalar("loss", loss.item(), global_step=steps) 314 | if steps % args["eval_interval"] == 0: 315 | rl = self.valid() 316 | writer.add_scalar("valid_rl", rl, global_step=steps) 317 | if rl > best_rl: 318 | best_rl = rl 319 | torch.save(self.model.module.state_dict(), f=args["save_path"]) 320 | if steps >= args["max_steps"]: 321 | break 322 | if steps >= args["max_steps"]: 323 | break 324 | writer.flush() 325 | writer.close() 326 | 327 | def valid(self): 328 | self.model.eval() 329 | rouge_l = [] 330 | with torch.no_grad(): 331 | for item in self.valid_loader: 332 | input_ids, input_mask, input_seg, label = item["input_ids"], item["input_mask"], item["input_seg"], item["label"] 333 | dec_seq = self.model( 334 | input_ids=input_ids.to(device), 335 | input_mask=input_mask.to(device), 336 | input_seg=input_seg.to(device) 337 | ) 338 | dec_seq = dec_seq.cpu().numpy() 339 | for i in range(len(dec_seq)): 340 | x = dec_seq[i] 341 | s = [] 342 | for j in x: 343 | if int(j) == args["end_token_id"]: 344 | break 345 | else: 346 | s.append(int(j)) 347 | s = "".join(tokenizer.convert_ids_to_tokens(s)) 348 | s = s.replace(",", "").replace("[UNK]", "") 349 | char_lis = [] 350 | for c in s: 351 | if c not in char_lis: 352 | char_lis.append(c) 353 | for c in char_lis: 354 | try: 355 | p = re.compile("(%s){2,}" % c) 356 | s = re.sub(p, c, s) 357 | except Exception as e: 358 | continue 359 | rouge_l.append(self.rouge_l(hypo=s, refer=label[i])) 360 | self.model.train() 361 | return np.average(rouge_l) 362 | 363 | @staticmethod 364 | def test_encode(context, answer): 365 | context_tokens = tokenizer.tokenize(context) 366 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 367 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 368 | if len(c) > args["max_enc_len"] - 1: 369 | c = c[:args["max_enc_len"] - 1] 370 | c += ["[SEP]"] 371 | input_ids = tokenizer.convert_tokens_to_ids(c) 372 | input_mask = [1.0] * len(input_ids) 373 | input_seg = [0] * (len(answer_tokens) + 2) + [1] * (len(input_ids) - 2 - len(answer_tokens)) 374 | extra = args["max_enc_len"] - len(input_ids) 375 | if extra > 0: 376 | input_ids += [0] * extra 377 | input_mask += [0.0] * extra 378 | input_seg += [1] * extra 379 | return { 380 | "input_ids": torch.tensor(input_ids).long().unsqueeze(dim=0).to(device), 381 | "input_mask": torch.tensor(input_mask).float().unsqueeze(dim=0).to(device), 382 | "input_seg": torch.tensor(input_seg).long().unsqueeze(dim=0).to(device) 383 | } 384 | 385 | def test(self): 386 | output = x["test_items"] 387 | for i in range(len(output)): 388 | text = output[i]["text"] 389 | annotations = output[i]["annotations"] 390 | tmp_enc_ids, tmp_enc_mask, tmp_enc_seg = [], [], [] 391 | for j in range(len(annotations)): 392 | y = self.test_encode(text, annotations[j]["A"]) 393 | tmp_enc_ids.append(y["input_ids"]) 394 | tmp_enc_mask.append(y["input_mask"]) 395 | tmp_enc_seg.append(y["input_seg"]) 396 | dec_seq = self.model( 397 | input_ids=torch.cat(tmp_enc_ids, dim=0), 398 | input_mask=torch.cat(tmp_enc_mask, dim=0), 399 | input_seg=torch.cat(tmp_enc_seg, dim=0) 400 | ) 401 | dec_seq = dec_seq.cpu().numpy() 402 | for j in range(len(dec_seq)): 403 | y = dec_seq[j] 404 | s = [] 405 | for k in y: 406 | if int(k) == args["end_token_id"]: 407 | break 408 | else: 409 | s.append(int(k)) 410 | s = "".join(tokenizer.convert_ids_to_tokens(s)) 411 | s = s.replace(",", "").replace("[UNK]", "") 412 | char_lis = [] 413 | for c in s: 414 | if c not in char_lis: 415 | char_lis.append(c) 416 | for c in char_lis: 417 | try: 418 | p = re.compile("(%s){2,}" % c) 419 | s = re.sub(p, c, s) 420 | except Exception as e: 421 | continue 422 | annotations[j]["Q"] = s 423 | if i % 50 == 0 and i > 0: 424 | print("The program has completed %s predictions" % i) 425 | with open("submit.json", "w", encoding="UTF-8") as f: 426 | json.dump(output, f, ensure_ascii=False) 427 | print("The program has completed all predictions") 428 | 429 | @staticmethod 430 | def rouge_l(hypo, refer): 431 | if len(hypo) == 0 or len(refer) == 0: 432 | return 0 433 | x = [[0 for _ in range(len(refer) + 1)] for _ in range(len(hypo) + 1)] 434 | lcs = 0 435 | for i in range(len(hypo)): 436 | for j in range(len(refer)): 437 | if hypo[i] == refer[j]: 438 | x[i + 1][j + 1] = x[i][j] + 1 439 | if x[i + 1][j + 1] > lcs: 440 | lcs = x[i + 1][j + 1] 441 | else: 442 | x[i + 1][j + 1] = max(x[i + 1][j], x[i][j + 1]) 443 | p, r = lcs / len(hypo), lcs / len(refer) 444 | if (p + r) == 0: 445 | return 0 446 | else: 447 | return (2 * p * r) / (p + r) 448 | 449 | 450 | if __name__ == "__main__": 451 | device = "cuda" 452 | 453 | args = { 454 | "init_lr": 2e-5, 455 | "batch_size": 10, 456 | "weight_decay": 0.01, 457 | "warm_up_steps": 1000, 458 | "lr_decay_steps": 15000, 459 | "max_steps": 18000, 460 | "min_lr_rate": 1e-9, 461 | "eval_interval": 1000, 462 | "save_path": "ModelStorage/gru_ir.pth", 463 | "mos": 4, 464 | "pre_train_dir": "/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/", 465 | "clip_norm": 0.25, 466 | "start_token": "[unused1]", 467 | "end_token": "[unused2]", 468 | "start_token_id": 1, 469 | "end_token_id": 2, 470 | "dimension": 1024, 471 | "max_enc_len": 512, 472 | "max_dec_len": 50, 473 | "max_answer_len": 100, 474 | "use_beam_search": False, 475 | "beam_width": 5, 476 | "warm_start": False, 477 | "freeze_roberta": False 478 | } 479 | 480 | with open("DataSet/baseline.pkl", "rb") as f: 481 | x = pickle.load(f) 482 | 483 | tokenizer = BertTokenizer(vocab_file="/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/vocab.txt") 484 | 485 | if sys.argv[1] == "train": 486 | torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1, init_method='tcp://localhost:6666') 487 | args["is_train"] = True 488 | writer = SummaryWriter(logdir="RunLog/%s" % sys.argv[3]) 489 | train_dataset = MyDataset(data=x["train_items"], max_enc_len=args["max_enc_len"], 490 | max_dec_len=args["max_dec_len"]) 491 | valid_dataset = MyDataset(data=x["valid_items"], max_enc_len=args["max_enc_len"], 492 | max_dec_len=args["max_dec_len"]) 493 | 494 | train_loader = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=4) 495 | valid_loader = DataLoader(valid_dataset, batch_size=args["batch_size"], shuffle=False, num_workers=4) 496 | 497 | m = Main(train_loader, valid_loader) 498 | m.train() 499 | else: 500 | args["is_train"] = False 501 | writer = None 502 | m = Main(None, None, test_flag=True, test_items=x["test_items"]) 503 | m.test() 504 | -------------------------------------------------------------------------------- /MultiTaskXLIR-DuReader.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Time: 2020/9/14 5 | @Author: menghuanlater 6 | @Software: Pycharm 2019.2 7 | @Usage: data preprocess 8 | ----------------------------- 9 | Description: Base on RoBERTa and Transformer-XL Decoder and Copy Mechanism 10 | Transformer Decoder采用Transformer-XL 11 | ----------------------------- 12 | """ 13 | import os 14 | import sys 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2][4:] 17 | from typing import Any 18 | 19 | from transformers import BertTokenizer, BertModel 20 | import torch 21 | from torch import nn 22 | import pickle 23 | from torch.utils.data import DataLoader, Dataset 24 | from torch import optim 25 | import numpy as np 26 | import json 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | class MyDataset(Dataset): 31 | def __init__(self, data, max_enc_len, max_dec_len): 32 | self.data = data 33 | self.max_encode_len = max_enc_len 34 | self.max_decode_len = max_dec_len 35 | self.SEG_A = 0 36 | self.SEG_P = 1 37 | self.ID_PAD = 0 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, index): 43 | item = self.data[index] 44 | context, query, answer = item["context"], item["query"], item["answer"] 45 | context_tokens = tokenizer.tokenize(context) 46 | query_tokens = tokenizer.tokenize(query) 47 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 48 | 49 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 50 | if len(c) > self.max_encode_len - 1: 51 | c = c[:self.max_encode_len - 1] 52 | c += ["[SEP]"] 53 | input_ids = tokenizer.convert_tokens_to_ids(c) 54 | input_mask = [1.0] * len(input_ids) 55 | input_seg = [self.SEG_A] * (len(answer_tokens) + 2) + [self.SEG_P] * (len(input_ids) - 2 - len(answer_tokens)) 56 | extra = self.max_encode_len - len(input_ids) 57 | if extra > 0: 58 | input_ids += [self.ID_PAD] * extra 59 | input_mask += [0.0] * extra 60 | input_seg += [self.SEG_P] * extra 61 | if len(query_tokens) > self.max_decode_len - 1: 62 | query_tokens = query_tokens[: self.max_decode_len - 1] 63 | c = tokenizer.convert_tokens_to_ids(query_tokens) 64 | dec_input = [args["start_token_id"]] + c 65 | dec_target = c + [args["end_token_id"]] 66 | extra = self.max_decode_len - len(dec_input) 67 | if extra > 0: 68 | dec_input += [self.ID_PAD] * extra 69 | dec_target += [self.ID_PAD] * extra 70 | return { 71 | "input_ids": torch.tensor(input_ids).long(), "input_mask": torch.tensor(input_mask).float(), 72 | "input_seg": torch.tensor(input_seg).long(), "decode_input": torch.tensor(dec_input).long(), 73 | "decode_target": torch.tensor(dec_target).long(), "label": query 74 | } 75 | 76 | 77 | class XLRelPosEmb(nn.Module): 78 | def _forward_unimplemented(self, *input: Any) -> None: 79 | pass 80 | 81 | def __init__(self, d_embed): 82 | super().__init__() 83 | 84 | self.d_embed = d_embed 85 | inv_freq = 1 / (10000 ** (torch.arange(0.0, self.d_embed, 2.0) / self.d_embed)) 86 | self.register_buffer("inv_freq", inv_freq) 87 | 88 | def forward(self, pos_seq): 89 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 90 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 91 | return pos_emb 92 | 93 | 94 | class PositionwiseFFN(nn.Module): 95 | def _forward_unimplemented(self, *input: Any) -> None: 96 | pass 97 | 98 | def __init__(self, d_model, d_inner, layer_norm_epsilon=1e-5): 99 | super().__init__() 100 | self.d_model = d_model 101 | self.d_inner = d_inner 102 | self.CoreNet = nn.Sequential( 103 | nn.Linear(d_model, d_inner), 104 | nn.GELU(), 105 | nn.Dropout(p=args["dropout"]), 106 | nn.Linear(d_inner, d_model), 107 | nn.Dropout(p=args["dropout"]) 108 | ) 109 | self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 110 | 111 | def forward(self, inp): 112 | core_out = self.CoreNet(inp) 113 | output = self.layer_norm(inp + core_out) 114 | return output 115 | 116 | 117 | class RelPartialLearnableMultiHeadAttn(torch.nn.Module): 118 | 119 | def _forward_unimplemented(self, *input: Any) -> None: 120 | pass 121 | 122 | def __init__(self, n_heads, d_model, layer_norm_epsilon=1e-5): 123 | super().__init__() 124 | 125 | self.n_heads = n_heads 126 | self.d_model = d_model 127 | self.d_head = d_model // n_heads 128 | 129 | self.mask_attn_qkv_net = nn.Linear(d_model, 3 * d_model, bias=False) 130 | self.mask_attn_o_net = nn.Linear(d_model, d_model, bias=False) 131 | 132 | self.interaction_kv_net = nn.Linear(d_model, 2 * d_model, bias=False) 133 | self.interaction_q_net = nn.Linear(d_model, d_model, bias=False) 134 | self.interaction_o_net = nn.Linear(d_model, d_model, bias=False) 135 | 136 | self.layer_norm_mask_attn = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 137 | self.layer_norm_interaction = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 138 | self.scale = 1 / (self.d_head ** 0.5) 139 | 140 | self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 141 | self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 142 | 143 | self.r_net = nn.Linear(d_model, d_model, bias=False) 144 | 145 | self.drop = nn.Dropout(p=args["dropout"]) 146 | 147 | @staticmethod 148 | def _rel_shift(x): 149 | zero_pad_shape = (x.size(0), 1) + x.size()[2:] 150 | zero_pad = torch.zeros(zero_pad_shape, device=device, dtype=x.dtype) 151 | x_padded = torch.cat([zero_pad, x], dim=1) 152 | 153 | x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] 154 | x_padded = x_padded.view(*x_padded_shape) 155 | 156 | x = x_padded[1:].view_as(x) 157 | 158 | return x 159 | 160 | def forward(self, w, r, enc_context, attn_mask, padding_mask): 161 | # attn_mask用于Masked-Attn Mechanism(decode自身部分) 162 | # padding_mask用于Norm Multi-Attn, Decode与Encode Contextual Rep交互 163 | dec_len, bsz, enc_len = w.size(0), w.size(1), enc_context.size(0) 164 | w_heads = self.mask_attn_qkv_net(w) 165 | r_head_k = self.r_net(r) 166 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 167 | 168 | w_head_q = w_head_q.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 169 | w_head_k = w_head_k.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 170 | w_head_v = w_head_v.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 171 | 172 | r_head_k = r_head_k.view(dec_len, self.n_heads, self.d_head) # dec_len x n_head x d_head 173 | rw_head_q = w_head_q + self.r_w_bias # dec_len x bsz x n_head x d_head 174 | AC = torch.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # dec_len x dec_len x bsz x n_head 175 | rr_head_q = w_head_q + self.r_r_bias 176 | BD = torch.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # dec_len x dec_len x bsz x n_head 177 | BD = self._rel_shift(BD) 178 | 179 | attn_score = AC + BD 180 | attn_score.mul_(self.scale) 181 | 182 | # causal masking mechanism 183 | attn_mask = attn_mask == 0 # switch to bool 184 | attn_score = attn_score.float().masked_fill(attn_mask, -1e30).type_as(attn_score) 185 | attn_prob = torch.softmax(attn_score, dim=1) 186 | attn_prob = self.drop(attn_prob) 187 | 188 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) 189 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 190 | 191 | attn_out = self.mask_attn_o_net(attn_vec) 192 | attn_out = self.drop(attn_out) 193 | 194 | mask_attn_output = self.layer_norm_mask_attn(w + attn_out) 195 | 196 | # 与编码器交互部分 197 | inter_k, inter_v = torch.chunk(self.interaction_kv_net(enc_context), 2, dim=-1) 198 | inter_q = self.interaction_q_net(mask_attn_output) 199 | inter_q = inter_q.view(dec_len, bsz, self.n_heads, self.d_head) 200 | inter_k = inter_k.view(enc_len, bsz, self.n_heads, self.d_head) 201 | inter_v = inter_v.view(enc_len, bsz, self.n_heads, self.d_head) 202 | 203 | attn_score = torch.einsum("qbnd,kbnd->qkbn", inter_q, inter_k) 204 | attn_score.mul_(self.scale) 205 | 206 | # use padding_mask to mask input [PAD] token 207 | padding_mask = padding_mask[None, :, :, None].repeat(dec_len, 1, 1, 1) 208 | attn_score = attn_score + (1 - padding_mask) * (-1e30) 209 | attn_prob = torch.softmax(attn_score, dim=1) 210 | attn_prob = self.drop(attn_prob) 211 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, inter_v) 212 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 213 | 214 | attn_out = self.interaction_o_net(attn_vec) 215 | attn_out = self.drop(attn_out) 216 | 217 | interaction_output = self.layer_norm_interaction(attn_out + mask_attn_output) 218 | return interaction_output 219 | 220 | 221 | class RelPartialLearnableDecoderLayer(torch.nn.Module): 222 | 223 | def _forward_unimplemented(self, *input: Any) -> None: 224 | pass 225 | 226 | def __init__(self, n_heads, d_model, d_inner): 227 | super().__init__() 228 | 229 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_heads=n_heads, d_model=d_model) 230 | self.ffn_layer = PositionwiseFFN(d_model=d_model, d_inner=d_inner) 231 | 232 | def forward(self, dec_inp, r, enc_inp, dec_mask, enc_mask): 233 | attn_output = self.dec_attn(w=dec_inp, r=r, enc_context=enc_inp, attn_mask=dec_mask, padding_mask=enc_mask) 234 | ffn_out = self.ffn_layer(attn_output) 235 | return ffn_out 236 | 237 | 238 | class XLDecoder(torch.nn.Module): 239 | 240 | def _forward_unimplemented(self, *input: Any) -> None: 241 | pass 242 | 243 | def __init__(self, dim, embedding_matrix: nn.Embedding, seq_len): 244 | super().__init__() 245 | self.d_model = dim 246 | self.word_emb = embedding_matrix 247 | self.seq_len = seq_len 248 | self.n_layers = args["decoder_layers"] 249 | self.n_heads = 16 250 | self.ffn = 4 * dim 251 | self.epsilon = 1e-6 252 | 253 | self.drop = nn.Dropout(p=args["dropout"]) 254 | self.pos_emb = XLRelPosEmb(d_embed=dim) 255 | self.layers = nn.ModuleList() 256 | 257 | self.layers = nn.ModuleList() 258 | for i in range(self.n_layers): 259 | self.layers.append( 260 | RelPartialLearnableDecoderLayer( 261 | n_heads=self.n_heads, d_model=self.d_model, d_inner=self.ffn 262 | ) 263 | ) 264 | self.output = nn.Linear(in_features=dim, out_features=dim) 265 | self.copy_output = nn.Linear(in_features=dim, out_features=dim) 266 | # 自适应的解码概率结合 267 | self.mode_select = nn.Sequential( 268 | nn.Linear(in_features=dim, out_features=1), 269 | nn.Sigmoid() 270 | ) 271 | 272 | def forward(self, input_ids, encoder_rep, input_mask, decode_input, decode_target, use_beam_search, beam_width): 273 | bsz = input_ids.size(0) 274 | if decode_input is not None: # 代表训练模式 275 | input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1) 276 | decode_embed = self.drop(self.word_emb(decode_input)) 277 | all_ones = decode_embed.new_ones((self.seq_len, self.seq_len), dtype=torch.uint8) 278 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 279 | pos_seq = torch.arange(self.seq_len - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 280 | pos_embed = self.drop(self.pos_emb(pos_seq)) 281 | core_out = decode_embed.transpose(0, 1).contiguous() 282 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 283 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 284 | for layer in self.layers: 285 | core_out = layer( 286 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 287 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 288 | ) 289 | core_out = self.drop(core_out.transpose(0, 1).contiguous()) # (bsz, dec_len, dim) 290 | output = self.output(core_out) 291 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 292 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 293 | input_logits = torch.einsum("bid,bjd->bij", self.copy_output(core_out), encoder_rep) # (bsz, dec_len, enc_len) 294 | input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat(1, self.seq_len, 1)) * (-1e30) 295 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, dec_len, enc_len) 296 | mode_sig = self.mode_select(core_out) # (bsz, dec_len, 1) 297 | vocab_prob = vocab_prob * mode_sig 298 | vocab_prob = torch.scatter_add(vocab_prob, dim=2, index=input_ids, src=(1 - mode_sig) * input_prob) 299 | vocab_prob = vocab_prob.view(-1, args["vocab_size"]) 300 | decode_target = decode_target.view(-1) 301 | predict = torch.gather(vocab_prob, dim=1, index=decode_target[:, None]).squeeze(dim=-1) 302 | init_loss = -torch.log(predict + self.epsilon) 303 | init_loss *= (decode_target != 0).float() 304 | loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) 305 | # 为了并行化设计, 将loss变成(bsz,) 306 | return loss[None].repeat(bsz) 307 | else: # 代表验证或者测试解码模式 ==> 比较耗时 308 | dec_list = [] 309 | decode_ids = torch.full(size=(bsz, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) 310 | for i in range(1, self.seq_len + 1): 311 | if i > 1: 312 | decode_ids = torch.cat([decode_ids, dec_list[i - 2]], dim=-1) 313 | decode_embed = self.word_emb(decode_ids) 314 | all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) 315 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 316 | pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 317 | pos_embed = self.pos_emb(pos_seq) 318 | core_out = decode_embed.transpose(0, 1).contiguous() 319 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 320 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 321 | for layer in self.layers: 322 | core_out = layer( 323 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 324 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 325 | ) 326 | core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] 327 | output = self.output(core_out) 328 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 329 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 330 | input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz, enc_len) 331 | input_logits = input_logits + (1.0 - input_mask) * (-1e30) 332 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_len) 333 | mode_sig = self.mode_select(core_out) # (bsz, 1) 334 | vocab_prob = vocab_prob * mode_sig 335 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) 336 | dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) 337 | return torch.cat(dec_list, dim=-1) 338 | 339 | 340 | class MyModel(torch.nn.Module): 341 | def _forward_unimplemented(self, *input: Any) -> None: 342 | pass 343 | 344 | def __init__(self, pre_train_dir: str): 345 | super().__init__() 346 | self.roberta_encoder = BertModel.from_pretrained(pre_train_dir) 347 | self.decoder_layer = XLDecoder( 348 | dim=args["dimension"], embedding_matrix=self.roberta_encoder.get_input_embeddings(), 349 | seq_len=args["max_dec_len"]) 350 | 351 | def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None): 352 | encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0] 353 | return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target, 354 | args["use_beam_search"], 355 | args["beam_width"]) 356 | 357 | 358 | class InitializeNetWeight(object): 359 | def __init__(self, init_range, init_std): 360 | self.init_range = init_range 361 | self.init_std = init_std 362 | 363 | def _init_weight(self, weight): 364 | nn.init.normal_(weight, self.init_range, self.init_std) 365 | 366 | @staticmethod 367 | def _init_bias(bias): 368 | nn.init.constant_(bias, 0) 369 | 370 | def _init_emb_proj(self, proj): 371 | if self.init_method == "normal": 372 | nn.init.normal_(proj, 0.0, self.proj_init_std) 373 | elif self.init_method == "uniform": 374 | nn.init.uniform_(proj, self.init_range, self.proj_init_std) 375 | 376 | def _init_weights(self, m): 377 | """ 378 | :param parameters: 379 | """ 380 | classname = m.__class__.__name__ 381 | if classname.find("Embedding") != -1: # 解码器部分不能动embedding矩阵的参数 382 | return 383 | if classname.find("Linear") != -1: 384 | if hasattr(m, "weight") and m.weight is not None: 385 | self._init_weight(m.weight) 386 | if hasattr(m, "bias") and m.bias is not None: 387 | self._init_bias(m.bias) 388 | elif classname.find("LayerNorm") != -1: 389 | if hasattr(m, "weight"): 390 | nn.init.normal_(m.weight, 1.0, self.init_std) 391 | if hasattr(m, "bias") and m.bias is not None: 392 | self._init_bias(m.bias) 393 | else: 394 | if hasattr(m, "r_w_bias"): 395 | self._init_weight(m.r_w_bias) 396 | if hasattr(m, "r_r_bias"): 397 | self._init_weight(m.r_r_bias) 398 | if hasattr(m, "bias"): 399 | self._init_bias(m.bias) 400 | 401 | def init_weights(self, model): 402 | model.apply(self._init_weights) 403 | print("random initialize weights succeed.") 404 | 405 | 406 | class WarmUp_LinearDecay: 407 | def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_steps, decay_steps, min_lr_rate): 408 | self.optimizer = optimizer 409 | self.init_rate = init_rate 410 | self.warm_up_steps = warm_up_steps 411 | self.decay_steps = decay_steps 412 | self.min_lr_rate = min_lr_rate 413 | self.optimizer_step = 0 414 | 415 | def step(self): 416 | self.optimizer_step += 1 417 | if self.optimizer_step <= self.warm_up_steps: 418 | rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate 419 | elif self.warm_up_steps < self.optimizer_step <= (self.warm_up_steps + self.decay_steps): 420 | rate = (1.0 - ((self.optimizer_step - self.warm_up_steps) / self.decay_steps)) * self.init_rate 421 | else: 422 | rate = self.min_lr_rate 423 | for p in self.optimizer.param_groups: 424 | p["lr"] = rate 425 | self.optimizer.step() 426 | 427 | 428 | class Main(object): 429 | def __init__(self, train_loader): 430 | self.train_loader = train_loader 431 | self.model = MyModel(pre_train_dir=args["pre_train_dir"]) 432 | 433 | self.init_obj = InitializeNetWeight(init_std=args["init_std"], init_range=args["init_range"]) 434 | self.init_obj.init_weights(self.model.decoder_layer) 435 | param_optimizer = list(self.model.named_parameters()) 436 | no_decay = ['bias', 'gamma', 'beta'] 437 | optimizer_grouped_parameters = [ 438 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 439 | 'weight_decay_rate': args["weight_decay"]}, 440 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 441 | 'weight_decay_rate': 0.0} 442 | ] 443 | 444 | self.optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args["init_lr"]) 445 | self.schedule = WarmUp_LinearDecay(optimizer=self.optimizer, init_rate=args["init_lr"], 446 | warm_up_steps=args["warm_up_steps"], 447 | decay_steps=args["lr_decay_steps"], min_lr_rate=args["min_lr_rate"]) 448 | self.model.to(device=device) 449 | self.model = nn.parallel.DistributedDataParallel(module=self.model, dim=0, find_unused_parameters=True) 450 | 451 | def train(self): 452 | self.model.train() 453 | steps = 0 454 | while True: 455 | for item in self.train_loader: 456 | input_ids, input_mask, input_seg, decode_input, decode_target = \ 457 | item["input_ids"], item["input_mask"], item["input_seg"], item["decode_input"], item[ 458 | "decode_target"] 459 | self.optimizer.zero_grad() 460 | loss = self.model( 461 | input_ids=input_ids.to(device), 462 | input_mask=input_mask.to(device), 463 | input_seg=input_seg.to(device), 464 | decode_input=decode_input.to(device), 465 | decode_target=decode_target.to(device) 466 | ) 467 | loss = loss.float().mean().type_as(loss) 468 | loss.backward() 469 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=args["clip_norm"]) 470 | self.schedule.step() 471 | steps += 1 472 | writer.add_scalar("loss", loss.item(), global_step=steps) 473 | if steps % args["save_interval"] == 0: 474 | torch.save(self.model.module.state_dict(), f=args["save_path"]) 475 | if steps >= args["max_steps"]: 476 | break 477 | if steps >= args["max_steps"]: 478 | break 479 | writer.flush() 480 | writer.close() 481 | 482 | 483 | if __name__ == "__main__": 484 | device = "cuda" 485 | args = { 486 | "init_lr": 2e-5, 487 | "batch_size": 24, 488 | "mos": 2, 489 | "weight_decay": 0.01, 490 | "warm_up_steps": 3600, 491 | "lr_decay_steps": 56000, 492 | "max_steps": 60000, 493 | "min_lr_rate": 1e-9, 494 | "save_interval": 1000, 495 | "save_path": "ModelStorage/xl_dureader.pth", 496 | "pre_train_dir": "/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/", 497 | "clip_norm": 0.25, 498 | "start_token": "[unused1]", 499 | "end_token": "[unused2]", 500 | "start_token_id": 1, 501 | "end_token_id": 2, 502 | "dimension": 1024, 503 | "max_enc_len": 512, 504 | "max_dec_len": 50, 505 | "max_answer_len": 100, 506 | "use_beam_search": False, 507 | "beam_width": 5, 508 | "decoder_layers": 3, 509 | "dropout": 0.1, 510 | "vocab_size": 21128, 511 | "init_range": 0.02, 512 | "init_std": 0.02 513 | } 514 | 515 | with open("DataSet/multi_task.pkl", "rb") as f: 516 | x = pickle.load(f) 517 | 518 | tokenizer = BertTokenizer(vocab_file="/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/vocab.txt") 519 | 520 | if sys.argv[1] == "train": 521 | torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1, init_method='tcp://localhost:7001') 522 | writer = SummaryWriter(logdir="RunLog/Multi-DuReader") 523 | train_dataset = MyDataset(data=x["dureader_train_items"], max_enc_len=args["max_enc_len"], 524 | max_dec_len=args["max_dec_len"]) 525 | train_loader = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=4) 526 | 527 | m = Main(train_loader) 528 | m.train() 529 | else: 530 | print("Invalid args.") 531 | -------------------------------------------------------------------------------- /MultiTaskXLIR-Final.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | """ 4 | @Time: 2020/9/14 5 | @Author: menghuanlater 6 | @Software: Pycharm 2019.2 7 | @Usage: data preprocess 8 | ----------------------------- 9 | Description: Base on RoBERTa and Transformer-XL Decoder and Copy Mechanism 10 | Transformer Decoder采用Transformer-XL 11 | ----------------------------- 12 | """ 13 | import os 14 | import sys 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2][4:] 17 | from typing import Any 18 | 19 | from transformers import BertTokenizer, BertModel, BertConfig 20 | import torch 21 | from torch import nn 22 | import pickle 23 | from torch.utils.data import DataLoader, Dataset 24 | from torch import optim 25 | import numpy as np 26 | import json 27 | import re 28 | from copy import deepcopy 29 | from tensorboardX import SummaryWriter 30 | 31 | 32 | class MyDataset(Dataset): 33 | def __init__(self, data, max_enc_len, max_dec_len): 34 | self.data = data 35 | self.max_encode_len = max_enc_len 36 | self.max_decode_len = max_dec_len 37 | self.SEG_A = 0 38 | self.SEG_P = 1 39 | self.ID_PAD = 0 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, index): 45 | item = self.data[index] 46 | context, query, answer = item["context"], item["query"], item["answer"] 47 | context_tokens = tokenizer.tokenize(context.replace("\n", " ").replace("\t", " ").replace("\\", "")) 48 | query_tokens = tokenizer.tokenize(query) 49 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 50 | 51 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 52 | if len(c) > self.max_encode_len - 1: 53 | c = c[:self.max_encode_len - 1] 54 | c += ["[SEP]"] 55 | input_ids = tokenizer.convert_tokens_to_ids(c) 56 | input_mask = [1.0] * len(input_ids) 57 | input_seg = [self.SEG_A] * (len(answer_tokens) + 2) + [self.SEG_P] * (len(input_ids) - 2 - len(answer_tokens)) 58 | extra = self.max_encode_len - len(input_ids) 59 | if extra > 0: 60 | input_ids += [self.ID_PAD] * extra 61 | input_mask += [0.0] * extra 62 | input_seg += [self.SEG_P] * extra 63 | if len(query_tokens) > self.max_decode_len - 1: 64 | query_tokens = query_tokens[: self.max_decode_len - 1] 65 | c = tokenizer.convert_tokens_to_ids(query_tokens) 66 | dec_input = [args["start_token_id"]] + c 67 | dec_target = c + [args["end_token_id"]] 68 | extra = self.max_decode_len - len(dec_input) 69 | if extra > 0: 70 | dec_input += [self.ID_PAD] * extra 71 | dec_target += [self.ID_PAD] * extra 72 | return { 73 | "input_ids": torch.tensor(input_ids).long(), "input_mask": torch.tensor(input_mask).float(), 74 | "input_seg": torch.tensor(input_seg).long(), "decode_input": torch.tensor(dec_input).long(), 75 | "decode_target": torch.tensor(dec_target).long(), "label": query 76 | } 77 | 78 | 79 | class XLRelPosEmb(nn.Module): 80 | def _forward_unimplemented(self, *input: Any) -> None: 81 | pass 82 | 83 | def __init__(self, d_embed): 84 | super().__init__() 85 | 86 | self.d_embed = d_embed 87 | inv_freq = 1 / (10000 ** (torch.arange(0.0, self.d_embed, 2.0) / self.d_embed)) 88 | self.register_buffer("inv_freq", inv_freq) 89 | 90 | def forward(self, pos_seq): 91 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 92 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 93 | return pos_emb 94 | 95 | 96 | class PositionwiseFFN(nn.Module): 97 | def _forward_unimplemented(self, *input: Any) -> None: 98 | pass 99 | 100 | def __init__(self, d_model, d_inner, layer_norm_epsilon=1e-5): 101 | super().__init__() 102 | self.d_model = d_model 103 | self.d_inner = d_inner 104 | self.CoreNet = nn.Sequential( 105 | nn.Linear(d_model, d_inner), 106 | nn.GELU(), 107 | nn.Dropout(p=args["dropout"]), 108 | nn.Linear(d_inner, d_model), 109 | nn.Dropout(p=args["dropout"]) 110 | ) 111 | self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 112 | 113 | def forward(self, inp): 114 | core_out = self.CoreNet(inp) 115 | output = self.layer_norm(inp + core_out) 116 | return output 117 | 118 | 119 | class RelPartialLearnableMultiHeadAttn(torch.nn.Module): 120 | 121 | def _forward_unimplemented(self, *input: Any) -> None: 122 | pass 123 | 124 | def __init__(self, n_heads, d_model, layer_norm_epsilon=1e-5): 125 | super().__init__() 126 | 127 | self.n_heads = n_heads 128 | self.d_model = d_model 129 | self.d_head = d_model // n_heads 130 | 131 | self.mask_attn_qkv_net = nn.Linear(d_model, 3 * d_model, bias=False) 132 | self.mask_attn_o_net = nn.Linear(d_model, d_model, bias=False) 133 | 134 | self.interaction_kv_net = nn.Linear(d_model, 2 * d_model, bias=False) 135 | self.interaction_q_net = nn.Linear(d_model, d_model, bias=False) 136 | self.interaction_o_net = nn.Linear(d_model, d_model, bias=False) 137 | 138 | self.layer_norm_mask_attn = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 139 | self.layer_norm_interaction = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 140 | self.scale = 1 / (self.d_head ** 0.5) 141 | 142 | self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 143 | self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_heads, self.d_head)) 144 | 145 | self.r_net = nn.Linear(d_model, d_model, bias=False) 146 | 147 | self.drop = nn.Dropout(p=args["dropout"]) 148 | 149 | @staticmethod 150 | def _rel_shift(x): 151 | zero_pad_shape = (x.size(0), 1) + x.size()[2:] 152 | zero_pad = torch.zeros(zero_pad_shape, device=device, dtype=x.dtype) 153 | x_padded = torch.cat([zero_pad, x], dim=1) 154 | 155 | x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] 156 | x_padded = x_padded.view(*x_padded_shape) 157 | 158 | x = x_padded[1:].view_as(x) 159 | 160 | return x 161 | 162 | def forward(self, w, r, enc_context, attn_mask, padding_mask): 163 | # attn_mask用于Masked-Attn Mechanism(decode自身部分) 164 | # padding_mask用于Norm Multi-Attn, Decode与Encode Contextual Rep交互 165 | dec_len, bsz, enc_len = w.size(0), w.size(1), enc_context.size(0) 166 | w_heads = self.mask_attn_qkv_net(w) 167 | r_head_k = self.r_net(r) 168 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 169 | 170 | w_head_q = w_head_q.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 171 | w_head_k = w_head_k.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 172 | w_head_v = w_head_v.view(dec_len, bsz, self.n_heads, self.d_head) # dec_len x bsz x n_head x d_head 173 | 174 | r_head_k = r_head_k.view(dec_len, self.n_heads, self.d_head) # dec_len x n_head x d_head 175 | rw_head_q = w_head_q + self.r_w_bias # dec_len x bsz x n_head x d_head 176 | AC = torch.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # dec_len x dec_len x bsz x n_head 177 | rr_head_q = w_head_q + self.r_r_bias 178 | BD = torch.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # dec_len x dec_len x bsz x n_head 179 | BD = self._rel_shift(BD) 180 | 181 | attn_score = AC + BD 182 | attn_score.mul_(self.scale) 183 | 184 | # causal masking mechanism 185 | attn_mask = attn_mask == 0 # switch to bool 186 | attn_score = attn_score.float().masked_fill(attn_mask, -1e30).type_as(attn_score) 187 | attn_prob = torch.softmax(attn_score, dim=1) 188 | attn_prob = self.drop(attn_prob) 189 | 190 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) 191 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 192 | 193 | attn_out = self.mask_attn_o_net(attn_vec) 194 | attn_out = self.drop(attn_out) 195 | 196 | mask_attn_output = self.layer_norm_mask_attn(w + attn_out) 197 | 198 | # 与编码器交互部分 199 | inter_k, inter_v = torch.chunk(self.interaction_kv_net(enc_context), 2, dim=-1) 200 | inter_q = self.interaction_q_net(mask_attn_output) 201 | inter_q = inter_q.view(dec_len, bsz, self.n_heads, self.d_head) 202 | inter_k = inter_k.view(enc_len, bsz, self.n_heads, self.d_head) 203 | inter_v = inter_v.view(enc_len, bsz, self.n_heads, self.d_head) 204 | 205 | attn_score = torch.einsum("qbnd,kbnd->qkbn", inter_q, inter_k) 206 | attn_score.mul_(self.scale) 207 | 208 | # use padding_mask to mask input [PAD] token 209 | padding_mask = padding_mask[None, :, :, None].repeat(dec_len, 1, 1, 1) 210 | attn_score = attn_score + (1 - padding_mask) * (-1e30) 211 | attn_prob = torch.softmax(attn_score, dim=1) 212 | attn_prob = self.drop(attn_prob) 213 | attn_vec = torch.einsum("ijbn,jbnd->ibnd", attn_prob, inter_v) 214 | attn_vec = attn_vec.contiguous().view(dec_len, bsz, self.d_model) 215 | 216 | attn_out = self.interaction_o_net(attn_vec) 217 | attn_out = self.drop(attn_out) 218 | 219 | interaction_output = self.layer_norm_interaction(attn_out + mask_attn_output) 220 | return interaction_output 221 | 222 | 223 | class RelPartialLearnableDecoderLayer(torch.nn.Module): 224 | 225 | def _forward_unimplemented(self, *input: Any) -> None: 226 | pass 227 | 228 | def __init__(self, n_heads, d_model, d_inner): 229 | super().__init__() 230 | 231 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_heads=n_heads, d_model=d_model) 232 | self.ffn_layer = PositionwiseFFN(d_model=d_model, d_inner=d_inner) 233 | 234 | def forward(self, dec_inp, r, enc_inp, dec_mask, enc_mask): 235 | attn_output = self.dec_attn(w=dec_inp, r=r, enc_context=enc_inp, attn_mask=dec_mask, padding_mask=enc_mask) 236 | ffn_out = self.ffn_layer(attn_output) 237 | return ffn_out 238 | 239 | 240 | class XLDecoder(torch.nn.Module): 241 | 242 | def _forward_unimplemented(self, *input: Any) -> None: 243 | pass 244 | 245 | def __init__(self, dim, embedding_matrix: nn.Embedding, seq_len): 246 | super().__init__() 247 | self.d_model = dim 248 | self.word_emb = embedding_matrix 249 | self.seq_len = seq_len 250 | self.n_layers = args["decoder_layers"] 251 | self.n_heads = 16 252 | self.ffn = 4 * dim 253 | self.epsilon = 1e-6 254 | 255 | self.drop = nn.Dropout(p=args["dropout"]) 256 | self.pos_emb = XLRelPosEmb(d_embed=dim) 257 | self.layers = nn.ModuleList() 258 | 259 | self.layers = nn.ModuleList() 260 | for i in range(self.n_layers): 261 | self.layers.append( 262 | RelPartialLearnableDecoderLayer( 263 | n_heads=self.n_heads, d_model=self.d_model, d_inner=self.ffn 264 | ) 265 | ) 266 | self.output = nn.Linear(in_features=dim, out_features=dim) 267 | self.copy_output = nn.Linear(in_features=dim, out_features=dim) 268 | # 自适应的解码概率结合 269 | self.mode_select = nn.Sequential( 270 | nn.Linear(in_features=dim, out_features=1), 271 | nn.Sigmoid() 272 | ) 273 | 274 | def forward(self, input_ids, encoder_rep, input_mask, decode_input, decode_target, use_beam_search, beam_width): 275 | bsz = input_ids.size(0) 276 | if decode_input is not None: # 代表训练模式 277 | input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1) 278 | decode_embed = self.drop(self.word_emb(decode_input)) 279 | all_ones = decode_embed.new_ones((self.seq_len, self.seq_len), dtype=torch.uint8) 280 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 281 | pos_seq = torch.arange(self.seq_len - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 282 | pos_embed = self.drop(self.pos_emb(pos_seq)) 283 | core_out = decode_embed.transpose(0, 1).contiguous() 284 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 285 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 286 | for layer in self.layers: 287 | core_out = layer( 288 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 289 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 290 | ) 291 | core_out = self.drop(core_out.transpose(0, 1).contiguous()) # (bsz, dec_len, dim) 292 | output = self.output(core_out) 293 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 294 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 295 | input_logits = torch.einsum("bid,bjd->bij", self.copy_output(core_out), encoder_rep) # (bsz, dec_len, enc_len) 296 | input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat(1, self.seq_len, 1)) * (-1e30) 297 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, dec_len, enc_len) 298 | mode_sig = self.mode_select(core_out) # (bsz, dec_len, 1) 299 | vocab_prob = vocab_prob * mode_sig 300 | vocab_prob = torch.scatter_add(vocab_prob, dim=2, index=input_ids, src=(1 - mode_sig) * input_prob) 301 | vocab_prob = vocab_prob.view(-1, args["vocab_size"]) 302 | decode_target = decode_target.view(-1) 303 | predict = torch.gather(vocab_prob, dim=1, index=decode_target[:, None]).squeeze(dim=-1) 304 | init_loss = -torch.log(predict + self.epsilon) 305 | init_loss *= (decode_target != 0).float() 306 | loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) 307 | # 为了并行化设计, 将loss变成(bsz,) 308 | return loss[None].repeat(bsz) 309 | else: # 代表验证或者测试解码模式 ==> 比较耗时 310 | if not use_beam_search: # 使用贪心搜索 ==> 验证集 311 | dec_list = [] 312 | decode_ids = torch.full(size=(bsz, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) 313 | for i in range(1, self.seq_len + 1): 314 | if i > 1: 315 | decode_ids = torch.cat([decode_ids, dec_list[i - 2]], dim=-1) 316 | decode_embed = self.word_emb(decode_ids) 317 | all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) 318 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 319 | pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 320 | pos_embed = self.pos_emb(pos_seq) 321 | core_out = decode_embed.transpose(0, 1).contiguous() 322 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 323 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 324 | for layer in self.layers: 325 | core_out = layer( 326 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 327 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 328 | ) 329 | core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] 330 | output = self.output(core_out) 331 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 332 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 333 | input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz, enc_len) 334 | input_logits = input_logits + (1.0 - input_mask) * (-1e30) 335 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_len) 336 | mode_sig = self.mode_select(core_out) # (bsz, 1) 337 | vocab_prob = vocab_prob * mode_sig 338 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) 339 | dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) 340 | return torch.cat(dec_list, dim=-1) 341 | else: # 使用集束搜索 342 | # 扩展成beam_width * bsz 343 | """ 344 | 需要注意: 1. trigram-block的使用 ==> 出现重复直接加上-1e9(需要考虑end_token边界=>只在边界范围内使用) 345 | 2. 长度惩罚, 考虑end_token边界 346 | """ 347 | decode_ids = torch.full(size=(bsz * beam_width, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) 348 | input_ids = input_ids.repeat((beam_width, 1)) 349 | encoder_rep = encoder_rep.repeat((beam_width, 1, 1)) 350 | input_mask = input_mask.repeat((beam_width, 1)) 351 | dec_topK_log_probs = [0] * (beam_width * bsz) # (bsz*beam) 每个序列的当前log概率和 352 | dec_topK_sequences = [[] for _ in range(beam_width * bsz)] # (bsz*beam, seq_len) 解码id序列 353 | dec_topK_seq_lens = [1] * (beam_width * bsz) # 解码序列长度 ==> 加上一个偏置项1, 防止进行长度惩罚时出现div 0的情况 354 | for i in range(1, self.seq_len + 1): 355 | if i > 1: 356 | input_decode_ids = torch.cat([decode_ids, torch.tensor(dec_topK_sequences).long().to(device)], dim=-1) 357 | else: 358 | input_decode_ids = decode_ids 359 | decode_embed = self.word_emb(input_decode_ids) 360 | all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) 361 | dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] 362 | pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) 363 | pos_embed = self.pos_emb(pos_seq) 364 | core_out = decode_embed.transpose(0, 1).contiguous() 365 | enc_rep_t = encoder_rep.transpose(0, 1).contiguous() 366 | enc_mask_t = input_mask.transpose(0, 1).contiguous() 367 | for layer in self.layers: 368 | core_out = layer( 369 | dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, 370 | dec_mask=dec_attn_mask, enc_mask=enc_mask_t 371 | ) 372 | core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] 373 | output = self.output(core_out) 374 | vocab_logits = torch.nn.functional.linear(input=output, weight=self.word_emb.weight) 375 | vocab_prob = torch.softmax(vocab_logits, dim=-1) 376 | input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz*beam, enc_len) 377 | input_logits = input_logits + (1.0 - input_mask) * (-1e30) 378 | input_prob = torch.softmax(input_logits, dim=-1) # (bsz*beam, enc_len) 379 | mode_sig = self.mode_select(core_out) # (bsz*beam, 1) 380 | vocab_prob = vocab_prob * mode_sig 381 | vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) # (bsz*beam, vocab) 382 | vocab_logp = torch.log(vocab_prob + self.epsilon) # 取对数, 加eps 383 | """ step1: 检查是否存在trigram blocking重叠, 只需要检查最后一项和之前项是否存在重叠即可 """ 384 | if i > 4: # 当序列长度大于等于4时才有意义, 或者当前解码时刻大于4时才有检查的必要 385 | for j in range(beam_width * bsz): 386 | trigram_blocks = [] 387 | for k in range(3, i): 388 | if dec_topK_sequences[j][k-1] == args["end_token_id"]: 389 | break 390 | trigram_blocks.append(dec_topK_sequences[j][k-3:k]) 391 | if len(trigram_blocks) > 1 and trigram_blocks[-1] in trigram_blocks[:-1]: 392 | dec_topK_log_probs[j] += -1e9 393 | """ step2: 为每个样本, 选择topK个序列 ==> 类似于重构dec_topK_sequences""" 394 | for j in range(bsz): 395 | topK_vocab_logp = vocab_logp[j::bsz] # (k, vocab) 396 | candidate_list = [] 397 | """ 容易出错的地方, i=1的时候不需要为每个K生成K个候选,否则beam search将完全沦为greedy search """ 398 | for k in range(beam_width): 399 | ind = j + k * bsz 400 | if args["end_token_id"] in dec_topK_sequences[ind]: 401 | candidate_list.append({ 402 | "add_logit": 0, "add_seq_len": 0, "affiliate_k": k, "add_token_id": args["end_token_id"], 403 | "sort_logits": dec_topK_log_probs[ind] / (dec_topK_seq_lens[ind] ** args["beam_length_penalty"]) 404 | }) 405 | else: 406 | k_logps, k_indices = topK_vocab_logp[k].topk(dim=0, k=beam_width) 407 | k_logps, k_indices = k_logps.cpu().numpy(), k_indices.cpu().numpy() 408 | for l in range(beam_width): 409 | aff = l if i == 1 else k 410 | candidate_list.append({ 411 | "add_logit": k_logps[l], "add_seq_len": 1, "affiliate_k": aff, "add_token_id": k_indices[l], 412 | "sort_logits": (dec_topK_log_probs[ind] + k_logps[l]) / ((dec_topK_seq_lens[ind] + 1) ** args["beam_length_penalty"]) 413 | }) 414 | if i == 1: ## 当解码第一个词的时候只能考虑一个 415 | break 416 | candidate_list.sort(key=lambda x: x["sort_logits"], reverse=True) 417 | candidate_list = candidate_list[:beam_width] 418 | """ 序列修正, 更新topK """ 419 | c_dec_topK_sequences, c_dec_topK_log_probs, c_dec_topK_seq_lens = \ 420 | deepcopy(dec_topK_sequences), deepcopy(dec_topK_log_probs), deepcopy(dec_topK_seq_lens) 421 | for k in range(beam_width): 422 | ind = bsz * candidate_list[k]["affiliate_k"] + j 423 | r_ind = bsz * k + j 424 | father_seq, father_logits, father_len = c_dec_topK_sequences[ind], c_dec_topK_log_probs[ind], c_dec_topK_seq_lens[ind] 425 | dec_topK_sequences[r_ind] = father_seq + [candidate_list[k]["add_token_id"]] 426 | dec_topK_log_probs[r_ind] = father_logits + candidate_list[k]["add_logit"] 427 | dec_topK_seq_lens[r_ind] = father_len + candidate_list[k]["add_seq_len"] 428 | return torch.tensor(dec_topK_sequences[:bsz]).long().to(device) 429 | 430 | 431 | class MyModel(torch.nn.Module): 432 | def _forward_unimplemented(self, *input: Any) -> None: 433 | pass 434 | 435 | def __init__(self, pre_train_dir: str): 436 | super().__init__() 437 | self.roberta_encoder = BertModel(config=BertConfig.from_json_file(pre_train_dir+ "config.json")) 438 | self.decoder_layer = XLDecoder( 439 | dim=args["dimension"], embedding_matrix=self.roberta_encoder.get_input_embeddings(), 440 | seq_len=args["max_dec_len"]) 441 | 442 | def forward(self, input_ids, input_mask, input_seg, decode_input=None, decode_target=None): 443 | encoder_rep = self.roberta_encoder(input_ids, input_mask, input_seg)[0] 444 | return self.decoder_layer(input_ids, encoder_rep, input_mask, decode_input, decode_target, 445 | args["use_beam_search"], 446 | args["beam_width"]) 447 | 448 | 449 | class WarmUp_LinearDecay: 450 | def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_steps, decay_steps, min_lr_rate): 451 | self.optimizer = optimizer 452 | self.init_rate = init_rate 453 | self.warm_up_steps = warm_up_steps 454 | self.decay_steps = decay_steps 455 | self.min_lr_rate = min_lr_rate 456 | self.optimizer_step = 0 457 | 458 | def step(self): 459 | self.optimizer_step += 1 460 | if self.optimizer_step <= self.warm_up_steps: 461 | rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate 462 | elif self.warm_up_steps < self.optimizer_step <= (self.warm_up_steps + self.decay_steps): 463 | rate = (1.0 - ((self.optimizer_step - self.warm_up_steps) / self.decay_steps)) * self.init_rate 464 | else: 465 | rate = self.min_lr_rate 466 | for p in self.optimizer.param_groups: 467 | p["lr"] = rate 468 | self.optimizer.step() 469 | 470 | 471 | class Main(object): 472 | def __init__(self, train_loader, valid_loader, test_flag=False, test_items=None): 473 | self.train_loader = train_loader 474 | self.valid_loader = valid_loader 475 | self.test_items = test_items 476 | self.model = MyModel(pre_train_dir=args["pre_train_dir"]) 477 | 478 | if test_flag: 479 | self.model.load_state_dict(torch.load(args["save_path"], map_location=device), strict=False) 480 | else: 481 | self.model.load_state_dict(torch.load(args["load_path"], map_location=device), strict=False) 482 | param_optimizer = list(self.model.named_parameters()) 483 | no_decay = ['bias', 'gamma', 'beta'] 484 | optimizer_grouped_parameters = [ 485 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 486 | 'weight_decay_rate': args["weight_decay"]}, 487 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 488 | 'weight_decay_rate': 0.0} 489 | ] 490 | 491 | self.optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args["init_lr"]) 492 | self.schedule = WarmUp_LinearDecay(optimizer=self.optimizer, init_rate=args["init_lr"], 493 | warm_up_steps=args["warm_up_steps"], 494 | decay_steps=args["lr_decay_steps"], min_lr_rate=args["min_lr_rate"]) 495 | self.model.to(device=device) 496 | if args["is_train"]: 497 | self.model = nn.parallel.DistributedDataParallel(module=self.model, dim=0, find_unused_parameters=True) 498 | 499 | def train(self): 500 | best_rl = 0.0 501 | self.model.train() 502 | steps = 0 503 | while True: 504 | for item in self.train_loader: 505 | input_ids, input_mask, input_seg, decode_input, decode_target = \ 506 | item["input_ids"], item["input_mask"], item["input_seg"], item["decode_input"], item[ 507 | "decode_target"] 508 | self.optimizer.zero_grad() 509 | loss = self.model( 510 | input_ids=input_ids.to(device), 511 | input_mask=input_mask.to(device), 512 | input_seg=input_seg.to(device), 513 | decode_input=decode_input.to(device), 514 | decode_target=decode_target.to(device) 515 | ) 516 | loss = loss.float().mean().type_as(loss) 517 | loss.backward() 518 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=args["clip_norm"]) 519 | self.schedule.step() 520 | steps += 1 521 | writer.add_scalar("loss", loss.item(), global_step=steps) 522 | if steps % args["eval_interval"] == 0: 523 | rl = self.valid() 524 | writer.add_scalar("valid_rl", rl, global_step=steps) 525 | if rl > best_rl: 526 | best_rl = rl 527 | torch.save(self.model.module.state_dict(), f=args["save_path"]) 528 | if steps >= args["max_steps"]: 529 | break 530 | if steps >= args["max_steps"]: 531 | break 532 | writer.flush() 533 | writer.close() 534 | 535 | def valid(self): 536 | self.model.eval() 537 | rouge_l = [] 538 | with torch.no_grad(): 539 | for item in self.valid_loader: 540 | input_ids, input_mask, input_seg, label = item["input_ids"], item["input_mask"], item["input_seg"], \ 541 | item["label"] 542 | dec_seq = self.model( 543 | input_ids=input_ids.to(device), 544 | input_mask=input_mask.to(device), 545 | input_seg=input_seg.to(device) 546 | ) 547 | dec_seq = dec_seq.cpu().numpy() 548 | for i in range(len(dec_seq)): 549 | x = dec_seq[i] 550 | s = [] 551 | for j in x: 552 | if int(j) == args["end_token_id"]: 553 | break 554 | else: 555 | s.append(int(j)) 556 | s = "".join(tokenizer.convert_ids_to_tokens(s)) 557 | s = s.replace(",", "").replace("[UNK]", "") 558 | char_lis = [] 559 | for c in s: 560 | if c not in char_lis: 561 | char_lis.append(c) 562 | for c in char_lis: 563 | try: 564 | p = re.compile("(%s){2,}" % c) 565 | s = re.sub(p, c, s) 566 | except Exception as e: 567 | continue 568 | rouge_l.append(self.rouge_l(hypo=s, refer=label[i])) 569 | self.model.train() 570 | return np.average(rouge_l) 571 | 572 | @staticmethod 573 | def test_encode(context, answer): 574 | context_tokens = tokenizer.tokenize(context.replace("\n", " ").replace("\t", " ").replace("\\", "")) 575 | answer_tokens = tokenizer.tokenize(answer)[:args["max_answer_len"]] 576 | c = ["[CLS]"] + answer_tokens + ["[SEP]"] + context_tokens 577 | if len(c) > args["max_enc_len"] - 1: 578 | c = c[:args["max_enc_len"] - 1] 579 | c += ["[SEP]"] 580 | input_ids = tokenizer.convert_tokens_to_ids(c) 581 | input_mask = [1.0] * len(input_ids) 582 | input_seg = [0] * (len(answer_tokens) + 2) + [1] * (len(input_ids) - 2 - len(answer_tokens)) 583 | extra = args["max_enc_len"] - len(input_ids) 584 | if extra > 0: 585 | input_ids += [0] * extra 586 | input_mask += [0.0] * extra 587 | input_seg += [1] * extra 588 | return { 589 | "input_ids": torch.tensor(input_ids).long().unsqueeze(dim=0).to(device), 590 | "input_mask": torch.tensor(input_mask).float().unsqueeze(dim=0).to(device), 591 | "input_seg": torch.tensor(input_seg).long().unsqueeze(dim=0).to(device) 592 | } 593 | 594 | def test(self): 595 | self.model.eval() 596 | output = x["test_items"] 597 | with torch.no_grad(): 598 | for i in range(len(output)): 599 | text = output[i]["text"] 600 | annotations = output[i]["annotations"] 601 | tmp_enc_ids, tmp_enc_mask, tmp_enc_seg = [], [], [] 602 | for j in range(len(annotations)): 603 | y = self.test_encode(text, annotations[j]["A"]) 604 | tmp_enc_ids.append(y["input_ids"]) 605 | tmp_enc_mask.append(y["input_mask"]) 606 | tmp_enc_seg.append(y["input_seg"]) 607 | dec_seq = self.model( 608 | input_ids=torch.cat(tmp_enc_ids, dim=0), 609 | input_mask=torch.cat(tmp_enc_mask, dim=0), 610 | input_seg=torch.cat(tmp_enc_seg, dim=0) 611 | ) 612 | dec_seq = dec_seq.cpu().numpy() 613 | for j in range(len(dec_seq)): 614 | y = dec_seq[j] 615 | s = [] 616 | for k in y: 617 | if int(k) == args["end_token_id"]: 618 | break 619 | else: 620 | s.append(int(k)) 621 | s = "".join(tokenizer.convert_ids_to_tokens(s)) 622 | s = s.replace(",", "").replace("[UNK]", "").replace("#", "") 623 | char_lis = [] 624 | for c in s: 625 | if c not in char_lis: 626 | char_lis.append(c) 627 | for c in char_lis: 628 | try: 629 | p = re.compile("(%s){2,}" % c) 630 | s = re.sub(p, c, s) 631 | except Exception as e: 632 | continue 633 | # 针对英文的一些修正 634 | t_text = text.lower() 635 | p = re.compile("([A-Za-z]+)") 636 | m = re.finditer(p, s) 637 | for i_match in m: 638 | start, end, i_str = i_match.start(), i_match.end(), i_match.group() 639 | if i_str in t_text: 640 | i_index = t_text.index(i_str) 641 | s = s[:start] + text[i_index: i_index + (end - start)] + s[end:] 642 | if len(s) < 2: 643 | s = annotations[j]["A"] 644 | annotations[j]["Q"] = s 645 | if i % 50 == 0 and i > 0: 646 | print("The program has completed %s predictions" % i) 647 | with open("submit.json", "w", encoding="UTF-8") as f: 648 | json.dump(output, f, ensure_ascii=False) 649 | print("The program has completed all predictions") 650 | 651 | @staticmethod 652 | def rouge_l(hypo, refer): 653 | if len(hypo) == 0 or len(refer) == 0: 654 | return 0 655 | x = [[0 for _ in range(len(refer) + 1)] for _ in range(len(hypo) + 1)] 656 | lcs = 0 657 | for i in range(len(hypo)): 658 | for j in range(len(refer)): 659 | if hypo[i] == refer[j]: 660 | x[i + 1][j + 1] = x[i][j] + 1 661 | if x[i + 1][j + 1] > lcs: 662 | lcs = x[i + 1][j + 1] 663 | else: 664 | x[i + 1][j + 1] = max(x[i + 1][j], x[i][j + 1]) 665 | p, r = lcs / len(hypo), lcs / len(refer) 666 | if (p + r) == 0: 667 | return 0 668 | else: 669 | return (2 * p * r) / (p + r) 670 | 671 | 672 | if __name__ == "__main__": 673 | device = "cuda" 674 | args = { 675 | "init_lr": 2e-5, 676 | "batch_size": 8, 677 | "mos": 2, 678 | "weight_decay": 0.01, 679 | "warm_up_steps": 1000, 680 | "lr_decay_steps": 15000, 681 | "max_steps": 16000, 682 | "min_lr_rate": 1e-9, 683 | "eval_interval": 1000, 684 | "save_path": "ModelStorage/final.pth", 685 | "load_path": "ModelStorage/xl_dureader_drmc.pth", 686 | "pre_train_dir": "/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/", 687 | "clip_norm": 0.25, 688 | "start_token": "[unused1]", 689 | "end_token": "[unused2]", 690 | "start_token_id": 1, 691 | "end_token_id": 2, 692 | "dimension": 1024, 693 | "max_enc_len": 512, 694 | "max_dec_len": 50, 695 | "max_answer_len": 100, 696 | "use_beam_search": False, 697 | "beam_width": 5, 698 | "beam_length_penalty": 0.6, 699 | "decoder_layers": 3, 700 | "dropout": 0.1, 701 | "vocab_size": 21128, 702 | "init_range": 0.02, 703 | "init_std": 0.02 704 | } 705 | 706 | with open("DataSet/multi_task.pkl", "rb") as f: 707 | x = pickle.load(f) 708 | 709 | tokenizer = BertTokenizer(vocab_file="/home/ldmc/quanlin/Pretrained_NLP_Models/Pytorch/RoBERTa_Large_ZH/vocab.txt") 710 | 711 | if sys.argv[1] == "train": 712 | torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1, init_method='tcp://localhost:7011') 713 | args["is_train"] = True 714 | writer = SummaryWriter(logdir="RunLog/%s" % sys.argv[3]) 715 | train_dataset = MyDataset(data=x["train_items"], max_enc_len=args["max_enc_len"], 716 | max_dec_len=args["max_dec_len"]) 717 | valid_dataset = MyDataset(data=x["valid_items"], max_enc_len=args["max_enc_len"], 718 | max_dec_len=args["max_dec_len"]) 719 | 720 | train_loader = DataLoader(train_dataset, batch_size=args["batch_size"], shuffle=True, num_workers=4) 721 | valid_loader = DataLoader(valid_dataset, batch_size=args["batch_size"], shuffle=False, num_workers=4) 722 | 723 | m = Main(train_loader, valid_loader) 724 | m.train() 725 | else: 726 | writer = None 727 | args["is_train"] = False 728 | args["use_beam_search"] = True 729 | m = Main(None, None, test_flag=True, test_items=x["test_items"]) 730 | m.test() 731 | --------------------------------------------------------------------------------