├── src ├── tool │ ├── __init__.py │ ├── extract_emotion_event.py │ └── graph.py ├── config.py └── main.py ├── scripts ├── train.sh ├── generate.sh └── preprocess.sh ├── requirements.txt ├── LICENSE └── README.md /src/tool/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python src/main.py \ 2 | --train_data data/train_dynamic_persona.json \ 3 | --valid_data data/valid_dynamic_persona.json \ 4 | --config_path gpt2 \ 5 | --epoch 5 \ 6 | --batch_size 4 \ 7 | --accumulate_grad 2 \ 8 | --gpu $1 \ 9 | --save_dir results -------------------------------------------------------------------------------- /scripts/generate.sh: -------------------------------------------------------------------------------- 1 | python src/main.py \ 2 | --generate \ 3 | --test_data data/test_dynamic_persona.json \ 4 | --output_path results/lightning_logs/version_0/gen.json \ 5 | --ckpt_path results/ckpt/epoch=4-step=29944.ckpt \ 6 | --config_path gpt2 \ 7 | --batch_size 32 \ 8 | --gpu $1 -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | cd ../src/preprocess 2 | echo 'nltk download corpus...' 3 | python -c "import nltk; nltk.download('stopwords'); nltk.download('vader_lexicon')" 4 | echo 'aggregate...' 5 | python process.py 6 | echo 'split scene...' 7 | python split.py 8 | echo 'get target sentence...' 9 | python keep_one_card.py 10 | echo 'extract keywords...' 11 | python extract_emotion_event.py 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rouge==1.0.0 2 | nltk==3.5 3 | pandas==1.2.1 4 | torch==1.7.1 5 | transformers==4.0.1 6 | scipy==1.4.1 7 | dataclasses==0.6 8 | requests==2.25.1 9 | pytorch_lightning==1.2.1 10 | numpy==1.18.5 11 | bert_score==0.3.7 12 | tqdm==4.56.0 13 | six==1.15.0 14 | munkres==1.1.4 15 | matplotlib==3.3.4 16 | grequests==0.6.0 17 | pyenchant==3.2.0 18 | scikit_learn==0.24.2 19 | spacy==3.0.6 20 | statsmodels==0.12.2 21 | summa==1.2.0 22 | zhon==1.1.5 23 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Config: 6 | 7 | def __init__(self, args=None, file=None): 8 | self.py_name, _ = os.path.splitext(file) 9 | 10 | if args is not None: 11 | for k, v in args.__dict__.items(): 12 | self.__setattr__(k, v) 13 | 14 | def show(self): 15 | for name, value in vars(self).items(): 16 | print(f"{name}={value}") 17 | 18 | def add_display(self, name): 19 | if hasattr(self, name): 20 | return f'_{name}{getattr(self, name)}' 21 | else: 22 | return '' 23 | 24 | def get_generate_out_file_name(self): 25 | res = self.py_name 26 | names = ['version_num', 'sent'] 27 | for name in names: 28 | res += self.add_display(name) 29 | return res 30 | 31 | 32 | if __name__ == '__main__': 33 | print(__file__) 34 | config = Config(file=__file__) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 thu-coai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConPer 2 | 3 | Code and datasets for our paper [Persona-Guided Planning for Persona-Aware Story Generation](https://arxiv.org/pdf/2204.10703.pdf) 4 | 5 | ## 1. Environment Setup 6 | 7 | python 3.7 8 | 9 | pip 20.3.3 10 | 11 | Install dependencies 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | ## 2.Run 16 | ### Preparation 17 | 18 | #### Download datasets 19 | The preprocessed datasets can be obtained from this [link](https://drive.google.com/drive/u/0/folders/1MrcOc04waE13U-PXx5nkQRQiaB7Si5ON). 20 | 21 | You need to put the preprocessed data in `data/`. 22 | 23 | #### Download fine-tuned model 24 | 25 | The fine-tuned model can be obatined from this [link](https://drive.google.com/drive/u/0/folders/13p0TZocDWLfnUQLcO_78Zo01q57-xjWE) 26 | 27 | You need to put the fine-tuned checkpoint in `results/`. 28 | 29 | ### Train 30 | 31 | To train a model, you can run the following command, where `0` denotes GPU_ID. 32 | 33 | ``` 34 | bash scripts/train.sh 0 35 | ``` 36 | ### Generate 37 | 38 | To generate stories, you can run the following command, where `0` denotes GPU_ID. 39 | 40 | ``` 41 | bash scripts/generate.sh 0 42 | 43 | main arguments: 44 | 45 | --ckpt_path: path of the fine-tuned checkpoint 46 | --output_path: path of the generation result 47 | ``` 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/tool/extract_emotion_event.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import string 5 | 6 | import nltk.sentiment.sentiment_analyzer 7 | import numpy as np 8 | import torch 9 | from bert_score import BERTScorer 10 | from nltk import sent_tokenize 11 | from nltk.corpus import stopwords, wordnet 12 | from nltk.sentiment.vader import SentimentIntensityAnalyzer 13 | from nltk.stem import WordNetLemmatizer 14 | from tqdm import tqdm 15 | from multiprocessing import Pool 16 | 17 | 18 | random.seed(2020) 19 | stop = stopwords.words('english') + list(string.punctuation) # + ["'s", "'m", "'re", "'ve"] 20 | 21 | sid = SentimentIntensityAnalyzer() 22 | stemmer = WordNetLemmatizer() 23 | 24 | discard_words = [] 25 | 26 | def extract_emotion_event_at_least_one(text, limit=False, per_sent=False): 27 | def get_wordnet_pos(tag): 28 | if tag.startswith("J"): 29 | return wordnet.ADJ 30 | elif tag.startswith("V"): 31 | return wordnet.VERB 32 | elif tag.startswith("N"): 33 | return wordnet.NOUN 34 | elif tag.startswith("R"): 35 | return wordnet.ADV 36 | else: 37 | return wordnet.NOUN 38 | 39 | def lemmatize(word, nltk_tag): 40 | tag = get_wordnet_pos(nltk_tag) 41 | return stemmer.lemmatize(word, tag).lower() 42 | 43 | def sample_sorted(words, num): 44 | ids = list(range(len(words))) 45 | choosed = random.sample(ids, num) 46 | choosed = sorted(choosed) 47 | res = [] 48 | for i in choosed: 49 | res.append(words[i]) 50 | return res 51 | 52 | def extract_emotion_event(sent, last_word): 53 | global sid, discard_words 54 | 55 | words = nltk.word_tokenize(sent) 56 | tagged = nltk.pos_tag(words) 57 | res = [] 58 | res.append(last_word) 59 | 60 | 61 | for word, tag in tagged: 62 | if not word[0].isalpha(): 63 | continue 64 | origin_word = lemmatize(word, tag) 65 | 66 | ss = sid.polarity_scores(word) 67 | for key in ss: 68 | if not word[0].isalpha(): 69 | continue 70 | if ss[key] > 0.5 and (key == 'pos' or key == 'neg'): 71 | 72 | res.append(origin_word) 73 | 74 | if get_wordnet_pos(tag) in [wordnet.VERB, wordnet.NOUN]: 75 | if origin_word in stop or word in stop: 76 | continue 77 | if origin_word in discard_words or word in discard_words: 78 | continue 79 | res.append(origin_word) 80 | 81 | res.pop(0) 82 | from math import ceil 83 | max_len = min(5, ceil(len(words) * 0.1)) 84 | if limit and max_len < len(res): 85 | return sample_sorted(res, max_len) 86 | else: 87 | return res 88 | 89 | def choose_one_word(sent): 90 | words = nltk.word_tokenize(sent) 91 | tagged = nltk.pos_tag(words) 92 | new_tagged = [(word, tag) for word, tag in tagged if word[0].isalpha()] 93 | if not new_tagged: 94 | word, tag = random.choice(tagged) 95 | else: 96 | word, tag = random.choice(new_tagged) 97 | return [lemmatize(word, tag)] 98 | 99 | if isinstance(text, list): 100 | sents = text 101 | else: 102 | sents = nltk.sent_tokenize(text) 103 | 104 | res = [] 105 | choose_one_word_cnt = 0 106 | for sent in sents: 107 | last_word = None 108 | if res: 109 | last_word = res[-1][-1] 110 | w = extract_emotion_event(sent, last_word) 111 | if not per_sent: 112 | res.extend(w) 113 | else: 114 | if w: 115 | res.append(w) 116 | else: 117 | choose_one_word_cnt += 1 118 | res.append(choose_one_word(sent)) 119 | 120 | if per_sent: 121 | return res, choose_one_word_cnt 122 | 123 | if res: 124 | return res, 0 125 | else: 126 | return choose_one_word(sents[0]), 1 127 | 128 | -------------------------------------------------------------------------------- /src/tool/graph.py: -------------------------------------------------------------------------------- 1 | class RelationList: 2 | def __init__(self): 3 | self.relation2id = {} 4 | self.cnt = 0 5 | 6 | def add(self, relation): 7 | if relation not in self.relation2id: 8 | self.relation2id[relation] = self.cnt 9 | self.cnt += 1 10 | return self.relation2id[relation] 11 | 12 | def get_idx(self, relation): 13 | return self.relation2id[relation] 14 | 15 | def __len__(self): 16 | return self.cnt 17 | 18 | def __repr__(self): 19 | result = '' 20 | for k, v in self.relation2id.items(): 21 | result += f'{k} : {v}\n' 22 | return result 23 | 24 | 25 | class KnowledgeGraph: 26 | 27 | def __init__(self, edges, bidir=False): 28 | self.data = {} 29 | self.prev = {} 30 | self.weights = {} 31 | self.relations = RelationList() 32 | 33 | self.relations.add('unrelated') # relation_id 0 对应和NOT_A_FACT的连接 34 | 35 | for item in edges: 36 | # [head, relation, tail, weight] 37 | head = item[0] 38 | relation = item[1] 39 | tail = item[2] 40 | 41 | if not self.eng_word(head) or not self.eng_word(tail): 42 | continue 43 | assert '/' not in head and '/' not in tail 44 | 45 | relation_id = self.relations.add(relation) 46 | self.add(head, relation_id, tail) 47 | if bidir: 48 | self.add(tail, relation_id, head) 49 | self.weights[self.get_name(item[0], item[-2])] = float(item[-1]) 50 | if bidir: 51 | self.weights[self.get_name(item[-2], item[0])] = float(item[-1]) 52 | 53 | print(f"relation nums:{len(self.relations)}") 54 | print(self.relations) 55 | 56 | def get_relation_size(self): 57 | # 返回relation总数 58 | return len(self.relations) 59 | 60 | def get_relation_list(self): 61 | return self.relations 62 | 63 | # def get_relation_idx(self, relation): 64 | # 加进图谱的时候已经转换成id了 65 | # return self.relations.get_idx(relation) 66 | 67 | def filter_points(self, points): 68 | res = [] 69 | for pt in points: 70 | if pt in self.data: 71 | res.append(pt) 72 | return res 73 | 74 | def check(self, point): 75 | return point in self.data 76 | 77 | def get_name(self, src, dst): 78 | return src + "___" + dst 79 | 80 | def get_weight(self, src, dst): 81 | name = self.get_name(src, dst) 82 | if name in self.weights: 83 | return self.weights[name] 84 | return None 85 | 86 | def eng_word(self, word): 87 | if '_' in word: 88 | return False 89 | return True 90 | 91 | def get_avg_deg(self): 92 | r = 0 93 | for src in self.data: 94 | r += len(self.data[src]) 95 | 96 | return r / len(self.data) 97 | 98 | def show_degs(self): 99 | data = list(self.data.items()) 100 | print(data[-3:]) 101 | data.sort(key=lambda x: len(x[1])) 102 | for k, v in data: 103 | print(f'{k}:{len(v)}') 104 | 105 | def get_node_num(self): 106 | return len(self.data) 107 | 108 | def add(self, src, relation, dst): 109 | w = (dst, relation) 110 | if src in self.data: 111 | if w not in self.data[src]: 112 | self.data[src].append(w) 113 | else: 114 | self.data[src] = [w] 115 | 116 | q = (src, relation) 117 | if dst in self.prev: 118 | if q not in self.prev[dst]: 119 | self.prev[dst].append(q) 120 | else: 121 | self.prev[dst] = [q] 122 | 123 | 124 | def get_neighbors(self, pt, relation=False): 125 | if pt not in self.data: 126 | return [] 127 | else: 128 | if relation: 129 | return self.data[pt] 130 | else: 131 | return [i[0] for i in self.data[pt]] 132 | 133 | 134 | def get_triples(self, word): 135 | 136 | res = [] 137 | if word in self.data: 138 | for dst, r in self.data[word]: 139 | res.append((word, r, dst)) 140 | 141 | if word in self.prev: 142 | for src, r in self.prev[word]: 143 | res.append((src, r, word)) 144 | 145 | if not res: 146 | res.append((word, 0, 'NOT_A_FACT')) 147 | 148 | return res 149 | 150 | def get_hops_set(self, srcs, hop, relation=False): 151 | res = set(srcs) 152 | step = 0 153 | temp = set(srcs) 154 | while step < hop: 155 | step += 1 156 | new_temp = [] 157 | for pt in temp: 158 | ns = self.get_neighbors(pt, relation=relation) 159 | for n in ns: 160 | if n not in res: 161 | new_temp.append(n) 162 | new_temp = set(new_temp) 163 | temp = new_temp 164 | res = res | new_temp 165 | return res 166 | 167 | def get_intersect(self, srcs, dsts, hop=2): 168 | src_neis = self.get_hops_set(srcs, hop) 169 | dst_neis = self.get_hops_set(dsts, hop) 170 | return src_neis & dst_neis 171 | 172 | def find_neigh_in_set(self, src, points): 173 | res = [] 174 | if src not in self.data: 175 | return res 176 | for pt in points: 177 | if pt in self.data[src]: 178 | res.append(pt) 179 | return set(res) 180 | 181 | def find_paths(self, srcs, dsts): 182 | a = self.get_hops_set(srcs, 1) 183 | res = [] 184 | for w in a: 185 | x = self.find_neigh_in_set(w, srcs) 186 | y = self.find_neigh_in_set(w, dsts) 187 | if x and y: 188 | res.append([x, w, y]) 189 | return res 190 | 191 | def show_paths(self, srcs, dsts): 192 | paths = self.find_paths(srcs, dsts) 193 | for path in paths: 194 | print(path) 195 | 196 | def get_dis(self, dst, srcs, max_hop=3): 197 | vis = set() 198 | points = [dst] 199 | vis.add(dst) 200 | step = 0 201 | if dst in srcs: 202 | return step 203 | while step < max_hop: 204 | step += 1 205 | temp_points = [] 206 | for pt in points: 207 | ns = self.get_neighbors(pt) 208 | for n in ns: 209 | if n in srcs: 210 | return step 211 | if n in vis: 212 | continue 213 | vis.add(n) 214 | temp_points.append(n) 215 | points = temp_points 216 | return step 217 | 218 | 219 | def get_conceptnet(path): 220 | 221 | with open(path, encoding='utf-8') as f: 222 | lines = f.readlines() 223 | print(len(lines)) 224 | edges = [] 225 | for line in lines: 226 | edge = line.strip().split('|||') 227 | edges.append(edge) 228 | 229 | return KnowledgeGraph(edges) 230 | 231 | 232 | if __name__ == '__main__': 233 | 234 | graph = get_conceptnet() 235 | print(f"node num:{graph.get_node_num()}, avg deg:{graph.get_avg_deg()}") 236 | 237 | print(graph.get_hops_set(['people'], hop=1, relation=False)) 238 | print('='*100) 239 | print(graph.get_hops_set(['people'], hop=1, relation=True)) 240 | # graph.show_degs() 241 | # print(graph.get_hops_set(['people'], 1)) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import nltk 5 | from nltk import word_tokenize, sent_tokenize 6 | from itertools import chain 7 | import numpy as np 8 | from tqdm import tqdm 9 | import copy 10 | from time import sleep 11 | 12 | from bert_score import BERTScorer 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from torch.utils.data import Dataset, DataLoader 18 | from torch.nn.utils.rnn import pad_sequence 19 | from torch.optim import Adam 20 | 21 | from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, GPT2PreTrainedModel, GPT2Model, AutoConfig, LogitsProcessorList 22 | from transformers.modeling_outputs import CausalLMOutputWithPastAndCrossAttentions 23 | from typing import Optional, Union 24 | 25 | import pytorch_lightning as pl 26 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 27 | from pytorch_lightning import seed_everything 28 | 29 | from tool.graph import get_conceptnet, KnowledgeGraph 30 | from tool.extract_emotion_event import extract_emotion_event_at_least_one 31 | 32 | torch.multiprocessing.set_sharing_strategy('file_system') 33 | 34 | 35 | def get_idf_sents(): 36 | def get_idf_docs(): 37 | if os.path.exists('train_doc_idf.json'): 38 | with open('train_doc_idf.json', encoding='utf-8') as f: 39 | return json.load(f) 40 | file_name = 'train.json' 41 | with open(file_name, encoding='utf-8') as f: 42 | a = json.load(f) 43 | hyps = [] 44 | refs = [] 45 | for scene in a: 46 | for entry in scene['entries']: 47 | hyps.append(entry['description']) 48 | for card in scene['entries'][-1]['cards']: 49 | refs.append(card['description']) 50 | with open('train_doc_idf.json', 'w', encoding='utf-8') as fi: 51 | json.dump(hyps + refs, fi, ensure_ascii=False) 52 | print('finish get_idf_sent') 53 | 54 | return hyps + refs 55 | return get_idf_docs() 56 | 57 | class Helper(): 58 | def __init__(self, args): 59 | self.tokenizer = GPT2Tokenizer.from_pretrained(args.config_path) 60 | self.eoc_token = '<|endofcard|>' 61 | self.bot_token = '<|beginoftarget|>' 62 | self.eot_token = '<|endoftarget|>' 63 | self.boo_token = '<|beginofoutline|>' 64 | self.bob_token = '<|beginofbedding|>' 65 | self.boe_token = '<|beginofending|>' 66 | self.soo_token = '<|sepofoutline|>' 67 | self.soos_token = '<|sepofoutlinesent|>' 68 | self.eop_token = '<|endofprompt|>' 69 | self.son_token = '<|sepofname|>' 70 | self.not_a_fact = 'NOT_A_FACT' 71 | self.tokenizer.add_tokens( 72 | [self.eop_token, self.eoc_token, self.bot_token, self.eot_token, self.bob_token, self.boe_token, 73 | self.boo_token, self.soo_token, self.soos_token]) 74 | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 75 | 76 | def get_token_id(self, token): 77 | return self.tokenizer.convert_tokens_to_ids(token) 78 | 79 | def pad_seq(self, ids, max_len): 80 | return self.tokenizer.pad(ids, max_len) 81 | 82 | def call_tokenizer(self, text): 83 | return self.tokenizer(text) 84 | 85 | def get_vocab_size(self): 86 | return len(self.tokenizer) 87 | 88 | 89 | def get_nodes_dis(words, device=None, in_graph=True): 90 | if not in_graph: 91 | if device is None: 92 | return torch.zeros(helper.get_vocab_size(), dtype=torch.float) 93 | else: 94 | return torch.zeros(helper.get_vocab_size(), dtype=torch.float, device=device) 95 | 96 | seq = helper.soo_token.join(words) 97 | ids = helper.tokenizer.encode(seq) 98 | if device is None: 99 | res = torch.zeros(helper.get_vocab_size(), dtype=torch.float) 100 | else: 101 | res = torch.zeros(helper.get_vocab_size(), dtype=torch.float, device=device) 102 | res[ids] = 1 103 | return (1 - res) * (-1e10) 104 | 105 | class StoriumDataset(Dataset): 106 | 107 | def __init__(self, in_file): 108 | super().__init__() 109 | self.in_file = in_file 110 | 111 | with open(self.in_file, encoding='utf-8') as f: 112 | self.data = json.load(f) 113 | 114 | self.vocab_size = len(helper.tokenizer) 115 | 116 | def __len__(self): 117 | return len(self.data) 118 | 119 | def __getitem__(self, idx): 120 | 121 | item = self.data[idx] 122 | 123 | input = '' 124 | 125 | # persona 126 | for card in item['entries'][-1]['cards']: 127 | input += card['description'] 128 | input += helper.eoc_token 129 | 130 | # context 131 | entri_des = [] 132 | for entry in item['entries'][:-1]: 133 | entri_des.append(entry['description']) 134 | input += ' '.join(entri_des) 135 | input += helper.eop_token 136 | 137 | output = '' 138 | 139 | # target sentence 140 | sents = nltk.sent_tokenize(item['entries'][-1]['description']) 141 | peak_idx = item['peak_idx'] 142 | output += helper.bot_token + sents[peak_idx] + helper.eot_token 143 | 144 | # outline keywords 145 | keywords_list = item['bedding_kws'] + item['ending_kws'] 146 | old_keywords_list = item['bedding_kws'] + item['ending_kws'] 147 | context_keywords_list = item['context_kws'] + list(chain(*item['target_kws'])) 148 | for i in range(len(keywords_list)): 149 | keywords_list[i] = helper.soo_token.join(keywords_list[i]) 150 | output += helper.soos_token.join(keywords_list) 151 | 152 | output += helper.bob_token 153 | output += ' '.join(sents[:peak_idx]) 154 | output += helper.boe_token 155 | output += ' '.join(sents[peak_idx + 1:]) 156 | output += helper.tokenizer.eos_token 157 | input_ts = torch.tensor(helper.tokenizer.encode(input, add_special_tokens=False), dtype=torch.long) 158 | output_ts = torch.tensor(helper.tokenizer.encode(output, add_special_tokens=False), dtype=torch.long) 159 | 160 | nodes_dis = [] 161 | outline_mask = [] 162 | keywords_list = list(chain(*old_keywords_list)) 163 | for i in range(len(keywords_list)): 164 | words = graph.get_hops_set(keywords_list[:i] + context_keywords_list, hop=1) 165 | if keywords_list[i] in words: 166 | outline_mask.append(1) 167 | nodes_dis.append(get_nodes_dis(words)) 168 | else: 169 | outline_mask.append(0) 170 | nodes_dis.append(get_nodes_dis(words)) 171 | 172 | outline_mask_piece = [] 173 | node_dis_piece = [] 174 | ori_id = 0 175 | cat_tensor = torch.cat([input_ts, output_ts], dim=0) 176 | for i, value in enumerate(cat_tensor): 177 | if value == helper.get_token_id(helper.eot_token): 178 | for j, value in enumerate(cat_tensor[i + 1:]): 179 | if value == helper.get_token_id(helper.soo_token) or value == helper.get_token_id( 180 | helper.soos_token): 181 | outline_mask_piece.append(0) # sepofoutline算不在图谱中,置0 182 | node_dis_piece.append(nodes_dis[ori_id]) 183 | ori_id += 1 184 | continue 185 | elif value == helper.get_token_id(helper.bob_token): 186 | end = j 187 | outline_mask_piece.append(0) 188 | if len(keywords_list): 189 | node_dis_piece.append(nodes_dis[ori_id]) 190 | else: 191 | node_dis_piece.append(get_nodes_dis(words=None, in_graph=False)) 192 | break 193 | else: 194 | outline_mask_piece.append(outline_mask[ori_id]) 195 | node_dis_piece.append(nodes_dis[ori_id]) 196 | break 197 | 198 | node_dis_ts = torch.stack(node_dis_piece, dim=0) 199 | outline_mask_ts = torch.tensor(outline_mask_piece, dtype=torch.long) 200 | 201 | start_idx = (cat_tensor == helper.get_token_id(helper.eot_token)).nonzero(as_tuple=False).item() 202 | end_idx = (cat_tensor == helper.get_token_id(helper.bob_token)).nonzero(as_tuple=False).item() 203 | ids = cat_tensor[start_idx + 1: end_idx + 1].tolist() 204 | 205 | context_kws = item['context_kws'] 206 | target_kws = list(chain(*item['target_kws'])) 207 | outline_kws = list(chain(*old_keywords_list)) 208 | return {'input': input_ts, 'output': output_ts, 'peak_idx': peak_idx, 209 | 'nodes_dis': node_dis_ts, 210 | 'context_kws': context_kws, 'target_kws': target_kws, 'outline_kws': outline_kws, 211 | 'outline_mask': outline_mask_ts} 212 | 213 | 214 | def pad_collate(batch): 215 | def get_attention_mask(max_len, one_len): 216 | a = np.ones(max_len) 217 | a[one_len:] = 0 218 | return torch.tensor(a, dtype=torch.float) 219 | 220 | 221 | def loss_mask(ts, start, end): 222 | ts[:start] = -100 223 | ts[end:] = -100 224 | return ts 225 | res = {} 226 | res['input_ids'] = pad_sequence([torch.cat([x['input'], x['output']])[:1024] for x in batch], batch_first=True, 227 | padding_value=helper.tokenizer.pad_token_id) 228 | res['attention_mask'] = torch.stack( 229 | [get_attention_mask(res['input_ids'].size(1), len(x['input']) + len(x['output'])) for x in batch]) 230 | res['labels'] = torch.stack( 231 | [loss_mask(copy.deepcopy(res['input_ids'][idx]), len(x['input']), len(x['input']) + len(x['output'])) for idx, x 232 | in enumerate(batch)]) 233 | res['nodes_dis'] = [sample['nodes_dis'] for sample in batch] 234 | 235 | res['peak_idx'] = [sample['peak_idx'] for sample in batch] 236 | res['context_kws'] = [sample['context_kws'] for sample in batch] 237 | res['target_kws'] = [sample['target_kws'] for sample in batch] 238 | res['outline_kws'] = [sample['outline_kws'] for sample in batch] 239 | res['outline_mask'] = [sample['outline_mask'] for sample in batch] 240 | return res 241 | 242 | 243 | class WordTokenizer(): 244 | def __init__(self, file_path=''): 245 | with open(file_path, encoding='utf-8') as f: 246 | self.word2ids = json.load(f) 247 | 248 | def encode(self, word): 249 | try: 250 | result = self.word2ids[word] 251 | except: 252 | result = helper.tokenizer.encode(word, add_special_tokens=False) 253 | self.word2ids[word] = result 254 | return result 255 | 256 | class Gpt2OutLineModel(GPT2LMHeadModel): 257 | 258 | @property 259 | def wte(self): 260 | return self.transformer.wte 261 | 262 | def __init__(self, config): 263 | super().__init__(config) 264 | self.outline_classify_head = nn.Linear(3 * config.n_embd, 2, bias=False) 265 | self.outline_wquery = nn.Linear(config.n_embd, config.n_embd, bias=True) 266 | self.outline_wvalue = nn.Linear(config.n_embd, config.n_embd, bias=True) 267 | self.word_wkey = nn.Linear(config.n_embd, config.n_embd, bias=True) 268 | self.word_wvalue = nn.Linear(config.n_embd, config.n_embd, bias=True) 269 | self.outline_lm_head = nn.Linear(3 * config.n_embd, helper.get_vocab_size(), bias=False) 270 | self.relation_num = graph.get_relation_size() # 得到关系个数 271 | print('relation num = ', self.relation_num) 272 | self.relation_tensor = torch.nn.Embedding(self.relation_num, config.n_embd) 273 | self.Wh = nn.Linear(config.n_embd, config.n_embd, bias=False) 274 | self.Wt = nn.Linear(config.n_embd, config.n_embd, bias=False) 275 | self.Wr = nn.Linear(config.n_embd, config.n_embd, bias=False) 276 | self.Wk = nn.Linear(2 * config.n_embd, config.n_embd, bias=False) 277 | self.dis_matrix = nn.Linear(3 * config.n_embd, config.n_embd, bias=False) 278 | print('Gpt2OutLineModel init (with config)!') 279 | 280 | def get_concept_embedding(self, word): 281 | return wordTokenizer.encode(word) 282 | 283 | def get_graph_vectors_words(self, words, device, generate=False): 284 | # return: [words_len, hidden_size] 285 | relation_embs = [] 286 | embedding_matrix = self.get_input_embeddings().weight 287 | 288 | lens = [] 289 | 290 | encode_lens = [] 291 | 292 | relation_ids = [] 293 | entity_ids = [] 294 | for word in words: 295 | triples = graph.get_triples(word) 296 | for h, r, t in triples: 297 | relation_ids.append(r) 298 | cur_head_id = self.get_concept_embedding(h) 299 | cur_tail_id = self.get_concept_embedding(t) 300 | entity_ids.append(cur_head_id) 301 | entity_ids.append(cur_tail_id) 302 | encode_lens.append(len(cur_head_id)) 303 | encode_lens.append(len(cur_tail_id)) 304 | lens.append(len(triples)) 305 | 306 | max_len = max(encode_lens) 307 | padded_entity_ids = [x + [0] * (max_len - len(x)) for x in entity_ids] 308 | pad_entity_ids = torch.tensor(padded_entity_ids, device=device) 309 | embeds = embedding_matrix[pad_entity_ids] # [ triple_lens, max_word_piece_len, hidden_size] 310 | 311 | # [triple_lens, max_len] 312 | mask_emb = np.zeros((len(encode_lens), max_len)) 313 | for data, l in zip(mask_emb, encode_lens): 314 | data[: l] = 1 315 | mask_emb = torch.tensor(mask_emb, device=device, dtype=torch.float).unsqueeze(-1) 316 | 317 | temp = embeds * mask_emb 318 | real_emb = torch.sum(temp, dim=1) # [triple_len, hidden_size] 319 | real_emb = real_emb.reshape((-1, 2, 768)) # [real_triple_len, 2, hidden_size] 320 | head_embs = real_emb[:, 0, :] 321 | tail_embs = real_emb[:, 1, :] 322 | concat_emb = torch.cat([head_embs, tail_embs], dim=1) # [real_triple_len, hidden_size * 2] 323 | 324 | relation_embs = self.relation_tensor(torch.tensor(relation_ids, device=device)) 325 | 326 | x = self.Wr(relation_embs) 327 | y = torch.tanh(self.Wh(head_embs) + self.Wt(tail_embs)) 328 | betas = torch.sum(x * y, dim=-1) 329 | start = 0 330 | ans = [] 331 | 332 | for i, l in enumerate(lens): 333 | end = start + l 334 | b = betas[start:end] 335 | alphas = F.softmax(b, dim=0) 336 | concat_ts = concat_emb[start:end] 337 | alphas = alphas.unsqueeze(dim=-1) 338 | result = alphas * concat_ts 339 | ans.append(torch.sum(result, dim=0)) 340 | start = end 341 | 342 | if generate: 343 | return ans 344 | else: 345 | return torch.stack(ans, dim=0) 346 | 347 | def compute_all_graph_vectors(self, context_kws, target_kws, outline_kws, device): 348 | res = [] 349 | for c, t, o in zip(context_kws, target_kws, outline_kws): 350 | cv = self.get_graph_vectors_words(c, device) if c else None 351 | tv = self.get_graph_vectors_words(t, device) if t else None 352 | ov = self.get_graph_vectors_words(o, device) if o else None 353 | res.append((cv, tv, ov)) 354 | return res 355 | 356 | def get_context_graph_vector(self, hidden_state, words=None, gvs=None, mask=None): 357 | # hidden_state : [seq_len, hidden_size] 358 | # graph_vectors : [kws_len, 2 * hidden_size] 359 | # mask: [seq_len, kws_len] 360 | # return : [seq_len, hidden_size * 2] 361 | betas = [] 362 | graph_vectors = [] 363 | 364 | assert gvs is not None 365 | graph_vectors = gvs 366 | 367 | betas = torch.matmul(hidden_state, self.Wk(graph_vectors).T) 368 | if mask is not None: 369 | betas.masked_fill_(mask, -1e10) 370 | 371 | alphas = torch.softmax(betas, dim=-1) 372 | return torch.matmul(alphas, graph_vectors) 373 | 374 | def get_logits(self, hidden_states, input_ids, logits_mask=None, logits_mask_on=None, generate=False, 375 | outline_label=None, context_kws=None, target_kws=None, outline_kws=None): 376 | if not generate: 377 | start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero(as_tuple=False).item() for batch in input_ids] 378 | end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero(as_tuple=False).item() for batch in input_ids] 379 | a = [] 380 | for idx, batch in enumerate(hidden_states): 381 | before = batch[:start_idxs[idx]] 382 | after = batch[end_idxs[idx]:] 383 | x = batch[start_idxs[idx]: end_idxs[idx]] 384 | x2 = self.get_hidden_combine_kg(x, input_ids[idx][start_idxs[idx]: end_idxs[idx]], 385 | context_kws=context_kws[idx], target_kws=target_kws[idx], 386 | outline_kws=outline_kws[idx], batch_idx=idx) 387 | x = torch.cat([x, x2], dim=-1) 388 | x = self.outline_lm_head(x) 389 | w = logits_mask[idx] * (outline_label[idx].unsqueeze(dim=-1)) 390 | x += w 391 | before = self.lm_head(before) 392 | after = self.lm_head(after) 393 | ts = torch.cat([before, x, after], dim=0) 394 | a.append(ts) 395 | return torch.stack(a, dim=0) 396 | else: 397 | res = [] 398 | lm_logits = self.lm_head(hidden_states[:, -1, :]) 399 | for idx, batch in enumerate(hidden_states): 400 | if logits_mask_on[idx] == False: 401 | res.append(lm_logits[idx]) 402 | continue 403 | hid = batch[-1].unsqueeze(0) 404 | hid2 = self.get_context_graph_vector(hid, gvs=torch.stack(self.generated_graph_vectors[idx], dim=0)) 405 | self.hidden_combine_kg_res[idx] = hid2 406 | combine_hid = torch.cat([hid, hid2], dim=-1).squeeze(0) 407 | logit = self.outline_classify_head(combine_hid) 408 | 409 | w = self.outline_lm_head(combine_hid) 410 | 411 | if logit[1] > logit[0]: 412 | w += logits_mask[idx] 413 | 414 | res.append(w) 415 | return torch.stack(res, dim=0).unsqueeze(1) 416 | 417 | def get_hidden_combine_kg(self, hidden_states, input_ids, context_kws, target_kws, outline_kws, batch_idx): 418 | if batch_idx in self.hidden_combine_kg_res: 419 | return self.hidden_combine_kg_res[batch_idx] 420 | 421 | gvs = torch.cat( 422 | [self.graph_vectors[batch_idx][i] for i in range(3) if self.graph_vectors[batch_idx][i] is not None], dim=0) 423 | col_num = gvs.shape[0] 424 | cnt = 0 425 | for i in range(2): 426 | if self.graph_vectors[batch_idx][i] is not None: 427 | cnt += self.graph_vectors[batch_idx][i].size(0) 428 | mask_lens = [] 429 | 430 | for idx, (hidden_state, input_id) in enumerate(zip(hidden_states, input_ids)): 431 | if input_id in [helper.get_token_id(helper.soo_token), helper.get_token_id(helper.soos_token)]: 432 | cnt += 1 433 | mask_lens.append(cnt) 434 | 435 | masks = F.one_hot(torch.tensor(mask_lens, device=hidden_states.device, dtype=torch.long), col_num + 1) 436 | masks = torch.cumsum(masks, dim=-1)[:, :-1].bool() 437 | self.hidden_combine_kg_res[batch_idx] = self.get_context_graph_vector(hidden_states, gvs=gvs, mask=masks) 438 | 439 | return self.hidden_combine_kg_res[batch_idx] 440 | 441 | def get_outline_classify_loss(self, hidden_states, input_ids, outline_label, context_kws, target_kws, outline_kws): 442 | start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero(as_tuple=False).item() for batch in input_ids] 443 | end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero(as_tuple=False).item() for batch in input_ids] 444 | 445 | hid_ts = None 446 | label_ts = None 447 | hid2_ts = None 448 | new_hidden_states = hidden_states 449 | 450 | for idx, batch in enumerate(hidden_states): 451 | hid = batch[start_idxs[idx]: end_idxs[idx]] 452 | 453 | if hid_ts is None: 454 | hid_ts = new_hidden_states[idx][start_idxs[idx]: end_idxs[idx]] 455 | hid2_ts = self.get_hidden_combine_kg(hid, input_ids[idx][start_idxs[idx]: end_idxs[idx]], 456 | context_kws=context_kws[idx], target_kws=target_kws[idx], 457 | outline_kws=outline_kws[idx], batch_idx=idx) 458 | else: 459 | hid_ts = torch.cat([hid_ts, new_hidden_states[idx][start_idxs[idx]: end_idxs[idx]]], dim=0) 460 | hid2_ts = torch.cat([hid2_ts, 461 | self.get_hidden_combine_kg(hid, input_ids[idx][start_idxs[idx]: end_idxs[idx]], 462 | context_kws=context_kws[idx], 463 | target_kws=target_kws[idx], 464 | outline_kws=outline_kws[idx], batch_idx=idx)], dim=0) 465 | 466 | if label_ts is None: 467 | label_ts = outline_label[idx] 468 | else: 469 | label_ts = torch.cat([label_ts, outline_label[idx]]) 470 | 471 | new_hid_ts = torch.cat([hid_ts, hid2_ts], dim=-1) 472 | new_hid_ts = self.outline_classify_head(new_hid_ts) 473 | loss_fct = nn.CrossEntropyLoss() 474 | return loss_fct(new_hid_ts, label_ts), new_hid_ts 475 | 476 | def combine_relative_dis(self, hidden_states, input_ids, lm_logits, generate, logits_mask_on=None): 477 | 478 | if not generate: 479 | start_idxs = [(batch == helper.get_token_id(helper.eot_token)).nonzero(as_tuple=False).item() for batch in input_ids] 480 | end_idxs = [(batch == helper.get_token_id(helper.bob_token)).nonzero(as_tuple=False).item() for batch in input_ids] 481 | 482 | target_start_idxs = [(batch == helper.get_token_id(helper.bot_token)).nonzero(as_tuple=False).item() for batch in 483 | input_ids] 484 | target_end_idxs = start_idxs 485 | word_embeddings = self.get_input_embeddings().weight 486 | 487 | res = [] 488 | 489 | for idx, batch in enumerate(hidden_states): 490 | if generate and logits_mask_on[idx] == False: 491 | res.append(lm_logits[idx]) 492 | continue 493 | if not generate: 494 | target_ts = torch.mean(batch[target_start_idxs[idx]: target_end_idxs[idx]], dim=0) 495 | else: 496 | target_ts = torch.mean(torch.stack(self.target_hidden_states[idx], dim=0), dim=0) 497 | hid_combine_ts = self.hidden_combine_kg_res[idx] 498 | # [seq_len, hidden_size * 3] 499 | concat_ts = torch.cat([hid_combine_ts, target_ts.unsqueeze(0).repeat(hid_combine_ts.size(0), 1)], dim=-1) 500 | d = torch.matmul(self.dis_matrix(concat_ts), word_embeddings.T) 501 | if not generate: 502 | before = lm_logits[idx][:start_idxs[idx]] 503 | mid = lm_logits[idx][start_idxs[idx]:end_idxs[idx]] 504 | after = lm_logits[idx][end_idxs[idx]:] 505 | # mid_combine = (d + mid) / 2 506 | mid_combine = d + mid 507 | res.append(torch.cat([before, mid_combine, after], dim=0)) 508 | else: 509 | res.append(lm_logits[idx][-1].unsqueeze(0) + d) 510 | 511 | return torch.stack(res, dim=0) 512 | 513 | def forward( 514 | self, 515 | input_ids=None, 516 | past_key_values=None, 517 | attention_mask=None, 518 | token_type_ids=None, 519 | position_ids=None, 520 | head_mask=None, 521 | inputs_embeds=None, 522 | encoder_hidden_states=None, 523 | encoder_attention_mask=None, 524 | labels=None, 525 | use_cache=None, 526 | output_attentions=None, 527 | output_hidden_states=None, 528 | return_dict=None, 529 | context_kws=None, 530 | target_kws=None, 531 | outline_kws=None, 532 | 533 | logits_mask=None, 534 | logits_mask_on=None, 535 | generate=False, 536 | outline_label=None, 537 | train_step=None, 538 | ): 539 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 540 | 541 | transformer_outputs = self.transformer( 542 | input_ids, 543 | past_key_values=past_key_values, 544 | attention_mask=attention_mask, 545 | token_type_ids=token_type_ids, 546 | position_ids=position_ids, 547 | head_mask=head_mask, 548 | inputs_embeds=inputs_embeds, 549 | encoder_hidden_states=encoder_hidden_states, 550 | encoder_attention_mask=encoder_attention_mask, 551 | use_cache=use_cache, 552 | output_attentions=output_attentions, 553 | output_hidden_states=output_hidden_states, 554 | return_dict=return_dict, 555 | ) 556 | hidden_states = transformer_outputs[0] 557 | if generate: 558 | for idx, hidden_state in enumerate(hidden_states): 559 | if self.target_flag[idx]: 560 | self.target_hidden_states[idx].append(hidden_state[-1, :]) 561 | 562 | if not generate: 563 | self.graph_vectors = self.compute_all_graph_vectors(context_kws=context_kws, target_kws=target_kws, 564 | outline_kws=outline_kws, 565 | device=hidden_states.device) 566 | self.hidden_combine_kg_res = {} 567 | lm_logits = self.get_logits(hidden_states, input_ids, logits_mask, 568 | logits_mask_on, generate, outline_label, context_kws=context_kws, 569 | target_kws=target_kws, outline_kws=outline_kws) 570 | 571 | lm_logits = self.combine_relative_dis(hidden_states, input_ids, lm_logits, generate, logits_mask_on) 572 | loss = None 573 | loss2 = None 574 | cls_logits = None 575 | if labels is not None: 576 | # Shift so that tokens < n predict n 577 | shift_logits = lm_logits[..., :-1, :].contiguous() 578 | shift_labels = labels[..., 1:].contiguous() 579 | loss_fct = nn.CrossEntropyLoss() 580 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 581 | loss2, cls_logits = self.get_outline_classify_loss(hidden_states, input_ids, outline_label, 582 | context_kws=context_kws, target_kws=target_kws, 583 | outline_kws=outline_kws) 584 | self.hidden_combine_kg_res = {} 585 | loss += loss2 586 | 587 | if not return_dict: 588 | output = (lm_logits,) + (cls_logits,) + transformer_outputs[1:] 589 | return ((loss, loss2) + output) if loss is not None else output 590 | 591 | return CausalLMOutputWithPastAndCrossAttentions( 592 | loss=loss, 593 | logits=lm_logits, 594 | past_key_values=transformer_outputs.past_key_values, 595 | hidden_states=transformer_outputs.hidden_states, 596 | attentions=transformer_outputs.attentions, 597 | cross_attentions=transformer_outputs.cross_attentions, 598 | ) 599 | 600 | def init_state(self, input_ids): 601 | batch_size = input_ids.size(0) 602 | self.logits_mask_on = [False] * batch_size 603 | self.logits_mask = torch.zeros([batch_size, helper.get_vocab_size()], dtype=torch.float, 604 | device=input_ids.device) 605 | self.target_flag = [False] * batch_size 606 | self.target_hidden_states = [[] for _ in range(batch_size)] 607 | self.computed = [False] * batch_size 608 | self.last_spec_pos = [-1] * batch_size 609 | self.generated_outline = [[] for _ in range(batch_size)] 610 | self.generated_graph_vectors = [[] for _ in range(batch_size)] 611 | self.generated_target_kws = [[] for _ in range(batch_size)] 612 | 613 | def get_target_kws(self, input_ids): 614 | # input_ids: [seq_len] 615 | start_idx = (input_ids == helper.get_token_id(helper.bot_token)).nonzero(as_tuple=False)[0][0].item() 616 | end_idx = (input_ids == helper.get_token_id(helper.eot_token)).nonzero(as_tuple=False)[0][0].item() 617 | id = input_ids[start_idx + 1: end_idx].tolist() 618 | sent = helper.tokenizer.decode(id) 619 | words, _ = extract_emotion_event_at_least_one(sent, per_sent=True) 620 | assert isinstance(words[0], list) 621 | words = list(chain(*words)) 622 | return words 623 | 624 | def update_state(self, input_ids, context_kws, device): 625 | # when <|endoftarget|> appears, compute intersect nodes and turn on logits_mask 626 | # when <|beginofbedding|> appears, turn off logits_mask 627 | for idx, batch in enumerate(input_ids): 628 | if batch[-1].item() == helper.get_token_id(helper.bot_token): 629 | self.target_flag[idx] = True 630 | 631 | if batch[-1].item() == helper.get_token_id(helper.eot_token): 632 | self.last_spec_pos[idx] = len(batch) - 1 633 | self.logits_mask_on[idx] = True 634 | self.target_flag[idx] = False 635 | 636 | self.generated_target_kws[idx] = self.get_target_kws(batch) 637 | self.generated_graph_vectors[idx] = self.get_graph_vectors_words( 638 | context_kws[idx] + self.generated_target_kws[idx], device=device, generate=True) 639 | 640 | if self.logits_mask_on[idx] and (batch[-1].item() == helper.get_token_id(helper.soo_token) or batch[ 641 | -1].item() == helper.get_token_id(helper.soos_token)) and self.computed[idx] == False: 642 | # self.logits_mask_on[idx] = True 643 | word_pieces = batch[self.last_spec_pos[idx] + 1:-1] 644 | word = helper.tokenizer.decode(word_pieces.tolist()) 645 | self.generated_outline[idx].append(word) 646 | 647 | self.generated_graph_vectors[idx].append( 648 | self.get_graph_vectors_words([word], device=device, generate=True)[0]) 649 | assert self.generated_target_kws[idx] 650 | target_kws = self.generated_target_kws[idx] 651 | words = set(self.generated_outline[idx]) 652 | words |= set(target_kws + context_kws[idx]) 653 | words = graph.get_hops_set(words, hop=1) # 用context + target + already generated outline的1-hop 654 | self.logits_mask[idx] = get_nodes_dis(words, device=input_ids.device) 655 | self.last_spec_pos[idx] = len(batch) - 1 656 | 657 | elif batch[-1].item() == helper.get_token_id(helper.bob_token): 658 | self.logits_mask_on[idx] = False 659 | self.computed[idx] = True 660 | 661 | def sample( 662 | self, 663 | input_ids: torch.LongTensor, 664 | logits_processor: Optional[LogitsProcessorList] = None, 665 | logits_warper: Optional[LogitsProcessorList] = None, 666 | max_length: Optional[int] = None, 667 | pad_token_id: Optional[int] = None, 668 | eos_token_id: Optional[int] = None, 669 | **model_kwargs 670 | ): 671 | # init values 672 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 673 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 674 | max_length = max_length if max_length is not None else self.config.max_length 675 | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id 676 | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id 677 | 678 | # init sequence length tensors 679 | sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( 680 | input_ids, max_length 681 | ) 682 | 683 | # auto-regressive generation 684 | context_kws = model_kwargs['context_kws'] 685 | self.init_state(input_ids) 686 | 687 | while cur_len < max_length: 688 | # prepare model inputs 689 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 690 | self.update_state(input_ids, context_kws, f"cuda:{args.gpu}") 691 | # forward pass to get next token 692 | outputs = self(**model_inputs, return_dict=True, generate=True, logits_mask=self.logits_mask, 693 | logits_mask_on=self.logits_mask_on) 694 | next_token_logits = outputs.logits[:, -1, :] 695 | 696 | # pre-process distribution 697 | scores = logits_processor(input_ids, next_token_logits) 698 | scores = logits_warper(input_ids, scores) 699 | 700 | # sample 701 | probs = F.softmax(scores, dim=-1) 702 | 703 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 704 | 705 | # add code that transfomers next_tokens to tokens_to_add 706 | if eos_token_id is not None: 707 | assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." 708 | next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) 709 | 710 | # add token and increase length by one 711 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 712 | 713 | cur_len = cur_len + 1 714 | 715 | # update sequence length 716 | if eos_token_id is not None: 717 | sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( 718 | sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id 719 | ) 720 | 721 | # stop when there is a in each sentence, or if we exceed the maximul length 722 | if unfinished_sequences.max() == 0: 723 | break 724 | 725 | # update model kwargs 726 | model_kwargs = self._update_model_kwargs_for_generation( 727 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 728 | ) 729 | 730 | return input_ids 731 | 732 | 733 | class Gpt2(pl.LightningModule): 734 | def __init__(self, config_dir): 735 | super().__init__() 736 | self.model = Gpt2OutLineModel.from_pretrained(config_dir) 737 | self.model.resize_token_embeddings(len(helper.tokenizer)) 738 | print('vocab size = ', len(helper.tokenizer)) 739 | 740 | def get_inputs_embeds(self, input_ids, segment_ids=None): 741 | if segment_ids: 742 | return self.model.wte(input_ids) + self.model.wte(segment_ids) 743 | else: 744 | return self.model.wte(input_ids) 745 | 746 | def forward(self, x): 747 | return self.model(x) 748 | 749 | def training_step(self, batch, batch_idx): 750 | input_ids = batch['input_ids'] 751 | attention_mask = batch['attention_mask'] 752 | labels = batch['labels'] 753 | logits_mask = batch['nodes_dis'] 754 | outline_mask = batch['outline_mask'] 755 | 756 | context_kws = batch['context_kws'] 757 | target_kws = batch['target_kws'] 758 | outline_kws = batch['outline_kws'] 759 | 760 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, 761 | logits_mask=logits_mask, outline_label=outline_mask, context_kws=context_kws, 762 | target_kws=target_kws, outline_kws=outline_kws, return_dict=False, train_step=batch_idx) 763 | 764 | self.log('train_loss', outputs[0].item()) 765 | self.log('train_classify_loss', outputs[1].item()) 766 | return outputs[0] 767 | 768 | def validation_step(self, batch, batch_idx): 769 | input_ids = batch['input_ids'] 770 | attention_mask = batch['attention_mask'] 771 | labels = batch['labels'] 772 | logits_mask = batch['nodes_dis'] 773 | outline_mask = batch['outline_mask'] 774 | 775 | context_kws = batch['context_kws'] 776 | target_kws = batch['target_kws'] 777 | outline_kws = batch['outline_kws'] 778 | 779 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, 780 | logits_mask=logits_mask, outline_label=outline_mask, context_kws=context_kws, 781 | target_kws=target_kws, outline_kws=outline_kws, return_dict=False) 782 | 783 | self.log('val_loss', outputs[0].item()) 784 | self.log('val_classify_loss', outputs[1].item()) 785 | return outputs[0] 786 | 787 | def configure_optimizers(self): 788 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5) 789 | return optimizer 790 | 791 | 792 | def train(args): 793 | valid_dataset = StoriumDataset(args.valid_data) 794 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 795 | num_workers=args.num_workers, collate_fn=pad_collate) 796 | 797 | train_dataset = StoriumDataset(args.train_data) 798 | train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 799 | num_workers=args.num_workers, collate_fn=pad_collate) 800 | 801 | print('after load data') 802 | 803 | model = Gpt2(args.config_path) 804 | 805 | checkpoint_callback = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', verbose=True) 806 | earlystop_callback = EarlyStopping(monitor='val_loss', verbose=True, mode='min') 807 | trainer = pl.Trainer(gpus=[args.gpu], max_epochs=args.epoch, val_check_interval=args.eval_interval, 808 | callbacks=[checkpoint_callback, earlystop_callback], 809 | default_root_dir=args.save_dir, 810 | accumulate_grad_batches=args.accumulate_grad) 811 | 812 | trainer.fit(model=model, train_dataloader=train_dataloader, val_dataloaders=valid_dataloader) 813 | 814 | def generate(args): 815 | model = Gpt2(args.config_path) 816 | ckpt = torch.load(args.ckpt_path, map_location="cuda:{}".format(args.gpu)) 817 | model.load_state_dict(ckpt['state_dict']) 818 | device = torch.device("cuda:{}".format(args.gpu)) 819 | model.to(device) 820 | model.eval() 821 | 822 | test_dataset = StoriumDataset(args.test_data) 823 | 824 | data = test_dataset.data 825 | res = [] 826 | 827 | print(f"out file:{args.output_path}") 828 | ending = len(test_dataset) 829 | for idx in tqdm(range(0, ending, args.batch_size)): 830 | end = min(ending, idx + args.batch_size) 831 | batch = [] 832 | for j in range(idx, end): 833 | batch.append(test_dataset[j]) 834 | 835 | input_ids = [helper.tokenizer.decode(sample['input']) for sample in batch] 836 | 837 | helper.tokenizer.padding_side = "left" 838 | inputs = helper.tokenizer(input_ids, return_tensors="pt", padding=True) 839 | context_kws = [sample['context_kws'] for sample in batch] 840 | output_seqs = model.model.generate(input_ids=inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device), 841 | context_kws=context_kws, max_length=1024, top_p=0.9, temperature=args.temperature, 842 | do_sample=True, no_repeat_ngram_size=args.norepeatngram, use_cache=True) 843 | cards_text = [] 844 | for j in range(idx, end): 845 | scene = data[j] 846 | card_text = '' 847 | for entry in scene['entries'][-1:]: 848 | for card in entry['cards']: 849 | card_text += card['description'] + '' 850 | cards_text.append(card_text) 851 | answer = [helper.tokenizer.decode(sample['output'].tolist(), skip_special_tokens=True) for sample in batch] 852 | prompt = [helper.tokenizer.decode(sample['input'].tolist(), skip_special_tokens=True) for sample in batch] 853 | output_text = [helper.tokenizer.decode(sample.tolist(), skip_special_tokens=True).replace(prompt[idx], '', 1) for idx, sample in enumerate(output_seqs)] 854 | 855 | for j in range(0, end - idx): 856 | res.append( 857 | {'prompt': prompt[j], 'generated': output_text[j], 'answer': answer[j], 'cards': cards_text[j]}) 858 | 859 | with open(args.output_path, 'w', encoding='utf-8') as f: 860 | json.dump(res, f, indent=1, ensure_ascii=False) 861 | print(f"finish generate to {args.output_path}") 862 | 863 | 864 | 865 | def parse_args(): 866 | 867 | parser = argparse.ArgumentParser() 868 | 869 | # data args 870 | parser.add_argument('--train_data', type=str, default=None) 871 | parser.add_argument("--valid_data", type=str, default=None) 872 | parser.add_argument("--test_data", type=str, default=None) 873 | parser.add_argument("--conceptnet_path", type=str, default='data/conceptnet_cleaned_final.txt') 874 | parser.add_argument("--outlinevocab_path", type=str, default='data/outline_ids.json') 875 | 876 | # config / checkpoint 877 | parser.add_argument("--config_path", type=str, default="gpt2") 878 | parser.add_argument("--ckpt_path", type=str, default=None) 879 | 880 | # training args 881 | parser.add_argument('--gpu', type=int, default=0) 882 | parser.add_argument("--num_workers", type=int, default=4) 883 | parser.add_argument("--epoch", type=int, default=1) 884 | parser.add_argument("--eval_interval", type=float, default=0.5) 885 | parser.add_argument('--batch_size', type=int, default=16) 886 | parser.add_argument("--accumulate_grad", type=int, default=1) 887 | parser.add_argument("--save_dir", type=str, default='logs') 888 | parser.add_argument("--seed", type=int, default=42) 889 | 890 | # generate args 891 | parser.add_argument('--generate', action='store_true') 892 | parser.add_argument('--norepeatngram', type=int, default=0) 893 | parser.add_argument('--temperature', type=float, default=1.0) 894 | parser.add_argument("--output_path", type=str, default=None) 895 | 896 | args = parser.parse_args() 897 | 898 | return args 899 | 900 | if __name__ == '__main__': 901 | args = parse_args() 902 | from config import Config 903 | configs = Config(args=args, file=__file__) 904 | configs.show() 905 | args = configs 906 | 907 | pl.seed_everything(args.seed) 908 | 909 | graph = get_conceptnet(args.conceptnet_path) 910 | print('finish load graph!') 911 | wordTokenizer = WordTokenizer(file_path=args.outlinevocab_path) 912 | 913 | helper = Helper(args) 914 | 915 | if args.generate: 916 | generate(args) 917 | else: 918 | train(args) 919 | --------------------------------------------------------------------------------